Working Hook.Patch and old patch methods

This commit is contained in:
Evil Factory
2026-01-29 22:20:16 -03:00
committed by Maplewheels
parent 6b8a0a7dca
commit 4f02cb4967
7 changed files with 84 additions and 182 deletions

View File

@@ -12,4 +12,10 @@ public interface ILuaCsHook : ILuaCsShim
[Obsolete("Only Lua subscribers will receive events from call. Use ILuaEventService.Add() instead.")]
T Call<T>(string eventName, params object[] args);
object Call(string eventName, params object[] args);
string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, EventService.HookMethodType hookType = EventService.HookMethodType.Before);
string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, EventService.HookMethodType hookType = EventService.HookMethodType.Before);
string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, EventService.HookMethodType hookType = EventService.HookMethodType.Before);
string Patch(string className, string methodName, LuaCsPatchFunc patch, EventService.HookMethodType hookType = EventService.HookMethodType.Before);
bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, EventService.HookMethodType hookType);
bool RemovePatch(string identifier, string className, string methodName, EventService.HookMethodType hookType);
}

View File

@@ -11,7 +11,7 @@ using OneOf;
namespace Barotrauma.LuaCs.Services;
public class EventService : IEventService, IEventAssemblyContextUnloading
public partial class EventService : IEventService, IEventAssemblyContextUnloading
{
private readonly record struct TypeStringKey : IEqualityComparer<TypeStringKey>, IEquatable<TypeStringKey>
{
@@ -74,6 +74,8 @@ public class EventService : IEventService, IEventAssemblyContextUnloading
{
_pluginManagementService = pluginManagementService ?? throw new ArgumentNullException(nameof(pluginManagementService));
this.Subscribe<IEventAssemblyContextUnloading>(this);
InitPatcher();
}
public bool IsDisposed { get; private set; } = false;
@@ -344,12 +346,8 @@ public class EventService : IEventService, IEventAssemblyContextUnloading
public void Dispose()
{
Reset();
IsDisposed = true;
_subscriptions.Clear();
_luaSubscriptionFactories.Clear();
_eventTypeNameAliases.Clear();
_luaLegacySubscriptionFactories.Clear();
_luaOrphanSubscribers.Clear();
GC.SuppressFinalize(this);
}
@@ -361,6 +359,8 @@ public class EventService : IEventService, IEventAssemblyContextUnloading
_eventTypeNameAliases.Clear();
_luaLegacySubscriptionFactories.Clear();
_luaOrphanSubscribers.Clear();
ResetPatcher();
InitPatcher();
return FluentResults.Result.Ok();
}

View File

@@ -169,7 +169,7 @@ class LuaScriptManagementService : ILuaScriptManagementService, ILuaDataService
Script.GlobalOptions.ShouldPCallCatchException = (Exception ex) => { return true; };
RegisterType(typeof(LuaGame));
RegisterType(typeof(ILuaCsHook));
RegisterType(typeof(EventService));
RegisterType(typeof(ILuaCsNetworking));
RegisterType(typeof(ILuaCsUtility));
RegisterType(typeof(ILuaCsTimer));

View File

@@ -127,6 +127,12 @@ public class PluginManagementService : IAssemblyManagementService
public Type GetType(string typeName, bool isByRefType = false, bool includeInterfaces = false,
bool includeDefaultContext = true)
{
if (typeName.StartsWith("out ") || typeName.StartsWith("ref "))
{
typeName = typeName.Remove(0, 4);
isByRefType = true;
}
if (includeDefaultContext)
{
var type = Type.GetType(typeName, false, false);

View File

@@ -1,4 +1,11 @@
using System;
using Barotrauma.LuaCs.Services;
using HarmonyLib;
using Microsoft.Xna.Framework;
using MoonSharp.Interpreter;
using MoonSharp.Interpreter.Interop;
using Sigil;
using Sigil.NonGeneric;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
@@ -8,19 +15,16 @@ using System.Reflection;
using System.Reflection.Emit;
using System.Text;
using System.Text.RegularExpressions;
using HarmonyLib;
using Microsoft.Xna.Framework;
using MoonSharp.Interpreter;
using MoonSharp.Interpreter.Interop;
using Sigil;
using Sigil.NonGeneric;
namespace Barotrauma
{
public delegate void LuaCsAction(params object[] args);
public delegate object LuaCsFunc(params object[] args);
public delegate DynValue LuaCsPatchFunc(object instance, LuaCsHook.ParameterTable ptable);
public delegate DynValue LuaCsPatchFunc(object instance, EventService.ParameterTable ptable);
}
namespace Barotrauma.LuaCs.Services
{
internal static class SigilExtensions
{
/// <summary>
@@ -410,7 +414,7 @@ namespace Barotrauma
}
}
public partial class LuaCsHook
partial class EventService
{
public enum HookMethodType
{
@@ -536,13 +540,11 @@ namespace Barotrauma
private Lazy<ModuleBuilder> patchModuleBuilder;
private readonly Dictionary<string, Dictionary<string, (LuaCsHookCallback, ACsMod)>> hookFunctions = new Dictionary<string, Dictionary<string, (LuaCsHookCallback, ACsMod)>>();
private readonly Dictionary<MethodKey, PatchedMethod> registeredPatches = new Dictionary<MethodKey, PatchedMethod>();
private LuaCsSetup luaCs;
private static LuaCsHook instance;
private static EventService instance;
private struct MethodKey : IEquatable<MethodKey>
{
@@ -582,21 +584,17 @@ namespace Barotrauma
};
}
internal LuaCsHook(LuaCsSetup luaCs)
public void InitPatcher()
{
instance = this;
this.luaCs = luaCs;
}
public void Initialize()
{
harmony = new Harmony("LuaCsForBarotrauma");
patchModuleBuilder = new Lazy<ModuleBuilder>(CreateModuleBuilder);
UserData.RegisterType<ParameterTable>();
var hookType = UserData.RegisterType<LuaCsHook>();
var hookType = UserData.RegisterType<EventService>();
var hookDesc = (StandardUserDataDescriptor)hookType;
typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => {
typeof(EventService).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => {
if (
m.Name.Contains("HookMethod") ||
m.Name.Contains("UnhookMethod") ||
@@ -609,6 +607,29 @@ namespace Barotrauma
});
}
public void ResetPatcher()
{
harmony?.UnpatchSelf();
foreach (var (_, patch) in registeredPatches)
{
// Remove references stored in our dynamic types so the generated
// assembly can be garbage-collected.
patch.HarmonyPrefixMethod.DeclaringType
.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static)
.SetValue(null, null);
patch.HarmonyPostfixMethod.DeclaringType
.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static)
.SetValue(null, null);
}
registeredPatches.Clear();
patchModuleBuilder = null;
compatHookPrefixMethods.Clear();
compatHookPostfixMethods.Clear();
}
private ModuleBuilder CreateModuleBuilder()
{
var assemblyName = $"LuaCsHookPatch-{Guid.NewGuid():N}";
@@ -689,143 +710,6 @@ namespace Barotrauma
return moduleBuilder;
}
public void Add(string name, LuaCsFunc func, ACsMod owner = null) => Add(name, name, func, owner);
public void Add(string name, string identifier, LuaCsFunc func, ACsMod owner = null)
{
if (name == null) throw new ArgumentNullException(nameof(name));
if (identifier == null) throw new ArgumentNullException(nameof(identifier));
if (func == null) throw new ArgumentNullException(nameof(func));
name = NormalizeIdentifier(name);
identifier = NormalizeIdentifier(identifier);
if (!hookFunctions.ContainsKey(name))
{
hookFunctions.Add(name, new Dictionary<string, (LuaCsHookCallback, ACsMod)>());
}
hookFunctions[name][identifier] = (new LuaCsHookCallback(name, identifier, func), owner);
}
public bool Exists(string name, string identifier)
{
if (name == null) throw new ArgumentNullException(nameof(name));
if (identifier == null) throw new ArgumentNullException(nameof(identifier));
name = NormalizeIdentifier(name);
identifier = NormalizeIdentifier(identifier);
if (!hookFunctions.ContainsKey(name))
{
return false;
}
return hookFunctions[name].ContainsKey(identifier);
}
public void Remove(string name, string identifier)
{
if (name == null) throw new ArgumentNullException(nameof(name));
if (identifier == null) throw new ArgumentNullException(nameof(identifier));
name = NormalizeIdentifier(name);
identifier = NormalizeIdentifier(identifier);
if (hookFunctions.ContainsKey(name) && hookFunctions[name].ContainsKey(identifier))
{
hookFunctions[name].Remove(identifier);
}
}
public void Clear()
{
harmony?.UnpatchSelf();
foreach (var (_, patch) in registeredPatches)
{
// Remove references stored in our dynamic types so the generated
// assembly can be garbage-collected.
patch.HarmonyPrefixMethod.DeclaringType
.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static)
.SetValue(null, null);
patch.HarmonyPostfixMethod.DeclaringType
.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static)
.SetValue(null, null);
}
hookFunctions.Clear();
registeredPatches.Clear();
patchModuleBuilder = null;
compatHookPrefixMethods.Clear();
compatHookPostfixMethods.Clear();
}
private Stopwatch performanceMeasurement = new Stopwatch();
[MoonSharpHidden]
public T Call<T>(string name, params object[] args)
{
if (name == null) throw new ArgumentNullException(name);
if (args == null) args = new object[0];
name = NormalizeIdentifier(name);
if (!hookFunctions.ContainsKey(name)) return default;
T lastResult = default;
var hooks = hookFunctions[name].ToArray();
foreach ((string key, var tuple) in hooks)
{
if (tuple.Item2 != null && tuple.Item2.IsDisposed)
{
hookFunctions[name].Remove(key);
continue;
}
try
{
var result = tuple.Item1.func(args);
if (result is DynValue luaResult)
{
if (luaResult.Type == DataType.Tuple)
{
bool replaceNil = luaResult.Tuple.Length > 1 && luaResult.Tuple[1].CastToBool();
if (!luaResult.Tuple[0].IsNil() || replaceNil)
{
lastResult = luaResult.ToObject<T>();
}
}
else if (!luaResult.IsNil())
{
lastResult = luaResult.ToObject<T>();
}
}
else
{
lastResult = (T)result;
}
}
catch (Exception e)
{
var argsSb = new StringBuilder();
foreach (var arg in args)
{
argsSb.Append(arg + " ");
}
LuaCsLogger.LogError($"Error in Hook '{name}'->'{key}', with args '{argsSb}':\n{e}", LuaCsMessageOrigin.Unknown);
LuaCsLogger.HandleException(e, LuaCsMessageOrigin.Unknown);
}
}
return lastResult;
}
public object Call(string name, params object[] args) => Call<object>(name, args);
private static MethodBase ResolveMethod(string className, string methodName, string[] parameters)
{
var classType = GameMain.LuaCs.PluginManagementService.GetType(className);
@@ -984,8 +868,8 @@ namespace Barotrauma
// IL: var patchExists = instance.registeredPatches.TryGetValue(patchKey, out MethodPatches patches)
var patchExists = il.DeclareLocal<bool>("patchExists");
var patches = il.DeclareLocal<PatchedMethod>("patches");
il.LoadField(typeof(LuaCsHook).GetField(nameof(instance), BindingFlags.NonPublic | BindingFlags.Static));
il.LoadField(typeof(LuaCsHook).GetField(nameof(registeredPatches), BindingFlags.NonPublic | BindingFlags.Instance));
il.LoadField(typeof(EventService).GetField(nameof(instance), BindingFlags.NonPublic | BindingFlags.Static));
il.LoadField(typeof(EventService).GetField(nameof(registeredPatches), BindingFlags.NonPublic | BindingFlags.Instance));
il.LoadLocal(patchKey);
il.LoadLocalAddress(patches); // out parameter
il.Call(typeof(Dictionary<MethodKey, PatchedMethod>).GetMethod("TryGetValue"));

View File

@@ -1,4 +1,6 @@
using System;
global using LuaCsHook = Barotrauma.LuaCs.Services.EventService;
using System;
using System.Linq;
using System.Reflection;
using HarmonyLib;
@@ -10,8 +12,11 @@ namespace Barotrauma
{
// XXX: this can't be renamed because of backward compatibility with C# mods
public delegate object LuaCsPatch(object self, Dictionary<string, object> args);
}
partial class LuaCsHook
namespace Barotrauma.LuaCs.Services
{
partial class EventService
{
private Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>> compatHookPrefixMethods = new Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>>();
private Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>> compatHookPostfixMethods = new Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, ACsMod)>>();
@@ -115,10 +120,10 @@ namespace Barotrauma
if (result != null) __result = result;
}
private static MethodInfo _miHookLuaCsPatchPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaCsHook).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchPrefix = typeof(EventService).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchPostfix = typeof(EventService).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(EventService).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(EventService).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static);
// TODO: deprecate this
public void HookMethod(string identifier, MethodBase method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null)

