Assembly and Script Loading Overhauled.

This commit is contained in:
MapleWheels
2023-09-17 01:04:51 -04:00
committed by Evil Factory
parent a58cb8251f
commit 414d46b33e
17 changed files with 2699 additions and 477 deletions

View File

@@ -24,11 +24,8 @@ RegisterBarotrauma("Media.Video")
RegisterBarotrauma("SoundsFile")
RegisterBarotrauma("SoundPrefab")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.SoundPrefab]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.BackgroundMusic]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.GUISound]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.DamageSound]]")
RegisterBarotrauma("PrefabSelector`1[[Barotrauma.SoundPrefab]]")
RegisterBarotrauma("PrefabCollection`1")
RegisterBarotrauma("PrefabSelector`1")
RegisterBarotrauma("BackgroundMusic")
RegisterBarotrauma("GUISound")
RegisterBarotrauma("DamageSound")
@@ -57,7 +54,6 @@ RegisterBarotrauma("Particles.Particle")
RegisterBarotrauma("Particles.ParticleEmitterProperties")
RegisterBarotrauma("Particles.ParticleEmitter")
RegisterBarotrauma("Particles.ParticlePrefab")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.Particles.ParticlePrefab]]")
RegisterBarotrauma("Lights.LightManager")
RegisterBarotrauma("Lights.LightSource")
@@ -146,4 +142,4 @@ RegisterBarotrauma("UISprite")
RegisterBarotrauma("ParamsEditor")
RegisterBarotrauma("Inventory+SlotReference")
RegisterBarotrauma("VisualSlot")
RegisterBarotrauma("VisualSlot")

View File

@@ -6,8 +6,8 @@ Register("System.Exception")
Register("System.Console")
Register("System.Exception")
RegisterBarotrauma("Success`2[[Barotrauma.ContentPackage],[System.Exception]]")
RegisterBarotrauma("Failure`2[[Barotrauma.ContentPackage],[System.Exception]]")
RegisterBarotrauma("Success`2")
RegisterBarotrauma("Failure`2")
RegisterBarotrauma("LuaSByte")
RegisterBarotrauma("LuaByte")
@@ -24,8 +24,7 @@ RegisterBarotrauma("GameMain")
RegisterBarotrauma("Networking.BanList")
RegisterBarotrauma("Networking.BannedPlayer")
RegisterBarotrauma("Range`1[System.Single]")
RegisterBarotrauma("Range`1[System.Int32]")
RegisterBarotrauma("Range`1")
RegisterBarotrauma("RichString")
RegisterBarotrauma("Identifier")
@@ -399,27 +398,11 @@ end
RegisterBarotrauma("Camera")
RegisterBarotrauma("Key")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.ItemPrefab]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.JobPrefab]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.CharacterPrefab]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.AfflictionPrefab]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.TalentPrefab]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.TalentTree]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.OrderPrefab]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.LevelGenerationParams]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.OutpostGenerationParams]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.RuinGeneration.RuinGenerationParams]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.LevelGenerationParams]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.LocationType]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.EventPrefab]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.EventSet]]")
RegisterBarotrauma("PrefabCollection`1[[Barotrauma.EventManagerSettings]]")
RegisterBarotrauma("PrefabCollection`1")
RegisterBarotrauma("PrefabSelector`1[[Barotrauma.SkillSettings]]")
RegisterBarotrauma("PrefabSelector`1")
RegisterBarotrauma("Pair`2[[Barotrauma.JobPrefab],[System.Int32]]")
RegisterBarotrauma("Range`1[System.Single]")
RegisterBarotrauma("Pair`2")
RegisterBarotrauma("Items.Components.Signal")
RegisterBarotrauma("SubmarineInfo")
@@ -461,4 +444,4 @@ LuaUserData.RemoveMember(workshopItem, "AddFavorite")
LuaUserData.RemoveMember(workshopItem, "RemoveFavorite")
LuaUserData.RemoveMember(workshopItem, "Vote")
LuaUserData.RemoveMember(workshopItem, "GetUserVote")
LuaUserData.RemoveMember(workshopItem, "Edit")
LuaUserData.RemoveMember(workshopItem, "Edit")

View File

@@ -1,50 +0,0 @@
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.CodeAnalysis.Scripting;
using System.Reflection;
using Microsoft.CodeAnalysis.CSharp;
using System.Linq;
using Microsoft.CodeAnalysis;
using System.Runtime.Loader;
using System.Reflection.PortableExecutable;
using System.Reflection.Metadata;
using System.Text;
using System.Runtime.CompilerServices;
namespace Barotrauma
{
class CsScriptBase : AssemblyLoadContext
{
public const string CsScriptAssembly = "NetScriptAssembly";
public static readonly string[] LoadedAssemblyName = {
CsScriptBase.CsScriptAssembly
};
public static Dictionary<string, object> Revision = new Dictionary<string, object>()
{
{ CsScriptAssembly, 0}
};
public CSharpParseOptions ParseOptions { get; protected set; }
public CsScriptBase() : base(isCollectible: true) {
ParseOptions = CSharpParseOptions.Default
.WithPreprocessorSymbols(new[] { LuaCsSetup.IsServer ? "SERVER" : (LuaCsSetup.IsClient ? "CLIENT" : "UNDEFINED") });
}
public static SyntaxTree AssemblyInfoSyntaxTree(string asmName = null)
{
Revision[asmName] = (int)Revision[asmName] + 1;
var asmInfo = new StringBuilder();
asmInfo.AppendLine("using System.Reflection;");
asmInfo.AppendLine($"[assembly: AssemblyMetadata(\"Revision\", \"{Revision[asmName]}\")]");
asmInfo.AppendLine($"[assembly: AssemblyVersion(\"0.0.0.{Revision[asmName]}\")]");
return CSharpSyntaxTree.ParseText(asmInfo.ToString(), CSharpParseOptions.Default);
}
~CsScriptBase() { }
}
}

View File

@@ -1,286 +0,0 @@
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.CodeAnalysis.Scripting;
using System.Reflection;
using Microsoft.CodeAnalysis.CSharp;
using System.Linq;
using Microsoft.CodeAnalysis;
using System.Runtime.Loader;
using System.Reflection.PortableExecutable;
using System.Reflection.Metadata;
using System.Text.RegularExpressions;
using System.Xml.Linq;
namespace Barotrauma
{
class CsScriptLoader : CsScriptBase
{
private List<MetadataReference> defaultReferences;
private Dictionary<string, List<string>> sources;
public Assembly Assembly { get; private set; }
public CsScriptLoader()
{
defaultReferences = AppDomain.CurrentDomain.GetAssemblies()
.Where(a => !(a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit")))
.Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference)
.ToList();
sources = new Dictionary<string, List<string>>();
Assembly = null;
}
private enum RunType { Standard, Forced, None };
private bool ShouldRun(ContentPackage cp, string path)
{
if (!Directory.Exists(path + "CSharp"))
{
return false;
}
var isEnabled = ContentPackageManager.EnabledPackages.All.Contains(cp);
if (File.Exists(path + "CSharp/RunConfig.xml"))
{
Stream stream = File.Open(path + "CSharp/RunConfig.xml", FileMode.Open, FileAccess.Read, FileShare.ReadWrite);
var doc = XDocument.Load(stream);
var elems = doc.Root.Elements().ToArray();
var elem = elems.FirstOrDefault(e => e.Name.LocalName.Equals(LuaCsSetup.IsServer ? "Server" : (LuaCsSetup.IsClient ? "Client" : "None"), StringComparison.OrdinalIgnoreCase));
if (elem != null && Enum.TryParse(elem.Value, true, out RunType rtValue))
{
if (rtValue == RunType.Standard && isEnabled)
{
LuaCsLogger.LogMessage($"Added {cp.Name} {cp.ModVersion} to Cs compilation. (Standard)");
return true;
}
else if (rtValue == RunType.Forced && (isEnabled || !GameMain.LuaCs.Config.TreatForcedModsAsNormal))
{
LuaCsLogger.LogMessage($"Added {cp.Name} {cp.ModVersion} to Cs compilation. (Forced)");
return true;
}
else if (rtValue == RunType.None)
{
return false;
}
}
stream.Close();
}
if (isEnabled)
{
LuaCsLogger.LogMessage($"Added {cp.Name} {cp.ModVersion} to Cs compilation. (Assumed)");
return true;
}
else
{
return false;
}
}
public void SearchFolders()
{
var packagesAdded = new HashSet<ContentPackage>();
var paths = new Dictionary<string, string>();
foreach (var cp in ContentPackageManager.AllPackages.Concat(ContentPackageManager.EnabledPackages.All))
{
if (packagesAdded.Contains(cp)) { continue; }
var path = $"{Path.GetFullPath(Path.GetDirectoryName(cp.Path)).Replace('\\', '/')}/";
if (ShouldRun(cp, path))
{
if (paths.ContainsKey(cp.Name))
{
if (ContentPackageManager.EnabledPackages.All.Contains(cp))
{
paths[cp.Name] = path;
}
}
else
{
paths.Add(cp.Name, path);
}
packagesAdded.Add(cp);
}
}
foreach ((var _, var path) in paths)
{
RunFolder(path);
}
}
public bool HasSources { get => sources.Count > 0; }
private void AddSources(string folder)
{
foreach (var str in DirSearch(folder))
{
string s = str.Replace("\\", "/");
if (sources.ContainsKey(folder))
{
sources[folder].Add(s);
}
else
{
sources.Add(folder, new List<string> { s });
}
}
}
private void RunFolder(string folder)
{
AddSources(folder + "/CSharp/Shared");
#if SERVER
AddSources(folder + "/CSharp/Server");
#else
AddSources(folder + "/CSharp/Client");
#endif
}
private IEnumerable<SyntaxTree> ParseSources() {
var syntaxTrees = new List<SyntaxTree>();
if (sources.Count <= 0) throw new Exception("No Cs sources detected");
syntaxTrees.Add(AssemblyInfoSyntaxTree(CsScriptAssembly));
foreach ((var folder, var src) in sources)
{
try
{
foreach (var file in src)
{
var tree = SyntaxFactory.ParseSyntaxTree(File.ReadAllText(file), ParseOptions, file);
syntaxTrees.Add(tree);
}
}
catch (Exception ex)
{
LuaCsLogger.LogError("Error loading '" + folder + "':\n" + ex.Message + "\n" + ex.StackTrace, LuaCsMessageOrigin.CSharpMod);
}
}
return syntaxTrees;
}
private ContentPackage FindSourcePackage(Diagnostic diagnostic)
{
if (diagnostic.Location.SourceTree == null)
{
return null;
}
string path = diagnostic.Location.SourceTree.FilePath;
foreach (var package in ContentPackageManager.AllPackages)
{
if (Path.GetFullPath(path).StartsWith(Path.GetFullPath(package.Dir)))
{
return package;
}
}
return null;
}
public List<Type> Compile()
{
IEnumerable<SyntaxTree> syntaxTrees = ParseSources();
var options = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)
.WithMetadataImportOptions(MetadataImportOptions.All)
.WithOptimizationLevel(OptimizationLevel.Release)
.WithAllowUnsafe(true);
var topLevelBinderFlagsProperty = typeof(CSharpCompilationOptions).GetProperty("TopLevelBinderFlags", BindingFlags.Instance | BindingFlags.NonPublic);
topLevelBinderFlagsProperty.SetValue(options, (uint)1 << 22);
var compilation = CSharpCompilation.Create(CsScriptAssembly, syntaxTrees, defaultReferences, options);
using (var mem = new MemoryStream())
{
var result = compilation.Emit(mem);
if (!result.Success)
{
IEnumerable<Diagnostic> failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error);
string errStr = "CS MODS NOT LOADED | Compilation errors:";
foreach (Diagnostic diagnostic in failures)
{
errStr += $"\n{diagnostic}";
#if CLIENT
ContentPackage package = FindSourcePackage(diagnostic);
if (package != null)
{
LuaCsLogger.ShowErrorOverlay($"{package.Name} {package.ModVersion} is causing compilation errors. Check debug console for more details.", 7f, 7f);
}
#endif
}
LuaCsLogger.LogError(errStr, LuaCsMessageOrigin.CSharpMod);
}
else
{
mem.Seek(0, SeekOrigin.Begin);
Assembly = LoadFromStream(mem);
}
}
if (Assembly != null)
{
RegisterAssemblyWithNativeGame(Assembly);
try
{
return Assembly.GetTypes().Where(t => t.IsSubclassOf(typeof(ACsMod))).ToList();
}
catch (ReflectionTypeLoadException re)
{
LuaCsLogger.LogError($"Unable to load CsMod Types. {re.Message}", LuaCsMessageOrigin.CSharpMod);
throw re;
}
}
else
{
throw new Exception("Unable to create cs mods assembly.");
}
}
/// <summary>
/// This function should be used whenever a new assembly is created. Wrapper to allow more complicated setup later if need be.
/// </summary>
private static void RegisterAssemblyWithNativeGame(Assembly assembly)
{
Barotrauma.ReflectionUtils.AddNonAbstractAssemblyTypes(assembly);
}
/// <summary>
/// This function should be used whenever a new assembly is about to be destroyed/unloaded. Wrapper to allow more complicated setup later if need be.
/// </summary>
/// <param name="assembly">Assembly to remove</param>
private static void UnregisterAssemblyFromNativeGame(Assembly assembly)
{
Barotrauma.ReflectionUtils.RemoveAssemblyFromCache(assembly);
}
private static string[] DirSearch(string sDir)
{
if (!Directory.Exists(sDir))
{
return new string[] {};
}
return Directory.GetFiles(sDir, "*.cs", SearchOption.AllDirectories);
}
public void Clear()
{
if (Assembly != null)
{
UnregisterAssemblyFromNativeGame(Assembly);
Assembly = null;
}
}
}
}

