Merge pull request #109 from notpeelz/fix-patch-confusing-ctor-with-cctor

Fix Hook.Patch confusing .ctor with .cctor
This commit is contained in:
Evil Factory
2022-09-17 10:15:43 -03:00
committed by GitHub
4 changed files with 88 additions and 21 deletions

View File

@@ -1,4 +1,4 @@
using System;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
@@ -11,8 +11,10 @@ namespace Barotrauma
{
public static Type GetType(string typeName)
{
if (typeName == null || typeName.Length == 0) return null;
var byRef = false;
if (typeName.StartsWith("out ") || typeName.StartsWith("ref "))
if (typeName.StartsWith("out ") || typeName.StartsWith("ref "))
{
typeName = typeName.Remove(0, 4);
byRef = true;
@@ -266,4 +268,4 @@ namespace Barotrauma
return CreateUserDataFromDescriptor(scriptObject, descriptor);
}
}
}
}

View File

@@ -817,31 +817,59 @@ namespace Barotrauma
private static MethodBase ResolveMethod(string className, string methodName, string[] parameters)
{
var classType = LuaUserData.GetType(className);
if (classType == null) throw new InvalidOperationException($"Invalid class name '{className}'");
if (classType == null) throw new ScriptRuntimeException($"invalid class name '{className}'");
const BindingFlags BINDING_FLAGS = BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
const string CTOR = ".ctor";
MethodBase method = null;
if (parameters != null)
try
{
var parameterTypes = parameters.Select(x => LuaUserData.GetType(x)).ToArray();
// TODO: remove the casts once we can use C# 9 features
method = methodName == CTOR
? (MethodBase)classType.GetConstructor(BINDING_FLAGS, null, parameterTypes, null)
: (MethodBase)classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null);
if (parameters != null)
{
var parameterTypes = parameters.Select(x => LuaUserData.GetType(x)).ToArray();
method = methodName switch
{
".cctor" => classType.TypeInitializer,
".ctor" => classType.GetConstructors(BINDING_FLAGS)
.Except(new[] { classType.TypeInitializer })
.Where(x => x.GetParameters().Select(x => x.ParameterType).SequenceEqual(parameterTypes))
.SingleOrDefault(),
_ => classType.GetMethod(methodName, BINDING_FLAGS, null, parameterTypes, null),
};
}
else
{
ConstructorInfo GetCtor()
{
var ctors = classType.GetConstructors(BINDING_FLAGS)
.Except(new[] { classType.TypeInitializer })
.GetEnumerator();
if (!ctors.MoveNext()) return null;
var ctor = ctors.Current;
if (ctors.MoveNext()) throw new AmbiguousMatchException();
return ctor;
}
method = methodName switch
{
".cctor" => throw new ScriptRuntimeException("type initializers can't have parameters"),
".ctor" => GetCtor(),
_ => classType.GetMethod(methodName, BINDING_FLAGS),
};
}
}
else
catch (AmbiguousMatchException)
{
method = methodName == CTOR
? (MethodBase)classType.GetConstructor(BINDING_FLAGS, null, Array.Empty<Type>(), null)
: (MethodBase)classType.GetMethod(methodName, BINDING_FLAGS);
throw new ScriptRuntimeException("ambiguous method signature");
}
if (method == null)
{
var parameterNamesStr = parameters == null ? "" : string.Join(", ", parameters);
throw new InvalidOperationException($"Method '{methodName}({parameterNamesStr})' not found in class '{className}'");
throw new ScriptRuntimeException($"method '{methodName}({parameterNamesStr})' not found in class '{className}'");
}
return method;

View File

@@ -39,7 +39,7 @@ namespace TestProject.LuaCs
if (patchId != null) args.Add(Stringify(patchId));
args.Add(Stringify(className));
args.Add(Stringify(methodName));
if (parameters != null && parameters.Length > 0)
if (parameters != null)
{
var sb = new StringBuilder();
sb.Append("{ ");
@@ -101,7 +101,7 @@ namespace TestProject.LuaCs
end
", LuaCsHook.HookMethodType.Before);
Assert.Equal(DataType.String, returnValue.Type);
return new(returnValue.String, () => luaCs.RemovePrefix<T>(returnValue.String, methodName));
return new(returnValue.String, () => luaCs.RemovePrefix<T>(returnValue.String, methodName, parameters));
}
public static PatchHandle AddPostfix<T>(this LuaCsSetup luaCs, string body, string methodName, string[]? parameters = null, string? patchId = null)
@@ -113,7 +113,7 @@ namespace TestProject.LuaCs
end
", LuaCsHook.HookMethodType.After);
Assert.Equal(DataType.String, returnValue.Type);
return new(returnValue.String, () => luaCs.RemovePostfix<T>(returnValue.String, methodName));
return new(returnValue.String, () => luaCs.RemovePostfix<T>(returnValue.String, methodName, parameters));
}
public static bool RemovePrefix<T>(this LuaCsSetup luaCs, string patchId, string methodName, string[]? parameters = null)

View File

@@ -1,6 +1,7 @@
using Barotrauma;
using Microsoft.Xna.Framework;
using MoonSharp.Interpreter;
using System;
using Xunit;
using Xunit.Abstractions;
@@ -31,6 +32,7 @@ namespace TestProject.LuaCs
UserData.RegisterType<PatchTargetReturnsInterface>();
UserData.RegisterType<PatchTargetModifyParams>();
UserData.RegisterType<PatchTargetVector2>();
UserData.RegisterType<PatchTargetAmbiguous>();
UserData.RegisterType<PatchTargetConstructor>();
UserData.RegisterType<PatchTargetNumbers>();
@@ -306,6 +308,41 @@ namespace TestProject.LuaCs
Assert.Equal("{X:1 Y:2}", returnValue);
}
private class PatchTargetAmbiguous
{
public PatchTargetAmbiguous() { }
public PatchTargetAmbiguous(int a) { }
public void Blah() { }
public void Blah(int a) { }
}
[Fact]
public void TestPatchAmbiguous()
{
using var patchTargetHandle = HookPatchHelpers.LockPatchTarget<PatchTargetAmbiguous>();
Assert.Throws<ScriptRuntimeException>(() =>
{
using var postfixHandle = luaCs.AddPostfix<PatchTargetAmbiguous>("", ".ctor");
});
Assert.Throws<ScriptRuntimeException>(() =>
{
using var prefixHandle = luaCs.AddPrefix<PatchTargetAmbiguous>("", ".ctor");
});
Assert.Throws<ScriptRuntimeException>(() =>
{
using var postfixHandle = luaCs.AddPostfix<PatchTargetAmbiguous>("", nameof(PatchTargetAmbiguous.Blah));
});
Assert.Throws<ScriptRuntimeException>(() =>
{
using var prefixHandle = luaCs.AddPrefix<PatchTargetAmbiguous>("", nameof(PatchTargetAmbiguous.Blah));
});
}
private class PatchTargetConstructor
{
public enum CtorType
@@ -345,10 +382,10 @@ namespace TestProject.LuaCs
{
using var postfixHandle = luaCs.AddPostfix<PatchTargetConstructor>(@$"
instance.Ctor = {(int)PatchTargetConstructor.CtorType.Patched}
", ".ctor");
", ".ctor", Array.Empty<string>());
using var prefixHandle = luaCs.AddPrefix<PatchTargetConstructor>(@$"
instance.PrefixRan = true
", ".ctor");
", ".ctor", Array.Empty<string>());
var target = new PatchTargetConstructor();
Assert.Equal(PatchTargetConstructor.CtorType.Patched, target.Ctor);
Assert.True(target.PrefixRan);