diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs index 94b74963e..914caefa9 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Linq; using System.Reflection; @@ -11,8 +11,10 @@ namespace Barotrauma { public static Type GetType(string typeName) { + if (typeName == null || typeName.Length == 0) return null; + var byRef = false; - if (typeName.StartsWith("out ") || typeName.StartsWith("ref ")) + if (typeName.StartsWith("out ") || typeName.StartsWith("ref ")) { typeName = typeName.Remove(0, 4); byRef = true; @@ -266,4 +268,4 @@ namespace Barotrauma return CreateUserDataFromDescriptor(scriptObject, descriptor); } } -} \ No newline at end of file +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs index 344856c46..38df2d536 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs @@ -817,31 +817,59 @@ namespace Barotrauma private static MethodBase ResolveMethod(string className, string methodName, string[] parameters) { var classType = LuaUserData.GetType(className); - if (classType == null) throw new InvalidOperationException($"Invalid class name '{className}'"); + if (classType == null) throw new ScriptRuntimeException($"invalid class name '{className}'"); const BindingFlags BINDING_FLAGS = BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; - const string CTOR = ".ctor"; MethodBase method = null; - if (parameters != null) + + try { - var parameterTypes = parameters.Select(x => LuaUserData.GetType(x)).ToArray(); - // TODO: remove the casts once we can use C# 9 features - method = methodName == CTOR - ? (MethodBase)classType.GetConstructor(BINDING_FLAGS, null, parameterTypes, null) - : (MethodBase)classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null); + if (parameters != null) + { + var parameterTypes = parameters.Select(x => LuaUserData.GetType(x)).ToArray(); + method = methodName switch + { + ".cctor" => classType.TypeInitializer, + ".ctor" => classType.GetConstructors(BINDING_FLAGS) + .Except(new[] { classType.TypeInitializer }) + .Where(x => x.GetParameters().Select(x => x.ParameterType).SequenceEqual(parameterTypes)) + .SingleOrDefault(), + _ => classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null), + }; + } + else + { + ConstructorInfo GetCtor() + { + var ctors = classType.GetConstructors(BINDING_FLAGS) + .Except(new[] { classType.TypeInitializer }) + .GetEnumerator(); + + if (!ctors.MoveNext()) return null; + var ctor = ctors.Current; + + if (ctors.MoveNext()) throw new AmbiguousMatchException(); + return ctor; + } + + method = methodName switch + { + ".cctor" => throw new ScriptRuntimeException("type initializers can't have parameters"), + ".ctor" => GetCtor(), + _ => classType.GetMethod(methodName, BINDING_FLAGS), + }; + } } - else + catch (AmbiguousMatchException) { - method = methodName == CTOR - ? (MethodBase)classType.GetConstructor(BINDING_FLAGS, null, Array.Empty(), null) - : (MethodBase)classType.GetMethod(methodName, BINDING_FLAGS); + throw new ScriptRuntimeException("ambiguous method signature"); } if (method == null) { var parameterNamesStr = parameters == null ? "" : string.Join(", ", parameters); - throw new InvalidOperationException($"Method '{methodName}({parameterNamesStr})' not found in class '{className}'"); + throw new ScriptRuntimeException($"method '{methodName}({parameterNamesStr})' not found in class '{className}'"); } return method; diff --git a/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs b/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs index 632ceca38..e506b2682 100644 --- a/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs +++ b/Barotrauma/BarotraumaTest/LuaCs/HookPatchHelpers.cs @@ -39,7 +39,7 @@ namespace TestProject.LuaCs if (patchId != null) args.Add(Stringify(patchId)); args.Add(Stringify(className)); args.Add(Stringify(methodName)); - if (parameters != null && parameters.Length > 0) + if (parameters != null) { var sb = new StringBuilder(); sb.Append("{ "); @@ -101,7 +101,7 @@ namespace TestProject.LuaCs end ", LuaCsHook.HookMethodType.Before); Assert.Equal(DataType.String, returnValue.Type); - return new(returnValue.String, () => luaCs.RemovePrefix(returnValue.String, methodName)); + return new(returnValue.String, () => luaCs.RemovePrefix(returnValue.String, methodName, parameters)); } public static PatchHandle AddPostfix(this LuaCsSetup luaCs, string body, string methodName, string[]? parameters = null, string? patchId = null) @@ -113,7 +113,7 @@ namespace TestProject.LuaCs end ", LuaCsHook.HookMethodType.After); Assert.Equal(DataType.String, returnValue.Type); - return new(returnValue.String, () => luaCs.RemovePostfix(returnValue.String, methodName)); + return new(returnValue.String, () => luaCs.RemovePostfix(returnValue.String, methodName, parameters)); } public static bool RemovePrefix(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null) diff --git a/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs b/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs index bd69ff172..de2169909 100644 --- a/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs +++ b/Barotrauma/BarotraumaTest/LuaCs/HookPatchTests.cs @@ -1,6 +1,7 @@ using Barotrauma; using Microsoft.Xna.Framework; using MoonSharp.Interpreter; +using System; using Xunit; using Xunit.Abstractions; @@ -31,6 +32,7 @@ namespace TestProject.LuaCs UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); + UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); @@ -306,6 +308,41 @@ namespace TestProject.LuaCs Assert.Equal("{X:1 Y:2}", returnValue); } + private class PatchTargetAmbiguous + { + public PatchTargetAmbiguous() { } + + public PatchTargetAmbiguous(int a) { } + + public void Blah() { } + + public void Blah(int a) { } + } + + [Fact] + public void TestPatchAmbiguous() + { + using var patchTargetHandle = HookPatchHelpers.LockPatchTarget(); + + Assert.Throws(() => + { + using var postfixHandle = luaCs.AddPostfix("", ".ctor"); + }); + Assert.Throws(() => + { + using var prefixHandle = luaCs.AddPrefix("", ".ctor"); + }); + + Assert.Throws(() => + { + using var postfixHandle = luaCs.AddPostfix("", nameof(PatchTargetAmbiguous.Blah)); + }); + Assert.Throws(() => + { + using var prefixHandle = luaCs.AddPrefix("", nameof(PatchTargetAmbiguous.Blah)); + }); + } + private class PatchTargetConstructor { public enum CtorType @@ -345,10 +382,10 @@ namespace TestProject.LuaCs { using var postfixHandle = luaCs.AddPostfix(@$" instance.Ctor = {(int)PatchTargetConstructor.CtorType.Patched} - ", ".ctor"); + ", ".ctor", Array.Empty()); using var prefixHandle = luaCs.AddPrefix(@$" instance.PrefixRan = true - ", ".ctor"); + ", ".ctor", Array.Empty()); var target = new PatchTargetConstructor(); Assert.Equal(PatchTargetConstructor.CtorType.Patched, target.Ctor); Assert.True(target.PrefixRan);