View File

@@ -1,8 +0,0 @@
using System;
using MoonSharp.Interpreter;
using Barotrauma.Networking;
namespace Barotrauma
{
}

View File

@@ -9,8 +9,8 @@ using LuaCsCompatPatchFunc = Barotrauma.LuaCsPatch;
using System.Diagnostics;
using MoonSharp.VsCodeDebugger;
using System.Reflection;
using System.Runtime.Loader;
[assembly: InternalsVisibleTo(Barotrauma.CsScriptBase.CsScriptAssembly, AllInternalsVisible = true)]
namespace Barotrauma
{
class LuaCsSetupConfig
@@ -66,12 +66,18 @@ namespace Barotrauma
public LuaCsSteam Steam { get; private set; }
public LuaCsPerformanceCounter PerformanceCounter { get; private set; }
// must be available at anytime
private static AssemblyManager _assemblyManager;
public static AssemblyManager AssemblyManager => _assemblyManager ??= new AssemblyManager();
private CsPackageManager _pluginPackageManager;
public CsPackageManager PluginPackageManager => _pluginPackageManager ??= new CsPackageManager(AssemblyManager, this);
public LuaCsModStore ModStore { get; private set; }
private LuaRequire require { get; set; }
public CsScriptLoader CsScriptLoader { get; private set; }
public LuaCsSetupConfig Config { get; private set; }
public MoonSharpVsCodeDebugServer DebugServer { get; private set; }
public bool IsInitialized { get; private set; }
private bool ShouldRunCs
{
@@ -90,7 +96,6 @@ namespace Barotrauma
Game = new LuaGame();
Networking = new LuaCsNetworking();
DebugServer = new MoonSharpVsCodeDebugServer();
if (File.Exists(configFileName))
@@ -105,35 +110,11 @@ namespace Barotrauma
Config = new LuaCsSetupConfig();
}
}
[Obsolete("Use AssemblyManager::GetTypesByName()")]
public static Type GetType(string typeName, bool throwOnError = false, bool ignoreCase = false)
{
if (typeName == null || typeName.Length == 0) { return null; }
var byRef = false;
if (typeName.StartsWith("out ") || typeName.StartsWith("ref "))
{
typeName = typeName.Remove(0, 4);
byRef = true;
}
var type = Type.GetType(typeName, throwOnError, ignoreCase);
if (type != null) { return byRef ? type.MakeByRefType() : type; }
foreach (var a in AppDomain.CurrentDomain.GetAssemblies())
{
if (CsScriptBase.LoadedAssemblyName.Contains(a.GetName().Name))
{
var attrs = a.GetCustomAttributes<AssemblyMetadataAttribute>();
var revision = attrs.FirstOrDefault(attr => attr.Key == "Revision")?.Value;
if (revision != null && int.Parse(revision) != (int)CsScriptBase.Revision[a.GetName().Name]) { continue; }
}
type = a.GetType(typeName, throwOnError, ignoreCase);
if (type != null)
{
return byRef ? type.MakeByRefType() : type;
}
}
return null;
return AssemblyManager.GetTypesByName(typeName).FirstOrDefault((Type)null);
}
public void ToggleDebugger(int port = 41912)
@@ -293,17 +274,21 @@ namespace Barotrauma
public void Stop()
{
foreach (var type in AppDomain.CurrentDomain.GetAssemblies().Where(a => a.GetName().Name == CsScriptBase.CsScriptAssembly).SelectMany(assembly => assembly.GetTypes()))
// unregister types
foreach (Type type in AssemblyManager.GetAllLoadedACLs().SelectMany(
acl => acl.AssembliesTypes.Select(kvp => kvp.Value)))
{
UserData.UnregisterType(type, true);
}
foreach (var mod in ACsMod.LoadedMods.ToArray())
{
mod.Dispose();
}
ACsMod.LoadedMods.Clear();
PluginPackageManager.UnloadPlugins(); // stop plugin code execution
if (Lua?.Globals is not null)
{
Lua.Globals.Remove("CsPackageManager");
Lua.Globals.Remove("AssemblyManager");
}
if (Thread.CurrentThread == GameMain.MainThread)
{
@@ -317,27 +302,27 @@ namespace Barotrauma
Game?.Stop();
Hook.Clear();
Hook?.Clear();
ModStore.Clear();
LuaScriptLoader = null;
Lua = null;
// we can only unload assemblies after clearing ModStore/references.
PluginPackageManager.Dispose();
Game = new LuaGame();
Networking = new LuaCsNetworking();
Timer = new LuaCsTimer();
Steam = new LuaCsSteam();
PerformanceCounter = new LuaCsPerformanceCounter();
LuaScriptLoader = null;
Lua = null;
if (CsScriptLoader != null)
{
CsScriptLoader.Clear();
CsScriptLoader.Unload();
CsScriptLoader = null;
}
IsInitialized = false;
}
public void Initialize(bool forceEnableCs = false)
{
Stop();
if (IsInitialized)
Stop();
LuaCsLogger.LogMessage("Lua! Version " + AssemblyInfo.GitRevision);
@@ -380,6 +365,9 @@ namespace Barotrauma
UserData.RegisterType<LuaUserData>();
UserData.RegisterType<LuaCsPerformanceCounter>();
UserData.RegisterType<IUserDataDescriptor>();
UserData.RegisterType<CsPackageManager>();
UserData.RegisterType<AssemblyManager>();
UserData.RegisterType<IAssemblyPlugin>();
UserData.RegisterExtensionType(typeof(MathUtils));
UserData.RegisterExtensionType(typeof(XMLExtensions));
@@ -430,65 +418,80 @@ namespace Barotrauma
DebugConsole.AddWarning("Cs package active! Cs mods are NOT sandboxed, use it at your own risk!");
}
CsScriptLoader = new CsScriptLoader();
CsScriptLoader.SearchFolders();
if (CsScriptLoader.HasSources)
Lua.Globals["PluginPackageManager"] = PluginPackageManager;
Lua.Globals["AssemblyManager"] = AssemblyManager;
try
{
try
Stopwatch taskTimer = new();
taskTimer.Start();
ModStore.Clear();
var state = PluginPackageManager.LoadAssemblyPackages();
if (state is AssemblyLoadingSuccessState.Success or AssemblyLoadingSuccessState.AlreadyLoaded)
{
Stopwatch compilationTime = new Stopwatch();
compilationTime.Start();
var modTypes = CsScriptLoader.Compile();
modTypes.ForEach(t =>
{
t.GetConstructor(new Type[] { })?.Invoke(null);
});
compilationTime.Stop();
LuaCsLogger.LogMessage($"Took {compilationTime.ElapsedMilliseconds}ms to compile and run Cs Scripts.");
if(!PluginPackageManager.PluginsInitialized)
PluginPackageManager.InstantiatePlugins(true);
if(!PluginPackageManager.PluginsPreInit)
PluginPackageManager.RunPluginsPreInit(); // this is intended to be called at startup in the future
if(!PluginPackageManager.PluginsLoaded)
PluginPackageManager.RunPluginsInit();
state = AssemblyLoadingSuccessState.Success;
taskTimer.Stop();
ModUtils.Logging.PrintMessage($"{nameof(LuaCsSetup)}: Completed assembly loading. Total time {taskTimer.ElapsedMilliseconds}ms.");
}
catch (Exception ex)
else
{
LuaCsLogger.HandleException(ex, LuaCsMessageOrigin.CSharpMod);
PluginPackageManager.Dispose(); // cleanup if there's an error
}
if(state is not AssemblyLoadingSuccessState.Success)
{
ModUtils.Logging.PrintError($"{nameof(LuaCsSetup)}: Error while loading Cs-Assembly Mods | Err: {state}");
taskTimer.Stop();
}
}
catch (Exception e)
{
ModUtils.Logging.PrintError($"{nameof(LuaCsSetup)}::{nameof(Initialize)}() | Error while loading assemblies! Details: {e.Message} | {e.StackTrace}");
}
IsInitialized = true;
}
ContentPackage luaPackage = GetPackage(LuaForBarotraumaId);
void runLocal()
void RunLocal()
{
LuaCsLogger.LogMessage("Using LuaSetup.lua from the Barotrauma Lua/ folder.");
string luaPath = LuaSetupFile;
CallLuaFunction(Lua.LoadFile(luaPath), Path.GetDirectoryName(Path.GetFullPath(luaPath)));
}
void runWorkshop()
void RunWorkshop()
{
LuaCsLogger.LogMessage("Using LuaSetup.lua from the content package.");
string luaPath = Path.Combine(Path.GetDirectoryName(luaPackage.Path), "Binary/Lua/LuaSetup.lua");
CallLuaFunction(Lua.LoadFile(luaPath), Path.GetDirectoryName(Path.GetFullPath(luaPath)));
}
void runNone()
void RunNone()
{
LuaCsLogger.LogError("LuaSetup.lua not found! Lua/LuaSetup.lua, no Lua scripts will be executed or work.", LuaCsMessageOrigin.LuaMod);
}
if (Config.PreferToUseWorkshopLuaSetup)
{
if (luaPackage != null) { runWorkshop(); }
else if (File.Exists(LuaSetupFile)) { runLocal(); }
else { runNone(); }
if (luaPackage != null) { RunWorkshop(); }
else if (File.Exists(LuaSetupFile)) { RunLocal(); }
else { RunNone(); }
}
else
{
if (File.Exists(LuaSetupFile)) { runLocal(); }
else if (luaPackage != null) { runWorkshop(); }
else { runNone(); }
if (File.Exists(LuaSetupFile)) { RunLocal(); }
else if (luaPackage != null) { RunWorkshop(); }
else { RunNone(); }
}
executionNumber++;

View File

@@ -4,6 +4,7 @@ using MoonSharp.Interpreter;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.IO;
using System.Linq;
@@ -259,14 +260,22 @@ namespace Barotrauma
private static Type[] LoadDocTypes(XElement typesElem)
{
var result = new List<Type>();
var loadedTypes = LuaCsSetup.AssemblyManager
.GetAllTypesInLoadedAssemblies()
.ToImmutableHashSet();
foreach (var elem in typesElem.Elements())
{
var type = Type.GetType(elem.Value);
if (type == null && GameMain.LuaCs?.CsScriptLoader?.Assembly != null) type = GameMain.LuaCs.CsScriptLoader.Assembly.GetType(elem.Value);
if (type == null) throw new Exception($"Type {elem.Value} not found.");
result.Add(type);
var typesFound = loadedTypes.Where(t => t.FullName?.EndsWith(elem.Value) ?? false).ToImmutableList();
if (!typesFound.Any())
{
ModUtils.Logging.PrintError(
$"{nameof(LuaCsConfig)}::{nameof(LoadDocTypes)}() | Unable to find a matching type for {elem.Value}");
continue;
}
result.AddRange(typesFound);
}
return result.ToArray();
}

View File

