diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs index 34b587b56..495b2907e 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs @@ -16,7 +16,7 @@ namespace Barotrauma public delegate object LuaCsFunc(params object[] args); public delegate object LuaCsPatch(object self, Dictionary args); - public class LuaCsHook + public partial class LuaCsHook { public enum HookMethodType { @@ -50,12 +50,18 @@ namespace Barotrauma } } + private const BindingFlags DefaultBindingFlags = BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; + private static readonly string[] prohibitedHooks = { + "Barotrauma.Lua", + "Barotrauma.Cs", + "ContentPackageManager", + }; + private Harmony harmony; private Dictionary> hookFunctions; - - private Dictionary> hookPrefixMethods; - private Dictionary> hookPostfixMethods; + private Dictionary> hookPrefixMethods; + private Dictionary> hookPostfixMethods; private static LuaCsHook instance; @@ -64,8 +70,11 @@ namespace Barotrauma hookFunctions = new Dictionary>(); - hookPrefixMethods = new Dictionary>(); - hookPostfixMethods = new Dictionary>(); + hookPrefixMethods = new Dictionary>(); + hookPostfixMethods = new Dictionary>(); + + compatHookPrefixMethods = new Dictionary>(); + compatHookPostfixMethods = new Dictionary>(); } public void Initialize() @@ -86,118 +95,178 @@ namespace Barotrauma } }); } - - private static void _hookLuaCsPatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookMethodType) - { - result = null; - try + public void Add(string name, string hookName, LuaCsFunc hook, ACsMod owner = null) + { + name = name.ToLower(); + + if (name == null || hookName == null || hook == null) throw new ArgumentNullException("Names and Hook must not be null"); + + if (!hookFunctions.ContainsKey(name)) { - var funcAddr = ((long)__originalMethod.MethodHandle.GetFunctionPointer()); - HashSet<(string, LuaCsPatch, ACsMod)> methodSet = null; - switch (hookMethodType) + hookFunctions.Add(name, new Dictionary()); + } + + hookFunctions[name][hookName] = (new LuaCsHookCallback(name, hookName, hook), owner); + } + + public void Remove(string name, string hookName) + { + if (name == null || hookName == null) return; + + name = name.ToLower(); + + if (hookFunctions.ContainsKey(name) && hookFunctions[name].ContainsKey(hookName)) + hookFunctions[name].Remove(hookName); + } + + public void Clear() + { + hookFunctions.Clear(); + + hookPrefixMethods.Clear(); + hookPostfixMethods.Clear(); + + compatHookPrefixMethods.Clear(); + compatHookPostfixMethods.Clear(); + + harmony?.UnpatchAll(); + } + + + public void Update() + { + + } + + private Stopwatch performanceMeasurement = new Stopwatch(); + + [MoonSharpHidden] + public T Call(string name, params object[] args) + { + if (GameMain.LuaCs == null) { return default(T); } + if (name == null) { return default(T); } + if (args == null) { args = new object[] { }; } + + name = name.ToLower(); + + if (!hookFunctions.ContainsKey(name)) + { + return default(T); + } + + T lastResult = default(T); + + if (!hookFunctions.ContainsKey(name)) + { + return lastResult; + } + + var hooksToRemove = new List(); + foreach ((var key, var tuple) in hookFunctions[name]) + { + if (tuple.Item2 != null && tuple.Item2.IsDisposed) { - case HookMethodType.Before: - instance.hookPrefixMethods.TryGetValue(funcAddr, out methodSet); - break; - case HookMethodType.After: - instance.hookPostfixMethods.TryGetValue(funcAddr, out methodSet); - break; - default: - break; + hooksToRemove.Add(key); + continue; } - if (methodSet != null) + try { - var @params = __originalMethod.GetParameters(); - var args = new Dictionary(); - for (int i = 0; i < @params.Length; i++) + if (GameMain.LuaCs.PerformanceCounter.EnablePerformanceCounter) { - args.Add(@params[i].Name, __args[i]); + performanceMeasurement.Start(); } - var outOfSocpe = new HashSet<(string, LuaCsPatch, ACsMod)>(); - foreach (var tuple in methodSet) + var result = tuple.Item1.func(args); + if (result != null) { - if (tuple.Item3 != null && tuple.Item3.IsDisposed) + if (typeof(object) != typeof(T)) { - outOfSocpe.Add(tuple); + if (result is LuaResult lRes) + { + if (!lRes.IsNull()) { lastResult = lRes.DynValue().ToObject(); } + } + else if (result is T cRes && cRes != null) + { + lastResult = cRes; + } } else { - var _result = tuple.Item2(__instance, args); - if (_result != null) + if (result is LuaResult lRes) { - if (_result is LuaResult res) - { - if (!res.IsNull()) - { - if (__originalMethod is MethodInfo mi && mi.ReturnType != typeof(void)) - { - result = res.DynValue().ToObject(mi.ReturnType); - } - else - { - result = res.DynValue().ToObject(); - } - } - } - else - { - result = _result; - } + if (!lRes.IsNull()) { lastResult = (T)(object)lRes.DynValue(); } + } + else + { + lastResult = (T)result; } } } - foreach (var tuple in outOfSocpe) { methodSet.Remove(tuple); } + + if (GameMain.LuaCs.PerformanceCounter.EnablePerformanceCounter) + { + performanceMeasurement.Stop(); + GameMain.LuaCs.PerformanceCounter.SetHookElapsedTicks(name, key, performanceMeasurement.ElapsedTicks); + performanceMeasurement.Reset(); + } + } + catch (Exception e) + { + StringBuilder argsSb = new StringBuilder(); + foreach (var arg in args) argsSb.Append(arg + " "); + GameMain.LuaCs.HandleException( + e, $"Error in Hook '{name}'->'{key}', with args '{argsSb}':\n{e}", ExceptionType.Both); } } - catch (Exception ex) - { - GameMain.LuaCs.HandleException(ex, $"Error in {__originalMethod.Name}:", exceptionType: LuaCsSetup.ExceptionType.Both); + foreach (var key in hooksToRemove) + { + hookFunctions[name].Remove(key); } + + return lastResult; } + public object Call(string name, params object[] args) => Call(name, args); - private static bool HookLuaCsPatchPrefix(MethodBase __originalMethod, object[] __args, object __instance) + private static bool PatchPrefix(MethodBase __originalMethod, object[] __args, object __instance) { - _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before); + ExecutePatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before); return result == null; } - private static void HookLuaCsPatchPostfix(MethodBase __originalMethod, object[] __args, object __instance) => - _hookLuaCsPatch(__originalMethod, __args, __instance, out object _, HookMethodType.After); + private static void PatchPostfix(MethodBase __originalMethod, object[] __args, object __instance) => + ExecutePatch(__originalMethod, __args, __instance, out object _, HookMethodType.After); - private static bool HookLuaCsPatchRetPrefix(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) + private static bool PatchPrefixWithReturn(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) { - _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before); + ExecutePatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before); if (result != null) { __result = result; return false; } - else return true; + else { return true; } } - private static void HookLuaCsPatchRetPostfix(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) + private static void PatchPostfixWithReturn(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) { - _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.After); - if (result != null) __result = result; + ExecutePatch(__originalMethod, __args, __instance, out object result, HookMethodType.After); + if (result != null) { __result = result; } } - private const BindingFlags DefaultBindingFlags = BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; + private static MethodInfo miPatchPrefix = typeof(LuaCsHook).GetMethod("PatchPrefix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo miPatchPostfix = typeof(LuaCsHook).GetMethod("PatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo miPatchPrefixWithReturn = typeof(LuaCsHook).GetMethod("PatchPrefixWithReturn", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo miPatchPostfixWithReturn = typeof(LuaCsHook).GetMethod("PatchPostfixWithReturn", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo _miHookLuaCsPatchPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo _miHookLuaCsPatchPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); - - private static MethodInfo ResolveMethod(string where, string className, string methodName, string[] parameterNames) - { + private static MethodInfo ResolveMethod(string className, string methodName, string[] parameterNames) + { var classType = LuaUserData.GetType(className); if (classType == null) { - GameMain.LuaCs.HandleException(new Exception($"Tried to use {where} with an invalid class name '{className}'.")); + GameMain.LuaCs.HandleException(new Exception($"Invalid class name '{className}'.")); return null; } @@ -222,52 +291,45 @@ namespace Barotrauma return methodInfo; } - private static readonly string[] prohibitedHooks = { - "Barotrauma.Lua", - "Barotrauma.Cs", - "ContentPackageManager", - }; - public void HookMethod(string identifier, MethodInfo method, LuaCsPatch patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null) + public void Patch(string identifier, MethodInfo method, LuaCsFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null) { if (identifier == null || method == null || patch == null) { - GameMain.LuaCs.HandleException(new ArgumentNullException("Identifier, Method and Patch arguments must not be null."), exceptionType: ExceptionType.Both); - return; + throw new ArgumentNullException("Identifier, Method and Patch arguments must not be null."); } if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h))) { - GameMain.LuaCs.HandleException(new ArgumentException("Hooks into Modding Environment are prohibited."), exceptionType: ExceptionType.Both); - return; + throw new ArgumentException("Hooks into Modding Environment are prohibited."); } - var funcAddr = ((long)method.MethodHandle.GetFunctionPointer()); + var funcAddr = (long)method.MethodHandle.GetFunctionPointer(); var patches = Harmony.GetPatchInfo(method); if (hookType == HookMethodType.Before) { if (method.ReturnType != typeof(void)) - { - if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchRetPrefix) == null) + { + if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == miPatchPrefixWithReturn) == null) { - harmony.Patch(method, prefix: new HarmonyMethod(_miHookLuaCsPatchRetPrefix)); + harmony.Patch(method, prefix: new HarmonyMethod(miPatchPrefixWithReturn)); } } else - { - if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchPrefix) == null) + { + if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == miPatchPrefix) == null) { - harmony.Patch(method, prefix: new HarmonyMethod(_miHookLuaCsPatchPrefix)); + harmony.Patch(method, prefix: new HarmonyMethod(miPatchPrefix)); } } - if (hookPrefixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsPatch, ACsMod)> methodSet)) + if (hookPrefixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsFunc, ACsMod)> methodSet)) { methodSet.RemoveWhere(tuple => tuple.Item1 == identifier); methodSet.Add((identifier, patch, owner)); } else if (patch != null) { - hookPrefixMethods.Add(funcAddr, new HashSet<(string, LuaCsPatch, ACsMod)>() { (identifier, patch, owner) }); + hookPrefixMethods.Add(funcAddr, new HashSet<(string, LuaCsFunc, ACsMod)>() { (identifier, patch, owner) }); } } @@ -275,178 +337,147 @@ namespace Barotrauma { if (method.ReturnType != typeof(void)) { - if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchRetPostfix) == null) + if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == miPatchPostfixWithReturn) == null) { - harmony.Patch(method, postfix: new HarmonyMethod(_miHookLuaCsPatchRetPostfix)); + harmony.Patch(method, postfix: new HarmonyMethod(miPatchPostfixWithReturn)); } } else - { - if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchPostfix) == null) + { + if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == miPatchPostfix) == null) { - harmony.Patch(method, postfix: new HarmonyMethod(_miHookLuaCsPatchPostfix)); + harmony.Patch(method, postfix: new HarmonyMethod(miPatchPostfix)); } } - if (hookPostfixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsPatch, ACsMod)> methodSet)) + if (hookPostfixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsFunc, ACsMod)> methodSet)) { methodSet.RemoveWhere(tuple => tuple.Item1 == identifier); methodSet.Add((identifier, patch, owner)); } else if (patch != null) { - hookPostfixMethods.Add(funcAddr, new HashSet<(string, LuaCsPatch, ACsMod)>() { (identifier, patch, owner) }); + hookPostfixMethods.Add(funcAddr, new HashSet<(string, LuaCsFunc, ACsMod)>() { (identifier, patch, owner) }); } } } - protected void HookMethod(string identifier, string className, string methodName, string[] parameterNames, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) + private static void ExecutePatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookMethodType) { + result = null; - MethodInfo methodInfo = ResolveMethod("HookMethod", className, methodName, parameterNames); - if (methodInfo == null) return; - HookMethod(identifier, methodInfo, patch, hookMethodType); - } - protected void HookMethod(string identifier, string className, string methodName, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) => - HookMethod(identifier, className, methodName, null, patch, hookMethodType); - protected void HookMethod(string className, string methodName, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) => - HookMethod("", className, methodName, null, patch, hookMethodType); - protected void HookMethod(string className, string methodName, string[] parameterNames, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) => - HookMethod("", className, methodName, parameterNames, patch, hookMethodType); - - - public void UnhookMethod(string identifier, MethodInfo method, HookMethodType hookType = HookMethodType.Before) - { - var funcAddr = ((long)method.MethodHandle.GetFunctionPointer()); - - Dictionary> methods; - if (hookType == HookMethodType.Before) methods = hookPrefixMethods; - else if (hookType == HookMethodType.After) methods = hookPostfixMethods; - else throw null; - - if (methods.ContainsKey(funcAddr)) methods[funcAddr]?.RemoveWhere(t => t.Item1 == identifier); - } - protected void UnhookMethod(string identifier, string className, string methodName, string[] parameterNames, HookMethodType hookType = HookMethodType.Before) - { - MethodInfo methodInfo = ResolveMethod("UnhookMathod", className, methodName, parameterNames); - if (methodInfo == null) return; - UnhookMethod(identifier, methodInfo, hookType); - } - - public void Add(string name, string hookName, LuaCsFunc hook, ACsMod owner = null) - { - name = name.ToLower(); - - if (name == null || hookName == null || hook == null) throw new ArgumentNullException("Names and Hook must not be null"); - - if (!hookFunctions.ContainsKey(name)) - hookFunctions.Add(name, new Dictionary()); - - hookFunctions[name][hookName] = (new LuaCsHookCallback(name, hookName, hook), owner); - } - - public void Remove(string name, string hookName) - { - if (name == null || hookName == null) return; - - name = name.ToLower(); - - if (hookFunctions.ContainsKey(name) && hookFunctions[name].ContainsKey(hookName)) - hookFunctions[name].Remove(hookName); - } - - public void Clear() - { - hookFunctions.Clear(); - - hookPrefixMethods.Clear(); - hookPostfixMethods.Clear(); - - harmony?.UnpatchAll(); - } - - - public void Update() - { - - } - - private Stopwatch performanceMeasurement = new Stopwatch(); - - [MoonSharpHidden] - public T Call(string name, params object[] args) - { - if (GameMain.LuaCs == null) return default(T); - if (name == null) return default(T); - if (args == null) { args = new object[] { }; } - - name = name.ToLower(); - - if (!hookFunctions.ContainsKey(name)) - return default(T); - - T lastResult = default(T); - - if (hookFunctions.ContainsKey(name)) + try { - var outOfScope = new List(); - foreach ((var key, var tuple) in hookFunctions[name]) + long funcAddr = (long)__originalMethod.MethodHandle.GetFunctionPointer(); + + HashSet<(string, LuaCsFunc, ACsMod)> methodSet = null; + switch (hookMethodType) { - if (tuple.Item2 != null && tuple.Item2.IsDisposed) - outOfScope.Add(key); - else + case HookMethodType.Before: + instance.hookPrefixMethods.TryGetValue(funcAddr, out methodSet); + break; + case HookMethodType.After: + instance.hookPostfixMethods.TryGetValue(funcAddr, out methodSet); + break; + default: + break; + } + + if (methodSet == null) + { + return; + } + + var patchesToRemove = new HashSet<(string, LuaCsFunc, ACsMod)>(); + foreach (var tuple in methodSet) + { + if (tuple.Item3 != null && tuple.Item3.IsDisposed) { - try - { - if (GameMain.LuaCs.PerformanceCounter.EnablePerformanceCounter) - { - performanceMeasurement.Start(); - } + patchesToRemove.Add(tuple); + continue; + } - var result = tuple.Item1.func(args); - if (result != null) - { - if (typeof(object) != typeof(T)) - { - if (result is LuaResult lRes) - { - if (!lRes.IsNull()) lastResult = lRes.DynValue().ToObject(); - } - else if (result is T cRes && cRes != null) lastResult = cRes; - } - else - { - if (result is LuaResult lRes) - { - if (!lRes.IsNull()) lastResult = (T)(object)lRes.DynValue(); - } - else lastResult = (T)result; - } - } + object[] args = new object[] { __instance }.Concat(__args).ToArray(); + object _result = tuple.Item2(args); - if (GameMain.LuaCs.PerformanceCounter.EnablePerformanceCounter) - { - performanceMeasurement.Stop(); - GameMain.LuaCs.PerformanceCounter.SetHookElapsedTicks(name, key, performanceMeasurement.ElapsedTicks); - performanceMeasurement.Reset(); - } - } - catch (Exception e) + if (_result == null) + { + continue; + } + + if (_result is LuaResult res) + { + if (!res.IsNull()) { - StringBuilder argsSb = new StringBuilder(); - foreach (var arg in args) argsSb.Append(arg + " "); - GameMain.LuaCs.HandleException( - e, $"Error in Hook '{name}'->'{key}', with args '{argsSb}':\n{e}", ExceptionType.Both); + if (__originalMethod is MethodInfo mi && mi.ReturnType != typeof(void)) + { + result = res.DynValue().ToObject(mi.ReturnType); + } + else + { + result = res.DynValue().ToObject(); + } } } + else + { + result = _result; + } } - foreach (var key in outOfScope) hookFunctions[name].Remove(key); + + foreach (var tuple in patchesToRemove) + { + methodSet.Remove(tuple); + } + } + catch (Exception ex) + { + GameMain.LuaCs.HandleException(ex, $"Error in {__originalMethod.Name}:", exceptionType: LuaCsSetup.ExceptionType.Both); + } + } + + + public void Patch(string identifier, string className, string methodName, string[] parameterNames, LuaCsFunc patch, HookMethodType hookMethodType = HookMethodType.Before) + { + MethodInfo methodInfo = ResolveMethod(className, methodName, parameterNames); + if (methodInfo == null) return; + Patch(identifier, methodInfo, patch, hookMethodType); + } + public void Patch(string identifier, string className, string methodName, LuaCsFunc patch, HookMethodType hookMethodType = HookMethodType.Before) => + Patch(identifier, className, methodName, null, patch, hookMethodType); + public void Patch(string className, string methodName, LuaCsFunc patch, HookMethodType hookMethodType = HookMethodType.Before) => + Patch("", className, methodName, null, patch, hookMethodType); + public void Patch(string className, string methodName, string[] parameterNames, LuaCsFunc patch, HookMethodType hookMethodType = HookMethodType.Before) => + Patch("", className, methodName, parameterNames, patch, hookMethodType); + + + public void RemovePatch(string identifier, MethodInfo method, HookMethodType hookType = HookMethodType.Before) + { + var funcAddr = (long)method.MethodHandle.GetFunctionPointer(); + + Dictionary> methods; + if (hookType == HookMethodType.Before) { methods = hookPrefixMethods; } + else if (hookType == HookMethodType.After) { methods = hookPostfixMethods; } + else { throw new NotImplementedException(); } + + if (methods.ContainsKey(funcAddr)) + { + methods[funcAddr]?.RemoveWhere(t => t.Item1 == identifier); + } + } + + public void RemovePatch(string identifier, string className, string methodName, string[] parameterNames, HookMethodType hookType = HookMethodType.Before) + { + MethodInfo methodInfo = ResolveMethod(className, methodName, parameterNames); + + if (methodInfo == null) + { + return; } - return lastResult; + RemovePatch(identifier, methodInfo, hookType); } - public object Call(string name, params object[] args) => Call(name, args); } } \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs new file mode 100644 index 000000000..e3795e7fc --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs @@ -0,0 +1,261 @@ +using System; +using System.Linq; +using System.Reflection; +using MoonSharp.Interpreter; +using HarmonyLib; +using System.Collections.Generic; +using System.Text; +using MoonSharp.Interpreter.Interop; +using static Barotrauma.LuaCsSetup; +using System.Threading; +using System.Diagnostics; + +namespace Barotrauma +{ + partial class LuaCsHook + { + private Dictionary> compatHookPrefixMethods; + private Dictionary> compatHookPostfixMethods; + + private static void _hookLuaCsPatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookMethodType) + { + result = null; + + try + { + var funcAddr = ((long)__originalMethod.MethodHandle.GetFunctionPointer()); + HashSet<(string, LuaCsPatch, ACsMod)> methodSet = null; + switch (hookMethodType) + { + case HookMethodType.Before: + instance.compatHookPrefixMethods.TryGetValue(funcAddr, out methodSet); + break; + case HookMethodType.After: + instance.compatHookPostfixMethods.TryGetValue(funcAddr, out methodSet); + break; + default: + break; + } + + if (methodSet != null) + { + var @params = __originalMethod.GetParameters(); + var args = new Dictionary(); + for (int i = 0; i < @params.Length; i++) + { + args.Add(@params[i].Name, __args[i]); + } + + var outOfSocpe = new HashSet<(string, LuaCsPatch, ACsMod)>(); + foreach (var tuple in methodSet) + { + if (tuple.Item3 != null && tuple.Item3.IsDisposed) + { + outOfSocpe.Add(tuple); + } + else + { + var _result = tuple.Item2(__instance, args); + if (_result != null) + { + if (_result is LuaResult res) + { + if (!res.IsNull()) + { + if (__originalMethod is MethodInfo mi && mi.ReturnType != typeof(void)) + { + result = res.DynValue().ToObject(mi.ReturnType); + } + else + { + result = res.DynValue().ToObject(); + } + } + } + else + { + result = _result; + } + } + } + } + foreach (var tuple in outOfSocpe) { methodSet.Remove(tuple); } + } + } + catch (Exception ex) + { + GameMain.LuaCs.HandleException(ex, $"Error in {__originalMethod.Name}:", exceptionType: LuaCsSetup.ExceptionType.Both); + } + } + + + private static bool HookLuaCsPatchPrefix(MethodBase __originalMethod, object[] __args, object __instance) + { + _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before); + return result == null; + } + private static void HookLuaCsPatchPostfix(MethodBase __originalMethod, object[] __args, object __instance) => + _hookLuaCsPatch(__originalMethod, __args, __instance, out object _, HookMethodType.After); + + private static bool HookLuaCsPatchRetPrefix(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) + { + _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before); + if (result != null) + { + __result = result; + return false; + } + else return true; + } + private static void HookLuaCsPatchRetPostfix(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) + { + _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.After); + if (result != null) __result = result; + } + + + private static MethodInfo _miHookLuaCsPatchPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); + + private static MethodInfo ResolveMethod(string where, string className, string methodName, string[] parameterNames) + { + var classType = LuaUserData.GetType(className); + + if (classType == null) + { + GameMain.LuaCs.HandleException(new Exception($"Tried to use {where} with an invalid class name '{className}'.")); + return null; + } + + MethodInfo methodInfo = null; + + if (parameterNames != null) + { + Type[] parameterTypes = parameterNames.Select(x => LuaUserData.GetType(x)).ToArray(); + methodInfo = classType.GetMethod(methodName, DefaultBindingFlags, null, parameterTypes, null); + } + else + { + methodInfo = classType.GetMethod(methodName, DefaultBindingFlags); + } + + if (methodInfo == null) + { + string parameterNamesStr = parameterNames == null ? "" : string.Join(", ", parameterNames); + GameMain.LuaCs.HandleException(new Exception($"Method '{methodName}' with parameters '{parameterNamesStr}' not found in class '{className}'")); + } + + return methodInfo; + } + + public void HookMethod(string identifier, MethodInfo method, LuaCsPatch patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null) + { + if (identifier == null || method == null || patch == null) + { + GameMain.LuaCs.HandleException(new ArgumentNullException("Identifier, Method and Patch arguments must not be null."), exceptionType: ExceptionType.Both); + return; + } + if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h))) + { + GameMain.LuaCs.HandleException(new ArgumentException("Hooks into Modding Environment are prohibited."), exceptionType: ExceptionType.Both); + return; + } + + var funcAddr = ((long)method.MethodHandle.GetFunctionPointer()); + var patches = Harmony.GetPatchInfo(method); + + if (hookType == HookMethodType.Before) + { + if (method.ReturnType != typeof(void)) + { + if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchRetPrefix) == null) + { + harmony.Patch(method, prefix: new HarmonyMethod(_miHookLuaCsPatchRetPrefix)); + } + } + else + { + if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchPrefix) == null) + { + harmony.Patch(method, prefix: new HarmonyMethod(_miHookLuaCsPatchPrefix)); + } + } + + if (compatHookPrefixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsPatch, ACsMod)> methodSet)) + { + methodSet.RemoveWhere(tuple => tuple.Item1 == identifier); + methodSet.Add((identifier, patch, owner)); + } + else if (patch != null) + { + compatHookPrefixMethods.Add(funcAddr, new HashSet<(string, LuaCsPatch, ACsMod)>() { (identifier, patch, owner) }); + } + + } + else if (hookType == HookMethodType.After) + { + if (method.ReturnType != typeof(void)) + { + if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchRetPostfix) == null) + { + harmony.Patch(method, postfix: new HarmonyMethod(_miHookLuaCsPatchRetPostfix)); + } + } + else + { + if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == _miHookLuaCsPatchPostfix) == null) + { + harmony.Patch(method, postfix: new HarmonyMethod(_miHookLuaCsPatchPostfix)); + } + } + + if (compatHookPostfixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsPatch, ACsMod)> methodSet)) + { + methodSet.RemoveWhere(tuple => tuple.Item1 == identifier); + methodSet.Add((identifier, patch, owner)); + } + else if (patch != null) + { + compatHookPostfixMethods.Add(funcAddr, new HashSet<(string, LuaCsPatch, ACsMod)>() { (identifier, patch, owner) }); + } + + } + + } + + protected void HookMethod(string identifier, string className, string methodName, string[] parameterNames, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) + { + + MethodInfo methodInfo = ResolveMethod("HookMethod", className, methodName, parameterNames); + if (methodInfo == null) return; + HookMethod(identifier, methodInfo, patch, hookMethodType); + } + protected void HookMethod(string identifier, string className, string methodName, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) => + HookMethod(identifier, className, methodName, null, patch, hookMethodType); + protected void HookMethod(string className, string methodName, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) => + HookMethod("", className, methodName, null, patch, hookMethodType); + protected void HookMethod(string className, string methodName, string[] parameterNames, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) => + HookMethod("", className, methodName, parameterNames, patch, hookMethodType); + + + public void UnhookMethod(string identifier, MethodInfo method, HookMethodType hookType = HookMethodType.Before) + { + var funcAddr = ((long)method.MethodHandle.GetFunctionPointer()); + + Dictionary> methods; + if (hookType == HookMethodType.Before) methods = compatHookPrefixMethods; + else if (hookType == HookMethodType.After) methods = compatHookPostfixMethods; + else throw null; + + if (methods.ContainsKey(funcAddr)) methods[funcAddr]?.RemoveWhere(t => t.Item1 == identifier); + } + protected void UnhookMethod(string identifier, string className, string methodName, string[] parameterNames, HookMethodType hookType = HookMethodType.Before) + { + MethodInfo methodInfo = ResolveMethod("UnhookMathod", className, methodName, parameterNames); + if (methodInfo == null) return; + UnhookMethod(identifier, methodInfo, hookType); + } + } +} \ No newline at end of file