diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Compatibility/ILuaCsHook.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Compatibility/ILuaCsHook.cs index 3eb723a2e..e5a1c7b09 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Compatibility/ILuaCsHook.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Compatibility/ILuaCsHook.cs @@ -12,4 +12,10 @@ public interface ILuaCsHook : ILuaCsShim [Obsolete("Only Lua subscribers will receive events from call. Use ILuaEventService.Add() instead.")] T Call(string eventName, params object[] args); object Call(string eventName, params object[] args); + string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, EventService.HookMethodType hookType = EventService.HookMethodType.Before); + string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, EventService.HookMethodType hookType = EventService.HookMethodType.Before); + string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, EventService.HookMethodType hookType = EventService.HookMethodType.Before); + string Patch(string className, string methodName, LuaCsPatchFunc patch, EventService.HookMethodType hookType = EventService.HookMethodType.Before); + bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, EventService.HookMethodType hookType); + bool RemovePatch(string identifier, string className, string methodName, EventService.HookMethodType hookType); } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/EventService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/EventService.cs index 9839afa1b..8f825ae6e 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/EventService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/EventService.cs @@ -11,7 +11,7 @@ using OneOf; namespace Barotrauma.LuaCs.Services; -public class EventService : IEventService, IEventAssemblyContextUnloading +public partial class EventService : IEventService, IEventAssemblyContextUnloading { private readonly record struct TypeStringKey : IEqualityComparer, IEquatable { @@ -74,6 +74,8 @@ public class EventService : IEventService, IEventAssemblyContextUnloading { _pluginManagementService = pluginManagementService ?? throw new ArgumentNullException(nameof(pluginManagementService)); this.Subscribe(this); + + InitPatcher(); } public bool IsDisposed { get; private set; } = false; @@ -344,12 +346,8 @@ public class EventService : IEventService, IEventAssemblyContextUnloading public void Dispose() { + Reset(); IsDisposed = true; - _subscriptions.Clear(); - _luaSubscriptionFactories.Clear(); - _eventTypeNameAliases.Clear(); - _luaLegacySubscriptionFactories.Clear(); - _luaOrphanSubscribers.Clear(); GC.SuppressFinalize(this); } @@ -361,6 +359,8 @@ public class EventService : IEventService, IEventAssemblyContextUnloading _eventTypeNameAliases.Clear(); _luaLegacySubscriptionFactories.Clear(); _luaOrphanSubscribers.Clear(); + ResetPatcher(); + InitPatcher(); return FluentResults.Result.Ok(); } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/LuaScriptManagementService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/LuaScriptManagementService.cs index 370c5103d..d15b0e682 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/LuaScriptManagementService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/LuaScriptManagementService.cs @@ -169,7 +169,7 @@ class LuaScriptManagementService : ILuaScriptManagementService, ILuaDataService Script.GlobalOptions.ShouldPCallCatchException = (Exception ex) => { return true; }; RegisterType(typeof(LuaGame)); - RegisterType(typeof(ILuaCsHook)); + RegisterType(typeof(EventService)); RegisterType(typeof(ILuaCsNetworking)); RegisterType(typeof(ILuaCsUtility)); RegisterType(typeof(ILuaCsTimer)); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PluginManagementService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PluginManagementService.cs index 532173247..ea6fedaf6 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PluginManagementService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PluginManagementService.cs @@ -127,6 +127,12 @@ public class PluginManagementService : IAssemblyManagementService public Type GetType(string typeName, bool isByRefType = false, bool includeInterfaces = false, bool includeDefaultContext = true) { + if (typeName.StartsWith("out ") || typeName.StartsWith("ref ")) + { + typeName = typeName.Remove(0, 4); + isByRefType = true; + } + if (includeDefaultContext) { var type = Type.GetType(typeName, false, false); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcher.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcher.cs index d34fe2844..3f5f47cb0 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcher.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcher.cs @@ -1,4 +1,11 @@ -using System; +using Barotrauma.LuaCs.Services; +using HarmonyLib; +using Microsoft.Xna.Framework; +using MoonSharp.Interpreter; +using MoonSharp.Interpreter.Interop; +using Sigil; +using Sigil.NonGeneric; +using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics; @@ -8,19 +15,16 @@ using System.Reflection; using System.Reflection.Emit; using System.Text; using System.Text.RegularExpressions; -using HarmonyLib; -using Microsoft.Xna.Framework; -using MoonSharp.Interpreter; -using MoonSharp.Interpreter.Interop; -using Sigil; -using Sigil.NonGeneric; namespace Barotrauma { public delegate void LuaCsAction(params object[] args); public delegate object LuaCsFunc(params object[] args); - public delegate DynValue LuaCsPatchFunc(object instance, LuaCsHook.ParameterTable ptable); + public delegate DynValue LuaCsPatchFunc(object instance, EventService.ParameterTable ptable); +} +namespace Barotrauma.LuaCs.Services +{ internal static class SigilExtensions { /// @@ -410,7 +414,7 @@ namespace Barotrauma } } - public partial class LuaCsHook + partial class EventService { public enum HookMethodType { @@ -536,13 +540,11 @@ namespace Barotrauma private Lazy patchModuleBuilder; - private readonly Dictionary> hookFunctions = new Dictionary>(); - private readonly Dictionary registeredPatches = new Dictionary(); private LuaCsSetup luaCs; - private static LuaCsHook instance; + private static EventService instance; private struct MethodKey : IEquatable { @@ -582,21 +584,17 @@ namespace Barotrauma }; } - internal LuaCsHook(LuaCsSetup luaCs) + public void InitPatcher() { instance = this; - this.luaCs = luaCs; - } - public void Initialize() - { harmony = new Harmony("LuaCsForBarotrauma"); patchModuleBuilder = new Lazy(CreateModuleBuilder); UserData.RegisterType(); - var hookType = UserData.RegisterType(); + var hookType = UserData.RegisterType(); var hookDesc = (StandardUserDataDescriptor)hookType; - typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => { + typeof(EventService).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => { if ( m.Name.Contains("HookMethod") || m.Name.Contains("UnhookMethod") || @@ -609,6 +607,29 @@ namespace Barotrauma }); } + public void ResetPatcher() + { + harmony?.UnpatchSelf(); + + foreach (var (_, patch) in registeredPatches) + { + // Remove references stored in our dynamic types so the generated + // assembly can be garbage-collected. + patch.HarmonyPrefixMethod.DeclaringType + .GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static) + .SetValue(null, null); + patch.HarmonyPostfixMethod.DeclaringType + .GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static) + .SetValue(null, null); + } + + registeredPatches.Clear(); + patchModuleBuilder = null; + + compatHookPrefixMethods.Clear(); + compatHookPostfixMethods.Clear(); + } + private ModuleBuilder CreateModuleBuilder() { var assemblyName = $"LuaCsHookPatch-{Guid.NewGuid():N}"; @@ -689,143 +710,6 @@ namespace Barotrauma return moduleBuilder; } - public void Add(string name, LuaCsFunc func, ACsMod owner = null) => Add(name, name, func, owner); - - public void Add(string name, string identifier, LuaCsFunc func, ACsMod owner = null) - { - if (name == null) throw new ArgumentNullException(nameof(name)); - if (identifier == null) throw new ArgumentNullException(nameof(identifier)); - if (func == null) throw new ArgumentNullException(nameof(func)); - - name = NormalizeIdentifier(name); - identifier = NormalizeIdentifier(identifier); - - if (!hookFunctions.ContainsKey(name)) - { - hookFunctions.Add(name, new Dictionary()); - } - - hookFunctions[name][identifier] = (new LuaCsHookCallback(name, identifier, func), owner); - } - - public bool Exists(string name, string identifier) - { - if (name == null) throw new ArgumentNullException(nameof(name)); - if (identifier == null) throw new ArgumentNullException(nameof(identifier)); - - name = NormalizeIdentifier(name); - identifier = NormalizeIdentifier(identifier); - - if (!hookFunctions.ContainsKey(name)) - { - return false; - } - - return hookFunctions[name].ContainsKey(identifier); - } - - public void Remove(string name, string identifier) - { - if (name == null) throw new ArgumentNullException(nameof(name)); - if (identifier == null) throw new ArgumentNullException(nameof(identifier)); - - name = NormalizeIdentifier(name); - identifier = NormalizeIdentifier(identifier); - - if (hookFunctions.ContainsKey(name) && hookFunctions[name].ContainsKey(identifier)) - { - hookFunctions[name].Remove(identifier); - } - } - - public void Clear() - { - harmony?.UnpatchSelf(); - - foreach (var (_, patch) in registeredPatches) - { - // Remove references stored in our dynamic types so the generated - // assembly can be garbage-collected. - patch.HarmonyPrefixMethod.DeclaringType - .GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static) - .SetValue(null, null); - patch.HarmonyPostfixMethod.DeclaringType - .GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static) - .SetValue(null, null); - } - - hookFunctions.Clear(); - registeredPatches.Clear(); - patchModuleBuilder = null; - - compatHookPrefixMethods.Clear(); - compatHookPostfixMethods.Clear(); - } - - private Stopwatch performanceMeasurement = new Stopwatch(); - - [MoonSharpHidden] - public T Call(string name, params object[] args) - { - if (name == null) throw new ArgumentNullException(name); - if (args == null) args = new object[0]; - - name = NormalizeIdentifier(name); - if (!hookFunctions.ContainsKey(name)) return default; - - T lastResult = default; - - var hooks = hookFunctions[name].ToArray(); - foreach ((string key, var tuple) in hooks) - { - if (tuple.Item2 != null && tuple.Item2.IsDisposed) - { - hookFunctions[name].Remove(key); - continue; - } - - try - { - var result = tuple.Item1.func(args); - - if (result is DynValue luaResult) - { - if (luaResult.Type == DataType.Tuple) - { - bool replaceNil = luaResult.Tuple.Length > 1 && luaResult.Tuple[1].CastToBool(); - - if (!luaResult.Tuple[0].IsNil() || replaceNil) - { - lastResult = luaResult.ToObject(); - } - } - else if (!luaResult.IsNil()) - { - lastResult = luaResult.ToObject(); - } - } - else - { - lastResult = (T)result; - } - } - catch (Exception e) - { - var argsSb = new StringBuilder(); - foreach (var arg in args) - { - argsSb.Append(arg + " "); - } - LuaCsLogger.LogError($"Error in Hook '{name}'->'{key}', with args '{argsSb}':\n{e}", LuaCsMessageOrigin.Unknown); - LuaCsLogger.HandleException(e, LuaCsMessageOrigin.Unknown); - } - } - - return lastResult; - } - - public object Call(string name, params object[] args) => Call(name, args); - private static MethodBase ResolveMethod(string className, string methodName, string[] parameters) { var classType = GameMain.LuaCs.PluginManagementService.GetType(className); @@ -984,8 +868,8 @@ namespace Barotrauma // IL: var patchExists = instance.registeredPatches.TryGetValue(patchKey, out MethodPatches patches) var patchExists = il.DeclareLocal("patchExists"); var patches = il.DeclareLocal("patches"); - il.LoadField(typeof(LuaCsHook).GetField(nameof(instance), BindingFlags.NonPublic | BindingFlags.Static)); - il.LoadField(typeof(LuaCsHook).GetField(nameof(registeredPatches), BindingFlags.NonPublic | BindingFlags.Instance)); + il.LoadField(typeof(EventService).GetField(nameof(instance), BindingFlags.NonPublic | BindingFlags.Static)); + il.LoadField(typeof(EventService).GetField(nameof(registeredPatches), BindingFlags.NonPublic | BindingFlags.Instance)); il.LoadLocal(patchKey); il.LoadLocalAddress(patches); // out parameter il.Call(typeof(Dictionary).GetMethod("TryGetValue")); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcherCompat.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcherCompat.cs index 957f77ba9..99b6bb68f 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcherCompat.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaClasses/LuaPatcherCompat.cs @@ -1,4 +1,6 @@ -using System; +global using LuaCsHook = Barotrauma.LuaCs.Services.EventService; + +using System; using System.Linq; using System.Reflection; using HarmonyLib; @@ -10,8 +12,11 @@ namespace Barotrauma { // XXX: this can't be renamed because of backward compatibility with C# mods public delegate object LuaCsPatch(object self, Dictionary args); +} - partial class LuaCsHook +namespace Barotrauma.LuaCs.Services +{ + partial class EventService { private Dictionary> compatHookPrefixMethods = new Dictionary>(); private Dictionary> compatHookPostfixMethods = new Dictionary>(); @@ -115,10 +120,10 @@ namespace Barotrauma if (result != null) __result = result; } - private static MethodInfo _miHookLuaCsPatchPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo _miHookLuaCsPatchPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static); - private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchPrefix = typeof(EventService).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchPostfix = typeof(EventService).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(EventService).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static); + private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(EventService).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static); // TODO: deprecate this public void HookMethod(string identifier, MethodBase method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null) diff --git a/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs b/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs index 49037a3c3..8aaeb45c0 100644 --- a/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs +++ b/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs @@ -1,6 +1,7 @@ extern alias Client; -using Client::Barotrauma; +using Client::Barotrauma.LuaCs.Services; +using Client::Barotrauma; using MoonSharp.Interpreter; using System; using System.Collections.Concurrent; @@ -63,14 +64,14 @@ namespace TestProject.LuaCs string methodName, string[]? parameters, string function, - LuaCsHook.HookMethodType patchType) + EventService.HookMethodType patchType) { var args = BuildHookPatchArgsList(patchId, className, methodName, parameters); args.Add(function); args.Add(patchType switch { - LuaCsHook.HookMethodType.Before => "Hook.HookMethodType.Before", - LuaCsHook.HookMethodType.After => "Hook.HookMethodType.After", + EventService.HookMethodType.Before => "Hook.HookMethodType.Before", + EventService.HookMethodType.After => "Hook.HookMethodType.After", _ => throw new NotImplementedException(), }); throw new NotImplementedException(); @@ -83,13 +84,13 @@ namespace TestProject.LuaCs string className, string methodName, string[]? parameters, - LuaCsHook.HookMethodType patchType) + EventService.HookMethodType patchType) { var args = BuildHookPatchArgsList(patchId, className, methodName, parameters); args.Add(patchType switch { - LuaCsHook.HookMethodType.Before => "Hook.HookMethodType.Before", - LuaCsHook.HookMethodType.After => "Hook.HookMethodType.After", + EventService.HookMethodType.Before => "Hook.HookMethodType.Before", + EventService.HookMethodType.After => "Hook.HookMethodType.After", _ => throw new NotImplementedException(), }); throw new NotImplementedException(); @@ -103,7 +104,7 @@ namespace TestProject.LuaCs function(instance, ptable) {body} end - ", LuaCsHook.HookMethodType.Before); + ", EventService.HookMethodType.Before); Assert.Equal(DataType.String, returnValue.Type); return new(returnValue.String, () => luaCs.RemovePrefix(returnValue.String, methodName, parameters)); } @@ -115,7 +116,7 @@ namespace TestProject.LuaCs function(instance, ptable) {body} end - ", LuaCsHook.HookMethodType.After); + ", EventService.HookMethodType.After); Assert.Equal(DataType.String, returnValue.Type); return new(returnValue.String, () => luaCs.RemovePostfix(returnValue.String, methodName, parameters)); } @@ -123,7 +124,7 @@ namespace TestProject.LuaCs public static bool RemovePrefix(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null) { var className = typeof(T).FullName!; - var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, LuaCsHook.HookMethodType.Before); + var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, EventService.HookMethodType.Before); Assert.Equal(DataType.Boolean, returnValue.Type); return returnValue.Boolean; } @@ -131,7 +132,7 @@ namespace TestProject.LuaCs public static bool RemovePostfix(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null) { var className = typeof(T).FullName!; - var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, LuaCsHook.HookMethodType.After); + var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, EventService.HookMethodType.After); Assert.Equal(DataType.Boolean, returnValue.Type); return returnValue.Boolean; }