@@ -0,0 +1,331 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Xml.Serialization;
using Barotrauma;
using Barotrauma.Items.Components;
using Barotrauma.Networking;
using Microsoft.CodeAnalysis;
namespace Barotrauma;
public static class ModUtils
{
#region LOGGING
public static class Logging
{
public static void PrintMessage(string s)
{
#if SERVER
LuaCsLogger.LogMessage($"[Server] {s}");
#else
LuaCsLogger.LogMessage($"[Client] {s}");
#endif
}
public static void PrintError(string s)
{
#if SERVER
LuaCsLogger.LogError($"[Server] {s}");
#else
LuaCsLogger.LogError($"[Client] {s}");
#endif
}
}
#endregion
#region FILE_IO
// ReSharper disable once InconsistentNaming
public static class IO
{
public static IEnumerable<string> FindAllFilesInDirectory(string folder, string pattern,
SearchOption option)
{
try
{
return Directory.GetFiles(folder, pattern, option);
}
catch (DirectoryNotFoundException e)
{
return new string[] { };
}
}
public static string PrepareFilePathString(string filePath) =>
PrepareFilePathString(Path.GetDirectoryName(filePath)!, Path.GetFileName(filePath));
public static string PrepareFilePathString(string path, string fileName) =>
Path.Combine(SanitizePath(path), SanitizeFileName(fileName));
public static string SanitizeFileName(string fileName)
{
foreach (char c in Barotrauma.IO.Path.GetInvalidFileNameCharsCrossPlatform())
fileName = fileName.Replace(c, '_');
return fileName;
}
/// <summary>
/// Gets the sanitized path for the top-level directory for a given content package.
/// </summary>
/// <param name="package"></param>
/// <returns></returns>
public static string GetContentPackageDir(ContentPackage package)
{
return SanitizePath(Path.GetFullPath(package.Dir));
}
public static string SanitizePath(string path)
{
foreach (char c in Path.GetInvalidPathChars())
path = path.Replace(c.ToString(), "_");
return path.CleanUpPath();
}
public static IOActionResultState GetOrCreateFileText(string filePath, out string fileText, Func<string> fileDataFactory = null, bool createFile = true)
{
fileText = null;
string fp = Path.GetFullPath(SanitizePath(filePath));
IOActionResultState ioActionResultState = IOActionResultState.Success;
if (createFile)
{
ioActionResultState = CreateFilePath(SanitizePath(filePath), out fp, fileDataFactory);
}
else if (!File.Exists(fp))
{
return IOActionResultState.FileNotFound;
}
if (ioActionResultState == IOActionResultState.Success)
{
try
{
fileText = File.ReadAllText(fp!);
return IOActionResultState.Success;
}
catch (ArgumentNullException ane)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is null. path: {fp ?? "null"} | Exception Details: {ane.Message}");
return IOActionResultState.FilePathNull;
}
catch (ArgumentException ae)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is invalid. path: {fp ?? "null"} | Exception Details: {ae.Message}");
return IOActionResultState.FilePathInvalid;
}
catch (DirectoryNotFoundException dnfe)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Cannot find directory. path: {fp ?? "null"} | Exception Details: {dnfe.Message}");
return IOActionResultState.DirectoryMissing;
}
catch (PathTooLongException ptle)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: path length is over 200 characters. path: {fp ?? "null"} | Exception Details: {ptle.Message}");
return IOActionResultState.PathTooLong;
}
catch (NotSupportedException nse)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Operation not supported on your platform/environment (permissions?). path: {fp ?? "null"} | Exception Details: {nse.Message}");
return IOActionResultState.InvalidOperation;
}
catch (IOException ioe)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: IO tasks failed (Operation not supported). path: {fp ?? "null"} | Exception Details: {ioe.Message}");
return IOActionResultState.IOFailure;
}
catch (Exception e)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Unknown/Other Exception. path: {fp ?? "null"} | ExceptionMessage: {e.Message}");
return IOActionResultState.UnknownError;
}
}
return ioActionResultState;
}
public static IOActionResultState CreateFilePath(string filePath, out string formattedFilePath, Func<string> fileDataFactory = null)
{
string file = Path.GetFileName(filePath);
string path = Path.GetDirectoryName(filePath)!;
formattedFilePath = IO.PrepareFilePathString(path, file);
try
{
if (!Directory.Exists(path))
Directory.CreateDirectory(path);
if (!File.Exists(formattedFilePath))
File.WriteAllText(formattedFilePath, fileDataFactory is null ? "" : fileDataFactory.Invoke());
return IOActionResultState.Success;
}
catch (ArgumentNullException ane)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is null. path: {formattedFilePath ?? "null"} | Exception Details: {ane.Message}");
return IOActionResultState.FilePathNull;
}
catch (ArgumentException ae)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: An argument is invalid. path: {formattedFilePath ?? "null"} | Exception Details: {ae.Message}");
return IOActionResultState.FilePathInvalid;
}
catch (DirectoryNotFoundException dnfe)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Cannot find directory. path: {path ?? "null"} | Exception Details: {dnfe.Message}");
return IOActionResultState.DirectoryMissing;
}
catch (PathTooLongException ptle)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: path length is over 200 characters. path: {formattedFilePath ?? "null"} | Exception Details: {ptle.Message}");
return IOActionResultState.PathTooLong;
}
catch (NotSupportedException nse)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Operation not supported on your platform/environment (permissions?). path: {formattedFilePath ?? "null"} | Exception Details: {nse.Message}");
return IOActionResultState.InvalidOperation;
}
catch (IOException ioe)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: IO tasks failed (Operation not supported). path: {formattedFilePath ?? "null"} | Exception Details: {ioe.Message}");
return IOActionResultState.IOFailure;
}
catch (Exception e)
{
ModUtils.Logging.PrintError($"ModUtils::CreateFilePath() | Exception: Unknown/Other Exception. path: {path ?? "null"} | Exception Details: {e.Message}");
return IOActionResultState.UnknownError;
}
}
public static IOActionResultState WriteFileText(string filePath, string fileText)
{
IOActionResultState ioActionResultState = CreateFilePath(filePath, out var fp);
if (ioActionResultState == IOActionResultState.Success)
{
try
{
File.WriteAllText(fp!, fileText);
return IOActionResultState.Success;
}
catch (ArgumentNullException ane)
{
ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: An argument is null. path: {fp ?? "null"} | Exception Details: {ane.Message}");
return IOActionResultState.FilePathNull;
}
catch (ArgumentException ae)
{
ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: An argument is invalid. path: {fp ?? "null"} | Exception Details: {ae.Message}");
return IOActionResultState.FilePathInvalid;
}
catch (DirectoryNotFoundException dnfe)
{
ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: Cannot find directory. path: {fp ?? "null"} | Exception Details: {dnfe.Message}");
return IOActionResultState.DirectoryMissing;
}
catch (PathTooLongException ptle)
{
ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: path length is over 200 characters. path: {fp ?? "null"} | Exception Details: {ptle.Message}");
return IOActionResultState.PathTooLong;
}
catch (NotSupportedException nse)
{
ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: Operation not supported on your platform/environment (permissions?). path: {fp ?? "null"} | Exception Details: {nse.Message}");
return IOActionResultState.InvalidOperation;
}
catch (IOException ioe)
{
ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: IO tasks failed (Operation not supported). path: {fp ?? "null"} | Exception Details: {ioe.Message}");
return IOActionResultState.IOFailure;
}
catch (Exception e)
{
ModUtils.Logging.PrintError($"ModUtils::WriteFileText() | Exception: Unknown/Other Exception. path: {fp ?? "null"} | ExceptionMessage: {e.Message}");
return IOActionResultState.UnknownError;
}
}
return ioActionResultState;
}
/// <summary>
///
/// </summary>
/// <param name="instance"></param>
/// <param name="filepath"></param>
/// <param name="typeFactory"></param>
/// <param name="createFile"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public static bool LoadOrCreateTypeXml<T>(out T instance,
string filepath, Func<T> typeFactory = null, bool createFile = true) where T : class, new()
{
instance = null;
filepath = filepath.CleanUpPath();
if (IOActionResultState.Success == GetOrCreateFileText(
filepath, out string fileText, typeFactory is not null ? () =>
{
using StringWriter sw = new StringWriter();
T t = typeFactory?.Invoke();
if (t is not null)
{
XmlSerializer s = new XmlSerializer(typeof(T));
s.Serialize(sw, t);
return sw.ToString();
}
return "";
} : null, createFile))
{
XmlSerializer s = new XmlSerializer(typeof(T));
try
{
using TextReader tr = new StringReader(fileText);
instance = (T)s.Deserialize(tr);
return true;
}
catch(InvalidOperationException ioe)
{
ModUtils.Logging.PrintError($"Error while parsing type data for {typeof(T)}.");
#if DEBUG
ModUtils.Logging.PrintError($"Exception: {ioe.Message}. Details: {ioe.InnerException?.Message}");
#endif
instance = null;
return false;
}
}
return false;
}
public enum IOActionResultState
{
Success, FileNotFound, FilePathNull, FilePathInvalid, DirectoryMissing, PathTooLong, InvalidOperation, IOFailure, UnknownError
}
}
#endregion
#region GAME
public static class Game
{
/// <summary>
/// Returns whether or not there is a round running.
/// </summary>
/// <returns></returns>
public static bool IsRoundInProgress()
{
#if CLIENT
if (Screen.Selected is not null
&& Screen.Selected.IsEditor)
return false;
#endif
return GameMain.GameSession is not null && Level.Loaded is not null;
}
}
#endregion
}

View File

@@ -1,11 +1,11 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Reflection;
namespace Barotrauma
{
public abstract class ACsMod : IDisposable
[Obsolete("Make your class implement IAssemblyPlugin instead.")]
public abstract class ACsMod : IAssemblyPlugin
{
private static List<ACsMod> mods = new List<ACsMod>();
public static List<ACsMod> LoadedMods { get => mods; }
@@ -18,7 +18,6 @@ namespace Barotrauma
if (!Directory.Exists(modFolder)) Directory.CreateDirectory(modFolder);
return modFolder;
}
public static string GetSoreFolder<T>() where T : ACsMod => GetStoreFolder<T>();
public bool IsDisposed { get; private set; }
@@ -29,7 +28,23 @@ namespace Barotrauma
LoadedMods.Add(this);
}
public void Dispose()
/// <summary>
/// Called as soon as plugin loading begins, use this for internal setup only.
/// </summary>
public virtual void Initialize() { }
/// <summary>
/// Called once all plugins have completed Initialization. Put cross-mod code here.
/// </summary>
public virtual void OnLoadCompleted() { }
/// <summary>
/// [NotImplemented] Called before vanilla content is loaded. Use to patch Barotrauma classes before they're
/// instantiated.
/// </summary>
public void PreInitPatching() { }
public virtual void Dispose()
{
try
{
@@ -43,8 +58,7 @@ namespace Barotrauma
LoadedMods.Remove(this);
IsDisposed = true;
}
/// Error or client exit
public abstract void Stop();
}
}

View File

@@ -0,0 +1,6 @@
namespace Barotrauma;
public enum ApplicationMode
{
Client, Server
}

View File

@@ -0,0 +1,15 @@
namespace Barotrauma;
public enum AssemblyLoadingSuccessState
{
ACLLoadFailure,
AlreadyLoaded,
BadFilePath,
CannotLoadFile,
InvalidAssembly,
NoAssemblyFound,
PluginInstanceFailure,
BadName,
CannotLoadFromStream,
Success
}

View File

