diff --git a/Barotrauma/BarotraumaShared/Lua/DefaultHook.lua b/Barotrauma/BarotraumaShared/Lua/DefaultHook.lua index 05fe51ecc..15a3731d4 100644 --- a/Barotrauma/BarotraumaShared/Lua/DefaultHook.lua +++ b/Barotrauma/BarotraumaShared/Lua/DefaultHook.lua @@ -1,36 +1,85 @@ -Hook.HookMethod("Barotrauma.Item", "TryInteract", function (instance, p) - if Hook.Call("itemInteract", instance, p.picker, p.ignoreRequiredItems, p.forceSelectKey, p.forceActionKey) == true then - return false - end -end, Hook.HookMethodType.Before) +Hook.HookMethod( + "Barotrauma.Item", "TryInteract", + { + "Barotrauma.Character", + "System.Boolean", + "System.Boolean", + "System.Boolean" + }, + function (instance, p) + if Hook.Call("itemInteract", instance, p.picker, p.ignoreRequiredItems, p.forceSelectKey, p.forceActionKey) == true then + return false + end + end, + Hook.HookMethodType.Before +) -Hook.HookMethod("Barotrauma.Item", "ApplyTreatment", function (instance, p) - if Hook.Call("itemApplyTreatment", instance, p.user, p.character, p.targetLimb) then - return false - end -end, Hook.HookMethodType.Before) +Hook.HookMethod( + "Barotrauma.Item", "ApplyTreatment", + { + "Barotrauma.Character", + "Barotrauma.Character", + "Barotrauma.Limb" + }, + function (instance, p) + if Hook.Call("itemApplyTreatment", instance, p.user, p.character, p.targetLimb) then + return false + end + end, + Hook.HookMethodType.Before +) -Hook.HookMethod("Barotrauma.Item", "Combine", function (instance, p) - if Hook.Call("itemCombine", instance, p.item, p.user) == true then - return false - end -end, Hook.HookMethodType.Before) +Hook.HookMethod( + "Barotrauma.Item", "Combine", + { + "Barotrauma.Item", + "Barotrauma.Character" + }, + function (instance, p) + if Hook.Call("itemCombine", instance, p.item, p.user) == true then + return false + end + end, + Hook.HookMethodType.Before +) -Hook.HookMethod("Barotrauma.Item", "Drop", function (instance, p) - if Hook.Call("itemDrop", instance, p.dropper) == true then - return false - end -end, Hook.HookMethodType.Before) +Hook.HookMethod( + "Barotrauma.Item", "Drop", + { + "Barotrauma.Character", + "System.Boolean" + }, + function (instance, p) + if Hook.Call("itemDrop", instance, p.dropper) == true then + return false + end + end, + Hook.HookMethodType.Before +) -Hook.HookMethod("Barotrauma.Item", "Equip", function (instance, p) - if Hook.Call("itemEquip", instance, p.character) == true then - return false - end -end, Hook.HookMethodType.Before) +Hook.HookMethod( + "Barotrauma.Item", "Equip", + { + "Barotrauma.Character" + }, + function (instance, p) + if Hook.Call("itemEquip", instance, p.character) == true then + return false + end + end, + Hook.HookMethodType.Before +) -Hook.HookMethod("Barotrauma.Item", "Unequip", function (instance, p) - if Hook.Call("itemUnequip", instance, p.character) == true then - return false - end -end, Hook.HookMethodType.Before) \ No newline at end of file +Hook.HookMethod( + "Barotrauma.Item", "Unequip", + { + "Barotrauma.Character" + }, + function (instance, p) + if Hook.Call("itemUnequip", instance, p.character) == true then + return false + end + end, + Hook.HookMethodType.Before +) diff --git a/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses.cs b/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses.cs index 234df3108..09e283d1a 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses.cs @@ -799,7 +799,8 @@ namespace Barotrauma public LuaHook(LuaSetup e) { env = e; - _hookMethods = new Dictionary(); + _hookPrefixMethods = new Dictionary>(); + _hookPostfixMethods = new Dictionary>(); } public class HookFunction @@ -818,7 +819,8 @@ namespace Barotrauma private Dictionary> hookFunctions = new Dictionary>(); - private static Dictionary _hookMethods; + private static Dictionary> _hookPrefixMethods; + private static Dictionary> _hookPostfixMethods; private Queue> queuedFunctionCalls = new Queue>(); @@ -829,29 +831,36 @@ namespace Barotrauma static void _hookLuaPatch(MethodBase __originalMethod, object[] __args, object __instance, out LuaResult result, HookMethodType hookMethodType) { - // Although it works correctly, the performance is low result = new LuaResult(null); - + try { - var classType = __originalMethod.DeclaringType; - var methodPath = $"{hookMethodType}:{classType.Namespace}.{classType.Name}.{__originalMethod.Name}"; - var @params = __originalMethod.GetParameters(); var ptable = new Dictionary(); for (int i = 0; i < @params.Length; i++) { ptable.Add(@params[i].Name, __args[i]); - } - if (_hookMethods.TryGetValue(methodPath, out object hookMethod)) - { - result = new LuaResult(luaSetup.hook.env.lua.Call(hookMethod, __instance, ptable)); - } - else + var funcAddr = ((long)__originalMethod.MethodHandle.GetFunctionPointer()); + HashSet methodSet = null; + switch (hookMethodType) { - luaSetup.PrintError($"No hook method found in _hookMethods[{methodPath}]"); + case HookMethodType.Before: + methodSet = _hookPrefixMethods[funcAddr]; + break; + case HookMethodType.After: + methodSet = _hookPostfixMethods[funcAddr]; + break; + default: + break; + } + if (methodSet != null) + { + foreach (var hookMethod in methodSet) + { + result = new LuaResult(luaSetup.lua.Call(hookMethod, __instance, ptable)); + } } } @@ -881,14 +890,14 @@ namespace Barotrauma return true; } - static void HookLuaPatchPostfix(MethodBase __originalMethod, object[] __params, object __instance) + static void HookLuaPatchPostfix(MethodBase __originalMethod, object[] __args, object __instance) { - _hookLuaPatch(__originalMethod, __params, __instance, out LuaResult result, HookMethodType.After); + _hookLuaPatch(__originalMethod, __args, __instance, out LuaResult result, HookMethodType.After); } - static void HookLuaPatchRetPostfix(MethodBase __originalMethod, object[] __params, ref object __result, object __instance) + static void HookLuaPatchRetPostfix(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) { - _hookLuaPatch(__originalMethod, __params, __instance, out LuaResult result, HookMethodType.After); + _hookLuaPatch(__originalMethod, __args, __instance, out LuaResult result, HookMethodType.After); if (!result.IsNull()) __result = result.Object(); @@ -898,61 +907,99 @@ namespace Barotrauma private static MethodInfo _miHookLuaPatchRetPrefix = typeof(LuaHook).GetMethod("HookLuaPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static); private static MethodInfo _miHookLuaPatchPostfix = typeof(LuaHook).GetMethod("HookLuaPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); private static MethodInfo _miHookLuaPatchRetPostfix = typeof(LuaHook).GetMethod("HookLuaPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); - public void HookMethod(string className, string methodName, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) + public void HookMethod(string className, string methodName, string[] parameterNames, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) { + if (hookMethod == null) + { + env.PrintError("hookMethod cannot be null"); + return; + } + var classType = Type.GetType(className); - var methodInfos = classType.GetMethods(); - HarmonyMethod harmonyMethod = new HarmonyMethod(); - HarmonyMethod harmonyMethodRet = new HarmonyMethod(); + MethodInfo methodInfo = null; + + if (parameterNames.Length > 0) + { + Type[] parameterTypes = parameterNames.Select(x => AccessTools.TypeByName(x)).ToArray(); + methodInfo = classType.GetMethod(methodName, parameterTypes); + } + else + { + methodInfo = classType.GetMethod(methodName); + } + + if (methodInfo == null) + { + env.PrintError($"can't find method({className}.{methodName}) with these parameters' types({string.Join(", ", parameterNames)})"); + return; + } + + var funcAddr = ((long)methodInfo.MethodHandle.GetFunctionPointer()); + var patches = Harmony.GetPatchInfo(methodInfo); if (hookMethodType == HookMethodType.Before) { - harmonyMethod = new HarmonyMethod(_miHookLuaPatchPrefix); - harmonyMethodRet = new HarmonyMethod(_miHookLuaPatchRetPrefix); + if (methodInfo.ReturnType == typeof(void)) + { + if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaPatchPrefix) == null) + { + env.harmony.Patch(methodInfo, prefix: new HarmonyMethod(_miHookLuaPatchPrefix)); + } + } + else + { + if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaPatchRetPrefix) == null) + { + env.harmony.Patch(methodInfo, prefix: new HarmonyMethod(_miHookLuaPatchRetPrefix)); + } + } + + if (_hookPrefixMethods.TryGetValue(funcAddr, out HashSet methodSet)) + { + methodSet.Add(hookMethod); + } + else + { + _hookPrefixMethods.Add(funcAddr, new HashSet() { hookMethod }); + } + } else if (hookMethodType == HookMethodType.After) - { - harmonyMethod = new HarmonyMethod(_miHookLuaPatchPostfix); - harmonyMethodRet = new HarmonyMethod(_miHookLuaPatchRetPostfix); - } - - foreach (var methodInfo in methodInfos) - { - if (methodInfo.Name == methodName) + { + if (methodInfo.ReturnType == typeof(void)) { - if (hookMethodType == HookMethodType.Before) - { - if (methodInfo.ReturnType == typeof(void)) - env.harmony.Patch(methodInfo, prefix: harmonyMethod); - else - env.harmony.Patch(methodInfo, prefix: harmonyMethodRet); - } - else if (hookMethodType == HookMethodType.After) - { - if (methodInfo.ReturnType == typeof(void)) - env.harmony.Patch(methodInfo, postfix: harmonyMethod); - else - env.harmony.Patch(methodInfo, postfix: harmonyMethodRet); - } - - // build an unique method path by patch type, class, method self - var methodPath = $"{hookMethodType}:{classType.Namespace}.{classType.Name}.{methodInfo.Name}"; - - if (hookMethod != null) + if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaPatchPostfix) == null) { - if (!_hookMethods.TryAdd(methodPath, hookMethod)) - env.PrintError($"Failed to add key-value in {nameof(_hookMethods)}\n[{methodPath}, {hookMethod.ToString()}]"); -#if DEBUG - else - env.PrintMessage($"Sucessfully added key-value in {nameof(_hookMethods)}\n[{methodPath}, {hookMethod.ToString()}]"); -#endif + env.harmony.Patch(methodInfo, postfix: new HarmonyMethod(_miHookLuaPatchPostfix)); } - break; } + else + { + if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaPatchRetPostfix) == null) + { + env.harmony.Patch(methodInfo, postfix: new HarmonyMethod(_miHookLuaPatchRetPostfix)); + } + } + + if (_hookPostfixMethods.TryGetValue(funcAddr, out HashSet methodSet)) + { + methodSet.Add(hookMethod); + } + else + { + _hookPostfixMethods.Add(funcAddr, new HashSet() { hookMethod }); + } + } } + public void HookMethod(string className, string methodName, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) + { + HookMethod(className, methodName, new string[] {}, hookMethod, hookMethodType); + } + + public void EnqueueFunction(object function, params object[] args) { queuedFunctionCalls.Enqueue(new Tuple(function, args));