Move the Lua IL patching bullshit to a separate service

This commit is contained in:
Evil Factory
2026-02-03 19:37:47 -03:00
committed by Maplewheels
parent ea602f6d2f
commit 70dd602bcf
7 changed files with 547 additions and 471 deletions

View File

@@ -190,6 +190,7 @@ namespace Barotrauma
servicesProvider.RegisterServiceResolver<IPluginManagementService>(factory => factory.GetInstance<IAssemblyManagementService>());
servicesProvider.RegisterServiceType<ILuaScriptManagementService, LuaScriptManagementService>(ServiceLifetime.Singleton);
servicesProvider.RegisterServiceType<IDefaultLuaRegistrar, DefaultLuaRegistrar>(ServiceLifetime.Singleton);
servicesProvider.RegisterServiceType<ILuaPatcher, LuaPatcherService>(ServiceLifetime.Singleton);
servicesProvider.RegisterServiceType<ILuaUserDataService, LuaUserDataService>(ServiceLifetime.Singleton);
servicesProvider.RegisterServiceType<ISafeLuaUserDataService, SafeLuaUserDataService>(ServiceLifetime.Singleton);

View File

@@ -1,10 +1,9 @@
using System;
using System.Reflection;
using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch;
namespace Barotrauma.LuaCs.Services.Compatibility;
public interface ILuaCsHook : ILuaCsShim
public interface ILuaCsHook : ILuaPatcher, ILuaCsShim
{
// Event Services
[Obsolete("ACsMod is deprecated. Use ILuaEventService.Add() instead.")]
@@ -15,18 +14,8 @@ public interface ILuaCsHook : ILuaCsShim
//bool Exists(string eventName, string identifier);
object Call(string eventName, params object[] args);
T Call<T>(string eventName, params object[] args);
// Hook/Method Patching
string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before);
string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before);
string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before);
string Patch(string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before);
bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, HookMethodType hookType);
bool RemovePatch(string identifier, string className, string methodName, HookMethodType hookType);
void HookMethod(string identifier, MethodBase method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, IAssemblyPlugin owner = null);
// Needs to be here instead of ILuaPatcher for compatiility purposes
public enum HookMethodType
{
Before, After

View File

@@ -1,15 +1,18 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Barotrauma.Extensions;
using Barotrauma.Extensions;
using Barotrauma.LuaCs.Events;
using Barotrauma.LuaCs.Services.Compatibility;
using FluentResults;
using FluentResults.LuaCs;
using HarmonyLib;
using Microsoft.Toolkit.Diagnostics;
using OneOf;
using RestSharp;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Reflection;
namespace Barotrauma.LuaCs.Services;
@@ -51,13 +54,14 @@ public partial class EventService : IEventService
public static implicit operator TypeStringKey(string typeName) => new(typeName);
}
private ILoggerService _loggerService;
private readonly ILoggerService _loggerService;
private readonly ILuaPatcher _luaPatcher;
private readonly AsyncReaderWriterLock _operationsLock = new();
private readonly ConcurrentDictionary<TypeStringKey, ConcurrentDictionary<OneOf<IEvent, string>, IEvent>> _subscribers = new();
private readonly ConcurrentDictionary<TypeStringKey, (TypeStringKey Event, Func<LuaCsFunc, IEvent> RunnerFactory)> _luaAliasEventFactory = new();
private readonly ConcurrentDictionary<string, ConcurrentDictionary<string, LuaCsFunc>> _luaLegacyEventsSubscribers = new();
#region Disposal
#region LifeCycle
public void Dispose()
{
@@ -70,13 +74,15 @@ public partial class EventService : IEventService
_luaLegacyEventsSubscribers.Clear();
_luaAliasEventFactory.Clear();
_subscribers.Clear();
_luaPatcher.Dispose();
}
private int _isDisposed;
public EventService(ILoggerService loggerService)
public EventService(ILoggerService loggerService, ILuaPatcher luaPatcher)
{
_loggerService = loggerService;
_luaPatcher = luaPatcher;
}
public bool IsDisposed
@@ -91,6 +97,7 @@ public partial class EventService : IEventService
_luaLegacyEventsSubscribers.Clear();
_luaAliasEventFactory.Clear();
_subscribers.Clear();
_luaPatcher.Reset();
return FluentResults.Result.Ok();
}
@@ -299,4 +306,41 @@ public partial class EventService : IEventService
return results;
}
#region LuaPatcherAdapter
public string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, LuaCsHook.HookMethodType hookType = LuaCsHook.HookMethodType.Before)
{
return _luaPatcher.Patch(identifier, className, methodName, parameterTypes, patch, hookType);
}
public string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, LuaCsHook.HookMethodType hookType = LuaCsHook.HookMethodType.Before)
{
return _luaPatcher.Patch(identifier, className, methodName, patch, hookType);
}
public string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, LuaCsHook.HookMethodType hookType = LuaCsHook.HookMethodType.Before)
{
return _luaPatcher.Patch(className, methodName, parameterTypes, patch, hookType);
}
public string Patch(string className, string methodName, LuaCsPatchFunc patch, LuaCsHook.HookMethodType hookType = LuaCsHook.HookMethodType.Before)
{
return _luaPatcher.Patch(className, methodName, patch, hookType);
}
public bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsHook.HookMethodType hookType)
{
return _luaPatcher.RemovePatch(className, methodName, methodName, parameterTypes, hookType);
}
public bool RemovePatch(string identifier, string className, string methodName, LuaCsHook.HookMethodType hookType)
{
return _luaPatcher.RemovePatch(className, methodName, methodName, hookType);
}
public void HookMethod(string identifier, MethodBase method, LuaCsPatch patch, LuaCsHook.HookMethodType hookType = LuaCsHook.HookMethodType.Before, IAssemblyPlugin owner = null)
{
_luaPatcher.HookMethod(identifier, method, patch, hookType, owner);
}
#endregion
}