@@ -0,0 +1,772 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.Loader;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
// ReSharper disable EventNeverSubscribedTo.Global
// ReSharper disable InconsistentNaming
namespace Barotrauma;
/***
* Note: This class was written to be thread-safe in order to allow parallelization in loading in the future if the need
* becomes necessary as there is almost no serial performance overhead for adding threading protection.
*/
/// <summary>
/// Provides functionality for the loading, unloading and management of plugins implementing IAssemblyPlugin.
/// All plugins are loaded into their own AssemblyLoadContext along with their dependencies.
/// </summary>
public partial class AssemblyManager
{
#region ExternalAPI
/// <summary>
/// Called when an assembly is loaded.
/// </summary>
public event Action<Assembly> OnAssemblyLoaded;
/// <summary>
/// Called when an assembly is marked for unloading, before unloading begins. You should use this to cleanup
/// any references that you have to this assembly.
/// </summary>
public event Action<Assembly> OnAssemblyUnloading;
/// <summary>
/// Called whenever an exception is thrown. First arg is a formatted message, Second arg is the Exception.
/// </summary>
public event Action<string, Exception> OnException;
/// <summary>
/// For unloading issue debugging. Called whenever MemoryFileAssemblyContextLoader [load context] is unloaded.
/// </summary>
public event Action<Guid> OnACLUnload;
#if DEBUG
/// <summary>
/// [DEBUG ONLY]
/// Returns a list of the current unloading ACLs.
/// </summary>
public ImmutableList<WeakReference<MemoryFileAssemblyContextLoader>> StillUnloadingACLs
{
get
{
OpsLockUnloaded.EnterReadLock();
try
{
return UnloadingACLs.ToImmutableList();
}
finally
{
OpsLockUnloaded.ExitReadLock();
}
}
}
#endif
// ReSharper disable once MemberCanBePrivate.Global
/// <summary>
/// Checks if there are any AssemblyLoadContexts still in the process of unloading.
/// </summary>
public bool IsCurrentlyUnloading
{
get
{
OpsLockUnloaded.EnterReadLock();
try
{
return UnloadingACLs.Any();
}
catch (Exception)
{
return false;
}
finally
{
OpsLockUnloaded.ExitReadLock();
}
}
}
// Old API compatibility
public IEnumerable<Type> GetSubTypesInLoadedAssemblies<T>()
{
return GetSubTypesInLoadedAssemblies<T>(false);
}
/// <summary>
/// Allows iteration over all non-interface types in all loaded assemblies in the AsmMgr that are assignable to the given type (IsAssignableFrom).
/// Warning: care should be used when using this method in hot paths as performance may be affected.
/// </summary>
/// <typeparam name="T">The type to compare against</typeparam>
/// <param name="rebuildList">Forces caches to clear and for the lists of types to be rebuilt.</param>
/// <returns>An Enumerator for matching types.</returns>
public IEnumerable<Type> GetSubTypesInLoadedAssemblies<T>(bool rebuildList)
{
Type targetType = typeof(T);
string typeName = targetType.FullName ?? targetType.Name;
// rebuild
if (rebuildList)
RebuildTypesList();
// check cache
if (_subTypesLookupCache.TryGetValue(typeName, out var subTypeList))
{
return subTypeList;
}
// build from scratch
OpsLockLoaded.EnterReadLock();
try
{
// build list
var list1 = _defaultContextTypes
.Where(kvp1 => targetType.IsAssignableFrom(kvp1.Value) && !kvp1.Value.IsInterface)
.Concat(LoadedACLs
.SelectMany(kvp => kvp.Value.AssembliesTypes)
.Where(kvp2 => targetType.IsAssignableFrom(kvp2.Value) && !kvp2.Value.IsInterface))
.Select(kvp3 => kvp3.Value)
.ToImmutableList();
// only add if we find something
if (list1.Count > 0)
{
if (!_subTypesLookupCache.TryAdd(typeName, list1))
{
ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Unable to add subtypes to cache of type {typeName}!");
}
}
else
{
ModUtils.Logging.PrintMessage($"{nameof(AssemblyManager)}: Warning: No types found during search for subtypes of {typeName}");
}
return list1;
}
finally
{
OpsLockLoaded.ExitReadLock();
}
}
/// <summary>
/// Tries to get types assignable to type from the ACL given the Guid.
/// </summary>
/// <param name="id"></param>
/// <param name="types"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public bool TryGetSubTypesFromACL<T>(Guid id, out IEnumerable<Type> types)
{
Type targetType = typeof(T);
if (TryGetACL(id, out var acl))
{
types = acl.AssembliesTypes
.Where(kvp => targetType.IsAssignableFrom(kvp.Value) && !kvp.Value.IsInterface)
.Select(kvp => kvp.Value);
return true;
}
types = null;
return false;
}
/// <summary>
/// Tries to get types from the ACL given the Guid.
/// </summary>
/// <param name="id"></param>
/// <param name="types"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public bool TryGetSubTypesFromACL(Guid id, out IEnumerable<Type> types)
{
if (TryGetACL(id, out var acl))
{
types = acl.AssembliesTypes.Select(kvp => kvp.Value);
return true;
}
types = null;
return false;
}
/// <summary>
/// Allows iteration over all types, including interfaces, in all loaded assemblies in the AsmMgr who's names match the string.
/// Note: Will return the by-reference equivalent type if the type name is prefixed with "out " or "ref ".
/// </summary>
/// <param name="name">The string name of the type to search for.</param>
/// <returns>An Enumerator for matching types.</returns>
public IEnumerable<Type> GetTypesByName(string typeName)
{
bool byRef = false;
if (typeName.StartsWith("out ") || typeName.StartsWith("ref "))
{
typeName = typeName.Remove(0, 4);
byRef = true;
}
List<Type> types = new();
TypesListHelper();
if (types.Count > 0)
return types;
// we couldn't find it, rebuild and try one more time
RebuildTypesList();
TypesListHelper();
return types;
void TypesListHelper()
{
if (_defaultContextTypes.TryGetValue(typeName, out var type1))
{
if (type1 is not null)
types.Add(byRef ? type1.MakeByRefType() : type1);
}
OpsLockLoaded.EnterReadLock();
try
{
foreach (KeyValuePair<Guid,LoadedACL> loadedAcl in LoadedACLs)
{
var at = loadedAcl.Value.AssembliesTypes;
if (at.TryGetValue(typeName, out var type2))
{
if (type2 is not null)
types.Add(byRef ? type2.MakeByRefType() : type2);
}
}
}
finally
{
OpsLockLoaded.ExitReadLock();
}
}
}
/// <summary>
/// Allows iteration over all types (including interfaces) in all loaded assemblies managed by the AsmMgr.
/// Warning: High usage may result in performance issues.
/// </summary>
/// <returns>An Enumerator for iteration.</returns>
public IEnumerable<Type> GetAllTypesInLoadedAssemblies()
{
OpsLockLoaded.EnterReadLock();
try
{
return AssemblyLoadContext.Default.Assemblies
.SelectMany(a => a.GetSafeTypes())
.Concat(LoadedACLs
.SelectMany(kvp => kvp.Value.AssembliesTypes.Select(kv => kv.Value)))
.ToImmutableList();
}
finally
{
OpsLockLoaded.ExitReadLock();
}
}
/// <summary>
/// Returns a list of all loaded ACLs.
/// WARNING: References to these ACLs outside of the AssemblyManager should be kept in a WeakReference in order
/// to avoid causing issues with unloading/disposal.
/// </summary>
/// <returns></returns>
public IEnumerable<LoadedACL> GetAllLoadedACLs()
{
try
{
OpsLockLoaded.EnterReadLock();
return LoadedACLs.Select(kvp => kvp.Value).ToImmutableList();
}
finally
{
OpsLockLoaded.ExitReadLock();
}
}
#endregion
#region InternalAPI
/// <summary>
/// Used by content package and plugin management to stop unloading of a given ACL until all plugins have gracefully closed.
/// </summary>
public event System.Func<LoadedACL, bool> IsReadyToUnloadACL;
public AssemblyLoadingSuccessState LoadAssemblyFromMemory([NotNull] string compiledAssemblyName,
[NotNull] IEnumerable<SyntaxTree> syntaxTree,
IEnumerable<MetadataReference> externalMetadataReferences,
[NotNull] CSharpCompilationOptions compilationOptions,
ref Guid id,
IEnumerable<Assembly> externFileAssemblyRefs = null)
{
// validation
if (compiledAssemblyName.IsNullOrWhiteSpace())
return AssemblyLoadingSuccessState.BadName;
if (!GetOrCreateACL(id, out var acl))
return AssemblyLoadingSuccessState.ACLLoadFailure;
id = acl.Id; // pass on true id returned
// this acl is already hosting an in-memory assembly
if (acl.Acl.CompiledAssembly is not null)
return AssemblyLoadingSuccessState.AlreadyLoaded;
// compile
var state = acl.Acl.CompileAndLoadScriptAssembly(compiledAssemblyName, syntaxTree, externalMetadataReferences,
compilationOptions, out var messages, externFileAssemblyRefs);
// get types
if (state is AssemblyLoadingSuccessState.Success)
{
_subTypesLookupCache.Clear();
acl.RebuildTypesList();
OnAssemblyLoaded?.Invoke(acl.Acl.CompiledAssembly);
}
else
{
ModUtils.Logging.PrintError($"Unable to compile assembly '{compiledAssemblyName}' due to errors: {messages}");
}
return state;
}
/// <summary>
/// Switches the ACL with the given Guid to Template Mode, which disables assembly name resolution for any assemblies loaded in it.
/// These ACLs are intended to be used to host Assemblies for information only and not for code execution.
/// WARNING: This process is irreversible.
/// </summary>
/// <param name="guid">Guid of the ACL.</param>
/// <returns>Whether or not an ACL was found with the given ID.</returns>
public bool SetACLToTemplateMode(Guid guid)
{
if (!TryGetACL(guid, out var acl))
return false;
acl.Acl.IsTemplateMode = true;
return true;
}
/// <summary>
/// Tries to load all assemblies at the supplied file paths list into the ACl with the given Guid.
/// If the supplied Guid is Empty, then a new ACl will be created and the Guid will be assigned to it.
/// </summary>
/// <param name="filePaths">List of assemblies to try and load.</param>
/// <param name="id">Guid of the ACL or Empty if none specified. Guid of ACL will be assigned to this var.</param>
/// <returns>Operation success messages.</returns>
/// <exception cref="ArgumentNullException"></exception>
public AssemblyLoadingSuccessState LoadAssembliesFromLocations([NotNull] IEnumerable<string> filePaths,
ref Guid id)
{
if (filePaths is null)
{
throw new ArgumentNullException(
$"{nameof(AssemblyManager)}::{nameof(LoadAssembliesFromLocations)}() | file paths supplied is null!");
}
ImmutableList<string> assemblyFilePaths = filePaths.ToImmutableList(); // copy the list before loading
if (!assemblyFilePaths.Any())
{
return AssemblyLoadingSuccessState.NoAssemblyFound;
}
if (GetOrCreateACL(id, out var loadedAcl))
{
var state = loadedAcl.Acl.LoadFromFiles(assemblyFilePaths);
// if failure, we dispose of the acl
if (state != AssemblyLoadingSuccessState.Success)
{
DisposeACL(loadedAcl.Id);
ModUtils.Logging.PrintError($"ACL failed, unloading...");
return state;
}
// build types list
_subTypesLookupCache.Clear();
loadedAcl.RebuildTypesList();
id = loadedAcl.Id;
foreach (Assembly assembly in loadedAcl.Acl.Assemblies)
{
OnAssemblyLoaded?.Invoke(assembly);
}
return state;
}
return AssemblyLoadingSuccessState.ACLLoadFailure;
}
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Synchronized)]
public bool TryBeginDispose()
{
OpsLockLoaded.EnterWriteLock();
OpsLockUnloaded.EnterWriteLock();
try
{
_subTypesLookupCache.Clear();
foreach (KeyValuePair<Guid, LoadedACL> loadedAcl in LoadedACLs)
{
if (loadedAcl.Value.Acl is not null)
{
foreach (Delegate del in IsReadyToUnloadACL.GetInvocationList())
{
if (del is System.Func<LoadedACL, bool> { } func)
{
if (!func.Invoke(loadedAcl.Value))
return false; // Not ready, exit
}
}
foreach (Assembly assembly in loadedAcl.Value.Acl.Assemblies)
{
OnAssemblyUnloading?.Invoke(assembly);
}
UnloadingACLs.Add(new WeakReference<MemoryFileAssemblyContextLoader>(loadedAcl.Value.Acl, true));
loadedAcl.Value.ClearTypesList();
loadedAcl.Value.Acl.Unload();
OnACLUnload?.Invoke(loadedAcl.Value.Id);
}
}
LoadedACLs.Clear();
return true;
}
catch
{
// should never happen
return false;
}
finally
{
OpsLockUnloaded.ExitWriteLock();
OpsLockLoaded.ExitWriteLock();
}
}
[MethodImpl(MethodImplOptions.NoInlining)]
public bool FinalizeDispose()
{
bool isUnloaded;
OpsLockUnloaded.EnterUpgradeableReadLock();
try
{
List<WeakReference<MemoryFileAssemblyContextLoader>> toRemove = new();
foreach (WeakReference<MemoryFileAssemblyContextLoader> weakReference in UnloadingACLs)
{
if (!weakReference.TryGetTarget(out _))
{
toRemove.Add(weakReference);
}
}
if (toRemove.Any())
{
OpsLockUnloaded.EnterWriteLock();
try
{
foreach (WeakReference<MemoryFileAssemblyContextLoader> reference in toRemove)
{
UnloadingACLs.Remove(reference);
}
}
finally
{
OpsLockUnloaded.ExitWriteLock();
}
}
isUnloaded = !UnloadingACLs.Any();
}
finally
{
OpsLockUnloaded.ExitUpgradeableReadLock();
}
return isUnloaded;
}
/// <summary>
/// Tries to retrieve the LoadedACL with the given ID or null if none is found.
/// WARNING: External references to this ACL with long lifespans should be kept in a WeakReference
/// to avoid causing unloading/disposal issues.
/// </summary>
/// <param name="id">GUID of the ACL.</param>
/// <param name="acl">The found ACL or null if none was found.</param>
/// <returns>Whether or not an ACL was found.</returns>
[MethodImpl(MethodImplOptions.NoInlining)]
public bool TryGetACL(Guid id, out LoadedACL acl)
{
acl = null;
OpsLockLoaded.EnterReadLock();
try
{
if (id.Equals(Guid.Empty) || !LoadedACLs.ContainsKey(id))
return false;
acl = LoadedACLs[id];
return true;
}
finally
{
OpsLockLoaded.ExitReadLock();
}
}
/// <summary>
/// Gets or creates an AssemblyCtxLoader for the given ID. Creates if the ID is empty or no ACL can be found.
/// [IMPORTANT] After calling this method, the id you use should be taken from the acl container (acl.Id).
/// </summary>
/// <param name="id"></param>
/// <param name="acl"></param>
/// <returns>Should only return false if an error occurs.</returns>
[MethodImpl(MethodImplOptions.NoInlining)]
private bool GetOrCreateACL(Guid id, out LoadedACL acl)
{
OpsLockLoaded.EnterUpgradeableReadLock();
try
{
if (id.Equals(Guid.Empty) || !LoadedACLs.ContainsKey(id) || LoadedACLs[id] is null)
{
OpsLockLoaded.EnterWriteLock();
try
{
id = Guid.NewGuid();
acl = new LoadedACL(id, this);
LoadedACLs[id] = acl;
return true;
}
finally
{
OpsLockLoaded.ExitWriteLock();
}
}
else
{
acl = LoadedACLs[id];
return true;
}
}
catch
{
// should never happen but in-case
acl = null;
return false;
}
finally
{
OpsLockLoaded.ExitUpgradeableReadLock();
}
}
[MethodImpl(MethodImplOptions.NoInlining)]
private bool DisposeACL(Guid id)
{
OpsLockLoaded.EnterWriteLock();
OpsLockUnloaded.EnterWriteLock();
try
{
if (id.Equals(Guid.Empty) || !LoadedACLs.ContainsKey(id) || LoadedACLs[id] is null)
{
return false; // nothing to dispose of
}
var acl = LoadedACLs[id];
foreach (Assembly assembly in acl.Acl.Assemblies)
{
OnAssemblyUnloading?.Invoke(assembly);
}
_subTypesLookupCache.Clear();
UnloadingACLs.Add(new WeakReference<MemoryFileAssemblyContextLoader>(acl.Acl, true));
acl.Acl.Unload();
OnACLUnload?.Invoke(acl.Id);
return true;
}
catch
{
// should never happen
return false;
}
finally
{
OpsLockLoaded.ExitWriteLock();
OpsLockUnloaded.ExitWriteLock();
}
}
internal AssemblyManager()
{
RebuildTypesList();
}
/// <summary>
/// Rebuilds the list of types in the default assembly load context.
/// </summary>
private void RebuildTypesList()
{
try
{
_defaultContextTypes = AssemblyLoadContext.Default.Assemblies
.SelectMany(a => a.GetSafeTypes())
.ToImmutableDictionary(t => t.FullName ?? t.Name, t => t);
_subTypesLookupCache.Clear();
}
catch(ArgumentException _)
{
try
{
// some types must've had duplicate type names, build the list while filtering
Dictionary<string, Type> types = new();
foreach (var type in AssemblyLoadContext.Default.Assemblies.SelectMany(a => a.GetSafeTypes()))
{
try
{
types.TryAdd(type.FullName ?? type.Name, type);
}
catch
{
// ignore, null key exception
}
}
_defaultContextTypes = types.ToImmutableDictionary();
}
catch (Exception e)
{
ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Unable to create list of default assembly types! Default AssemblyLoadContext types searching not available.");
#if DEBUG
ModUtils.Logging.PrintError($"{nameof(AssemblyManager)}: Exception Details :{e.Message} | {e.InnerException}");
#endif
_defaultContextTypes = ImmutableDictionary<string, Type>.Empty;
}
}
}
#endregion
#region Data
private readonly ConcurrentDictionary<string, ImmutableList<Type>> _subTypesLookupCache = new();
private ImmutableDictionary<string, Type> _defaultContextTypes;
private readonly ConcurrentDictionary<Guid, LoadedACL> LoadedACLs = new();
private readonly List<WeakReference<MemoryFileAssemblyContextLoader>> UnloadingACLs= new();
private readonly ReaderWriterLockSlim OpsLockLoaded = new ReaderWriterLockSlim();
private readonly ReaderWriterLockSlim OpsLockUnloaded = new ReaderWriterLockSlim();
#endregion
#region TypeDefs
public sealed class LoadedACL
{
public readonly Guid Id;
private ImmutableDictionary<string, Type> _assembliesTypes = ImmutableDictionary<string, Type>.Empty;
public readonly MemoryFileAssemblyContextLoader Acl;
private readonly AssemblyManager _manager;
internal LoadedACL(Guid id, AssemblyManager manager)
{
this.Id = id;
this.Acl = new(manager);
this._manager = manager;
}
public ImmutableDictionary<string, Type> AssembliesTypes => _assembliesTypes;
/// <summary>
/// Rebuild the list of types from assemblies loaded in the AsmCtxLoader.
/// </summary>
internal void RebuildTypesList()
{
ClearTypesList();
try
{
_assembliesTypes = this.Acl.Assemblies
.SelectMany(a => a.GetSafeTypes())
.ToImmutableDictionary(t => t.FullName ?? t.Name, t => t);
}
catch(ArgumentException _)
{
// some types must've had duplicate type names, build the list while filtering
Dictionary<string, Type> types = new();
foreach (var type in this.Acl.Assemblies.SelectMany(a => a.GetSafeTypes()))
{
try
{
types.TryAdd(type.FullName ?? type.Name, type);
}
catch
{
// ignore, null key exception
}
}
_assembliesTypes = types.ToImmutableDictionary();
}
}
internal void ClearTypesList()
{
_assembliesTypes.Clear();
}
}
#endregion
}
public static class AssemblyExtensions
{
/// <summary>
/// Gets all types in the given assembly. Handles invalid type scenarios.
/// </summary>
/// <param name="assembly">The assembly to scan</param>
/// <returns>An enumerable collection of types.</returns>
public static IEnumerable<Type> GetSafeTypes(this Assembly assembly)
{
// Based on https://github.com/Qkrisi/ktanemodkit/blob/master/Assets/Scripts/ReflectionHelper.cs#L53-L67
try
{
return assembly.GetTypes();
}
catch (ReflectionTypeLoadException re)
{
try
{
return re.Types.Where(x => x != null)!;
}
catch (InvalidOperationException ioe)
{
return new List<Type>();
}
}
catch (Exception e)
{
return new List<Type>();
}
}
}

