diff --git a/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses/LuaHook.cs b/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses/LuaHook.cs index 0629a791d..8df80320f 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses/LuaHook.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/Lua/LuaClasses/LuaHook.cs @@ -12,8 +12,8 @@ namespace Barotrauma { public LuaHook() { - _hookPrefixMethods = new Dictionary>(); - _hookPostfixMethods = new Dictionary>(); + _hookPrefixMethods = new Dictionary>>(); + _hookPostfixMethods = new Dictionary>>(); } public class HookFunction @@ -32,8 +32,8 @@ namespace Barotrauma private Dictionary> hookFunctions = new Dictionary>(); - private static Dictionary> _hookPrefixMethods; - private static Dictionary> _hookPostfixMethods; + private static Dictionary>> _hookPrefixMethods; + private static Dictionary>> _hookPostfixMethods; private Queue> queuedFunctionCalls = new Queue>(); @@ -53,38 +53,38 @@ namespace Barotrauma try { - var @params = __originalMethod.GetParameters(); - var ptable = new Dictionary(); - for (int i = 0; i < @params.Length; i++) - { - ptable.Add(@params[i].Name, __args[i]); - } - var funcAddr = ((long)__originalMethod.MethodHandle.GetFunctionPointer()); - HashSet methodSet = null; + HashSet> methodSet = null; switch (hookMethodType) { case HookMethodType.Before: - methodSet = _hookPrefixMethods[funcAddr]; + _hookPrefixMethods.TryGetValue(funcAddr, out methodSet); break; case HookMethodType.After: - methodSet = _hookPostfixMethods[funcAddr]; + _hookPostfixMethods.TryGetValue(funcAddr, out methodSet); break; default: break; } + if (methodSet != null) { - foreach (var hookMethod in methodSet) + var @params = __originalMethod.GetParameters(); + var ptable = new Dictionary(); + for (int i = 0; i < @params.Length; i++) { - result = new LuaResult(GameMain.Lua.lua.Call(hookMethod, __instance, ptable)); + ptable.Add(@params[i].Name, __args[i]); + } + + foreach (var tuple in methodSet) + { + result = new LuaResult(GameMain.Lua.lua.Call(tuple.Item2, __instance, ptable)); } } - } catch (Exception ex) { - GameMain.Lua.HandleLuaException(ex, nameof(_hookLuaPatch)); + GameMain.Lua.HandleLuaException(ex); } } @@ -126,20 +126,14 @@ namespace Barotrauma private static MethodInfo _miHookLuaPatchRetPrefix = typeof(LuaHook).GetMethod("HookLuaPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static); private static MethodInfo _miHookLuaPatchPostfix = typeof(LuaHook).GetMethod("HookLuaPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); private static MethodInfo _miHookLuaPatchRetPostfix = typeof(LuaHook).GetMethod("HookLuaPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); - public void HookMethod(string className, string methodName, string[] parameterNames, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) + public void HookMethod(string identifier, string className, string methodName, string[] parameterNames, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) { - if (hookMethod == null) - { - GameMain.Lua.PrintError("hookMethod cannot be null"); - return; - } - - var classType = Type.GetType(className); + var classType = LuaUserData.GetType(className); MethodInfo methodInfo = null; if (parameterNames != null) { - Type[] parameterTypes = parameterNames.Select(x => AccessTools.TypeByName(x)).ToArray(); + Type[] parameterTypes = parameterNames.Select(x => LuaUserData.GetType(x)).ToArray(); methodInfo = classType.GetMethod(methodName, DefaultBindingFlags, null, parameterTypes, null); } else @@ -149,7 +143,7 @@ namespace Barotrauma if (methodInfo == null) { - GameMain.Lua.PrintError($"can't find method({className}.{methodName}) with these parameters' types({string.Join(", ", parameterNames)})"); + GameMain.Lua.PrintError($"Method '{methodName}' with parameter '{string.Join(", ", parameterNames)}' not found from Class '{className}'"); return; } @@ -173,13 +167,20 @@ namespace Barotrauma } } - if (_hookPrefixMethods.TryGetValue(funcAddr, out HashSet methodSet)) + if (_hookPrefixMethods.TryGetValue(funcAddr, out HashSet> methodSet)) { - methodSet.Add(hookMethod); + if (identifier != "") + { + methodSet.RemoveWhere(tuple => tuple.Item1 == identifier); + } + if (hookMethod != null) + { + methodSet.Add(Tuple.Create(identifier, hookMethod)); + } } - else + else if (hookMethod != null) { - _hookPrefixMethods.Add(funcAddr, new HashSet() { hookMethod }); + _hookPrefixMethods.Add(funcAddr, new HashSet>() { Tuple.Create(identifier, hookMethod) }); } } @@ -200,13 +201,20 @@ namespace Barotrauma } } - if (_hookPostfixMethods.TryGetValue(funcAddr, out HashSet methodSet)) + if (_hookPostfixMethods.TryGetValue(funcAddr, out HashSet> methodSet)) { - methodSet.Add(hookMethod); + if (identifier != "") + { + methodSet.RemoveWhere(tuple => tuple.Item1 == identifier); + } + if (hookMethod != null) + { + methodSet.Add(Tuple.Create(identifier, hookMethod)); + } } - else + else if (hookMethod != null) { - _hookPostfixMethods.Add(funcAddr, new HashSet() { hookMethod }); + _hookPostfixMethods.Add(funcAddr, new HashSet>() { Tuple.Create(identifier, hookMethod) }); } } @@ -215,9 +223,18 @@ namespace Barotrauma public void HookMethod(string className, string methodName, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) { - HookMethod(className, methodName, null, hookMethod, hookMethodType); + HookMethod("", className, methodName, null, hookMethod, hookMethodType); } + public void HookMethod(string className, string methodName, string[] parameterNames, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) + { + HookMethod("", className, methodName, parameterNames, hookMethod, hookMethodType); + } + + public void HookMethod(string identifier, string className, string methodName, object hookMethod, HookMethodType hookMethodType = HookMethodType.Before) + { + HookMethod(identifier, className, methodName, null, hookMethod, hookMethodType); + } public void EnqueueFunction(object function, params object[] args) {