cs script runner + mod store + cs-lua interface

This commit is contained in:
Oiltanker
2022-04-17 16:05:00 +03:00
parent 6aeac06112
commit a7b4004058
9 changed files with 336 additions and 65 deletions

View File

@@ -3223,7 +3223,17 @@ namespace Barotrauma
return;
}
GameMain.LuaCs.DoString(string.Join(" ", args));
GameMain.LuaCs.Lua.DoString(string.Join(" ", args));
}));
commands.Add(new Command("cl_cs", "cs_cl: runs a string on the client", (string[] args) =>
{
if (GameMain.Client != null && !GameMain.Client.HasPermission(ClientPermissions.ConsoleCommands))
{
ThrowError("Command not permitted.");
return;
}
GameMain.LuaCs.CsScript.Run(string.Join(" ", args));
}));
commands.Add(new Command("cl_reloadlua", "reloads lua on the client", (string[] args) =>
@@ -3236,28 +3246,6 @@ namespace Barotrauma
GameMain.LuaCs.Initialize();
}));
commands.Add(new Command("cl_net", "lua_net: runs a script on the client", (string[] args) =>
{
if (GameMain.Client != null && !GameMain.Client.HasPermission(ClientPermissions.ConsoleCommands))
{
ThrowError("Command not permitted.");
return;
}
GameMain.LuaCs.DoString(string.Join(" ", args));
}));
commands.Add(new Command("cl_reloadnet", "reloads lua on the client", (string[] args) =>
{
if (GameMain.Client != null && !GameMain.Client.HasPermission(ClientPermissions.ConsoleCommands))
{
ThrowError("Command not permitted.");
return;
}
GameMain.LuaCs.Initialize();
}));
}
private static void ReloadWearables(Character character, int variant = 0)

View File

@@ -1243,7 +1243,11 @@ namespace Barotrauma
commands.Add(new Command("lua", "lua: runs a string", (string[] args) =>
{
GameMain.LuaCs.DoString(string.Join(" ", args));
GameMain.LuaCs.Lua.DoString(string.Join(" ", args));
}));
commands.Add(new Command("cs", "cs: runs a string", (string[] args) =>
{
GameMain.LuaCs.CsScript.Run(string.Join(" ", args));
}));
commands.Add(new Command("reloadlua", "reloads lua", (string[] args) =>

View File

@@ -0,0 +1,22 @@
using Barotrauma;
using MoonSharp.Interpreter;
using System.Collections.Generic;
namespace Barotrauma
{
partial class LuaCsSetup {
public class CsLua
{
private LuaCsSetup setup;
public Table Globals { get; private set; }
public CsLua(LuaCsSetup setup)
{
this.setup = setup;
Globals = setup.lua.Globals;
}
public DynValue DoString(string code) => setup.DoString(code);
}
}
}

View File

@@ -9,9 +9,7 @@ using System.Reflection.Metadata;
namespace Barotrauma {
class CsScriptFilter
{
private const bool useWhitelist = false;
private static string[] typesPermited = new string[] {
private static readonly string[] typesPermitted = new string[] {
// Basics
"System.Runtime.CompilerServices.CompilationRelaxationsAttribute",
"System.Runtime.CompilerServices.RuntimeCompatibilityAttribute",
@@ -19,20 +17,25 @@ namespace Barotrauma {
"System.Object",
"System.String",
"System.Collections",
"System",
// Some roslyn magic
".DebuggingModes",
// Barotrauma
"Barotrauma",
// Lua
"MoonSharp.Interpreter"
};
private static string[] typesProhibited = new string[] {
private static readonly string[] typesProhibited = new string[] {
//"System.Reflection",
"System.IO",
"Moonsharp.Interpreter.UserData"
};
public static bool IsTypeAllowed(string usingName)
public static bool IsTypeAllowed(string name)
{
if (useWhitelist && !typesPermited.Any(u => u.StartsWith(usingName))) return false;
if (typesProhibited.Any(u => u.StartsWith(usingName))) return false;
return true;
var longestPemited = typesPermitted.Where(s => s.StartsWith(name)).Max(s => s.Length);
var longestProhibitted = typesProhibited.Where(s => s.StartsWith(name)).Max(s => s.Length);
if (longestPemited == 0 || longestPemited < longestProhibitted) return false;
else return true;
}
public static string FilterSyntaxTree(CSharpSyntaxTree tree)

View File

@@ -72,11 +72,11 @@ namespace Barotrauma
{
errStr += "\n" + diag.ToString();
}
LuaCsSetup.PrintCsMessage(errStr);
LuaCsSetup.PrintCsError(errStr);
}
catch (Exception ex)
{
LuaCsSetup.PrintCsMessage("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace);
LuaCsSetup.PrintCsError("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace);
}
}
@@ -98,7 +98,7 @@ namespace Barotrauma
string errStr = "NET MODS NOT LOADED | Mod cmopilation errors:";
foreach (Diagnostic diagnostic in failures)
errStr = $"\n{diagnostic}";
LuaCsSetup.PrintCsMessage(errStr);
LuaCsSetup.PrintCsError(errStr);
}
else
{
@@ -109,7 +109,7 @@ namespace Barotrauma
mem.Seek(0, SeekOrigin.Begin);
Assembly = LoadFromStream(mem);
}
else LuaCsSetup.PrintCsMessage(errStr);
else LuaCsSetup.PrintCsError(errStr);
}
}
syntaxTrees.Clear();

View File

@@ -0,0 +1,132 @@
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.CodeAnalysis.Scripting;
using System.Reflection;
using Microsoft.CodeAnalysis.CSharp;
using System.Linq;
using Microsoft.CodeAnalysis;
using System.Runtime.Loader;
using System.Reflection.PortableExecutable;
using System.Reflection.Metadata;
namespace Barotrauma
{
class CsScriptRunner : AssemblyLoadContext
{
public LuaCsSetup setup;
private List<MetadataReference> defaultReferences;
private CSharpCompilationOptions compileOptions;
public CsScriptRunner(LuaCsSetup setup)
{
this.setup = setup;
defaultReferences = AppDomain.CurrentDomain.GetAssemblies()
.Where(a => !(a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit")))
.Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference)
.ToList();
compileOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)
.WithMetadataImportOptions(MetadataImportOptions.All)
.WithOptimizationLevel(OptimizationLevel.Release)
.WithAllowUnsafe(false);
}
private static string ToOneTimeScript(string code)
{
var prefix = @"
public class NetOneTimeScriptRunner {
public NetOneTimeScriptRunner() { }
public object Run() {
";
var postfix = @"
return null;
}
}
";
return prefix + code + postfix;
}
public object Run(string code)
{
object scriptResilt = null;
try
{
var syntaxTree = SyntaxFactory.ParseSyntaxTree(ToOneTimeScript(code), CSharpParseOptions.Default);
var compilation = CSharpCompilation.Create("NetOneTimeScriptAssembly", new[] { syntaxTree }, defaultReferences, compileOptions);
Assembly assembly = null;
using (var mem = new MemoryStream())
{
var result = compilation.Emit(mem);
if (!result.Success)
{
IEnumerable<Diagnostic> failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error);
string errStr = "Script cmopilation errors:";
foreach (Diagnostic diagnostic in failures)
errStr = $"\n{diagnostic}";
LuaCsSetup.PrintCsError(errStr);
}
else
{
mem.Seek(0, SeekOrigin.Begin);
var metaReader = new PEReader(mem).GetMetadataReader();
var errStr = CsScriptFilter.FilterMetadata(metaReader);
if (errStr == null)
{
foreach (var handle in metaReader.TypeDefinitions)
{
var typeDef = metaReader.GetTypeDefinition(handle);
var typeName = $"{metaReader.GetString(typeDef.Namespace)}.{metaReader.GetString(typeDef.Name)}";
if (typeName != ".NetOneTimeScriptRunner")
{
errStr = "Script Error - malformed assembly";
break;
}
}
}
if (errStr == null)
{
mem.Seek(0, SeekOrigin.Begin);
assembly = LoadFromStream(mem);
var runner = assembly.CreateInstance("NetOneTimeScriptRunner");
if (runner != null)
{
if (runner.GetType().GetMethods().Count() > 1) LuaCsSetup.PrintCsError("Script Error - malformed runner class");
else
{
var method = runner.GetType().GetMethod("Run", BindingFlags.Public | BindingFlags.Instance);
if (method != null) scriptResilt = method.Invoke(runner, null);
else LuaCsSetup.PrintCsError("Script Error - no run method detected");
}
}
else LuaCsSetup.PrintCsError("Script Error - no runner class detected");
}
else LuaCsSetup.PrintCsError(errStr);
}
}
Unload();
}
catch (CompilationErrorException ex)
{
string errStr = "Script Cmopilation Error:";
foreach (var diag in ex.Diagnostics)
{
errStr += "\n" + diag.ToString();
}
LuaCsSetup.PrintCsError(errStr);
}
catch (Exception ex)
{
LuaCsSetup.PrintCsError("Error running script:\n" + ex.Message + "\n" + ex.StackTrace);
}
return scriptResilt;
}
}
}

View File

@@ -73,7 +73,20 @@ namespace Barotrauma
{
harmony = new Harmony("LuaCsForBarotrauma");
}
var hookType = UserData.RegisterType<LuaCsHook>();
var hookDesc = (StandardUserDataDescriptor)hookType;
typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => {
if (
m.Name.Contains("HookMethod") ||
m.Name.Contains("UnhookMethod") ||
m.Name.Contains("EnqueueFunction") ||
m.Name.Contains("EnqueueTimedFunction")
)
{
hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default));
}
});
}
private static void _hookLuaCsPatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookMethodType)
{

View File

@@ -0,0 +1,110 @@
using MoonSharp.Interpreter;
using MoonSharp.Interpreter.Interop;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
namespace Barotrauma
{
partial class LuaCsSetup
{
public class LuaCsModStore
{
public abstract class ModStore<T, TStore>
{
protected Dictionary<string, TStore> store;
public TStore Set(string name, TStore value) => store[name] = value;
public TStore Get(string name) => store[name];
public ModStore(Dictionary<string, TStore> store) => this.store = store;
public abstract bool Equals(T value);
}
public class LuaModStore : ModStore<string, DynValue>
{
public string Name;
public LuaModStore(Dictionary<string, DynValue> store) : base(store) { }
public override bool Equals(string value) => Name == value;
}
public class CsModStore : ModStore<ACsMod, object>
{
public ACsMod Mod;
public CsModStore(Dictionary<string, object> store) : base(store) { }
public override bool Equals(ACsMod value) => Mod == value;
}
private HashSet<LuaModStore> luaModInterface;
private HashSet<CsModStore> csModInterface;
public LuaCsModStore()
{
luaModInterface = new HashSet<LuaModStore>();
csModInterface = new HashSet<CsModStore>();
}
public void Initialize()
{
UserData.RegisterType<LuaModStore>();
UserData.RegisterType<CsModStore>();
var msType = UserData.RegisterType<LuaCsModStore>();
var msDesc = (StandardUserDataDescriptor)msType;
typeof(StandardUserDataDescriptor).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m =>
{
if (
m.Name.Contains("Register")
)
{
msDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default));
}
});
}
public void Clear()
{
luaModInterface.Clear();
csModInterface.Clear();
}
protected LuaModStore Register(string modName)
{
if (luaModInterface.Any(i => i.Equals(modName)))
{
GameMain.LuaCs.HandleException(new ArgumentException($"'{modName}' entry already registered"), exceptionType: ExceptionType.Lua);
return null;
}
var newHandle = new LuaModStore(new Dictionary<string, DynValue>());
if (luaModInterface.Add(newHandle)) return newHandle;
else return null;
}
[MoonSharpHidden]
public CsModStore Register(ACsMod mod)
{
if (csModInterface.Any(i => i.Equals(mod)))
{
GameMain.LuaCs.HandleException(new ArgumentException($"'{mod.GetType().FullName}' entry already registered"), exceptionType: ExceptionType.CSharp);
return null;
}
var newHandle = new CsModStore(new Dictionary<string, object>());
if (csModInterface.Add(newHandle)) return newHandle;
else return null;
}
public CsModStore GetCsStore(string modName) {
var result = csModInterface.Where(i => i.Mod.GetType().FullName == modName).FirstOrDefault();
if (!result.Mod.IsDisposed) return result;
else
{
csModInterface.Remove(result);
return null;
}
}
protected LuaModStore GetLuaStore(string modName) => luaModInterface.Where(i => i.Name == modName).FirstOrDefault();
}
}
}