View File

@@ -0,0 +1,16 @@
using System.Reflection;
using static Barotrauma.LuaCs.Services.Compatibility.ILuaCsHook;
using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch;
namespace Barotrauma.LuaCs.Services;
public interface ILuaPatcher : IReusableService
{
string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before);
string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before);
string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before);
string Patch(string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before);
bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, HookMethodType hookType);
bool RemovePatch(string identifier, string className, string methodName, HookMethodType hookType);
void HookMethod(string identifier, MethodBase method, LuaCsCompatPatchFunc patch, HookMethodType hookType = HookMethodType.Before, IAssemblyPlugin owner = null);
}

View File

@@ -1,5 +1,4 @@
//global using LuaCsHook = Barotrauma.LuaCs.Services.EventService;
global using LuaCsHook = Barotrauma.LuaCs.Services.Compatibility.ILuaCsHook;
global using LuaCsHook = Barotrauma.LuaCs.Services.Compatibility.ILuaCsHook;
using System;
using System.Linq;
@@ -18,8 +17,10 @@ namespace Barotrauma
namespace Barotrauma.LuaCs.Services
{
partial class EventService
partial class LuaPatcherService
{
private static LuaPatcherService instance;
private Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, IAssemblyPlugin)>> compatHookPrefixMethods = new Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, IAssemblyPlugin)>>();
private Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, IAssemblyPlugin)>> compatHookPostfixMethods = new Dictionary<long, HashSet<(string, LuaCsCompatPatchFunc, IAssemblyPlugin)>>();
@@ -122,10 +123,10 @@ namespace Barotrauma.LuaCs.Services
if (result != null) __result = result;
}
private static MethodInfo _miHookLuaCsPatchPrefix = typeof(EventService).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchPostfix = typeof(EventService).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(EventService).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(EventService).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchPrefix = typeof(LuaPatcherService).GetMethod("HookLuaCsPatchPrefix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchPostfix = typeof(LuaPatcherService).GetMethod("HookLuaCsPatchPostfix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchRetPrefix = typeof(LuaPatcherService).GetMethod("HookLuaCsPatchRetPrefix", BindingFlags.NonPublic | BindingFlags.Static);
private static MethodInfo _miHookLuaCsPatchRetPostfix = typeof(LuaPatcherService).GetMethod("HookLuaCsPatchRetPostfix", BindingFlags.NonPublic | BindingFlags.Static);
// TODO: deprecate this

View File

@@ -20,401 +20,12 @@ namespace Barotrauma
{
public delegate void LuaCsAction(params object[] args);
public delegate object LuaCsFunc(params object[] args);
public delegate DynValue LuaCsPatchFunc(object instance, EventService.ParameterTable ptable);
public delegate DynValue LuaCsPatchFunc(object instance, LuaPatcherService.ParameterTable ptable);
}
namespace Barotrauma.LuaCs.Services
{
internal static class SigilExtensions
{
/// <summary>
/// Puts a type on the stack, as a <see cref="Type" /> object instead of a
/// runtime type token.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="type">The type to put on the stack.</param>
public static void LoadType(this Emit il, Type type)
{
if (type == null) throw new ArgumentNullException(nameof(type));
il.LoadConstant(type); // ldtoken
// This converts the type token into a Type object
il.Call(typeof(Type).GetMethod(
name: nameof(Type.GetTypeFromHandle),
bindingAttr: BindingFlags.Public | BindingFlags.Static,
binder: null,
types: new Type[] { typeof(RuntimeTypeHandle) },
modifiers: null));
}
/// <summary>
/// Converts the value on the stack to <see cref="object" />.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="type">The type of the value on the stack.</param>
public static void ToObject(this Emit il, Type type)
{
if (type == null) throw new ArgumentNullException(nameof(type));
il.DerefIfByRef(ref type);
if (type.IsValueType)
{
il.Box(type);
}
else if (type != typeof(object))
{
il.CastClass<object>();
}
}
/// <summary>
/// Deferences the value on stack if the provided type is ByRef.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="type">The type to check if ByRef.</param>
public static void DerefIfByRef(this Emit il, Type type) => il.DerefIfByRef(ref type);
/// <summary>
/// Deferences the value on stack if the provided type is ByRef.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="type">The type to check if ByRef.</param>
public static void DerefIfByRef(this Emit il, ref Type type)
{
if (type == null) throw new ArgumentNullException(nameof(type));
if (type.IsByRef)
{
type = type.GetElementType();
if (type.IsValueType)
{
il.LoadObject(type);
}
else
{
il.LoadIndirect(type);
}
}
}
// Copied from https://github.com/evilfactory/moonsharp/blob/5264656c6442e783f3c75082cce69a93d66d4cc0/src/MoonSharp.Interpreter/Interop/Converters/ScriptToClrConversions.cs#L79-L99
private static MethodInfo GetImplicitOperatorMethod(Type baseType, Type targetType)
{
try
{
return Expression.Convert(Expression.Parameter(baseType, null), targetType).Method;
}
catch
{
if (baseType.BaseType != null)
{
return GetImplicitOperatorMethod(baseType.BaseType, targetType);
}
if (targetType.BaseType != null)
{
return GetImplicitOperatorMethod(baseType, targetType.BaseType);
}
return null;
}
}
/// <summary>
/// Loads a local variable and casts it to the target type.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="value">The value to cast. Must be of type <see cref="object" />.</param>
/// <param name="targetType">The type to cast into.</param>
public static void LoadLocalAndCast(this Emit il, Local value, Type targetType)
{
if (value == null) throw new ArgumentNullException(nameof(value));
if (targetType == null) throw new ArgumentNullException(nameof(targetType));
if (value.LocalType != typeof(object))
{
throw new ArgumentException($"Expected local type {typeof(object)}; got {value.LocalType}.", nameof(value));
}
var guid = Guid.NewGuid().ToString("N");
if (targetType.IsByRef)
{
targetType = targetType.GetElementType();
}
// IL: var baseType = value.GetType();
var baseType = il.DeclareLocal(typeof(Type), $"cast_baseType_{guid}");
il.LoadLocal(value);
il.Call(typeof(object).GetMethod("GetType"));
il.StoreLocal(baseType);
// IL: var implicitOperatorMethod = SigilExtensions.GetImplicitOperatorMethod(baseType, <targetType>);
var implicitOperatorMethod = il.DeclareLocal(typeof(MethodInfo), $"cast_implicitOperatorMethod_{guid}");
il.LoadLocal(baseType);
il.LoadType(targetType);
il.Call(typeof(SigilExtensions).GetMethod(nameof(GetImplicitOperatorMethod), BindingFlags.NonPublic | BindingFlags.Static));
il.StoreLocal(implicitOperatorMethod);
// IL: <TargetType> castValue;
var castValue = il.DeclareLocal(targetType, $"cast_castValue_{guid}");
// IL: if (implicitConversionMethod != null)
il.LoadLocal(implicitOperatorMethod);
il.Branch((il) =>
{
// IL: var methodInvokeParams = new object[1];
var methodInvokeParams = il.DeclareLocal(typeof(object[]), $"cast_methodInvokeParams_{guid}");
il.LoadConstant(1);
il.NewArray(typeof(object));
il.StoreLocal(methodInvokeParams);
// IL: methodInvokeParams[0] = value;
il.LoadLocal(methodInvokeParams);
il.LoadConstant(0);
il.LoadLocal(value);
il.StoreElement<object>();
// IL: castValue = (<TargetType>)implicitConversionMethod.Invoke(null, methodInvokeParams);
il.LoadLocal(implicitOperatorMethod);
il.LoadNull(); // first parameter is null because implicit cast operators are static
il.LoadLocal(methodInvokeParams);
il.Call(typeof(MethodInfo).GetMethod("Invoke", new[] { typeof(object), typeof(object[]) }));
if (targetType.IsValueType)
{
il.UnboxAny(targetType);
}
else
{
il.CastClass(targetType);
}
il.StoreLocal(castValue);
},
(il) =>
{
// IL: castValue = (<TargetType>)value;
il.LoadLocal(value);
if (targetType.IsValueType)
{
il.UnboxAny(targetType);
}
else
{
il.CastClass(targetType);
}
il.StoreLocal(castValue);
});
il.LoadLocal(castValue);
}
/// <summary>
/// Emits a call to <see cref="string.Format(string, object[])"/>.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="format">The string format.</param>
/// <param name="args">The local variables passed to string.Format.</param>
public static void FormatString(this Emit il, string format, params Local[] args)
{
if (format == null) throw new ArgumentNullException(nameof(format));
if (args == null) throw new ArgumentNullException(nameof(args));
var guid = Guid.NewGuid().ToString("N");
var listType = typeof(List<>).MakeGenericType(typeof(object));
var list = il.DeclareLocal(listType, $"formatString_list_{guid}");
il.NewObject(listType);
il.StoreLocal(list);
foreach (var arg in args)
{
il.LoadLocal(list);
il.LoadLocal(arg);
il.ToObject(arg.LocalType);
il.CallVirtual(listType.GetMethod("Add", new[] { typeof(object) }));
}
var arr = il.DeclareLocal<object[]>($"formatString_arr_{guid}");
il.LoadLocal(list);
il.CallVirtual(listType.GetMethod("ToArray", new Type[0]));
il.StoreLocal(arr);
il.LoadConstant(format);
il.LoadLocal(arr);
il.Call(typeof(string).GetMethod("Format", new[] { typeof(string), typeof(object[]) }));
}
/// <summary>
/// Emits a call to <see cref="DebugConsole.NewMessage(string, Color?, bool)" />.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="message">The message to print.</param>
public static void NewMessage(this Emit il, string message)
{
var newMessage = typeof(DebugConsole).GetMethod(
name: nameof(DebugConsole.NewMessage),
bindingAttr: BindingFlags.Public | BindingFlags.Static,
binder: null,
types: new Type[] { typeof(string), typeof(Color?), typeof(bool) },
modifiers: null);
il.LoadConstant(message);
il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod());
il.LoadConstant(false);
il.Call(newMessage);
}
/// <summary>
/// Emits a call to <see cref="DebugConsole.NewMessage(string, Color?, bool)" />,
/// using the string on the stack.
/// </summary>
/// <param name="il">The IL emitter.</param>
public static void NewMessage(this Emit il)
{
var newMessage = typeof(DebugConsole).GetMethod(
name: nameof(DebugConsole.NewMessage),
bindingAttr: BindingFlags.Public | BindingFlags.Static,
binder: null,
types: new Type[] { typeof(string), typeof(Color?), typeof(bool) },
modifiers: null);
il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod());
il.LoadConstant(false);
il.Call(newMessage);
}
/// <summary>
/// Emits a <c>foreach</c> loop that iterates over an <see cref="IEnumerable{T}"/> local variable.
/// </summary>
/// <typeparam name="T">The type of elements in the enumerable.</typeparam>
/// <param name="il">The IL emitter.</param>
/// <param name="enumerable">The enumerable.</param>
/// <param name="action">The body of code to run on each iteration.</param>
public static void ForEachEnumerable<T>(this Emit il, Local enumerable, Action<Emit, Local, Sigil.Label> action)
{
if (enumerable == null) throw new ArgumentNullException(nameof(enumerable));
if (action == null) throw new ArgumentNullException(nameof(action));
if (!typeof(IEnumerable<T>).IsAssignableFrom(enumerable.LocalType))
{
throw new ArgumentException($"Expected local type {typeof(IEnumerator<T>)}; got {enumerable.LocalType}.", nameof(enumerable));
}
var guid = Guid.NewGuid().ToString("N");
var enumerator = il.DeclareLocal<IEnumerator<T>>($"forEachEnumerable_enumerator_{guid}");
il.LoadLocal(enumerable);
il.CallVirtual(typeof(IEnumerable<T>).GetMethod("GetEnumerator"));
il.StoreLocal(enumerator);
ForEachEnumerator<T>(il, enumerator, action);
}
/// <summary>
/// Emits a <c>foreach</c> loop that iterates over an <see cref="IEnumerator{T}"/> local variable.
/// </summary>
/// <typeparam name="T">The type of elements in the enumerable.</typeparam>
/// <param name="il">The IL emitter.</param>
/// <param name="enumerator">The enumerator.</param>
/// <param name="action">The body of code to run on each iteration.</param>
public static void ForEachEnumerator<T>(this Emit il, Local enumerator, Action<Emit, Local, Sigil.Label> action)
{
if (enumerator == null) throw new ArgumentNullException(nameof(enumerator));
if (action == null) throw new ArgumentNullException(nameof(action));
if (!typeof(IEnumerator<T>).IsAssignableFrom(enumerator.LocalType))
{
throw new ArgumentException($"Expected local type {typeof(IEnumerator<T>)}; got {enumerator.LocalType}.", nameof(enumerator));
}
var guid = Guid.NewGuid().ToString("N");
var labelLoopStart = il.DefineLabel($"forEach_loopStart_{guid}");
var labelMoveNext = il.DefineLabel($"forEach_moveNext_{guid}");
var labelLeave = il.DefineLabel($"forEach_leave_{guid}");
il.BeginExceptionBlock(out var exceptionBlock);
il.Branch(labelMoveNext); // MoveNext() needs to be called at least once before iterating
il.MarkLabel(labelLoopStart);
// IL: var current = enumerator.Current;
var current = il.DeclareLocal<T>($"forEachEnumerator_current_{guid}");
il.LoadLocal(enumerator);
il.CallVirtual(enumerator.LocalType.GetProperty("Current").GetGetMethod());
il.StoreLocal(current);
action(il, current, labelLeave);
il.MarkLabel(labelMoveNext);
il.LoadLocal(enumerator);
il.CallVirtual(typeof(IEnumerator).GetMethod("MoveNext"));
il.BranchIfTrue(labelLoopStart); // loop if MoveNext() returns true
// IL: finally { enumerator.Dispose(); }
il.BeginFinallyBlock(exceptionBlock, out var finallyBlock);
il.LoadLocal(enumerator);
il.CallVirtual(typeof(IDisposable).GetMethod("Dispose"));
il.EndFinallyBlock(finallyBlock);
il.EndExceptionBlock(exceptionBlock);
il.MarkLabel(labelLeave);
}
/// <summary>
/// Emits a branch that only executes if the last value on the stack
/// is truthy (e.g. non-null references, 1, etc).
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="action">The body of code to run if the value is truthy.</param>
public static void If(this Emit il, Action<Emit> action)
{
if (action == null) throw new ArgumentNullException(nameof(action));
il.Branch(@if: action);
}
/// <summary>
/// Emits a branch that only executes if the last value on the stack
/// is falsy (e.g. null references, 0, etc).
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="action">The body of code to run if the value is falsy.</param>
public static void IfNot(this Emit il, Action<Emit> action)
{
if (action == null) throw new ArgumentNullException(nameof(action));
il.Branch(@else: action);
}
/// <summary>
/// Emits two branches that diverge based on a condition -- analogous
/// to an if-else statement. If either <paramref name="if"/>
/// or <paramref name="else"/> are omitted, it behaves the same as
/// <see cref="If(Emit, Action{Emit})"/>
/// and <see cref="IfNot(Emit, Action{Emit})"/>.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="if">The body of code to run if the value is truthy.</param>
/// <param name="else">The body of code to run if the value is falsy.</param>
public static void Branch(this Emit il, Action<Emit> @if = null, Action<Emit> @else = null)
{
if (@if == null && @else == null) throw new ArgumentException("At least one of the two branches must be defined.");
var guid = Guid.NewGuid().ToString("N");
var labelEnd = il.DefineLabel($"branch_end_{guid}");
if (@if != null && @else != null)
{
var labelElse = il.DefineLabel($"branch_else_{guid}");
il.BranchIfFalse(labelElse);
@if(il);
il.Branch(labelEnd);
il.MarkLabel(labelElse);
@else(il);
}
else if (@if != null)
{
il.BranchIfFalse(labelEnd);
@if(il);
}
else
{
il.BranchIfTrue(labelEnd);
@else(il);
}
il.MarkLabel(labelEnd);
}
}
partial class EventService
public partial class LuaPatcherService : ILuaPatcher
{
private class LuaCsHookCallback
{
@@ -511,36 +122,6 @@ namespace Barotrauma.LuaCs.Services
public Dictionary<string, object> ModifiedParameters { get; } = new Dictionary<string, object>();
}
private static readonly string[] prohibitedHooks =
{
"Barotrauma.Lua",
"Barotrauma.Cs",
"Barotrauma.ContentPackageManager",
};
private static void ValidatePatchTarget(MethodBase method)
{
if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h)))
{
throw new ArgumentException("Hooks into the modding environment are prohibited.");
}
}
private static string NormalizeIdentifier(string identifier)
{
return identifier?.Trim().ToLowerInvariant();
}
private Harmony harmony;
private Lazy<ModuleBuilder> patchModuleBuilder;
private readonly Dictionary<MethodKey, PatchedMethod> registeredPatches = new Dictionary<MethodKey, PatchedMethod>();
private LuaCsSetup luaCs;
private static EventService instance;
private struct MethodKey : IEquatable<MethodKey>
{
public ModuleHandle ModuleHandle { get; set; }
@@ -579,7 +160,19 @@ namespace Barotrauma.LuaCs.Services
};
}
public void InitPatcher()
private static readonly string[] prohibitedHooks =
{
"Barotrauma.Lua",
"Barotrauma.Cs",
"Barotrauma.ContentPackageManager",
};
private Harmony harmony;
private Lazy<ModuleBuilder> patchModuleBuilder;
private readonly Dictionary<MethodKey, PatchedMethod> registeredPatches = new Dictionary<MethodKey, PatchedMethod>();
public LuaPatcherService()
{
instance = this;
@@ -587,6 +180,9 @@ namespace Barotrauma.LuaCs.Services
patchModuleBuilder = new Lazy<ModuleBuilder>(CreateModuleBuilder);
UserData.RegisterType<ParameterTable>();
// whats this for?
/*
var hookType = UserData.RegisterType<EventService>();
var hookDesc = (StandardUserDataDescriptor)hookType;
typeof(EventService).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => {
@@ -600,29 +196,20 @@ namespace Barotrauma.LuaCs.Services
hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default));
}
});
*/
}
public void ResetPatcher()
private static void ValidatePatchTarget(MethodBase method)
{
harmony?.UnpatchSelf();
foreach (var (_, patch) in registeredPatches)
if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h)))
{
// Remove references stored in our dynamic types so the generated
// assembly can be garbage-collected.
patch.HarmonyPrefixMethod.DeclaringType
.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static)
.SetValue(null, null);
patch.HarmonyPostfixMethod.DeclaringType
.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static)
.SetValue(null, null);
throw new ArgumentException("Hooks into the modding environment are prohibited.");
}
}
registeredPatches.Clear();
patchModuleBuilder = null;
compatHookPrefixMethods.Clear();
compatHookPostfixMethods.Clear();
private static string NormalizeIdentifier(string identifier)
{
return identifier?.Trim().ToLowerInvariant();
}
private ModuleBuilder CreateModuleBuilder()
@@ -797,6 +384,8 @@ namespace Barotrauma.LuaCs.Services
private const string FIELD_LUACS = "LuaCs";
public bool IsDisposed { get; private set; }
// If you need to debug this:
// - use https://sharplab.io ; it's a very useful for resource for writing IL by hand.
// - use il.NewMessage("") or il.WriteLine("") to see where the IL crashes at runtime.
@@ -863,8 +452,8 @@ namespace Barotrauma.LuaCs.Services
// IL: var patchExists = instance.registeredPatches.TryGetValue(patchKey, out MethodPatches patches)
var patchExists = il.DeclareLocal<bool>("patchExists");
var patches = il.DeclareLocal<PatchedMethod>("patches");
il.LoadField(typeof(EventService).GetField(nameof(instance), BindingFlags.NonPublic | BindingFlags.Static));
il.LoadField(typeof(EventService).GetField(nameof(registeredPatches), BindingFlags.NonPublic | BindingFlags.Instance));
il.LoadField(typeof(LuaPatcherService).GetField(nameof(instance), BindingFlags.NonPublic | BindingFlags.Static));
il.LoadField(typeof(LuaPatcherService).GetField(nameof(registeredPatches), BindingFlags.NonPublic | BindingFlags.Instance));
il.LoadLocal(patchKey);
il.LoadLocalAddress(patches); // out parameter
il.Call(typeof(Dictionary<MethodKey, PatchedMethod>).GetMethod("TryGetValue"));
@@ -1081,7 +670,7 @@ namespace Barotrauma.LuaCs.Services
}
var type = typeBuilder.CreateType();
type.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static).SetValue(null, luaCs);
type.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static).SetValue(null, GameMain.LuaCs);
return type.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static);
}
@@ -1187,5 +776,42 @@ namespace Barotrauma.LuaCs.Services
var method = ResolveMethod(className, methodName, null);
return RemovePatch(identifier, method, hookType);
}
private void ClearAll()
{
harmony?.UnpatchSelf();
foreach (var (_, patch) in registeredPatches)
{
// Remove references stored in our dynamic types so the generated
// assembly can be garbage-collected.
patch.HarmonyPrefixMethod.DeclaringType
.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static)
.SetValue(null, null);
patch.HarmonyPostfixMethod.DeclaringType
.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static)
.SetValue(null, null);
}
registeredPatches.Clear();
patchModuleBuilder = null;
compatHookPrefixMethods.Clear();
compatHookPostfixMethods.Clear();
}
public void Dispose()
{
IsDisposed = true;
ClearAll();
}
public FluentResults.Result Reset()
{
ClearAll();
return FluentResults.Result.Ok();
}
}
}

