From 5c714ce1241c9d3f6fb1f6e36bb61c39a9a443ee Mon Sep 17 00:00:00 2001 From: Oiltanker Date: Sun, 17 Apr 2022 22:59:12 +0300 Subject: [PATCH] fix CsScriptRunner --- .../SharedSource/LuaCs/Cs/CsScriptFilter.cs | 78 ++++++++++++++---- .../SharedSource/LuaCs/Cs/CsScriptLoader.cs | 4 +- .../SharedSource/LuaCs/Cs/CsScriptRunner.cs | 81 ++++++++----------- 3 files changed, 100 insertions(+), 63 deletions(-) diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs index a735ba69e..faa4851b0 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs @@ -9,7 +9,7 @@ using System.Reflection.Metadata; namespace Barotrauma { class CsScriptFilter { - private static readonly string[] typesPermitted = new string[] { + private static readonly string[] typesPermitted = { // Basics "System", // Barotrauma @@ -24,7 +24,7 @@ namespace Barotrauma { "MoonSharp.Interpreter.TailCallData", "MoonSharp.Interpreter.DataType", }; - private static readonly string[] typesProhibited = new string[] { + private static readonly string[] typesProhibited = { "System.IO", "Moonsharp", "Barotrauma.IO", @@ -58,6 +58,24 @@ namespace Barotrauma { return null; } + private static string ResolveTypeRef(MetadataReader reader, TypeReferenceHandle t) + { + var tRef = reader.GetTypeReference(t); + + var typeName = $"{reader.GetString(tRef.Name)}"; + EntityHandle handle = tRef.ResolutionScope; + TypeReference tr = tRef; + while (!handle.IsNil && handle.Kind == HandleKind.TypeReference) + { + tr = reader.GetTypeReference((TypeReferenceHandle)handle); + handle = tr.ResolutionScope; + typeName = $"{reader.GetString(tr.Name)}.{typeName}"; + } + typeName = $"{reader.GetString(tr.Namespace)}.{typeName}"; + + return typeName; + } + public static string FilterMetadata(MetadataReader reader) { if (reader == null) throw new ArgumentNullException("Metadata Reader must not be null."); @@ -65,18 +83,7 @@ namespace Barotrauma { var conflictingTypes = new List(); reader.TypeReferences.ToList().ForEach(t => { - var tRef = reader.GetTypeReference(t); - - var typeName = $"{reader.GetString(tRef.Name)}"; - EntityHandle handle = tRef.ResolutionScope; - TypeReference tr = tRef; - while (!handle.IsNil && handle.Kind == HandleKind.TypeReference) - { - tr = reader.GetTypeReference((TypeReferenceHandle)handle); - handle = tr.ResolutionScope; - typeName = $"{reader.GetString(tr.Name)}.{typeName}"; - } - typeName = $"{reader.GetString(tr.Namespace)}.{typeName}"; + var typeName = ResolveTypeRef(reader, t); if (!IsTypeAllowed(typeName)) conflictingTypes.Add(typeName); }); @@ -90,5 +97,48 @@ namespace Barotrauma { return null; } + + private static readonly string[] permitedDefinitions = { + ".", + "NetOneTimeScript.NetOneTimeScriptRunner" + }; + private static readonly string[] permitedMethods = { + ".ctor", + "Run" + }; + public static string FilterOneTimeMetadata(MetadataReader reader) + { + var errStr = FilterMetadata(reader); + if (errStr != null) return errStr; + + TypeDefinition? runDef = null; + reader.TypeDefinitions.Select(t => reader.GetTypeDefinition(t)).ToList().ForEach(t => + { + var typeName = $"{reader.GetString(t.Name)}"; + if (typeName == "NetOneTimeScriptRunner") runDef = t; + while (t.IsNested) + { + t = reader.GetTypeDefinition(t.GetDeclaringType()); + typeName = $"{reader.GetString(t.Name)}.{typeName}"; + } + typeName = $"{reader.GetString(t.Namespace)}.{typeName}"; + if (!permitedDefinitions.Contains(typeName)) errStr = "Malformed assembly"; + }); + if (errStr != null) return errStr; + + if (runDef == null) return "runner class not detected"; + else + { + var methods = runDef.Value.GetMethods(); + if (methods.Count > 2) return "malformed runner class"; + + methods.Select(m => reader.GetMethodDefinition(m)).ToList().ForEach(m => { + if (!permitedMethods.Contains(reader.GetString(m.Name))) errStr = "malformed runner class"; + }); + if (errStr != null) return errStr; + } + + return null; + } } } \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs index e65f62f67..49906e2cd 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs @@ -19,7 +19,7 @@ namespace Barotrauma private List syntaxTrees; public Assembly Assembly { get; private set; } - public CsScriptLoader(LuaCsSetup setup) + public CsScriptLoader(LuaCsSetup setup) : base(isCollectible: true) { this.setup = setup; @@ -97,7 +97,7 @@ namespace Barotrauma string errStr = "NET MODS NOT LOADED | Mod cmopilation errors:"; foreach (Diagnostic diagnostic in failures) - errStr = $"\n{diagnostic}"; + errStr += $"\n{diagnostic}"; LuaCsSetup.PrintCsError(errStr); } else diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs index 60dcbe22e..d8338d815 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs @@ -17,8 +17,14 @@ namespace Barotrauma public LuaCsSetup setup; private List defaultReferences; private CSharpCompilationOptions compileOptions; + private static readonly string[] usings = { + "System", + "Barotrauma", + "System.Collections.Generic", + "System.Linq" + }; - public CsScriptRunner(LuaCsSetup setup) + public CsScriptRunner(LuaCsSetup setup) : base(isCollectible: true) { this.setup = setup; @@ -34,17 +40,10 @@ namespace Barotrauma private static string ToOneTimeScript(string code) { - var prefix = @" -public class NetOneTimeScriptRunner { - public NetOneTimeScriptRunner() { } - public object Run() { - -"; - var postfix = @" - return null; - } -} -"; + var prefix = ""; + foreach (var u in usings) prefix += $"using {u}; "; + prefix += "namespace NetOneTimeScript { public class NetOneTimeScriptRunner { public NetOneTimeScriptRunner() { } public object Run() {\n"; + var postfix = "\nreturn null; } } }"; return prefix + code + postfix; } @@ -54,7 +53,8 @@ public class NetOneTimeScriptRunner { try { - var syntaxTree = SyntaxFactory.ParseSyntaxTree(ToOneTimeScript(code), CSharpParseOptions.Default); + code = ToOneTimeScript(code); + var syntaxTree = SyntaxFactory.ParseSyntaxTree(code, CSharpParseOptions.Default); var compilation = CSharpCompilation.Create("NetOneTimeScriptAssembly", new[] { syntaxTree }, defaultReferences, compileOptions); Assembly assembly = null; @@ -66,43 +66,39 @@ public class NetOneTimeScriptRunner { IEnumerable failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error); string errStr = "Script cmopilation errors:"; + var lineErr = new SortedDictionary(); foreach (Diagnostic diagnostic in failures) - errStr = $"\n{diagnostic}"; + { + var line = syntaxTree.GetLineSpan(diagnostic.Location.SourceSpan).StartLinePosition.Line; + lineErr[line] = (diagnostic.Id, diagnostic.ToString()); + } + var lines = code.Split('\n'); + for (var i = 1; i < lines.Length - 1; i++) + { + errStr += $"\n{i} >> {lines[i]}"; + if (lineErr.ContainsKey(i)) errStr += $" <=== {lineErr[i].Item1}"; + } + errStr += "\n"; + foreach ((var idx, (var id, var err)) in lineErr) + { + errStr += $"\n{idx}: {err}"; + } LuaCsSetup.PrintCsError(errStr); } else { mem.Seek(0, SeekOrigin.Begin); - var metaReader = new PEReader(mem).GetMetadataReader(); - var errStr = CsScriptFilter.FilterMetadata(metaReader); - if (errStr == null) - { - foreach (var handle in metaReader.TypeDefinitions) - { - var typeDef = metaReader.GetTypeDefinition(handle); - var typeName = $"{metaReader.GetString(typeDef.Namespace)}.{metaReader.GetString(typeDef.Name)}"; - if (typeName != ".NetOneTimeScriptRunner") - { - errStr = "Script Error - malformed assembly"; - break; - } - } - } - + var errStr = CsScriptFilter.FilterOneTimeMetadata(new PEReader(mem).GetMetadataReader()); if (errStr == null) { mem.Seek(0, SeekOrigin.Begin); assembly = LoadFromStream(mem); - var runner = assembly.CreateInstance("NetOneTimeScriptRunner"); + var runner = assembly.CreateInstance("NetOneTimeScript.NetOneTimeScriptRunner"); if (runner != null) { - if (runner.GetType().GetMethods().Count() > 1) LuaCsSetup.PrintCsError("Script Error - malformed runner class"); - else - { - var method = runner.GetType().GetMethod("Run", BindingFlags.Public | BindingFlags.Instance); - if (method != null) scriptResilt = method.Invoke(runner, null); - else LuaCsSetup.PrintCsError("Script Error - no run method detected"); - } + var method = runner.GetType().GetMethod("Run", BindingFlags.Public | BindingFlags.Instance); + if (method != null) scriptResilt = method.Invoke(runner, null); + else LuaCsSetup.PrintCsError("Script Error - no run method detected"); } else LuaCsSetup.PrintCsError("Script Error - no runner class detected"); } @@ -111,15 +107,6 @@ public class NetOneTimeScriptRunner { } Unload(); } - catch (CompilationErrorException ex) - { - string errStr = "Script Cmopilation Error:"; - foreach (var diag in ex.Diagnostics) - { - errStr += "\n" + diag.ToString(); - } - LuaCsSetup.PrintCsError(errStr); - } catch (Exception ex) { LuaCsSetup.PrintCsError("Error running script:\n" + ex.Message + "\n" + ex.StackTrace);