diff --git a/Barotrauma/BarotraumaClient/LinuxClient.csproj b/Barotrauma/BarotraumaClient/LinuxClient.csproj index d51ae3431..01106d746 100644 --- a/Barotrauma/BarotraumaClient/LinuxClient.csproj +++ b/Barotrauma/BarotraumaClient/LinuxClient.csproj @@ -142,6 +142,7 @@ + @@ -222,4 +223,4 @@ - \ No newline at end of file + diff --git a/Barotrauma/BarotraumaClient/MacClient.csproj b/Barotrauma/BarotraumaClient/MacClient.csproj index fb9a8c51f..97be07204 100644 --- a/Barotrauma/BarotraumaClient/MacClient.csproj +++ b/Barotrauma/BarotraumaClient/MacClient.csproj @@ -134,6 +134,7 @@ + @@ -224,4 +225,4 @@ - \ No newline at end of file + diff --git a/Barotrauma/BarotraumaClient/WindowsClient.csproj b/Barotrauma/BarotraumaClient/WindowsClient.csproj index 1c96d533e..2114560b0 100644 --- a/Barotrauma/BarotraumaClient/WindowsClient.csproj +++ b/Barotrauma/BarotraumaClient/WindowsClient.csproj @@ -141,6 +141,7 @@ + diff --git a/Barotrauma/BarotraumaServer/LinuxServer.csproj b/Barotrauma/BarotraumaServer/LinuxServer.csproj index 0828ce53f..9d7662827 100644 --- a/Barotrauma/BarotraumaServer/LinuxServer.csproj +++ b/Barotrauma/BarotraumaServer/LinuxServer.csproj @@ -89,6 +89,7 @@ + diff --git a/Barotrauma/BarotraumaServer/MacServer.csproj b/Barotrauma/BarotraumaServer/MacServer.csproj index 731599318..518422501 100644 --- a/Barotrauma/BarotraumaServer/MacServer.csproj +++ b/Barotrauma/BarotraumaServer/MacServer.csproj @@ -86,6 +86,7 @@ + diff --git a/Barotrauma/BarotraumaServer/WindowsServer.csproj b/Barotrauma/BarotraumaServer/WindowsServer.csproj index d1e747e39..b4e15a59c 100644 --- a/Barotrauma/BarotraumaServer/WindowsServer.csproj +++ b/Barotrauma/BarotraumaServer/WindowsServer.csproj @@ -88,6 +88,7 @@ + diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs index 581becdc0..a0a4fd308 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs @@ -4,10 +4,10 @@ using System.Text; using MoonSharp.Interpreter; using Microsoft.Xna.Framework; using FarseerPhysics.Dynamics; +using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch; namespace Barotrauma { - public static class LuaCustomConverters { public static void RegisterAll() @@ -31,8 +31,12 @@ namespace Barotrauma v => (LuaCsFunc)(args => GameMain.LuaCs.CallLuaFunction(v.Function, args))); Script.GlobalOptions.CustomConverters.SetScriptToClrCustomConversion( DataType.Function, - typeof(LuaCsPatch), - v => (LuaCsPatch)((self, args) => GameMain.LuaCs.CallLuaFunction(v.Function, self, args))); + typeof(LuaCsCompatPatchFunc), + v => (LuaCsCompatPatchFunc)((self, args) => GameMain.LuaCs.CallLuaFunction(v.Function, self, args))); + Script.GlobalOptions.CustomConverters.SetScriptToClrCustomConversion( + DataType.Function, + typeof(LuaCsPatchFunc), + v => (LuaCsPatchFunc)((self, args) => GameMain.LuaCs.CallLuaFunction(v.Function, self, args))); #if CLIENT RegisterAction(); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs index 4558c8afa..1d6f53405 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs @@ -1,61 +1,524 @@ using System; -using System.Linq; -using System.Reflection; -using MoonSharp.Interpreter; -using HarmonyLib; +using System.Collections; using System.Collections.Generic; -using System.Text; -using MoonSharp.Interpreter.Interop; -using static Barotrauma.LuaCsSetup; -using System.Threading; using System.Diagnostics; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Reflection.Emit; +using System.Text; +using System.Text.RegularExpressions; +using HarmonyLib; +using Microsoft.Xna.Framework; +using MoonSharp.Interpreter; +using MoonSharp.Interpreter.Interop; +using Sigil; +using Sigil.NonGeneric; +using static Barotrauma.LuaCsSetup; namespace Barotrauma { - public delegate void LuaCsAction(params object[] args); - public delegate DynValue LuaCsFunc(params object[] args); - public delegate object LuaCsPatch(object self, Dictionary args); + public delegate void LuaCsAction(params object[] args); + public delegate DynValue LuaCsFunc(params object[] args); + public delegate DynValue LuaCsPatchFunc(object instance, LuaCsHook.ParameterTable ptable); - public partial class LuaCsHook - { - public enum HookMethodType - { - Before, After - } + internal static class SigilExtensions + { + /// + /// Puts a type on the stack, as a object instead of a + /// runtime type token. + /// + /// The IL emitter. + /// The type to put on the stack. + public static void LoadType(this Emit il, Type type) + { + if (type == null) throw new ArgumentNullException(nameof(type)); + il.LoadConstant(type); // ldtoken + // This converts the type token into a Type object + il.Call(typeof(Type).GetMethod( + name: nameof(Type.GetTypeFromHandle), + bindingAttr: BindingFlags.Public | BindingFlags.Static, + binder: null, + types: new Type[] { typeof(RuntimeTypeHandle) }, + modifiers: null)); + } - private class LuaHookFunction - { - public string name; - public string hookName; - public object function; + /// + /// Converts the value on the stack to . + /// + /// The IL emitter. + /// The type of the value on the stack. + public static void ToObject(this Emit il, Type type) + { + if (type == null) throw new ArgumentNullException(nameof(type)); + il.DerefIfByRef(ref type); + if (type.IsValueType) + { + il.Box(type); + } + else if (type != typeof(object)) + { + il.CastClass(); + } + } - public LuaHookFunction(string n, string hn, object func) - { - name = n; - hookName = hn; - function = func; - } - } - private class LuaCsHookCallback - { - public string name; - public string hookName; - public LuaCsFunc func; + /// + /// Deferences the value on stack if the provided type is ByRef. + /// + /// The IL emitter. + /// The type to check if ByRef. + public static void DerefIfByRef(this Emit il, Type type) => il.DerefIfByRef(ref type); - public LuaCsHookCallback(string name, string hookName, LuaCsFunc func) - { - this.name = name; - this.hookName = hookName; - this.func = func; - } - } + /// + /// Deferences the value on stack if the provided type is ByRef. + /// + /// The IL emitter. + /// The type to check if ByRef. + public static void DerefIfByRef(this Emit il, ref Type type) + { + if (type == null) throw new ArgumentNullException(nameof(type)); + if (type.IsByRef) + { + type = type.GetElementType(); + if (type.IsValueType) + { + il.LoadObject(type); + } + else + { + il.LoadIndirect(type); + } + } + } - private const BindingFlags DefaultBindingFlags = BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; - private static readonly string[] prohibitedHooks = { - "Barotrauma.Lua", - "Barotrauma.Cs", - "ContentPackageManager", - }; + // Copied from https://github.com/evilfactory/moonsharp/blob/5264656c6442e783f3c75082cce69a93d66d4cc0/src/MoonSharp.Interpreter/Interop/Converters/ScriptToClrConversions.cs#L79-L99 + private static MethodInfo HasImplicitConversion(Type baseType, Type targetType) + { + try + { + return Expression.Convert(Expression.Parameter(baseType, null), targetType).Method; + } + catch + { + if (baseType.BaseType != null) + { + return HasImplicitConversion(baseType.BaseType, targetType); + } + + if (targetType.BaseType != null) + { + return HasImplicitConversion(baseType, targetType.BaseType); + } + + return null; + } + } + + /// + /// Loads a local variable and casts it to the target type. + /// + /// The IL emitter. + /// The value to cast. Must be of type . + /// The type to cast into. + public static void LoadLocalAndCast(this Emit il, Local value, Type targetType) + { + if (value == null) throw new ArgumentNullException(nameof(value)); + if (targetType == null) throw new ArgumentNullException(nameof(targetType)); + if (value.LocalType != typeof(object)) + { + throw new ArgumentException($"Expected local type {typeof(object)}; got {value.LocalType}.", nameof(value)); + } + + var guid = Guid.NewGuid().ToString("N"); + + if (targetType.IsByRef) + { + targetType = targetType.GetElementType(); + } + + // IL: var baseType = value.GetType(); + var baseType = il.DeclareLocal(typeof(Type), $"cast_baseType_{guid}"); + il.LoadLocal(value); + il.Call(typeof(object).GetMethod("GetType")); + il.StoreLocal(baseType); + + // IL: var implicitConversionMethod = SigilExtensions.HasImplicitConversion(baseType, ); + var implicitConversionMethod = il.DeclareLocal(typeof(MethodInfo), $"cast_implicitConversionMethod_{guid}"); + il.LoadLocal(baseType); + il.LoadType(targetType); + il.Call(typeof(SigilExtensions).GetMethod(nameof(HasImplicitConversion), BindingFlags.NonPublic | BindingFlags.Static)); + il.StoreLocal(implicitConversionMethod); + + // IL: castValue; + var castValue = il.DeclareLocal(targetType, $"cast_castValue_{guid}"); + + // IL: if (implicitConversionMethod != null) + il.LoadLocal(implicitConversionMethod); + il.Branch((il) => + { + // IL: var methodInvokeParams = new object[1]; + var methodInvokeParams = il.DeclareLocal(typeof(object[]), $"cast_methodInvokeParams_{guid}"); + il.LoadConstant(1); + il.NewArray(typeof(object)); + il.StoreLocal(methodInvokeParams); + + // IL: methodInvokeParams[0] = value; + il.LoadLocal(methodInvokeParams); + il.LoadConstant(0); + il.LoadLocal(value); + il.StoreElement(); + + // IL: castValue = ()implicitConversionMethod.Invoke(null, methodInvokeParams); + il.LoadLocal(implicitConversionMethod); + il.LoadNull(); // first parameter is null because implicit cast operators are static + il.LoadLocal(methodInvokeParams); + il.Call(typeof(MethodInfo).GetMethod("Invoke", new[] { typeof(object), typeof(object[]) })); + if (targetType.IsValueType) + { + il.UnboxAny(targetType); + } + else + { + il.CastClass(targetType); + } + il.StoreLocal(castValue); + }, + (il) => + { + // IL: castValue = ()value; + il.LoadLocal(value); + if (targetType.IsValueType) + { + il.UnboxAny(targetType); + } + else + { + il.CastClass(targetType); + } + il.StoreLocal(castValue); + }); + + il.LoadLocal(castValue); + } + + /// + /// Emits a call to . + /// + /// The IL emitter. + /// The string format. + /// The local variables passed to string.Format. + public static void FormatString(this Emit il, string format, params Local[] args) + { + if (format == null) throw new ArgumentNullException(nameof(format)); + if (args == null) throw new ArgumentNullException(nameof(args)); + + var guid = Guid.NewGuid().ToString("N"); + + var listType = typeof(List<>).MakeGenericType(typeof(object)); + var list = il.DeclareLocal(listType, $"formatString_list_{guid}"); + il.NewObject(listType); + il.StoreLocal(list); + + foreach (var arg in args) + { + il.LoadLocal(list); + il.LoadLocal(arg); + il.ToObject(arg.LocalType); + il.CallVirtual(listType.GetMethod("Add", new[] { typeof(object) })); + } + + var arr = il.DeclareLocal($"formatString_arr_{guid}"); + il.LoadLocal(list); + il.CallVirtual(listType.GetMethod("ToArray", new Type[0])); + il.StoreLocal(arr); + + il.LoadConstant(format); + il.LoadLocal(arr); + il.Call(typeof(string).GetMethod("Format", new[] { typeof(string), typeof(object[]) })); + } + + /// + /// Emits a call to . + /// + /// The IL emitter. + /// The message to print. + public static void NewMessage(this Emit il, string message) + { + var newMessage = typeof(DebugConsole).GetMethod( + name: nameof(DebugConsole.NewMessage), + bindingAttr: BindingFlags.Public | BindingFlags.Static, + binder: null, + types: new Type[] { typeof(string), typeof(Color?), typeof(bool) }, + modifiers: null); + il.LoadConstant(message); + il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod()); + il.LoadConstant(false); + il.Call(newMessage); + } + + /// + /// Emits a call to , + /// using the string on the stack. + /// + /// The IL emitter. + public static void NewMessage(this Emit il) + { + var newMessage = typeof(DebugConsole).GetMethod( + name: nameof(DebugConsole.NewMessage), + bindingAttr: BindingFlags.Public | BindingFlags.Static, + binder: null, + types: new Type[] { typeof(string), typeof(Color?), typeof(bool) }, + modifiers: null); + il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod()); + il.LoadConstant(false); + il.Call(newMessage); + } + + /// + /// Emits a foreach loop that iterates over an local variable. + /// + /// The type of elements in the enumerable. + /// The IL emitter. + /// The enumerable. + /// The body of code to run on each iteration. + public static void ForEachEnumerable(this Emit il, Local enumerable, Action action) + { + if (enumerable == null) throw new ArgumentNullException(nameof(enumerable)); + if (action == null) throw new ArgumentNullException(nameof(action)); + if (!typeof(IEnumerable).IsAssignableFrom(enumerable.LocalType)) + { + throw new ArgumentException($"Expected local type {typeof(IEnumerator)}; got {enumerable.LocalType}.", nameof(enumerable)); + } + + var guid = Guid.NewGuid().ToString("N"); + + var enumerator = il.DeclareLocal>($"forEachEnumerable_enumerator_{guid}"); + il.LoadLocal(enumerable); + il.CallVirtual(typeof(IEnumerable).GetMethod("GetEnumerator")); + il.StoreLocal(enumerator); + ForEachEnumerator(il, enumerator, action); + } + + /// + /// Emits a foreach loop that iterates over an local variable. + /// + /// The type of elements in the enumerable. + /// The IL emitter. + /// The enumerator. + /// The body of code to run on each iteration. + public static void ForEachEnumerator(this Emit il, Local enumerator, Action action) + { + if (enumerator == null) throw new ArgumentNullException(nameof(enumerator)); + if (action == null) throw new ArgumentNullException(nameof(action)); + if (!typeof(IEnumerator).IsAssignableFrom(enumerator.LocalType)) + { + throw new ArgumentException($"Expected local type {typeof(IEnumerator)}; got {enumerator.LocalType}.", nameof(enumerator)); + } + + var guid = Guid.NewGuid().ToString("N"); + var labelLoopStart = il.DefineLabel($"forEach_loopStart_{guid}"); + var labelMoveNext = il.DefineLabel($"forEach_moveNext_{guid}"); + var labelLeave = il.DefineLabel($"forEach_leave_{guid}"); + + il.BeginExceptionBlock(out var exceptionBlock); + il.Branch(labelMoveNext); // MoveNext() needs to be called at least once before iterating + il.MarkLabel(labelLoopStart); + + // IL: var current = enumerator.Current; + var current = il.DeclareLocal($"forEachEnumerator_current_{guid}"); + il.LoadLocal(enumerator); + il.CallVirtual(enumerator.LocalType.GetProperty("Current").GetGetMethod()); + il.StoreLocal(current); + + action(il, current, labelLeave); + + il.MarkLabel(labelMoveNext); + il.LoadLocal(enumerator); + il.CallVirtual(typeof(IEnumerator).GetMethod("MoveNext")); + il.BranchIfTrue(labelLoopStart); // loop if MoveNext() returns true + + // IL: finally { enumerator.Dispose(); } + il.BeginFinallyBlock(exceptionBlock, out var finallyBlock); + il.LoadLocal(enumerator); + il.CallVirtual(typeof(IDisposable).GetMethod("Dispose")); + il.EndFinallyBlock(finallyBlock); + + il.EndExceptionBlock(exceptionBlock); + + il.MarkLabel(labelLeave); + } + + /// + /// Emits a branch that only executes if the last value on the stack + /// is truthy (e.g. non-null references, 1, etc). + /// + /// The IL emitter. + /// The body of code to run if the value is truthy. + public static void If(this Emit il, Action action) + { + if (action == null) throw new ArgumentNullException(nameof(action)); + il.Branch(@if: action); + } + + /// + /// Emits a branch that only executes if the last value on the stack + /// is falsy (e.g. null references, 0, etc). + /// + /// The IL emitter. + /// The body of code to run if the value is falsy. + public static void IfNot(this Emit il, Action action) + { + if (action == null) throw new ArgumentNullException(nameof(action)); + il.Branch(@else: action); + } + + /// + /// Emits two branches that diverge based on a condition -- analogous + /// to an if-else statement. If either + /// or are omitted, it behaves the same as + /// + /// and . + /// + /// The IL emitter. + /// The body of code to run if the value is truthy. + /// The body of code to run if the value is falsy. + public static void Branch(this Emit il, Action @if = null, Action @else = null) + { + if (@if == null && @else == null) throw new ArgumentException("At least one of the two branches must be defined."); + + var guid = Guid.NewGuid().ToString("N"); + var labelEnd = il.DefineLabel($"branch_end_{guid}"); + if (@if != null && @else != null) + { + var labelElse = il.DefineLabel($"branch_else_{guid}"); + il.BranchIfFalse(labelElse); + @if(il); + il.Branch(labelEnd); + il.MarkLabel(labelElse); + @else(il); + } + else if (@if != null) + { + il.BranchIfFalse(labelEnd); + @if(il); + } + else + { + il.BranchIfTrue(labelEnd); + @else(il); + } + il.MarkLabel(labelEnd); + } + } + + public partial class LuaCsHook + { + public enum HookMethodType + { + Before, After + } + + private class LuaCsHookCallback + { + public string name; + public string hookName; + public LuaCsFunc func; + + public LuaCsHookCallback(string name, string hookName, LuaCsFunc func) + { + this.name = name; + this.hookName = hookName; + this.func = func; + } + } + + private class LuaCsPatch + { + public string Identifier { get; set; } + + public LuaCsPatchFunc PatchFunc { get; set; } + } + + private class PatchedMethod + { + public PatchedMethod(MethodInfo harmonyPrefix, MethodInfo harmonyPostfix) + { + HarmonyPrefixMethod = harmonyPrefix; + HarmonyPostfixMethod = harmonyPostfix; + Prefixes = new Dictionary(); + Postfixes = new Dictionary(); + } + + public MethodInfo HarmonyPrefixMethod { get; } + + public MethodInfo HarmonyPostfixMethod { get; } + + public IEnumerator GetPrefixEnumerator() => Prefixes.Values.GetEnumerator(); + + public IEnumerator GetPostfixEnumerator() => Postfixes.Values.GetEnumerator(); + + public Dictionary Prefixes { get; } + + public Dictionary Postfixes { get; } + } + + public class ParameterTable + { + private readonly Dictionary parameters; + private bool returnValueModified; + private object returnValue; + + public ParameterTable(Dictionary dict) + { + parameters = dict; + } + + public object this[string paramName] + { + get + { + if (ModifiedParameters.TryGetValue(paramName, out var value)) + { + return value; + } + return OriginalParameters[paramName]; + } + set + { + ModifiedParameters[paramName] = value; + } + } + + public object OriginalReturnValue { get; private set; } + + public object ReturnValue + { + get + { + if (returnValueModified) return returnValue; + return OriginalReturnValue; + } + set + { + returnValueModified = true; + returnValue = value; + } + } + + public bool PreventExecution { get; set; } + + public Dictionary OriginalParameters => parameters; + + [MoonSharpHidden] + public Dictionary ModifiedParameters { get; } = new Dictionary(); + } + + private static readonly string[] prohibitedHooks = + { + "Barotrauma.Lua", + "Barotrauma.Cs", + "ContentPackageManager", + }; private static void ValidatePatchTarget(MethodInfo methodInfo) { @@ -65,409 +528,705 @@ namespace Barotrauma } } - private Harmony harmony; - - private Dictionary> hookFunctions; - private Dictionary> hookPrefixMethods; - private Dictionary> hookPostfixMethods; - - private static LuaCsHook instance; - - public LuaCsHook() { - instance = this; - - hookFunctions = new Dictionary>(); - - hookPrefixMethods = new Dictionary>(); - hookPostfixMethods = new Dictionary>(); - - compatHookPrefixMethods = new Dictionary>(); - compatHookPostfixMethods = new Dictionary>(); - } - - public void Initialize() + private static string NormalizeIdentifier(string identifier) { - harmony = new Harmony("LuaCsForBarotrauma"); - - var hookType = UserData.RegisterType(); - var hookDesc = (StandardUserDataDescriptor)hookType; - typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => { - if ( - m.Name.Contains("HookMethod") || - m.Name.Contains("UnhookMethod") || - m.Name.Contains("EnqueueFunction") || - m.Name.Contains("EnqueueTimedFunction") - ) - { - hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default)); - } - }); - } - - public void Add(string name, string hookName, LuaCsFunc hook, ACsMod owner = null) - { - if (name == null || hookName == null || hook == null) - { - throw new ScriptRuntimeException("Hook.Add: name, hookName and hook must not be null."); - } - - name = name.ToLower(); - - if (!hookFunctions.ContainsKey(name)) - { - hookFunctions.Add(name, new Dictionary()); - } - - hookFunctions[name][hookName] = (new LuaCsHookCallback(name, hookName, hook), owner); - } - - public void Remove(string name, string hookName) - { - if (name == null || hookName == null) { return; } - - name = name.ToLower(); - - if (hookFunctions.ContainsKey(name) && hookFunctions[name].ContainsKey(hookName)) - { - hookFunctions[name].Remove(hookName); - } - } - - public void Clear() - { - hookFunctions.Clear(); - - hookPrefixMethods.Clear(); - hookPostfixMethods.Clear(); - - compatHookPrefixMethods.Clear(); - compatHookPostfixMethods.Clear(); - - harmony?.UnpatchAll(); - } - - - public void Update() - { - - } - - private Stopwatch performanceMeasurement = new Stopwatch(); - - [MoonSharpHidden] - public T Call(string name, params object[] args) - { - if (GameMain.LuaCs == null) return default; // FIXME: should this throw an exception? - if (name == null) throw new ArgumentNullException(name); - if (args == null) args = new object[0]; - - name = name.ToLower(); - - if (!hookFunctions.ContainsKey(name)) - { - return default; - } - - T lastResult = default; - - if (!hookFunctions.ContainsKey(name)) - { - return lastResult; - } - - var hooksToRemove = new List(); - foreach ((var key, var tuple) in hookFunctions[name]) - { - if (tuple.Item2 != null && tuple.Item2.IsDisposed) - { - hooksToRemove.Add(key); - continue; - } - - try - { - if (GameMain.LuaCs.PerformanceCounter.EnablePerformanceCounter) - { - performanceMeasurement.Start(); - } - - var result = tuple.Item1.func(args); - if (result != null && !result.IsNil()) - { - lastResult = result.ToObject(); - } - - if (GameMain.LuaCs.PerformanceCounter.EnablePerformanceCounter) - { - performanceMeasurement.Stop(); - GameMain.LuaCs.PerformanceCounter.SetHookElapsedTicks(name, key, performanceMeasurement.ElapsedTicks); - performanceMeasurement.Reset(); - } - } - catch (Exception e) - { - StringBuilder argsSb = new StringBuilder(); - foreach (var arg in args) argsSb.Append(arg + " "); - GameMain.LuaCs.HandleException(e, $"Error in Hook '{name}'->'{key}', with args '{argsSb}':\n{e}", ExceptionType.Both); - } - } - foreach (var key in hooksToRemove) - { - hookFunctions[name].Remove(key); - } - - return lastResult; - } - - public object Call(string name, params object[] args) - { - if (name == null) throw new ScriptRuntimeException("Hook.Call: name must not be null."); - return Call(name, args); + return identifier?.Trim().ToLowerInvariant(); } - private static bool PatchPrefix(MethodBase __originalMethod, object[] __args, object __instance) - { - ExecutePatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before); - return result == null; - } - private static void PatchPostfix(MethodBase __originalMethod, object[] __args, object __instance) => - ExecutePatch(__originalMethod, __args, __instance, out object _, HookMethodType.After); + private Harmony harmony; - private static bool PatchPrefixWithReturn(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) - { - ExecutePatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before); - if (result != null) - { - __result = result; - return false; - } - else { return true; } - } - private static void PatchPostfixWithReturn(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) - { - ExecutePatch(__originalMethod, __args, __instance, out object result, HookMethodType.After); - if (result != null) { __result = result; } - } + private Lazy patchModuleBuilder; + private readonly Dictionary> hookFunctions = new Dictionary>(); - private static readonly MethodInfo miPatchPrefix = typeof(LuaCsHook).GetMethod("PatchPrefix", BindingFlags.NonPublic | BindingFlags.Static); - private static readonly MethodInfo miPatchPostfix = typeof(LuaCsHook).GetMethod("PatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); - private static readonly MethodInfo miPatchPrefixWithReturn = typeof(LuaCsHook).GetMethod("PatchPrefixWithReturn", BindingFlags.NonPublic | BindingFlags.Static); - private static readonly MethodInfo miPatchPostfixWithReturn = typeof(LuaCsHook).GetMethod("PatchPostfixWithReturn", BindingFlags.NonPublic | BindingFlags.Static); + private readonly Dictionary registeredPatches = new Dictionary(); - private static MethodInfo ResolveMethod(string className, string methodName, string[] parameterNames) - { - var classType = LuaUserData.GetType(className); + private static LuaCsHook instance; - if (classType == null) - { - throw new ArgumentNullException($"Invalid class name '{className}'."); - } + private struct MethodKey : IEquatable + { + public ModuleHandle ModuleHandle { get; set; } - MethodInfo methodInfo = null; + public int MetadataToken { get; set; } - if (parameterNames != null) - { - Type[] parameterTypes = parameterNames.Select(x => LuaUserData.GetType(x)).ToArray(); - methodInfo = classType.GetMethod(methodName, DefaultBindingFlags, null, parameterTypes, null); - } - else - { - methodInfo = classType.GetMethod(methodName, DefaultBindingFlags); - } + public override bool Equals(object obj) + { + return obj is MethodKey key && Equals(key); + } - if (methodInfo == null) - { - string parameterNamesStr = parameterNames == null ? "" : string.Join(", ", parameterNames); - throw new ArgumentNullException($"Method '{methodName}' with parameters '{parameterNamesStr}' not found in class '{className}'"); - } + public bool Equals(MethodKey other) + { + return ModuleHandle.Equals(other.ModuleHandle) && MetadataToken == other.MetadataToken; + } - return methodInfo; - } + public override int GetHashCode() + { + return HashCode.Combine(ModuleHandle, MetadataToken); + } - public void Patch(string identifier, MethodInfo method, LuaCsFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null) - { - if (identifier == null || method == null || patch == null) - { - throw new ArgumentNullException("Identifier, Method and Patch arguments must not be null."); - } - ValidatePatchTarget(method); + public static bool operator ==(MethodKey left, MethodKey right) + { + return left.Equals(right); + } - var funcAddr = (long)method.MethodHandle.GetFunctionPointer(); - var patches = Harmony.GetPatchInfo(method); + public static bool operator !=(MethodKey left, MethodKey right) + { + return !(left == right); + } - if (hookType == HookMethodType.Before) - { - if (method.ReturnType != typeof(void)) - { - if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == miPatchPrefixWithReturn) == null) - { - harmony.Patch(method, prefix: new HarmonyMethod(miPatchPrefixWithReturn)); - } - } - else - { - if (patches == null || patches.Prefixes == null || patches.Prefixes.Find(patch => patch.PatchMethod == miPatchPrefix) == null) - { - harmony.Patch(method, prefix: new HarmonyMethod(miPatchPrefix)); - } - } + public static MethodKey Create(MethodInfo method) => new MethodKey + { + ModuleHandle = method.Module.ModuleHandle, + MetadataToken = method.MetadataToken, + }; + } - if (hookPrefixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsFunc, ACsMod)> methodSet)) - { - if (identifier != "") - { - methodSet.RemoveWhere(tuple => tuple.Item1 == identifier); - } + public LuaCsHook() + { + instance = this; + } - methodSet.Add((identifier, patch, owner)); - } - else if (patch != null) - { - hookPrefixMethods.Add(funcAddr, new HashSet<(string, LuaCsFunc, ACsMod)>() { (identifier, patch, owner) }); - } + public void Initialize() + { + harmony = new Harmony("LuaCsForBarotrauma"); + patchModuleBuilder = new Lazy(CreateModuleBuilder); - } - else if (hookType == HookMethodType.After) - { - if (method.ReturnType != typeof(void)) - { - if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == miPatchPostfixWithReturn) == null) - { - harmony.Patch(method, postfix: new HarmonyMethod(miPatchPostfixWithReturn)); - } - } - else - { - if (patches == null || patches.Postfixes == null || patches.Postfixes.Find(patch => patch.PatchMethod == miPatchPostfix) == null) - { - harmony.Patch(method, postfix: new HarmonyMethod(miPatchPostfix)); - } - } + UserData.RegisterType(); + var hookType = UserData.RegisterType(); + var hookDesc = (StandardUserDataDescriptor)hookType; + typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => { + if ( + m.Name.Contains("HookMethod") || + m.Name.Contains("UnhookMethod") || + m.Name.Contains("EnqueueFunction") || + m.Name.Contains("EnqueueTimedFunction") + ) + { + hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default)); + } + }); + } - if (hookPostfixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsFunc, ACsMod)> methodSet)) - { - if (identifier != "") - { - methodSet.RemoveWhere(tuple => tuple.Item1 == identifier); - } + private ModuleBuilder CreateModuleBuilder() + { + var assemblyName = $"LuaCsHookPatch-{Guid.NewGuid():N}"; + var assemblyBuilder = AssemblyBuilder.DefineDynamicAssembly(new AssemblyName(assemblyName), AssemblyBuilderAccess.RunAndCollect); + var moduleBuilder = assemblyBuilder.DefineDynamicModule("LuaCsHookPatch"); - methodSet.Add((identifier, patch, owner)); - } - else if (patch != null) - { - hookPostfixMethods.Add(funcAddr, new HashSet<(string, LuaCsFunc, ACsMod)>() { (identifier, patch, owner) }); - } + // This code emits the Roslyn attribute + // "IgnoresAccessChecksToAttribute" so we can freely access + // the Barotrauma assembly from our dynamic patches. + // This is important because the generated IL references + // non-public types/members. - } + // class IgnoresAccessChecksToAttribute { + var typeBuilder = moduleBuilder.DefineType( + name: "System.Runtime.CompilerServices.IgnoresAccessChecksToAttribute", + attr: TypeAttributes.NotPublic | TypeAttributes.Sealed | TypeAttributes.Class, + parent: typeof(Attribute)); - } + // [AttributeUsage(AllowMultiple = true)] + var attributeUsageAttribute = new CustomAttributeBuilder( + con: typeof(AttributeUsageAttribute).GetConstructor(new[] { typeof(AttributeTargets) }), + constructorArgs: new object[] { AttributeTargets.Assembly }, + namedProperties: new[] { typeof(AttributeUsageAttribute).GetProperty("AllowMultiple") }, + propertyValues: new object[] { true }); + typeBuilder.SetCustomAttribute(attributeUsageAttribute); - private static void ExecutePatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookMethodType) - { - result = null; + // private readonly string assemblyName; + var attributeTypeFieldBuilder = typeBuilder.DefineField( + fieldName: "assemblyName", + type: typeof(string), + attributes: FieldAttributes.Private | FieldAttributes.InitOnly); - try - { - long funcAddr = (long)__originalMethod.MethodHandle.GetFunctionPointer(); + var ctor = Emit.BuildConstructor( + parameterTypes: new[] { typeof(string) }, + type: typeBuilder, + attributes: MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName, + callingConvention: CallingConventions.Standard | CallingConventions.HasThis); + // IL: this.assemblyName = arg; + ctor.LoadArgument(0); + ctor.LoadArgument(1); + ctor.StoreField(attributeTypeFieldBuilder); + ctor.Return(); + ctor.CreateConstructor(); - HashSet<(string, LuaCsFunc, ACsMod)> methodSet = null; - switch (hookMethodType) - { - case HookMethodType.Before: - instance.hookPrefixMethods.TryGetValue(funcAddr, out methodSet); - break; - case HookMethodType.After: - instance.hookPostfixMethods.TryGetValue(funcAddr, out methodSet); - break; - default: - break; - } + // public string AttributeName => this.assemblyName; + var attributeNameGetter = Emit.BuildMethod( + returnType: typeof(string), + parameterTypes: new Type[0], + type: typeBuilder, + name: "get_AttributeName", + attributes: MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName, + callingConvention: CallingConventions.Standard | CallingConventions.HasThis); + attributeNameGetter.LoadArgument(0); + attributeNameGetter.LoadField(attributeTypeFieldBuilder); + attributeNameGetter.Return(); - if (methodSet == null) - { - return; - } + var attributeName = typeBuilder.DefineProperty( + name: "AttributeName", + attributes: PropertyAttributes.None, + returnType: typeof(string), + parameterTypes: null); + attributeName.SetGetMethod(attributeNameGetter.CreateMethod()); + // } - var patchesToRemove = new HashSet<(string, LuaCsFunc, ACsMod)>(); - foreach (var tuple in methodSet) - { - if (tuple.Item3 != null && tuple.Item3.IsDisposed) - { - patchesToRemove.Add(tuple); - continue; - } + var type = typeBuilder.CreateTypeInfo().AsType(); - var args = Enumerable.Empty() - .Concat(__args) - .Prepend(__instance) - .ToArray(); - var _result = tuple.Item2(args); + // The assembly names are hardcoded, otherwise it would + // break unit tests. + var assembliesToExpose = new[] { "Barotrauma", "DedicatedServer" }; + foreach (var name in assembliesToExpose) + { + var attr = new CustomAttributeBuilder( + con: type.GetConstructor(new[] { typeof(string)}), + constructorArgs: new[] { name }); + assemblyBuilder.SetCustomAttribute(attr); + } - if (_result != null && !_result.IsNil()) - { - if (__originalMethod is MethodInfo mi && mi.ReturnType != typeof(void)) - { - result = _result.ToObject(mi.ReturnType); - } - else - { - result = _result.ToObject(); - } - } - } + return moduleBuilder; + } - foreach (var tuple in patchesToRemove) - { - methodSet.Remove(tuple); - } - } - catch (Exception ex) - { - GameMain.LuaCs.HandleException(ex, $"Error in {__originalMethod.Name}:", exceptionType: LuaCsSetup.ExceptionType.Both); - } - } + public void Add(string name, string identifier, LuaCsFunc func, ACsMod owner = null) + { + if (name == null) throw new ArgumentNullException(nameof(name)); + if (identifier == null) throw new ArgumentNullException(nameof(identifier)); + if (func == null) throw new ArgumentNullException(nameof(func)); + name = NormalizeIdentifier(name); + identifier = NormalizeIdentifier(identifier); - public void Patch(string identifier, string className, string methodName, string[] parameterNames, LuaCsFunc patch, HookMethodType hookMethodType = HookMethodType.Before) - { - MethodInfo methodInfo = ResolveMethod(className, methodName, parameterNames); - if (methodInfo == null) return; - Patch(identifier, methodInfo, patch, hookMethodType); - } - public void Patch(string identifier, string className, string methodName, LuaCsFunc patch, HookMethodType hookMethodType = HookMethodType.Before) => - Patch(identifier, className, methodName, null, patch, hookMethodType); - public void Patch(string className, string methodName, LuaCsFunc patch, HookMethodType hookMethodType = HookMethodType.Before) => - Patch("", className, methodName, null, patch, hookMethodType); - public void Patch(string className, string methodName, string[] parameterNames, LuaCsFunc patch, HookMethodType hookMethodType = HookMethodType.Before) => - Patch("", className, methodName, parameterNames, patch, hookMethodType); + if (!hookFunctions.ContainsKey(name)) + { + hookFunctions.Add(name, new Dictionary()); + } + hookFunctions[name][identifier] = (new LuaCsHookCallback(name, identifier, func), owner); + } - public void RemovePatch(string identifier, MethodInfo method, HookMethodType hookType = HookMethodType.Before) - { - var funcAddr = (long)method.MethodHandle.GetFunctionPointer(); + public void Remove(string name, string identifier) + { + if (name == null) throw new ArgumentNullException(nameof(name)); + if (identifier == null) throw new ArgumentNullException(nameof(identifier)); - Dictionary> methods; - if (hookType == HookMethodType.Before) { methods = hookPrefixMethods; } - else if (hookType == HookMethodType.After) { methods = hookPostfixMethods; } - else { throw new NotImplementedException(); } + name = NormalizeIdentifier(name); + identifier = NormalizeIdentifier(identifier); - if (methods.ContainsKey(funcAddr)) - { - methods[funcAddr]?.RemoveWhere(t => t.Item1 == identifier); - } - } + if (hookFunctions.ContainsKey(name) && hookFunctions[name].ContainsKey(identifier)) + { + hookFunctions[name].Remove(identifier); + } + } - public void RemovePatch(string identifier, string className, string methodName, string[] parameterNames, HookMethodType hookType = HookMethodType.Before) - { - MethodInfo methodInfo = ResolveMethod(className, methodName, parameterNames); + public void Clear() + { + harmony?.UnpatchAll(); - if (methodInfo == null) - { - return; - } + hookFunctions.Clear(); + registeredPatches.Clear(); + patchModuleBuilder = null; - RemovePatch(identifier, methodInfo, hookType); - } - } -} \ No newline at end of file + compatHookPrefixMethods.Clear(); + compatHookPostfixMethods.Clear(); + } + + public void Update() { } + + private Stopwatch performanceMeasurement = new Stopwatch(); + + [MoonSharpHidden] + public T Call(string name, params object[] args) + { + if (GameMain.LuaCs == null) throw new InvalidOperationException("Can't call hooks before LuaCsHook is initialized."); + if (name == null) throw new ArgumentNullException(name); + if (args == null) args = new object[0]; + + name = NormalizeIdentifier(name); + if (!hookFunctions.ContainsKey(name)) return default; + + T lastResult = default; + + var hooksToRemove = new List(); + foreach ((var key, var tuple) in hookFunctions[name]) + { + if (tuple.Item2 != null && tuple.Item2.IsDisposed) + { + hooksToRemove.Add(key); + continue; + } + + try + { + if (GameMain.LuaCs.PerformanceCounter.EnablePerformanceCounter) + { + performanceMeasurement.Start(); + } + + var result = tuple.Item1.func(args); + // TODO(BREAKING): change this to !result.IsVoid() + if (result != null && !result.IsNil()) + { + lastResult = result.ToObject(); + } + + if (GameMain.LuaCs.PerformanceCounter.EnablePerformanceCounter) + { + performanceMeasurement.Stop(); + GameMain.LuaCs.PerformanceCounter.SetHookElapsedTicks(name, key, performanceMeasurement.ElapsedTicks); + performanceMeasurement.Reset(); + } + } + catch (Exception e) + { + var argsSb = new StringBuilder(); + foreach (var arg in args) + { + argsSb.Append(arg + " "); + } + GameMain.LuaCs.HandleException(e, $"Error in Hook '{name}'->'{key}', with args '{argsSb}':\n{e}", ExceptionType.Both); + } + } + foreach (var key in hooksToRemove) + { + hookFunctions[name].Remove(key); + } + + return lastResult; + } + + public object Call(string name, params object[] args) => Call(name, args); + + private static MethodInfo ResolveMethod(string className, string methodName, string[] parameterNames) + { + var classType = LuaUserData.GetType(className); + if (classType == null) throw new InvalidOperationException($"Invalid class name '{className}'"); + + const BindingFlags BINDING_FLAGS = BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; + MethodInfo methodInfo = null; + if (parameterNames != null) + { + var parameterTypes = parameterNames.Select(x => LuaUserData.GetType(x)).ToArray(); + methodInfo = classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null); + } + else + { + methodInfo = classType.GetMethod(methodName, BINDING_FLAGS); + } + + if (methodInfo == null) + { + var parameterNamesStr = parameterNames == null ? "" : string.Join(", ", parameterNames); + throw new InvalidOperationException($"Method '{methodName}({parameterNamesStr})' not found in class '{className}'"); + } + + return methodInfo; + } + + private class DynamicParameterMapping + { + public DynamicParameterMapping(string name, Type originalMethodParamType, Type harmonyPatchParamType) + { + ParameterName = name; + OriginalMethodParamType = originalMethodParamType; + HarmonyPatchParamType = harmonyPatchParamType; + } + + public string ParameterName { get; set; } + + public Type OriginalMethodParamType { get; set; } + + public Type HarmonyPatchParamType { get; set; } + } + + private static readonly Regex InvalidIdentifierCharsRegex = new Regex(@"[^\w\d]", RegexOptions.Compiled); + + // If you need to debug this: + // - use https://sharplab.io ; it's a very useful for resource for writing IL by hand. + // - use il.NewMessage("") or il.WriteLine("") to see where the IL crashes at runtime. + private MethodInfo CreateDynamicHarmonyPatch(string identifier, MethodInfo original, HookMethodType hookType) + { + var parameters = new List + { + new DynamicParameterMapping("__originalMethod", null, typeof(MethodBase)), + new DynamicParameterMapping("__instance", null, typeof(object)), + }; + + var hasReturnType = original.ReturnType != typeof(void); + if (hasReturnType) + { + parameters.Add(new DynamicParameterMapping("__result", null, typeof(object).MakeByRefType())); + } + + foreach (var parameter in original.GetParameters()) + { + var paramName = parameter.Name; + var originalMethodParamType = parameter.ParameterType; + var harmonyPatchParamType = originalMethodParamType.IsByRef + ? originalMethodParamType + // Make all parameters modifiable by the harmony patch + : originalMethodParamType.MakeByRefType(); + parameters.Add(new DynamicParameterMapping(paramName, originalMethodParamType, harmonyPatchParamType)); + } + + static string MangleName(object o) => InvalidIdentifierCharsRegex.Replace(o?.ToString(), "_"); + + var moduleBuilder = patchModuleBuilder.Value; + var mangledName = original.DeclaringType != null + ? $"{MangleName(original.DeclaringType)}-{MangleName(original)}" + : MangleName(original); + var typeBuilder = moduleBuilder.DefineType($"Patch_{identifier}_{Guid.NewGuid():N}_{mangledName}", TypeAttributes.Public); + + var methodName = hookType == HookMethodType.Before ? "HarmonyPrefix" : "HarmonyPostfix"; + var il = Emit.BuildMethod( + returnType: hookType == HookMethodType.Before ? typeof(bool) : typeof(void), + parameterTypes: parameters.Select(x => x.HarmonyPatchParamType).ToArray(), + type: typeBuilder, + name: methodName, + attributes: MethodAttributes.Public | MethodAttributes.Static, + callingConvention: CallingConventions.Standard); + + var labelReturn = il.DefineLabel("endOfFunction"); + + il.BeginExceptionBlock(out var exceptionBlock); + + // IL: var harmonyReturnValue = true; + var harmonyReturnValue = il.DeclareLocal("harmonyReturnValue"); + il.LoadConstant(true); + il.StoreLocal(harmonyReturnValue); + + // IL: var patchKey = MethodKey.Create(__originalMethod); + var patchKey = il.DeclareLocal("patchKey"); + il.LoadArgument(0); // load __originalMethod + il.CastClass(); + il.Call(typeof(MethodKey).GetMethod(nameof(MethodKey.Create))); + il.StoreLocal(patchKey); + + // IL: var patchExists = instance.registeredPatches.TryGetValue(patchKey, out MethodPatches patches) + var patchExists = il.DeclareLocal("patchExists"); + var patches = il.DeclareLocal("patches"); + il.LoadField(typeof(LuaCsHook).GetField(nameof(instance), BindingFlags.NonPublic | BindingFlags.Static)); + il.LoadField(typeof(LuaCsHook).GetField(nameof(registeredPatches), BindingFlags.NonPublic | BindingFlags.Instance)); + il.LoadLocal(patchKey); + il.LoadLocalAddress(patches); // out parameter + il.Call(typeof(Dictionary).GetMethod("TryGetValue")); + il.StoreLocal(patchExists); + + // IL: if (!patchExists) + il.LoadLocal(patchExists); + il.IfNot((il) => + { + // XXX: if we get here, it's probably because a patched + // method was running when `reloadlua` was executed. + // This can happen with a postfix on + // `Barotrauma.Networking.GameServer#Update`. + il.Leave(labelReturn); + }); + + // IL: var parameterDict = new Dictionary(); + var parameterDict = il.DeclareLocal>("parameterDict"); + il.LoadConstant(parameters.Count(x => x.OriginalMethodParamType != null)); // preallocate the dictionary using the # of args + il.NewObject(typeof(Dictionary), typeof(int)); + il.StoreLocal(parameterDict); + + for (ushort i = 0; i < parameters.Count; i++) + { + // Skip parameters that don't exist in the original method + if (parameters[i].OriginalMethodParamType == null) continue; + + // IL: parameterDict.Add(, ); + il.LoadLocal(parameterDict); + il.LoadConstant(parameters[i].ParameterName); + il.LoadArgument(i); + il.ToObject(parameters[i].HarmonyPatchParamType); + il.Call(typeof(Dictionary).GetMethod("Add")); + } + + // IL: var ptable = new ParameterTable(parameterDict); + var ptable = il.DeclareLocal("ptable"); + il.LoadLocal(parameterDict); + il.NewObject(typeof(ParameterTable), typeof(Dictionary)); + il.StoreLocal(ptable); + + if (hasReturnType && hookType == HookMethodType.After) + { + // IL: ptable.OriginalReturnValue = __result; + il.LoadLocal(ptable); + il.LoadArgument(2); // ref __result + il.ToObject(parameters[2].HarmonyPatchParamType); + il.Call(typeof(ParameterTable).GetProperty(nameof(ParameterTable.OriginalReturnValue)).GetSetMethod(nonPublic: true)); + } + + // IL: var enumerator = patches.GetPrefixEnumerator(); + var enumerator = il.DeclareLocal>("enumerator"); + il.LoadLocal(patches); + il.CallVirtual(typeof(PatchedMethod).GetMethod( + name: hookType == HookMethodType.Before + ? nameof(PatchedMethod.GetPrefixEnumerator) + : nameof(PatchedMethod.GetPostfixEnumerator), + bindingAttr: BindingFlags.Public | BindingFlags.Instance)); + il.StoreLocal(enumerator); + + var labelUpdateParameters = il.DefineLabel("updateParameters"); + + // Iterate over prefixes/postfixes + il.ForEachEnumerator(enumerator, (il, current, labelLeave) => + { + // IL: var luaReturnValue = current.PatchFunc.Invoke(__instance, ptable); + var luaReturnValue = il.DeclareLocal("luaReturnValue"); + il.LoadLocal(current); + il.Call(typeof(LuaCsPatch).GetProperty(nameof(LuaCsPatch.PatchFunc)).GetGetMethod()); + il.LoadArgument(1); // __instance + il.LoadLocal(ptable); + il.CallVirtual(typeof(LuaCsPatchFunc).GetMethod("Invoke")); + il.StoreLocal(luaReturnValue); + + if (hasReturnType) + { + // IL: var ptableReturnValue = ptable.ReturnValue; + var ptableReturnValue = il.DeclareLocal("ptableReturnValue"); + il.LoadLocal(ptable); + il.Call(typeof(ParameterTable).GetProperty(nameof(ParameterTable.ReturnValue)).GetGetMethod()); + il.StoreLocal(ptableReturnValue); + + // IL: if (ptableReturnValue != null) + il.LoadLocal(ptableReturnValue); + il.If((il) => + { + // IL: __result = ptableReturnValue; + il.LoadArgument(2); // ref __result + il.LoadLocal(ptableReturnValue); + il.StoreIndirect(typeof(object)); + il.Break(); + }); + + // IL: if (luaReturnValue != null) + il.LoadLocal(luaReturnValue); + il.If((il) => + { + // IL: if (!luaReturnValue.IsVoid()) + il.LoadLocal(luaReturnValue); + il.Call(typeof(DynValue).GetMethod(nameof(DynValue.IsVoid))); + il.IfNot((il) => + { + // IL: var csReturnType = Type.GetTypeFromHandle(); + var csReturnType = il.DeclareLocal("csReturnType"); + il.LoadType(original.ReturnType); + il.StoreLocal(csReturnType); + + // IL: var csReturnValue = luaReturnValue.ToObject(csReturnValueType); + var csReturnValue = il.DeclareLocal("csReturnValue"); + il.LoadLocal(luaReturnValue); + il.LoadLocal(csReturnType); + il.Call(typeof(DynValue).GetMethod( + name: nameof(DynValue.ToObject), + bindingAttr: BindingFlags.Public | BindingFlags.Instance, + binder: null, + types: new Type[] { typeof(Type) }, + modifiers: null)); + il.StoreLocal(csReturnValue); + + // IL: __result = csReturnValue; + il.LoadArgument(2); // ref __result + il.LoadLocal(csReturnValue); + il.StoreIndirect(typeof(object)); + }); + }); + } + + // IL: if (ptable.PreventExecution) + il.LoadLocal(ptable); + il.Call(typeof(ParameterTable).GetProperty(nameof(ParameterTable.PreventExecution)).GetGetMethod()); + il.If((il) => + { + // IL: harmonyReturnValue = false; + il.LoadConstant(false); + il.StoreLocal(harmonyReturnValue); + + // IL: break; + il.Leave(labelLeave); + }); + }); + + // IL: var modifiedParameters = ptable.ModifiedParameters; + var modifiedParameters = il.DeclareLocal>("modifiedParameters"); + il.LoadLocal(ptable); + il.Call(typeof(ParameterTable).GetProperty(nameof(ParameterTable.ModifiedParameters)).GetGetMethod()); + il.StoreLocal(modifiedParameters); + // IL: object modifiedValue; + var modifiedValue = il.DeclareLocal("modifiedValue"); + + // Update the parameters + for (ushort i = 0; i < parameters.Count; i++) + { + // Skip parameters that don't exist in the original method + if (parameters[i].OriginalMethodParamType == null) continue; + + // IL: if (modifiedParameters.TryGetValue("parameterName", out modifiedValue)) + il.LoadLocal(modifiedParameters); + il.LoadConstant(parameters[i].ParameterName); + il.LoadLocalAddress(modifiedValue); // out parameter + il.Call(typeof(Dictionary).GetMethod(nameof(Dictionary.TryGetValue))); + il.If((il) => + { + // XXX: GetElementType() gets the "real" type behind + // the ByRef. This is safe because all the parameters + // are made into ByRef to support modification. + var paramType = parameters[i].HarmonyPatchParamType.GetElementType(); + + // IL: ref argName = modifiedValue; + il.LoadArgument(i); + il.LoadLocalAndCast(modifiedValue, paramType); + if (paramType.IsValueType) + { + il.StoreObject(paramType); + } + else + { + il.StoreIndirect(paramType); + } + }); + } + + il.MarkLabel(labelReturn); + + // IL: catch (Exception exception) + il.BeginCatchAllBlock(exceptionBlock, out var catchBlock); + var exception = il.DeclareLocal("exception"); + il.StoreLocal(exception); + + // IL: var luaCsSetup = GameMain.LuaCs; + var luaCsSetup = il.DeclareLocal("luaCsSetup"); + il.LoadField(typeof(GameMain).GetField(nameof(GameMain.LuaCs), BindingFlags.Public | BindingFlags.Static)); + il.StoreLocal(luaCsSetup); + + // IL: luaCsSetup.HandleException(exception, "", ExceptionType.Lua); + il.LoadLocal(luaCsSetup); + il.LoadLocal(exception); + il.LoadConstant(""); + il.LoadConstant((int)ExceptionType.Lua); // underlying enum type is int + il.Call(typeof(LuaCsSetup).GetMethod(nameof(LuaCsSetup.HandleException))); + il.EndCatchBlock(catchBlock); + + il.EndExceptionBlock(exceptionBlock); + + // Only prefixes return a bool + if (hookType == HookMethodType.Before) + { + il.LoadLocal(harmonyReturnValue); + } + il.Return(); + + var method = il.CreateMethod(); + for (var i = 0; i < parameters.Count; i++) + { + method.DefineParameter(i + 1, ParameterAttributes.None, parameters[i].ParameterName); + } + + var type = typeBuilder.CreateType(); + return type.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static); + } + + private string Patch(string identifier, MethodInfo method, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) + { + if (method == null) throw new ArgumentNullException(nameof(method)); + if (patch == null) throw new ArgumentNullException(nameof(patch)); + ValidatePatchTarget(method); + + identifier ??= Guid.NewGuid().ToString("N"); + identifier = NormalizeIdentifier(identifier); + + var patchKey = MethodKey.Create(method); + if (!registeredPatches.TryGetValue(patchKey, out var methodPatches)) + { + var harmonyPrefix = CreateDynamicHarmonyPatch(identifier, method, HookMethodType.Before); + var harmonyPostfix = CreateDynamicHarmonyPatch(identifier, method, HookMethodType.After); + harmony.Patch(method, prefix: new HarmonyMethod(harmonyPrefix), postfix: new HarmonyMethod(harmonyPostfix)); + methodPatches = registeredPatches[patchKey] = new PatchedMethod(harmonyPrefix, harmonyPostfix); + } + + if (hookType == HookMethodType.Before) + { + if (methodPatches.Prefixes.Remove(identifier)) + { + PrintLogMessage($"Replacing existing prefix: {identifier}"); + } + + methodPatches.Prefixes.Add(identifier, new LuaCsPatch + { + Identifier = identifier, + PatchFunc = patch, + }); + } + else if (hookType == HookMethodType.After) + { + if (methodPatches.Postfixes.Remove(identifier)) + { + PrintLogMessage($"Replacing existing postfix: {identifier}"); + } + + methodPatches.Postfixes.Add(identifier, new LuaCsPatch + { + Identifier = identifier, + PatchFunc = patch, + }); + } + + return identifier; + } + + public string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) + { + var methodInfo = ResolveMethod(className, methodName, parameterTypes); + return Patch(identifier, methodInfo, patch, hookType); + } + + public string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) + { + var methodInfo = ResolveMethod(className, methodName, null); + return Patch(identifier, methodInfo, patch, hookType); + } + + public string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) + { + var methodInfo = ResolveMethod(className, methodName, parameterTypes); + return Patch(null, methodInfo, patch, hookType); + } + + public string Patch(string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before) + { + var methodInfo = ResolveMethod(className, methodName, null); + return Patch(null, methodInfo, patch, hookType); + } + + private bool RemovePatch(string identifier, MethodInfo method, HookMethodType hookType) + { + if (identifier == null) throw new ArgumentNullException(nameof(identifier)); + identifier = NormalizeIdentifier(identifier); + + var patchKey = MethodKey.Create(method); + if (!registeredPatches.TryGetValue(patchKey, out var methodPatches)) + { + return false; + } + + return hookType switch + { + HookMethodType.Before => methodPatches.Prefixes.Remove(identifier), + HookMethodType.After => methodPatches.Postfixes.Remove(identifier), + _ => throw new ArgumentException($"Invalid {nameof(HookMethodType)} enum value.", nameof(hookType)), + }; + } + + public bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, HookMethodType hookType) + { + var methodInfo = ResolveMethod(className, methodName, parameterTypes); + return RemovePatch(identifier, methodInfo, hookType); + } + + public bool RemovePatch(string identifier, string className, string methodName, HookMethodType hookType) + { + var methodInfo = ResolveMethod(className, methodName, null); + return RemovePatch(identifier, methodInfo, hookType); + } + } +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs index c009a8d24..0967b6165 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHookCompat.cs @@ -1,31 +1,31 @@ using System; using System.Linq; using System.Reflection; -using MoonSharp.Interpreter; using HarmonyLib; using System.Collections.Generic; -using System.Text; -using MoonSharp.Interpreter.Interop; +using MoonSharp.Interpreter; using static Barotrauma.LuaCsSetup; -using System.Threading; -using System.Diagnostics; +using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch; namespace Barotrauma { - partial class LuaCsHook - { - private Dictionary> compatHookPrefixMethods; - private Dictionary> compatHookPostfixMethods; + // XXX: this can't be renamed because of backward compatibility with C# mods + public delegate object LuaCsPatch(object self, Dictionary args); - private static void _hookLuaCsPatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookMethodType) + partial class LuaCsHook + { + private Dictionary> compatHookPrefixMethods = new Dictionary>(); + private Dictionary> compatHookPostfixMethods = new Dictionary>(); + + private static void _hookLuaCsPatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookType) { result = null; try { var funcAddr = ((long)__originalMethod.MethodHandle.GetFunctionPointer()); - HashSet<(string, LuaCsPatch, ACsMod)> methodSet = null; - switch (hookMethodType) + HashSet<(string, LuaCsCompatPatchFunc, ACsMod)> methodSet = null; + switch (hookType) { case HookMethodType.Before: instance.compatHookPrefixMethods.TryGetValue(funcAddr, out methodSet); @@ -34,7 +34,7 @@ namespace Barotrauma instance.compatHookPostfixMethods.TryGetValue(funcAddr, out methodSet); break; default: - break; + throw new ArgumentException($"Invalid {nameof(HookMethodType)} enum value.", nameof(hookType)); } if (methodSet != null) @@ -46,7 +46,7 @@ namespace Barotrauma args.Add(@params[i].Name, __args[i]); } - var outOfSocpe = new HashSet<(string, LuaCsPatch, ACsMod)>(); + var outOfSocpe = new HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>(); foreach (var tuple in methodSet) { if (tuple.Item3 != null && tuple.Item3.IsDisposed) @@ -94,6 +94,7 @@ namespace Barotrauma _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.Before); return result == null; } + private static void HookLuaCsPatchPostfix(MethodBase __originalMethod, object[] __args, object __instance) => _hookLuaCsPatch(__originalMethod, __args, __instance, out object _, HookMethodType.After); @@ -107,50 +108,19 @@ namespace Barotrauma } else return true; } + private static void HookLuaCsPatchRetPostfix(MethodBase __originalMethod, object[] __args, ref object __result, object __instance) { _hookLuaCsPatch(__originalMethod, __args, __instance, out object result, HookMethodType.After); if (result != null) __result = result; } - private static MethodInfo _miHookLuaCsPatchPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static); private static MethodInfo _miHookLuaCsPatchPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static); private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo ResolveMethod(string where, string className, string methodName, string[] parameterNames) - { - var classType = LuaUserData.GetType(className); - - if (classType == null) - { - GameMain.LuaCs.HandleException(new Exception($"Tried to use {where} with an invalid class name '{className}'.")); - return null; - } - - MethodInfo methodInfo = null; - - if (parameterNames != null) - { - Type[] parameterTypes = parameterNames.Select(x => LuaUserData.GetType(x)).ToArray(); - methodInfo = classType.GetMethod(methodName, DefaultBindingFlags, null, parameterTypes, null); - } - else - { - methodInfo = classType.GetMethod(methodName, DefaultBindingFlags); - } - - if (methodInfo == null) - { - string parameterNamesStr = parameterNames == null ? "" : string.Join(", ", parameterNames); - GameMain.LuaCs.HandleException(new Exception($"Method '{methodName}' with parameters '{parameterNamesStr}' not found in class '{className}'")); - } - - return methodInfo; - } - - public void HookMethod(string identifier, MethodInfo method, LuaCsPatch patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null) + public void HookMethod(string identifier, MethodInfo method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null) { if (identifier == null || method == null || patch == null) { @@ -179,7 +149,7 @@ namespace Barotrauma } } - if (compatHookPrefixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsPatch, ACsMod)> methodSet)) + if (compatHookPrefixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsCompatPatchFunc, ACsMod)> methodSet)) { if (identifier != "") { @@ -190,7 +160,7 @@ namespace Barotrauma } else if (patch != null) { - compatHookPrefixMethods.Add(funcAddr, new HashSet<(string, LuaCsPatch, ACsMod)>() { (identifier, patch, owner) }); + compatHookPrefixMethods.Add(funcAddr, new HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>() { (identifier, patch, owner) }); } } @@ -211,7 +181,7 @@ namespace Barotrauma } } - if (compatHookPostfixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsPatch, ACsMod)> methodSet)) + if (compatHookPostfixMethods.TryGetValue(funcAddr, out HashSet<(string, LuaCsCompatPatchFunc, ACsMod)> methodSet)) { if (identifier != "") { @@ -222,25 +192,25 @@ namespace Barotrauma } else if (patch != null) { - compatHookPostfixMethods.Add(funcAddr, new HashSet<(string, LuaCsPatch, ACsMod)>() { (identifier, patch, owner) }); + compatHookPostfixMethods.Add(funcAddr, new HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>() { (identifier, patch, owner) }); } - } - } - - protected void HookMethod(string identifier, string className, string methodName, string[] parameterNames, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) + protected void HookMethod(string identifier, string className, string methodName, string[] parameterNames, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) { - - MethodInfo methodInfo = ResolveMethod("HookMethod", className, methodName, parameterNames); + var methodInfo = ResolveMethod(className, methodName, parameterNames); if (methodInfo == null) return; + if (methodInfo.GetParameters().Any(x => x.ParameterType.IsByRef)) + { + throw new InvalidOperationException($"{nameof(HookMethod)} doesn't support ByRef parameters; use {nameof(Patch)} instead."); + } HookMethod(identifier, methodInfo, patch, hookMethodType); } - protected void HookMethod(string identifier, string className, string methodName, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) => + protected void HookMethod(string identifier, string className, string methodName, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) => HookMethod(identifier, className, methodName, null, patch, hookMethodType); - protected void HookMethod(string className, string methodName, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) => + protected void HookMethod(string className, string methodName, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) => HookMethod("", className, methodName, null, patch, hookMethodType); - protected void HookMethod(string className, string methodName, string[] parameterNames, LuaCsPatch patch, HookMethodType hookMethodType = HookMethodType.Before) => + protected void HookMethod(string className, string methodName, string[] parameterNames, LuaCsCompatPatchFunc patch, HookMethodType hookMethodType = HookMethodType.Before) => HookMethod("", className, methodName, parameterNames, patch, hookMethodType); @@ -248,7 +218,7 @@ namespace Barotrauma { var funcAddr = ((long)method.MethodHandle.GetFunctionPointer()); - Dictionary> methods; + Dictionary> methods; if (hookType == HookMethodType.Before) methods = compatHookPrefixMethods; else if (hookType == HookMethodType.After) methods = compatHookPostfixMethods; else throw null; @@ -257,7 +227,7 @@ namespace Barotrauma } protected void UnhookMethod(string identifier, string className, string methodName, string[] parameterNames, HookMethodType hookType = HookMethodType.Before) { - MethodInfo methodInfo = ResolveMethod("UnhookMathod", className, methodName, parameterNames); + var methodInfo = ResolveMethod(className, methodName, parameterNames); if (methodInfo == null) return; UnhookMethod(identifier, methodInfo, hookType); } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs index 9df6d4931..e5ca284cd 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs @@ -12,6 +12,7 @@ using System.Runtime.CompilerServices; using System.Linq; using System.Reflection; using System.Threading; +using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch; [assembly: InternalsVisibleTo(Barotrauma.CsScriptBase.CsScriptAssembly, AllInternalsVisible = true)] [assembly: InternalsVisibleTo(Barotrauma.CsScriptBase.CsOneTimeScriptAssembly, AllInternalsVisible = true)] @@ -396,7 +397,8 @@ namespace Barotrauma UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); - UserData.RegisterType(); + UserData.RegisterType(); + UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType();