From da251099a3e5aea4b0077899c89c6b5c00a9ad42 Mon Sep 17 00:00:00 2001 From: Oiltanker Date: Tue, 19 Apr 2022 22:08:06 +0300 Subject: [PATCH] working serialization + cs mods one-time warning --- .../SharedSource/LuaCs/Cs/CsScriptFilter.cs | 4 +- .../SharedSource/LuaCs/Cs/CsScriptLoader.cs | 65 ++-- .../SharedSource/LuaCs/Cs/CsScriptRunner.cs | 6 +- .../LuaCs/Lua/LuaCustomConverters.cs | 1 + .../SharedSource/LuaCs/LuaCsSetup.cs | 81 ++++- .../SharedSource/LuaCs/LuaCsUtility.cs | 325 +++++++++++++++++- 6 files changed, 418 insertions(+), 64 deletions(-) diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs index faa4851b0..7f65f78c4 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs @@ -26,8 +26,10 @@ namespace Barotrauma { }; private static readonly string[] typesProhibited = { "System.IO", - "Moonsharp", "Barotrauma.IO", + "System.Xml.XmlReader", + "System.Xml.XmlWriter", + "Barotrauma.LuaUserData", }; public static bool IsTypeAllowed(string name) { diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs index 389f19aa1..98e372bc5 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs @@ -16,7 +16,8 @@ namespace Barotrauma { public LuaCsSetup setup; private List defaultReferences; - private List syntaxTrees; + + private Dictionary> sources; public Assembly Assembly { get; private set; } public CsScriptLoader(LuaCsSetup setup) : base(isCollectible: true) @@ -28,7 +29,7 @@ namespace Barotrauma .Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference) .ToList(); - syntaxTrees = new List(); + sources = new Dictionary>(); Assembly = null; } @@ -41,47 +42,52 @@ namespace Barotrauma } } - private void RunFolder(string folder) + public bool HasSources { get => sources.Count > 0; } + + private void RunFolder(string folder) { - var scriptFiles = new List(); foreach (var str in DirSearch(folder)) { var s = str.Replace("\\", "/"); - if (s.EndsWith(".cs") && LuaCsFile.IsPathAllowedCsException(s)) scriptFiles.Add(s); - } - - try - { - if (scriptFiles.Count <= 0) return; - - // Check file content for prohibited stuff - foreach (var file in scriptFiles) + if (s.EndsWith(".cs") && LuaCsFile.IsPathAllowedCsException(s)) { - var tree = SyntaxFactory.ParseSyntaxTree(File.ReadAllText(file), CSharpParseOptions.Default, file); - var error = CsScriptFilter.FilterSyntaxTree(tree as CSharpSyntaxTree); - if (error != null) throw new Exception(error); - - syntaxTrees.Add(tree); + if (sources.ContainsKey(folder)) sources[folder].Add(s); + else sources.Add(folder, new List { s }); } - } - catch (CompilationErrorException ex) - { - string errStr = "Compilation Error in '" + folder + "':"; - foreach (var diag in ex.Diagnostics) - { - errStr += "\n" + diag.ToString(); - } - LuaCsSetup.PrintCsError(errStr); } - catch (Exception ex) + } + + private IEnumerable ParseSources() { + var syntaxTrees = new List(); + + if (sources.Count <= 0) throw new Exception("No Cs sources detected"); + foreach ((var folder, var src) in sources) { - LuaCsSetup.PrintCsError("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace); + try + { + foreach (var file in src) + { + var tree = SyntaxFactory.ParseSyntaxTree(File.ReadAllText(file), CSharpParseOptions.Default, file); + var error = CsScriptFilter.FilterSyntaxTree(tree as CSharpSyntaxTree); // Check file content for prohibited stuff + if (error != null) throw new Exception(error); + + syntaxTrees.Add(tree); + } + } + catch (Exception ex) + { + LuaCsSetup.PrintCsError("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace); + } } + + return syntaxTrees; } public List Compile() { + IEnumerable syntaxTrees = ParseSources(); + var options = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) .WithMetadataImportOptions(MetadataImportOptions.All) .WithOptimizationLevel(OptimizationLevel.Release) @@ -112,7 +118,6 @@ namespace Barotrauma else LuaCsSetup.PrintCsError(errStr); } } - syntaxTrees.Clear(); if (Assembly != null) return Assembly.GetTypes().Where(t => t.IsSubclassOf(typeof(ACsMod))).ToList(); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs index 60bdf559f..da773600e 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs @@ -101,7 +101,11 @@ namespace Barotrauma if (method != null) { scriptResilt = method.Invoke(runner, null); - foreach (var type in assembly.GetTypes()) { UserData.UnregisterType(type, true); } + foreach (var type in assembly.GetTypes()) + { + //UserData.UnregisterType(type, true); + UserData.UnregisterType(type); + } } else LuaCsSetup.PrintCsError("Script Error - no run method detected"); } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs index cfb0a9984..5bbf70773 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaCustomConverters.cs @@ -26,6 +26,7 @@ namespace Barotrauma Script.GlobalOptions.CustomConverters.SetScriptToClrCustomConversion(DataType.Function, typeof(LuaCsAction), v => (LuaCsAction)( args => GameMain.LuaCs.CallLuaFunction(v.Function, args) )); Script.GlobalOptions.CustomConverters.SetScriptToClrCustomConversion(DataType.Function, typeof(LuaCsFunc), v => (LuaCsFunc)( args => new LuaResult(GameMain.LuaCs.CallLuaFunction(v.Function, args)) )); Script.GlobalOptions.CustomConverters.SetScriptToClrCustomConversion(DataType.Function, typeof(LuaCsPatch), v => (LuaCsPatch)( (self, args) => new LuaResult(GameMain.LuaCs.CallLuaFunction(v.Function, self, args)) )); + Script.GlobalOptions.CustomConverters.SetClrToScriptCustomConversion(typeof(LuaResult), (Script s, object v) => (v as LuaResult).DynValue()); #if CLIENT RegisterAction(); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs index ec0890c7f..0e3baa80b 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs @@ -16,11 +16,28 @@ using System.Reflection; [assembly: InternalsVisibleTo("NetOneTimeScriptAssembly", AllInternalsVisible = true)] namespace Barotrauma { + class LuaCsSetupConfig + { + public bool FirstTimeCsWaring = true; + + public LuaCsSetupConfig() { } + } + partial class LuaCsSetup { public const string LUASETUP_FILE = "Lua/LuaSetup.lua"; public const string VERSION_FILE = "luacsversion.txt"; + private const string configFileName = "LuaCsSetupConfig.xml"; + +#if SERVER + public const bool IsServer = true; + public const bool IsClient = false; +#else + public const bool IsServer = false; + public const bool IsClient = true; +#endif + private Script lua; public CsScriptRunner CsScript { get; private set; } public LuaGame Game { get; private set; } @@ -33,6 +50,8 @@ namespace Barotrauma public CsScriptLoader CsScriptLoader { get; private set; } public CsLua Lua { get; private set; } + public LuaCsSetupConfig Config { get; private set; } + public LuaCsSetup() { Hook = new LuaCsHook(); @@ -42,6 +61,15 @@ namespace Barotrauma Networking = new LuaCsNetworking(); } + public void UpdateConfig() + { + FileStream file; + if (!File.Exists(configFileName)) file = File.Create(configFileName); + else file = File.Open(configFileName, FileMode.Truncate, FileAccess.Write); + LuaCsConfig.Save(file, Config); + file.Close(); + } + public static ContentPackage GetPackage(Identifier name) { @@ -280,7 +308,7 @@ namespace Barotrauma { foreach (var type in AppDomain.CurrentDomain.GetAssemblies().Where(a => a.GetName().Name == "NetScriptAssembly").SelectMany(assembly => assembly.GetTypes())) { - UserData.UnregisterType(type, true); + UserData.UnregisterType(type); } foreach (var mod in ACsMod.LoadedMods.ToArray()) mod.Dispose(); ACsMod.LoadedMods.Clear(); @@ -296,6 +324,7 @@ namespace Barotrauma lua = null; Lua = null; CsScript = null; + Config = null; if (CsScriptLoader != null) { @@ -313,6 +342,15 @@ namespace Barotrauma PrintMessage("Lua! Version " + AssemblyInfo.GitRevision); + + if (File.Exists(configFileName)) + { + using (var file = File.Open(configFileName, FileMode.Open, FileAccess.Read)) + Config = LuaCsConfig.Load(file); + } + else Config = new LuaCsSetupConfig(); + + LuaScriptLoader = new LuaScriptLoader(); LuaScriptLoader.ModulePaths = new string[] { }; @@ -329,9 +367,11 @@ namespace Barotrauma Hook.Initialize(); ModStore.Initialize(); + UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); + UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); @@ -360,27 +400,36 @@ namespace Barotrauma lua.Globals["File"] = UserData.CreateStatic(); lua.Globals["Networking"] = Networking; - bool isServer; + lua.Globals["SERVER"] = IsServer; + lua.Globals["CLIENT"] = IsClient; -#if SERVER - isServer = true; -#else - isServer = false; -#endif + CsScriptLoader = new CsScriptLoader(this); + CsScriptLoader.SearchFolders(); + if (CsScriptLoader.HasSources) + { + if (Config.FirstTimeCsWaring) + { + Config.FirstTimeCsWaring = false; + UpdateConfig(); - lua.Globals["SERVER"] = isServer; - lua.Globals["CLIENT"] = !isServer; + LuaCsTimer.Wait((args) => PrintCsError(@" + ----==== ====---- - // LuaDocs.GenerateDocsAll(); + WARNING! + -- -- -- -- -- -- + !Use of Cs Mods detected! - //ContentPackage csPackage = GetPackage("CsForBarotrauma"); + Cs Mods are questionably +sandboxed, as they have +access to reflection, due to +modding needs. + USE ON YOUR OWN RISK! - //if (csPackage != null) - //{ - CsScriptLoader = new CsScriptLoader(this); + ----==== ====---- +"), 200); + } - CsScriptLoader.SearchFolders(); try { var modTypes = CsScriptLoader.Compile(); @@ -395,7 +444,7 @@ namespace Barotrauma } PrintMessage("Cs! Version " + AssemblyInfo.GitRevision); - //} + } ContentPackage luaPackage = GetPackage("LuaForBarotraumaUnstable"); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs index 26d887eb6..5379c204b 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsUtility.cs @@ -1,23 +1,19 @@ -using System; -using System.Collections.Generic; -using System.Text; -using MoonSharp.Interpreter; -using Microsoft.Xna.Framework; -using Barotrauma.Networking; using Barotrauma.Items.Components; -using System.IO; -using System.Net; -using System.Linq; -using System.Xml.Linq; -using FarseerPhysics.Dynamics; -using System.Reflection; -using HarmonyLib; -using MoonSharp.Interpreter.Interop; +using Barotrauma.Networking; +using MoonSharp.Interpreter; +using System; +using System.Collections; +using System.Collections.Generic; using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Reflection; +using System.Xml.Linq; namespace Barotrauma { - public partial class LuaCsTimer + public partial class LuaCsTimer { public static long LastUpdateTime = 0; @@ -29,7 +25,7 @@ namespace Barotrauma } } - public void Wait(LuaCsAction action, int millisecondDelay) + public static void Wait(LuaCsAction action, int millisecondDelay) { GameMain.LuaCs.Hook.EnqueueTimed((float)Timing.TotalTime + (millisecondDelay / 1000f), action); } @@ -155,6 +151,22 @@ namespace Barotrauma File.WriteAllText(path, text); } + public static FileStream OpenRead(string path) + { + if (!IsPathAllowedException(path)) + return null; + + return File.Open(path, FileMode.Open, FileAccess.Read); + } + public static FileStream OpenWrite(string path) + { + if (!IsPathAllowedException(path)) + return null; + + if (File.Exists(path)) return File.Open(path, FileMode.Truncate, FileAccess.Write); + else return File.Open(path, FileMode.Create, FileAccess.Write); + } + public static bool Exists(string path) { if (!IsPathAllowedException(path, false)) @@ -393,4 +405,285 @@ namespace Barotrauma } } + + class LuaCsConfig + { + private enum ValueType + { + None, + Text, + Integer, + Decimal, + Boolean, + Collection, + Object, + Enum + } + + private static Type[] LoadDocTypes(XElement typesElem) + { + var result = new List(); + foreach (var elem in typesElem.Elements()) + { + var type = Type.GetType(elem.Value); + if (type == null) throw new Exception($"Type {elem.Value} not found."); + result.Add(type); + + } + return result.ToArray(); + } + + private static IEnumerable SaveDocTypes(IEnumerable types) + { + return types.Select(t => new XElement("Type", t.ToString())); + } + + private static Type GetTypeAttr(Type[] types, XElement elem) + { + var idx = elem.GetAttributeInt("Type", -1); + if (idx < 0 || idx >= types.Length) throw new Exception($"Type index '{idx}' is outside of saved types bounds"); + return types[idx]; + } + private static ValueType GetValueType(XElement elem) + { + Enum.TryParse(typeof(ValueType), elem.Attribute("Value")?.Value, out object result); + if (result != null) return (ValueType)result; + else return ValueType.None; + } + private static object ParseValue(Type[] types, XElement elem) + { + var type = GetValueType(elem); + + if (elem.IsEmpty) return null; + if (type == ValueType.Enum) + { + var tType = GetTypeAttr(types, elem); + if (tType == null || !tType.IsSubclassOf(typeof(Enum))) return null; + if (Enum.TryParse(tType, elem.Value, out object result)) return result; + else return null; + } + if (type == ValueType.Collection) + { + var tType = GetTypeAttr(types, elem); + var tInt = tType.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)); + var gArg = tInt.GetGenericArguments()[0]; + if (tType == null || !tType.GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>))) return null; + + object result = null; + + if (result == null) { + var ctor = tType.GetConstructors(BindingFlags.Public | BindingFlags.Instance).FirstOrDefault(c => + { + var param = c.GetParameters(); + return param.Count() == 1 && param.Any(p => p.ParameterType.IsGenericType && p.ParameterType.GetGenericTypeDefinition() == typeof(IEnumerable<>)); + }); + if (ctor != null) + { + var elements = elem.Elements().Select(x => ParseValue(types, x)); + var castElems = typeof(Enumerable).GetMethod("Cast").MakeGenericMethod(gArg).Invoke(elements, new object[] { elements }); + result = ctor.Invoke(new object[] { castElems }); + } + } + + if (result == null) + { + var ctor = tType.GetConstructors(BindingFlags.Public | BindingFlags.Instance).FirstOrDefault(c => c.GetParameters().Count() == 0); + var addMethod = tType.GetMethods(BindingFlags.Instance | BindingFlags.Public).FirstOrDefault(m => + { + if (m.Name != "Add") return false; + var param = m.GetParameters(); + return param.Count() == 1 && param[0].ParameterType == gArg; + }); + if (ctor != null && addMethod != null) + { + var elements = elem.Elements().Select(x => ParseValue(types, x)); + result = ctor.Invoke(null); + foreach (var el in elements) addMethod.Invoke(result, new object[] { el }); + } + } + + if (result == null) + { + var ctor = tType.GetConstructors(BindingFlags.Public | BindingFlags.Instance).FirstOrDefault(); + var setMethod = tType.GetMethods(BindingFlags.Instance | BindingFlags.Public).FirstOrDefault(m => + { + if (m.Name != "Set") return false; + var param = m.GetParameters(); + return param.Count() == 2 && param[0].ParameterType == typeof(int) && param[1].ParameterType == gArg; + }); + if (ctor != null || setMethod != null) + { + var elements = elem.Elements().Select(x => ParseValue(types, x)); + result = ctor.Invoke(new object[] { elements.Count() }); + int i = 0; + foreach (var el in elements) + { + setMethod.Invoke(result, new object[] { i, el }); + i++; + } + } + } + + return result; + } + else if (type == ValueType.Text) return elem.Value; + else if (type == ValueType.Integer) + { + int.TryParse(elem.Value, out var num); + return num; + } + else if (type == ValueType.Decimal) + { + float.TryParse(elem.Value, out var num); + return num; + } + else if (type == ValueType.Boolean) + { + bool.TryParse(elem.Value, out var boolean); + return boolean; + } + else if (type == ValueType.Object) + { + var tType = GetTypeAttr(types, elem); + if (tType == null) return null; + + IEnumerable fields = tType.GetFields(BindingFlags.Instance | BindingFlags.Public) + .Concat(tType.GetFields(BindingFlags.Instance | BindingFlags.NonPublic)); + IEnumerable properties = tType.GetProperties(BindingFlags.Instance | BindingFlags.Public).Where(p => p.GetSetMethod() != null) + .Concat(tType.GetProperties(BindingFlags.Instance | BindingFlags.NonPublic).Where(p => p.GetSetMethod() != null)); + + object result = null; + var ctor = tType.GetConstructors(BindingFlags.Public | BindingFlags.Instance).FirstOrDefault(c => c.GetParameters().Count() == 0); + if (ctor == null) + { + if (!tType.IsValueType) return null; + result = Activator.CreateInstance(tType); + } + else result = ctor.Invoke(null); + + foreach(var el in elem.Elements()) + { + var value = ParseValue(types, el); + + var field = fields.FirstOrDefault(f => f.Name == el.Name.LocalName); + if (field != null) field.SetValue(result, value); + var property = properties.FirstOrDefault(p => p.Name == el.Name.LocalName); + if (property != null) property.SetValue(result, value); + } + return result; + } + else return elem.Value; + + } + + private static void AddTypeAttr(List types, Type type, XElement elem) + { + if (!types.Contains(type)) types.Add(type); + elem.SetAttributeValue("Type", types.IndexOf(type)); + } + + private static XElement ParseObject(List types, string name, object value) + { + XElement result = new XElement(name); + + if (value != null) + { + var tType = value.GetType(); + + if (tType.IsEnum) + { + result.SetAttributeValue("Value", ValueType.Enum); + AddTypeAttr(types, tType, result); + + result.Value = Enum.GetName(tType, value) ?? ""; + } + else if (value is string str) + { + result.SetAttributeValue("Value", ValueType.Text); + result.Value = str; + } + else if (value is int integer) + { + result.SetAttributeValue("Value", ValueType.Integer); + result.Value = integer.ToString(); + } + else if (value is float || value is double) + { + result.SetAttributeValue("Value", ValueType.Decimal); + result.Value = value.ToString(); + } + else if (value is bool boolean) + { + result.SetAttributeValue("Value", ValueType.Boolean); + result.Value = boolean.ToString(); + } + else if (tType.GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>))) + { + result.SetAttributeValue("Value", ValueType.Collection); + AddTypeAttr(types, tType, result); + + var enumerator = (IEnumerator)tType.GetMethod("GetEnumerator").Invoke(value, null); + while (enumerator.MoveNext()) + { + var elVal = ParseObject(types, "Item", enumerator.Current); + result.Add(elVal); + } + } + else if (tType.IsClass || tType.IsValueType) + { + result.SetAttributeValue("Value", ValueType.Object); + AddTypeAttr(types, tType, result); + + IEnumerable fields = tType.GetFields(BindingFlags.Instance | BindingFlags.Public) + .Concat(tType.GetFields(BindingFlags.Instance | BindingFlags.NonPublic)); + IEnumerable properties = tType.GetProperties(BindingFlags.Instance | BindingFlags.Public).Where(p => p.GetSetMethod() != null) + .Concat(tType.GetProperties(BindingFlags.Instance | BindingFlags.NonPublic).Where(p => p.GetSetMethod() != null)); + + foreach(var field in fields) result.Add(ParseObject(types, field.Name, field.GetValue(value))); + foreach (var property in properties) result.Add(ParseObject(types, property.Name, property.GetValue(value))); + } + else + { + result.SetAttributeValue("Value", ValueType.None); + result.Value = value.ToString(); + } + } + + return result; + } + + + public static T Load(FileStream file) + { + var doc = XDocument.Load(file); + + var rootElems = doc.Root.Elements().ToArray(); + var types = rootElems[0]; + var elem = rootElems[1]; + + var dict = ParseValue(LoadDocTypes(types), elem); + if (dict.GetType() == typeof(T)) return (T)dict; + else throw new Exception($"Loaded configuration is not of the type '{typeof(T).Name}'"); + } + + public static void Save(FileStream file, object obj) + { + var types = new List(); + var elem = ParseObject(types, "Root", obj); + var root = new XElement("Configuration", new XElement("Types", SaveDocTypes(types)), elem); + + var doc = new XDocument(root); + doc.Save(file); + } + + public static T Load(string path) + { + using (var file = LuaCsFile.OpenRead(path)) return Load(file); + } + + public static void Save(string path, object obj) + { + using (var file = LuaCsFile.OpenWrite(path)) Save(file, obj); + } + } } \ No newline at end of file