Refactor hooking API

This completely changes how method patching works under the hood. Unlike
the previous API (`Hook.HookMethod`), the new API (`Hook.Patch`)
generates a Harmony patch method at runtime, using IL generation.

This fixes methods with ByRef (out/ref) parameters getting silently corrupted
due to the ByRef semantics being lost when passed through the `object[] __args`
parameter.

This new API also makes it possible to:
  - modify parameters (including ByRef params)
  - change the return value to `null` (old API would use `return nil`
    for Harmony control flow)
  - prevent execution of the original method (and other harmony
    patches), independently of modifying the return value
This commit is contained in:
peelz
2022-08-03 21:34:41 -04:00
parent 768abd5ce1
commit 08836088fb
10 changed files with 1223 additions and 482 deletions

View File

@@ -142,6 +142,7 @@
<PackageReference Include="MonoMod.Common" Version="22.5.1.1" />
<PackageReference Include="NVorbis" Version="0.8.6" />
<PackageReference Include="RestSharp" Version="106.13.0" />
<PackageReference Include="Sigil" Version="5.0.0" />
</ItemGroup>
<ItemGroup>
@@ -222,4 +223,4 @@
</PropertyGroup>
<Import Project="../BarotraumaShared/DeployGameAnalytics.props" />
</Project>
</Project>

View File

@@ -134,6 +134,7 @@
<PackageReference Include="MonoMod.Common" Version="22.5.1.1" />
<PackageReference Include="NVorbis" Version="0.8.6" />
<PackageReference Include="RestSharp" Version="106.13.0" />
<PackageReference Include="Sigil" Version="5.0.0" />
</ItemGroup>
<!-- Sourced from https://stackoverflow.com/a/45248069 -->
@@ -224,4 +225,4 @@
</PropertyGroup>
<Import Project="../BarotraumaShared/DeployGameAnalytics.props" />
</Project>
</Project>

View File

@@ -141,6 +141,7 @@
<PackageReference Include="MonoMod.Common" Version="22.5.1.1" />
<PackageReference Include="NVorbis" Version="0.8.6" />
<PackageReference Include="RestSharp" Version="106.13.0" />
<PackageReference Include="Sigil" Version="5.0.0" />
</ItemGroup>
<ItemGroup>

View File

@@ -89,6 +89,7 @@
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.1.0" />
<PackageReference Include="MonoMod.Common" Version="22.5.1.1" />
<PackageReference Include="RestSharp" Version="106.13.0" />
<PackageReference Include="Sigil" Version="5.0.0" />
</ItemGroup>
<ItemGroup>

View File

@@ -86,6 +86,7 @@
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.1.0" />
<PackageReference Include="MonoMod.Common" Version="22.5.1.1" />
<PackageReference Include="RestSharp" Version="106.13.0" />
<PackageReference Include="Sigil" Version="5.0.0" />
</ItemGroup>
<!-- Sourced from https://stackoverflow.com/a/45248069 -->

View File

@@ -88,6 +88,7 @@
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.1.0" />
<PackageReference Include="MonoMod.Common" Version="22.5.1.1" />
<PackageReference Include="RestSharp" Version="106.13.0" />
<PackageReference Include="Sigil" Version="5.0.0" />
</ItemGroup>
<ItemGroup>

View File

@@ -4,10 +4,10 @@ using System.Text;
using MoonSharp.Interpreter;
using Microsoft.Xna.Framework;
using FarseerPhysics.Dynamics;
using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch;
namespace Barotrauma
{
public static class LuaCustomConverters
{
public static void RegisterAll()
@@ -31,8 +31,12 @@ namespace Barotrauma
v => (LuaCsFunc)(args => GameMain.LuaCs.CallLuaFunction(v.Function, args)));
Script.GlobalOptions.CustomConverters.SetScriptToClrCustomConversion(
DataType.Function,
typeof(LuaCsPatch),
v => (LuaCsPatch)((self, args) => GameMain.LuaCs.CallLuaFunction(v.Function, self, args)));
typeof(LuaCsCompatPatchFunc),
v => (LuaCsCompatPatchFunc)((self, args) => GameMain.LuaCs.CallLuaFunction(v.Function, self, args)));
Script.GlobalOptions.CustomConverters.SetScriptToClrCustomConversion(
DataType.Function,
typeof(LuaCsPatchFunc),
v => (LuaCsPatchFunc)((self, args) => GameMain.LuaCs.CallLuaFunction(v.Function, self, args)));
#if CLIENT
RegisterAction<float>();

