diff --git a/Barotrauma/BarotraumaShared/SharedSource/Net/NetScriptFilter.cs b/Barotrauma/BarotraumaShared/SharedSource/Net/NetScriptFilter.cs index e70adb182..4d2b46b29 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/Net/NetScriptFilter.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/Net/NetScriptFilter.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Reflection.Metadata; partial class NetScript { @@ -13,18 +14,34 @@ partial class NetScript { private const bool useWhitelist = false; - private static string[] classesPermited = new string[] {}; - private static string[] classesProhibited = new string[] { }; - public static bool IsClassAllowed(string usingName) + private static string[] typesPermited = new string[] { + // Basics + "System.Runtime.CompilerServices.CompilationRelaxationsAttribute", + "System.Runtime.CompilerServices.RuntimeCompatibilityAttribute", + "System.Diagnostics.DebuggableAttribute", + "System.Object", + "System.String", + "System.Collections", + // Some roslyn magic + ".DebuggingModes", + // Barotrauma + "Barotrauma", + }; + private static string[] typessProhibited = new string[] { + //"System.Reflection", + "System.IO.File", + }; + public static bool IsTypeAllowed(string usingName) { - if (useWhitelist && !classesPermited.Any(u => u.Equals(usingName))) return false; - if (classesProhibited.Any(u => u.Equals(usingName))) return false; + if (useWhitelist && !typesPermited.Any(u => u.StartsWith(usingName))) return false; + if (typessProhibited.Any(u => u.StartsWith(usingName))) return false; return true; } public static string FilterSyntaxTree(CSharpSyntaxTree tree) { if (tree == null) throw new ArgumentNullException("Syntax tree must not be null."); + { // Disallow top-level statements var nodeCheck = tree.GetRoot().DescendantNodes(); @@ -32,65 +49,34 @@ partial class NetScript if (tlStatements.Count > 0) { string errStr = "Cmopilation Error:"; - foreach (var tls in tlStatements) tls.GetDiagnostics().ToList().ForEach(d => errStr += "\n" + d.ToString()); + foreach (var tls in tlStatements) tls.GetDiagnostics().ToList().ForEach(d => errStr += $"\n {d.ToString()}"); return errStr; } } - var compRoot = tree.GetCompilationUnitRoot(); - var refDirs = compRoot.GetReferenceDirectives().ToList(); - Console.WriteLine($"Reference Directives [{refDirs.Count}]:"); - refDirs.ForEach(d => Console.WriteLine(d.ToFullString())); - - List allUsedTypes = new List(); - { // Find all used types - } - - List allResolvedTypes = new List(); - { // Resolve all types - } - - { // Check all used types - } - - if (!Directory.Exists("./SyntaxTrees")) Directory.CreateDirectory("./SyntaxTrees"); - string fileName = "./SyntaxTrees/" + tree.FilePath.Replace("/", "--") + ".txt"; - if (File.Exists(fileName)) File.Delete(fileName); - var fileWriter = File.CreateText(fileName); - - var nodes = new Queue<(SyntaxNode, int)>(); - nodes.Enqueue((tree.GetRoot(), 0)); - while (nodes.Count > 0) - { - var nodeElem = nodes.Dequeue(); - var node = nodeElem.Item1; - var indent = nodeElem.Item2; - - node.ChildNodes().ToList().ForEach(n => { - if (n.ChildNodes().Count() > 0) nodes.Enqueue((n, indent + 1)); - if (!( - n is MemberAccessExpressionSyntax || - n is UsingDirectiveSyntax || - n is BaseTypeSyntax || - n is TypeSyntax - )) return; - //Console.WriteLine(new String(' ', indent * 2) + n.GetType().Name + " | " + n.GetText()?.ToString() ?? "null"); - fileWriter.WriteLine(new String(' ', indent * 2) + n.GetType().Name + " | " + n.GetText()?.ToString() ?? "null"); - }); - node.DescendantNodes().ToList().ForEach(n => { - if (n.DescendantNodes().Count() > 0 && !nodes.Contains((n, indent + 1))) nodes.Enqueue((n, indent + 1)); - if (!( - n is MemberAccessExpressionSyntax || - n is UsingDirectiveSyntax || - n is BaseTypeSyntax || - n is TypeSyntax - )) return; - //Console.WriteLine(new String(' ', indent * 2) + n.GetType().Name + " | " + n.GetText()?.ToString() ?? "null"); - fileWriter.WriteLine(new String(' ', indent * 2) + n.GetType().Name + " | " + n.GetText()?.ToString() ?? "null"); - }); - } - fileWriter.Close(); return null; } + + public static string FilterMetadata(MetadataReader reader) + { + if (reader == null) throw new ArgumentNullException("Metadata Reader must not be null."); + + var conflictingTypes = new List(); + reader.TypeReferences.ToList().ForEach(t => + { + var tRef = reader.GetTypeReference(t); + var typeName = $"{reader.GetString(tRef.Namespace)}.{reader.GetString(tRef.Name)}"; + if (!IsTypeAllowed(typeName)) conflictingTypes.Add(typeName); + }); + + if (conflictingTypes.Count > 0) + { + string errStr = "Metadata Error:"; + conflictingTypes.ForEach(t => errStr += $"\n Usage of type '{t}' in mods is prohibited."); + return errStr; + } + + return null; + } } } \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/SharedSource/Net/NetScriptLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/Net/NetScriptLoader.cs index 8dd616986..c40758309 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/Net/NetScriptLoader.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/Net/NetScriptLoader.cs @@ -5,14 +5,11 @@ using Microsoft.CodeAnalysis.Scripting; using System.Reflection; using Microsoft.CodeAnalysis.CSharp; using System.Linq; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis; -using System.Collections.Immutable; using System.Runtime.Loader; -using static NetScript; -using Microsoft.CodeAnalysis.Emit; using System.Reflection.PortableExecutable; using System.Reflection.Metadata; +using static NetScript; namespace Barotrauma { @@ -113,29 +110,19 @@ namespace Barotrauma string errStr = "NET MODS NOT LOADED | Mod cmopilation errors:"; foreach (Diagnostic diagnostic in failures) - { errStr = $"\n{diagnostic}"; - } NetSetup.PrintMessage(errStr); } else { mem.Seek(0, SeekOrigin.Begin); - var reader = new PEReader(mem); - var mdReader = reader.GetMetadataReader(); - mdReader.AssemblyReferences.ToList().ForEach(a => - { - var aRef = mdReader.GetAssemblyReference(a); - Console.WriteLine(aRef.GetAssemblyName() + " " + aRef.Version); - }); - Console.WriteLine(); - mdReader.TypeReferences.ToList().ForEach(t => - { - var tRef = mdReader.GetTypeReference(t); - Console.WriteLine(mdReader.GetString(tRef.Namespace) + " - " + mdReader.GetString(tRef.Name)); - }); - mem.Seek(0, SeekOrigin.Begin); - Assembly = LoadFromStream(mem); + var errStr = NetScriptFilter.FilterMetadata(new PEReader(mem).GetMetadataReader()); + if (errStr == null) + { + mem.Seek(0, SeekOrigin.Begin); + Assembly = LoadFromStream(mem); + } + else NetSetup.PrintMessage(errStr); } } syntaxTrees.Clear();