diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs index 895d4319d..7f65f78c4 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs @@ -9,7 +9,7 @@ using System.Reflection.Metadata; namespace Barotrauma { class CsScriptFilter { - private static readonly string[] typesPermitted = new string[] { + private static readonly string[] typesPermitted = { // Basics "System", // Barotrauma @@ -24,12 +24,12 @@ namespace Barotrauma { "MoonSharp.Interpreter.TailCallData", "MoonSharp.Interpreter.DataType", }; - private static readonly string[] typesProhibited = new string[] { - //"System.Reflection", - //"System.Type", + private static readonly string[] typesProhibited = { "System.IO", - "Moonsharp", "Barotrauma.IO", + "System.Xml.XmlReader", + "System.Xml.XmlWriter", + "Barotrauma.LuaUserData", }; public static bool IsTypeAllowed(string name) { @@ -60,6 +60,24 @@ namespace Barotrauma { return null; } + private static string ResolveTypeRef(MetadataReader reader, TypeReferenceHandle t) + { + var tRef = reader.GetTypeReference(t); + + var typeName = $"{reader.GetString(tRef.Name)}"; + EntityHandle handle = tRef.ResolutionScope; + TypeReference tr = tRef; + while (!handle.IsNil && handle.Kind == HandleKind.TypeReference) + { + tr = reader.GetTypeReference((TypeReferenceHandle)handle); + handle = tr.ResolutionScope; + typeName = $"{reader.GetString(tr.Name)}.{typeName}"; + } + typeName = $"{reader.GetString(tr.Namespace)}.{typeName}"; + + return typeName; + } + public static string FilterMetadata(MetadataReader reader) { if (reader == null) throw new ArgumentNullException("Metadata Reader must not be null."); @@ -67,18 +85,7 @@ namespace Barotrauma { var conflictingTypes = new List(); reader.TypeReferences.ToList().ForEach(t => { - var tRef = reader.GetTypeReference(t); - - var typeName = $"{reader.GetString(tRef.Name)}"; - EntityHandle handle = tRef.ResolutionScope; - TypeReference tr = tRef; - while (!handle.IsNil && handle.Kind == HandleKind.TypeReference) - { - tr = reader.GetTypeReference((TypeReferenceHandle)handle); - handle = tr.ResolutionScope; - typeName = $"{reader.GetString(tr.Name)}.{typeName}"; - } - typeName = $"{reader.GetString(tr.Namespace)}.{typeName}"; + var typeName = ResolveTypeRef(reader, t); if (!IsTypeAllowed(typeName)) conflictingTypes.Add(typeName); }); @@ -92,5 +99,48 @@ namespace Barotrauma { return null; } + + private static readonly string[] permitedDefinitions = { + ".", + "NetOneTimeScript.NetOneTimeScriptRunner" + }; + private static readonly string[] permitedMethods = { + ".ctor", + "Run" + }; + public static string FilterOneTimeMetadata(MetadataReader reader) + { + var errStr = FilterMetadata(reader); + if (errStr != null) return errStr; + + TypeDefinition? runDef = null; + reader.TypeDefinitions.Select(t => reader.GetTypeDefinition(t)).ToList().ForEach(t => + { + var typeName = $"{reader.GetString(t.Name)}"; + if (typeName == "NetOneTimeScriptRunner") runDef = t; + while (t.IsNested) + { + t = reader.GetTypeDefinition(t.GetDeclaringType()); + typeName = $"{reader.GetString(t.Name)}.{typeName}"; + } + typeName = $"{reader.GetString(t.Namespace)}.{typeName}"; + if (!permitedDefinitions.Contains(typeName)) errStr = "Malformed assembly"; + }); + if (errStr != null) return errStr; + + if (runDef == null) return "runner class not detected"; + else + { + var methods = runDef.Value.GetMethods(); + if (methods.Count > 2) return "malformed runner class"; + + methods.Select(m => reader.GetMethodDefinition(m)).ToList().ForEach(m => { + if (!permitedMethods.Contains(reader.GetString(m.Name))) errStr = "malformed runner class"; + }); + if (errStr != null) return errStr; + } + + return null; + } } } \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs index e65f62f67..98e372bc5 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs @@ -16,10 +16,11 @@ namespace Barotrauma { public LuaCsSetup setup; private List defaultReferences; - private List syntaxTrees; + + private Dictionary> sources; public Assembly Assembly { get; private set; } - public CsScriptLoader(LuaCsSetup setup) + public CsScriptLoader(LuaCsSetup setup) : base(isCollectible: true) { this.setup = setup; @@ -28,7 +29,7 @@ namespace Barotrauma .Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference) .ToList(); - syntaxTrees = new List(); + sources = new Dictionary>(); Assembly = null; } @@ -41,47 +42,52 @@ namespace Barotrauma } } - private void RunFolder(string folder) + public bool HasSources { get => sources.Count > 0; } + + private void RunFolder(string folder) { - var scriptFiles = new List(); foreach (var str in DirSearch(folder)) { var s = str.Replace("\\", "/"); - if (s.EndsWith(".cs") && LuaCsFile.IsPathAllowedCsException(s)) scriptFiles.Add(s); - } - - try - { - if (scriptFiles.Count <= 0) return; - - // Check file content for prohibited stuff - foreach (var file in scriptFiles) + if (s.EndsWith(".cs") && LuaCsFile.IsPathAllowedCsException(s)) { - var tree = SyntaxFactory.ParseSyntaxTree(File.ReadAllText(file), CSharpParseOptions.Default, file); - var error = CsScriptFilter.FilterSyntaxTree(tree as CSharpSyntaxTree); - if (error != null) throw new Exception(error); - - syntaxTrees.Add(tree); + if (sources.ContainsKey(folder)) sources[folder].Add(s); + else sources.Add(folder, new List { s }); } - } - catch (CompilationErrorException ex) - { - string errStr = "Cmopilation Error in '" + folder + "':"; - foreach (var diag in ex.Diagnostics) - { - errStr += "\n" + diag.ToString(); - } - LuaCsSetup.PrintCsError(errStr); } - catch (Exception ex) + } + + private IEnumerable ParseSources() { + var syntaxTrees = new List(); + + if (sources.Count <= 0) throw new Exception("No Cs sources detected"); + foreach ((var folder, var src) in sources) { - LuaCsSetup.PrintCsError("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace); + try + { + foreach (var file in src) + { + var tree = SyntaxFactory.ParseSyntaxTree(File.ReadAllText(file), CSharpParseOptions.Default, file); + var error = CsScriptFilter.FilterSyntaxTree(tree as CSharpSyntaxTree); // Check file content for prohibited stuff + if (error != null) throw new Exception(error); + + syntaxTrees.Add(tree); + } + } + catch (Exception ex) + { + LuaCsSetup.PrintCsError("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace); + } } + + return syntaxTrees; } public List Compile() { + IEnumerable syntaxTrees = ParseSources(); + var options = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) .WithMetadataImportOptions(MetadataImportOptions.All) .WithOptimizationLevel(OptimizationLevel.Release) @@ -95,9 +101,9 @@ namespace Barotrauma { IEnumerable failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error); - string errStr = "NET MODS NOT LOADED | Mod cmopilation errors:"; + string errStr = "NET MODS NOT LOADED | Mod compilation errors:"; foreach (Diagnostic diagnostic in failures) - errStr = $"\n{diagnostic}"; + errStr += $"\n{diagnostic}"; LuaCsSetup.PrintCsError(errStr); } else @@ -112,7 +118,6 @@ namespace Barotrauma else LuaCsSetup.PrintCsError(errStr); } } - syntaxTrees.Clear(); if (Assembly != null) return Assembly.GetTypes().Where(t => t.IsSubclassOf(typeof(ACsMod))).ToList(); @@ -144,5 +149,14 @@ namespace Barotrauma return files.ToArray(); } + public void Clear() + { + Assembly = null; + } + + ~CsScriptLoader() + { + + } } } \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs index 60dcbe22e..da773600e 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs @@ -9,6 +9,7 @@ using Microsoft.CodeAnalysis; using System.Runtime.Loader; using System.Reflection.PortableExecutable; using System.Reflection.Metadata; +using MoonSharp.Interpreter; namespace Barotrauma { @@ -17,8 +18,14 @@ namespace Barotrauma public LuaCsSetup setup; private List defaultReferences; private CSharpCompilationOptions compileOptions; + private static readonly string[] usings = { + "System", + "Barotrauma", + "System.Collections.Generic", + "System.Linq" + }; - public CsScriptRunner(LuaCsSetup setup) + public CsScriptRunner(LuaCsSetup setup) : base(isCollectible: true) { this.setup = setup; @@ -34,17 +41,10 @@ namespace Barotrauma private static string ToOneTimeScript(string code) { - var prefix = @" -public class NetOneTimeScriptRunner { - public NetOneTimeScriptRunner() { } - public object Run() { - -"; - var postfix = @" - return null; - } -} -"; + var prefix = ""; + foreach (var u in usings) prefix += $"using {u}; "; + prefix += "namespace NetOneTimeScript { public class NetOneTimeScriptRunner { public NetOneTimeScriptRunner() { } public object Run() {\n"; + var postfix = "\nreturn null; } } }"; return prefix + code + postfix; } @@ -54,7 +54,8 @@ public class NetOneTimeScriptRunner { try { - var syntaxTree = SyntaxFactory.ParseSyntaxTree(ToOneTimeScript(code), CSharpParseOptions.Default); + code = ToOneTimeScript(code); + var syntaxTree = SyntaxFactory.ParseSyntaxTree(code, CSharpParseOptions.Default); var compilation = CSharpCompilation.Create("NetOneTimeScriptAssembly", new[] { syntaxTree }, defaultReferences, compileOptions); Assembly assembly = null; @@ -65,44 +66,48 @@ public class NetOneTimeScriptRunner { { IEnumerable failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error); - string errStr = "Script cmopilation errors:"; + string errStr = "Script compilation errors:"; + var lineErr = new SortedDictionary(); foreach (Diagnostic diagnostic in failures) - errStr = $"\n{diagnostic}"; + { + var line = syntaxTree.GetLineSpan(diagnostic.Location.SourceSpan).StartLinePosition.Line; + lineErr[line] = (diagnostic.Id, diagnostic.ToString()); + } + var lines = code.Split('\n'); + for (var i = 1; i < lines.Length - 1; i++) + { + errStr += $"\n{i} >> {lines[i]}"; + if (lineErr.ContainsKey(i)) errStr += $" <=== {lineErr[i].Item1}"; + } + errStr += "\n"; + foreach ((var idx, (var id, var err)) in lineErr) + { + errStr += $"\n{idx}: {err}"; + } LuaCsSetup.PrintCsError(errStr); } else { mem.Seek(0, SeekOrigin.Begin); - var metaReader = new PEReader(mem).GetMetadataReader(); - var errStr = CsScriptFilter.FilterMetadata(metaReader); - if (errStr == null) - { - foreach (var handle in metaReader.TypeDefinitions) - { - var typeDef = metaReader.GetTypeDefinition(handle); - var typeName = $"{metaReader.GetString(typeDef.Namespace)}.{metaReader.GetString(typeDef.Name)}"; - if (typeName != ".NetOneTimeScriptRunner") - { - errStr = "Script Error - malformed assembly"; - break; - } - } - } - + var errStr = CsScriptFilter.FilterOneTimeMetadata(new PEReader(mem).GetMetadataReader()); if (errStr == null) { mem.Seek(0, SeekOrigin.Begin); assembly = LoadFromStream(mem); - var runner = assembly.CreateInstance("NetOneTimeScriptRunner"); + var runner = assembly.CreateInstance("NetOneTimeScript.NetOneTimeScriptRunner"); if (runner != null) { - if (runner.GetType().GetMethods().Count() > 1) LuaCsSetup.PrintCsError("Script Error - malformed runner class"); - else - { - var method = runner.GetType().GetMethod("Run", BindingFlags.Public | BindingFlags.Instance); - if (method != null) scriptResilt = method.Invoke(runner, null); - else LuaCsSetup.PrintCsError("Script Error - no run method detected"); + var method = runner.GetType().GetMethod("Run", BindingFlags.Public | BindingFlags.Instance); + if (method != null) + { + scriptResilt = method.Invoke(runner, null); + foreach (var type in assembly.GetTypes()) + { + //UserData.UnregisterType(type, true); + UserData.UnregisterType(type); + } } + else LuaCsSetup.PrintCsError("Script Error - no run method detected"); } else LuaCsSetup.PrintCsError("Script Error - no runner class detected"); } @@ -111,15 +116,6 @@ public class NetOneTimeScriptRunner { } Unload(); } - catch (CompilationErrorException ex) - { - string errStr = "Script Cmopilation Error:"; - foreach (var diag in ex.Diagnostics) - { - errStr += "\n" + diag.ToString(); - } - LuaCsSetup.PrintCsError(errStr); - } catch (Exception ex) { LuaCsSetup.PrintCsError("Error running script:\n" + ex.Message + "\n" + ex.StackTrace); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs index 48393ab6b..122428209 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs @@ -13,7 +13,7 @@ namespace Barotrauma { var type = Type.GetType(typeName); if (type != null) return type; - foreach (var a in AppDomain.CurrentDomain.GetAssemblies()) + foreach (var a in AppDomain.CurrentDomain.GetAssemblies().Reverse()) { type = a.GetType(typeName); if (type != null) diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs index cfb0a9984..5bbf70773 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs @@ -26,6 +26,7 @@ namespace Barotrauma Script.GlobalOptions.CustomConverters.SetScriptToClrCustomConversion(DataType.Function, typeof(LuaCsAction), v => (LuaCsAction)( args => GameMain.LuaCs.CallLuaFunction(v.Function, args) )); Script.GlobalOptions.CustomConverters.SetScriptToClrCustomConversion(DataType.Function, typeof(LuaCsFunc), v => (LuaCsFunc)( args => new LuaResult(GameMain.LuaCs.CallLuaFunction(v.Function, args)) )); Script.GlobalOptions.CustomConverters.SetScriptToClrCustomConversion(DataType.Function, typeof(LuaCsPatch), v => (LuaCsPatch)( (self, args) => new LuaResult(GameMain.LuaCs.CallLuaFunction(v.Function, self, args)) )); + Script.GlobalOptions.CustomConverters.SetClrToScriptCustomConversion(typeof(LuaResult), (Script s, object v) => (v as LuaResult).DynValue()); #if CLIENT RegisterAction(); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs index 60059b925..7fcec9aa2 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs @@ -435,7 +435,14 @@ namespace Barotrauma } else if (result is T cRes && cRes != null) lastResult = cRes; } - else if (result is T res && res != null) lastResult = res; + else + { + if (result is LuaResult lRes) + { + if (!lRes.IsNull()) lastResult = (T)(object)lRes.DynValue(); + } + else lastResult = (T)result; + } } } catch (Exception e) @@ -443,8 +450,7 @@ namespace Barotrauma StringBuilder argsSb = new StringBuilder(); foreach (var arg in args) argsSb.Append(arg + " "); GameMain.LuaCs.HandleException( - e, $"Error in Hook '{name}'->'{key}', with args '{argsSb}':\n{e}", - LuaCsSetup.ExceptionType.Both); + e, $"Error in Hook '{name}'->'{key}', with args '{argsSb}':\n{e}", ExceptionType.Both); } } } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsModStore.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsModStore.cs index bce1dd3b1..d539cb5e1 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsModStore.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsModStore.cs @@ -97,12 +97,16 @@ namespace Barotrauma public CsModStore GetCsStore(string modName) { var result = csModInterface.Where(i => i.Mod.GetType().FullName == modName).FirstOrDefault(); - if (!result.Mod.IsDisposed) return result; - else + if (result != null) { - csModInterface.Remove(result); - return null; + if (!result.Mod.IsDisposed) return result; + else + { + csModInterface.Remove(result); + return null; + } } + else return null; } protected LuaModStore GetLuaStore(string modName) => luaModInterface.Where(i => i.Name == modName).FirstOrDefault(); } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs index 209fec874..fd2c0c57d 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs @@ -16,11 +16,28 @@ using System.Reflection; [assembly: InternalsVisibleTo("NetOneTimeScriptAssembly", AllInternalsVisible = true)] namespace Barotrauma { + class LuaCsSetupConfig + { + public bool FirstTimeCsWaring = true; + + public LuaCsSetupConfig() { } + } + partial class LuaCsSetup { public const string LUASETUP_FILE = "Lua/LuaSetup.lua"; public const string VERSION_FILE = "luacsversion.txt"; + private const string configFileName = "LuaCsSetupConfig.xml"; + +#if SERVER + public const bool IsServer = true; + public const bool IsClient = false; +#else + public const bool IsServer = false; + public const bool IsClient = true; +#endif + private Script lua; public CsScriptRunner CsScript { get; private set; } public LuaGame Game { get; private set; } @@ -31,9 +48,11 @@ namespace Barotrauma public LuaCsModStore ModStore { get; private set; } private LuaRequire require { get; set; } - public CsScriptLoader NetScriptLoader { get; private set; } + public CsScriptLoader CsScriptLoader { get; private set; } public CsLua Lua { get; private set; } + public LuaCsSetupConfig Config { get; private set; } + public LuaCsSetup() { Hook = new LuaCsHook(); @@ -43,6 +62,15 @@ namespace Barotrauma Networking = new LuaCsNetworking(); } + public void UpdateConfig() + { + FileStream file; + if (!File.Exists(configFileName)) file = File.Create(configFileName); + else file = File.Open(configFileName, FileMode.Truncate, FileAccess.Write); + LuaCsConfig.Save(file, Config); + file.Close(); + } + public static ContentPackage GetPackage(Identifier name, bool fallbackToAll = true) { @@ -286,6 +314,10 @@ namespace Barotrauma public void Stop() { + foreach (var type in AppDomain.CurrentDomain.GetAssemblies().Where(a => a.GetName().Name == "NetScriptAssembly").SelectMany(assembly => assembly.GetTypes())) + { + UserData.UnregisterType(type); + } foreach (var mod in ACsMod.LoadedMods.ToArray()) mod.Dispose(); ACsMod.LoadedMods.Clear(); Hook?.Call("stop"); @@ -297,8 +329,19 @@ namespace Barotrauma Game = new LuaGame(); Networking = new LuaCsNetworking(); LuaScriptLoader = null; + lua = null; Lua = null; CsScript = null; + Config = null; + + if (CsScriptLoader != null) + { + CsScriptLoader.Clear(); + CsScriptLoader.Unload(); + CsScriptLoader = null; + GC.Collect(); + GC.WaitForPendingFinalizers(); + } } public void Initialize() @@ -307,6 +350,15 @@ namespace Barotrauma PrintMessage("Lua! Version " + AssemblyInfo.GitRevision); + + if (File.Exists(configFileName)) + { + using (var file = File.Open(configFileName, FileMode.Open, FileAccess.Read)) + Config = LuaCsConfig.Load(file); + } + else Config = new LuaCsSetupConfig(); + + LuaScriptLoader = new LuaScriptLoader(); LuaScriptLoader.ModulePaths = new string[] { }; @@ -325,6 +377,11 @@ namespace Barotrauma Hook.Initialize(); ModStore.Initialize(); + UserData.RegisterType(); + UserData.RegisterType(); + UserData.RegisterType(); + UserData.RegisterType(); + UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); @@ -353,31 +410,43 @@ namespace Barotrauma lua.Globals["File"] = UserData.CreateStatic(); lua.Globals["Networking"] = Networking; - bool isServer; + lua.Globals["SERVER"] = IsServer; + lua.Globals["CLIENT"] = IsClient; -#if SERVER - isServer = true; -#else - isServer = false; -#endif - - lua.Globals["SERVER"] = isServer; - lua.Globals["CLIENT"] = !isServer; - - // LuaDocs.GenerateDocsAll(); - - ContentPackage csPackage = GetPackage("CsForBarotrauma", false); - - - if (csPackage != null) + CsScriptLoader = new CsScriptLoader(this); + CsScriptLoader.SearchFolders(); + if (CsScriptLoader.HasSources) { - NetScriptLoader = new CsScriptLoader(this); + if (Config.FirstTimeCsWaring) + { + Config.FirstTimeCsWaring = false; + UpdateConfig(); + + LuaCsTimer.Wait((args) => PrintCsError(@" + ----==== ====---- + + WARNING! + -- -- -- -- -- -- + !Use of Cs Mods detected! + + Cs Mods are questionably +sandboxed, as they have +access to reflection, due to +modding needs. + + USE ON YOUR OWN RISK! + + ----==== ====---- +"), 200); + } - NetScriptLoader.SearchFolders(); try { - var modTypes = NetScriptLoader.Compile(); - modTypes.ForEach(t => t.GetConstructor(new Type[] { })?.Invoke(null)); + var modTypes = CsScriptLoader.Compile(); + modTypes.ForEach(t => { + UserData.RegisterType(t); + t.GetConstructor(new Type[] { })?.Invoke(null); + }); } catch (Exception ex) { diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs index 26d887eb6..5379c204b 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs @@ -1,23 +1,19 @@ -using System; -using System.Collections.Generic; -using System.Text; -using MoonSharp.Interpreter; -using Microsoft.Xna.Framework; -using Barotrauma.Networking; using Barotrauma.Items.Components; -using System.IO; -using System.Net; -using System.Linq; -using System.Xml.Linq; -using FarseerPhysics.Dynamics; -using System.Reflection; -using HarmonyLib; -using MoonSharp.Interpreter.Interop; +using Barotrauma.Networking; +using MoonSharp.Interpreter; +using System; +using System.Collections; +using System.Collections.Generic; using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Reflection; +using System.Xml.Linq; namespace Barotrauma { - public partial class LuaCsTimer + public partial class LuaCsTimer { public static long LastUpdateTime = 0; @@ -29,7 +25,7 @@ namespace Barotrauma } } - public void Wait(LuaCsAction action, int millisecondDelay) + public static void Wait(LuaCsAction action, int millisecondDelay) { GameMain.LuaCs.Hook.EnqueueTimed((float)Timing.TotalTime + (millisecondDelay / 1000f), action); } @@ -155,6 +151,22 @@ namespace Barotrauma File.WriteAllText(path, text); } + public static FileStream OpenRead(string path) + { + if (!IsPathAllowedException(path)) + return null; + + return File.Open(path, FileMode.Open, FileAccess.Read); + } + public static FileStream OpenWrite(string path) + { + if (!IsPathAllowedException(path)) + return null; + + if (File.Exists(path)) return File.Open(path, FileMode.Truncate, FileAccess.Write); + else return File.Open(path, FileMode.Create, FileAccess.Write); + } + public static bool Exists(string path) { if (!IsPathAllowedException(path, false)) @@ -393,4 +405,285 @@ namespace Barotrauma } } + + class LuaCsConfig + { + private enum ValueType + { + None, + Text, + Integer, + Decimal, + Boolean, + Collection, + Object, + Enum + } + + private static Type[] LoadDocTypes(XElement typesElem) + { + var result = new List(); + foreach (var elem in typesElem.Elements()) + { + var type = Type.GetType(elem.Value); + if (type == null) throw new Exception($"Type {elem.Value} not found."); + result.Add(type); + + } + return result.ToArray(); + } + + private static IEnumerable SaveDocTypes(IEnumerable types) + { + return types.Select(t => new XElement("Type", t.ToString())); + } + + private static Type GetTypeAttr(Type[] types, XElement elem) + { + var idx = elem.GetAttributeInt("Type", -1); + if (idx < 0 || idx >= types.Length) throw new Exception($"Type index '{idx}' is outside of saved types bounds"); + return types[idx]; + } + private static ValueType GetValueType(XElement elem) + { + Enum.TryParse(typeof(ValueType), elem.Attribute("Value")?.Value, out object result); + if (result != null) return (ValueType)result; + else return ValueType.None; + } + private static object ParseValue(Type[] types, XElement elem) + { + var type = GetValueType(elem); + + if (elem.IsEmpty) return null; + if (type == ValueType.Enum) + { + var tType = GetTypeAttr(types, elem); + if (tType == null || !tType.IsSubclassOf(typeof(Enum))) return null; + if (Enum.TryParse(tType, elem.Value, out object result)) return result; + else return null; + } + if (type == ValueType.Collection) + { + var tType = GetTypeAttr(types, elem); + var tInt = tType.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)); + var gArg = tInt.GetGenericArguments()[0]; + if (tType == null || !tType.GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>))) return null; + + object result = null; + + if (result == null) { + var ctor = tType.GetConstructors(BindingFlags.Public | BindingFlags.Instance).FirstOrDefault(c => + { + var param = c.GetParameters(); + return param.Count() == 1 && param.Any(p => p.ParameterType.IsGenericType && p.ParameterType.GetGenericTypeDefinition() == typeof(IEnumerable<>)); + }); + if (ctor != null) + { + var elements = elem.Elements().Select(x => ParseValue(types, x)); + var castElems = typeof(Enumerable).GetMethod("Cast").MakeGenericMethod(gArg).Invoke(elements, new object[] { elements }); + result = ctor.Invoke(new object[] { castElems }); + } + } + + if (result == null) + { + var ctor = tType.GetConstructors(BindingFlags.Public | BindingFlags.Instance).FirstOrDefault(c => c.GetParameters().Count() == 0); + var addMethod = tType.GetMethods(BindingFlags.Instance | BindingFlags.Public).FirstOrDefault(m => + { + if (m.Name != "Add") return false; + var param = m.GetParameters(); + return param.Count() == 1 && param[0].ParameterType == gArg; + }); + if (ctor != null && addMethod != null) + { + var elements = elem.Elements().Select(x => ParseValue(types, x)); + result = ctor.Invoke(null); + foreach (var el in elements) addMethod.Invoke(result, new object[] { el }); + } + } + + if (result == null) + { + var ctor = tType.GetConstructors(BindingFlags.Public | BindingFlags.Instance).FirstOrDefault(); + var setMethod = tType.GetMethods(BindingFlags.Instance | BindingFlags.Public).FirstOrDefault(m => + { + if (m.Name != "Set") return false; + var param = m.GetParameters(); + return param.Count() == 2 && param[0].ParameterType == typeof(int) && param[1].ParameterType == gArg; + }); + if (ctor != null || setMethod != null) + { + var elements = elem.Elements().Select(x => ParseValue(types, x)); + result = ctor.Invoke(new object[] { elements.Count() }); + int i = 0; + foreach (var el in elements) + { + setMethod.Invoke(result, new object[] { i, el }); + i++; + } + } + } + + return result; + } + else if (type == ValueType.Text) return elem.Value; + else if (type == ValueType.Integer) + { + int.TryParse(elem.Value, out var num); + return num; + } + else if (type == ValueType.Decimal) + { + float.TryParse(elem.Value, out var num); + return num; + } + else if (type == ValueType.Boolean) + { + bool.TryParse(elem.Value, out var boolean); + return boolean; + } + else if (type == ValueType.Object) + { + var tType = GetTypeAttr(types, elem); + if (tType == null) return null; + + IEnumerable fields = tType.GetFields(BindingFlags.Instance | BindingFlags.Public) + .Concat(tType.GetFields(BindingFlags.Instance | BindingFlags.NonPublic)); + IEnumerable properties = tType.GetProperties(BindingFlags.Instance | BindingFlags.Public).Where(p => p.GetSetMethod() != null) + .Concat(tType.GetProperties(BindingFlags.Instance | BindingFlags.NonPublic).Where(p => p.GetSetMethod() != null)); + + object result = null; + var ctor = tType.GetConstructors(BindingFlags.Public | BindingFlags.Instance).FirstOrDefault(c => c.GetParameters().Count() == 0); + if (ctor == null) + { + if (!tType.IsValueType) return null; + result = Activator.CreateInstance(tType); + } + else result = ctor.Invoke(null); + + foreach(var el in elem.Elements()) + { + var value = ParseValue(types, el); + + var field = fields.FirstOrDefault(f => f.Name == el.Name.LocalName); + if (field != null) field.SetValue(result, value); + var property = properties.FirstOrDefault(p => p.Name == el.Name.LocalName); + if (property != null) property.SetValue(result, value); + } + return result; + } + else return elem.Value; + + } + + private static void AddTypeAttr(List types, Type type, XElement elem) + { + if (!types.Contains(type)) types.Add(type); + elem.SetAttributeValue("Type", types.IndexOf(type)); + } + + private static XElement ParseObject(List types, string name, object value) + { + XElement result = new XElement(name); + + if (value != null) + { + var tType = value.GetType(); + + if (tType.IsEnum) + { + result.SetAttributeValue("Value", ValueType.Enum); + AddTypeAttr(types, tType, result); + + result.Value = Enum.GetName(tType, value) ?? ""; + } + else if (value is string str) + { + result.SetAttributeValue("Value", ValueType.Text); + result.Value = str; + } + else if (value is int integer) + { + result.SetAttributeValue("Value", ValueType.Integer); + result.Value = integer.ToString(); + } + else if (value is float || value is double) + { + result.SetAttributeValue("Value", ValueType.Decimal); + result.Value = value.ToString(); + } + else if (value is bool boolean) + { + result.SetAttributeValue("Value", ValueType.Boolean); + result.Value = boolean.ToString(); + } + else if (tType.GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>))) + { + result.SetAttributeValue("Value", ValueType.Collection); + AddTypeAttr(types, tType, result); + + var enumerator = (IEnumerator)tType.GetMethod("GetEnumerator").Invoke(value, null); + while (enumerator.MoveNext()) + { + var elVal = ParseObject(types, "Item", enumerator.Current); + result.Add(elVal); + } + } + else if (tType.IsClass || tType.IsValueType) + { + result.SetAttributeValue("Value", ValueType.Object); + AddTypeAttr(types, tType, result); + + IEnumerable fields = tType.GetFields(BindingFlags.Instance | BindingFlags.Public) + .Concat(tType.GetFields(BindingFlags.Instance | BindingFlags.NonPublic)); + IEnumerable properties = tType.GetProperties(BindingFlags.Instance | BindingFlags.Public).Where(p => p.GetSetMethod() != null) + .Concat(tType.GetProperties(BindingFlags.Instance | BindingFlags.NonPublic).Where(p => p.GetSetMethod() != null)); + + foreach(var field in fields) result.Add(ParseObject(types, field.Name, field.GetValue(value))); + foreach (var property in properties) result.Add(ParseObject(types, property.Name, property.GetValue(value))); + } + else + { + result.SetAttributeValue("Value", ValueType.None); + result.Value = value.ToString(); + } + } + + return result; + } + + + public static T Load(FileStream file) + { + var doc = XDocument.Load(file); + + var rootElems = doc.Root.Elements().ToArray(); + var types = rootElems[0]; + var elem = rootElems[1]; + + var dict = ParseValue(LoadDocTypes(types), elem); + if (dict.GetType() == typeof(T)) return (T)dict; + else throw new Exception($"Loaded configuration is not of the type '{typeof(T).Name}'"); + } + + public static void Save(FileStream file, object obj) + { + var types = new List(); + var elem = ParseObject(types, "Root", obj); + var root = new XElement("Configuration", new XElement("Types", SaveDocTypes(types)), elem); + + var doc = new XDocument(root); + doc.Save(file); + } + + public static T Load(string path) + { + using (var file = LuaCsFile.OpenRead(path)) return Load(file); + } + + public static void Save(string path, object obj) + { + using (var file = LuaCsFile.OpenWrite(path)) Save(file, obj); + } + } } \ No newline at end of file