From 6daf410e50c92b07a27d603f13e5435d8122ae93 Mon Sep 17 00:00:00 2001 From: zhurengong <2731412072@qq.com> Date: Tue, 25 Jan 2022 21:54:56 +0800 Subject: [PATCH] Improve Harmony Hook 1. Hook.HookMethod can patch overloaded methods based on parameter types 2. Hook.HookMethod can add more patches to the method 3. Find all hook methods by the address of the function (origin method), not by the method path (string). 4. Fixed unable to add postfix patch to method (__params changed to __args) 5. Fix patching method will cause duplicate patches to be added --- .../BarotraumaShared/Lua/DefaultHook.lua | 109 ++++++++---- .../SharedSource/Lua/LuaClasses.cs | 163 +++++++++++------- 2 files changed, 184 insertions(+), 88 deletions(-) diff --git a/Barotrauma/BarotraumaShared/Lua/DefaultHook.lua b/Barotrauma/BarotraumaShared/Lua/DefaultHook.lua index 05fe51ecc..15a3731d4 100644 --- a/Barotrauma/BarotraumaShared/Lua/DefaultHook.lua +++ b/Barotrauma/BarotraumaShared/Lua/DefaultHook.lua @@ -1,36 +1,85 @@ -Hook.HookMethod("Barotrauma.Item", "TryInteract", function (instance, p) - if Hook.Call("itemInteract", instance, p.picker, p.ignoreRequiredItems, p.forceSelectKey, p.forceActionKey) == true then - return false - end -end, Hook.HookMethodType.Before) +Hook.HookMethod( + "Barotrauma.Item", "TryInteract", + { + "Barotrauma.Character", + "System.Boolean", + "System.Boolean", + "System.Boolean" + }, + function (instance, p) + if Hook.Call("itemInteract", instance, p.picker, p.ignoreRequiredItems, p.forceSelectKey, p.forceActionKey) == true then + return false + end + end, + Hook.HookMethodType.Before +) -Hook.HookMethod("Barotrauma.Item", "ApplyTreatment", function (instance, p) - if Hook.Call("itemApplyTreatment", instance, p.user, p.character, p.targetLimb) then - return false - end -end, Hook.HookMethodType.Before) +Hook.HookMethod( + "Barotrauma.Item", "ApplyTreatment", + { + "Barotrauma.Character", + "Barotrauma.Character", + "Barotrauma.Limb" + }, + function (instance, p) + if Hook.Call("itemApplyTreatment", instance, p.user, p.character, p.targetLimb) then + return false + end + end, + Hook.HookMethodType.Before +) -Hook.HookMethod("Barotrauma.Item", "Combine", function (instance, p) - if Hook.Call("itemCombine", instance, p.item, p.user) == true then - return false - end -end, Hook.HookMethodType.Before) +Hook.HookMethod( + "Barotrauma.Item", "Combine", + { + "Barotrauma.Item", + "Barotrauma.Character" + }, + function (instance, p) + if Hook.Call("itemCombine", instance, p.item, p.user) == true then + return false + end + end, + Hook.HookMethodType.Before +) -Hook.HookMethod("Barotrauma.Item", "Drop", function (instance, p) - if Hook.Call("itemDrop", instance, p.dropper) == true then - return false - end -end, Hook.HookMethodType.Before) +Hook.HookMethod( + "Barotrauma.Item", "Drop", + { + "Barotrauma.Character", + "System.Boolean" + }, + function (instance, p) + if Hook.Call("itemDrop", instance, p.dropper) == true then + return false + end + end, + Hook.HookMethodType.Before +) -Hook.HookMethod("Barotrauma.Item", "Equip", function (instance, p) - if Hook.Call("itemEquip", instance, p.character) == true then - return false - end -end, Hook.HookMethodType.Before) +Hook.HookMethod( + "Barotrauma.Item", "Equip", + { + "Barotrauma.Character" + }, + function (instance, p) + if Hook.Call("itemEquip", instance, p.character) == true then + return false + end + end, + Hook.HookMethodType.Before +) -Hook.HookMethod("Barotrauma.Item", "Unequip", function (instance, p) - if Hook.Call("itemUnequip", instance, p.character) == true then - return false - end -end, Hook.HookMethodType.Before) \ No newline at end of file +Hook.HookMethod( + "Barotrauma.Item", "Unequip", + { + "Barotrauma.Character" + }, + function (instance, p) + if Hook.Call("itemUnequip", instance, p.character) == true then + return false + end + end, + Hook.HookMethodType.Before +) diff --git a/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses.cs b/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses.cs index 234df3108..09e283d1a 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses.cs @@ -799,7 +799,8 @@ namespace Barotrauma public LuaHook(LuaSetup e) { env = e; - _hookMethods = new Dictionary(); + _hookPrefixMethods = new Dictionary>(); + _hookPostfixMethods = new Dictionary>(); } public class HookFunction @@ -818,7 +819,8 @@ namespace Barotrauma private Dictionary> hookFunctions = new Dictionary>(); - private static Dictionary _hookMethods; + private static Dictionary> _hookPrefixMethods; + private static Dictionary> _hookPostfixMethods; private Queue> queuedFunctionCalls = new Queue>(); @@ -829,29 +831,36 @@ namespace Barotrauma static void _hookLuaPatch(MethodBase __originalMethod, object[] __args, object __instance, out LuaResult result, HookMethodType hookMethodType) { - // Although it works correctly, the performance is low result = new LuaResult(null); - + try { - var classType = __originalMethod.DeclaringType; - var methodPath = $"{hookMethodType}:{classType.Namespace}.{classType.Name}.{__originalMethod.Name}"; - var @params = __originalMethod.GetParameters(); var ptable = new Dictionary(); for (int i = 0; i < @params.Length; i++) { ptable.Add(@params[i].Name, __args[i]); - } - if (_hookMethods.TryGetValue(methodPath, out object hookMethod)) - { - result = new LuaResult(luaSetup.hook.env.lua.Call(hookMethod, __instance, ptable)); - } - else + var funcAddr = ((long)__originalMethod.MethodHandle.GetFunctionPointer()); + HashSet methodSet = null; + switch (hookMethodType) { - luaSetup.PrintError($"No hook method found in _hookMethods[{methodPath}]"); + case HookMethodType.Before: + methodSet = _hookPrefixMethods[funcAddr]; + break; + case HookMethodType.After: + methodSet = _hookPostfixMethods[funcAddr]; + break; + default: + break; + } + if (methodSet != null) + { + foreach (var hookMethod in methodSet) + { + result = new LuaResult(luaSetup.lua.Call(hookMethod, __instance, ptable)); + } } } @@ -881,14 +890,14 @@ namespace Barotrauma return true; } - static void HookLuaPatchPostfix(MethodBase __originalMethod, object[] __params, object __instance) + static void HookLuaPatchPostfix(MethodBase __originalMethod, object[] __args, object __instance) { - _hookLuaPatch(__originalMethod, __params, __instance, out LuaResult result, HookMethodType.After); + _hookLuaPatch(__originalMethod, __args, __instance, out LuaResult result, HookMethodType.After); } - static void HookLuaPatchRetPostfix(MethodBase __originalMethod, object[] __params, ref object __result, object __instance) + static void HookLuaPatchRetPostfix(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) { - _hookLuaPatch(__originalMethod, __params, __instance, out LuaResult result, HookMethodType.After); + _hookLuaPatch(__originalMethod, __args, __instance, out LuaResult result, HookMethodType.After); if (!result.IsNull()) __result = result.Object(); @@ -898,61 +907,99 @@ namespace Barotrauma private static MethodInfo _miHookLuaPatchRetPrefix = typeof(LuaHook).GetMethod("HookLuaPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static); private static MethodInfo _miHookLuaPatchPostfix = typeof(LuaHook).GetMethod("HookLuaPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); private static MethodInfo _miHookLuaPatchRetPostfix = typeof(LuaHook).GetMethod("HookLuaPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); - public void HookMethod(string className, string methodName, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) + public void HookMethod(string className, string methodName, string[] parameterNames, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) { + if (hookMethod == null) + { + env.PrintError("hookMethod cannot be null"); + return; + } + var classType = Type.GetType(className); - var methodInfos = classType.GetMethods(); - HarmonyMethod harmonyMethod = new HarmonyMethod(); - HarmonyMethod harmonyMethodRet = new HarmonyMethod(); + MethodInfo methodInfo = null; + + if (parameterNames.Length > 0) + { + Type[] parameterTypes = parameterNames.Select(x => AccessTools.TypeByName(x)).ToArray(); + methodInfo = classType.GetMethod(methodName, parameterTypes); + } + else + { + methodInfo = classType.GetMethod(methodName); + } + + if (methodInfo == null) + { + env.PrintError($"can't find method({className}.{methodName}) with these parameters' types({string.Join(", ", parameterNames)})"); + return; + } + + var funcAddr = ((long)methodInfo.MethodHandle.GetFunctionPointer()); + var patches = Harmony.GetPatchInfo(methodInfo); if (hookMethodType == HookMethodType.Before) { - harmonyMethod = new HarmonyMethod(_miHookLuaPatchPrefix); - harmonyMethodRet = new HarmonyMethod(_miHookLuaPatchRetPrefix); + if (methodInfo.ReturnType == typeof(void)) + { + if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaPatchPrefix) == null) + { + env.harmony.Patch(methodInfo, prefix: new HarmonyMethod(_miHookLuaPatchPrefix)); + } + } + else + { + if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaPatchRetPrefix) == null) + { + env.harmony.Patch(methodInfo, prefix: new HarmonyMethod(_miHookLuaPatchRetPrefix)); + } + } + + if (_hookPrefixMethods.TryGetValue(funcAddr, out HashSet methodSet)) + { + methodSet.Add(hookMethod); + } + else + { + _hookPrefixMethods.Add(funcAddr, new HashSet() { hookMethod }); + } + } else if (hookMethodType == HookMethodType.After) - { - harmonyMethod = new HarmonyMethod(_miHookLuaPatchPostfix); - harmonyMethodRet = new HarmonyMethod(_miHookLuaPatchRetPostfix); - } - - foreach (var methodInfo in methodInfos) - { - if (methodInfo.Name == methodName) + { + if (methodInfo.ReturnType == typeof(void)) { - if (hookMethodType == HookMethodType.Before) - { - if (methodInfo.ReturnType == typeof(void)) - env.harmony.Patch(methodInfo, prefix: harmonyMethod); - else - env.harmony.Patch(methodInfo, prefix: harmonyMethodRet); - } - else if (hookMethodType == HookMethodType.After) - { - if (methodInfo.ReturnType == typeof(void)) - env.harmony.Patch(methodInfo, postfix: harmonyMethod); - else - env.harmony.Patch(methodInfo, postfix: harmonyMethodRet); - } - - // build an unique method path by patch type, class, method self - var methodPath = $"{hookMethodType}:{classType.Namespace}.{classType.Name}.{methodInfo.Name}"; - - if (hookMethod != null) + if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaPatchPostfix) == null) { - if (!_hookMethods.TryAdd(methodPath, hookMethod)) - env.PrintError($"Failed to add key-value in {nameof(_hookMethods)}\n[{methodPath}, {hookMethod.ToString()}]"); -#if DEBUG - else - env.PrintMessage($"Sucessfully added key-value in {nameof(_hookMethods)}\n[{methodPath}, {hookMethod.ToString()}]"); -#endif + env.harmony.Patch(methodInfo, postfix: new HarmonyMethod(_miHookLuaPatchPostfix)); } - break; } + else + { + if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaPatchRetPostfix) == null) + { + env.harmony.Patch(methodInfo, postfix: new HarmonyMethod(_miHookLuaPatchRetPostfix)); + } + } + + if (_hookPostfixMethods.TryGetValue(funcAddr, out HashSet methodSet)) + { + methodSet.Add(hookMethod); + } + else + { + _hookPostfixMethods.Add(funcAddr, new HashSet() { hookMethod }); + } + } } + public void HookMethod(string className, string methodName, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) + { + HookMethod(className, methodName, new string[] {}, hookMethod, hookMethodType); + } + + public void EnqueueFunction(object function, params object[] args) { queuedFunctionCalls.Enqueue(new Tuple(function, args));