View File

@@ -0,0 +1,978 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading;
using Barotrauma.Steam;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using MonoMod.Utils;
namespace Barotrauma;
public sealed class CsPackageManager : IDisposable
{
#region PRIVATE_FUNCDATA
private static readonly CSharpParseOptions ScriptParseOptions = CSharpParseOptions.Default
.WithPreprocessorSymbols(new[]
{
#if SERVER
"SERVER"
#elif CLIENT
"CLIENT"
#else
"UNDEFINED"
#endif
#if DEBUG
,"DEBUG"
#endif
});
#if WINDOWS
private const string PLATFORM_TARGET = "Windows";
#elif OSX
private const string PLATFORM_TARGET = "OSX";
#elif LINUX
private const string PLATFORM_TARGET = "Linux";
#endif
#if CLIENT
private const string ARCHITECTURE_TARGET = "Client";
#elif SERVER
private const string ARCHITECTURE_TARGET = "Server";
#endif
private static readonly CSharpCompilationOptions CompilationOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)
.WithMetadataImportOptions(MetadataImportOptions.All)
#if DEBUG
.WithOptimizationLevel(OptimizationLevel.Debug)
#else
.WithOptimizationLevel(OptimizationLevel.Release)
#endif
.WithAllowUnsafe(true);
private static readonly SyntaxTree BaseAssemblyImports = CSharpSyntaxTree.ParseText(
new StringBuilder()
.AppendLine("using System.Reflection;")
.AppendLine("using Barotrauma;")
.AppendLine("using System.Runtime.CompilerServices;")
#if CLIENT
.AppendLine("[assembly: IgnoresAccessChecksTo(\"Barotrauma\")]")
#elif SERVER
.AppendLine("[assembly: IgnoresAccessChecksTo(\"DedicatedServer\")]")
#endif
.ToString(),
ScriptParseOptions);
private const string SCRIPT_FILE_REGEX = "*.cs";
private const string ASSEMBLY_FILE_REGEX = "*.dll";
private readonly float _assemblyUnloadTimeoutSeconds = 4f;
private Guid _publicizedAssemblyLoader;
private readonly List<ContentPackage> _currentPackagesByLoadOrder = new();
private readonly Dictionary<ContentPackage, ImmutableList<ContentPackage>> _packagesDependencies = new();
private readonly Dictionary<ContentPackage, Guid> _loadedCompiledPackageAssemblies = new();
private readonly Dictionary<Guid, ContentPackage> _reverseLookupGuidList = new();
private readonly Dictionary<Guid, HashSet<IAssemblyPlugin>> _loadedPlugins = new ();
private readonly Dictionary<Guid, ImmutableHashSet<Type>> _pluginTypes = new(); // where Type : IAssemblyPlugin
private readonly Dictionary<ContentPackage, RunConfig> _packageRunConfigs = new();
private readonly Dictionary<Guid, ImmutableList<Type>> _luaRegisteredTypes = new();
private readonly AssemblyManager _assemblyManager;
private readonly LuaCsSetup _luaCsSetup;
private DateTime _assemblyUnloadStartTime;
#endregion
#region PUBLIC_API
#region LUA_EXTENSIONS
/// <summary>
/// Searches for all types in all loaded assemblies from content packages who's names contain the name string and registers them with the Lua Interpreter.
/// </summary>
/// <param name="name"></param>
/// <param name="caseSensitive"></param>
/// <returns></returns>
public bool LuaTryRegisterPackageTypes(string name, bool caseSensitive = false)
{
if (!AssembliesLoaded)
return false;
var matchingPacks = _loadedCompiledPackageAssemblies
.Where(kvp => kvp.Key.Name.ToLowerInvariant().Contains(name.ToLowerInvariant()))
.Select(kvp => kvp.Value)
.ToImmutableList();
if (!matchingPacks.Any())
return false;
var types = matchingPacks
.Where(guid => !_luaRegisteredTypes.ContainsKey(guid))
.Select(guid => new KeyValuePair<Guid, ImmutableList<Type>>(
guid,
_assemblyManager.TryGetSubTypesFromACL(guid, out var types)
? types.ToImmutableList()
: ImmutableList<Type>.Empty))
.ToImmutableList();
if (!types.Any())
return false;
foreach (var kvp in types)
{
_luaRegisteredTypes[kvp.Key] = kvp.Value;
foreach (Type type in kvp.Value)
{
MoonSharp.Interpreter.UserData.RegisterType(type);
}
}
return true;
}
#endregion
/// <summary>
/// Whether or not assemblies have been loaded.
/// </summary>
public bool AssembliesLoaded { get; private set; }
/// <summary>
/// Whether or not loaded plugins had their preloader run.
/// </summary>
public bool PluginsPreInit { get; private set; }
/// <summary>
/// Whether or not plugins' types have been instantiated.
/// </summary>
public bool PluginsInitialized { get; private set; } = false;
/// <summary>
/// Whether or not plugins are fully loaded.
/// </summary>
public bool PluginsLoaded { get; private set; } = false;
public IEnumerable<ContentPackage> GetCurrentPackagesByLoadOrder() => _currentPackagesByLoadOrder;
/// <summary>
/// Tries to find the content package that a given plugin belongs to.
/// </summary>
/// <param name="package">Package if found, null otherwise.</param>
/// <typeparam name="T">The IAssemblyPlugin type to find.</typeparam>
/// <returns></returns>
public bool TryGetPackageForPlugin<T>(out ContentPackage package) where T : IAssemblyPlugin
{
package = null;
var t = typeof(T);
var guid = _pluginTypes
.Where(kvp => kvp.Value.Contains(t))
.Select(kvp => kvp.Key)
.FirstOrDefault(Guid.Empty);
if (guid.Equals(Guid.Empty) || !_reverseLookupGuidList.ContainsKey(guid) || _reverseLookupGuidList[guid] is null)
return false;
package = _reverseLookupGuidList[guid];
return true;
}
/// <summary>
/// Tries to get the loaded plugins for a given package.
/// </summary>
/// <param name="package">Package to find.</param>
/// <param name="loadedPlugins">The collection of loaded plugins.</param>
/// <returns></returns>
public bool TryGetLoadedPluginsForPackage(ContentPackage package, out IEnumerable<IAssemblyPlugin> loadedPlugins)
{
loadedPlugins = null;
if (package is null || !_loadedCompiledPackageAssemblies.ContainsKey(package))
return false;
var guid = _loadedCompiledPackageAssemblies[package];
if (guid.Equals(Guid.Empty) || !_loadedPlugins.ContainsKey(guid))
return false;
loadedPlugins = _loadedPlugins[guid];
return true;
}
/// <summary>
/// Called when clean up is being performed. Use when relying on or making use of references from this manager.
/// </summary>
public event Action OnDispose;
public void Dispose()
{
// send events for cleanup
OnDispose?.Invoke();
// cleanup events
if (OnDispose is not null)
{
foreach (Delegate del in OnDispose.GetInvocationList())
{
OnDispose -= (del as System.Action);
}
}
// cleanup plugins and assemblies
ReflectionUtils.ResetCache();
UnloadPlugins();
// try cleaning up the assemblies
_pluginTypes.Clear(); // remove assembly references
_loadedPlugins.Clear();
// lua cleanup
foreach (var kvp in _luaRegisteredTypes)
{
foreach (Type type in kvp.Value)
{
MoonSharp.Interpreter.UserData.UnregisterType(type);
}
}
_luaRegisteredTypes.Clear();
_assemblyUnloadStartTime = DateTime.Now;
_publicizedAssemblyLoader = Guid.Empty;
// we can't wait forever or app dies but we can try to be graceful
while (!_assemblyManager.TryBeginDispose())
{
if (_assemblyUnloadStartTime.AddSeconds(_assemblyUnloadTimeoutSeconds) > DateTime.Now)
{
break;
}
}
_assemblyUnloadStartTime = DateTime.Now;
while (!_assemblyManager.FinalizeDispose())
{
if (_assemblyUnloadStartTime.AddSeconds(_assemblyUnloadTimeoutSeconds) > DateTime.Now)
{
break;
}
}
_assemblyManager.OnAssemblyLoaded -= AssemblyManagerOnAssemblyLoaded;
_assemblyManager.OnAssemblyUnloading -= AssemblyManagerOnAssemblyUnloading;
_publicizedAssemblyLoader = Guid.Empty;
// clear lists after cleaning up
_packagesDependencies.Clear();
_loadedCompiledPackageAssemblies.Clear();
_reverseLookupGuidList.Clear();
_packageRunConfigs.Clear();
_currentPackagesByLoadOrder.Clear();
AssembliesLoaded = false;
GC.SuppressFinalize(this);
}
/// <summary>
/// Begins the loading process of scanning packages for scripts and binary assemblies, compiling and executing them.
/// </summary>
/// <returns></returns>
public AssemblyLoadingSuccessState LoadAssemblyPackages()
{
if (AssembliesLoaded)
{
return AssemblyLoadingSuccessState.AlreadyLoaded;
}
_assemblyManager.OnAssemblyLoaded += AssemblyManagerOnAssemblyLoaded;
_assemblyManager.OnAssemblyUnloading += AssemblyManagerOnAssemblyUnloading;
// load publicized assemblies
var publicizedDir = Path.Combine(Environment.CurrentDirectory, "Publicized");
ImmutableList<Assembly> publicizedAssemblies = ImmutableList<Assembly>.Empty;
if (Directory.Exists(publicizedDir))
{
// search for assemblies
var list = Directory.GetFiles(publicizedDir, "*.dll")
#if CLIENT
.Where(s => !s.ToLowerInvariant().EndsWith("dedicatedserver.dll"));
#elif SERVER
.Where(s => !s.ToLowerInvariant().EndsWith("barotrauma.dll"));
#endif
// try load them into an acl
var loadState = _assemblyManager.LoadAssembliesFromLocations(list, ref _publicizedAssemblyLoader);
// loaded
if (loadState is AssemblyLoadingSuccessState.Success)
{
if (_assemblyManager.TryGetACL(_publicizedAssemblyLoader, out var acl))
{
publicizedAssemblies = acl.Acl.Assemblies.ToImmutableList();
_assemblyManager.SetACLToTemplateMode(_publicizedAssemblyLoader);
}
}
}
// get packages
IEnumerable<ContentPackage> packages = BuildPackagesList();
// check and load config
_packageRunConfigs.AddRange(packages
.Select(p => new KeyValuePair<ContentPackage, RunConfig>(p, GetRunConfigForPackage(p)))
.ToDictionary(p => p.Key, p=> p.Value));
// filter not to be loaded
var cpToRun = _packageRunConfigs
.Where(kvp => ShouldRunPackage(kvp.Key, kvp.Value))
.Select(kvp => kvp.Key)
.ToImmutableList();
// build dependencies map
bool reliableMap = TryBuildDependenciesMap(cpToRun, out var packDeps);
if (!reliableMap)
{
ModUtils.Logging.PrintMessage($"{nameof(CsPackageManager)}: Unable to create reliable dependencies map.");
}
_packagesDependencies.AddRange(packDeps.ToDictionary(
kvp => kvp.Key,
kvp => kvp.Value.ToImmutableList())
);
List<ContentPackage> packagesToLoadInOrder = new();
// build load order
if (reliableMap && OrderAndFilterPackagesByDependencies(
_packagesDependencies,
out var readyToLoad,
out var cannotLoadPackages,
null))
{
packagesToLoadInOrder.AddRange(readyToLoad);
if (cannotLoadPackages is not null)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the following mods due to dependency errors:");
foreach (var pair in cannotLoadPackages)
{
ModUtils.Logging.PrintError($"Package: {pair.Key.Name} | Reason: {pair.Value}");
}
}
}
else
{
// use unsorted list on failure and send error message.
packagesToLoadInOrder.AddRange(_packagesDependencies.Select( p=> p.Key));
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to create a reliable load order. Defaulting to unordered loading!");
}
// get assemblies and scripts' filepaths from packages
var toLoad = packagesToLoadInOrder
.Select(cp => new KeyValuePair<ContentPackage, LoadableData>(
cp,
new LoadableData(
TryScanPackagesForAssemblies(cp, out var list1) ? list1 : null,
TryScanPackageForScripts(cp, out var list2) ? list2 : null)))
.ToImmutableDictionary();
HashSet<ContentPackage> badPackages = new();
foreach (var pair in toLoad)
{
// check if unloadable
if (badPackages.Contains(pair.Key))
continue;
// try load binary assemblies
var id = Guid.Empty; // id for the ACL for this package defined by AssemblyManager.
AssemblyLoadingSuccessState successState;
if (pair.Value.AssembliesFilePaths is not null && pair.Value.AssembliesFilePaths.Any())
{
ModUtils.Logging.PrintMessage($"Loading assemblies for CPackage {pair.Key.Name}");
#if DEBUG
foreach (string assembliesFilePath in pair.Value.AssembliesFilePaths)
{
ModUtils.Logging.PrintMessage($"Found assemblies located at {Path.GetFullPath(ModUtils.IO.SanitizePath(assembliesFilePath))}");
}
#endif
successState = _assemblyManager.LoadAssembliesFromLocations(pair.Value.AssembliesFilePaths, ref id);
// error handling
if (successState is not AssemblyLoadingSuccessState.Success)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the binary assemblies for package {pair.Key.Name}. Error: {successState.ToString()}");
UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies);
continue;
}
}
// try compile scripts to assemblies
if (pair.Value.ScriptsFilePaths is not null && pair.Value.ScriptsFilePaths.Any())
{
ModUtils.Logging.PrintMessage($"Loading scripts for CPackage {pair.Key.Name}");
List<SyntaxTree> syntaxTrees = new();
syntaxTrees.Add(GetPackageScriptImports());
bool abortPackage = false;
// load scripts data from files
foreach (string scriptPath in pair.Value.ScriptsFilePaths)
{
var state = ModUtils.IO.GetOrCreateFileText(scriptPath, out string fileText, null, false);
// could not load file data
if (state is not ModUtils.IO.IOActionResultState.Success)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the script files for package {pair.Key.Name}. Error: {state.ToString()}");
UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies);
abortPackage = true;
break;
}
try
{
CancellationToken token = new();
syntaxTrees.Add(SyntaxFactory.ParseSyntaxTree(fileText, ScriptParseOptions, scriptPath, Encoding.Default, token));
// cancel if parsing failed
if (token.IsCancellationRequested)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the script files for package {pair.Key.Name}. Error: Syntax Parse Error.");
UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies);
abortPackage = true;
break;
}
}
catch (Exception e)
{
// unknown error
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to load the script files for package {pair.Key.Name}. Error: {e.Message}");
UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies);
abortPackage = true;
break;
}
}
if (abortPackage)
continue;
// try compile
successState = _assemblyManager.LoadAssemblyFromMemory(
pair.Key.Name.Replace(" ",""),
syntaxTrees,
null,
CompilationOptions,
ref id, publicizedAssemblies);
if (successState is not AssemblyLoadingSuccessState.Success)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Unable to compile script assembly for package {pair.Key.Name}. Error: {successState.ToString()}");
UpdatePackagesToDisable(ref badPackages, pair.Key, _packagesDependencies);
continue;
}
}
// something was loaded, add to index
if (id != Guid.Empty)
{
ModUtils.Logging.PrintMessage($"Assemblies from CPackage {pair.Key.Name} loaded with Guid {id}.");
_loadedCompiledPackageAssemblies.Add(pair.Key, id);
_reverseLookupGuidList.Add(id, pair.Key);
}
}
// update loaded packages to exclude bad packages
_currentPackagesByLoadOrder.AddRange(toLoad
.Where(p => !badPackages.Contains(p.Key))
.Select(p => p.Key));
// build list of plugins
foreach (var pair in _loadedCompiledPackageAssemblies)
{
if (_assemblyManager.TryGetSubTypesFromACL<IAssemblyPlugin>(pair.Value, out var types))
{
_pluginTypes[pair.Value] = types.ToImmutableHashSet();
foreach (var type in _pluginTypes[pair.Value])
{
ModUtils.Logging.PrintMessage($"Loading type: {type.Name}");
}
}
}
this.AssembliesLoaded = true;
return AssemblyLoadingSuccessState.Success;
bool ShouldRunPackage(ContentPackage package, RunConfig config)
{
if (config.AutoGenerated)
return false;
return (!_luaCsSetup.Config.TreatForcedModsAsNormal && config.IsForced())
|| (ContentPackageManager.EnabledPackages.All.Contains(package) && config.IsForcedOrStandard());
}
void UpdatePackagesToDisable(ref HashSet<ContentPackage> list,
ContentPackage newDisabledPackage,
IEnumerable<KeyValuePair<ContentPackage, ImmutableList<ContentPackage>>> dependenciesMap)
{
list.Add(newDisabledPackage);
foreach (var package in dependenciesMap)
{
if (package.Value.Contains(newDisabledPackage))
list.Add(newDisabledPackage);
}
}
}
/// <summary>
/// Executes instantiated plugins' Initialize() and OnLoadCompleted() methods.
/// </summary>
public void RunPluginsInit()
{
if (!AssembliesLoaded)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' Initialize() without any loaded assemblies!");
return;
}
if (!PluginsInitialized)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' Initialize() without type instantiation!");
return;
}
if (PluginsLoaded)
return;
foreach (var contentPlugins in _loadedPlugins)
{
// init
foreach (var plugin in contentPlugins.Value)
{
TryRun(() => plugin.Initialize(), $"{nameof(IAssemblyPlugin.Initialize)}", plugin.GetType().Name);
}
}
foreach (var contentPlugins in _loadedPlugins)
{
// load complete
foreach (var plugin in contentPlugins.Value)
{
TryRun(() => plugin.OnLoadCompleted(), $"{nameof(IAssemblyPlugin.OnLoadCompleted)}", plugin.GetType().Name);
}
}
PluginsLoaded = true;
}
/// <summary>
/// Executes instantiated plugins' PreInitPatching() method.
/// </summary>
public void RunPluginsPreInit()
{
if (!AssembliesLoaded)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' PreInitPatching() without any loaded assemblies!");
return;
}
if (!PluginsInitialized)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to call plugins' PreInitPatching() without type initialization!");
return;
}
if (PluginsPreInit)
{
return;
}
foreach (var contentPlugins in _loadedPlugins)
{
// init
foreach (var plugin in contentPlugins.Value)
{
TryRun(() => plugin.PreInitPatching(), $"{nameof(IAssemblyPlugin.PreInitPatching)}", plugin.GetType().Name);
}
}
PluginsPreInit = true;
}
/// <summary>
/// Initializes plugin types that are registered.
/// </summary>
/// <param name="force"></param>
public void InstantiatePlugins(bool force = false)
{
if (!AssembliesLoaded)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to instantiate plugins without any loaded assemblies!");
return;
}
if (PluginsInitialized)
{
if (force)
UnloadPlugins();
else
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Attempted to load plugins when they were already loaded!");
return;
}
}
foreach (var pair in _pluginTypes)
{
// instantiate
foreach (Type type in pair.Value)
{
if (!_loadedPlugins.ContainsKey(pair.Key))
_loadedPlugins.Add(pair.Key, new());
else if (_loadedPlugins[pair.Key] is null)
_loadedPlugins[pair.Key] = new();
IAssemblyPlugin plugin = null;
try
{
plugin = (IAssemblyPlugin)Activator.CreateInstance(type);
}
catch (Exception e)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while instantiating plugin of type {type}. Now disposing...");
#if DEBUG
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Details: {e.Message} | {e.InnerException}");
#endif
TryRun(() => plugin?.Dispose(), "Dispose", type.FullName ?? type.Name);
plugin = null;
}
if (plugin is not null)
_loadedPlugins[pair.Key].Add(plugin);
else
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while instantiating plugin of type {type}");
}
}
PluginsInitialized = true;
}
/// <summary>
/// Unloads all plugins by calling Dispose() on them. Note: This does not remove their external references nor
/// unregister their types.
/// </summary>
public void UnloadPlugins()
{
foreach (var contentPlugins in _loadedPlugins)
{
foreach (var plugin in contentPlugins.Value)
{
TryRun(() => plugin.Dispose(), $"{nameof(IAssemblyPlugin.Dispose)}", plugin.GetType().Name);
}
contentPlugins.Value.Clear();
}
_loadedPlugins.Clear();
PluginsInitialized = false;
PluginsPreInit = false;
PluginsLoaded = false;
}
/// <summary>
/// Gets the RunConfig.xml for the given package located at [cp_root]/CSharp/RunConfig.xml.
/// Generates a default config if one is not found.
/// </summary>
/// <param name="package">The package to search for.</param>
/// <param name="config">RunConfig data.</param>
/// <returns>True if a config is loaded, false if one was created.</returns>
public static bool GetOrCreateRunConfig(ContentPackage package, out RunConfig config)
{
var path = System.IO.Path.Combine(Path.GetFullPath(package.Dir), "CSharp", "RunConfig.xml");
if (!File.Exists(path))
{
config = new RunConfig(true).Sanitize();
return false;
}
return ModUtils.IO.LoadOrCreateTypeXml(out config, path, () => new RunConfig(true).Sanitize(), false);
}
#endregion
#region INTERNALS
private void TryRun(Action action, string messageMethodName, string messageTypeName)
{
try
{
action?.Invoke();
}
catch (Exception e)
{
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Error while running {messageMethodName}() on plugin of type {messageTypeName}");
#if DEBUG
ModUtils.Logging.PrintError($"{nameof(CsPackageManager)}: Details: {e.Message} | {e.InnerException}");
#endif
}
}
private void AssemblyManagerOnAssemblyUnloading(Assembly assembly)
{
ReflectionUtils.RemoveAssemblyFromCache(assembly);
}
private void AssemblyManagerOnAssemblyLoaded(Assembly assembly)
{
//ReflectionUtils.AddNonAbstractAssemblyTypes(assembly);
// As ReflectionUtils.GetDerivedNonAbstract is only used for Prefabs & Barotrauma-specific implementing types,
// we can safely not register System/Core assemblies.
if (assembly.FullName is not null && assembly.FullName.StartsWith("System."))
return;
ReflectionUtils.AddNonAbstractAssemblyTypes(assembly, true);
}
internal CsPackageManager([NotNull] AssemblyManager assemblyManager, [NotNull] LuaCsSetup luaCsSetup)
{
this._assemblyManager = assemblyManager;
this._luaCsSetup = luaCsSetup;
}
~CsPackageManager()
{
this.Dispose();
}
private static bool TryScanPackageForScripts(ContentPackage package, out ImmutableList<string> scriptFilePaths)
{
string pathShared = Path.Combine(ModUtils.IO.GetContentPackageDir(package), "CSharp", "Shared");
string pathArch = Path.Combine(ModUtils.IO.GetContentPackageDir(package), "CSharp", ARCHITECTURE_TARGET);
List<string> files = new();
if (Directory.Exists(pathShared))
files.AddRange(Directory.GetFiles(pathShared, SCRIPT_FILE_REGEX, SearchOption.AllDirectories));
if (Directory.Exists(pathArch))
files.AddRange(Directory.GetFiles(pathArch, SCRIPT_FILE_REGEX, SearchOption.AllDirectories));
if (files.Count > 0)
{
scriptFilePaths = files.ToImmutableList();
return true;
}
scriptFilePaths = ImmutableList<string>.Empty;
return false;
}
private static bool TryScanPackagesForAssemblies(ContentPackage package, out ImmutableList<string> assemblyFilePaths)
{
string path = Path.Combine(ModUtils.IO.GetContentPackageDir(package), "bin", ARCHITECTURE_TARGET, PLATFORM_TARGET);
if (!Directory.Exists(path))
{
assemblyFilePaths = ImmutableList<string>.Empty;
return false;
}
assemblyFilePaths = System.IO.Directory.GetFiles(path, ASSEMBLY_FILE_REGEX, SearchOption.AllDirectories)
.ToImmutableList();
return assemblyFilePaths.Count > 0;
}
private static RunConfig GetRunConfigForPackage(ContentPackage package)
{
if (!GetOrCreateRunConfig(package, out var config))
config.AutoGenerated = true;
return config;
}
private IEnumerable<ContentPackage> BuildPackagesList()
{
// get unique list of content packages.
// Note: there is an old issue where the AllPackages group
// would sometimes not contain packages downloaded from the host, so we union enabled.
return ContentPackageManager.AllPackages.Union(ContentPackageManager.EnabledPackages.All).Where(pack => !pack.Name.ToLowerInvariant().Equals("vanilla"));
}
private static SyntaxTree GetPackageScriptImports() => BaseAssemblyImports;
/// <summary>
/// Builds a list of ContentPackage dependencies for each of the packages in the list. Note: All dependencies must be included in the provided list of packages.
/// </summary>
/// <param name="packages">List of packages to check</param>
/// <param name="dependenciesMap">Dependencies by package</param>
/// <returns>True if all dependencies were found.</returns>
private static bool TryBuildDependenciesMap(ImmutableList<ContentPackage> packages, out Dictionary<ContentPackage, List<ContentPackage>> dependenciesMap)
{
bool reliableMap = true; // remains true if all deps were found.
dependenciesMap = new();
foreach (var package in packages)
{
dependenciesMap.Add(package, new());
if (GetOrCreateRunConfig(package, out var config))
{
if (config.Dependencies is null || !config.Dependencies.Any())
continue;
foreach (RunConfig.Dependency dependency in config.Dependencies)
{
ContentPackage dep = packages.FirstOrDefault(p =>
(dependency.SteamWorkshopId != 0 && p.TryExtractSteamWorkshopId(out var steamWorkshopId)
&& steamWorkshopId.Value == dependency.SteamWorkshopId)
|| (!dependency.PackageName.IsNullOrWhiteSpace() && p.Name.ToLowerInvariant().Contains(dependency.PackageName.ToLowerInvariant())), null);
if (dep is not null)
{
dependenciesMap[package].Add(dep);
}
else
{
ModUtils.Logging.PrintError($"Warning! The ContentPackage {package.Name} lists a dependency of (STEAMID: {dependency.SteamWorkshopId}, PackageName: {dependency.PackageName}) but it could not be found in the to-be-loaded CSharp packages list!");
reliableMap = false;
}
}
}
else
{
ModUtils.Logging.PrintMessage($"Warning! Could not retrieve RunConfig for ContentPackage {package.Name}!");
}
}
return reliableMap;
}
/// <summary>
/// Given a table of packages and dependent packages, will sort them by dependency loading order along with packages
/// that cannot be loaded due to errors or failing the predicate checks.
/// </summary>
/// <param name="packages">A dictionary/map with key as the package and the elements as it's dependencies.</param>
/// <param name="readyToLoad">List of packages that are ready to load and in the correct order.</param>
/// <param name="cannotLoadPackages">Packages with errors or cyclic dependencies. Element is error message. Null if empty.</param>
/// <param name="packageChecksPredicate">Optional: Allows for a custom checks to be performed on each package.
/// Returns a bool indicating if the package is ready to load.</param>
/// <returns>Whether or not the process produces a usable list.</returns>
private static bool OrderAndFilterPackagesByDependencies(
Dictionary<ContentPackage, ImmutableList<ContentPackage>> packages,
out IEnumerable<ContentPackage> readyToLoad,
out IEnumerable<KeyValuePair<ContentPackage, string>> cannotLoadPackages,
Func<ContentPackage, bool> packageChecksPredicate = null)
{
HashSet<ContentPackage> completedPackages = new();
List<ContentPackage> readyPackages = new();
Dictionary<ContentPackage, string> unableToLoad = new();
HashSet<ContentPackage> currentNodeChain = new();
readyToLoad = readyPackages;
try
{
foreach (var toProcessPack in packages)
{
ProcessPackage(toProcessPack.Key, toProcessPack.Value);
}
PackageProcRet ProcessPackage(ContentPackage packageToProcess, IEnumerable<ContentPackage> dependencies)
{
//cyclic handling
if (unableToLoad.ContainsKey(packageToProcess))
{
return PackageProcRet.BadPackage;
}
// already processed
if (completedPackages.Contains(packageToProcess))
{
return PackageProcRet.AlreadyCompleted;
}
// cyclic check
if (currentNodeChain.Contains(packageToProcess))
{
StringBuilder sb = new();
sb.AppendLine("Error: Cyclic Dependency. ")
.Append(
"The following ContentPackages rely on eachother in a way that makes it impossible to know which to load first! ")
.Append(
"Note: the package listed twice shows where the cycle starts/ends and is not necessarily the problematic package.");
int i = 0;
foreach (var package in currentNodeChain)
{
i++;
sb.AppendLine($"{i}. {package.Name}");
}
sb.AppendLine($"{i}. {packageToProcess.Name}");
unableToLoad.Add(packageToProcess, sb.ToString());
completedPackages.Add(packageToProcess);
return PackageProcRet.BadPackage;
}
if (packageChecksPredicate is not null && !packageChecksPredicate.Invoke(packageToProcess))
{
unableToLoad.Add(packageToProcess, $"Unable to load package {packageToProcess.Name} due to failing checks.");
completedPackages.Add(packageToProcess);
return PackageProcRet.BadPackage;
}
currentNodeChain.Add(packageToProcess);
foreach (ContentPackage dependency in dependencies)
{
// The mod lists a dependent that was not found during the discovery phase.
if (!packages.ContainsKey(dependency))
{
// search to see if it's enabled
if (!ContentPackageManager.EnabledPackages.All.Contains(dependency))
{
// present warning but allow loading anyways, better to let the user just disable the package if it's really an issue.
ModUtils.Logging.PrintError(
$"Warning: the ContentPackage of {packageToProcess.Name} requires the Dependency {dependency.Name} but this package wasn't found in the enabled mods list!");
}
continue;
}
var ret = ProcessPackage(dependency, packages[dependency]);
if (ret is PackageProcRet.BadPackage)
{
if (!unableToLoad.ContainsKey(packageToProcess))
{
unableToLoad.Add(packageToProcess, $"Error: Dependency failure. Failed to load {dependency.Name}");
}
currentNodeChain.Remove(packageToProcess);
if (!completedPackages.Contains(packageToProcess))
{
completedPackages.Add(packageToProcess);
}
return PackageProcRet.BadPackage;
}
}
currentNodeChain.Remove(packageToProcess);
completedPackages.Add(packageToProcess);
readyPackages.Add(packageToProcess);
return PackageProcRet.Completed;
}
}
catch (Exception e)
{
ModUtils.Logging.PrintError($"Error while generating dependency loading order! Exception: {e.Message}");
#if DEBUG
ModUtils.Logging.PrintError($"Stack Trace: {e.StackTrace}");
#endif
cannotLoadPackages = unableToLoad.Any() ? unableToLoad : null;
return false;
}
cannotLoadPackages = unableToLoad.Any() ? unableToLoad : null;
return true;
}
private enum PackageProcRet : byte
{
AlreadyCompleted,
Completed,
BadPackage
}
private record LoadableData(ImmutableList<string> AssembliesFilePaths, ImmutableList<string> ScriptsFilePaths);
#endregion
}

