Merge pull request #58 from oiltanker/unstable-tests

cs script runner + mod store + cs-lua interface
This commit is contained in:
Evil Factory
2022-04-17 14:31:44 -03:00
committed by GitHub
9 changed files with 374 additions and 75 deletions

View File

@@ -3223,7 +3223,17 @@ namespace Barotrauma
return;
}
GameMain.LuaCs.DoString(string.Join(" ", args));
GameMain.LuaCs.Lua.DoString(string.Join(" ", args));
}));
commands.Add(new Command("cl_cs", "cs_cl: runs a string on the client", (string[] args) =>
{
if (GameMain.Client != null && !GameMain.Client.HasPermission(ClientPermissions.ConsoleCommands))
{
ThrowError("Command not permitted.");
return;
}
GameMain.LuaCs.CsScript.Run(string.Join(" ", args));
}));
commands.Add(new Command("cl_reloadlua", "reloads lua on the client", (string[] args) =>
@@ -3236,28 +3246,6 @@ namespace Barotrauma
GameMain.LuaCs.Initialize();
}));
commands.Add(new Command("cl_net", "lua_net: runs a script on the client", (string[] args) =>
{
if (GameMain.Client != null && !GameMain.Client.HasPermission(ClientPermissions.ConsoleCommands))
{
ThrowError("Command not permitted.");
return;
}
GameMain.LuaCs.DoString(string.Join(" ", args));
}));
commands.Add(new Command("cl_reloadnet", "reloads lua on the client", (string[] args) =>
{
if (GameMain.Client != null && !GameMain.Client.HasPermission(ClientPermissions.ConsoleCommands))
{
ThrowError("Command not permitted.");
return;
}
GameMain.LuaCs.Initialize();
}));
}
private static void ReloadWearables(Character character, int variant = 0)

View File

