diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs index 344856c46..66c9f14cd 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs @@ -820,22 +820,42 @@ namespace Barotrauma if (classType == null) throw new InvalidOperationException($"Invalid class name '{className}'"); const BindingFlags BINDING_FLAGS = BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; - const string CTOR = ".ctor"; MethodBase method = null; if (parameters != null) { var parameterTypes = parameters.Select(x => LuaUserData.GetType(x)).ToArray(); - // TODO: remove the casts once we can use C# 9 features - method = methodName == CTOR - ? (MethodBase)classType.GetConstructor(BINDING_FLAGS, null, parameterTypes, null) - : (MethodBase)classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null); + method = methodName switch + { + ".cctor" => classType.TypeInitializer, + ".ctor" => classType.GetConstructors(BINDING_FLAGS) + .Except(new[] { classType.TypeInitializer }) + .Where(x => x.GetParameters().Select(x => x.ParameterType).SequenceEqual(parameterTypes)) + .SingleOrDefault(), + _ => classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null), + }; } else { - method = methodName == CTOR - ? (MethodBase)classType.GetConstructor(BINDING_FLAGS, null, Array.Empty(), null) - : (MethodBase)classType.GetMethod(methodName, BINDING_FLAGS); + ConstructorInfo GetCtor() + { + var ctors = classType.GetConstructors(BINDING_FLAGS) + .Except(new[] { classType.TypeInitializer }) + .GetEnumerator(); + + if (!ctors.MoveNext()) return null; + var ctor = ctors.Current; + + if (ctors.MoveNext()) throw new AmbiguousMatchException(); + return ctor; + } + + method = methodName switch + { + ".cctor" => throw new InvalidOperationException("Type initializers can't have parameters."), + ".ctor" => GetCtor(), + _ => classType.GetMethod(methodName, BINDING_FLAGS), + }; } if (method == null) diff --git a/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs b/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs index 632ceca38..e506b2682 100644 --- a/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs +++ b/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs @@ -39,7 +39,7 @@ namespace TestProject.LuaCs if (patchId != null) args.Add(Stringify(patchId)); args.Add(Stringify(className)); args.Add(Stringify(methodName)); - if (parameters != null && parameters.Length > 0) + if (parameters != null) { var sb = new StringBuilder(); sb.Append("{ "); @@ -101,7 +101,7 @@ namespace TestProject.LuaCs end ", LuaCsHook.HookMethodType.Before); Assert.Equal(DataType.String, returnValue.Type); - return new(returnValue.String, () => luaCs.RemovePrefix(returnValue.String, methodName)); + return new(returnValue.String, () => luaCs.RemovePrefix(returnValue.String, methodName, parameters)); } public static PatchHandle AddPostfix(this LuaCsSetup luaCs, string body, string methodName, string[]? parameters = null, string? patchId = null) @@ -113,7 +113,7 @@ namespace TestProject.LuaCs end ", LuaCsHook.HookMethodType.After); Assert.Equal(DataType.String, returnValue.Type); - return new(returnValue.String, () => luaCs.RemovePostfix(returnValue.String, methodName)); + return new(returnValue.String, () => luaCs.RemovePostfix(returnValue.String, methodName, parameters)); } public static bool RemovePrefix(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null) diff --git a/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs b/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs index bd69ff172..57537317d 100644 --- a/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs +++ b/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs @@ -1,6 +1,7 @@ using Barotrauma; using Microsoft.Xna.Framework; using MoonSharp.Interpreter; +using System; using Xunit; using Xunit.Abstractions; @@ -345,10 +346,10 @@ namespace TestProject.LuaCs { using var postfixHandle = luaCs.AddPostfix(@$" instance.Ctor = {(int)PatchTargetConstructor.CtorType.Patched} - ", ".ctor"); + ", ".ctor", Array.Empty()); using var prefixHandle = luaCs.AddPrefix(@$" instance.PrefixRan = true - ", ".ctor"); + ", ".ctor", Array.Empty()); var target = new PatchTargetConstructor(); Assert.Equal(PatchTargetConstructor.CtorType.Patched, target.Ctor); Assert.True(target.PrefixRan);