View File

@@ -0,0 +1,22 @@
using System;
namespace Barotrauma;
public interface IAssemblyPlugin : IDisposable
{
/// <summary>
/// Called on plugin normal, use this for basic/core loading that does not rely on any other modded content.
/// </summary>
void Initialize();
/// <summary>
/// Called once all plugins have been loaded. if you have integrations with any other mod, put that code here.
/// </summary>
void OnLoadCompleted();
/// <summary>
/// Called before Barotrauma initializes vanilla content. WARNING: This method may be called before Initialize()!
/// </summary>
void PreInitPatching();
}

View File

@@ -0,0 +1,289 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.Loader;
using System.Threading;
using Barotrauma;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Emit;
namespace Barotrauma;
/// <summary>
/// AssemblyLoadContext to compile from syntax trees in memory and to load from disk/file. Provides dependency resolution.
/// [IMPORTANT] Only supports 1 in-memory compiled assembly at a time. Use more instances if you need more.
/// [IMPORTANT] All file assemblies required for the compilation of syntax trees should be loaded first.
/// </summary>
public class MemoryFileAssemblyContextLoader : AssemblyLoadContext
{
// public
// ReSharper disable MemberCanBePrivate.Global
public Assembly CompiledAssembly { get; private set; } = null;
public byte[] CompiledAssemblyImage { get; private set; } = null;
// ReSharper restore MemberCanBePrivate.Global
// internal
private readonly Dictionary<string, AssemblyDependencyResolver> _dependencyResolvers = new(); // path-folder, resolver
protected bool IsResolving; //this is to avoid circular dependency lookup.
private AssemblyManager _assemblyManager;
public bool IsTemplateMode { get; set; } = false;
public MemoryFileAssemblyContextLoader(AssemblyManager assemblyManager) : base(isCollectible: true)
{
this._assemblyManager = assemblyManager;
}
/// <summary>
/// Try to load the list of disk-file assemblies.
/// </summary>
/// <param name="assemblyFilePaths">Operation success or failure reason.</param>
public AssemblyLoadingSuccessState LoadFromFiles([NotNull] IEnumerable<string> assemblyFilePaths)
{
if (assemblyFilePaths is null)
throw new ArgumentNullException(
$"{nameof(MemoryFileAssemblyContextLoader)}::{nameof(LoadFromFiles)}() | The supplied filepath list is null.");
foreach (string filepath in assemblyFilePaths)
{
// path verification
if (filepath.IsNullOrWhiteSpace())
continue;
string sanitizedFilePath = System.IO.Path.GetFullPath(filepath.CleanUpPath());
string directoryKey = System.IO.Path.GetDirectoryName(sanitizedFilePath);
if (directoryKey is null)
return AssemblyLoadingSuccessState.BadFilePath;
// setup dep resolver if not available
if (!_dependencyResolvers.ContainsKey(directoryKey) || _dependencyResolvers[directoryKey] is null)
{
_dependencyResolvers[directoryKey] = new AssemblyDependencyResolver(sanitizedFilePath); // supply the first assembly to be loaded
}
// try loading the assemblies
try
{
LoadFromAssemblyPath(sanitizedFilePath);
}
// on fail of any we're done because we assume that loaded files are related. This ACL needs to be unloaded and collected.
catch (ArgumentNullException ane)
{
return AssemblyLoadingSuccessState.BadFilePath;
}
catch (ArgumentException ae)
{
return AssemblyLoadingSuccessState.BadFilePath;
}
catch (FileLoadException fle)
{
return AssemblyLoadingSuccessState.CannotLoadFile;
}
catch (FileNotFoundException fne)
{
return AssemblyLoadingSuccessState.NoAssemblyFound;
}
catch (BadImageFormatException bfe)
{
return AssemblyLoadingSuccessState.InvalidAssembly;
}
catch (Exception e)
{
#if SERVER
LuaCsLogger.LogError($"Unable to load dependency assembly file at {filepath.CleanUpPath()} for the assembly named {CompiledAssembly?.FullName}. | Data: {e.Message} | InnerException: {e.InnerException}");
#elif CLIENT
LuaCsLogger.ShowErrorOverlay($"Unable to load dependency assembly file at {filepath} for the assembly named {CompiledAssembly?.FullName}. | Data: {e.Message} | InnerException: {e.InnerException}");
#endif
return AssemblyLoadingSuccessState.ACLLoadFailure;
}
}
return AssemblyLoadingSuccessState.Success;
}
/// <summary>
/// Compiles the supplied syntaxtrees and options into an in-memory assembly image.
/// Builds metadata from loaded assemblies, only supply your own if you have in-memory images not managed by the
/// AssemblyManager class.
/// </summary>
/// <param name="assemblyName">Name of the assembly. Must be supplied for in-memory assemblies.</param>
/// <param name="syntaxTrees">Syntax trees to compile into the assembly.</param>
/// <param name="externMetadataReferences">Metadata to be used for compilation.
/// [IMPORTANT] This method builds metadata from loaded assemblies, only supply your own if you have in-memory
/// images not managed by the AssemblyManager class.</param>
/// <param name="compilationOptions">CSharp compilation options. This method automatically adds the 'IgnoreAccessChecks' property for compilation.</param>
/// <param name="compilationMessages">Will contain any diagnostic messages for compilation failure.</param>
/// <param name="externFileAssemblyReferences">Additional assemblies located in the FileSystem to build metadata references from.
/// Assemblies here will have duplicates by the same name that are currently loaded filtered out.</param>
/// <returns>Success state of the operation.</returns>
/// <exception cref="ArgumentNullException">Throws exception if any of the required arguments are null.</exception>
public AssemblyLoadingSuccessState CompileAndLoadScriptAssembly(
[NotNull] string assemblyName,
[NotNull] IEnumerable<SyntaxTree> syntaxTrees,
IEnumerable<MetadataReference> externMetadataReferences,
[NotNull] CSharpCompilationOptions compilationOptions,
out string compilationMessages,
IEnumerable<Assembly> externFileAssemblyReferences = null)
{
compilationMessages = "";
if (this.CompiledAssembly is not null)
{
return AssemblyLoadingSuccessState.AlreadyLoaded;
}
var externAssemblyRefs = externFileAssemblyReferences is not null ? externFileAssemblyReferences.ToImmutableList() : ImmutableList<Assembly>.Empty;
var externAssemblyNames = externAssemblyRefs.Any() ? externAssemblyRefs
.Where(a => a.FullName is not null)
.Select(a => a.FullName).ToImmutableHashSet()
: ImmutableHashSet<string>.Empty;
// verifications
if (assemblyName.IsNullOrWhiteSpace())
throw new ArgumentNullException(
$"{nameof(MemoryFileAssemblyContextLoader)}::{nameof(CompileAndLoadScriptAssembly)}() | The supplied assembly name is null!");
if (syntaxTrees is null)
throw new ArgumentNullException(
$"{nameof(MemoryFileAssemblyContextLoader)}::{nameof(CompileAndLoadScriptAssembly)}() | The supplied syntax tree is null!");
// add external references
List<MetadataReference> metadataReferences = new();
if (externMetadataReferences is not null)
metadataReferences.AddRange(externMetadataReferences);
// build metadata refs from global where not an in-memory compiled assembly and not the same assembly as supplied.
metadataReferences.AddRange(AppDomain.CurrentDomain.GetAssemblies()
.Where(a =>
{
if (a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit"))
return false;
if (a.FullName is null)
return true;
return !externAssemblyNames.Contains(a.FullName); // exclude duplicates
})
.Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference)
.Union(externAssemblyRefs // add custom supplied assemblies
.Where(a => !(a.IsDynamic || string.IsNullOrEmpty(a.Location) || a.Location.Contains("xunit")))
.Select(a => MetadataReference.CreateFromFile(a.Location) as MetadataReference)
).ToList());
// build metadata refs from in-memory images
foreach (var loadedAcl in _assemblyManager.GetAllLoadedACLs())
{
if (loadedAcl.Acl.CompiledAssemblyImage is null || loadedAcl.Acl.CompiledAssemblyImage.Length == 0)
continue;
metadataReferences.Add(MetadataReference.CreateFromImage(loadedAcl.Acl.CompiledAssemblyImage));
}
// Change inaccessible options to allow public access to restricted members
var topLevelBinderFlagsProperty = typeof(CSharpCompilationOptions).GetProperty("TopLevelBinderFlags", BindingFlags.Instance | BindingFlags.NonPublic);
topLevelBinderFlagsProperty?.SetValue(compilationOptions, (uint)1 << 22);
// begin compilation
using var memoryCompilation = new MemoryStream();
// compile, emit
var result = CSharpCompilation.Create(assemblyName, syntaxTrees, metadataReferences, compilationOptions).Emit(memoryCompilation);
// check for errors
if (!result.Success)
{
IEnumerable<Diagnostic> failures = result.Diagnostics.Where(d => d.IsWarningAsError || d.Severity == DiagnosticSeverity.Error);
foreach (Diagnostic diagnostic in failures)
{
compilationMessages += $"\n{diagnostic}";
}
return AssemblyLoadingSuccessState.InvalidAssembly;
}
// read compiled assembly from memory stream into an in-memory assembly & image
memoryCompilation.Seek(0, SeekOrigin.Begin); // reset
try
{
CompiledAssembly = LoadFromStream(memoryCompilation);
CompiledAssemblyImage = memoryCompilation.ToArray();
}
catch (Exception e)
{
#if SERVER
LuaCsLogger.LogError($"Unable to load memory assembly from stream. | Data: {e.Message} | InnerException: {e.InnerException}");
#elif CLIENT
LuaCsLogger.ShowErrorOverlay($"Unable to load memory assembly from stream. | Data: {e.Message} | InnerException: {e.InnerException}");
#endif
return AssemblyLoadingSuccessState.CannotLoadFromStream;
}
return AssemblyLoadingSuccessState.Success;
}
[SuppressMessage("ReSharper", "ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract")]
protected override Assembly Load(AssemblyName assemblyName)
{
if (IsResolving)
return null; //circular resolution fast exit.
try
{
IsResolving = true;
// resolve self collection
Assembly ass = this.Assemblies.FirstOrDefault(a =>
a.FullName is not null && a.FullName.Equals(assemblyName.FullName), null);
if (ass is not null)
return ass;
// resolve to local folders
foreach (KeyValuePair<string,AssemblyDependencyResolver> pair in _dependencyResolvers)
{
var asspath = pair.Value.ResolveAssemblyToPath(assemblyName);
if (asspath is null)
continue;
ass = LoadFromAssemblyPath(asspath);
// ReSharper disable once ConditionIsAlwaysTrueOrFalse
if (ass is not null)
return ass;
}
//try resolve against other loaded alcs
foreach (var loadedAcL in _assemblyManager.GetAllLoadedACLs())
{
if (loadedAcL.Acl is null || loadedAcL.Acl.IsTemplateMode) continue;
try
{
ass = loadedAcL.Acl.LoadFromAssemblyName(assemblyName);
if (ass is not null)
return ass;
}
catch
{
// LoadFromAssemblyName throws, no need to propagate
}
}
ass = AssemblyLoadContext.Default.LoadFromAssemblyName(assemblyName);
if (ass is not null)
return ass;
}
finally
{
IsResolving = false;
}
return null;
}
private new void Unload()
{
CompiledAssembly = null;
CompiledAssemblyImage = null;
base.Unload();
}
}

