Add constructor support to Hook.Patch
This commit is contained in:
@@ -519,9 +519,9 @@ namespace Barotrauma
|
||||
"ContentPackageManager",
|
||||
};
|
||||
|
||||
private static void ValidatePatchTarget(MethodInfo methodInfo)
|
||||
private static void ValidatePatchTarget(MethodBase method)
|
||||
{
|
||||
if (prohibitedHooks.Any(h => methodInfo.DeclaringType.FullName.StartsWith(h)))
|
||||
if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h)))
|
||||
{
|
||||
throw new ArgumentException("Hooks into the modding environment are prohibited.");
|
||||
}
|
||||
@@ -575,7 +575,7 @@ namespace Barotrauma
|
||||
return !(left == right);
|
||||
}
|
||||
|
||||
public static MethodKey Create(MethodInfo method) => new MethodKey
|
||||
public static MethodKey Create(MethodBase method) => new MethodKey
|
||||
{
|
||||
ModuleHandle = method.Module.ModuleHandle,
|
||||
MetadataToken = method.MetadataToken,
|
||||
@@ -814,30 +814,37 @@ namespace Barotrauma
|
||||
|
||||
public object Call(string name, params object[] args) => Call<object>(name, args);
|
||||
|
||||
private static MethodInfo ResolveMethod(string className, string methodName, string[] parameterNames)
|
||||
private static MethodBase ResolveMethod(string className, string methodName, string[] parameters)
|
||||
{
|
||||
var classType = LuaUserData.GetType(className);
|
||||
if (classType == null) throw new InvalidOperationException($"Invalid class name '{className}'");
|
||||
|
||||
const BindingFlags BINDING_FLAGS = BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
|
||||
MethodInfo methodInfo = null;
|
||||
if (parameterNames != null)
|
||||
const string CTOR = ".ctor";
|
||||
|
||||
MethodBase method = null;
|
||||
if (parameters != null)
|
||||
{
|
||||
var parameterTypes = parameterNames.Select(x => LuaUserData.GetType(x)).ToArray();
|
||||
methodInfo = classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null);
|
||||
var parameterTypes = parameters.Select(x => LuaUserData.GetType(x)).ToArray();
|
||||
// TODO: remove the casts once we can use C# 9 features
|
||||
method = methodName == CTOR
|
||||
? (MethodBase)classType.GetConstructor(BINDING_FLAGS, null, parameterTypes, null)
|
||||
: (MethodBase)classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null);
|
||||
}
|
||||
else
|
||||
{
|
||||
methodInfo = classType.GetMethod(methodName, BINDING_FLAGS);
|
||||
method = methodName == CTOR
|
||||
? (MethodBase)classType.GetConstructor(BINDING_FLAGS, null, Array.Empty<Type>(), null)
|
||||
: (MethodBase)classType.GetMethod(methodName, BINDING_FLAGS);
|
||||
}
|
||||
|
||||
if (methodInfo == null)
|
||||
if (method == null)
|
||||
{
|
||||
var parameterNamesStr = parameterNames == null ? "" : string.Join(", ", parameterNames);
|
||||
var parameterNamesStr = parameters == null ? "" : string.Join(", ", parameters);
|
||||
throw new InvalidOperationException($"Method '{methodName}({parameterNamesStr})' not found in class '{className}'");
|
||||
}
|
||||
|
||||
return methodInfo;
|
||||
return method;
|
||||
}
|
||||
|
||||
private class DynamicParameterMapping
|
||||
@@ -863,7 +870,7 @@ namespace Barotrauma
|
||||
// If you need to debug this:
|
||||
// - use https://sharplab.io ; it's a very useful for resource for writing IL by hand.
|
||||
// - use il.NewMessage("") or il.WriteLine("") to see where the IL crashes at runtime.
|
||||
private MethodInfo CreateDynamicHarmonyPatch(string identifier, MethodInfo original, HookMethodType hookType)
|
||||
private MethodInfo CreateDynamicHarmonyPatch(string identifier, MethodBase original, HookMethodType hookType)
|
||||
{
|
||||
var parameters = new List<DynamicParameterMapping>
|
||||
{
|
||||
@@ -871,7 +878,7 @@ namespace Barotrauma
|
||||
new DynamicParameterMapping("__instance", null, typeof(object)),
|
||||
};
|
||||
|
||||
var hasReturnType = original.ReturnType != typeof(void);
|
||||
var hasReturnType = original is MethodInfo mi && mi.ReturnType != typeof(void);
|
||||
if (hasReturnType)
|
||||
{
|
||||
parameters.Add(new DynamicParameterMapping("__result", null, typeof(object).MakeByRefType()));
|
||||
@@ -919,7 +926,7 @@ namespace Barotrauma
|
||||
// IL: var patchKey = MethodKey.Create(__originalMethod);
|
||||
var patchKey = il.DeclareLocal<MethodKey>("patchKey");
|
||||
il.LoadArgument(0); // load __originalMethod
|
||||
il.CastClass<MethodInfo>();
|
||||
il.CastClass<MethodBase>();
|
||||
il.Call(typeof(MethodKey).GetMethod(nameof(MethodKey.Create)));
|
||||
il.StoreLocal(patchKey);
|
||||
|
||||
@@ -1032,10 +1039,10 @@ namespace Barotrauma
|
||||
{
|
||||
// IL: var csReturnType = Type.GetTypeFromHandle(<original.ReturnType>);
|
||||
var csReturnType = il.DeclareLocal<Type>("csReturnType");
|
||||
il.LoadType(original.ReturnType);
|
||||
il.LoadType(((MethodInfo)original).ReturnType);
|
||||
il.StoreLocal(csReturnType);
|
||||
|
||||
// IL: var csReturnValue = luaReturnValue.ToObject(csReturnValueType);
|
||||
// IL: var csReturnValue = luaReturnValue.ToObject(csReturnType);
|
||||
var csReturnValue = il.DeclareLocal<object>("csReturnValue");
|
||||
il.LoadLocal(luaReturnValue);
|
||||
il.LoadLocal(csReturnType);
|
||||
@@ -1149,7 +1156,7 @@ namespace Barotrauma
|
||||
return type.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static);
|
||||
}
|
||||
|
||||
private string Patch(string identifier, MethodInfo method, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
||||
private string Patch(string identifier, MethodBase method, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
||||
{
|
||||
if (method == null) throw new ArgumentNullException(nameof(method));
|
||||
if (patch == null) throw new ArgumentNullException(nameof(patch));
|
||||
@@ -1199,29 +1206,29 @@ namespace Barotrauma
|
||||
|
||||
public string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
||||
{
|
||||
var methodInfo = ResolveMethod(className, methodName, parameterTypes);
|
||||
return Patch(identifier, methodInfo, patch, hookType);
|
||||
var method = ResolveMethod(className, methodName, parameterTypes);
|
||||
return Patch(identifier, method, patch, hookType);
|
||||
}
|
||||
|
||||
public string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
||||
{
|
||||
var methodInfo = ResolveMethod(className, methodName, null);
|
||||
return Patch(identifier, methodInfo, patch, hookType);
|
||||
var method = ResolveMethod(className, methodName, null);
|
||||
return Patch(identifier, method, patch, hookType);
|
||||
}
|
||||
|
||||
public string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
||||
{
|
||||
var methodInfo = ResolveMethod(className, methodName, parameterTypes);
|
||||
return Patch(null, methodInfo, patch, hookType);
|
||||
var method = ResolveMethod(className, methodName, parameterTypes);
|
||||
return Patch(null, method, patch, hookType);
|
||||
}
|
||||
|
||||
public string Patch(string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
||||
{
|
||||
var methodInfo = ResolveMethod(className, methodName, null);
|
||||
return Patch(null, methodInfo, patch, hookType);
|
||||
var method = ResolveMethod(className, methodName, null);
|
||||
return Patch(null, method, patch, hookType);
|
||||
}
|
||||
|
||||
private bool RemovePatch(string identifier, MethodInfo method, HookMethodType hookType)
|
||||
private bool RemovePatch(string identifier, MethodBase method, HookMethodType hookType)
|
||||
{
|
||||
if (identifier == null) throw new ArgumentNullException(nameof(identifier));
|
||||
identifier = NormalizeIdentifier(identifier);
|
||||
@@ -1242,14 +1249,14 @@ namespace Barotrauma
|
||||
|
||||
public bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, HookMethodType hookType)
|
||||
{
|
||||
var methodInfo = ResolveMethod(className, methodName, parameterTypes);
|
||||
return RemovePatch(identifier, methodInfo, hookType);
|
||||
var method = ResolveMethod(className, methodName, parameterTypes);
|
||||
return RemovePatch(identifier, method, hookType);
|
||||
}
|
||||
|
||||
public bool RemovePatch(string identifier, string className, string methodName, HookMethodType hookType)
|
||||
{
|
||||
var methodInfo = ResolveMethod(className, methodName, null);
|
||||
return RemovePatch(identifier, methodInfo, hookType);
|
||||
var method = ResolveMethod(className, methodName, null);
|
||||
return RemovePatch(identifier, method, hookType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
using System;
|
||||
using System;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using HarmonyLib;
|
||||
using System.Collections.Generic;
|
||||
using MoonSharp.Interpreter;
|
||||
using static Barotrauma.LuaCsSetup;
|
||||
using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch;
|
||||
|
||||
namespace Barotrauma
|
||||
@@ -122,7 +121,7 @@ namespace Barotrauma
|
||||
private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static);
|
||||
|
||||
// TODO: deprecate this
|
||||
public void HookMethod(string identifier, MethodInfo method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null)
|
||||
public void HookMethod(string identifier, MethodBase method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null)
|
||||
{
|
||||
if (identifier == null || method == null || patch == null)
|
||||
{
|
||||
@@ -136,7 +135,7 @@ namespace Barotrauma
|
||||
|
||||
if (hookType == HookMethodType.Before)
|
||||
{
|
||||
if (method.ReturnType != typeof(void))
|
||||
if (method is MethodInfo mi && mi.ReturnType != typeof(void))
|
||||
{
|
||||
if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchRetPrefix) == null)
|
||||
{
|
||||
@@ -168,7 +167,7 @@ namespace Barotrauma
|
||||
}
|
||||
else if (hookType == HookMethodType.After)
|
||||
{
|
||||
if (method.ReturnType != typeof(void))
|
||||
if (method is MethodInfo mi && mi.ReturnType != typeof(void))
|
||||
{
|
||||
if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchRetPostfix) == null)
|
||||
{
|
||||
@@ -200,13 +199,13 @@ namespace Barotrauma
|
||||
}
|
||||
protected void HookMethod(string identifier, string className, string methodName, string[] parameterNames, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before)
|
||||
{
|
||||
var methodInfo = ResolveMethod(className, methodName, parameterNames);
|
||||
if (methodInfo == null) return;
|
||||
if (methodInfo.GetParameters().Any(x => x.ParameterType.IsByRef))
|
||||
var method = ResolveMethod(className, methodName, parameterNames);
|
||||
if (method == null) return;
|
||||
if (method.GetParameters().Any(x => x.ParameterType.IsByRef))
|
||||
{
|
||||
throw new InvalidOperationException($"{nameof(HookMethod)} doesn't support ByRef parameters; use {nameof(Patch)} instead.");
|
||||
}
|
||||
HookMethod(identifier, methodInfo, patch, hookMethodType);
|
||||
HookMethod(identifier, method, patch, hookMethodType);
|
||||
}
|
||||
protected void HookMethod(string identifier, string className, string methodName, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) =>
|
||||
HookMethod(identifier, className, methodName, null, patch, hookMethodType);
|
||||
@@ -216,9 +215,9 @@ namespace Barotrauma
|
||||
HookMethod("", className, methodName, parameterNames, patch, hookMethodType);
|
||||
|
||||
|
||||
public void UnhookMethod(string identifier, MethodInfo method, HookMethodType hookType = HookMethodType.Before)
|
||||
public void UnhookMethod(string identifier, MethodBase method, HookMethodType hookType = HookMethodType.Before)
|
||||
{
|
||||
var funcAddr = ((long)method.MethodHandle.GetFunctionPointer());
|
||||
var funcAddr = (long)method.MethodHandle.GetFunctionPointer();
|
||||
|
||||
Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>> methods;
|
||||
if (hookType == HookMethodType.Before) methods = compatHookPrefixMethods;
|
||||
@@ -229,9 +228,9 @@ namespace Barotrauma
|
||||
}
|
||||
protected void UnhookMethod(string identifier, string className, string methodName, string[] parameterNames, HookMethodType hookType = HookMethodType.Before)
|
||||
{
|
||||
var methodInfo = ResolveMethod(className, methodName, parameterNames);
|
||||
if (methodInfo == null) return;
|
||||
UnhookMethod(identifier, methodInfo, hookType);
|
||||
var method = ResolveMethod(className, methodName, parameterNames);
|
||||
if (method == null) return;
|
||||
UnhookMethod(identifier, method, hookType);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
using MoonSharp.Interpreter;
|
||||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using Xunit;
|
||||
|
||||
@@ -24,70 +26,108 @@ namespace TestProject.LuaCs
|
||||
public void Dispose() => disposeAction();
|
||||
}
|
||||
|
||||
public static PatchHandle AddPrefix<T>(this LuaCsSetup luaCs, string body, string methodName = "Run", string? patchId = null)
|
||||
private static List<string> BuildHookPatchArgsList(
|
||||
string? patchId,
|
||||
string className,
|
||||
string methodName,
|
||||
string[]? parameters)
|
||||
{
|
||||
var className = typeof(T).FullName;
|
||||
DynValue returnValue;
|
||||
if (patchId != null)
|
||||
{
|
||||
returnValue = luaCs.Lua.DoString(@$"
|
||||
return Hook.Patch('{patchId}', '{className}', '{methodName}', function(instance, ptable)
|
||||
{body}
|
||||
end, Hook.HookMethodType.Before)
|
||||
");
|
||||
}
|
||||
else
|
||||
{
|
||||
returnValue = luaCs.Lua.DoString(@$"
|
||||
return Hook.Patch('{className}', '{methodName}', function(instance, ptable)
|
||||
{body}
|
||||
end, Hook.HookMethodType.Before)
|
||||
");
|
||||
}
|
||||
static string Stringify(object value) =>
|
||||
"\"" + value.ToString()!.Replace(@"\", @"\\").Replace("\"", "\\\"") + "\"";
|
||||
|
||||
var args = new List<string>();
|
||||
if (patchId != null) args.Add(Stringify(patchId));
|
||||
args.Add(Stringify(className));
|
||||
args.Add(Stringify(methodName));
|
||||
if (parameters != null && parameters.Length > 0)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
sb.Append("{ ");
|
||||
foreach (var param in parameters)
|
||||
{
|
||||
sb.Append(Stringify(param));
|
||||
sb.Append(", ");
|
||||
}
|
||||
sb.Append(" }");
|
||||
args.Add(sb.ToString());
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
private static DynValue DoHookPatch(
|
||||
this LuaCsSetup luaCs,
|
||||
string? patchId,
|
||||
string className,
|
||||
string methodName,
|
||||
string[]? parameters,
|
||||
string function,
|
||||
LuaCsHook.HookMethodType patchType)
|
||||
{
|
||||
var args = BuildHookPatchArgsList(patchId, className, methodName, parameters);
|
||||
args.Add(function);
|
||||
args.Add(patchType switch
|
||||
{
|
||||
LuaCsHook.HookMethodType.Before => "Hook.HookMethodType.Before",
|
||||
LuaCsHook.HookMethodType.After => "Hook.HookMethodType.After",
|
||||
_ => throw new NotImplementedException(),
|
||||
});
|
||||
return luaCs.Lua.DoString($"return Hook.Patch({string.Join(", ", args)})");
|
||||
}
|
||||
|
||||
private static DynValue DoHookRemovePatch(
|
||||
this LuaCsSetup luaCs,
|
||||
string? patchId,
|
||||
string className,
|
||||
string methodName,
|
||||
string[]? parameters,
|
||||
LuaCsHook.HookMethodType patchType)
|
||||
{
|
||||
var args = BuildHookPatchArgsList(patchId, className, methodName, parameters);
|
||||
args.Add(patchType switch
|
||||
{
|
||||
LuaCsHook.HookMethodType.Before => "Hook.HookMethodType.Before",
|
||||
LuaCsHook.HookMethodType.After => "Hook.HookMethodType.After",
|
||||
_ => throw new NotImplementedException(),
|
||||
});
|
||||
return luaCs.Lua.DoString($"return Hook.RemovePatch({string.Join(", ", args)})");
|
||||
}
|
||||
|
||||
public static PatchHandle AddPrefix<T>(this LuaCsSetup luaCs, string body, string methodName, string[]? parameters = null, string? patchId = null)
|
||||
{
|
||||
var className = typeof(T).FullName!;
|
||||
var returnValue = luaCs.DoHookPatch(patchId, className, methodName, parameters, @$"
|
||||
function(instance, ptable)
|
||||
{body}
|
||||
end
|
||||
", LuaCsHook.HookMethodType.Before);
|
||||
Assert.Equal(DataType.String, returnValue.Type);
|
||||
return new(returnValue.String, () => luaCs.RemovePrefix<T>(returnValue.String, methodName));
|
||||
}
|
||||
|
||||
public static PatchHandle AddPostfix<T>(this LuaCsSetup luaCs, string body, string methodName = "Run", string? patchId = null)
|
||||
public static PatchHandle AddPostfix<T>(this LuaCsSetup luaCs, string body, string methodName, string[]? parameters = null, string? patchId = null)
|
||||
{
|
||||
var className = typeof(T).FullName;
|
||||
DynValue returnValue;
|
||||
if (patchId != null)
|
||||
{
|
||||
returnValue = luaCs.Lua.DoString(@$"
|
||||
return Hook.Patch('{patchId}', '{className}', '{methodName}', function(instance, ptable)
|
||||
{body}
|
||||
end, Hook.HookMethodType.After)
|
||||
");
|
||||
}
|
||||
else
|
||||
{
|
||||
returnValue = luaCs.Lua.DoString(@$"
|
||||
return Hook.Patch('{className}', '{methodName}', function(instance, ptable)
|
||||
{body}
|
||||
end, Hook.HookMethodType.After)
|
||||
");
|
||||
}
|
||||
var className = typeof(T).FullName!;
|
||||
var returnValue = luaCs.DoHookPatch(patchId, className, methodName, parameters, @$"
|
||||
function(instance, ptable)
|
||||
{body}
|
||||
end
|
||||
", LuaCsHook.HookMethodType.After);
|
||||
Assert.Equal(DataType.String, returnValue.Type);
|
||||
return new(returnValue.String, () => luaCs.RemovePostfix<T>(returnValue.String, methodName));
|
||||
}
|
||||
|
||||
public static bool RemovePrefix<T>(this LuaCsSetup luaCs, string patchId, string methodName = "Run")
|
||||
public static bool RemovePrefix<T>(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null)
|
||||
{
|
||||
var className = typeof(T).FullName;
|
||||
var returnValue = luaCs.Lua.DoString($@"
|
||||
return Hook.RemovePatch('{patchId}', '{className}', '{methodName}', Hook.HookMethodType.Before)
|
||||
");
|
||||
var className = typeof(T).FullName!;
|
||||
var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, LuaCsHook.HookMethodType.Before);
|
||||
Assert.Equal(DataType.Boolean, returnValue.Type);
|
||||
return returnValue.Boolean;
|
||||
}
|
||||
|
||||
public static bool RemovePostfix<T>(this LuaCsSetup luaCs, string patchId, string methodName = "Run")
|
||||
public static bool RemovePostfix<T>(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null)
|
||||
{
|
||||
var className = typeof(T).FullName;
|
||||
var returnValue = luaCs.Lua.DoString($@"
|
||||
return Hook.RemovePatch('{patchId}', '{className}', '{methodName}', Hook.HookMethodType.After)
|
||||
");
|
||||
var className = typeof(T).FullName!;
|
||||
var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, LuaCsHook.HookMethodType.After);
|
||||
Assert.Equal(DataType.Boolean, returnValue.Type);
|
||||
return returnValue.Boolean;
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ namespace TestProject.LuaCs
|
||||
UserData.RegisterType<PatchTargetReturnsInterface>();
|
||||
UserData.RegisterType<PatchTargetModifyParams>();
|
||||
UserData.RegisterType<PatchTargetVector2>();
|
||||
UserData.RegisterType<PatchTargetConstructor>();
|
||||
UserData.RegisterType<PatchTargetNumbers>();
|
||||
|
||||
luaCs.Initialize();
|
||||
@@ -53,7 +54,9 @@ namespace TestProject.LuaCs
|
||||
{
|
||||
using var patchTargetHandle = HookPatchHelpers.LockPatchTarget<PatchTargetSimple>();
|
||||
var target = new PatchTargetSimple();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetSimple>("ptable.PreventExecution = true");
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetSimple>(@"
|
||||
ptable.PreventExecution = true
|
||||
", nameof(PatchTargetSimple.Run));
|
||||
target.Run();
|
||||
Assert.False(target.ran);
|
||||
}
|
||||
@@ -66,7 +69,7 @@ namespace TestProject.LuaCs
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetSimple>(@"
|
||||
ptable.PreventExecution = true
|
||||
originalPatchRan = true
|
||||
", patchId: "test");
|
||||
", nameof(PatchTargetSimple.Run), patchId: "test");
|
||||
target.Run();
|
||||
Assert.False(target.ran);
|
||||
Assert.True(luaCs.Lua.Globals["originalPatchRan"] as bool?);
|
||||
@@ -76,7 +79,9 @@ namespace TestProject.LuaCs
|
||||
luaCs.Lua.Globals["originalPatchRan"] = false;
|
||||
|
||||
// Replace the existing prefix, but don't prevent execution this time
|
||||
luaCs.AddPrefix<PatchTargetSimple>("replacementPatchRan = true", patchId: "test");
|
||||
luaCs.AddPrefix<PatchTargetSimple>(@"
|
||||
replacementPatchRan = true
|
||||
", nameof(PatchTargetSimple.Run), patchId: "test");
|
||||
target.Run();
|
||||
Assert.True(target.ran);
|
||||
|
||||
@@ -95,7 +100,7 @@ namespace TestProject.LuaCs
|
||||
using (var patchHandle = luaCs.AddPrefix<PatchTargetSimple>(@"
|
||||
ptable.PreventExecution = true
|
||||
patchRan = true
|
||||
"))
|
||||
", nameof(PatchTargetSimple.Run)))
|
||||
{
|
||||
target.Run();
|
||||
Assert.False(target.ran);
|
||||
@@ -116,7 +121,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetSimple();
|
||||
using (var patchHandle = luaCs.AddPostfix<PatchTargetSimple>(@"
|
||||
patchRan = true
|
||||
"))
|
||||
", nameof(PatchTargetSimple.Run)))
|
||||
{
|
||||
target.Run();
|
||||
Assert.True(target.ran);
|
||||
@@ -177,7 +182,7 @@ namespace TestProject.LuaCs
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetReturnsObject>(@"
|
||||
ptable.PreventExecution = true
|
||||
return 123
|
||||
");
|
||||
", nameof(PatchTargetReturnsObject.Run));
|
||||
var returnValue = target.Run();
|
||||
Assert.False(target.ran);
|
||||
Assert.Equal(123, (int)(double)returnValue);
|
||||
@@ -189,7 +194,9 @@ namespace TestProject.LuaCs
|
||||
using var patchTargetHandle = HookPatchHelpers.LockPatchTarget<PatchTargetReturnsObject>();
|
||||
var target = new PatchTargetReturnsObject();
|
||||
// This should have no effect
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetReturnsObject>("return");
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetReturnsObject>(@"
|
||||
return
|
||||
", nameof(PatchTargetReturnsObject.Run));
|
||||
var returnValue = target.Run();
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(5, returnValue);
|
||||
@@ -201,7 +208,9 @@ namespace TestProject.LuaCs
|
||||
using var patchTargetHandle = HookPatchHelpers.LockPatchTarget<PatchTargetReturnsObject>();
|
||||
var target = new PatchTargetReturnsObject();
|
||||
// This should modify the return value to "null"
|
||||
using var patchHandle = luaCs.AddPostfix<PatchTargetReturnsObject>("return nil");
|
||||
using var patchHandle = luaCs.AddPostfix<PatchTargetReturnsObject>(@"
|
||||
return nil
|
||||
", nameof(PatchTargetReturnsObject.Run));
|
||||
var returnValue = target.Run();
|
||||
Assert.True(target.ran);
|
||||
Assert.Null(returnValue);
|
||||
@@ -214,7 +223,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetReturnsObject();
|
||||
using var patchHandle = luaCs.AddPostfix<PatchTargetReturnsObject>(@"
|
||||
return TestValueType.__new(100)
|
||||
");
|
||||
", nameof(PatchTargetSimple.Run));
|
||||
var returnValue = target.Run();
|
||||
Assert.True(target.ran);
|
||||
Assert.IsType<TestValueType>(returnValue);
|
||||
@@ -238,8 +247,8 @@ namespace TestProject.LuaCs
|
||||
using var patchTargetHandle = HookPatchHelpers.LockPatchTarget<PatchTargetReturnsInterface>();
|
||||
var target = new PatchTargetReturnsInterface();
|
||||
using var patchHandle = luaCs.AddPostfix<PatchTargetReturnsInterface>(@"
|
||||
return InterfaceImplementingType.__new(100);
|
||||
");
|
||||
return InterfaceImplementingType.__new(100)
|
||||
", nameof(PatchTargetReturnsInterface.Run));
|
||||
var returnValue = target.Run()!;
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(100, returnValue.GetFoo());
|
||||
@@ -265,7 +274,7 @@ namespace TestProject.LuaCs
|
||||
ptable['a'] = Int32(100)
|
||||
ptable['b'] = 'abc'
|
||||
ptable['refByte'] = Byte(4)
|
||||
");
|
||||
", nameof(PatchTargetModifyParams.Run));
|
||||
byte refByte = 123;
|
||||
target.Run(5, out var outString, ref refByte, "foo");
|
||||
Assert.True(target.ran);
|
||||
@@ -288,13 +297,88 @@ namespace TestProject.LuaCs
|
||||
{
|
||||
using var patchTargetHandle = HookPatchHelpers.LockPatchTarget<PatchTargetVector2>();
|
||||
var target = new PatchTargetVector2();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetVector2>("patchRan = true");
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetVector2>(@"
|
||||
patchRan = true
|
||||
", nameof(PatchTargetVector2.Run));
|
||||
var returnValue = target.Run(new Vector2(1, 2));
|
||||
Assert.True(target.ran);
|
||||
Assert.True(luaCs.Lua.Globals["patchRan"] as bool?);
|
||||
Assert.Equal("{X:1 Y:2}", returnValue);
|
||||
}
|
||||
|
||||
private class PatchTargetConstructor
|
||||
{
|
||||
public enum CtorType
|
||||
{
|
||||
None,
|
||||
Patched,
|
||||
Default,
|
||||
Int,
|
||||
StringString,
|
||||
}
|
||||
|
||||
public CtorType Ctor { get; set; }
|
||||
|
||||
public bool PrefixRan { get; set; }
|
||||
|
||||
public PatchTargetConstructor()
|
||||
{
|
||||
Ctor = CtorType.Default;
|
||||
}
|
||||
|
||||
public PatchTargetConstructor(int a = default)
|
||||
{
|
||||
Ctor = CtorType.Int;
|
||||
}
|
||||
|
||||
public PatchTargetConstructor(string a, string b)
|
||||
{
|
||||
Ctor = CtorType.StringString;
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestPatchConstructor()
|
||||
{
|
||||
using var patchTargetHandle = HookPatchHelpers.LockPatchTarget<PatchTargetConstructor>();
|
||||
|
||||
{
|
||||
using var postfixHandle = luaCs.AddPostfix<PatchTargetConstructor>(@$"
|
||||
instance.Ctor = {(int)PatchTargetConstructor.CtorType.Patched}
|
||||
", ".ctor");
|
||||
using var prefixHandle = luaCs.AddPrefix<PatchTargetConstructor>(@$"
|
||||
instance.PrefixRan = true
|
||||
", ".ctor");
|
||||
var target = new PatchTargetConstructor();
|
||||
Assert.Equal(PatchTargetConstructor.CtorType.Patched, target.Ctor);
|
||||
Assert.True(target.PrefixRan);
|
||||
}
|
||||
|
||||
{
|
||||
using var postfixHandle = luaCs.AddPostfix<PatchTargetConstructor>(@$"
|
||||
instance.Ctor = {(int)PatchTargetConstructor.CtorType.Patched}
|
||||
", ".ctor", new[] { typeof(int).FullName! });
|
||||
using var prefixHandle = luaCs.AddPrefix<PatchTargetConstructor>(@$"
|
||||
instance.PrefixRan = true
|
||||
", ".ctor", new[] { typeof(int).FullName! });
|
||||
var target = new PatchTargetConstructor(1);
|
||||
Assert.Equal(PatchTargetConstructor.CtorType.Patched, target.Ctor);
|
||||
Assert.True(target.PrefixRan);
|
||||
}
|
||||
|
||||
{
|
||||
using var postfixHandle = luaCs.AddPostfix<PatchTargetConstructor>(@$"
|
||||
instance.Ctor = {(int)PatchTargetConstructor.CtorType.Patched}
|
||||
", ".ctor", new[] { typeof(string).FullName!, typeof(string).FullName! });
|
||||
using var prefixHandle = luaCs.AddPrefix<PatchTargetConstructor>(@$"
|
||||
instance.PrefixRan = true
|
||||
", ".ctor", new[] { typeof(string).FullName!, typeof(string).FullName! });
|
||||
var target = new PatchTargetConstructor("", "");
|
||||
Assert.Equal(PatchTargetConstructor.CtorType.Patched, target.Ctor);
|
||||
Assert.True(target.PrefixRan);
|
||||
}
|
||||
}
|
||||
|
||||
private class PatchTargetNumbers
|
||||
{
|
||||
public bool ran;
|
||||
@@ -367,7 +451,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetNumbers();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetNumbers>(@"
|
||||
ptable['v'] = SByte(-6)
|
||||
", methodName: nameof(PatchTargetNumbers.RunSByte));
|
||||
", nameof(PatchTargetNumbers.RunSByte));
|
||||
var returnValue = target.RunSByte(-5);
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(-6, returnValue);
|
||||
@@ -380,7 +464,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetNumbers();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetNumbers>(@"
|
||||
ptable['v'] = Byte(6)
|
||||
", methodName: nameof(PatchTargetNumbers.RunByte));
|
||||
", nameof(PatchTargetNumbers.RunByte));
|
||||
var returnValue = target.RunByte(5);
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(6, returnValue);
|
||||
@@ -393,7 +477,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetNumbers();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetNumbers>(@"
|
||||
ptable['v'] = Int16(-25000)
|
||||
", methodName: nameof(PatchTargetNumbers.RunInt16));
|
||||
", nameof(PatchTargetNumbers.RunInt16));
|
||||
var returnValue = target.RunInt16(30000);
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(-25000, returnValue);
|
||||
@@ -406,7 +490,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetNumbers();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetNumbers>(@"
|
||||
ptable['v'] = UInt16(60000)
|
||||
", methodName: nameof(PatchTargetNumbers.RunUInt16));
|
||||
", nameof(PatchTargetNumbers.RunUInt16));
|
||||
var returnValue = target.RunUInt16(50000);
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(60000, returnValue);
|
||||
@@ -419,7 +503,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetNumbers();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetNumbers>(@"
|
||||
ptable['v'] = Int32('7FFFFF00', 16)
|
||||
", methodName: nameof(PatchTargetNumbers.RunInt32));
|
||||
", nameof(PatchTargetNumbers.RunInt32));
|
||||
var returnValue = target.RunInt32(900000);
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(0x7FFFFF00, returnValue);
|
||||
@@ -432,7 +516,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetNumbers();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetNumbers>(@"
|
||||
ptable['v'] = UInt32('AFFFFFFF', 16)
|
||||
", methodName: nameof(PatchTargetNumbers.RunUInt32));
|
||||
", nameof(PatchTargetNumbers.RunUInt32));
|
||||
var returnValue = target.RunUInt32(300500);
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(0xAFFFFFFF, returnValue);
|
||||
@@ -445,7 +529,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetNumbers();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetNumbers>(@"
|
||||
ptable['v'] = Int64('7555555555555555', 16)
|
||||
", methodName: nameof(PatchTargetNumbers.RunInt64));
|
||||
", nameof(PatchTargetNumbers.RunInt64));
|
||||
var returnValue = target.RunInt64(0x7FFFFFFF00000000);
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(0x7555555555555555, returnValue);
|
||||
@@ -458,7 +542,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetNumbers();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetNumbers>(@"
|
||||
ptable['v'] = UInt64('F555555555555555', 16)
|
||||
", methodName: nameof(PatchTargetNumbers.RunUInt64));
|
||||
", nameof(PatchTargetNumbers.RunUInt64));
|
||||
var returnValue = target.RunUInt64(0xFFFFFFFF00000000);
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(0xF555555555555555, returnValue);
|
||||
@@ -471,7 +555,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetNumbers();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetNumbers>(@"
|
||||
ptable['v'] = Single(123.456)
|
||||
", methodName: nameof(PatchTargetNumbers.RunSingle));
|
||||
", nameof(PatchTargetNumbers.RunSingle));
|
||||
var returnValue = target.RunSingle(111.111f);
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(123.456f, returnValue);
|
||||
@@ -484,7 +568,7 @@ namespace TestProject.LuaCs
|
||||
var target = new PatchTargetNumbers();
|
||||
using var patchHandle = luaCs.AddPrefix<PatchTargetNumbers>(@"
|
||||
ptable['v'] = Double(123.456)
|
||||
", methodName: nameof(PatchTargetNumbers.RunDouble));
|
||||
", nameof(PatchTargetNumbers.RunDouble));
|
||||
var returnValue = target.RunDouble(111.111d);
|
||||
Assert.True(target.ran);
|
||||
Assert.Equal(123.456d, returnValue);
|
||||
|
||||
Reference in New Issue
Block a user