diff --git a/Barotrauma/BarotraumaClient/ClientSource/DebugConsole.cs b/Barotrauma/BarotraumaClient/ClientSource/DebugConsole.cs index d82a9869c..e896f8668 100644 --- a/Barotrauma/BarotraumaClient/ClientSource/DebugConsole.cs +++ b/Barotrauma/BarotraumaClient/ClientSource/DebugConsole.cs @@ -3223,7 +3223,17 @@ namespace Barotrauma return; } - GameMain.LuaCs.DoString(string.Join(" ", args)); + GameMain.LuaCs.Lua.DoString(string.Join(" ", args)); + })); + commands.Add(new Command("cl_cs", "cs_cl: runs a string on the client", (string[] args) => + { + if (GameMain.Client != null && !GameMain.Client.HasPermission(ClientPermissions.ConsoleCommands)) + { + ThrowError("Command not permitted."); + return; + } + + GameMain.LuaCs.CsScript.Run(string.Join(" ", args)); })); commands.Add(new Command("cl_reloadlua", "reloads lua on the client", (string[] args) => @@ -3236,28 +3246,6 @@ namespace Barotrauma GameMain.LuaCs.Initialize(); })); - - commands.Add(new Command("cl_net", "lua_net: runs a script on the client", (string[] args) => - { - if (GameMain.Client != null && !GameMain.Client.HasPermission(ClientPermissions.ConsoleCommands)) - { - ThrowError("Command not permitted."); - return; - } - - GameMain.LuaCs.DoString(string.Join(" ", args)); - })); - - commands.Add(new Command("cl_reloadnet", "reloads lua on the client", (string[] args) => - { - if (GameMain.Client != null && !GameMain.Client.HasPermission(ClientPermissions.ConsoleCommands)) - { - ThrowError("Command not permitted."); - return; - } - - GameMain.LuaCs.Initialize(); - })); } private static void ReloadWearables(Character character, int variant = 0) diff --git a/Barotrauma/BarotraumaServer/ServerSource/DebugConsole.cs b/Barotrauma/BarotraumaServer/ServerSource/DebugConsole.cs index 4bcc8f624..a15fb8934 100644 --- a/Barotrauma/BarotraumaServer/ServerSource/DebugConsole.cs +++ b/Barotrauma/BarotraumaServer/ServerSource/DebugConsole.cs @@ -1243,7 +1243,11 @@ namespace Barotrauma commands.Add(new Command("lua", "lua: runs a string", (string[] args) => { - GameMain.LuaCs.DoString(string.Join(" ", args)); + GameMain.LuaCs.Lua.DoString(string.Join(" ", args)); + })); + commands.Add(new Command("cs", "cs: runs a string", (string[] args) => + { + GameMain.LuaCs.CsScript.Run(string.Join(" ", args)); })); commands.Add(new Command("reloadlua", "reloads lua", (string[] args) => diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsLua.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsLua.cs new file mode 100644 index 000000000..12d984339 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsLua.cs @@ -0,0 +1,22 @@ +using Barotrauma; +using MoonSharp.Interpreter; +using System.Collections.Generic; + +namespace Barotrauma +{ + partial class LuaCsSetup { + public class CsLua + { + private LuaCsSetup setup; + public Table Globals { get; private set; } + + public CsLua(LuaCsSetup setup) + { + this.setup = setup; + Globals = setup.lua.Globals; + } + + public DynValue DoString(string code) => setup.DoString(code); + } + } +} \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs index c4ad54880..895d4319d 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptFilter.cs @@ -9,30 +9,36 @@ using System.Reflection.Metadata; namespace Barotrauma { class CsScriptFilter { - private const bool useWhitelist = false; - - private static string[] typesPermited = new string[] { + private static readonly string[] typesPermitted = new string[] { // Basics - "System.Runtime.CompilerServices.CompilationRelaxationsAttribute", - "System.Runtime.CompilerServices.RuntimeCompatibilityAttribute", - "System.Diagnostics.DebuggableAttribute", - "System.Object", - "System.String", - "System.Collections", - // Some roslyn magic - ".DebuggingModes", + "System", // Barotrauma "Barotrauma", + // Lua + "MoonSharp.Interpreter.DynValue", + "MoonSharp.Interpreter.Closure", + "MoonSharp.Interpreter.Coroutine", + "MoonSharp.Interpreter.CoroutineState", + "MoonSharp.Interpreter.Table", + "MoonSharp.Interpreter.YieldRequest", + "MoonSharp.Interpreter.TailCallData", + "MoonSharp.Interpreter.DataType", }; - private static string[] typesProhibited = new string[] { + private static readonly string[] typesProhibited = new string[] { //"System.Reflection", + //"System.Type", "System.IO", + "Moonsharp", + "Barotrauma.IO", }; - public static bool IsTypeAllowed(string usingName) + public static bool IsTypeAllowed(string name) { - if (useWhitelist && !typesPermited.Any(u => u.StartsWith(usingName))) return false; - if (typesProhibited.Any(u => u.StartsWith(usingName))) return false; - return true; + var matchPermitted = typesPermitted.Where(s => name.StartsWith(s)); + var longestPemitted = matchPermitted.Count() > 0 ? matchPermitted.Max(s => s.Length) : 0; + var matchProhibited = typesProhibited.Where(s => name.StartsWith(s)); + var longestProhibited = matchProhibited.Count() > 0 ? matchProhibited.Max(s => s.Length) : 0; + if (longestPemitted == 0 || longestPemitted < longestProhibited) return false; + else return true; } public static string FilterSyntaxTree(CSharpSyntaxTree tree) @@ -62,7 +68,18 @@ namespace Barotrauma { reader.TypeReferences.ToList().ForEach(t => { var tRef = reader.GetTypeReference(t); - var typeName = $"{reader.GetString(tRef.Namespace)}.{reader.GetString(tRef.Name)}"; + + var typeName = $"{reader.GetString(tRef.Name)}"; + EntityHandle handle = tRef.ResolutionScope; + TypeReference tr = tRef; + while (!handle.IsNil && handle.Kind == HandleKind.TypeReference) + { + tr = reader.GetTypeReference((TypeReferenceHandle)handle); + handle = tr.ResolutionScope; + typeName = $"{reader.GetString(tr.Name)}.{typeName}"; + } + typeName = $"{reader.GetString(tr.Namespace)}.{typeName}"; + if (!IsTypeAllowed(typeName)) conflictingTypes.Add(typeName); }); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs index 7c5fef1f9..e65f62f67 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptLoader.cs @@ -72,11 +72,11 @@ namespace Barotrauma { errStr += "\n" + diag.ToString(); } - LuaCsSetup.PrintCsMessage(errStr); + LuaCsSetup.PrintCsError(errStr); } catch (Exception ex) { - LuaCsSetup.PrintCsMessage("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace); + LuaCsSetup.PrintCsError("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace); } } @@ -98,7 +98,7 @@ namespace Barotrauma string errStr = "NET MODS NOT LOADED | Mod cmopilation errors:"; foreach (Diagnostic diagnostic in failures) errStr = $"\n{diagnostic}"; - LuaCsSetup.PrintCsMessage(errStr); + LuaCsSetup.PrintCsError(errStr); } else { @@ -109,7 +109,7 @@ namespace Barotrauma mem.Seek(0, SeekOrigin.Begin); Assembly = LoadFromStream(mem); } - else LuaCsSetup.PrintCsMessage(errStr); + else LuaCsSetup.PrintCsError(errStr); } } syntaxTrees.Clear(); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs new file mode 100644 index 000000000..60dcbe22e --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Cs/CsScriptRunner.cs @@ -0,0 +1,132 @@ +using System; +using System.Collections.Generic; +using System.IO; +using Microsoft.CodeAnalysis.Scripting; +using System.Reflection; +using Microsoft.CodeAnalysis.CSharp; +using System.Linq; +using Microsoft.CodeAnalysis; +using System.Runtime.Loader; +using System.Reflection.PortableExecutable; +using System.Reflection.Metadata; + +namespace Barotrauma +{ + class CsScriptRunner : AssemblyLoadContext + { + public LuaCsSetup setup; + private List defaultReferences; + private CSharpCompilationOptions compileOptions; + + public CsScriptRunner(LuaCsSetup setup) + { + this.setup = setup; + + defaultReferences = AppDomain.CurrentDomain.GetAssemblies() + .Where(a => !(a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit"))) + .Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference) + .ToList(); + compileOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary) + .WithMetadataImportOptions(MetadataImportOptions.All) + .WithOptimizationLevel(OptimizationLevel.Release) + .WithAllowUnsafe(false); + } + + private static string ToOneTimeScript(string code) + { + var prefix = @" +public class NetOneTimeScriptRunner { + public NetOneTimeScriptRunner() { } + public object Run() { + +"; + var postfix = @" + return null; + } +} +"; + return prefix + code + postfix; + } + + public object Run(string code) + { + object scriptResilt = null; + + try + { + var syntaxTree = SyntaxFactory.ParseSyntaxTree(ToOneTimeScript(code), CSharpParseOptions.Default); + var compilation = CSharpCompilation.Create("NetOneTimeScriptAssembly", new[] { syntaxTree }, defaultReferences, compileOptions); + + Assembly assembly = null; + using (var mem = new MemoryStream()) + { + var result = compilation.Emit(mem); + if (!result.Success) + { + IEnumerable failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error); + + string errStr = "Script cmopilation errors:"; + foreach (Diagnostic diagnostic in failures) + errStr = $"\n{diagnostic}"; + LuaCsSetup.PrintCsError(errStr); + } + else + { + mem.Seek(0, SeekOrigin.Begin); + var metaReader = new PEReader(mem).GetMetadataReader(); + var errStr = CsScriptFilter.FilterMetadata(metaReader); + if (errStr == null) + { + foreach (var handle in metaReader.TypeDefinitions) + { + var typeDef = metaReader.GetTypeDefinition(handle); + var typeName = $"{metaReader.GetString(typeDef.Namespace)}.{metaReader.GetString(typeDef.Name)}"; + if (typeName != ".NetOneTimeScriptRunner") + { + errStr = "Script Error - malformed assembly"; + break; + } + } + } + + if (errStr == null) + { + mem.Seek(0, SeekOrigin.Begin); + assembly = LoadFromStream(mem); + var runner = assembly.CreateInstance("NetOneTimeScriptRunner"); + if (runner != null) + { + if (runner.GetType().GetMethods().Count() > 1) LuaCsSetup.PrintCsError("Script Error - malformed runner class"); + else + { + var method = runner.GetType().GetMethod("Run", BindingFlags.Public | BindingFlags.Instance); + if (method != null) scriptResilt = method.Invoke(runner, null); + else LuaCsSetup.PrintCsError("Script Error - no run method detected"); + } + } + else LuaCsSetup.PrintCsError("Script Error - no runner class detected"); + } + else LuaCsSetup.PrintCsError(errStr); + } + } + Unload(); + } + catch (CompilationErrorException ex) + { + string errStr = "Script Cmopilation Error:"; + foreach (var diag in ex.Diagnostics) + { + errStr += "\n" + diag.ToString(); + } + LuaCsSetup.PrintCsError(errStr); + } + catch (Exception ex) + { + LuaCsSetup.PrintCsError("Error running script:\n" + ex.Message + "\n" + ex.StackTrace); + } + + return scriptResilt; + } + + } +} \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs index 741eb4d29..60059b925 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsHook.cs @@ -6,6 +6,7 @@ using HarmonyLib; using System.Collections.Generic; using System.Text; using MoonSharp.Interpreter.Interop; +using static Barotrauma.LuaCsSetup; namespace Barotrauma { @@ -73,7 +74,20 @@ namespace Barotrauma { harmony = new Harmony("LuaCsForBarotrauma"); - } + var hookType = UserData.RegisterType(); + var hookDesc = (StandardUserDataDescriptor)hookType; + typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => { + if ( + m.Name.Contains("HookMethod") || + m.Name.Contains("UnhookMethod") || + m.Name.Contains("EnqueueFunction") || + m.Name.Contains("EnqueueTimedFunction") + ) + { + hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default)); + } + }); + } private static void _hookLuaCsPatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookMethodType) { @@ -202,9 +216,22 @@ namespace Barotrauma return methodInfo; } + private static readonly string[] prohibitedHooks = { + "Barotrauma.Lua", + "Barotrauma.Cs" + }; public void HookMethod(string identifier, MethodInfo method, LuaCsPatch patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null) { - if (identifier == null || method == null || patch == null) throw new ArgumentNullException("Identifier, Method and Patch arguments must not be null."); + if (identifier == null || method == null || patch == null) + { + GameMain.LuaCs.HandleException(new ArgumentNullException("Identifier, Method and Patch arguments must not be null."), exceptionType: ExceptionType.Both); + return; + } + if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h))) + { + GameMain.LuaCs.HandleException(new ArgumentException("Hooks into Modding Environment are prohibited."), exceptionType: ExceptionType.Both); + return; + } var funcAddr = ((long)method.MethodHandle.GetFunctionPointer()); var patches = Harmony.GetPatchInfo(method); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsModStore.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsModStore.cs new file mode 100644 index 000000000..bce1dd3b1 --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsModStore.cs @@ -0,0 +1,110 @@ +using MoonSharp.Interpreter; +using MoonSharp.Interpreter.Interop; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace Barotrauma +{ + partial class LuaCsSetup + { + public class LuaCsModStore + { + public abstract class ModStore + { + protected Dictionary store; + + public TStore Set(string name, TStore value) => store[name] = value; + public TStore Get(string name) => store[name]; + + public ModStore(Dictionary store) => this.store = store; + + public abstract bool Equals(T value); + } + public class LuaModStore : ModStore + { + public string Name; + + public LuaModStore(Dictionary store) : base(store) { } + public override bool Equals(string value) => Name == value; + } + public class CsModStore : ModStore + { + public ACsMod Mod; + + public CsModStore(Dictionary store) : base(store) { } + public override bool Equals(ACsMod value) => Mod == value; + } + + private HashSet luaModInterface; + private HashSet csModInterface; + + public LuaCsModStore() + { + luaModInterface = new HashSet(); + csModInterface = new HashSet(); + } + + public void Initialize() + { + UserData.RegisterType(); + UserData.RegisterType(); + var msType = UserData.RegisterType(); + var msDesc = (StandardUserDataDescriptor)msType; + + typeof(StandardUserDataDescriptor).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => + { + if ( + m.Name.Contains("Register") + ) + { + msDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default)); + } + }); + } + public void Clear() + { + luaModInterface.Clear(); + csModInterface.Clear(); + } + + protected LuaModStore Register(string modName) + { + if (luaModInterface.Any(i => i.Equals(modName))) + { + GameMain.LuaCs.HandleException(new ArgumentException($"'{modName}' entry already registered"), exceptionType: ExceptionType.Lua); + return null; + } + + var newHandle = new LuaModStore(new Dictionary()); + if (luaModInterface.Add(newHandle)) return newHandle; + else return null; + } + [MoonSharpHidden] + public CsModStore Register(ACsMod mod) + { + if (csModInterface.Any(i => i.Equals(mod))) + { + GameMain.LuaCs.HandleException(new ArgumentException($"'{mod.GetType().FullName}' entry already registered"), exceptionType: ExceptionType.CSharp); + return null; + } + + var newHandle = new CsModStore(new Dictionary()); + if (csModInterface.Add(newHandle)) return newHandle; + else return null; + } + + public CsModStore GetCsStore(string modName) { + var result = csModInterface.Where(i => i.Mod.GetType().FullName == modName).FirstOrDefault(); + if (!result.Mod.IsDisposed) return result; + else + { + csModInterface.Remove(result); + return null; + } + } + protected LuaModStore GetLuaStore(string modName) => luaModInterface.Where(i => i.Name == modName).FirstOrDefault(); + } + } +} \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs index 00e2057d2..c4038c0a2 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs @@ -13,6 +13,7 @@ using System.Linq; using System.Reflection; [assembly: InternalsVisibleTo("NetScriptAssembly", AllInternalsVisible = true)] +[assembly: InternalsVisibleTo("NetOneTimeScriptAssembly", AllInternalsVisible = true)] namespace Barotrauma { partial class LuaCsSetup @@ -20,19 +21,22 @@ namespace Barotrauma public const string LUASETUP_FILE = "Lua/LuaSetup.lua"; public const string VERSION_FILE = "luacsversion.txt"; - public Script lua; + private Script lua; + public CsScriptRunner CsScript { get; private set; } + public LuaGame Game { get; private set; } + public LuaScriptLoader LuaScriptLoader { get; private set; } - internal LuaCsHook Hook { get; private set; } + public LuaCsHook Hook { get; private set; } + public LuaCsNetworking Networking { get; private set; } + public LuaCsModStore ModStore { get; private set; } - public LuaGame Game; - public LuaCsNetworking Networking; - - public LuaScriptLoader LuaScriptLoader; - public CsScriptLoader NetScriptLoader; + public CsScriptLoader NetScriptLoader { get; private set; } + public CsLua Lua { get; private set; } public LuaCsSetup() { Hook = new LuaCsHook(); + ModStore = new LuaCsModStore(); Game = new LuaGame(); Networking = new LuaCsNetworking(); @@ -159,7 +163,7 @@ namespace Barotrauma public static void PrintCsMessage(object message) => PrintMessageBase("[CS] ", message, "Null"); public static void PrintLogMessage(object message) => PrintMessageBase("[LuaCs LOG] ", message, "Null"); - public DynValue DoString(string code, Table globalContext = null, string codeStringFriendly = null) + private DynValue DoString(string code, Table globalContext = null, string codeStringFriendly = null) { try { @@ -173,7 +177,7 @@ namespace Barotrauma return null; } - public DynValue DoFile(string file, Table globalContext = null, string codeStringFriendly = null) + private DynValue DoFile(string file, Table globalContext = null, string codeStringFriendly = null) { if (!LuaCsFile.IsPathAllowedLuaException(file, false)) return null; if (!LuaCsFile.Exists(file)) @@ -196,7 +200,7 @@ namespace Barotrauma } - public DynValue LoadString(string file, Table globalContext = null, string codeStringFriendly = null) + private DynValue LoadString(string file, Table globalContext = null, string codeStringFriendly = null) { try { @@ -211,7 +215,7 @@ namespace Barotrauma return null; } - public DynValue LoadFile(string file, Table globalContext = null, string codeStringFriendly = null) + private DynValue LoadFile(string file, Table globalContext = null, string codeStringFriendly = null) { if (!LuaCsFile.IsPathAllowedLuaException(file, false)) return null; if (!LuaCsFile.Exists(file)) @@ -233,7 +237,7 @@ namespace Barotrauma return null; } - public DynValue Require(string modname, Table globalContext) + private DynValue Require(string modname, Table globalContext) { try { @@ -262,7 +266,7 @@ namespace Barotrauma return null; } - public void SetModulePaths(string[] str) + private void SetModulePaths(string[] str) { LuaScriptLoader.ModulePaths = str; } @@ -281,9 +285,12 @@ namespace Barotrauma Game?.Stop(); Hook.Clear(); + ModStore.Clear(); Game = new LuaGame(); Networking = new LuaCsNetworking(); LuaScriptLoader = null; + Lua = null; + CsScript = null; } public void Initialize() @@ -302,31 +309,21 @@ namespace Barotrauma lua = new Script(CoreModules.Preset_SoftSandbox | CoreModules.Debug); lua.Options.DebugPrint = PrintMessage; lua.Options.ScriptLoader = LuaScriptLoader; + Lua = new CsLua(this); + CsScript = new CsScriptRunner(this); - Hook.Initialize(); Game = new LuaGame(); Networking = new LuaCsNetworking(); + Hook.Initialize(); + ModStore.Initialize(); + UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); - - var hookType = UserData.RegisterType(); - var hookDesc = (StandardUserDataDescriptor)hookType; - typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => { - if ( - m.Name.Contains("HookMethod") || - m.Name.Contains("UnhookMethod") || - m.Name.Contains("EnqueueFunction") || - m.Name.Contains("EnqueueTimedFunction") - ) - { - hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default)); - } - }); lua.Globals["printerror"] = (Action)PrintError; @@ -339,9 +336,11 @@ namespace Barotrauma lua.Globals["dostring"] = (Func)DoString; lua.Globals["load"] = (Func)LoadString; + lua.Globals["CsScript"] = CsScript; lua.Globals["LuaUserData"] = UserData.CreateStatic(); lua.Globals["Game"] = Game; lua.Globals["Hook"] = Hook; + lua.Globals["ModStore"] = ModStore; lua.Globals["Timer"] = new LuaCsTimer(); lua.Globals["File"] = UserData.CreateStatic(); lua.Globals["Networking"] = Networking;