View File

@@ -1,6 +1,7 @@
extern alias Client;
using Client::Barotrauma;
using Client::Barotrauma.LuaCs.Services;
using Client::Barotrauma;
using MoonSharp.Interpreter;
using System;
using System.Collections.Concurrent;
@@ -63,14 +64,14 @@ namespace TestProject.LuaCs
string methodName,
string[]? parameters,
string function,
LuaCsHook.HookMethodType patchType)
EventService.HookMethodType patchType)
{
var args = BuildHookPatchArgsList(patchId, className, methodName, parameters);
args.Add(function);
args.Add(patchType switch
{
LuaCsHook.HookMethodType.Before => "Hook.HookMethodType.Before",
LuaCsHook.HookMethodType.After => "Hook.HookMethodType.After",
EventService.HookMethodType.Before => "Hook.HookMethodType.Before",
EventService.HookMethodType.After => "Hook.HookMethodType.After",
_ => throw new NotImplementedException(),
});
throw new NotImplementedException();
@@ -83,13 +84,13 @@ namespace TestProject.LuaCs
string className,
string methodName,
string[]? parameters,
LuaCsHook.HookMethodType patchType)
EventService.HookMethodType patchType)
{
var args = BuildHookPatchArgsList(patchId, className, methodName, parameters);
args.Add(patchType switch
{
LuaCsHook.HookMethodType.Before => "Hook.HookMethodType.Before",
LuaCsHook.HookMethodType.After => "Hook.HookMethodType.After",
EventService.HookMethodType.Before => "Hook.HookMethodType.Before",
EventService.HookMethodType.After => "Hook.HookMethodType.After",
_ => throw new NotImplementedException(),
});
throw new NotImplementedException();
@@ -103,7 +104,7 @@ namespace TestProject.LuaCs
function(instance, ptable)
{body}
end
", LuaCsHook.HookMethodType.Before);
", EventService.HookMethodType.Before);
Assert.Equal(DataType.String, returnValue.Type);
return new(returnValue.String, () => luaCs.RemovePrefix<T>(returnValue.String, methodName, parameters));
}
@@ -115,7 +116,7 @@ namespace TestProject.LuaCs
function(instance, ptable)
{body}
end
", LuaCsHook.HookMethodType.After);
", EventService.HookMethodType.After);
Assert.Equal(DataType.String, returnValue.Type);
return new(returnValue.String, () => luaCs.RemovePostfix<T>(returnValue.String, methodName, parameters));
}
@@ -123,7 +124,7 @@ namespace TestProject.LuaCs
public static bool RemovePrefix<T>(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null)
{
var className = typeof(T).FullName!;
var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, LuaCsHook.HookMethodType.Before);
var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, EventService.HookMethodType.Before);
Assert.Equal(DataType.Boolean, returnValue.Type);
return returnValue.Boolean;
}
@@ -131,7 +132,7 @@ namespace TestProject.LuaCs
public static bool RemovePostfix<T>(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null)
{
var className = typeof(T).FullName!;
var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, LuaCsHook.HookMethodType.After);
var returnValue = luaCs.DoHookRemovePatch(patchId, className, methodName, parameters, EventService.HookMethodType.After);
Assert.Equal(DataType.Boolean, returnValue.Type);
return returnValue.Boolean;
}