From 9f1c3fa8239fef0cd274130799030e3e21db5e15 Mon Sep 17 00:00:00 2001 From: Evil Factory <36804725+evilfactory@users.noreply.github.com> Date: Sat, 31 Jan 2026 17:44:36 -0300 Subject: [PATCH] Move UserData checks out of Lua --- Barotrauma/BarotraumaShared/Lua/LuaSetup.lua | 5 +- .../BarotraumaShared/Lua/LuaUserData.lua | 97 --------- Barotrauma/BarotraumaShared/Lua/PostSetup.lua | 84 +------- .../LuaCs/Lua/LuaClasses/LuaSafeUserData.cs | 196 ++++++++++++++++++ .../LuaCs/Lua/LuaClasses/LuaUserData.cs | 72 ++++++- .../SharedSource/LuaCs/LuaCsSetup.cs | 5 +- 6 files changed, 279 insertions(+), 180 deletions(-) delete mode 100644 Barotrauma/BarotraumaShared/Lua/LuaUserData.lua create mode 100644 Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaSafeUserData.cs diff --git a/Barotrauma/BarotraumaShared/Lua/LuaSetup.lua b/Barotrauma/BarotraumaShared/Lua/LuaSetup.lua index 05163a0dd..fc970c3b6 100644 --- a/Barotrauma/BarotraumaShared/Lua/LuaSetup.lua +++ b/Barotrauma/BarotraumaShared/Lua/LuaSetup.lua @@ -7,7 +7,7 @@ package.path = {path .. "/?.lua"} setmodulepaths(package.path) -- Setup Libraries -require("LuaUserData") +LuaSetup.LuaUserData = LuaUserData require("DefaultRegister/RegisterShared") @@ -35,9 +35,6 @@ AddTableToGlobal(require("CompatibilityLib")) require("DefaultHook") -Descriptors = LuaSetup.LuaUserData.Descriptors -LuaUserData = LuaSetup.LuaUserData - require("DefaultLib/Utils/Math") require("DefaultLib/Utils/String") require("DefaultLib/Utils/Util") diff --git a/Barotrauma/BarotraumaShared/Lua/LuaUserData.lua b/Barotrauma/BarotraumaShared/Lua/LuaUserData.lua deleted file mode 100644 index 387fc7f2b..000000000 --- a/Barotrauma/BarotraumaShared/Lua/LuaUserData.lua +++ /dev/null @@ -1,97 +0,0 @@ -local clrLuaUserData = LuaUserData -local luaUserData = {} - -luaUserData.Descriptors = {} - -LuaSetup.LuaUserData = luaUserData - -luaUserData.IsRegistered = clrLuaUserData.IsRegistered -luaUserData.UnregisterType = clrLuaUserData.UnregisterType -luaUserData.RegisterGenericType = clrLuaUserData.RegisterGenericType -luaUserData.RegisterExtensionType = clrLuaUserData.RegisterExtensionType -luaUserData.UnregisterGenericType = clrLuaUserData.UnregisterGenericType -luaUserData.IsTargetType = clrLuaUserData.IsTargetType -luaUserData.TypeOf = clrLuaUserData.TypeOf -luaUserData.GetType = clrLuaUserData.GetType -luaUserData.CreateEnumTable = clrLuaUserData.CreateEnumTable -luaUserData.MakeFieldAccessible = clrLuaUserData.MakeFieldAccessible -luaUserData.MakeMethodAccessible = clrLuaUserData.MakeMethodAccessible -luaUserData.MakePropertyAccessible = clrLuaUserData.MakePropertyAccessible -luaUserData.AddMethod = clrLuaUserData.AddMethod -luaUserData.AddField = clrLuaUserData.AddField -luaUserData.RemoveMember = clrLuaUserData.RemoveMember -luaUserData.CreateUserDataFromDescriptor = clrLuaUserData.CreateUserDataFromDescriptor -luaUserData.CreateUserDataFromType = clrLuaUserData.CreateUserDataFromType -luaUserData.HasMember = clrLuaUserData.HasMember - -luaUserData.RegisterType = function(typeName) - local success, result = pcall(clrLuaUserData.RegisterType, typeName) - - if not success then - error(result, 2) - end - - luaUserData.Descriptors[typeName] = result - - return result -end - -luaUserData.RegisterTypeBarotrauma = function(typeName) - typeName = "Barotrauma." .. typeName - local success, result = pcall(luaUserData.RegisterType, typeName) - - if not success then - error(result, 2) - end - - return result -end - -luaUserData.AddCallMetaTable = function (userdata) - if userdata == nil then - error("Attempted to add a call metatable to a nil value.", 2) - end - - if not LuaUserData.HasMember(userdata, ".ctor") then - error("Attempted to add a call metatable to a userdata that does not have a constructor.", 2) - end - - debug.setmetatable(userdata, { - __call = function(obj, ...) - if userdata == nil then - error("userdata was nil.", 2) - end - - local success, result = pcall(userdata.__new, ...) - - - if not success then - error(result, 2) - end - - return result - end - }) -end - -luaUserData.CreateStatic = function(typeName) - if type(typeName) ~= "string" then - error("Expected a string for typeName, got " .. type(typeName) .. ".", 2) - end - - local success, result = pcall(clrLuaUserData.CreateStatic, typeName) - - if not success then - error(result, 2) - end - - if result == nil then - return - end - - if LuaUserData.HasMember(result, ".ctor") then - luaUserData.AddCallMetaTable(result) - end - - return result -end \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/Lua/PostSetup.lua b/Barotrauma/BarotraumaShared/Lua/PostSetup.lua index f54417259..d18dbc685 100644 --- a/Barotrauma/BarotraumaShared/Lua/PostSetup.lua +++ b/Barotrauma/BarotraumaShared/Lua/PostSetup.lua @@ -1,75 +1,13 @@ -if CSActive then - return +if not CSActive then + LuaUserDataIUUD = LuaUserData.RegisterType("Barotrauma.LuaSafeUserData") + LuaUserData = LuaUserData.CreateStatic("Barotrauma.LuaSafeUserData"); + + for k, v in pairs(debug) do + if k ~= "getmetatable" and k ~= "setmetatable" and k ~= "traceback" then + debug[k] = nil + end + end end -local function IsAllowed(typeName) - if string.startsWith(typeName, "Barotrauma.Lua") or string.startsWith(typeName, "Barotrauma.Cs") or string.startsWith(typeName, "Barotrauma.LuaCs") then - return false - end - - if string.startsWith(typeName, "System.Collections") then return true end - - if string.startsWith(typeName, "Microsoft.Xna") then return true end - - if string.startsWith(typeName, "Barotrauma.IO") then return false end - - if string.startsWith(typeName, "Barotrauma.ToolBox") then return false end - if string.startsWith(typeName, "Barotrauma.SaveUtil") then return false end - - if string.startsWith(typeName, "Barotrauma.") then return true end - - return false -end - -local function CanBeReRegistered(typeName) - if string.startsWith(typeName, "Barotrauma.Lua") or string.startsWith(typeName, "Barotrauma.Cs") or string.startsWith(typeName, "Barotrauma.LuaCs") then - return false - end - - return true -end - -local originalRegisterType = LuaUserData.RegisterType -LuaUserData.RegisterType = function (typeName) - if not (CanBeReRegistered(typeName) and LuaUserData.IsRegistered(typeName)) and not IsAllowed(typeName) then - error("Couldn't register type " .. typeName .. ".", 2) - end - - local success, result = pcall(originalRegisterType, typeName) - - if not success then - error(result, 2) - end - - return result -end - -local originalRegisterGenericType = LuaUserData.RegisterType -LuaUserData.RegisterGenericType = function (typeName, ...) - if not (CanBeReRegistered(typeName) and LuaUserData.IsRegistered(typeName)) and not IsAllowed(typeName) then - error("Couldn't register generic type " .. typeName .. ".", 2) - end - - local success, result = pcall(originalRegisterGenericType, typeName, ...) - - if not success then - error(result, 2) - end - - return result -end - -local originalCreateStatic = LuaUserData.CreateStatic -LuaUserData.CreateStatic = function (typeName, addCallMethod) - if not (CanBeReRegistered(typeName) and LuaUserData.IsRegistered(typeName)) and not IsAllowed(typeName) then - error("Couldn't create static type " .. typeName .. ".", 2) - end - - local success, result = pcall(originalCreateStatic, typeName, addCallMethod) - - if not success then - error(result, 2) - end - - return result -end \ No newline at end of file +Descriptors = LuaUserData.__new() +LuaUserDataIUUD = nil \ No newline at end of file diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaSafeUserData.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaSafeUserData.cs new file mode 100644 index 000000000..a7ba597fa --- /dev/null +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaSafeUserData.cs @@ -0,0 +1,196 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Reflection; +using MoonSharp.Interpreter; +using MoonSharp.Interpreter.Interop; + +namespace Barotrauma +{ + partial class LuaSafeUserData + { + public IUserDataDescriptor this[string index] + { + get => LuaUserData.Descriptors.GetValueOrDefault(index); + } + + private static bool CanBeRegistered(string typeName) + { + if (typeName.StartsWith("Barotrauma.Lua", StringComparison.Ordinal) || + typeName.StartsWith("Barotrauma.Cs", StringComparison.Ordinal) || + typeName.StartsWith("Barotrauma.LuaCs", StringComparison.Ordinal)) + { + return false; + } + + if (typeName == "System.Single") { return true; } + + if (typeName.StartsWith("System.Collections", StringComparison.Ordinal)) + return true; + + if (typeName.StartsWith("Microsoft.Xna", StringComparison.Ordinal)) + return true; + + if (typeName.StartsWith("Barotrauma.IO", StringComparison.Ordinal)) + return false; + + if (typeName.StartsWith("Barotrauma.ToolBox", StringComparison.Ordinal)) + return false; + + if (typeName.StartsWith("Barotrauma.SaveUtil", StringComparison.Ordinal)) + return false; + + if (typeName.StartsWith("Barotrauma.", StringComparison.Ordinal)) + return true; + + return false; + } + + private static bool CanBeReRegistered(string typeName) + { + if (typeName.StartsWith("Barotrauma.Lua", StringComparison.Ordinal) || + typeName.StartsWith("Barotrauma.Cs", StringComparison.Ordinal) || + typeName.StartsWith("Barotrauma.LuaCs", StringComparison.Ordinal)) + { + return false; + } + + return true; + } + + private static bool IsAllowed(string typeName) + { + if (!CanBeReRegistered(typeName) && LuaUserData.IsRegistered(typeName)) + { + return false; + } + + if (!CanBeRegistered(typeName)) + { + return false; + } + + return true; + } + + private static void CheckAllowed(string typeName) + { + if (!IsAllowed(typeName)) + { + throw new ScriptRuntimeException($"Type {typeName} can't be registered"); + } + } + + public static Type GetType(string typeName) + { + CheckAllowed(typeName); + + return LuaUserData.GetType(typeName); + } + + public static IUserDataDescriptor RegisterType(string typeName) + { + CheckAllowed(typeName); + + return LuaUserData.RegisterType(typeName); + } + + public static IUserDataDescriptor RegisterTypeBarotrauma(string typeName) + { + return RegisterType($"Barotrauma.{typeName}"); + } + + public static void RegisterExtensionType(string typeName) + { + CheckAllowed(typeName); + LuaUserData.RegisterExtensionType(typeName); + } + + public static bool IsRegistered(string typeName) + { + return LuaUserData.IsRegistered(typeName); + } + + public static void UnregisterType(string typeName, bool deleteHistory = false) + { + LuaUserData.UnregisterType(typeName, deleteHistory); + } + public static IUserDataDescriptor RegisterGenericType(string typeName, params string[] typeNameArguements) + { + CheckAllowed(typeName); + return LuaUserData.RegisterGenericType(typeName, typeNameArguements); + } + + public static void UnregisterGenericType(string typeName, params string[] typeNameArguements) + { + LuaUserData.UnregisterGenericType(typeName, typeNameArguements); + } + + public static bool IsTargetType(object obj, string typeName) + { + return LuaUserData.IsTargetType(obj, typeName); + } + + public static string TypeOf(object obj) + { + return LuaUserData.TypeOf(obj); + } + + public static object CreateStatic(string typeName) + { + CheckAllowed(typeName); + return LuaUserData.CreateStatic(typeName); + } + + public static object CreateEnumTable(string typeName) + { + return LuaUserData.CreateEnumTable(typeName); + } + + public static void MakeFieldAccessible(IUserDataDescriptor IUUD, string fieldName) + { + LuaUserData.MakeFieldAccessible(IUUD, fieldName); + } + + public static void MakeMethodAccessible(IUserDataDescriptor IUUD, string methodName, string[] parameters = null) + { + LuaUserData.MakeMethodAccessible(IUUD, methodName, parameters); + } + + public static void MakePropertyAccessible(IUserDataDescriptor IUUD, string propertyName) + { + LuaUserData.MakePropertyAccessible(IUUD, propertyName); + } + + public static void AddMethod(IUserDataDescriptor IUUD, string methodName, object function) + { + LuaUserData.AddMethod(IUUD, methodName, function); + } + + public static void AddField(IUserDataDescriptor IUUD, string fieldName, DynValue value) + { + LuaUserData.AddField(IUUD, fieldName, value); + } + + public static void RemoveMember(IUserDataDescriptor IUUD, string memberName) + { + LuaUserData.RemoveMember(IUUD, memberName); + } + + public static bool HasMember(object obj, string memberName) + { + return LuaUserData.HasMember(obj, memberName); + } + + public static DynValue CreateUserDataFromDescriptor(DynValue scriptObject, IUserDataDescriptor desiredTypeDescriptor) + { + return LuaUserData.CreateUserDataFromDescriptor(scriptObject, desiredTypeDescriptor); + } + + public static DynValue CreateUserDataFromType(DynValue scriptObject, Type desiredType) + { + return LuaUserData.CreateUserDataFromType(scriptObject, desiredType); + } + } +} diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs index a1e53d8eb..fd336066f 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Lua/LuaClasses/LuaUserData.cs @@ -1,14 +1,24 @@ -using System; +using MoonSharp.Interpreter; +using MoonSharp.Interpreter.Interop; +using System; +using System.Collections.Concurrent; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using System.Reflection; -using MoonSharp.Interpreter; -using MoonSharp.Interpreter.Interop; namespace Barotrauma { partial class LuaUserData { + public static ReadOnlyDictionary Descriptors => new ReadOnlyDictionary(descriptors); + private static ConcurrentDictionary descriptors = new ConcurrentDictionary(); + + public IUserDataDescriptor this[string index] + { + get => Descriptors.GetValueOrDefault(index); + } + public static Type GetType(string typeName) => LuaCsSetup.GetType(typeName); public static IUserDataDescriptor RegisterType(string typeName) @@ -20,7 +30,15 @@ namespace Barotrauma throw new ScriptRuntimeException($"tried to register a type that doesn't exist: {typeName}."); } - return UserData.RegisterType(type); + var descriptor = UserData.RegisterType(type); + descriptors.TryAdd(typeName, descriptor); + + return descriptor; + } + + public static IUserDataDescriptor RegisterTypeBarotrauma(string typeName) + { + return RegisterType($"Barotrauma.{typeName}"); } public static void RegisterExtensionType(string typeName) @@ -102,7 +120,9 @@ namespace Barotrauma MethodInfo method = typeof(UserData).GetMethod(nameof(UserData.CreateStatic), 1, new Type[0]); MethodInfo generic = method.MakeGenericMethod(type); - return generic.Invoke(null, null); + var result = generic.Invoke(null, null); + AddCallMetaTable(result); + return result; } public static object CreateEnumTable(string typeName) @@ -359,5 +379,47 @@ namespace Barotrauma descriptor ??= new StandardUserDataDescriptor(desiredType, InteropAccessMode.Default); return CreateUserDataFromDescriptor(scriptObject, descriptor); } + + public static void AddCallMetaTable(object userdata) + { + if (userdata == null) { return; } + + // not sure how to implement this in C# + var function = GameMain.LuaCs.Lua.LoadString(""" + local userdata = ... + if userdata == nil then + error("Attempted to add a call metatable to a nil value.", 2) + end + + if not LuaUserData.HasMember(userdata, ".ctor") then + return + end + + debug.setmetatable(userdata, { + __call = function(obj, ...) + if userdata == nil then + error("userdata was nil.", 2) + end + + local success, result = pcall(userdata.__new, ...) + + + if not success then + error(result, 2) + end + + return result + end + }) + """); + + GameMain.LuaCs.Lua.Call(function, userdata); + } + + + public static void Clear() + { + descriptors.Clear(); + } } } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs index d51193a50..fb2ef00f1 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/LuaCsSetup.cs @@ -347,6 +347,8 @@ namespace Barotrauma DebugServer.Detach(Lua); } + LuaUserData.Clear(); + Game?.Stop(); Hook?.Clear(); @@ -416,7 +418,7 @@ namespace Barotrauma UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); - UserData.RegisterType(); + var uuid = UserData.RegisterType(); UserData.RegisterType(); UserData.RegisterType(); @@ -433,6 +435,7 @@ namespace Barotrauma Lua.Globals["Logger"] = UserData.CreateStatic(); Lua.Globals["LuaUserData"] = UserData.CreateStatic(); + Lua.Globals["LuaUserDataIUUD"] = uuid; Lua.Globals["Game"] = Game; Lua.Globals["Hook"] = Hook; Lua.Globals["ModStore"] = ModStore;