View File

@@ -0,0 +1,399 @@
using Microsoft.Xna.Framework;
using Sigil;
using Sigil.NonGeneric;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
namespace Barotrauma.LuaCs;
internal static class SigilExtensions
{
/// <summary>
/// Puts a type on the stack, as a <see cref="Type" /> object instead of a
/// runtime type token.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="type">The type to put on the stack.</param>
public static void LoadType(this Emit il, Type type)
{
if (type == null) throw new ArgumentNullException(nameof(type));
il.LoadConstant(type); // ldtoken
// This converts the type token into a Type object
il.Call(typeof(Type).GetMethod(
name: nameof(Type.GetTypeFromHandle),
bindingAttr: BindingFlags.Public | BindingFlags.Static,
binder: null,
types: new Type[] { typeof(RuntimeTypeHandle) },
modifiers: null));
}
/// <summary>
/// Converts the value on the stack to <see cref="object" />.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="type">The type of the value on the stack.</param>
public static void ToObject(this Emit il, Type type)
{
if (type == null) throw new ArgumentNullException(nameof(type));
il.DerefIfByRef(ref type);
if (type.IsValueType)
{
il.Box(type);
}
else if (type != typeof(object))
{
il.CastClass<object>();
}
}
/// <summary>
/// Deferences the value on stack if the provided type is ByRef.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="type">The type to check if ByRef.</param>
public static void DerefIfByRef(this Emit il, Type type) => il.DerefIfByRef(ref type);
/// <summary>
/// Deferences the value on stack if the provided type is ByRef.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="type">The type to check if ByRef.</param>
public static void DerefIfByRef(this Emit il, ref Type type)
{
if (type == null) throw new ArgumentNullException(nameof(type));
if (type.IsByRef)
{
type = type.GetElementType();
if (type.IsValueType)
{
il.LoadObject(type);
}
else
{
il.LoadIndirect(type);
}
}
}
// Copied from https://github.com/evilfactory/moonsharp/blob/5264656c6442e783f3c75082cce69a93d66d4cc0/src/MoonSharp.Interpreter/Interop/Converters/ScriptToClrConversions.cs#L79-L99
private static MethodInfo GetImplicitOperatorMethod(Type baseType, Type targetType)
{
try
{
return Expression.Convert(Expression.Parameter(baseType, null), targetType).Method;
}
catch
{
if (baseType.BaseType != null)
{
return GetImplicitOperatorMethod(baseType.BaseType, targetType);
}
if (targetType.BaseType != null)
{
return GetImplicitOperatorMethod(baseType, targetType.BaseType);
}
return null;
}
}
/// <summary>
/// Loads a local variable and casts it to the target type.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="value">The value to cast. Must be of type <see cref="object" />.</param>
/// <param name="targetType">The type to cast into.</param>
public static void LoadLocalAndCast(this Emit il, Local value, Type targetType)
{
if (value == null) throw new ArgumentNullException(nameof(value));
if (targetType == null) throw new ArgumentNullException(nameof(targetType));
if (value.LocalType != typeof(object))
{
throw new ArgumentException($"Expected local type {typeof(object)}; got {value.LocalType}.", nameof(value));
}
var guid = Guid.NewGuid().ToString("N");
if (targetType.IsByRef)
{
targetType = targetType.GetElementType();
}
// IL: var baseType = value.GetType();
var baseType = il.DeclareLocal(typeof(Type), $"cast_baseType_{guid}");
il.LoadLocal(value);
il.Call(typeof(object).GetMethod("GetType"));
il.StoreLocal(baseType);
// IL: var implicitOperatorMethod = SigilExtensions.GetImplicitOperatorMethod(baseType, <targetType>);
var implicitOperatorMethod = il.DeclareLocal(typeof(MethodInfo), $"cast_implicitOperatorMethod_{guid}");
il.LoadLocal(baseType);
il.LoadType(targetType);
il.Call(typeof(SigilExtensions).GetMethod(nameof(GetImplicitOperatorMethod), BindingFlags.NonPublic | BindingFlags.Static));
il.StoreLocal(implicitOperatorMethod);
// IL: <TargetType> castValue;
var castValue = il.DeclareLocal(targetType, $"cast_castValue_{guid}");
// IL: if (implicitConversionMethod != null)
il.LoadLocal(implicitOperatorMethod);
il.Branch((il) =>
{
// IL: var methodInvokeParams = new object[1];
var methodInvokeParams = il.DeclareLocal(typeof(object[]), $"cast_methodInvokeParams_{guid}");
il.LoadConstant(1);
il.NewArray(typeof(object));
il.StoreLocal(methodInvokeParams);
// IL: methodInvokeParams[0] = value;
il.LoadLocal(methodInvokeParams);
il.LoadConstant(0);
il.LoadLocal(value);
il.StoreElement<object>();
// IL: castValue = (<TargetType>)implicitConversionMethod.Invoke(null, methodInvokeParams);
il.LoadLocal(implicitOperatorMethod);
il.LoadNull(); // first parameter is null because implicit cast operators are static
il.LoadLocal(methodInvokeParams);
il.Call(typeof(MethodInfo).GetMethod("Invoke", new[] { typeof(object), typeof(object[]) }));
if (targetType.IsValueType)
{
il.UnboxAny(targetType);
}
else
{
il.CastClass(targetType);
}
il.StoreLocal(castValue);
},
(il) =>
{
// IL: castValue = (<TargetType>)value;
il.LoadLocal(value);
if (targetType.IsValueType)
{
il.UnboxAny(targetType);
}
else
{
il.CastClass(targetType);
}
il.StoreLocal(castValue);
});
il.LoadLocal(castValue);
}
/// <summary>
/// Emits a call to <see cref="string.Format(string, object[])"/>.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="format">The string format.</param>
/// <param name="args">The local variables passed to string.Format.</param>
public static void FormatString(this Emit il, string format, params Local[] args)
{
if (format == null) throw new ArgumentNullException(nameof(format));
if (args == null) throw new ArgumentNullException(nameof(args));
var guid = Guid.NewGuid().ToString("N");
var listType = typeof(List<>).MakeGenericType(typeof(object));
var list = il.DeclareLocal(listType, $"formatString_list_{guid}");
il.NewObject(listType);
il.StoreLocal(list);
foreach (var arg in args)
{
il.LoadLocal(list);
il.LoadLocal(arg);
il.ToObject(arg.LocalType);
il.CallVirtual(listType.GetMethod("Add", new[] { typeof(object) }));
}
var arr = il.DeclareLocal<object[]>($"formatString_arr_{guid}");
il.LoadLocal(list);
il.CallVirtual(listType.GetMethod("ToArray", new Type[0]));
il.StoreLocal(arr);
il.LoadConstant(format);
il.LoadLocal(arr);
il.Call(typeof(string).GetMethod("Format", new[] { typeof(string), typeof(object[]) }));
}
/// <summary>
/// Emits a call to <see cref="DebugConsole.NewMessage(string, Color?, bool)" />.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="message">The message to print.</param>
public static void NewMessage(this Emit il, string message)
{
var newMessage = typeof(DebugConsole).GetMethod(
name: nameof(DebugConsole.NewMessage),
bindingAttr: BindingFlags.Public | BindingFlags.Static,
binder: null,
types: new Type[] { typeof(string), typeof(Color?), typeof(bool) },
modifiers: null);
il.LoadConstant(message);
il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod());
il.LoadConstant(false);
il.Call(newMessage);
}
/// <summary>
/// Emits a call to <see cref="DebugConsole.NewMessage(string, Color?, bool)" />,
/// using the string on the stack.
/// </summary>
/// <param name="il">The IL emitter.</param>
public static void NewMessage(this Emit il)
{
var newMessage = typeof(DebugConsole).GetMethod(
name: nameof(DebugConsole.NewMessage),
bindingAttr: BindingFlags.Public | BindingFlags.Static,
binder: null,
types: new Type[] { typeof(string), typeof(Color?), typeof(bool) },
modifiers: null);
il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod());
il.LoadConstant(false);
il.Call(newMessage);
}
/// <summary>
/// Emits a <c>foreach</c> loop that iterates over an <see cref="IEnumerable{T}"/> local variable.
/// </summary>
/// <typeparam name="T">The type of elements in the enumerable.</typeparam>
/// <param name="il">The IL emitter.</param>
/// <param name="enumerable">The enumerable.</param>
/// <param name="action">The body of code to run on each iteration.</param>
public static void ForEachEnumerable<T>(this Emit il, Local enumerable, Action<Emit, Local, Sigil.Label> action)
{
if (enumerable == null) throw new ArgumentNullException(nameof(enumerable));
if (action == null) throw new ArgumentNullException(nameof(action));
if (!typeof(IEnumerable<T>).IsAssignableFrom(enumerable.LocalType))
{
throw new ArgumentException($"Expected local type {typeof(IEnumerator<T>)}; got {enumerable.LocalType}.", nameof(enumerable));
}
var guid = Guid.NewGuid().ToString("N");
var enumerator = il.DeclareLocal<IEnumerator<T>>($"forEachEnumerable_enumerator_{guid}");
il.LoadLocal(enumerable);
il.CallVirtual(typeof(IEnumerable<T>).GetMethod("GetEnumerator"));
il.StoreLocal(enumerator);
ForEachEnumerator<T>(il, enumerator, action);
}
/// <summary>
/// Emits a <c>foreach</c> loop that iterates over an <see cref="IEnumerator{T}"/> local variable.
/// </summary>
/// <typeparam name="T">The type of elements in the enumerable.</typeparam>
/// <param name="il">The IL emitter.</param>
/// <param name="enumerator">The enumerator.</param>
/// <param name="action">The body of code to run on each iteration.</param>
public static void ForEachEnumerator<T>(this Emit il, Local enumerator, Action<Emit, Local, Sigil.Label> action)
{
if (enumerator == null) throw new ArgumentNullException(nameof(enumerator));
if (action == null) throw new ArgumentNullException(nameof(action));
if (!typeof(IEnumerator<T>).IsAssignableFrom(enumerator.LocalType))
{
throw new ArgumentException($"Expected local type {typeof(IEnumerator<T>)}; got {enumerator.LocalType}.", nameof(enumerator));
}
var guid = Guid.NewGuid().ToString("N");
var labelLoopStart = il.DefineLabel($"forEach_loopStart_{guid}");
var labelMoveNext = il.DefineLabel($"forEach_moveNext_{guid}");
var labelLeave = il.DefineLabel($"forEach_leave_{guid}");
il.BeginExceptionBlock(out var exceptionBlock);
il.Branch(labelMoveNext); // MoveNext() needs to be called at least once before iterating
il.MarkLabel(labelLoopStart);
// IL: var current = enumerator.Current;
var current = il.DeclareLocal<T>($"forEachEnumerator_current_{guid}");
il.LoadLocal(enumerator);
il.CallVirtual(enumerator.LocalType.GetProperty("Current").GetGetMethod());
il.StoreLocal(current);
action(il, current, labelLeave);
il.MarkLabel(labelMoveNext);
il.LoadLocal(enumerator);
il.CallVirtual(typeof(IEnumerator).GetMethod("MoveNext"));
il.BranchIfTrue(labelLoopStart); // loop if MoveNext() returns true
// IL: finally { enumerator.Dispose(); }
il.BeginFinallyBlock(exceptionBlock, out var finallyBlock);
il.LoadLocal(enumerator);
il.CallVirtual(typeof(IDisposable).GetMethod("Dispose"));
il.EndFinallyBlock(finallyBlock);
il.EndExceptionBlock(exceptionBlock);
il.MarkLabel(labelLeave);
}
/// <summary>
/// Emits a branch that only executes if the last value on the stack
/// is truthy (e.g. non-null references, 1, etc).
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="action">The body of code to run if the value is truthy.</param>
public static void If(this Emit il, Action<Emit> action)
{
if (action == null) throw new ArgumentNullException(nameof(action));
il.Branch(@if: action);
}
/// <summary>
/// Emits a branch that only executes if the last value on the stack
/// is falsy (e.g. null references, 0, etc).
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="action">The body of code to run if the value is falsy.</param>
public static void IfNot(this Emit il, Action<Emit> action)
{
if (action == null) throw new ArgumentNullException(nameof(action));
il.Branch(@else: action);
}
/// <summary>
/// Emits two branches that diverge based on a condition -- analogous
/// to an if-else statement. If either <paramref name="if"/>
/// or <paramref name="else"/> are omitted, it behaves the same as
/// <see cref="If(Emit, Action{Emit})"/>
/// and <see cref="IfNot(Emit, Action{Emit})"/>.
/// </summary>
/// <param name="il">The IL emitter.</param>
/// <param name="if">The body of code to run if the value is truthy.</param>
/// <param name="else">The body of code to run if the value is falsy.</param>
public static void Branch(this Emit il, Action<Emit> @if = null, Action<Emit> @else = null)
{
if (@if == null && @else == null) throw new ArgumentException("At least one of the two branches must be defined.");
var guid = Guid.NewGuid().ToString("N");
var labelEnd = il.DefineLabel($"branch_end_{guid}");
if (@if != null && @else != null)
{
var labelElse = il.DefineLabel($"branch_else_{guid}");
il.BranchIfFalse(labelElse);
@if(il);
il.Branch(labelEnd);
il.MarkLabel(labelElse);
@else(il);
}
else if (@if != null)
{
il.BranchIfFalse(labelEnd);
@if(il);
}
else
{
il.BranchIfTrue(labelEnd);
@else(il);
}
il.MarkLabel(labelEnd);
}
}