View File

@@ -13,6 +13,7 @@ using System.Linq;
using System.Reflection;
[assembly: InternalsVisibleTo("NetScriptAssembly", AllInternalsVisible = true)]
[assembly: InternalsVisibleTo("NetOneTimeScriptAssembly", AllInternalsVisible = true)]
namespace Barotrauma
{
partial class LuaCsSetup
@@ -20,19 +21,22 @@ namespace Barotrauma
public const string LUASETUP_FILE = "Lua/LuaSetup.lua";
public const string VERSION_FILE = "luacsversion.txt";
public Script lua;
private Script lua;
public CsScriptRunner CsScript { get; private set; }
public LuaGame Game { get; private set; }
public LuaScriptLoader LuaScriptLoader { get; private set; }
internal LuaCsHook Hook { get; private set; }
public LuaCsHook Hook { get; private set; }
public LuaCsNetworking Networking { get; private set; }
public LuaCsModStore ModStore { get; private set; }
public LuaGame Game;
public LuaCsNetworking Networking;
public LuaScriptLoader LuaScriptLoader;
public CsScriptLoader NetScriptLoader;
public CsScriptLoader NetScriptLoader { get; private set; }
public CsLua Lua { get; private set; }
public LuaCsSetup()
{
Hook = new LuaCsHook();
ModStore = new LuaCsModStore();
Game = new LuaGame();
Networking = new LuaCsNetworking();
@@ -159,7 +163,7 @@ namespace Barotrauma
public static void PrintCsMessage(object message) => PrintMessageBase("[CS] ", message, "Null");
public static void PrintLogMessage(object message) => PrintMessageBase("[LuaCs LOG] ", message, "Null");
public DynValue DoString(string code, Table globalContext = null, string codeStringFriendly = null)
private DynValue DoString(string code, Table globalContext = null, string codeStringFriendly = null)
{
try
{
@@ -173,7 +177,7 @@ namespace Barotrauma
return null;
}
public DynValue DoFile(string file, Table globalContext = null, string codeStringFriendly = null)
private DynValue DoFile(string file, Table globalContext = null, string codeStringFriendly = null)
{
if (!LuaCsFile.IsPathAllowedLuaException(file, false)) return null;
if (!LuaCsFile.Exists(file))
@@ -196,7 +200,7 @@ namespace Barotrauma
}
public DynValue LoadString(string file, Table globalContext = null, string codeStringFriendly = null)
private DynValue LoadString(string file, Table globalContext = null, string codeStringFriendly = null)
{
try
{
@@ -211,7 +215,7 @@ namespace Barotrauma
return null;
}
public DynValue LoadFile(string file, Table globalContext = null, string codeStringFriendly = null)
private DynValue LoadFile(string file, Table globalContext = null, string codeStringFriendly = null)
{
if (!LuaCsFile.IsPathAllowedLuaException(file, false)) return null;
if (!LuaCsFile.Exists(file))
@@ -233,7 +237,7 @@ namespace Barotrauma
return null;
}
public DynValue Require(string modname, Table globalContext)
private DynValue Require(string modname, Table globalContext)
{
try
{
@@ -262,7 +266,7 @@ namespace Barotrauma
return null;
}
public void SetModulePaths(string[] str)
private void SetModulePaths(string[] str)
{
LuaScriptLoader.ModulePaths = str;
}
@@ -281,9 +285,12 @@ namespace Barotrauma
Game?.Stop();
Hook.Clear();
ModStore.Clear();
Game = new LuaGame();
Networking = new LuaCsNetworking();
LuaScriptLoader = null;
Lua = null;
CsScript = null;
}
public void Initialize()
@@ -302,31 +309,21 @@ namespace Barotrauma
lua = new Script(CoreModules.Preset_SoftSandbox | CoreModules.Debug);
lua.Options.DebugPrint = PrintMessage;
lua.Options.ScriptLoader = LuaScriptLoader;
Lua = new CsLua(this);
CsScript = new CsScriptRunner(this);
Hook.Initialize();
Game = new LuaGame();
Networking = new LuaCsNetworking();
Hook.Initialize();
ModStore.Initialize();
UserData.RegisterType<CsScriptRunner>();
UserData.RegisterType<LuaGame>();
UserData.RegisterType<LuaCsTimer>();
UserData.RegisterType<LuaCsFile>();
UserData.RegisterType<LuaCsNetworking>();
UserData.RegisterType<LuaUserData>();
UserData.RegisterType<IUserDataDescriptor>();
var hookType = UserData.RegisterType<LuaCsHook>();
var hookDesc = (StandardUserDataDescriptor)hookType;
typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => {
if (
m.Name.Contains("HookMethod") ||
m.Name.Contains("UnhookMethod") ||
m.Name.Contains("EnqueueFunction") ||
m.Name.Contains("EnqueueTimedFunction")
)
{
hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default));
}
});
lua.Globals["printerror"] = (Action<object>)PrintError;
@@ -339,9 +336,11 @@ namespace Barotrauma
lua.Globals["dostring"] = (Func<string, Table, string, DynValue>)DoString;
lua.Globals["load"] = (Func<string, Table, string, DynValue>)LoadString;
lua.Globals["CsScript"] = CsScript;
lua.Globals["LuaUserData"] = UserData.CreateStatic<LuaUserData>();
lua.Globals["Game"] = Game;
lua.Globals["Hook"] = Hook;
lua.Globals["ModStore"] = ModStore;
lua.Globals["Timer"] = new LuaCsTimer();
lua.Globals["File"] = UserData.CreateStatic<LuaCsFile>();
lua.Globals["Networking"] = Networking;