fix CsScriptRunner

This commit is contained in:
Oiltanker
2022-04-17 22:59:12 +03:00
parent 94f0068509
commit 5c714ce124
3 changed files with 100 additions and 63 deletions

View File

@@ -9,7 +9,7 @@ using System.Reflection.Metadata;
namespace Barotrauma {
class CsScriptFilter
{
private static readonly string[] typesPermitted = new string[] {
private static readonly string[] typesPermitted = {
// Basics
"System",
// Barotrauma
@@ -24,7 +24,7 @@ namespace Barotrauma {
"MoonSharp.Interpreter.TailCallData",
"MoonSharp.Interpreter.DataType",
};
private static readonly string[] typesProhibited = new string[] {
private static readonly string[] typesProhibited = {
"System.IO",
"Moonsharp",
"Barotrauma.IO",
@@ -58,6 +58,24 @@ namespace Barotrauma {
return null;
}
private static string ResolveTypeRef(MetadataReader reader, TypeReferenceHandle t)
{
var tRef = reader.GetTypeReference(t);
var typeName = $"{reader.GetString(tRef.Name)}";
EntityHandle handle = tRef.ResolutionScope;
TypeReference tr = tRef;
while (!handle.IsNil && handle.Kind == HandleKind.TypeReference)
{
tr = reader.GetTypeReference((TypeReferenceHandle)handle);
handle = tr.ResolutionScope;
typeName = $"{reader.GetString(tr.Name)}.{typeName}";
}
typeName = $"{reader.GetString(tr.Namespace)}.{typeName}";
return typeName;
}
public static string FilterMetadata(MetadataReader reader)
{
if (reader == null) throw new ArgumentNullException("Metadata Reader must not be null.");
@@ -65,18 +83,7 @@ namespace Barotrauma {
var conflictingTypes = new List<string>();
reader.TypeReferences.ToList().ForEach(t =>
{
var tRef = reader.GetTypeReference(t);
var typeName = $"{reader.GetString(tRef.Name)}";
EntityHandle handle = tRef.ResolutionScope;
TypeReference tr = tRef;
while (!handle.IsNil && handle.Kind == HandleKind.TypeReference)
{
tr = reader.GetTypeReference((TypeReferenceHandle)handle);
handle = tr.ResolutionScope;
typeName = $"{reader.GetString(tr.Name)}.{typeName}";
}
typeName = $"{reader.GetString(tr.Namespace)}.{typeName}";
var typeName = ResolveTypeRef(reader, t);
if (!IsTypeAllowed(typeName)) conflictingTypes.Add(typeName);
});
@@ -90,5 +97,48 @@ namespace Barotrauma {
return null;
}
private static readonly string[] permitedDefinitions = {
".<Module>",
"NetOneTimeScript.NetOneTimeScriptRunner"
};
private static readonly string[] permitedMethods = {
".ctor",
"Run"
};
public static string FilterOneTimeMetadata(MetadataReader reader)
{
var errStr = FilterMetadata(reader);
if (errStr != null) return errStr;
TypeDefinition? runDef = null;
reader.TypeDefinitions.Select(t => reader.GetTypeDefinition(t)).ToList().ForEach(t =>
{
var typeName = $"{reader.GetString(t.Name)}";
if (typeName == "NetOneTimeScriptRunner") runDef = t;
while (t.IsNested)
{
t = reader.GetTypeDefinition(t.GetDeclaringType());
typeName = $"{reader.GetString(t.Name)}.{typeName}";
}
typeName = $"{reader.GetString(t.Namespace)}.{typeName}";
if (!permitedDefinitions.Contains(typeName)) errStr = "Malformed assembly";
});
if (errStr != null) return errStr;
if (runDef == null) return "runner class not detected";
else
{
var methods = runDef.Value.GetMethods();
if (methods.Count > 2) return "malformed runner class";
methods.Select(m => reader.GetMethodDefinition(m)).ToList().ForEach(m => {
if (!permitedMethods.Contains(reader.GetString(m.Name))) errStr = "malformed runner class";
});
if (errStr != null) return errStr;
}
return null;
}
}
}

View File

@@ -19,7 +19,7 @@ namespace Barotrauma
private List<SyntaxTree> syntaxTrees;
public Assembly Assembly { get; private set; }
public CsScriptLoader(LuaCsSetup setup)
public CsScriptLoader(LuaCsSetup setup) : base(isCollectible: true)
{
this.setup = setup;
@@ -97,7 +97,7 @@ namespace Barotrauma
string errStr = "NET MODS NOT LOADED | Mod cmopilation errors:";
foreach (Diagnostic diagnostic in failures)
errStr = $"\n{diagnostic}";
errStr += $"\n{diagnostic}";
LuaCsSetup.PrintCsError(errStr);
}
else

View File

@@ -17,8 +17,14 @@ namespace Barotrauma
public LuaCsSetup setup;
private List<MetadataReference> defaultReferences;
private CSharpCompilationOptions compileOptions;
private static readonly string[] usings = {
"System",
"Barotrauma",
"System.Collections.Generic",
"System.Linq"
};
public CsScriptRunner(LuaCsSetup setup)
public CsScriptRunner(LuaCsSetup setup) : base(isCollectible: true)
{
this.setup = setup;
@@ -34,17 +40,10 @@ namespace Barotrauma
private static string ToOneTimeScript(string code)
{
var prefix = @"
public class NetOneTimeScriptRunner {
public NetOneTimeScriptRunner() { }
public object Run() {
";
var postfix = @"
return null;
}
}
";
var prefix = "";
foreach (var u in usings) prefix += $"using {u}; ";
prefix += "namespace NetOneTimeScript { public class NetOneTimeScriptRunner { public NetOneTimeScriptRunner() { } public object Run() {\n";
var postfix = "\nreturn null; } } }";
return prefix + code + postfix;
}
@@ -54,7 +53,8 @@ public class NetOneTimeScriptRunner {
try
{
var syntaxTree = SyntaxFactory.ParseSyntaxTree(ToOneTimeScript(code), CSharpParseOptions.Default);
code = ToOneTimeScript(code);
var syntaxTree = SyntaxFactory.ParseSyntaxTree(code, CSharpParseOptions.Default);
var compilation = CSharpCompilation.Create("NetOneTimeScriptAssembly", new[] { syntaxTree }, defaultReferences, compileOptions);
Assembly assembly = null;
@@ -66,43 +66,39 @@ public class NetOneTimeScriptRunner {
IEnumerable<Diagnostic> failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error);
string errStr = "Script cmopilation errors:";
var lineErr = new SortedDictionary<int, (string, string)>();
foreach (Diagnostic diagnostic in failures)
errStr = $"\n{diagnostic}";
{
var line = syntaxTree.GetLineSpan(diagnostic.Location.SourceSpan).StartLinePosition.Line;
lineErr[line] = (diagnostic.Id, diagnostic.ToString());
}
var lines = code.Split('\n');
for (var i = 1; i < lines.Length - 1; i++)
{
errStr += $"\n{i} >> {lines[i]}";
if (lineErr.ContainsKey(i)) errStr += $" <=== {lineErr[i].Item1}";
}
errStr += "\n";
foreach ((var idx, (var id, var err)) in lineErr)
{
errStr += $"\n{idx}: {err}";
}
LuaCsSetup.PrintCsError(errStr);
}
else
{
mem.Seek(0, SeekOrigin.Begin);
var metaReader = new PEReader(mem).GetMetadataReader();
var errStr = CsScriptFilter.FilterMetadata(metaReader);
if (errStr == null)
{
foreach (var handle in metaReader.TypeDefinitions)
{
var typeDef = metaReader.GetTypeDefinition(handle);
var typeName = $"{metaReader.GetString(typeDef.Namespace)}.{metaReader.GetString(typeDef.Name)}";
if (typeName != ".NetOneTimeScriptRunner")
{
errStr = "Script Error - malformed assembly";
break;
}
}
}
var errStr = CsScriptFilter.FilterOneTimeMetadata(new PEReader(mem).GetMetadataReader());
if (errStr == null)
{
mem.Seek(0, SeekOrigin.Begin);
assembly = LoadFromStream(mem);
var runner = assembly.CreateInstance("NetOneTimeScriptRunner");
var runner = assembly.CreateInstance("NetOneTimeScript.NetOneTimeScriptRunner");
if (runner != null)
{
if (runner.GetType().GetMethods().Count() > 1) LuaCsSetup.PrintCsError("Script Error - malformed runner class");
else
{
var method = runner.GetType().GetMethod("Run", BindingFlags.Public | BindingFlags.Instance);
if (method != null) scriptResilt = method.Invoke(runner, null);
else LuaCsSetup.PrintCsError("Script Error - no run method detected");
}
var method = runner.GetType().GetMethod("Run", BindingFlags.Public | BindingFlags.Instance);
if (method != null) scriptResilt = method.Invoke(runner, null);
else LuaCsSetup.PrintCsError("Script Error - no run method detected");
}
else LuaCsSetup.PrintCsError("Script Error - no runner class detected");
}
@@ -111,15 +107,6 @@ public class NetOneTimeScriptRunner {
}
Unload();
}
catch (CompilationErrorException ex)
{
string errStr = "Script Cmopilation Error:";
foreach (var diag in ex.Diagnostics)
{
errStr += "\n" + diag.ToString();
}
LuaCsSetup.PrintCsError(errStr);
}
catch (Exception ex)
{
LuaCsSetup.PrintCsError("Error running script:\n" + ex.Message + "\n" + ex.StackTrace);