1325 lines
55 KiB
C#
1325 lines
55 KiB
C#
using System;
|
|
using System.Collections;
|
|
using System.Collections.Generic;
|
|
using System.Diagnostics;
|
|
using System.Linq;
|
|
using System.Linq.Expressions;
|
|
using System.Reflection;
|
|
using System.Reflection.Emit;
|
|
using System.Text;
|
|
using System.Text.RegularExpressions;
|
|
using HarmonyLib;
|
|
using Microsoft.Xna.Framework;
|
|
using MoonSharp.Interpreter;
|
|
using MoonSharp.Interpreter.Interop;
|
|
using Sigil;
|
|
using Sigil.NonGeneric;
|
|
|
|
namespace Barotrauma
|
|
{
|
|
public delegate void LuaCsAction(params object[] args);
|
|
public delegate object LuaCsFunc(params object[] args);
|
|
public delegate DynValue LuaCsPatchFunc(object instance, LuaCsHook.ParameterTable ptable);
|
|
|
|
internal static class SigilExtensions
|
|
{
|
|
/// <summary>
|
|
/// Puts a type on the stack, as a <see cref="Type" /> object instead of a
|
|
/// runtime type token.
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="type">The type to put on the stack.</param>
|
|
public static void LoadType(this Emit il, Type type)
|
|
{
|
|
if (type == null) throw new ArgumentNullException(nameof(type));
|
|
il.LoadConstant(type); // ldtoken
|
|
// This converts the type token into a Type object
|
|
il.Call(typeof(Type).GetMethod(
|
|
name: nameof(Type.GetTypeFromHandle),
|
|
bindingAttr: BindingFlags.Public | BindingFlags.Static,
|
|
binder: null,
|
|
types: new Type[] { typeof(RuntimeTypeHandle) },
|
|
modifiers: null));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Converts the value on the stack to <see cref="object" />.
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="type">The type of the value on the stack.</param>
|
|
public static void ToObject(this Emit il, Type type)
|
|
{
|
|
if (type == null) throw new ArgumentNullException(nameof(type));
|
|
il.DerefIfByRef(ref type);
|
|
if (type.IsValueType)
|
|
{
|
|
il.Box(type);
|
|
}
|
|
else if (type != typeof(object))
|
|
{
|
|
il.CastClass<object>();
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Deferences the value on stack if the provided type is ByRef.
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="type">The type to check if ByRef.</param>
|
|
public static void DerefIfByRef(this Emit il, Type type) => il.DerefIfByRef(ref type);
|
|
|
|
/// <summary>
|
|
/// Deferences the value on stack if the provided type is ByRef.
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="type">The type to check if ByRef.</param>
|
|
public static void DerefIfByRef(this Emit il, ref Type type)
|
|
{
|
|
if (type == null) throw new ArgumentNullException(nameof(type));
|
|
if (type.IsByRef)
|
|
{
|
|
type = type.GetElementType();
|
|
if (type.IsValueType)
|
|
{
|
|
il.LoadObject(type);
|
|
}
|
|
else
|
|
{
|
|
il.LoadIndirect(type);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Copied from https://github.com/evilfactory/moonsharp/blob/5264656c6442e783f3c75082cce69a93d66d4cc0/src/MoonSharp.Interpreter/Interop/Converters/ScriptToClrConversions.cs#L79-L99
|
|
private static MethodInfo GetImplicitOperatorMethod(Type baseType, Type targetType)
|
|
{
|
|
try
|
|
{
|
|
return Expression.Convert(Expression.Parameter(baseType, null), targetType).Method;
|
|
}
|
|
catch
|
|
{
|
|
if (baseType.BaseType != null)
|
|
{
|
|
return GetImplicitOperatorMethod(baseType.BaseType, targetType);
|
|
}
|
|
|
|
if (targetType.BaseType != null)
|
|
{
|
|
return GetImplicitOperatorMethod(baseType, targetType.BaseType);
|
|
}
|
|
|
|
return null;
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Loads a local variable and casts it to the target type.
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="value">The value to cast. Must be of type <see cref="object" />.</param>
|
|
/// <param name="targetType">The type to cast into.</param>
|
|
public static void LoadLocalAndCast(this Emit il, Local value, Type targetType)
|
|
{
|
|
if (value == null) throw new ArgumentNullException(nameof(value));
|
|
if (targetType == null) throw new ArgumentNullException(nameof(targetType));
|
|
if (value.LocalType != typeof(object))
|
|
{
|
|
throw new ArgumentException($"Expected local type {typeof(object)}; got {value.LocalType}.", nameof(value));
|
|
}
|
|
|
|
var guid = Guid.NewGuid().ToString("N");
|
|
|
|
if (targetType.IsByRef)
|
|
{
|
|
targetType = targetType.GetElementType();
|
|
}
|
|
|
|
// IL: var baseType = value.GetType();
|
|
var baseType = il.DeclareLocal(typeof(Type), $"cast_baseType_{guid}");
|
|
il.LoadLocal(value);
|
|
il.Call(typeof(object).GetMethod("GetType"));
|
|
il.StoreLocal(baseType);
|
|
|
|
// IL: var implicitOperatorMethod = SigilExtensions.GetImplicitOperatorMethod(baseType, <targetType>);
|
|
var implicitOperatorMethod = il.DeclareLocal(typeof(MethodInfo), $"cast_implicitOperatorMethod_{guid}");
|
|
il.LoadLocal(baseType);
|
|
il.LoadType(targetType);
|
|
il.Call(typeof(SigilExtensions).GetMethod(nameof(GetImplicitOperatorMethod), BindingFlags.NonPublic | BindingFlags.Static));
|
|
il.StoreLocal(implicitOperatorMethod);
|
|
|
|
// IL: <TargetType> castValue;
|
|
var castValue = il.DeclareLocal(targetType, $"cast_castValue_{guid}");
|
|
|
|
// IL: if (implicitConversionMethod != null)
|
|
il.LoadLocal(implicitOperatorMethod);
|
|
il.Branch((il) =>
|
|
{
|
|
// IL: var methodInvokeParams = new object[1];
|
|
var methodInvokeParams = il.DeclareLocal(typeof(object[]), $"cast_methodInvokeParams_{guid}");
|
|
il.LoadConstant(1);
|
|
il.NewArray(typeof(object));
|
|
il.StoreLocal(methodInvokeParams);
|
|
|
|
// IL: methodInvokeParams[0] = value;
|
|
il.LoadLocal(methodInvokeParams);
|
|
il.LoadConstant(0);
|
|
il.LoadLocal(value);
|
|
il.StoreElement<object>();
|
|
|
|
// IL: castValue = (<TargetType>)implicitConversionMethod.Invoke(null, methodInvokeParams);
|
|
il.LoadLocal(implicitOperatorMethod);
|
|
il.LoadNull(); // first parameter is null because implicit cast operators are static
|
|
il.LoadLocal(methodInvokeParams);
|
|
il.Call(typeof(MethodInfo).GetMethod("Invoke", new[] { typeof(object), typeof(object[]) }));
|
|
if (targetType.IsValueType)
|
|
{
|
|
il.UnboxAny(targetType);
|
|
}
|
|
else
|
|
{
|
|
il.CastClass(targetType);
|
|
}
|
|
il.StoreLocal(castValue);
|
|
},
|
|
(il) =>
|
|
{
|
|
// IL: castValue = (<TargetType>)value;
|
|
il.LoadLocal(value);
|
|
if (targetType.IsValueType)
|
|
{
|
|
il.UnboxAny(targetType);
|
|
}
|
|
else
|
|
{
|
|
il.CastClass(targetType);
|
|
}
|
|
il.StoreLocal(castValue);
|
|
});
|
|
|
|
il.LoadLocal(castValue);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Emits a call to <see cref="string.Format(string, object[])"/>.
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="format">The string format.</param>
|
|
/// <param name="args">The local variables passed to string.Format.</param>
|
|
public static void FormatString(this Emit il, string format, params Local[] args)
|
|
{
|
|
if (format == null) throw new ArgumentNullException(nameof(format));
|
|
if (args == null) throw new ArgumentNullException(nameof(args));
|
|
|
|
var guid = Guid.NewGuid().ToString("N");
|
|
|
|
var listType = typeof(List<>).MakeGenericType(typeof(object));
|
|
var list = il.DeclareLocal(listType, $"formatString_list_{guid}");
|
|
il.NewObject(listType);
|
|
il.StoreLocal(list);
|
|
|
|
foreach (var arg in args)
|
|
{
|
|
il.LoadLocal(list);
|
|
il.LoadLocal(arg);
|
|
il.ToObject(arg.LocalType);
|
|
il.CallVirtual(listType.GetMethod("Add", new[] { typeof(object) }));
|
|
}
|
|
|
|
var arr = il.DeclareLocal<object[]>($"formatString_arr_{guid}");
|
|
il.LoadLocal(list);
|
|
il.CallVirtual(listType.GetMethod("ToArray", new Type[0]));
|
|
il.StoreLocal(arr);
|
|
|
|
il.LoadConstant(format);
|
|
il.LoadLocal(arr);
|
|
il.Call(typeof(string).GetMethod("Format", new[] { typeof(string), typeof(object[]) }));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Emits a call to <see cref="DebugConsole.NewMessage(string, Color?, bool)" />.
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="message">The message to print.</param>
|
|
public static void NewMessage(this Emit il, string message)
|
|
{
|
|
var newMessage = typeof(DebugConsole).GetMethod(
|
|
name: nameof(DebugConsole.NewMessage),
|
|
bindingAttr: BindingFlags.Public | BindingFlags.Static,
|
|
binder: null,
|
|
types: new Type[] { typeof(string), typeof(Color?), typeof(bool) },
|
|
modifiers: null);
|
|
il.LoadConstant(message);
|
|
il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod());
|
|
il.LoadConstant(false);
|
|
il.Call(newMessage);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Emits a call to <see cref="DebugConsole.NewMessage(string, Color?, bool)" />,
|
|
/// using the string on the stack.
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
public static void NewMessage(this Emit il)
|
|
{
|
|
var newMessage = typeof(DebugConsole).GetMethod(
|
|
name: nameof(DebugConsole.NewMessage),
|
|
bindingAttr: BindingFlags.Public | BindingFlags.Static,
|
|
binder: null,
|
|
types: new Type[] { typeof(string), typeof(Color?), typeof(bool) },
|
|
modifiers: null);
|
|
il.Call(typeof(Color).GetProperty(nameof(Color.LightBlue), BindingFlags.Public | BindingFlags.Static).GetGetMethod());
|
|
il.LoadConstant(false);
|
|
il.Call(newMessage);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Emits a <c>foreach</c> loop that iterates over an <see cref="IEnumerable{T}"/> local variable.
|
|
/// </summary>
|
|
/// <typeparam name="T">The type of elements in the enumerable.</typeparam>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="enumerable">The enumerable.</param>
|
|
/// <param name="action">The body of code to run on each iteration.</param>
|
|
public static void ForEachEnumerable<T>(this Emit il, Local enumerable, Action<Emit, Local, Sigil.Label> action)
|
|
{
|
|
if (enumerable == null) throw new ArgumentNullException(nameof(enumerable));
|
|
if (action == null) throw new ArgumentNullException(nameof(action));
|
|
if (!typeof(IEnumerable<T>).IsAssignableFrom(enumerable.LocalType))
|
|
{
|
|
throw new ArgumentException($"Expected local type {typeof(IEnumerator<T>)}; got {enumerable.LocalType}.", nameof(enumerable));
|
|
}
|
|
|
|
var guid = Guid.NewGuid().ToString("N");
|
|
|
|
var enumerator = il.DeclareLocal<IEnumerator<T>>($"forEachEnumerable_enumerator_{guid}");
|
|
il.LoadLocal(enumerable);
|
|
il.CallVirtual(typeof(IEnumerable<T>).GetMethod("GetEnumerator"));
|
|
il.StoreLocal(enumerator);
|
|
ForEachEnumerator<T>(il, enumerator, action);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Emits a <c>foreach</c> loop that iterates over an <see cref="IEnumerator{T}"/> local variable.
|
|
/// </summary>
|
|
/// <typeparam name="T">The type of elements in the enumerable.</typeparam>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="enumerator">The enumerator.</param>
|
|
/// <param name="action">The body of code to run on each iteration.</param>
|
|
public static void ForEachEnumerator<T>(this Emit il, Local enumerator, Action<Emit, Local, Sigil.Label> action)
|
|
{
|
|
if (enumerator == null) throw new ArgumentNullException(nameof(enumerator));
|
|
if (action == null) throw new ArgumentNullException(nameof(action));
|
|
if (!typeof(IEnumerator<T>).IsAssignableFrom(enumerator.LocalType))
|
|
{
|
|
throw new ArgumentException($"Expected local type {typeof(IEnumerator<T>)}; got {enumerator.LocalType}.", nameof(enumerator));
|
|
}
|
|
|
|
var guid = Guid.NewGuid().ToString("N");
|
|
var labelLoopStart = il.DefineLabel($"forEach_loopStart_{guid}");
|
|
var labelMoveNext = il.DefineLabel($"forEach_moveNext_{guid}");
|
|
var labelLeave = il.DefineLabel($"forEach_leave_{guid}");
|
|
|
|
il.BeginExceptionBlock(out var exceptionBlock);
|
|
il.Branch(labelMoveNext); // MoveNext() needs to be called at least once before iterating
|
|
il.MarkLabel(labelLoopStart);
|
|
|
|
// IL: var current = enumerator.Current;
|
|
var current = il.DeclareLocal<T>($"forEachEnumerator_current_{guid}");
|
|
il.LoadLocal(enumerator);
|
|
il.CallVirtual(enumerator.LocalType.GetProperty("Current").GetGetMethod());
|
|
il.StoreLocal(current);
|
|
|
|
action(il, current, labelLeave);
|
|
|
|
il.MarkLabel(labelMoveNext);
|
|
il.LoadLocal(enumerator);
|
|
il.CallVirtual(typeof(IEnumerator).GetMethod("MoveNext"));
|
|
il.BranchIfTrue(labelLoopStart); // loop if MoveNext() returns true
|
|
|
|
// IL: finally { enumerator.Dispose(); }
|
|
il.BeginFinallyBlock(exceptionBlock, out var finallyBlock);
|
|
il.LoadLocal(enumerator);
|
|
il.CallVirtual(typeof(IDisposable).GetMethod("Dispose"));
|
|
il.EndFinallyBlock(finallyBlock);
|
|
|
|
il.EndExceptionBlock(exceptionBlock);
|
|
|
|
il.MarkLabel(labelLeave);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Emits a branch that only executes if the last value on the stack
|
|
/// is truthy (e.g. non-null references, 1, etc).
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="action">The body of code to run if the value is truthy.</param>
|
|
public static void If(this Emit il, Action<Emit> action)
|
|
{
|
|
if (action == null) throw new ArgumentNullException(nameof(action));
|
|
il.Branch(@if: action);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Emits a branch that only executes if the last value on the stack
|
|
/// is falsy (e.g. null references, 0, etc).
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="action">The body of code to run if the value is falsy.</param>
|
|
public static void IfNot(this Emit il, Action<Emit> action)
|
|
{
|
|
if (action == null) throw new ArgumentNullException(nameof(action));
|
|
il.Branch(@else: action);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Emits two branches that diverge based on a condition -- analogous
|
|
/// to an if-else statement. If either <paramref name="if"/>
|
|
/// or <paramref name="else"/> are omitted, it behaves the same as
|
|
/// <see cref="If(Emit, Action{Emit})"/>
|
|
/// and <see cref="IfNot(Emit, Action{Emit})"/>.
|
|
/// </summary>
|
|
/// <param name="il">The IL emitter.</param>
|
|
/// <param name="if">The body of code to run if the value is truthy.</param>
|
|
/// <param name="else">The body of code to run if the value is falsy.</param>
|
|
public static void Branch(this Emit il, Action<Emit> @if = null, Action<Emit> @else = null)
|
|
{
|
|
if (@if == null && @else == null) throw new ArgumentException("At least one of the two branches must be defined.");
|
|
|
|
var guid = Guid.NewGuid().ToString("N");
|
|
var labelEnd = il.DefineLabel($"branch_end_{guid}");
|
|
if (@if != null && @else != null)
|
|
{
|
|
var labelElse = il.DefineLabel($"branch_else_{guid}");
|
|
il.BranchIfFalse(labelElse);
|
|
@if(il);
|
|
il.Branch(labelEnd);
|
|
il.MarkLabel(labelElse);
|
|
@else(il);
|
|
}
|
|
else if (@if != null)
|
|
{
|
|
il.BranchIfFalse(labelEnd);
|
|
@if(il);
|
|
}
|
|
else
|
|
{
|
|
il.BranchIfTrue(labelEnd);
|
|
@else(il);
|
|
}
|
|
il.MarkLabel(labelEnd);
|
|
}
|
|
}
|
|
|
|
public partial class LuaCsHook
|
|
{
|
|
public enum HookMethodType
|
|
{
|
|
Before, After
|
|
}
|
|
|
|
private class LuaCsHookCallback
|
|
{
|
|
public string name;
|
|
public string hookName;
|
|
public LuaCsFunc func;
|
|
|
|
public LuaCsHookCallback(string name, string hookName, LuaCsFunc func)
|
|
{
|
|
this.name = name;
|
|
this.hookName = hookName;
|
|
this.func = func;
|
|
}
|
|
}
|
|
|
|
private class LuaCsPatch
|
|
{
|
|
public string Identifier { get; set; }
|
|
|
|
public LuaCsPatchFunc PatchFunc { get; set; }
|
|
}
|
|
|
|
private class PatchedMethod
|
|
{
|
|
public PatchedMethod(MethodInfo harmonyPrefix, MethodInfo harmonyPostfix)
|
|
{
|
|
HarmonyPrefixMethod = harmonyPrefix;
|
|
HarmonyPostfixMethod = harmonyPostfix;
|
|
Prefixes = new Dictionary<string, LuaCsPatch>();
|
|
Postfixes = new Dictionary<string, LuaCsPatch>();
|
|
}
|
|
|
|
public MethodInfo HarmonyPrefixMethod { get; }
|
|
|
|
public MethodInfo HarmonyPostfixMethod { get; }
|
|
|
|
public IEnumerator<LuaCsPatch> GetPrefixEnumerator() => Prefixes.Values.GetEnumerator();
|
|
|
|
public IEnumerator<LuaCsPatch> GetPostfixEnumerator() => Postfixes.Values.GetEnumerator();
|
|
|
|
public Dictionary<string, LuaCsPatch> Prefixes { get; }
|
|
|
|
public Dictionary<string, LuaCsPatch> Postfixes { get; }
|
|
}
|
|
|
|
public class ParameterTable
|
|
{
|
|
private readonly Dictionary<string, object> parameters;
|
|
private bool returnValueModified;
|
|
private object returnValue;
|
|
|
|
public ParameterTable(Dictionary<string, object> dict)
|
|
{
|
|
parameters = dict;
|
|
}
|
|
|
|
public object this[string paramName]
|
|
{
|
|
get
|
|
{
|
|
if (ModifiedParameters.TryGetValue(paramName, out var value))
|
|
{
|
|
return value;
|
|
}
|
|
return OriginalParameters[paramName];
|
|
}
|
|
set
|
|
{
|
|
ModifiedParameters[paramName] = value;
|
|
}
|
|
}
|
|
|
|
public object OriginalReturnValue { get; private set; }
|
|
|
|
public object ReturnValue
|
|
{
|
|
get
|
|
{
|
|
if (returnValueModified) return returnValue;
|
|
return OriginalReturnValue;
|
|
}
|
|
set
|
|
{
|
|
returnValueModified = true;
|
|
returnValue = value;
|
|
}
|
|
}
|
|
|
|
public bool PreventExecution { get; set; }
|
|
|
|
public Dictionary<string, object> OriginalParameters => parameters;
|
|
|
|
[MoonSharpHidden]
|
|
public Dictionary<string, object> ModifiedParameters { get; } = new Dictionary<string, object>();
|
|
}
|
|
|
|
private static readonly string[] prohibitedHooks =
|
|
{
|
|
"Barotrauma.Lua",
|
|
"Barotrauma.Cs",
|
|
"Barotrauma.ContentPackageManager",
|
|
};
|
|
|
|
private static void ValidatePatchTarget(MethodBase method)
|
|
{
|
|
if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h)))
|
|
{
|
|
throw new ArgumentException("Hooks into the modding environment are prohibited.");
|
|
}
|
|
}
|
|
|
|
private static string NormalizeIdentifier(string identifier)
|
|
{
|
|
return identifier?.Trim().ToLowerInvariant();
|
|
}
|
|
|
|
private Harmony harmony;
|
|
|
|
private Lazy<ModuleBuilder> patchModuleBuilder;
|
|
|
|
private readonly Dictionary<string, Dictionary<string, (LuaCsHookCallback, ACsMod)>> hookFunctions = new Dictionary<string, Dictionary<string, (LuaCsHookCallback, ACsMod)>>();
|
|
|
|
private readonly Dictionary<MethodKey, PatchedMethod> registeredPatches = new Dictionary<MethodKey, PatchedMethod>();
|
|
|
|
private LuaCsSetup luaCs;
|
|
|
|
private static LuaCsHook instance;
|
|
|
|
private struct MethodKey : IEquatable<MethodKey>
|
|
{
|
|
public ModuleHandle ModuleHandle { get; set; }
|
|
|
|
public int MetadataToken { get; set; }
|
|
|
|
public override bool Equals(object obj)
|
|
{
|
|
return obj is MethodKey key && Equals(key);
|
|
}
|
|
|
|
public bool Equals(MethodKey other)
|
|
{
|
|
return ModuleHandle.Equals(other.ModuleHandle) && MetadataToken == other.MetadataToken;
|
|
}
|
|
|
|
public override int GetHashCode()
|
|
{
|
|
return HashCode.Combine(ModuleHandle, MetadataToken);
|
|
}
|
|
|
|
public static bool operator ==(MethodKey left, MethodKey right)
|
|
{
|
|
return left.Equals(right);
|
|
}
|
|
|
|
public static bool operator !=(MethodKey left, MethodKey right)
|
|
{
|
|
return !(left == right);
|
|
}
|
|
|
|
public static MethodKey Create(MethodBase method) => new MethodKey
|
|
{
|
|
ModuleHandle = method.Module.ModuleHandle,
|
|
MetadataToken = method.MetadataToken,
|
|
};
|
|
}
|
|
|
|
internal LuaCsHook(LuaCsSetup luaCs)
|
|
{
|
|
instance = this;
|
|
this.luaCs = luaCs;
|
|
}
|
|
|
|
public void Initialize()
|
|
{
|
|
harmony = new Harmony("LuaCsForBarotrauma");
|
|
patchModuleBuilder = new Lazy<ModuleBuilder>(CreateModuleBuilder);
|
|
|
|
UserData.RegisterType<ParameterTable>();
|
|
var hookType = UserData.RegisterType<LuaCsHook>();
|
|
var hookDesc = (StandardUserDataDescriptor)hookType;
|
|
typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => {
|
|
if (
|
|
m.Name.Contains("HookMethod") ||
|
|
m.Name.Contains("UnhookMethod") ||
|
|
m.Name.Contains("EnqueueFunction") ||
|
|
m.Name.Contains("EnqueueTimedFunction")
|
|
)
|
|
{
|
|
hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default));
|
|
}
|
|
});
|
|
}
|
|
|
|
private ModuleBuilder CreateModuleBuilder()
|
|
{
|
|
var assemblyName = $"LuaCsHookPatch-{Guid.NewGuid():N}";
|
|
var assemblyBuilder = AssemblyBuilder.DefineDynamicAssembly(new AssemblyName(assemblyName), AssemblyBuilderAccess.RunAndCollect);
|
|
var moduleBuilder = assemblyBuilder.DefineDynamicModule("LuaCsHookPatch");
|
|
|
|
// This code emits the Roslyn attribute
|
|
// "IgnoresAccessChecksToAttribute" so we can freely access
|
|
// the Barotrauma assembly from our dynamic patches.
|
|
// This is important because the generated IL references
|
|
// non-public types/members.
|
|
|
|
// class IgnoresAccessChecksToAttribute {
|
|
var typeBuilder = moduleBuilder.DefineType(
|
|
name: "System.Runtime.CompilerServices.IgnoresAccessChecksToAttribute",
|
|
attr: TypeAttributes.NotPublic | TypeAttributes.Sealed | TypeAttributes.Class,
|
|
parent: typeof(Attribute));
|
|
|
|
// [AttributeUsage(AllowMultiple = true)]
|
|
var attributeUsageAttribute = new CustomAttributeBuilder(
|
|
con: typeof(AttributeUsageAttribute).GetConstructor(new[] { typeof(AttributeTargets) }),
|
|
constructorArgs: new object[] { AttributeTargets.Assembly },
|
|
namedProperties: new[] { typeof(AttributeUsageAttribute).GetProperty("AllowMultiple") },
|
|
propertyValues: new object[] { true });
|
|
typeBuilder.SetCustomAttribute(attributeUsageAttribute);
|
|
|
|
// private readonly string assemblyName;
|
|
var attributeTypeFieldBuilder = typeBuilder.DefineField(
|
|
fieldName: "assemblyName",
|
|
type: typeof(string),
|
|
attributes: FieldAttributes.Private | FieldAttributes.InitOnly);
|
|
|
|
var ctor = Emit.BuildConstructor(
|
|
parameterTypes: new[] { typeof(string) },
|
|
type: typeBuilder,
|
|
attributes: MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName,
|
|
callingConvention: CallingConventions.Standard | CallingConventions.HasThis);
|
|
// IL: this.assemblyName = arg;
|
|
ctor.LoadArgument(0);
|
|
ctor.LoadArgument(1);
|
|
ctor.StoreField(attributeTypeFieldBuilder);
|
|
ctor.Return();
|
|
ctor.CreateConstructor();
|
|
|
|
// public string AttributeName => this.assemblyName;
|
|
var attributeNameGetter = Emit.BuildMethod(
|
|
returnType: typeof(string),
|
|
parameterTypes: new Type[0],
|
|
type: typeBuilder,
|
|
name: "get_AttributeName",
|
|
attributes: MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName,
|
|
callingConvention: CallingConventions.Standard | CallingConventions.HasThis);
|
|
attributeNameGetter.LoadArgument(0);
|
|
attributeNameGetter.LoadField(attributeTypeFieldBuilder);
|
|
attributeNameGetter.Return();
|
|
|
|
var attributeName = typeBuilder.DefineProperty(
|
|
name: "AttributeName",
|
|
attributes: PropertyAttributes.None,
|
|
returnType: typeof(string),
|
|
parameterTypes: null);
|
|
attributeName.SetGetMethod(attributeNameGetter.CreateMethod());
|
|
// }
|
|
|
|
var type = typeBuilder.CreateTypeInfo().AsType();
|
|
|
|
// The assembly names are hardcoded, otherwise it would
|
|
// break unit tests.
|
|
var assembliesToExpose = new[] { "Barotrauma", "DedicatedServer" };
|
|
foreach (var name in assembliesToExpose)
|
|
{
|
|
var attr = new CustomAttributeBuilder(
|
|
con: type.GetConstructor(new[] { typeof(string)}),
|
|
constructorArgs: new[] { name });
|
|
assemblyBuilder.SetCustomAttribute(attr);
|
|
}
|
|
|
|
return moduleBuilder;
|
|
}
|
|
|
|
public void Add(string name, LuaCsFunc func, ACsMod owner = null) => Add(name, name, func, owner);
|
|
|
|
public void Add(string name, string identifier, LuaCsFunc func, ACsMod owner = null)
|
|
{
|
|
if (name == null) throw new ArgumentNullException(nameof(name));
|
|
if (identifier == null) throw new ArgumentNullException(nameof(identifier));
|
|
if (func == null) throw new ArgumentNullException(nameof(func));
|
|
|
|
name = NormalizeIdentifier(name);
|
|
identifier = NormalizeIdentifier(identifier);
|
|
|
|
if (!hookFunctions.ContainsKey(name))
|
|
{
|
|
hookFunctions.Add(name, new Dictionary<string, (LuaCsHookCallback, ACsMod)>());
|
|
}
|
|
|
|
hookFunctions[name][identifier] = (new LuaCsHookCallback(name, identifier, func), owner);
|
|
}
|
|
|
|
public bool Exists(string name, string identifier)
|
|
{
|
|
if (name == null) throw new ArgumentNullException(nameof(name));
|
|
if (identifier == null) throw new ArgumentNullException(nameof(identifier));
|
|
|
|
name = NormalizeIdentifier(name);
|
|
identifier = NormalizeIdentifier(identifier);
|
|
|
|
if (!hookFunctions.ContainsKey(name))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
return hookFunctions[name].ContainsKey(identifier);
|
|
}
|
|
|
|
public void Remove(string name, string identifier)
|
|
{
|
|
if (name == null) throw new ArgumentNullException(nameof(name));
|
|
if (identifier == null) throw new ArgumentNullException(nameof(identifier));
|
|
|
|
name = NormalizeIdentifier(name);
|
|
identifier = NormalizeIdentifier(identifier);
|
|
|
|
if (hookFunctions.ContainsKey(name) && hookFunctions[name].ContainsKey(identifier))
|
|
{
|
|
hookFunctions[name].Remove(identifier);
|
|
}
|
|
}
|
|
|
|
public void Clear()
|
|
{
|
|
harmony?.UnpatchSelf();
|
|
|
|
foreach (var (_, patch) in registeredPatches)
|
|
{
|
|
// Remove references stored in our dynamic types so the generated
|
|
// assembly can be garbage-collected.
|
|
patch.HarmonyPrefixMethod.DeclaringType
|
|
.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static)
|
|
.SetValue(null, null);
|
|
patch.HarmonyPostfixMethod.DeclaringType
|
|
.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static)
|
|
.SetValue(null, null);
|
|
}
|
|
|
|
hookFunctions.Clear();
|
|
registeredPatches.Clear();
|
|
patchModuleBuilder = null;
|
|
|
|
compatHookPrefixMethods.Clear();
|
|
compatHookPostfixMethods.Clear();
|
|
}
|
|
|
|
private Stopwatch performanceMeasurement = new Stopwatch();
|
|
|
|
[MoonSharpHidden]
|
|
public T Call<T>(string name, params object[] args)
|
|
{
|
|
if (name == null) throw new ArgumentNullException(name);
|
|
if (args == null) args = new object[0];
|
|
|
|
name = NormalizeIdentifier(name);
|
|
if (!hookFunctions.ContainsKey(name)) return default;
|
|
|
|
T lastResult = default;
|
|
|
|
var hooks = hookFunctions[name].ToArray();
|
|
foreach ((string key, var tuple) in hooks)
|
|
{
|
|
if (tuple.Item2 != null && tuple.Item2.IsDisposed)
|
|
{
|
|
hookFunctions[name].Remove(key);
|
|
continue;
|
|
}
|
|
|
|
try
|
|
{
|
|
if (luaCs.PerformanceCounter.EnablePerformanceCounter)
|
|
{
|
|
performanceMeasurement.Start();
|
|
}
|
|
|
|
var result = tuple.Item1.func(args);
|
|
|
|
if (result is DynValue luaResult)
|
|
{
|
|
if (luaResult.Type == DataType.Tuple)
|
|
{
|
|
bool replaceNil = luaResult.Tuple.Length > 1 && luaResult.Tuple[1].CastToBool();
|
|
|
|
if (!luaResult.Tuple[0].IsNil() || replaceNil)
|
|
{
|
|
lastResult = luaResult.ToObject<T>();
|
|
}
|
|
}
|
|
else if (!luaResult.IsNil())
|
|
{
|
|
lastResult = luaResult.ToObject<T>();
|
|
}
|
|
}
|
|
else
|
|
{
|
|
lastResult = (T)result;
|
|
}
|
|
|
|
if (luaCs.PerformanceCounter.EnablePerformanceCounter)
|
|
{
|
|
performanceMeasurement.Stop();
|
|
luaCs.PerformanceCounter.SetHookElapsedTicks(name, key, performanceMeasurement.ElapsedTicks);
|
|
performanceMeasurement.Reset();
|
|
}
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
var argsSb = new StringBuilder();
|
|
foreach (var arg in args)
|
|
{
|
|
argsSb.Append(arg + " ");
|
|
}
|
|
LuaCsLogger.LogError($"Error in Hook '{name}'->'{key}', with args '{argsSb}':\n{e}", LuaCsMessageOrigin.Unknown);
|
|
LuaCsLogger.HandleException(e, LuaCsMessageOrigin.Unknown);
|
|
}
|
|
}
|
|
|
|
return lastResult;
|
|
}
|
|
|
|
public object Call(string name, params object[] args) => Call<object>(name, args);
|
|
|
|
private static MethodBase ResolveMethod(string className, string methodName, string[] parameters)
|
|
{
|
|
var classType = LuaUserData.GetType(className);
|
|
if (classType == null) throw new ScriptRuntimeException($"invalid class name '{className}'");
|
|
|
|
const BindingFlags BINDING_FLAGS = BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
|
|
|
|
MethodBase method = null;
|
|
|
|
try
|
|
{
|
|
if (parameters != null)
|
|
{
|
|
Type[] parameterTypes = new Type[parameters.Length];
|
|
|
|
for (int i = 0; i < parameters.Length; i++)
|
|
{
|
|
Type type = LuaUserData.GetType(parameters[i]);
|
|
if (type == null)
|
|
{
|
|
throw new ScriptRuntimeException($"invalid parameter type '{parameters[i]}'");
|
|
}
|
|
parameterTypes[i] = type;
|
|
}
|
|
|
|
method = methodName switch
|
|
{
|
|
".cctor" => classType.TypeInitializer,
|
|
".ctor" => classType.GetConstructors(BINDING_FLAGS)
|
|
.Except(new[] { classType.TypeInitializer })
|
|
.Where(x => x.GetParameters().Select(x => x.ParameterType).SequenceEqual(parameterTypes))
|
|
.SingleOrDefault(),
|
|
_ => classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null),
|
|
};
|
|
}
|
|
else
|
|
{
|
|
ConstructorInfo GetCtor()
|
|
{
|
|
var ctors = classType.GetConstructors(BINDING_FLAGS)
|
|
.Except(new[] { classType.TypeInitializer })
|
|
.GetEnumerator();
|
|
|
|
if (!ctors.MoveNext()) return null;
|
|
var ctor = ctors.Current;
|
|
|
|
if (ctors.MoveNext()) throw new AmbiguousMatchException();
|
|
return ctor;
|
|
}
|
|
|
|
method = methodName switch
|
|
{
|
|
".cctor" => throw new ScriptRuntimeException("type initializers can't have parameters"),
|
|
".ctor" => GetCtor(),
|
|
_ => classType.GetMethod(methodName, BINDING_FLAGS),
|
|
};
|
|
}
|
|
}
|
|
catch (AmbiguousMatchException)
|
|
{
|
|
throw new ScriptRuntimeException("ambiguous method signature");
|
|
}
|
|
|
|
if (method == null)
|
|
{
|
|
var parameterNamesStr = parameters == null ? "" : string.Join(", ", parameters);
|
|
throw new ScriptRuntimeException($"method '{methodName}({parameterNamesStr})' not found in class '{className}'");
|
|
}
|
|
|
|
return method;
|
|
}
|
|
|
|
private class DynamicParameterMapping
|
|
{
|
|
public DynamicParameterMapping(string name, Type originalMethodParamType, Type harmonyPatchParamType)
|
|
{
|
|
ParameterName = name;
|
|
OriginalMethodParamType = originalMethodParamType;
|
|
HarmonyPatchParamType = harmonyPatchParamType;
|
|
}
|
|
|
|
public string ParameterName { get; set; }
|
|
|
|
public Type OriginalMethodParamType { get; set; }
|
|
|
|
public Type HarmonyPatchParamType { get; set; }
|
|
}
|
|
|
|
private static readonly Regex InvalidIdentifierCharsRegex = new Regex(@"[^\w\d]", RegexOptions.Compiled);
|
|
|
|
private const string FIELD_LUACS = "LuaCs";
|
|
|
|
// If you need to debug this:
|
|
// - use https://sharplab.io ; it's a very useful for resource for writing IL by hand.
|
|
// - use il.NewMessage("") or il.WriteLine("") to see where the IL crashes at runtime.
|
|
private MethodInfo CreateDynamicHarmonyPatch(string identifier, MethodBase original, HookMethodType hookType)
|
|
{
|
|
var parameters = new List<DynamicParameterMapping>
|
|
{
|
|
new DynamicParameterMapping("__originalMethod", null, typeof(MethodBase)),
|
|
new DynamicParameterMapping("__instance", null, typeof(object)),
|
|
};
|
|
|
|
var hasReturnType = original is MethodInfo mi && mi.ReturnType != typeof(void);
|
|
if (hasReturnType)
|
|
{
|
|
parameters.Add(new DynamicParameterMapping("__result", null, typeof(object).MakeByRefType()));
|
|
}
|
|
|
|
foreach (var parameter in original.GetParameters())
|
|
{
|
|
var paramName = parameter.Name;
|
|
var originalMethodParamType = parameter.ParameterType;
|
|
var harmonyPatchParamType = originalMethodParamType.IsByRef
|
|
? originalMethodParamType
|
|
// Make all parameters modifiable by the harmony patch
|
|
: originalMethodParamType.MakeByRefType();
|
|
parameters.Add(new DynamicParameterMapping(paramName, originalMethodParamType, harmonyPatchParamType));
|
|
}
|
|
|
|
static string MangleName(object o) => InvalidIdentifierCharsRegex.Replace(o?.ToString(), "_");
|
|
|
|
var moduleBuilder = patchModuleBuilder.Value;
|
|
var mangledName = original.DeclaringType != null
|
|
? $"{MangleName(original.DeclaringType)}-{MangleName(original)}"
|
|
: MangleName(original);
|
|
var typeBuilder = moduleBuilder.DefineType($"Patch_{identifier}_{Guid.NewGuid():N}_{mangledName}", TypeAttributes.Public);
|
|
|
|
var luaCsField = typeBuilder.DefineField(FIELD_LUACS, typeof(LuaCsSetup), FieldAttributes.Public | FieldAttributes.Static);
|
|
|
|
var methodName = hookType == HookMethodType.Before ? "HarmonyPrefix" : "HarmonyPostfix";
|
|
var il = Emit.BuildMethod(
|
|
returnType: hookType == HookMethodType.Before ? typeof(bool) : typeof(void),
|
|
parameterTypes: parameters.Select(x => x.HarmonyPatchParamType).ToArray(),
|
|
type: typeBuilder,
|
|
name: methodName,
|
|
attributes: MethodAttributes.Public | MethodAttributes.Static,
|
|
callingConvention: CallingConventions.Standard);
|
|
|
|
var labelReturn = il.DefineLabel("endOfFunction");
|
|
|
|
il.BeginExceptionBlock(out var exceptionBlock);
|
|
|
|
// IL: var harmonyReturnValue = true;
|
|
var harmonyReturnValue = il.DeclareLocal<bool>("harmonyReturnValue");
|
|
il.LoadConstant(true);
|
|
il.StoreLocal(harmonyReturnValue);
|
|
|
|
// IL: var patchKey = MethodKey.Create(__originalMethod);
|
|
var patchKey = il.DeclareLocal<MethodKey>("patchKey");
|
|
il.LoadArgument(0); // load __originalMethod
|
|
il.CastClass<MethodBase>();
|
|
il.Call(typeof(MethodKey).GetMethod(nameof(MethodKey.Create)));
|
|
il.StoreLocal(patchKey);
|
|
|
|
// IL: var patchExists = instance.registeredPatches.TryGetValue(patchKey, out MethodPatches patches)
|
|
var patchExists = il.DeclareLocal<bool>("patchExists");
|
|
var patches = il.DeclareLocal<PatchedMethod>("patches");
|
|
il.LoadField(typeof(LuaCsHook).GetField(nameof(instance), BindingFlags.NonPublic | BindingFlags.Static));
|
|
il.LoadField(typeof(LuaCsHook).GetField(nameof(registeredPatches), BindingFlags.NonPublic | BindingFlags.Instance));
|
|
il.LoadLocal(patchKey);
|
|
il.LoadLocalAddress(patches); // out parameter
|
|
il.Call(typeof(Dictionary<MethodKey, PatchedMethod>).GetMethod("TryGetValue"));
|
|
il.StoreLocal(patchExists);
|
|
|
|
// IL: if (!patchExists)
|
|
il.LoadLocal(patchExists);
|
|
il.IfNot((il) =>
|
|
{
|
|
// XXX: if we get here, it's probably because a patched
|
|
// method was running when `reloadlua` was executed.
|
|
// This can happen with a postfix on
|
|
// `Barotrauma.Networking.GameServer#Update`.
|
|
il.Leave(labelReturn);
|
|
});
|
|
|
|
// IL: var parameterDict = new Dictionary<string, object>(<paramCount>);
|
|
var parameterDict = il.DeclareLocal<Dictionary<string, object>>("parameterDict");
|
|
il.LoadConstant(parameters.Count(x => x.OriginalMethodParamType != null)); // preallocate the dictionary using the # of args
|
|
il.NewObject(typeof(Dictionary<string, object>), typeof(int));
|
|
il.StoreLocal(parameterDict);
|
|
|
|
for (ushort i = 0; i < parameters.Count; i++)
|
|
{
|
|
// Skip parameters that don't exist in the original method
|
|
if (parameters[i].OriginalMethodParamType == null) continue;
|
|
|
|
// IL: parameterDict.Add(<paramName>, <paramValue>);
|
|
il.LoadLocal(parameterDict);
|
|
il.LoadConstant(parameters[i].ParameterName);
|
|
il.LoadArgument(i);
|
|
il.ToObject(parameters[i].HarmonyPatchParamType);
|
|
il.Call(typeof(Dictionary<string, object>).GetMethod("Add"));
|
|
}
|
|
|
|
// IL: var ptable = new ParameterTable(parameterDict);
|
|
var ptable = il.DeclareLocal<ParameterTable>("ptable");
|
|
il.LoadLocal(parameterDict);
|
|
il.NewObject(typeof(ParameterTable), typeof(Dictionary<string, object>));
|
|
il.StoreLocal(ptable);
|
|
|
|
if (hasReturnType && hookType == HookMethodType.After)
|
|
{
|
|
// IL: ptable.OriginalReturnValue = __result;
|
|
il.LoadLocal(ptable);
|
|
il.LoadArgument(2); // ref __result
|
|
il.ToObject(parameters[2].HarmonyPatchParamType);
|
|
il.Call(typeof(ParameterTable).GetProperty(nameof(ParameterTable.OriginalReturnValue)).GetSetMethod(nonPublic: true));
|
|
}
|
|
|
|
// IL: var enumerator = patches.GetPrefixEnumerator();
|
|
var enumerator = il.DeclareLocal<IEnumerator<LuaCsPatch>>("enumerator");
|
|
il.LoadLocal(patches);
|
|
il.CallVirtual(typeof(PatchedMethod).GetMethod(
|
|
name: hookType == HookMethodType.Before
|
|
? nameof(PatchedMethod.GetPrefixEnumerator)
|
|
: nameof(PatchedMethod.GetPostfixEnumerator),
|
|
bindingAttr: BindingFlags.Public | BindingFlags.Instance));
|
|
il.StoreLocal(enumerator);
|
|
|
|
var labelUpdateParameters = il.DefineLabel("updateParameters");
|
|
|
|
// Iterate over prefixes/postfixes
|
|
il.ForEachEnumerator<LuaCsPatch>(enumerator, (il, current, labelLeave) =>
|
|
{
|
|
// IL: var luaReturnValue = current.PatchFunc.Invoke(__instance, ptable);
|
|
var luaReturnValue = il.DeclareLocal<DynValue>("luaReturnValue");
|
|
il.LoadLocal(current);
|
|
il.Call(typeof(LuaCsPatch).GetProperty(nameof(LuaCsPatch.PatchFunc)).GetGetMethod());
|
|
il.LoadArgument(1); // __instance
|
|
il.LoadLocal(ptable);
|
|
il.CallVirtual(typeof(LuaCsPatchFunc).GetMethod("Invoke"));
|
|
il.StoreLocal(luaReturnValue);
|
|
|
|
if (hasReturnType)
|
|
{
|
|
// IL: var ptableReturnValue = ptable.ReturnValue;
|
|
var ptableReturnValue = il.DeclareLocal<object>("ptableReturnValue");
|
|
il.LoadLocal(ptable);
|
|
il.Call(typeof(ParameterTable).GetProperty(nameof(ParameterTable.ReturnValue)).GetGetMethod());
|
|
il.StoreLocal(ptableReturnValue);
|
|
|
|
// IL: if (ptableReturnValue != null)
|
|
il.LoadLocal(ptableReturnValue);
|
|
il.If((il) =>
|
|
{
|
|
// IL: __result = ptableReturnValue;
|
|
il.LoadArgument(2); // ref __result
|
|
il.LoadLocal(ptableReturnValue);
|
|
il.StoreIndirect(typeof(object));
|
|
il.Break();
|
|
});
|
|
|
|
// IL: if (luaReturnValue != null)
|
|
il.LoadLocal(luaReturnValue);
|
|
il.If((il) =>
|
|
{
|
|
// IL: if (!luaReturnValue.IsVoid())
|
|
il.LoadLocal(luaReturnValue);
|
|
il.Call(typeof(DynValue).GetMethod(nameof(DynValue.IsVoid)));
|
|
il.IfNot((il) =>
|
|
{
|
|
// IL: var csReturnType = Type.GetTypeFromHandle(<original.ReturnType>);
|
|
var csReturnType = il.DeclareLocal<Type>("csReturnType");
|
|
il.LoadType(((MethodInfo)original).ReturnType);
|
|
il.StoreLocal(csReturnType);
|
|
|
|
// IL: var csReturnValue = luaReturnValue.ToObject(csReturnType);
|
|
var csReturnValue = il.DeclareLocal<object>("csReturnValue");
|
|
il.LoadLocal(luaReturnValue);
|
|
il.LoadLocal(csReturnType);
|
|
il.Call(typeof(DynValue).GetMethod(
|
|
name: nameof(DynValue.ToObject),
|
|
bindingAttr: BindingFlags.Public | BindingFlags.Instance,
|
|
binder: null,
|
|
types: new Type[] { typeof(Type) },
|
|
modifiers: null));
|
|
il.StoreLocal(csReturnValue);
|
|
|
|
// IL: __result = csReturnValue;
|
|
il.LoadArgument(2); // ref __result
|
|
il.LoadLocal(csReturnValue);
|
|
il.StoreIndirect(typeof(object));
|
|
});
|
|
});
|
|
}
|
|
|
|
// IL: if (ptable.PreventExecution)
|
|
il.LoadLocal(ptable);
|
|
il.Call(typeof(ParameterTable).GetProperty(nameof(ParameterTable.PreventExecution)).GetGetMethod());
|
|
il.If((il) =>
|
|
{
|
|
// IL: harmonyReturnValue = false;
|
|
il.LoadConstant(false);
|
|
il.StoreLocal(harmonyReturnValue);
|
|
|
|
// IL: break;
|
|
il.Leave(labelLeave);
|
|
});
|
|
});
|
|
|
|
// IL: var modifiedParameters = ptable.ModifiedParameters;
|
|
var modifiedParameters = il.DeclareLocal<Dictionary<string, object>>("modifiedParameters");
|
|
il.LoadLocal(ptable);
|
|
il.Call(typeof(ParameterTable).GetProperty(nameof(ParameterTable.ModifiedParameters)).GetGetMethod());
|
|
il.StoreLocal(modifiedParameters);
|
|
// IL: object modifiedValue;
|
|
var modifiedValue = il.DeclareLocal<object>("modifiedValue");
|
|
|
|
// Update the parameters
|
|
for (ushort i = 0; i < parameters.Count; i++)
|
|
{
|
|
// Skip parameters that don't exist in the original method
|
|
if (parameters[i].OriginalMethodParamType == null) continue;
|
|
|
|
// IL: if (modifiedParameters.TryGetValue("parameterName", out modifiedValue))
|
|
il.LoadLocal(modifiedParameters);
|
|
il.LoadConstant(parameters[i].ParameterName);
|
|
il.LoadLocalAddress(modifiedValue); // out parameter
|
|
il.Call(typeof(Dictionary<string, object>).GetMethod(nameof(Dictionary<string, object>.TryGetValue)));
|
|
il.If((il) =>
|
|
{
|
|
// XXX: GetElementType() gets the "real" type behind
|
|
// the ByRef. This is safe because all the parameters
|
|
// are made into ByRef to support modification.
|
|
var paramType = parameters[i].HarmonyPatchParamType.GetElementType();
|
|
|
|
// IL: ref argName = modifiedValue;
|
|
il.LoadArgument(i);
|
|
il.LoadLocalAndCast(modifiedValue, paramType);
|
|
if (paramType.IsValueType)
|
|
{
|
|
il.StoreObject(paramType);
|
|
}
|
|
else
|
|
{
|
|
il.StoreIndirect(paramType);
|
|
}
|
|
});
|
|
}
|
|
|
|
il.MarkLabel(labelReturn);
|
|
|
|
// IL: catch (Exception exception)
|
|
il.BeginCatchAllBlock(exceptionBlock, out var catchBlock);
|
|
var exception = il.DeclareLocal<Exception>("exception");
|
|
il.StoreLocal(exception);
|
|
|
|
// IL: if (LuaCs != null)
|
|
il.LoadField(luaCsField);
|
|
il.If((il) =>
|
|
{
|
|
// IL: LuaCs.HandleException(exception, LuaCsMessageOrigin.LuaMod);
|
|
il.LoadLocal(exception);
|
|
il.LoadConstant((int)LuaCsMessageOrigin.LuaMod); // underlying enum type is int
|
|
il.Call(typeof(LuaCsLogger).GetMethod(nameof(LuaCsLogger.HandleException), BindingFlags.Public | BindingFlags.Static));
|
|
});
|
|
|
|
il.EndCatchBlock(catchBlock);
|
|
|
|
il.EndExceptionBlock(exceptionBlock);
|
|
|
|
// Only prefixes return a bool
|
|
if (hookType == HookMethodType.Before)
|
|
{
|
|
il.LoadLocal(harmonyReturnValue);
|
|
}
|
|
il.Return();
|
|
|
|
var method = il.CreateMethod();
|
|
for (var i = 0; i < parameters.Count; i++)
|
|
{
|
|
method.DefineParameter(i + 1, ParameterAttributes.None, parameters[i].ParameterName);
|
|
}
|
|
|
|
var type = typeBuilder.CreateType();
|
|
type.GetField(FIELD_LUACS, BindingFlags.Public | BindingFlags.Static).SetValue(null, luaCs);
|
|
return type.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static);
|
|
}
|
|
|
|
private string Patch(string identifier, MethodBase method, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
|
{
|
|
if (method == null) throw new ArgumentNullException(nameof(method));
|
|
if (patch == null) throw new ArgumentNullException(nameof(patch));
|
|
ValidatePatchTarget(method);
|
|
|
|
identifier ??= Guid.NewGuid().ToString("N");
|
|
identifier = NormalizeIdentifier(identifier);
|
|
|
|
var patchKey = MethodKey.Create(method);
|
|
if (!registeredPatches.TryGetValue(patchKey, out var methodPatches))
|
|
{
|
|
var harmonyPrefix = CreateDynamicHarmonyPatch(identifier, method, HookMethodType.Before);
|
|
var harmonyPostfix = CreateDynamicHarmonyPatch(identifier, method, HookMethodType.After);
|
|
harmony.Patch(method, prefix: new HarmonyMethod(harmonyPrefix), postfix: new HarmonyMethod(harmonyPostfix));
|
|
methodPatches = registeredPatches[patchKey] = new PatchedMethod(harmonyPrefix, harmonyPostfix);
|
|
}
|
|
|
|
if (hookType == HookMethodType.Before)
|
|
{
|
|
if (methodPatches.Prefixes.Remove(identifier))
|
|
{
|
|
LuaCsLogger.LogMessage($"Replacing existing prefix: {identifier}");
|
|
}
|
|
|
|
methodPatches.Prefixes.Add(identifier, new LuaCsPatch
|
|
{
|
|
Identifier = identifier,
|
|
PatchFunc = patch,
|
|
});
|
|
}
|
|
else if (hookType == HookMethodType.After)
|
|
{
|
|
if (methodPatches.Postfixes.Remove(identifier))
|
|
{
|
|
LuaCsLogger.LogMessage($"Replacing existing postfix: {identifier}");
|
|
}
|
|
|
|
methodPatches.Postfixes.Add(identifier, new LuaCsPatch
|
|
{
|
|
Identifier = identifier,
|
|
PatchFunc = patch,
|
|
});
|
|
}
|
|
|
|
return identifier;
|
|
}
|
|
|
|
public string Patch(string identifier, string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
|
{
|
|
var method = ResolveMethod(className, methodName, parameterTypes);
|
|
return Patch(identifier, method, patch, hookType);
|
|
}
|
|
|
|
public string Patch(string identifier, string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
|
{
|
|
var method = ResolveMethod(className, methodName, null);
|
|
return Patch(identifier, method, patch, hookType);
|
|
}
|
|
|
|
public string Patch(string className, string methodName, string[] parameterTypes, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
|
{
|
|
var method = ResolveMethod(className, methodName, parameterTypes);
|
|
return Patch(null, method, patch, hookType);
|
|
}
|
|
|
|
public string Patch(string className, string methodName, LuaCsPatchFunc patch, HookMethodType hookType = HookMethodType.Before)
|
|
{
|
|
var method = ResolveMethod(className, methodName, null);
|
|
return Patch(null, method, patch, hookType);
|
|
}
|
|
|
|
private bool RemovePatch(string identifier, MethodBase method, HookMethodType hookType)
|
|
{
|
|
if (identifier == null) throw new ArgumentNullException(nameof(identifier));
|
|
identifier = NormalizeIdentifier(identifier);
|
|
|
|
var patchKey = MethodKey.Create(method);
|
|
if (!registeredPatches.TryGetValue(patchKey, out var methodPatches))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
return hookType switch
|
|
{
|
|
HookMethodType.Before => methodPatches.Prefixes.Remove(identifier),
|
|
HookMethodType.After => methodPatches.Postfixes.Remove(identifier),
|
|
_ => throw new ArgumentException($"Invalid {nameof(HookMethodType)} enum value.", nameof(hookType)),
|
|
};
|
|
}
|
|
|
|
public bool RemovePatch(string identifier, string className, string methodName, string[] parameterTypes, HookMethodType hookType)
|
|
{
|
|
var method = ResolveMethod(className, methodName, parameterTypes);
|
|
return RemovePatch(identifier, method, hookType);
|
|
}
|
|
|
|
public bool RemovePatch(string identifier, string className, string methodName, HookMethodType hookType)
|
|
{
|
|
var method = ResolveMethod(className, methodName, null);
|
|
return RemovePatch(identifier, method, hookType);
|
|
}
|
|
}
|
|
}
|