@@ -1243,7 +1243,11 @@ namespace Barotrauma
commands.Add(new Command("lua", "lua: runs a string", (string[] args) =>
{
GameMain.LuaCs.DoString(string.Join(" ", args));
GameMain.LuaCs.Lua.DoString(string.Join(" ", args));
}));
commands.Add(new Command("cs", "cs: runs a string", (string[] args) =>
{
GameMain.LuaCs.CsScript.Run(string.Join(" ", args));
}));
commands.Add(new Command("reloadlua", "reloads lua", (string[] args) =>

View File

@@ -0,0 +1,22 @@
using Barotrauma;
using MoonSharp.Interpreter;
using System.Collections.Generic;
namespace Barotrauma
{
partial class LuaCsSetup {
public class CsLua
{
private LuaCsSetup setup;
public Table Globals { get; private set; }
public CsLua(LuaCsSetup setup)
{
this.setup = setup;
Globals = setup.lua.Globals;
}
public DynValue DoString(string code) => setup.DoString(code);
}
}
}

View File

@@ -9,30 +9,36 @@ using System.Reflection.Metadata;
namespace Barotrauma {
class CsScriptFilter
{
private const bool useWhitelist = false;
private static string[] typesPermited = new string[] {
private static readonly string[] typesPermitted = new string[] {
// Basics
"System.Runtime.CompilerServices.CompilationRelaxationsAttribute",
"System.Runtime.CompilerServices.RuntimeCompatibilityAttribute",
"System.Diagnostics.DebuggableAttribute",
"System.Object",
"System.String",
"System.Collections",
// Some roslyn magic
".DebuggingModes",
"System",
// Barotrauma
"Barotrauma",
// Lua
"MoonSharp.Interpreter.DynValue",
"MoonSharp.Interpreter.Closure",
"MoonSharp.Interpreter.Coroutine",
"MoonSharp.Interpreter.CoroutineState",
"MoonSharp.Interpreter.Table",
"MoonSharp.Interpreter.YieldRequest",
"MoonSharp.Interpreter.TailCallData",
"MoonSharp.Interpreter.DataType",
};
private static string[] typesProhibited = new string[] {
private static readonly string[] typesProhibited = new string[] {
//"System.Reflection",
//"System.Type",
"System.IO",
"Moonsharp",
"Barotrauma.IO",
};
public static bool IsTypeAllowed(string usingName)
public static bool IsTypeAllowed(string name)
{
if (useWhitelist && !typesPermited.Any(u => u.StartsWith(usingName))) return false;
if (typesProhibited.Any(u => u.StartsWith(usingName))) return false;
return true;
var matchPermitted = typesPermitted.Where(s => name.StartsWith(s));
var longestPemitted = matchPermitted.Count() > 0 ? matchPermitted.Max(s => s.Length) : 0;
var matchProhibited = typesProhibited.Where(s => name.StartsWith(s));
var longestProhibited = matchProhibited.Count() > 0 ? matchProhibited.Max(s => s.Length) : 0;
if (longestPemitted == 0 || longestPemitted < longestProhibited) return false;
else return true;
}
public static string FilterSyntaxTree(CSharpSyntaxTree tree)
@@ -62,7 +68,18 @@ namespace Barotrauma {
reader.TypeReferences.ToList().ForEach(t =>
{
var tRef = reader.GetTypeReference(t);
var typeName = $"{reader.GetString(tRef.Namespace)}.{reader.GetString(tRef.Name)}";
var typeName = $"{reader.GetString(tRef.Name)}";
EntityHandle handle = tRef.ResolutionScope;
TypeReference tr = tRef;
while (!handle.IsNil && handle.Kind == HandleKind.TypeReference)
{
tr = reader.GetTypeReference((TypeReferenceHandle)handle);
handle = tr.ResolutionScope;
typeName = $"{reader.GetString(tr.Name)}.{typeName}";
}
typeName = $"{reader.GetString(tr.Namespace)}.{typeName}";
if (!IsTypeAllowed(typeName)) conflictingTypes.Add(typeName);
});

View File

@@ -72,11 +72,11 @@ namespace Barotrauma
{
errStr += "\n" + diag.ToString();
}
LuaCsSetup.PrintCsMessage(errStr);
LuaCsSetup.PrintCsError(errStr);
}
catch (Exception ex)
{
LuaCsSetup.PrintCsMessage("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace);
LuaCsSetup.PrintCsError("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace);
}
}
@@ -98,7 +98,7 @@ namespace Barotrauma
string errStr = "NET MODS NOT LOADED | Mod cmopilation errors:";
foreach (Diagnostic diagnostic in failures)
errStr = $"\n{diagnostic}";
LuaCsSetup.PrintCsMessage(errStr);
LuaCsSetup.PrintCsError(errStr);
}
else
{
@@ -109,7 +109,7 @@ namespace Barotrauma
mem.Seek(0, SeekOrigin.Begin);
Assembly = LoadFromStream(mem);
}
else LuaCsSetup.PrintCsMessage(errStr);
else LuaCsSetup.PrintCsError(errStr);
}
}
syntaxTrees.Clear();

View File

@@ -0,0 +1,132 @@
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.CodeAnalysis.Scripting;
using System.Reflection;
using Microsoft.CodeAnalysis.CSharp;
using System.Linq;
using Microsoft.CodeAnalysis;
using System.Runtime.Loader;
using System.Reflection.PortableExecutable;
using System.Reflection.Metadata;
namespace Barotrauma
{
class CsScriptRunner : AssemblyLoadContext
{
public LuaCsSetup setup;
private List<MetadataReference> defaultReferences;
private CSharpCompilationOptions compileOptions;
public CsScriptRunner(LuaCsSetup setup)
{
this.setup = setup;
defaultReferences = AppDomain.CurrentDomain.GetAssemblies()
.Where(a => !(a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit")))
.Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference)
.ToList();
compileOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)
.WithMetadataImportOptions(MetadataImportOptions.All)
.WithOptimizationLevel(OptimizationLevel.Release)
.WithAllowUnsafe(false);
}
private static string ToOneTimeScript(string code)
{
var prefix = @"
public class NetOneTimeScriptRunner {
public NetOneTimeScriptRunner() { }
public object Run() {
";
var postfix = @"
return null;
}
}
";
return prefix + code + postfix;
}
public object Run(string code)
{
object scriptResilt = null;
try
{
var syntaxTree = SyntaxFactory.ParseSyntaxTree(ToOneTimeScript(code), CSharpParseOptions.Default);
var compilation = CSharpCompilation.Create("NetOneTimeScriptAssembly", new[] { syntaxTree }, defaultReferences, compileOptions);
Assembly assembly = null;
using (var mem = new MemoryStream())
{
var result = compilation.Emit(mem);
if (!result.Success)
{
IEnumerable<Diagnostic> failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error);
string errStr = "Script cmopilation errors:";
foreach (Diagnostic diagnostic in failures)
errStr = $"\n{diagnostic}";
LuaCsSetup.PrintCsError(errStr);
}
else
{
mem.Seek(0, SeekOrigin.Begin);
var metaReader = new PEReader(mem).GetMetadataReader();
var errStr = CsScriptFilter.FilterMetadata(metaReader);
if (errStr == null)
{
foreach (var handle in metaReader.TypeDefinitions)
{
var typeDef = metaReader.GetTypeDefinition(handle);
var typeName = $"{metaReader.GetString(typeDef.Namespace)}.{metaReader.GetString(typeDef.Name)}";
if (typeName != ".NetOneTimeScriptRunner")
{
errStr = "Script Error - malformed assembly";
break;
}
}
}
if (errStr == null)
{
mem.Seek(0, SeekOrigin.Begin);
assembly = LoadFromStream(mem);
var runner = assembly.CreateInstance("NetOneTimeScriptRunner");
if (runner != null)
{
if (runner.GetType().GetMethods().Count() > 1) LuaCsSetup.PrintCsError("Script Error - malformed runner class");
else
{
var method = runner.GetType().GetMethod("Run", BindingFlags.Public | BindingFlags.Instance);
if (method != null) scriptResilt = method.Invoke(runner, null);
else LuaCsSetup.PrintCsError("Script Error - no run method detected");
}
}
else LuaCsSetup.PrintCsError("Script Error - no runner class detected");
}
else LuaCsSetup.PrintCsError(errStr);
}
}
Unload();
}
catch (CompilationErrorException ex)
{
string errStr = "Script Cmopilation Error:";
foreach (var diag in ex.Diagnostics)
{
errStr += "\n" + diag.ToString();
}
LuaCsSetup.PrintCsError(errStr);
}
catch (Exception ex)
{
LuaCsSetup.PrintCsError("Error running script:\n" + ex.Message + "\n" + ex.StackTrace);
}
return scriptResilt;
}
}
}

View File

@@ -6,6 +6,7 @@ using HarmonyLib;
using System.Collections.Generic;
using System.Text;
using MoonSharp.Interpreter.Interop;
using static Barotrauma.LuaCsSetup;
namespace Barotrauma
{
@@ -73,7 +74,20 @@ namespace Barotrauma
{
harmony = new Harmony("LuaCsForBarotrauma");
}
var hookType = UserData.RegisterType<LuaCsHook>();
var hookDesc = (StandardUserDataDescriptor)hookType;
typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => {
if (
m.Name.Contains("HookMethod") ||
m.Name.Contains("UnhookMethod") ||
m.Name.Contains("EnqueueFunction") ||
m.Name.Contains("EnqueueTimedFunction")
)
{
hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default));
}
});
}
private static void _hookLuaCsPatch(MethodBase __originalMethod, object[] __args, object __instance, out object result, HookMethodType hookMethodType)
{
@@ -202,9 +216,22 @@ namespace Barotrauma
return methodInfo;
}
private static readonly string[] prohibitedHooks = {
"Barotrauma.Lua",
"Barotrauma.Cs"
};
public void HookMethod(string identifier, MethodInfo method, LuaCsPatch patch, HookMethodType hookType = HookMethodType.Before, ACsMod owner = null)
{
if (identifier == null || method == null || patch == null) throw new ArgumentNullException("Identifier, Method and Patch arguments must not be null.");
if (identifier == null || method == null || patch == null)
{
GameMain.LuaCs.HandleException(new ArgumentNullException("Identifier, Method and Patch arguments must not be null."), exceptionType: ExceptionType.Both);
return;
}
if (prohibitedHooks.Any(h => method.DeclaringType.FullName.StartsWith(h)))
{
GameMain.LuaCs.HandleException(new ArgumentException("Hooks into Modding Environment are prohibited."), exceptionType: ExceptionType.Both);
return;
}
var funcAddr = ((long)method.MethodHandle.GetFunctionPointer());
var patches = Harmony.GetPatchInfo(method);

View File

@@ -0,0 +1,110 @@
using MoonSharp.Interpreter;
using MoonSharp.Interpreter.Interop;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
namespace Barotrauma
{
partial class LuaCsSetup
{
public class LuaCsModStore
{
public abstract class ModStore<T, TStore>
{
protected Dictionary<string, TStore> store;
public TStore Set(string name, TStore value) => store[name] = value;
public TStore Get(string name) => store[name];
public ModStore(Dictionary<string, TStore> store) => this.store = store;
public abstract bool Equals(T value);
}
public class LuaModStore : ModStore<string, DynValue>
{
public string Name;
public LuaModStore(Dictionary<string, DynValue> store) : base(store) { }
public override bool Equals(string value) => Name == value;
}
public class CsModStore : ModStore<ACsMod, object>
{
public ACsMod Mod;
public CsModStore(Dictionary<string, object> store) : base(store) { }
public override bool Equals(ACsMod value) => Mod == value;
}
private HashSet<LuaModStore> luaModInterface;
private HashSet<CsModStore> csModInterface;
public LuaCsModStore()
{
luaModInterface = new HashSet<LuaModStore>();
csModInterface = new HashSet<CsModStore>();
}
public void Initialize()
{
UserData.RegisterType<LuaModStore>();
UserData.RegisterType<CsModStore>();
var msType = UserData.RegisterType<LuaCsModStore>();
var msDesc = (StandardUserDataDescriptor)msType;
typeof(StandardUserDataDescriptor).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m =>
{
if (
m.Name.Contains("Register")
)
{
msDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default));
}
});
}
public void Clear()
{
luaModInterface.Clear();
csModInterface.Clear();
}
protected LuaModStore Register(string modName)
{
if (luaModInterface.Any(i => i.Equals(modName)))
{
GameMain.LuaCs.HandleException(new ArgumentException($"'{modName}' entry already registered"), exceptionType: ExceptionType.Lua);
return null;
}
var newHandle = new LuaModStore(new Dictionary<string, DynValue>());
if (luaModInterface.Add(newHandle)) return newHandle;
else return null;
}
[MoonSharpHidden]
public CsModStore Register(ACsMod mod)
{
if (csModInterface.Any(i => i.Equals(mod)))
{
GameMain.LuaCs.HandleException(new ArgumentException($"'{mod.GetType().FullName}' entry already registered"), exceptionType: ExceptionType.CSharp);
return null;
}
var newHandle = new CsModStore(new Dictionary<string, object>());
if (csModInterface.Add(newHandle)) return newHandle;
else return null;
}
public CsModStore GetCsStore(string modName) {
var result = csModInterface.Where(i => i.Mod.GetType().FullName == modName).FirstOrDefault();
if (!result.Mod.IsDisposed) return result;
else
{
csModInterface.Remove(result);
return null;
}
}
protected LuaModStore GetLuaStore(string modName) => luaModInterface.Where(i => i.Name == modName).FirstOrDefault();
}
}
}