File diff suppressed because it is too large Load Diff

View File

@@ -1,31 +1,31 @@
using System;
using System.Linq;
using System.Reflection;
using MoonSharp.Interpreter;
using HarmonyLib;
using System.Collections.Generic;
using System.Text;
using MoonSharp.Interpreter.Interop;
using MoonSharp.Interpreter;
using static Barotrauma.LuaCsSetup;
using System.Threading;
using System.Diagnostics;
using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch;
namespace Barotrauma
{
partial class LuaCsHook
{
private Dictionary<long, HashSet<(string, LuaCsPatch, ACsMod)>> compatHookPrefixMethods;
private Dictionary<long, HashSet<(string, LuaCsPatch, ACsMod)>> compatHookPostfixMethods;
// XXX: this can't be renamed because of backward compatibility with C# mods
public delegate object LuaCsPatch(object self, Dictionary<string, object> args);
private static void _hookLuaCsPatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookMethodType)
partial class LuaCsHook
{
private Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>> compatHookPrefixMethods = new Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>>();
private Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>> compatHookPostfixMethods = new Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>>();
private static void _hookLuaCsPatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookType)
{
result = null;
try
{
var funcAddr = ((long)__originalMethod.MethodHandle.GetFunctionPointer());
HashSet<(string, LuaCsPatch, ACsMod)> methodSet = null;
switch (hookMethodType)
HashSet<(string, LuaCsCompatPatchFunc, ACsMod)> methodSet = null;
switch (hookType)
{
case HookMethodType.Before:
instance.compatHookPrefixMethods.TryGetValue(funcAddr, out methodSet);
@@ -34,7 +34,7 @@ namespace Barotrauma
instance.compatHookPostfixMethods.TryGetValue(funcAddr, out methodSet);
break;
default:
break;
throw new ArgumentException($"Invalid {nameof(HookMethodType)} enum value.", nameof(hookType));
}
if (methodSet != null)
@@ -46,7 +46,7 @@ namespace Barotrauma
args.Add(@params[i].Name, __args[i]);
}
var outOfSocpe = new HashSet<(string, LuaCsPatch, ACsMod)>();
var outOfSocpe = new HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>();
foreach (var tuple in methodSet)
{
if (tuple.Item3 != null && tuple.Item3.IsDisposed)
@@ -94,6 +94,7 @@ namespace Barotrauma
_hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before);
return result == null;
}
private static void HookLuaCsPatchPostfix(MethodBase __originalMethod, object[] __args, object __instance) =>
_hookLuaCsPatch(__originalMethod, __args, __instance, out object _, HookMethodType.After);
@@ -107,50 +108,19 @@ namespace Barotrauma
}
else return true;
}
private static void HookLuaCsPatchRetPostfix(MethodBase __originalMethod, object[] __args, ref object __result, object __instance)
{
_hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.After);
if (result != null) __result = result;
}
private static MethodInfo _miHookLuaCsPatchPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo ResolveMethod(string where, string className, string methodName, string[] parameterNames)
{
var classType = LuaUserData.GetType(className);
if (classType == null)
{
GameMain.LuaCs.HandleException(new Exception($"Tried to use {where} with an invalid class name '{className}'."));
return null;
}
MethodInfo methodInfo = null;
if (parameterNames != null)
{
Type[] parameterTypes = parameterNames.Select(x => LuaUserData.GetType(x)).ToArray();
methodInfo = classType.GetMethod(methodName, DefaultBindingFlags, null, parameterTypes, null);
}
else
{
methodInfo = classType.GetMethod(methodName, DefaultBindingFlags);
}
if (methodInfo == null)
{
string parameterNamesStr = parameterNames == null ? "" : string.Join(", ", parameterNames);
GameMain.LuaCs.HandleException(new Exception($"Method '{methodName}' with parameters '{parameterNamesStr}' not found in class '{className}'"));
}
return methodInfo;
}
public void HookMethod(string identifier, MethodInfo method, LuaCsPatch patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null)
public void HookMethod(string identifier, MethodInfo method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null)
{
if (identifier == null || method == null || patch == null)
{
@@ -179,7 +149,7 @@ namespace Barotrauma
}
}
if (compatHookPrefixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsPatch, ACsMod)> methodSet))
if (compatHookPrefixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsCompatPatchFunc, ACsMod)> methodSet))
{
if (identifier != "")
{
@@ -190,7 +160,7 @@ namespace Barotrauma
}
else if (patch != null)
{
compatHookPrefixMethods.Add(funcAddr, new HashSet<(string, LuaCsPatch, ACsMod)>() { (identifier, patch, owner) });
compatHookPrefixMethods.Add(funcAddr, new HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>() { (identifier, patch, owner) });
}
}
@@ -211,7 +181,7 @@ namespace Barotrauma
}
}
if (compatHookPostfixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsPatch, ACsMod)> methodSet))
if (compatHookPostfixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsCompatPatchFunc, ACsMod)> methodSet))
{
if (identifier != "")
{
@@ -222,25 +192,25 @@ namespace Barotrauma
}
else if (patch != null)
{
compatHookPostfixMethods.Add(funcAddr, new HashSet<(string, LuaCsPatch, ACsMod)>() { (identifier, patch, owner) });
compatHookPostfixMethods.Add(funcAddr, new HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>() { (identifier, patch, owner) });
}
}
}
protected void HookMethod(string identifier, string className, string methodName, string[] parameterNames, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before)
protected void HookMethod(string identifier, string className, string methodName, string[] parameterNames, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before)
{
MethodInfo methodInfo = ResolveMethod("HookMethod", className, methodName, parameterNames);
var methodInfo = ResolveMethod(className, methodName, parameterNames);
if (methodInfo == null) return;
if (methodInfo.GetParameters().Any(x => x.ParameterType.IsByRef))
{
throw new InvalidOperationException($"{nameof(HookMethod)} doesn't support ByRef parameters; use {nameof(Patch)} instead.");
}
HookMethod(identifier, methodInfo, patch, hookMethodType);
}
protected void HookMethod(string identifier, string className, string methodName, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) =>
protected void HookMethod(string identifier, string className, string methodName, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) =>
HookMethod(identifier, className, methodName, null, patch, hookMethodType);
protected void HookMethod(string className, string methodName, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) =>
protected void HookMethod(string className, string methodName, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) =>
HookMethod("", className, methodName, null, patch, hookMethodType);
protected void HookMethod(string className, string methodName, string[] parameterNames, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) =>
protected void HookMethod(string className, string methodName, string[] parameterNames, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) =>
HookMethod("", className, methodName, parameterNames, patch, hookMethodType);
@@ -248,7 +218,7 @@ namespace Barotrauma
{
var funcAddr = ((long)method.MethodHandle.GetFunctionPointer());
Dictionary<long, HashSet<(string, LuaCsPatch, ACsMod)>> methods;
Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>> methods;
if (hookType == HookMethodType.Before) methods = compatHookPrefixMethods;
else if (hookType == HookMethodType.After) methods = compatHookPostfixMethods;
else throw null;
@@ -257,7 +227,7 @@ namespace Barotrauma
}
protected void UnhookMethod(string identifier, string className, string methodName, string[] parameterNames, HookMethodType hookType = HookMethodType.Before)
{
MethodInfo methodInfo = ResolveMethod("UnhookMathod", className, methodName, parameterNames);
var methodInfo = ResolveMethod(className, methodName, parameterNames);
if (methodInfo == null) return;
UnhookMethod(identifier, methodInfo, hookType);
}

View File

@@ -12,6 +12,7 @@ using System.Runtime.CompilerServices;
using System.Linq;
using System.Reflection;
using System.Threading;
using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch;
[assembly: InternalsVisibleTo(Barotrauma.CsScriptBase.CsScriptAssembly, AllInternalsVisible = true)]
[assembly: InternalsVisibleTo(Barotrauma.CsScriptBase.CsOneTimeScriptAssembly, AllInternalsVisible = true)]
@@ -396,7 +397,8 @@ namespace Barotrauma
UserData.RegisterType<LuaCsConfig>();
UserData.RegisterType<LuaCsAction>();
UserData.RegisterType<LuaCsFile>();
UserData.RegisterType<LuaCsPatch>();
UserData.RegisterType<LuaCsCompatPatchFunc>();
UserData.RegisterType<LuaCsPatchFunc>();
UserData.RegisterType<LuaCsConfig>();
UserData.RegisterType<CsScriptRunner>();
UserData.RegisterType<LuaGame>();