View File

@@ -0,0 +1,111 @@
using System;
using System.Xml.Serialization;
namespace Barotrauma;
[Serializable]
public sealed class RunConfig
{
/// <summary>
/// How should scripts be run on the server.
/// </summary>
[XmlElement(ElementName = "Server")] public string Server;
/// <summary>
/// How should scripts be run on the client.
/// </summary>
[XmlElement(ElementName = "Client")] public string Client;
/// <summary>
/// List of dependencies by either Steam Workshop ID or by Partial Inclusive Name (ie. "ModDep" will match a mod named "A ModDependency").
/// PIN Dependency checks if ContentPackage names contains the dependency string.
/// </summary>
[XmlArrayItem(ElementName = "Dependency", IsNullable = true, Type = typeof(Dependency))]
[XmlArray]
public Dependency[] Dependencies { get; set; }
[XmlElement(ElementName = "AutoGenerated")]
public bool AutoGenerated { get; set; }
public RunConfig(bool autoGenerated)
{
this.AutoGenerated = autoGenerated;
if (autoGenerated)
{
(Client, Server) = ("None", "None");
}
}
public RunConfig() { } // For serialization use
[Serializable]
public sealed class Dependency
{
/// <summary>
/// Steam Workshop ID of the dependency.
/// </summary>
[XmlElement(ElementName = "SteamWorkshopId")]
public ulong SteamWorkshopId;
/// <summary>
/// Package Name of the dependency. Not needed if SteamWorkshopId is set.
/// </summary>
[XmlElement(ElementName = "PackageName")]
public string PackageName;
}
public RunConfig Sanitize()
{
try
{
Client = SanitizeRunSetting(Client);
}
catch (Exception e)
{
Client = "None";
}
try
{
Server = SanitizeRunSetting(Server);
}
catch (Exception e)
{
Server = "None";
}
Dependencies ??= new RunConfig.Dependency[] { };
static string SanitizeRunSetting(string str) =>
str switch
{
null => "None",
"" => "None",
" " => "None",
_ => str[0].ToString().ToUpper() + str.Substring(1).ToLower()
};
return this;
}
public bool IsForced()
{
#if CLIENT
return this.Client.Equals("Forced");
#elif SERVER
return this.Server.Equals("Forced");
#endif
}
public bool IsStandard()
{
#if CLIENT
return this.Client.Equals("Standard");
#elif SERVER
return this.Server.Equals("Standard");
#endif
}
public bool IsForcedOrStandard() => this.IsForced() || this.IsStandard();
}

