From 2cdd3f3ec5ccf7e64d781baac286f700a033cd3f Mon Sep 17 00:00:00 2001 From: peelz Date: Thu, 15 Sep 2022 18:09:32 -0400 Subject: [PATCH] Add constructor support to Hook.Patch --- .../SharedSource/LuaCs/LuaCsHook.cs | 69 +++++---- .../SharedSource/LuaCs/LuaCsHookCompat.cs | 29 ++-- .../BarotraumaTest/LuaCs/HookPatchHelpers.cs | 136 +++++++++++------- .../BarotraumaTest/LuaCs/HookPatchTests.cs | 130 ++++++++++++++--- 4 files changed, 247 insertions(+), 117 deletions(-) diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs index e1f3e8c84..344856c46 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs @@ -519,9 +519,9 @@ namespace Barotrauma "ContentPackageManager", }; - private static void ValidatePatchTarget(MethodInfo methodInfo) + private static void ValidatePatchTarget(MethodBase method) { - if (prohibitedHooks.Any(h => methodInfo.DeclaringType.FullName.StartsWith(h))) + if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h))) { throw new ArgumentException("Hooks into the modding environment are prohibited."); } @@ -575,7 +575,7 @@ namespace Barotrauma return !(left == right); } - public static MethodKey Create(MethodInfo method) => new MethodKey + public static MethodKey Create(MethodBase method) => new MethodKey { ModuleHandle = method.Module.ModuleHandle, MetadataToken = method.MetadataToken, @@ -814,30 +814,37 @@ namespace Barotrauma public object Call(string name, params object[] args) => Call(name, args); - private static MethodInfo ResolveMethod(string className, string methodName, string[] parameterNames) + private static MethodBase ResolveMethod(string className, string methodName, string[] parameters) { var classType = LuaUserData.GetType(className); if (classType == null) throw new InvalidOperationException($"Invalid class name '{className}'"); const BindingFlags BINDING_FLAGS = BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; - MethodInfo methodInfo = null; - if (parameterNames != null) + const string CTOR = ".ctor"; + + MethodBase method = null; + if (parameters != null) { - var parameterTypes = parameterNames.Select(x => LuaUserData.GetType(x)).ToArray(); - methodInfo = classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null); + var parameterTypes = parameters.Select(x => LuaUserData.GetType(x)).ToArray(); + // TODO: remove the casts once we can use C# 9 features + method = methodName == CTOR + ? (MethodBase)classType.GetConstructor(BINDING_FLAGS, null, parameterTypes, null) + : (MethodBase)classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null); } else { - methodInfo = classType.GetMethod(methodName, BINDING_FLAGS); + method = methodName == CTOR + ? (MethodBase)classType.GetConstructor(BINDING_FLAGS, null, Array.Empty(), null) + : (MethodBase)classType.GetMethod(methodName, BINDING_FLAGS); } - if (methodInfo == null) + if (method == null) { - var parameterNamesStr = parameterNames == null ? "" : string.Join(", ", parameterNames); + var parameterNamesStr = parameters == null ? "" : string.Join(", ", parameters); throw new InvalidOperationException($"Method '{methodName}({parameterNamesStr})' not found in class '{className}'"); } - return methodInfo; + return method; } private class DynamicParameterMapping @@ -863,7 +870,7 @@ namespace Barotrauma // If you need to debug this: // - use https://sharplab.io ; it's a very useful for resource for writing IL by hand. // - use il.NewMessage("") or il.WriteLine("") to see where the IL crashes at runtime. - private MethodInfo CreateDynamicHarmonyPatch(string identifier, MethodInfo original, HookMethodType hookType) + private MethodInfo CreateDynamicHarmonyPatch(string identifier, MethodBase original, HookMethodType hookType) { var parameters = new List { @@ -871,7 +878,7 @@ namespace Barotrauma new DynamicParameterMapping("__instance", null, typeof(object)), }; - var hasReturnType = original.ReturnType != typeof(void); + var hasReturnType = original is MethodInfo mi && mi.ReturnType != typeof(void); if (hasReturnType) { parameters.Add(new DynamicParameterMapping("__result", null, typeof(object).MakeByRefType())); @@ -919,7 +926,7 @@ namespace Barotrauma // IL: var patchKey = MethodKey.Create(__originalMethod); var patchKey = il.DeclareLocal("patchKey"); il.LoadArgument(0); // load __originalMethod - il.CastClass(); + il.CastClass(); il.Call(typeof(MethodKey).GetMethod(nameof(MethodKey.Create))); il.StoreLocal(patchKey); @@ -1032,10 +1039,10 @@ namespace Barotrauma { // IL: var csReturnType = Type.GetTypeFromHandle(); var csReturnType = il.DeclareLocal("csReturnType"); - il.LoadType(original.ReturnType); + il.LoadType(((MethodInfo)original).ReturnType); il.StoreLocal(csReturnType); - // IL: var csReturnValue = luaReturnValue.ToObject(csReturnValueType); + // IL: var csReturnValue = luaReturnValue.ToObject(csReturnType); var csReturnValue = il.DeclareLocal("csReturnValue"); il.LoadLocal(luaReturnValue); il.LoadLocal(csReturnType); @@ -1149,7 +1156,7 @@ namespace Barotrauma return type.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static); } - private string Patch(string identifier, MethodInfo method, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) + private string Patch(string identifier, MethodBase method, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) { if (method == null) throw new ArgumentNullException(nameof(method)); if (patch == null) throw new ArgumentNullException(nameof(patch)); @@ -1199,29 +1206,29 @@ namespace Barotrauma public string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) { - var methodInfo = ResolveMethod(className, methodName, parameterTypes); - return Patch(identifier, methodInfo, patch, hookType); + var method = ResolveMethod(className, methodName, parameterTypes); + return Patch(identifier, method, patch, hookType); } public string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) { - var methodInfo = ResolveMethod(className, methodName, null); - return Patch(identifier, methodInfo, patch, hookType); + var method = ResolveMethod(className, methodName, null); + return Patch(identifier, method, patch, hookType); } public string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) { - var methodInfo = ResolveMethod(className, methodName, parameterTypes); - return Patch(null, methodInfo, patch, hookType); + var method = ResolveMethod(className, methodName, parameterTypes); + return Patch(null, method, patch, hookType); } public string Patch(string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) { - var methodInfo = ResolveMethod(className, methodName, null); - return Patch(null, methodInfo, patch, hookType); + var method = ResolveMethod(className, methodName, null); + return Patch(null, method, patch, hookType); } - private bool RemovePatch(string identifier, MethodInfo method, HookMethodType hookType) + private bool RemovePatch(string identifier, MethodBase method, HookMethodType hookType) { if (identifier == null) throw new ArgumentNullException(nameof(identifier)); identifier = NormalizeIdentifier(identifier); @@ -1242,14 +1249,14 @@ namespace Barotrauma public bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, HookMethodType hookType) { - var methodInfo = ResolveMethod(className, methodName, parameterTypes); - return RemovePatch(identifier, methodInfo, hookType); + var method = ResolveMethod(className, methodName, parameterTypes); + return RemovePatch(identifier, method, hookType); } public bool RemovePatch(string identifier, string className, string methodName, HookMethodType hookType) { - var methodInfo = ResolveMethod(className, methodName, null); - return RemovePatch(identifier, methodInfo, hookType); + var method = ResolveMethod(className, methodName, null); + return RemovePatch(identifier, method, hookType); } } } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs index b3a4eb9be..dd8f39261 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs @@ -1,10 +1,9 @@ -using System; +using System; using System.Linq; using System.Reflection; using HarmonyLib; using System.Collections.Generic; using MoonSharp.Interpreter; -using static Barotrauma.LuaCsSetup; using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch; namespace Barotrauma @@ -122,7 +121,7 @@ namespace Barotrauma private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); // TODO: deprecate this - public void HookMethod(string identifier, MethodInfo method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null) + public void HookMethod(string identifier, MethodBase method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null) { if (identifier == null || method == null || patch == null) { @@ -136,7 +135,7 @@ namespace Barotrauma if (hookType == HookMethodType.Before) { - if (method.ReturnType != typeof(void)) + if (method is MethodInfo mi && mi.ReturnType != typeof(void)) { if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchRetPrefix) == null) { @@ -168,7 +167,7 @@ namespace Barotrauma } else if (hookType == HookMethodType.After) { - if (method.ReturnType != typeof(void)) + if (method is MethodInfo mi && mi.ReturnType != typeof(void)) { if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchRetPostfix) == null) { @@ -200,13 +199,13 @@ namespace Barotrauma } protected void HookMethod(string identifier, string className, string methodName, string[] parameterNames, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) { - var methodInfo = ResolveMethod(className, methodName, parameterNames); - if (methodInfo == null) return; - if (methodInfo.GetParameters().Any(x => x.ParameterType.IsByRef)) + var method = ResolveMethod(className, methodName, parameterNames); + if (method == null) return; + if (method.GetParameters().Any(x => x.ParameterType.IsByRef)) { throw new InvalidOperationException($"{nameof(HookMethod)} doesn't support ByRef parameters; use {nameof(Patch)} instead."); } - HookMethod(identifier, methodInfo, patch, hookMethodType); + HookMethod(identifier, method, patch, hookMethodType); } protected void HookMethod(string identifier, string className, string methodName, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) => HookMethod(identifier, className, methodName, null, patch, hookMethodType); @@ -216,9 +215,9 @@ namespace Barotrauma HookMethod("", className, methodName, parameterNames, patch, hookMethodType); - public void UnhookMethod(string identifier, MethodInfo method, HookMethodType hookType = HookMethodType.Before) + public void UnhookMethod(string identifier, MethodBase method, HookMethodType hookType = HookMethodType.Before) { - var funcAddr = ((long)method.MethodHandle.GetFunctionPointer()); + var funcAddr = (long)method.MethodHandle.GetFunctionPointer(); Dictionary> methods; if (hookType == HookMethodType.Before) methods = compatHookPrefixMethods; @@ -229,9 +228,9 @@ namespace Barotrauma } protected void UnhookMethod(string identifier, string className, string methodName, string[] parameterNames, HookMethodType hookType = HookMethodType.Before) { - var methodInfo = ResolveMethod(className, methodName, parameterNames); - if (methodInfo == null) return; - UnhookMethod(identifier, methodInfo, hookType); + var method = ResolveMethod(className, methodName, parameterNames); + if (method == null) return; + UnhookMethod(identifier, method, hookType); } } -} \ No newline at end of file +} diff --git a/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs b/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs index c48510345..632ceca38 100644 --- a/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs +++ b/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs @@ -2,6 +2,8 @@ using MoonSharp.Interpreter; using System; using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Text; using System.Threading; using Xunit; @@ -24,70 +26,108 @@ namespace TestProject.LuaCs public void Dispose() => disposeAction(); } - public static PatchHandle AddPrefix(this LuaCsSetup luaCs, string body, string methodName = "Run", string? patchId = null) + private static List BuildHookPatchArgsList( + string? patchId, + string className, + string methodName, + string[]? parameters) { - var className = typeof(T).FullName; - DynValue returnValue; - if (patchId != null) - { - returnValue = luaCs.Lua.DoString(@$" - return Hook.Patch('{patchId}', '{className}', '{methodName}', function(instance, ptable) - {body} - end, Hook.HookMethodType.Before) - "); - } - else - { - returnValue = luaCs.Lua.DoString(@$" - return Hook.Patch('{className}', '{methodName}', function(instance, ptable) - {body} - end, Hook.HookMethodType.Before) - "); - } + static string Stringify(object value) => + "\"" + value.ToString()!.Replace(@"\", @"\\").Replace("\"", "\\\"") + "\""; + var args = new List(); + if (patchId != null) args.Add(Stringify(patchId)); + args.Add(Stringify(className)); + args.Add(Stringify(methodName)); + if (parameters != null && parameters.Length > 0) + { + var sb = new StringBuilder(); + sb.Append("{ "); + foreach (var param in parameters) + { + sb.Append(Stringify(param)); + sb.Append(", "); + } + sb.Append(" }"); + args.Add(sb.ToString()); + } + return args; + } + + private static DynValue DoHookPatch( + this LuaCsSetup luaCs, + string? patchId, + string className, + string methodName, + string[]? parameters, + string function, + LuaCsHook.HookMethodType patchType) + { + var args = BuildHookPatchArgsList(patchId, className, methodName, parameters); + args.Add(function); + args.Add(patchType switch + { + LuaCsHook.HookMethodType.Before => "Hook.HookMethodType.Before", + LuaCsHook.HookMethodType.After => "Hook.HookMethodType.After", + _ => throw new NotImplementedException(), + }); + return luaCs.Lua.DoString($"return Hook.Patch({string.Join(", ", args)})"); + } + + private static DynValue DoHookRemovePatch( + this LuaCsSetup luaCs, + string? patchId, + string className, + string methodName, + string[]? parameters, + LuaCsHook.HookMethodType patchType) + { + var args = BuildHookPatchArgsList(patchId, className, methodName, parameters); + args.Add(patchType switch + { + LuaCsHook.HookMethodType.Before => "Hook.HookMethodType.Before", + LuaCsHook.HookMethodType.After => "Hook.HookMethodType.After", + _ => throw new NotImplementedException(), + }); + return luaCs.Lua.DoString($"return Hook.RemovePatch({string.Join(", ", args)})"); + } + + public static PatchHandle AddPrefix(this LuaCsSetup luaCs, string body, string methodName, string[]? parameters = null, string? patchId = null) + { + var className = typeof(T).FullName!; + var returnValue = luaCs.DoHookPatch(patchId, className, methodName, parameters, @$" + function(instance, ptable) + {body} + end + ", LuaCsHook.HookMethodType.Before); Assert.Equal(DataType.String, returnValue.Type); return new(returnValue.String, () => luaCs.RemovePrefix(returnValue.String, methodName)); } - public static PatchHandle AddPostfix(this LuaCsSetup luaCs, string body, string methodName = "Run", string? patchId = null) + public static PatchHandle AddPostfix(this LuaCsSetup luaCs, string body, string methodName, string[]? parameters = null, string? patchId = null) { - var className = typeof(T).FullName; - DynValue returnValue; - if (patchId != null) - { - returnValue = luaCs.Lua.DoString(@$" - return Hook.Patch('{patchId}', '{className}', '{methodName}', function(instance, ptable) - {body} - end, Hook.HookMethodType.After) - "); - } - else - { - returnValue = luaCs.Lua.DoString(@$" - return Hook.Patch('{className}', '{methodName}', function(instance, ptable) - {body} - end, Hook.HookMethodType.After) - "); - } + var className = typeof(T).FullName!; + var returnValue = luaCs.DoHookPatch(patchId, className, methodName, parameters, @$" + function(instance, ptable) + {body} + end + ", LuaCsHook.HookMethodType.After); + Assert.Equal(DataType.String, returnValue.Type); return new(returnValue.String, () => luaCs.RemovePostfix(returnValue.String, methodName)); } - public static bool RemovePrefix(this LuaCsSetup luaCs, string patchId, string methodName = "Run") + public static bool RemovePrefix(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null) { - var className = typeof(T).FullName; - var returnValue = luaCs.Lua.DoString($@" - return Hook.RemovePatch('{patchId}', '{className}', '{methodName}', Hook.HookMethodType.Before) - "); + var className = typeof(T).FullName!; + var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, LuaCsHook.HookMethodType.Before); Assert.Equal(DataType.Boolean, returnValue.Type); return returnValue.Boolean; } - public static bool RemovePostfix(this LuaCsSetup luaCs, string patchId, string methodName = "Run") + public static bool RemovePostfix(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null) { - var className = typeof(T).FullName; - var returnValue = luaCs.Lua.DoString($@" - return Hook.RemovePatch('{patchId}', '{className}', '{methodName}', Hook.HookMethodType.After) - "); + var className = typeof(T).FullName!; + var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, LuaCsHook.HookMethodType.After); Assert.Equal(DataType.Boolean, returnValue.Type); return returnValue.Boolean; } diff --git a/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs b/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs index e215421ac..bd69ff172 100644 --- a/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs +++ b/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs @@ -31,6 +31,7 @@ namespace TestProject.LuaCs UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); + UserData.RegisterType(); UserData.RegisterType(); luaCs.Initialize(); @@ -53,7 +54,9 @@ namespace TestProject.LuaCs { using var patchTargetHandle = HookPatchHelpers.LockPatchTarget(); var target = new PatchTargetSimple(); - using var patchHandle = luaCs.AddPrefix("ptable.PreventExecution = true"); + using var patchHandle = luaCs.AddPrefix(@" + ptable.PreventExecution = true + ", nameof(PatchTargetSimple.Run)); target.Run(); Assert.False(target.ran); } @@ -66,7 +69,7 @@ namespace TestProject.LuaCs using var patchHandle = luaCs.AddPrefix(@" ptable.PreventExecution = true originalPatchRan = true - ", patchId: "test"); + ", nameof(PatchTargetSimple.Run), patchId: "test"); target.Run(); Assert.False(target.ran); Assert.True(luaCs.Lua.Globals["originalPatchRan"] as bool?); @@ -76,7 +79,9 @@ namespace TestProject.LuaCs luaCs.Lua.Globals["originalPatchRan"] = false; // Replace the existing prefix, but don't prevent execution this time - luaCs.AddPrefix("replacementPatchRan = true", patchId: "test"); + luaCs.AddPrefix(@" + replacementPatchRan = true + ", nameof(PatchTargetSimple.Run), patchId: "test"); target.Run(); Assert.True(target.ran); @@ -95,7 +100,7 @@ namespace TestProject.LuaCs using (var patchHandle = luaCs.AddPrefix(@" ptable.PreventExecution = true patchRan = true - ")) + ", nameof(PatchTargetSimple.Run))) { target.Run(); Assert.False(target.ran); @@ -116,7 +121,7 @@ namespace TestProject.LuaCs var target = new PatchTargetSimple(); using (var patchHandle = luaCs.AddPostfix(@" patchRan = true - ")) + ", nameof(PatchTargetSimple.Run))) { target.Run(); Assert.True(target.ran); @@ -177,7 +182,7 @@ namespace TestProject.LuaCs using var patchHandle = luaCs.AddPrefix(@" ptable.PreventExecution = true return 123 - "); + ", nameof(PatchTargetReturnsObject.Run)); var returnValue = target.Run(); Assert.False(target.ran); Assert.Equal(123, (int)(double)returnValue); @@ -189,7 +194,9 @@ namespace TestProject.LuaCs using var patchTargetHandle = HookPatchHelpers.LockPatchTarget(); var target = new PatchTargetReturnsObject(); // This should have no effect - using var patchHandle = luaCs.AddPrefix("return"); + using var patchHandle = luaCs.AddPrefix(@" + return + ", nameof(PatchTargetReturnsObject.Run)); var returnValue = target.Run(); Assert.True(target.ran); Assert.Equal(5, returnValue); @@ -201,7 +208,9 @@ namespace TestProject.LuaCs using var patchTargetHandle = HookPatchHelpers.LockPatchTarget(); var target = new PatchTargetReturnsObject(); // This should modify the return value to "null" - using var patchHandle = luaCs.AddPostfix("return nil"); + using var patchHandle = luaCs.AddPostfix(@" + return nil + ", nameof(PatchTargetReturnsObject.Run)); var returnValue = target.Run(); Assert.True(target.ran); Assert.Null(returnValue); @@ -214,7 +223,7 @@ namespace TestProject.LuaCs var target = new PatchTargetReturnsObject(); using var patchHandle = luaCs.AddPostfix(@" return TestValueType.__new(100) - "); + ", nameof(PatchTargetSimple.Run)); var returnValue = target.Run(); Assert.True(target.ran); Assert.IsType(returnValue); @@ -238,8 +247,8 @@ namespace TestProject.LuaCs using var patchTargetHandle = HookPatchHelpers.LockPatchTarget(); var target = new PatchTargetReturnsInterface(); using var patchHandle = luaCs.AddPostfix(@" - return InterfaceImplementingType.__new(100); - "); + return InterfaceImplementingType.__new(100) + ", nameof(PatchTargetReturnsInterface.Run)); var returnValue = target.Run()!; Assert.True(target.ran); Assert.Equal(100, returnValue.GetFoo()); @@ -265,7 +274,7 @@ namespace TestProject.LuaCs ptable['a'] = Int32(100) ptable['b'] = 'abc' ptable['refByte'] = Byte(4) - "); + ", nameof(PatchTargetModifyParams.Run)); byte refByte = 123; target.Run(5, out var outString, ref refByte, "foo"); Assert.True(target.ran); @@ -288,13 +297,88 @@ namespace TestProject.LuaCs { using var patchTargetHandle = HookPatchHelpers.LockPatchTarget(); var target = new PatchTargetVector2(); - using var patchHandle = luaCs.AddPrefix("patchRan = true"); + using var patchHandle = luaCs.AddPrefix(@" + patchRan = true + ", nameof(PatchTargetVector2.Run)); var returnValue = target.Run(new Vector2(1, 2)); Assert.True(target.ran); Assert.True(luaCs.Lua.Globals["patchRan"] as bool?); Assert.Equal("{X:1 Y:2}", returnValue); } + private class PatchTargetConstructor + { + public enum CtorType + { + None, + Patched, + Default, + Int, + StringString, + } + + public CtorType Ctor { get; set; } + + public bool PrefixRan { get; set; } + + public PatchTargetConstructor() + { + Ctor = CtorType.Default; + } + + public PatchTargetConstructor(int a = default) + { + Ctor = CtorType.Int; + } + + public PatchTargetConstructor(string a, string b) + { + Ctor = CtorType.StringString; + } + } + + [Fact] + public void TestPatchConstructor() + { + using var patchTargetHandle = HookPatchHelpers.LockPatchTarget(); + + { + using var postfixHandle = luaCs.AddPostfix(@$" + instance.Ctor = {(int)PatchTargetConstructor.CtorType.Patched} + ", ".ctor"); + using var prefixHandle = luaCs.AddPrefix(@$" + instance.PrefixRan = true + ", ".ctor"); + var target = new PatchTargetConstructor(); + Assert.Equal(PatchTargetConstructor.CtorType.Patched, target.Ctor); + Assert.True(target.PrefixRan); + } + + { + using var postfixHandle = luaCs.AddPostfix(@$" + instance.Ctor = {(int)PatchTargetConstructor.CtorType.Patched} + ", ".ctor", new[] { typeof(int).FullName! }); + using var prefixHandle = luaCs.AddPrefix(@$" + instance.PrefixRan = true + ", ".ctor", new[] { typeof(int).FullName! }); + var target = new PatchTargetConstructor(1); + Assert.Equal(PatchTargetConstructor.CtorType.Patched, target.Ctor); + Assert.True(target.PrefixRan); + } + + { + using var postfixHandle = luaCs.AddPostfix(@$" + instance.Ctor = {(int)PatchTargetConstructor.CtorType.Patched} + ", ".ctor", new[] { typeof(string).FullName!, typeof(string).FullName! }); + using var prefixHandle = luaCs.AddPrefix(@$" + instance.PrefixRan = true + ", ".ctor", new[] { typeof(string).FullName!, typeof(string).FullName! }); + var target = new PatchTargetConstructor("", ""); + Assert.Equal(PatchTargetConstructor.CtorType.Patched, target.Ctor); + Assert.True(target.PrefixRan); + } + } + private class PatchTargetNumbers { public bool ran; @@ -367,7 +451,7 @@ namespace TestProject.LuaCs var target = new PatchTargetNumbers(); using var patchHandle = luaCs.AddPrefix(@" ptable['v'] = SByte(-6) - ", methodName: nameof(PatchTargetNumbers.RunSByte)); + ", nameof(PatchTargetNumbers.RunSByte)); var returnValue = target.RunSByte(-5); Assert.True(target.ran); Assert.Equal(-6, returnValue); @@ -380,7 +464,7 @@ namespace TestProject.LuaCs var target = new PatchTargetNumbers(); using var patchHandle = luaCs.AddPrefix(@" ptable['v'] = Byte(6) - ", methodName: nameof(PatchTargetNumbers.RunByte)); + ", nameof(PatchTargetNumbers.RunByte)); var returnValue = target.RunByte(5); Assert.True(target.ran); Assert.Equal(6, returnValue); @@ -393,7 +477,7 @@ namespace TestProject.LuaCs var target = new PatchTargetNumbers(); using var patchHandle = luaCs.AddPrefix(@" ptable['v'] = Int16(-25000) - ", methodName: nameof(PatchTargetNumbers.RunInt16)); + ", nameof(PatchTargetNumbers.RunInt16)); var returnValue = target.RunInt16(30000); Assert.True(target.ran); Assert.Equal(-25000, returnValue); @@ -406,7 +490,7 @@ namespace TestProject.LuaCs var target = new PatchTargetNumbers(); using var patchHandle = luaCs.AddPrefix(@" ptable['v'] = UInt16(60000) - ", methodName: nameof(PatchTargetNumbers.RunUInt16)); + ", nameof(PatchTargetNumbers.RunUInt16)); var returnValue = target.RunUInt16(50000); Assert.True(target.ran); Assert.Equal(60000, returnValue); @@ -419,7 +503,7 @@ namespace TestProject.LuaCs var target = new PatchTargetNumbers(); using var patchHandle = luaCs.AddPrefix(@" ptable['v'] = Int32('7FFFFF00', 16) - ", methodName: nameof(PatchTargetNumbers.RunInt32)); + ", nameof(PatchTargetNumbers.RunInt32)); var returnValue = target.RunInt32(900000); Assert.True(target.ran); Assert.Equal(0x7FFFFF00, returnValue); @@ -432,7 +516,7 @@ namespace TestProject.LuaCs var target = new PatchTargetNumbers(); using var patchHandle = luaCs.AddPrefix(@" ptable['v'] = UInt32('AFFFFFFF', 16) - ", methodName: nameof(PatchTargetNumbers.RunUInt32)); + ", nameof(PatchTargetNumbers.RunUInt32)); var returnValue = target.RunUInt32(300500); Assert.True(target.ran); Assert.Equal(0xAFFFFFFF, returnValue); @@ -445,7 +529,7 @@ namespace TestProject.LuaCs var target = new PatchTargetNumbers(); using var patchHandle = luaCs.AddPrefix(@" ptable['v'] = Int64('7555555555555555', 16) - ", methodName: nameof(PatchTargetNumbers.RunInt64)); + ", nameof(PatchTargetNumbers.RunInt64)); var returnValue = target.RunInt64(0x7FFFFFFF00000000); Assert.True(target.ran); Assert.Equal(0x7555555555555555, returnValue); @@ -458,7 +542,7 @@ namespace TestProject.LuaCs var target = new PatchTargetNumbers(); using var patchHandle = luaCs.AddPrefix(@" ptable['v'] = UInt64('F555555555555555', 16) - ", methodName: nameof(PatchTargetNumbers.RunUInt64)); + ", nameof(PatchTargetNumbers.RunUInt64)); var returnValue = target.RunUInt64(0xFFFFFFFF00000000); Assert.True(target.ran); Assert.Equal(0xF555555555555555, returnValue); @@ -471,7 +555,7 @@ namespace TestProject.LuaCs var target = new PatchTargetNumbers(); using var patchHandle = luaCs.AddPrefix(@" ptable['v'] = Single(123.456) - ", methodName: nameof(PatchTargetNumbers.RunSingle)); + ", nameof(PatchTargetNumbers.RunSingle)); var returnValue = target.RunSingle(111.111f); Assert.True(target.ran); Assert.Equal(123.456f, returnValue); @@ -484,7 +568,7 @@ namespace TestProject.LuaCs var target = new PatchTargetNumbers(); using var patchHandle = luaCs.AddPrefix(@" ptable['v'] = Double(123.456) - ", methodName: nameof(PatchTargetNumbers.RunDouble)); + ", nameof(PatchTargetNumbers.RunDouble)); var returnValue = target.RunDouble(111.111d); Assert.True(target.ran); Assert.Equal(123.456d, returnValue);