View File

@@ -13,6 +13,7 @@ using System.Linq;
using System.Reflection;
[assembly: InternalsVisibleTo("NetScriptAssembly", AllInternalsVisible = true)]
[assembly: InternalsVisibleTo("NetOneTimeScriptAssembly", AllInternalsVisible = true)]
namespace Barotrauma
{
partial class LuaCsSetup
@@ -20,19 +21,22 @@ namespace Barotrauma
public const string LUASETUP_FILE = "Lua/LuaSetup.lua";
public const string VERSION_FILE = "luacsversion.txt";
public Script lua;
private Script lua;
public CsScriptRunner CsScript { get; private set; }
public LuaGame Game { get; private set; }
public LuaScriptLoader LuaScriptLoader { get; private set; }
internal LuaCsHook Hook { get; private set; }
public LuaCsHook Hook { get; private set; }
public LuaCsNetworking Networking { get; private set; }
public LuaCsModStore ModStore { get; private set; }
public LuaGame Game;
public LuaCsNetworking Networking;
public LuaScriptLoader LuaScriptLoader;
public CsScriptLoader NetScriptLoader;
public CsScriptLoader NetScriptLoader { get; private set; }
public CsLua Lua { get; private set; }
public LuaCsSetup()
{
Hook = new LuaCsHook();
ModStore = new LuaCsModStore();
Game = new LuaGame();
Networking = new LuaCsNetworking();
@@ -159,7 +163,7 @@ namespace Barotrauma
public static void PrintCsMessage(object message) => PrintMessageBase("[CS] ", message, "Null");
public static void PrintLogMessage(object message) => PrintMessageBase("[LuaCs LOG] ", message, "Null");
public DynValue DoString(string code, Table globalContext = null, string codeStringFriendly = null)
private DynValue DoString(string code, Table globalContext = null, string codeStringFriendly = null)
{
try
{
@@ -173,7 +177,7 @@ namespace Barotrauma
return null;
}
public DynValue DoFile(string file, Table globalContext = null, string codeStringFriendly = null)
private DynValue DoFile(string file, Table globalContext = null, string codeStringFriendly = null)
{
if (!LuaCsFile.IsPathAllowedLuaException(file, false)) return null;
if (!LuaCsFile.Exists(file))
@@ -196,7 +200,7 @@ namespace Barotrauma
}
public DynValue LoadString(string file, Table globalContext = null, string codeStringFriendly = null)
private DynValue LoadString(string file, Table globalContext = null, string codeStringFriendly = null)
{
try
{
@@ -211,7 +215,7 @@ namespace Barotrauma
return null;
}
public DynValue LoadFile(string file, Table globalContext = null, string codeStringFriendly = null)
private DynValue LoadFile(string file, Table globalContext = null, string codeStringFriendly = null)
{
if (!LuaCsFile.IsPathAllowedLuaException(file, false)) return null;
if (!LuaCsFile.Exists(file))
@@ -233,7 +237,7 @@ namespace Barotrauma
return null;
}
public DynValue Require(string modname, Table globalContext)
private DynValue Require(string modname, Table globalContext)
{
try
{
@@ -262,7 +266,7 @@ namespace Barotrauma
return null;
}
public void SetModulePaths(string[] str)
private void SetModulePaths(string[] str)
{
LuaScriptLoader.ModulePaths = str;
}
@@ -281,9 +285,12 @@ namespace Barotrauma
Game?.Stop();
Hook.Clear();
ModStore.Clear();
Game = new LuaGame();
Networking = new LuaCsNetworking();
LuaScriptLoader = null;
Lua = null;
CsScript = null;
}
public void Initialize()
@@ -302,31 +309,21 @@ namespace Barotrauma
lua = new Script(CoreModules.Preset_SoftSandbox | CoreModules.Debug);
lua.Options.DebugPrint = PrintMessage;
lua.Options.ScriptLoader = LuaScriptLoader;
Lua = new CsLua(this);
CsScript = new CsScriptRunner(this);
Hook.Initialize();
Game = new LuaGame();
Networking = new LuaCsNetworking();
Hook.Initialize();
ModStore.Initialize();
UserData.RegisterType<CsScriptRunner>();
UserData.RegisterType<LuaGame>();
UserData.RegisterType<LuaCsTimer>();
UserData.RegisterType<LuaCsFile>();
UserData.RegisterType<LuaCsNetworking>();
UserData.RegisterType<LuaUserData>();
UserData.RegisterType<IUserDataDescriptor>();
var hookType = UserData.RegisterType<LuaCsHook>();
var hookDesc = (StandardUserDataDescriptor)hookType;
typeof(LuaCsHook).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).ToList().ForEach(m => {
if (
m.Name.Contains("HookMethod") ||
m.Name.Contains("UnhookMethod") ||
m.Name.Contains("EnqueueFunction") ||
m.Name.Contains("EnqueueTimedFunction")
)
{
hookDesc.AddMember(m.Name, new MethodMemberDescriptor(m, InteropAccessMode.Default));
}
});
lua.Globals["printerror"] = (Action<object>)PrintError;
@@ -339,9 +336,11 @@ namespace Barotrauma
lua.Globals["dostring"] = (Func<string, Table, string, DynValue>)DoString;
lua.Globals["load"] = (Func<string, Table, string, DynValue>)LoadString;
lua.Globals["CsScript"] = CsScript;
lua.Globals["LuaUserData"] = UserData.CreateStatic<LuaUserData>();
lua.Globals["Game"] = Game;
lua.Globals["Hook"] = Hook;
lua.Globals["ModStore"] = ModStore;
lua.Globals["Timer"] = new LuaCsTimer();
lua.Globals["File"] = UserData.CreateStatic<LuaCsFile>();
lua.Globals["Networking"] = Networking;