View File

@@ -1,5 +1,6 @@
#nullable enable
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
@@ -9,26 +10,43 @@ namespace Barotrauma
{
public static class ReflectionUtils
{
private static readonly Dictionary<Assembly, ImmutableArray<Type>> cachedNonAbstractTypes
= new Dictionary<Assembly, ImmutableArray<Type>>();
private static readonly Dictionary<Assembly, Dictionary<Type, ImmutableArray<Type>>> cachedDerivedNonAbstract
= new Dictionary<Assembly, Dictionary<Type, ImmutableArray<Type>>>();
private static readonly ConcurrentDictionary<Assembly, ImmutableArray<Type>> CachedNonAbstractTypes
= new ConcurrentDictionary<Assembly, ImmutableArray<Type>>();
private static readonly ConcurrentDictionary<string, ImmutableArray<Type>> TypeSearchCache = new();
public static IEnumerable<Type> GetDerivedNonAbstract<T>()
{
Type t = typeof(T);
Assembly assembly = typeof(T).Assembly;
lock (cachedNonAbstractTypes)
string typeName = t.FullName ?? t.Name;
// search quick lookup cache
if (TypeSearchCache.TryGetValue(typeName, out var value))
{
if (!cachedNonAbstractTypes.ContainsKey(assembly))
{
AddNonAbstractAssemblyTypes(assembly);
}
return value;
}
// doesn't exist so let's add it.
Assembly assembly = typeof(T).Assembly;
if (!CachedNonAbstractTypes.ContainsKey(assembly))
{
AddNonAbstractAssemblyTypes(assembly);
}
// build cache from registered assemblies' types.
var list = CachedNonAbstractTypes.Values
.SelectMany(arr => arr.Where(type => type.IsSubclassOf(t)))
.ToImmutableArray();
#warning TODO: Add safety checks in case an assembly is unloaded without being removed from the cache.
return cachedNonAbstractTypes.Values.SelectMany(s => s.Where(t => t.IsSubclassOf(typeof(T))));
if (list.Length == 0)
{
return ImmutableArray<Type>.Empty; // No types, don't add to cache
}
if (!TypeSearchCache.TryAdd(typeName, list))
{
DebugConsole.LogError($"ReflectionUtils::AddNonAbstractAssemblyTypes() | Error while adding to quick lookup cache.");
}
return list;
}
/// <summary>
@@ -38,7 +56,7 @@ namespace Barotrauma
/// <param name="overwrite">Whether or not to overwrite an entry if the assembly already exists within it.</param>
public static void AddNonAbstractAssemblyTypes(Assembly assembly, bool overwrite = false)
{
if (cachedNonAbstractTypes.ContainsKey(assembly))
if (CachedNonAbstractTypes.ContainsKey(assembly))
{
if (!overwrite)
{
@@ -46,15 +64,20 @@ namespace Barotrauma
$"ReflectionUtils::AddNonAbstractAssemblyTypes() | The assembly [{assembly.GetName()}] already exists in the cache.");
return;
}
cachedNonAbstractTypes.Remove(assembly);
CachedNonAbstractTypes.Remove(assembly, out _);
}
try
{
if (!cachedNonAbstractTypes.TryAdd(assembly, assembly.GetTypes().Where(t => !t.IsAbstract).ToImmutableArray()))
if (!CachedNonAbstractTypes.TryAdd(assembly, assembly.GetSafeTypes().Where(t => !t.IsAbstract).ToImmutableArray()))
{
DebugConsole.LogError($"ReflectionUtils::AddNonAbstractAssemblyTypes() | Unable to add types from Assembly to cache.");
}
else
{
TypeSearchCache.Clear(); // Needs to be rebuilt to include potential new types
}
}
catch (ReflectionTypeLoadException e)
{
@@ -66,8 +89,22 @@ namespace Barotrauma
/// Removes an assembly from the cache for Barotrauma's Type lookup.
/// </summary>
/// <param name="assembly">Assembly to remove.</param>
public static void RemoveAssemblyFromCache(Assembly assembly) => cachedNonAbstractTypes.Remove(assembly);
public static void RemoveAssemblyFromCache(Assembly assembly)
{
CachedNonAbstractTypes.Remove(assembly, out _);
TypeSearchCache.Clear();
}
/// <summary>
/// Clears all cached assembly data and rebuilds types list only to include base Barotrauma types.
/// </summary>
internal static void ResetCache()
{
CachedNonAbstractTypes.Clear();
CachedNonAbstractTypes.TryAdd(typeof(ReflectionUtils).Assembly, typeof(ReflectionUtils).Assembly.GetSafeTypes().ToImmutableArray());
TypeSearchCache.Clear();
}
public static Option<TBase> ParseDerived<TBase, TInput>(TInput input) where TInput : notnull where TBase : notnull
{
static Option<TBase> none() => Option<TBase>.None();