diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs index 18ade84cd..c2b75b325 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs @@ -190,6 +190,7 @@ namespace Barotrauma servicesProvider.RegisterServiceResolver(factory => factory.GetInstance()); servicesProvider.RegisterServiceType(ServiceLifetime.Singleton); servicesProvider.RegisterServiceType(ServiceLifetime.Singleton); + servicesProvider.RegisterServiceType(ServiceLifetime.Singleton); servicesProvider.RegisterServiceType(ServiceLifetime.Singleton); servicesProvider.RegisterServiceType(ServiceLifetime.Singleton); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Compatibility/ILuaCsHook.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Compatibility/ILuaCsHook.cs index e596135e7..6ef5d256e 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Compatibility/ILuaCsHook.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Compatibility/ILuaCsHook.cs @@ -1,10 +1,9 @@ using System; using System.Reflection; -using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch; namespace Barotrauma.LuaCs.Services.Compatibility; -public interface ILuaCsHook : ILuaCsShim +public interface ILuaCsHook : ILuaPatcher, ILuaCsShim { // Event Services [Obsolete("ACsMod is deprecated. Use ILuaEventService.Add() instead.")] @@ -15,18 +14,8 @@ public interface ILuaCsHook : ILuaCsShim //bool Exists(string eventName, string identifier); object Call(string eventName, params object[] args); T Call(string eventName, params object[] args); - - // Hook/Method Patching - string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before); - string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before); - string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before); - string Patch(string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before); - bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, HookMethodType hookType); - bool RemovePatch(string identifier, string className, string methodName, HookMethodType hookType); - void HookMethod(string identifier, MethodBase method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, IAssemblyPlugin owner = null); - - + // Needs to be here instead of ILuaPatcher for compatiility purposes public enum HookMethodType { Before, After diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/EventService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/EventService.cs index c4c96d7c6..e8141201c 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/EventService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/EventService.cs @@ -1,15 +1,18 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Linq; -using Barotrauma.Extensions; +using Barotrauma.Extensions; using Barotrauma.LuaCs.Events; using Barotrauma.LuaCs.Services.Compatibility; using FluentResults; using FluentResults.LuaCs; +using HarmonyLib; using Microsoft.Toolkit.Diagnostics; using OneOf; +using RestSharp; +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Reflection; namespace Barotrauma.LuaCs.Services; @@ -51,13 +54,14 @@ public partial class EventService : IEventService public static implicit operator TypeStringKey(string typeName) => new(typeName); } - private ILoggerService _loggerService; + private readonly ILoggerService _loggerService; + private readonly ILuaPatcher _luaPatcher; private readonly AsyncReaderWriterLock _operationsLock = new(); private readonly ConcurrentDictionary, IEvent>> _subscribers = new(); private readonly ConcurrentDictionary RunnerFactory)> _luaAliasEventFactory = new(); private readonly ConcurrentDictionary> _luaLegacyEventsSubscribers = new(); - #region Disposal + #region LifeCycle public void Dispose() { @@ -70,13 +74,15 @@ public partial class EventService : IEventService _luaLegacyEventsSubscribers.Clear(); _luaAliasEventFactory.Clear(); _subscribers.Clear(); + _luaPatcher.Dispose(); } private int _isDisposed; - public EventService(ILoggerService loggerService) + public EventService(ILoggerService loggerService, ILuaPatcher luaPatcher) { _loggerService = loggerService; + _luaPatcher = luaPatcher; } public bool IsDisposed @@ -91,6 +97,7 @@ public partial class EventService : IEventService _luaLegacyEventsSubscribers.Clear(); _luaAliasEventFactory.Clear(); _subscribers.Clear(); + _luaPatcher.Reset(); return FluentResults.Result.Ok(); } @@ -299,4 +306,41 @@ public partial class EventService : IEventService return results; } + + #region LuaPatcherAdapter + public string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, LuaCsHook.HookMethodType hookType = LuaCsHook.HookMethodType.Before) + { + return _luaPatcher.Patch(identifier, className, methodName, parameterTypes, patch, hookType); + } + + public string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, LuaCsHook.HookMethodType hookType = LuaCsHook.HookMethodType.Before) + { + return _luaPatcher.Patch(identifier, className, methodName, patch, hookType); + } + + public string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, LuaCsHook.HookMethodType hookType = LuaCsHook.HookMethodType.Before) + { + return _luaPatcher.Patch(className, methodName, parameterTypes, patch, hookType); + } + + public string Patch(string className, string methodName, LuaCsPatchFunc patch, LuaCsHook.HookMethodType hookType = LuaCsHook.HookMethodType.Before) + { + return _luaPatcher.Patch(className, methodName, patch, hookType); + } + + public bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsHook.HookMethodType hookType) + { + return _luaPatcher.RemovePatch(className, methodName, methodName, parameterTypes, hookType); + } + + public bool RemovePatch(string identifier, string className, string methodName, LuaCsHook.HookMethodType hookType) + { + return _luaPatcher.RemovePatch(className, methodName, methodName, hookType); + } + + public void HookMethod(string identifier, MethodBase method, LuaCsPatch patch, LuaCsHook.HookMethodType hookType = LuaCsHook.HookMethodType.Before, IAssemblyPlugin owner = null) + { + _luaPatcher.HookMethod(identifier, method, patch, hookType, owner); + } + #endregion } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ILuaPatcher.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ILuaPatcher.cs new file mode 100644 index 000000000..eba27adba --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ILuaPatcher.cs @@ -0,0 +1,16 @@ +using System.Reflection; +using static Barotrauma.LuaCs.Services.Compatibility.ILuaCsHook; +using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch; + +namespace Barotrauma.LuaCs.Services; + +public interface ILuaPatcher : IReusableService +{ + string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before); + string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before); + string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before); + string Patch(string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before); + bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, HookMethodType hookType); + bool RemovePatch(string identifier, string className, string methodName, HookMethodType hookType); + void HookMethod(string identifier, MethodBase method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, IAssemblyPlugin owner = null); +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcherCompat.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaPatcherCompat.cs similarity index 95% rename from Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcherCompat.cs rename to Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaPatcherCompat.cs index 8b0d93306..b41111378 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcherCompat.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaPatcherCompat.cs @@ -1,5 +1,4 @@ -//global using LuaCsHook = Barotrauma.LuaCs.Services.EventService; -global using LuaCsHook = Barotrauma.LuaCs.Services.Compatibility.ILuaCsHook; +global using LuaCsHook = Barotrauma.LuaCs.Services.Compatibility.ILuaCsHook; using System; using System.Linq; @@ -18,8 +17,10 @@ namespace Barotrauma namespace Barotrauma.LuaCs.Services { - partial class EventService + partial class LuaPatcherService { + private static LuaPatcherService instance; + private Dictionary> compatHookPrefixMethods = new Dictionary>(); private Dictionary> compatHookPostfixMethods = new Dictionary>(); @@ -122,10 +123,10 @@ namespace Barotrauma.LuaCs.Services if (result != null) __result = result; } - private static MethodInfo _miHookLuaCsPatchPrefix = typeof(EventService).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo _miHookLuaCsPatchPostfix = typeof(EventService).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(EventService).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(EventService).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchPrefix = typeof(LuaPatcherService).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchPostfix = typeof(LuaPatcherService).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(LuaPatcherService).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaPatcherService).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); // TODO: deprecate this diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcher.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaPatcherService.cs similarity index 66% rename from Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcher.cs rename to Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaPatcherService.cs index f45c73e4e..5d51769f7 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcher.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaPatcherService.cs @@ -20,401 +20,12 @@ namespace Barotrauma { public delegate void LuaCsAction(params object[] args); public delegate object LuaCsFunc(params object[] args); - public delegate DynValue LuaCsPatchFunc(object instance, EventService.ParameterTable ptable); + public delegate DynValue LuaCsPatchFunc(object instance, LuaPatcherService.ParameterTable ptable); } namespace Barotrauma.LuaCs.Services { - internal static class SigilExtensions - { - /// - /// Puts a type on the stack, as a object instead of a - /// runtime type token. - /// - /// The IL emitter. - /// The type to put on the stack. - public static void LoadType(this Emit il, Type type) - { - if (type == null) throw new ArgumentNullException(nameof(type)); - il.LoadConstant(type); // ldtoken - // This converts the type token into a Type object - il.Call(typeof(Type).GetMethod( - name: nameof(Type.GetTypeFromHandle), - bindingAttr: BindingFlags.Public | BindingFlags.Static, - binder: null, - types: new Type[] { typeof(RuntimeTypeHandle) }, - modifiers: null)); - } - - /// - /// Converts the value on the stack to . - /// - /// The IL emitter. - /// The type of the value on the stack. - public static void ToObject(this Emit il, Type type) - { - if (type == null) throw new ArgumentNullException(nameof(type)); - il.DerefIfByRef(ref type); - if (type.IsValueType) - { - il.Box(type); - } - else if (type != typeof(object)) - { - il.CastClass(); - } - } - - /// - /// Deferences the value on stack if the provided type is ByRef. - /// - /// The IL emitter. - /// The type to check if ByRef. - public static void DerefIfByRef(this Emit il, Type type) => il.DerefIfByRef(ref type); - - /// - /// Deferences the value on stack if the provided type is ByRef. - /// - /// The IL emitter. - /// The type to check if ByRef. - public static void DerefIfByRef(this Emit il, ref Type type) - { - if (type == null) throw new ArgumentNullException(nameof(type)); - if (type.IsByRef) - { - type = type.GetElementType(); - if (type.IsValueType) - { - il.LoadObject(type); - } - else - { - il.LoadIndirect(type); - } - } - } - - // Copied from https://github.com/evilfactory/moonsharp/blob/5264656c6442e783f3c75082cce69a93d66d4cc0/src/MoonSharp.Interpreter/Interop/Converters/ScriptToClrConversions.cs#L79-L99 - private static MethodInfo GetImplicitOperatorMethod(Type baseType, Type targetType) - { - try - { - return Expression.Convert(Expression.Parameter(baseType, null), targetType).Method; - } - catch - { - if (baseType.BaseType != null) - { - return GetImplicitOperatorMethod(baseType.BaseType, targetType); - } - - if (targetType.BaseType != null) - { - return GetImplicitOperatorMethod(baseType, targetType.BaseType); - } - - return null; - } - } - - /// - /// Loads a local variable and casts it to the target type. - /// - /// The IL emitter. - /// The value to cast. Must be of type . - /// The type to cast into. - public static void LoadLocalAndCast(this Emit il, Local value, Type targetType) - { - if (value == null) throw new ArgumentNullException(nameof(value)); - if (targetType == null) throw new ArgumentNullException(nameof(targetType)); - if (value.LocalType != typeof(object)) - { - throw new ArgumentException($"Expected local type {typeof(object)}; got {value.LocalType}.", nameof(value)); - } - - var guid = Guid.NewGuid().ToString("N"); - - if (targetType.IsByRef) - { - targetType = targetType.GetElementType(); - } - - // IL: var baseType = value.GetType(); - var baseType = il.DeclareLocal(typeof(Type), $"cast_baseType_{guid}"); - il.LoadLocal(value); - il.Call(typeof(object).GetMethod("GetType")); - il.StoreLocal(baseType); - - // IL: var implicitOperatorMethod = SigilExtensions.GetImplicitOperatorMethod(baseType, ); - var implicitOperatorMethod = il.DeclareLocal(typeof(MethodInfo), $"cast_implicitOperatorMethod_{guid}"); - il.LoadLocal(baseType); - il.LoadType(targetType); - il.Call(typeof(SigilExtensions).GetMethod(nameof(GetImplicitOperatorMethod), BindingFlags.NonPublic | BindingFlags.Static)); - il.StoreLocal(implicitOperatorMethod); - - // IL: castValue; - var castValue = il.DeclareLocal(targetType, $"cast_castValue_{guid}"); - - // IL: if (implicitConversionMethod != null) - il.LoadLocal(implicitOperatorMethod); - il.Branch((il) => - { - // IL: var methodInvokeParams = new object[1]; - var methodInvokeParams = il.DeclareLocal(typeof(object[]), $"cast_methodInvokeParams_{guid}"); - il.LoadConstant(1); - il.NewArray(typeof(object)); - il.StoreLocal(methodInvokeParams); - - // IL: methodInvokeParams[0] = value; - il.LoadLocal(methodInvokeParams); - il.LoadConstant(0); - il.LoadLocal(value); - il.StoreElement(); - - // IL: castValue = ()implicitConversionMethod.Invoke(null, methodInvokeParams); - il.LoadLocal(implicitOperatorMethod); - il.LoadNull(); // first parameter is null because implicit cast operators are static - il.LoadLocal(methodInvokeParams); - il.Call(typeof(MethodInfo).GetMethod("Invoke", new[] { typeof(object), typeof(object[]) })); - if (targetType.IsValueType) - { - il.UnboxAny(targetType); - } - else - { - il.CastClass(targetType); - } - il.StoreLocal(castValue); - }, - (il) => - { - // IL: castValue = ()value; - il.LoadLocal(value); - if (targetType.IsValueType) - { - il.UnboxAny(targetType); - } - else - { - il.CastClass(targetType); - } - il.StoreLocal(castValue); - }); - - il.LoadLocal(castValue); - } - - /// - /// Emits a call to . - /// - /// The IL emitter. - /// The string format. - /// The local variables passed to string.Format. - public static void FormatString(this Emit il, string format, params Local[] args) - { - if (format == null) throw new ArgumentNullException(nameof(format)); - if (args == null) throw new ArgumentNullException(nameof(args)); - - var guid = Guid.NewGuid().ToString("N"); - - var listType = typeof(List<>).MakeGenericType(typeof(object)); - var list = il.DeclareLocal(listType, $"formatString_list_{guid}"); - il.NewObject(listType); - il.StoreLocal(list); - - foreach (var arg in args) - { - il.LoadLocal(list); - il.LoadLocal(arg); - il.ToObject(arg.LocalType); - il.CallVirtual(listType.GetMethod("Add", new[] { typeof(object) })); - } - - var arr = il.DeclareLocal($"formatString_arr_{guid}"); - il.LoadLocal(list); - il.CallVirtual(listType.GetMethod("ToArray", new Type[0])); - il.StoreLocal(arr); - - il.LoadConstant(format); - il.LoadLocal(arr); - il.Call(typeof(string).GetMethod("Format", new[] { typeof(string), typeof(object[]) })); - } - - /// - /// Emits a call to . - /// - /// The IL emitter. - /// The message to print. - public static void NewMessage(this Emit il, string message) - { - var newMessage = typeof(DebugConsole).GetMethod( - name: nameof(DebugConsole.NewMessage), - bindingAttr: BindingFlags.Public | BindingFlags.Static, - binder: null, - types: new Type[] { typeof(string), typeof(Color?), typeof(bool) }, - modifiers: null); - il.LoadConstant(message); - il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod()); - il.LoadConstant(false); - il.Call(newMessage); - } - - /// - /// Emits a call to , - /// using the string on the stack. - /// - /// The IL emitter. - public static void NewMessage(this Emit il) - { - var newMessage = typeof(DebugConsole).GetMethod( - name: nameof(DebugConsole.NewMessage), - bindingAttr: BindingFlags.Public | BindingFlags.Static, - binder: null, - types: new Type[] { typeof(string), typeof(Color?), typeof(bool) }, - modifiers: null); - il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod()); - il.LoadConstant(false); - il.Call(newMessage); - } - - /// - /// Emits a foreach loop that iterates over an local variable. - /// - /// The type of elements in the enumerable. - /// The IL emitter. - /// The enumerable. - /// The body of code to run on each iteration. - public static void ForEachEnumerable(this Emit il, Local enumerable, Action action) - { - if (enumerable == null) throw new ArgumentNullException(nameof(enumerable)); - if (action == null) throw new ArgumentNullException(nameof(action)); - if (!typeof(IEnumerable).IsAssignableFrom(enumerable.LocalType)) - { - throw new ArgumentException($"Expected local type {typeof(IEnumerator)}; got {enumerable.LocalType}.", nameof(enumerable)); - } - - var guid = Guid.NewGuid().ToString("N"); - - var enumerator = il.DeclareLocal>($"forEachEnumerable_enumerator_{guid}"); - il.LoadLocal(enumerable); - il.CallVirtual(typeof(IEnumerable).GetMethod("GetEnumerator")); - il.StoreLocal(enumerator); - ForEachEnumerator(il, enumerator, action); - } - - /// - /// Emits a foreach loop that iterates over an local variable. - /// - /// The type of elements in the enumerable. - /// The IL emitter. - /// The enumerator. - /// The body of code to run on each iteration. - public static void ForEachEnumerator(this Emit il, Local enumerator, Action action) - { - if (enumerator == null) throw new ArgumentNullException(nameof(enumerator)); - if (action == null) throw new ArgumentNullException(nameof(action)); - if (!typeof(IEnumerator).IsAssignableFrom(enumerator.LocalType)) - { - throw new ArgumentException($"Expected local type {typeof(IEnumerator)}; got {enumerator.LocalType}.", nameof(enumerator)); - } - - var guid = Guid.NewGuid().ToString("N"); - var labelLoopStart = il.DefineLabel($"forEach_loopStart_{guid}"); - var labelMoveNext = il.DefineLabel($"forEach_moveNext_{guid}"); - var labelLeave = il.DefineLabel($"forEach_leave_{guid}"); - - il.BeginExceptionBlock(out var exceptionBlock); - il.Branch(labelMoveNext); // MoveNext() needs to be called at least once before iterating - il.MarkLabel(labelLoopStart); - - // IL: var current = enumerator.Current; - var current = il.DeclareLocal($"forEachEnumerator_current_{guid}"); - il.LoadLocal(enumerator); - il.CallVirtual(enumerator.LocalType.GetProperty("Current").GetGetMethod()); - il.StoreLocal(current); - - action(il, current, labelLeave); - - il.MarkLabel(labelMoveNext); - il.LoadLocal(enumerator); - il.CallVirtual(typeof(IEnumerator).GetMethod("MoveNext")); - il.BranchIfTrue(labelLoopStart); // loop if MoveNext() returns true - - // IL: finally { enumerator.Dispose(); } - il.BeginFinallyBlock(exceptionBlock, out var finallyBlock); - il.LoadLocal(enumerator); - il.CallVirtual(typeof(IDisposable).GetMethod("Dispose")); - il.EndFinallyBlock(finallyBlock); - - il.EndExceptionBlock(exceptionBlock); - - il.MarkLabel(labelLeave); - } - - /// - /// Emits a branch that only executes if the last value on the stack - /// is truthy (e.g. non-null references, 1, etc). - /// - /// The IL emitter. - /// The body of code to run if the value is truthy. - public static void If(this Emit il, Action action) - { - if (action == null) throw new ArgumentNullException(nameof(action)); - il.Branch(@if: action); - } - - /// - /// Emits a branch that only executes if the last value on the stack - /// is falsy (e.g. null references, 0, etc). - /// - /// The IL emitter. - /// The body of code to run if the value is falsy. - public static void IfNot(this Emit il, Action action) - { - if (action == null) throw new ArgumentNullException(nameof(action)); - il.Branch(@else: action); - } - - /// - /// Emits two branches that diverge based on a condition -- analogous - /// to an if-else statement. If either - /// or are omitted, it behaves the same as - /// - /// and . - /// - /// The IL emitter. - /// The body of code to run if the value is truthy. - /// The body of code to run if the value is falsy. - public static void Branch(this Emit il, Action @if = null, Action @else = null) - { - if (@if == null && @else == null) throw new ArgumentException("At least one of the two branches must be defined."); - - var guid = Guid.NewGuid().ToString("N"); - var labelEnd = il.DefineLabel($"branch_end_{guid}"); - if (@if != null && @else != null) - { - var labelElse = il.DefineLabel($"branch_else_{guid}"); - il.BranchIfFalse(labelElse); - @if(il); - il.Branch(labelEnd); - il.MarkLabel(labelElse); - @else(il); - } - else if (@if != null) - { - il.BranchIfFalse(labelEnd); - @if(il); - } - else - { - il.BranchIfTrue(labelEnd); - @else(il); - } - il.MarkLabel(labelEnd); - } - } - - partial class EventService + public partial class LuaPatcherService : ILuaPatcher { private class LuaCsHookCallback { @@ -511,36 +122,6 @@ namespace Barotrauma.LuaCs.Services public Dictionary ModifiedParameters { get; } = new Dictionary(); } - private static readonly string[] prohibitedHooks = - { - "Barotrauma.Lua", - "Barotrauma.Cs", - "Barotrauma.ContentPackageManager", - }; - - private static void ValidatePatchTarget(MethodBase method) - { - if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h))) - { - throw new ArgumentException("Hooks into the modding environment are prohibited."); - } - } - - private static string NormalizeIdentifier(string identifier) - { - return identifier?.Trim().ToLowerInvariant(); - } - - private Harmony harmony; - - private Lazy patchModuleBuilder; - - private readonly Dictionary registeredPatches = new Dictionary(); - - private LuaCsSetup luaCs; - - private static EventService instance; - private struct MethodKey : IEquatable { public ModuleHandle ModuleHandle { get; set; } @@ -579,7 +160,19 @@ namespace Barotrauma.LuaCs.Services }; } - public void InitPatcher() + private static readonly string[] prohibitedHooks = + { + "Barotrauma.Lua", + "Barotrauma.Cs", + "Barotrauma.ContentPackageManager", + }; + + + private Harmony harmony; + private Lazy patchModuleBuilder; + private readonly Dictionary registeredPatches = new Dictionary(); + + public LuaPatcherService() { instance = this; @@ -587,6 +180,9 @@ namespace Barotrauma.LuaCs.Services patchModuleBuilder = new Lazy(CreateModuleBuilder); UserData.RegisterType(); + + // whats this for? + /* var hookType = UserData.RegisterType(); var hookDesc = (StandardUserDataDescriptor)hookType; typeof(EventService).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => { @@ -600,29 +196,20 @@ namespace Barotrauma.LuaCs.Services hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default)); } }); + */ } - public void ResetPatcher() + private static void ValidatePatchTarget(MethodBase method) { - harmony?.UnpatchSelf(); - - foreach (var (_, patch) in registeredPatches) + if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h))) { - // Remove references stored in our dynamic types so the generated - // assembly can be garbage-collected. - patch.HarmonyPrefixMethod.DeclaringType - .GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static) - .SetValue(null, null); - patch.HarmonyPostfixMethod.DeclaringType - .GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static) - .SetValue(null, null); + throw new ArgumentException("Hooks into the modding environment are prohibited."); } + } - registeredPatches.Clear(); - patchModuleBuilder = null; - - compatHookPrefixMethods.Clear(); - compatHookPostfixMethods.Clear(); + private static string NormalizeIdentifier(string identifier) + { + return identifier?.Trim().ToLowerInvariant(); } private ModuleBuilder CreateModuleBuilder() @@ -797,6 +384,8 @@ namespace Barotrauma.LuaCs.Services private const string FIELD_LUACS = "LuaCs"; + public bool IsDisposed { get; private set; } + // If you need to debug this: // - use https://sharplab.io ; it's a very useful for resource for writing IL by hand. // - use il.NewMessage("") or il.WriteLine("") to see where the IL crashes at runtime. @@ -863,8 +452,8 @@ namespace Barotrauma.LuaCs.Services // IL: var patchExists = instance.registeredPatches.TryGetValue(patchKey, out MethodPatches patches) var patchExists = il.DeclareLocal("patchExists"); var patches = il.DeclareLocal("patches"); - il.LoadField(typeof(EventService).GetField(nameof(instance), BindingFlags.NonPublic | BindingFlags.Static)); - il.LoadField(typeof(EventService).GetField(nameof(registeredPatches), BindingFlags.NonPublic | BindingFlags.Instance)); + il.LoadField(typeof(LuaPatcherService).GetField(nameof(instance), BindingFlags.NonPublic | BindingFlags.Static)); + il.LoadField(typeof(LuaPatcherService).GetField(nameof(registeredPatches), BindingFlags.NonPublic | BindingFlags.Instance)); il.LoadLocal(patchKey); il.LoadLocalAddress(patches); // out parameter il.Call(typeof(Dictionary).GetMethod("TryGetValue")); @@ -1081,7 +670,7 @@ namespace Barotrauma.LuaCs.Services } var type = typeBuilder.CreateType(); - type.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static).SetValue(null, luaCs); + type.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static).SetValue(null, GameMain.LuaCs); return type.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static); } @@ -1187,5 +776,42 @@ namespace Barotrauma.LuaCs.Services var method = ResolveMethod(className, methodName, null); return RemovePatch(identifier, method, hookType); } + + private void ClearAll() + { + harmony?.UnpatchSelf(); + + foreach (var (_, patch) in registeredPatches) + { + // Remove references stored in our dynamic types so the generated + // assembly can be garbage-collected. + patch.HarmonyPrefixMethod.DeclaringType + .GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static) + .SetValue(null, null); + patch.HarmonyPostfixMethod.DeclaringType + .GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static) + .SetValue(null, null); + } + + registeredPatches.Clear(); + patchModuleBuilder = null; + + compatHookPrefixMethods.Clear(); + compatHookPostfixMethods.Clear(); + } + + public void Dispose() + { + IsDisposed = true; + + ClearAll(); + } + + public FluentResults.Result Reset() + { + ClearAll(); + + return FluentResults.Result.Ok(); + } } } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/SigilExtensions.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/SigilExtensions.cs new file mode 100644 index 000000000..81143ffce --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/SigilExtensions.cs @@ -0,0 +1,399 @@ +using Microsoft.Xna.Framework; +using Sigil; +using Sigil.NonGeneric; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace Barotrauma.LuaCs; + +internal static class SigilExtensions +{ + /// + /// Puts a type on the stack, as a object instead of a + /// runtime type token. + /// + /// The IL emitter. + /// The type to put on the stack. + public static void LoadType(this Emit il, Type type) + { + if (type == null) throw new ArgumentNullException(nameof(type)); + il.LoadConstant(type); // ldtoken + // This converts the type token into a Type object + il.Call(typeof(Type).GetMethod( + name: nameof(Type.GetTypeFromHandle), + bindingAttr: BindingFlags.Public | BindingFlags.Static, + binder: null, + types: new Type[] { typeof(RuntimeTypeHandle) }, + modifiers: null)); + } + + /// + /// Converts the value on the stack to . + /// + /// The IL emitter. + /// The type of the value on the stack. + public static void ToObject(this Emit il, Type type) + { + if (type == null) throw new ArgumentNullException(nameof(type)); + il.DerefIfByRef(ref type); + if (type.IsValueType) + { + il.Box(type); + } + else if (type != typeof(object)) + { + il.CastClass(); + } + } + + /// + /// Deferences the value on stack if the provided type is ByRef. + /// + /// The IL emitter. + /// The type to check if ByRef. + public static void DerefIfByRef(this Emit il, Type type) => il.DerefIfByRef(ref type); + + /// + /// Deferences the value on stack if the provided type is ByRef. + /// + /// The IL emitter. + /// The type to check if ByRef. + public static void DerefIfByRef(this Emit il, ref Type type) + { + if (type == null) throw new ArgumentNullException(nameof(type)); + if (type.IsByRef) + { + type = type.GetElementType(); + if (type.IsValueType) + { + il.LoadObject(type); + } + else + { + il.LoadIndirect(type); + } + } + } + + // Copied from https://github.com/evilfactory/moonsharp/blob/5264656c6442e783f3c75082cce69a93d66d4cc0/src/MoonSharp.Interpreter/Interop/Converters/ScriptToClrConversions.cs#L79-L99 + private static MethodInfo GetImplicitOperatorMethod(Type baseType, Type targetType) + { + try + { + return Expression.Convert(Expression.Parameter(baseType, null), targetType).Method; + } + catch + { + if (baseType.BaseType != null) + { + return GetImplicitOperatorMethod(baseType.BaseType, targetType); + } + + if (targetType.BaseType != null) + { + return GetImplicitOperatorMethod(baseType, targetType.BaseType); + } + + return null; + } + } + + /// + /// Loads a local variable and casts it to the target type. + /// + /// The IL emitter. + /// The value to cast. Must be of type . + /// The type to cast into. + public static void LoadLocalAndCast(this Emit il, Local value, Type targetType) + { + if (value == null) throw new ArgumentNullException(nameof(value)); + if (targetType == null) throw new ArgumentNullException(nameof(targetType)); + if (value.LocalType != typeof(object)) + { + throw new ArgumentException($"Expected local type {typeof(object)}; got {value.LocalType}.", nameof(value)); + } + + var guid = Guid.NewGuid().ToString("N"); + + if (targetType.IsByRef) + { + targetType = targetType.GetElementType(); + } + + // IL: var baseType = value.GetType(); + var baseType = il.DeclareLocal(typeof(Type), $"cast_baseType_{guid}"); + il.LoadLocal(value); + il.Call(typeof(object).GetMethod("GetType")); + il.StoreLocal(baseType); + + // IL: var implicitOperatorMethod = SigilExtensions.GetImplicitOperatorMethod(baseType, ); + var implicitOperatorMethod = il.DeclareLocal(typeof(MethodInfo), $"cast_implicitOperatorMethod_{guid}"); + il.LoadLocal(baseType); + il.LoadType(targetType); + il.Call(typeof(SigilExtensions).GetMethod(nameof(GetImplicitOperatorMethod), BindingFlags.NonPublic | BindingFlags.Static)); + il.StoreLocal(implicitOperatorMethod); + + // IL: castValue; + var castValue = il.DeclareLocal(targetType, $"cast_castValue_{guid}"); + + // IL: if (implicitConversionMethod != null) + il.LoadLocal(implicitOperatorMethod); + il.Branch((il) => + { + // IL: var methodInvokeParams = new object[1]; + var methodInvokeParams = il.DeclareLocal(typeof(object[]), $"cast_methodInvokeParams_{guid}"); + il.LoadConstant(1); + il.NewArray(typeof(object)); + il.StoreLocal(methodInvokeParams); + + // IL: methodInvokeParams[0] = value; + il.LoadLocal(methodInvokeParams); + il.LoadConstant(0); + il.LoadLocal(value); + il.StoreElement(); + + // IL: castValue = ()implicitConversionMethod.Invoke(null, methodInvokeParams); + il.LoadLocal(implicitOperatorMethod); + il.LoadNull(); // first parameter is null because implicit cast operators are static + il.LoadLocal(methodInvokeParams); + il.Call(typeof(MethodInfo).GetMethod("Invoke", new[] { typeof(object), typeof(object[]) })); + if (targetType.IsValueType) + { + il.UnboxAny(targetType); + } + else + { + il.CastClass(targetType); + } + il.StoreLocal(castValue); + }, + (il) => + { + // IL: castValue = ()value; + il.LoadLocal(value); + if (targetType.IsValueType) + { + il.UnboxAny(targetType); + } + else + { + il.CastClass(targetType); + } + il.StoreLocal(castValue); + }); + + il.LoadLocal(castValue); + } + + /// + /// Emits a call to . + /// + /// The IL emitter. + /// The string format. + /// The local variables passed to string.Format. + public static void FormatString(this Emit il, string format, params Local[] args) + { + if (format == null) throw new ArgumentNullException(nameof(format)); + if (args == null) throw new ArgumentNullException(nameof(args)); + + var guid = Guid.NewGuid().ToString("N"); + + var listType = typeof(List<>).MakeGenericType(typeof(object)); + var list = il.DeclareLocal(listType, $"formatString_list_{guid}"); + il.NewObject(listType); + il.StoreLocal(list); + + foreach (var arg in args) + { + il.LoadLocal(list); + il.LoadLocal(arg); + il.ToObject(arg.LocalType); + il.CallVirtual(listType.GetMethod("Add", new[] { typeof(object) })); + } + + var arr = il.DeclareLocal($"formatString_arr_{guid}"); + il.LoadLocal(list); + il.CallVirtual(listType.GetMethod("ToArray", new Type[0])); + il.StoreLocal(arr); + + il.LoadConstant(format); + il.LoadLocal(arr); + il.Call(typeof(string).GetMethod("Format", new[] { typeof(string), typeof(object[]) })); + } + + /// + /// Emits a call to . + /// + /// The IL emitter. + /// The message to print. + public static void NewMessage(this Emit il, string message) + { + var newMessage = typeof(DebugConsole).GetMethod( + name: nameof(DebugConsole.NewMessage), + bindingAttr: BindingFlags.Public | BindingFlags.Static, + binder: null, + types: new Type[] { typeof(string), typeof(Color?), typeof(bool) }, + modifiers: null); + il.LoadConstant(message); + il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod()); + il.LoadConstant(false); + il.Call(newMessage); + } + + /// + /// Emits a call to , + /// using the string on the stack. + /// + /// The IL emitter. + public static void NewMessage(this Emit il) + { + var newMessage = typeof(DebugConsole).GetMethod( + name: nameof(DebugConsole.NewMessage), + bindingAttr: BindingFlags.Public | BindingFlags.Static, + binder: null, + types: new Type[] { typeof(string), typeof(Color?), typeof(bool) }, + modifiers: null); + il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod()); + il.LoadConstant(false); + il.Call(newMessage); + } + + /// + /// Emits a foreach loop that iterates over an local variable. + /// + /// The type of elements in the enumerable. + /// The IL emitter. + /// The enumerable. + /// The body of code to run on each iteration. + public static void ForEachEnumerable(this Emit il, Local enumerable, Action action) + { + if (enumerable == null) throw new ArgumentNullException(nameof(enumerable)); + if (action == null) throw new ArgumentNullException(nameof(action)); + if (!typeof(IEnumerable).IsAssignableFrom(enumerable.LocalType)) + { + throw new ArgumentException($"Expected local type {typeof(IEnumerator)}; got {enumerable.LocalType}.", nameof(enumerable)); + } + + var guid = Guid.NewGuid().ToString("N"); + + var enumerator = il.DeclareLocal>($"forEachEnumerable_enumerator_{guid}"); + il.LoadLocal(enumerable); + il.CallVirtual(typeof(IEnumerable).GetMethod("GetEnumerator")); + il.StoreLocal(enumerator); + ForEachEnumerator(il, enumerator, action); + } + + /// + /// Emits a foreach loop that iterates over an local variable. + /// + /// The type of elements in the enumerable. + /// The IL emitter. + /// The enumerator. + /// The body of code to run on each iteration. + public static void ForEachEnumerator(this Emit il, Local enumerator, Action action) + { + if (enumerator == null) throw new ArgumentNullException(nameof(enumerator)); + if (action == null) throw new ArgumentNullException(nameof(action)); + if (!typeof(IEnumerator).IsAssignableFrom(enumerator.LocalType)) + { + throw new ArgumentException($"Expected local type {typeof(IEnumerator)}; got {enumerator.LocalType}.", nameof(enumerator)); + } + + var guid = Guid.NewGuid().ToString("N"); + var labelLoopStart = il.DefineLabel($"forEach_loopStart_{guid}"); + var labelMoveNext = il.DefineLabel($"forEach_moveNext_{guid}"); + var labelLeave = il.DefineLabel($"forEach_leave_{guid}"); + + il.BeginExceptionBlock(out var exceptionBlock); + il.Branch(labelMoveNext); // MoveNext() needs to be called at least once before iterating + il.MarkLabel(labelLoopStart); + + // IL: var current = enumerator.Current; + var current = il.DeclareLocal($"forEachEnumerator_current_{guid}"); + il.LoadLocal(enumerator); + il.CallVirtual(enumerator.LocalType.GetProperty("Current").GetGetMethod()); + il.StoreLocal(current); + + action(il, current, labelLeave); + + il.MarkLabel(labelMoveNext); + il.LoadLocal(enumerator); + il.CallVirtual(typeof(IEnumerator).GetMethod("MoveNext")); + il.BranchIfTrue(labelLoopStart); // loop if MoveNext() returns true + + // IL: finally { enumerator.Dispose(); } + il.BeginFinallyBlock(exceptionBlock, out var finallyBlock); + il.LoadLocal(enumerator); + il.CallVirtual(typeof(IDisposable).GetMethod("Dispose")); + il.EndFinallyBlock(finallyBlock); + + il.EndExceptionBlock(exceptionBlock); + + il.MarkLabel(labelLeave); + } + + /// + /// Emits a branch that only executes if the last value on the stack + /// is truthy (e.g. non-null references, 1, etc). + /// + /// The IL emitter. + /// The body of code to run if the value is truthy. + public static void If(this Emit il, Action action) + { + if (action == null) throw new ArgumentNullException(nameof(action)); + il.Branch(@if: action); + } + + /// + /// Emits a branch that only executes if the last value on the stack + /// is falsy (e.g. null references, 0, etc). + /// + /// The IL emitter. + /// The body of code to run if the value is falsy. + public static void IfNot(this Emit il, Action action) + { + if (action == null) throw new ArgumentNullException(nameof(action)); + il.Branch(@else: action); + } + + /// + /// Emits two branches that diverge based on a condition -- analogous + /// to an if-else statement. If either + /// or are omitted, it behaves the same as + /// + /// and . + /// + /// The IL emitter. + /// The body of code to run if the value is truthy. + /// The body of code to run if the value is falsy. + public static void Branch(this Emit il, Action @if = null, Action @else = null) + { + if (@if == null && @else == null) throw new ArgumentException("At least one of the two branches must be defined."); + + var guid = Guid.NewGuid().ToString("N"); + var labelEnd = il.DefineLabel($"branch_end_{guid}"); + if (@if != null && @else != null) + { + var labelElse = il.DefineLabel($"branch_else_{guid}"); + il.BranchIfFalse(labelElse); + @if(il); + il.Branch(labelEnd); + il.MarkLabel(labelElse); + @else(il); + } + else if (@if != null) + { + il.BranchIfFalse(labelEnd); + @if(il); + } + else + { + il.BranchIfTrue(labelEnd); + @else(il); + } + il.MarkLabel(labelEnd); + } +}