diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/IEvents.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/IEvents.cs index d3c6ac779..4f98edf13 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/IEvents.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/IEvents.cs @@ -234,4 +234,9 @@ public interface IEventAssemblyContextUnloading : IEvent loaderService); } +public interface IEventAssemblyUnloading : IEvent +{ + void OnAssemblyUnloading(Assembly assembly); +} + #endregion diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PackageManagementService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PackageManagementService.cs index 09cdb0a32..608256034 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PackageManagementService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PackageManagementService.cs @@ -246,6 +246,8 @@ public sealed class PackageManagementService : IPackageManagementService if (!plugins.IsDefaultOrEmpty) { result.WithReasons(_pluginManagementService.LoadAssemblyResources(plugins).Reasons); + result.WithReasons(_pluginManagementService.ActivatePluginInstances( + plugins.Select(p => p.OwnerPackage).ToImmutableArray(), false).Reasons); } } @@ -271,10 +273,11 @@ public sealed class PackageManagementService : IPackageManagementService .Where(r => r.SupportedTargets.HasFlag(ModUtils.Environment.CurrentTarget)) .Where(r => !r.Optional || ( (r.RequiredPackages.IsDefaultOrEmpty || enabledPackagesIdents.Intersect(r.RequiredPackages).Any()) - && (r.IncompatiblePackages.IsDefaultOrEmpty || enabledPackagesIdents.Intersect(r.IncompatiblePackages).None())) - ).OrderBy(r => loadingOrder.IndexOf(r.OwnerPackage)) - .ThenBy(r => r.LoadPriority) - .ToImmutableArray(); + && (r.IncompatiblePackages.IsDefaultOrEmpty || enabledPackagesIdents.Intersect(r.IncompatiblePackages).None()))) + .OrderBy(r => r.Optional ? 1 : 0) // optional content last + .ThenBy(r => loadingOrder.IndexOf(r.OwnerPackage)) + .ThenBy(r => r.LoadPriority) + .ToImmutableArray(); } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PluginManagementService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PluginManagementService.cs index 8293805c6..393d15705 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PluginManagementService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/PluginManagementService.cs @@ -16,6 +16,7 @@ using Barotrauma.LuaCs.Events; using FluentResults; using FluentResults.LuaCs; using ImpromptuInterface.Build; +using LightInject; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.Text; @@ -102,20 +103,49 @@ public class PluginManagementService : IAssemblyManagementService private IAssemblyLoaderService.IFactory _assemblyLoaderFactory; private IStorageService _storageService; private ILoggerService _logger; + private IEventService _eventService; + private IConfigService _configService; + private ILuaScriptManagementService _luaScriptManagementService; private readonly ConcurrentDictionary _assemblyLoaders = new(); + private readonly ConcurrentDictionary> _pluginInstances = new(); + private readonly ConditionalWeakTable _unloadingAssemblyLoaders = new(); private readonly AsyncReaderWriterLock _operationsLock = new(); + private ServiceContainer _pluginInjectorContainer; public PluginManagementService( IServicesProvider serviceProvider, IAssemblyLoaderService.IFactory assemblyLoaderFactory, IStorageService storageService, - ILoggerService logger) + ILoggerService logger, + IEventService eventService, + ILuaScriptManagementService luaScriptManagementService, + IConfigService configService) { Guard.IsNotNull(serviceProvider, nameof(serviceProvider)); _serviceProvider = serviceProvider; _assemblyLoaderFactory = assemblyLoaderFactory; _storageService = storageService; _logger = logger; + _eventService = eventService; + _luaScriptManagementService = luaScriptManagementService; + _configService = configService; + } + + private ServiceContainer CreatePluginServiceContainer() + { + var container = new ServiceContainer(new ContainerOptions() + { + EnablePropertyInjection = true, + + }); + + container.Register(fac => _logger, new PerContainerLifetime()); + container.Register(fac => _storageService, new PerContainerLifetime()); + container.Register(fac => _eventService, new PerContainerLifetime()); + container.Register(fac => _luaScriptManagementService, new PerContainerLifetime()); + container.Register(fac => _configService, new PerContainerLifetime()); + + return container; } public Result> GetImplementingTypes(bool includeInterfaces = false, bool includeAbstractTypes = false, @@ -160,12 +190,118 @@ public class PluginManagementService : IAssemblyManagementService return null; } - public ImmutableArray> ActivateTypeInstances(ImmutableArray types, bool serviceInjection = true, - bool hostInstanceReference = false) where T : IDisposable + public FluentResults.Result ActivatePluginInstances(ImmutableArray executionOrder, bool excludeAlreadyRunningPackages = true) { - throw new NotImplementedException(); + if (executionOrder.IsDefaultOrEmpty) + { + ThrowHelper.ThrowArgumentNullException($"{nameof(ActivatePluginInstances)}: The ececution list provided is empty."); + } + using var lck = _operationsLock.AcquireWriterLock().ConfigureAwait(false).GetAwaiter().GetResult(); + IService.CheckDisposed(this); + + if (_assemblyLoaders.IsEmpty) + { + return FluentResults.Result.Ok(); + } + + var results = new FluentResults.Result(); + + var toLoad = _assemblyLoaders + .Where(al => executionOrder.Contains(al.Key)) + .Where(al => !excludeAlreadyRunningPackages || !_pluginInstances.ContainsKey(al.Key)) + .SelectMany(al => al.Value.Assemblies.Select(ass => (al.Key, ass))) + .SelectMany(kvp => kvp.ass.GetSafeTypes() + .Where(type => + type is { IsInterface: false, IsAbstract: false, IsGenericType: false } + && type.IsAssignableTo(typeof(IAssemblyPlugin))) + .Select(type => (kvp.Key, type))) + .GroupBy(kvp => kvp.Key, kvp => kvp.type) + .OrderBy(exeGrp => executionOrder.IndexOf(exeGrp.Key)) + .ToImmutableArray(); + + if (toLoad.Length == 0) + { + return FluentResults.Result.Ok(); + } + + var loadedPackagePlugins = + ImmutableArray.CreateBuilder<(ContentPackage Package, ImmutableArray Plugins)>(); + _pluginInjectorContainer ??= CreatePluginServiceContainer(); + + foreach (var packageTypes in toLoad) + { + var loadedTypes = ImmutableArray.CreateBuilder(); + foreach (var pluginType in packageTypes) + { + try + { + var plugin = (IAssemblyPlugin)Activator.CreateInstance(pluginType); + _pluginInjectorContainer.InjectProperties(plugin); + _pluginInjectorContainer.Register(pluginType, fac => plugin); + loadedTypes.Add(plugin); + } + catch (Exception e) + { + results.WithError(new ExceptionalError(e)); + continue; + } + } + loadedPackagePlugins.Add((packageTypes.Key, loadedTypes.ToImmutable())); + } + + var packPluginGroups = loadedPackagePlugins.ToImmutable(); + foreach (var packagePluginGrp in packPluginGroups) + { + if (_pluginInstances.TryGetValue(packagePluginGrp.Package, out var plugins)) + { + _pluginInstances[packagePluginGrp.Package] = plugins.Concat(packagePluginGrp.Plugins).ToImmutableArray(); + continue; + } + + _pluginInstances[packagePluginGrp.Package] = packagePluginGrp.Plugins; + } + + var pluginsToInit = packPluginGroups.SelectMany(ppg => ppg.Plugins).ToImmutableArray(); + + foreach (var plugin in pluginsToInit) + { + results.WithReasons(PluginInitRunner(plugin, p => p.PreInitPatching()).Reasons); + } + + _eventService.PublishEvent(sub => sub.PreInitPatching()); + + foreach (var plugin in pluginsToInit) + { + results.WithReasons(PluginInitRunner(plugin, p => p.Initialize()).Reasons); + } + + _eventService.PublishEvent(sub => sub.Initialize()); + + foreach (var plugin in pluginsToInit) + { + results.WithReasons(PluginInitRunner(plugin, p => p.OnLoadCompleted()).Reasons); + } + + _eventService.PublishEvent(sub => sub.OnLoadCompleted()); + + return results; + + // helper + FluentResults.Result PluginInitRunner(IAssemblyPlugin plugin, Action action) + { + try + { + action(plugin); + return FluentResults.Result.Ok(); + } + catch (Exception e) + { + return FluentResults.Result.Fail(new ExceptionalError(e)); + } + } } - + + public FluentResults.Result LoadAssemblyResources(ImmutableArray resources) { if (resources.IsDefaultOrEmpty) @@ -185,11 +321,15 @@ public class PluginManagementService : IAssemblyManagementService { LoadBinaries(contentPack); LoadAndCompileScriptAssemblies(contentPack); + foreach (var ass in _assemblyLoaders[contentPack.Key].Assemblies) + { + ReflectionUtils.AddNonAbstractAssemblyTypes(ass); + } } return result; - // helper methods + // --- helper methods void LoadBinaries(IGrouping contentPackRes) { var binaries = contentPackRes.Where(cRes => !cRes.IsScript) @@ -336,6 +476,9 @@ public class PluginManagementService : IAssemblyManagementService IEnumerable GetMetadataReferences() { +#if !DEBUG + throw new NotImplementedException($"Needs to use publicized barotrauma assemblies."); +#endif return Basic.Reference.Assemblies.Net80.References.All .Union(AppDomain.CurrentDomain.GetAssemblies() .Where(ass => !ass.Location.IsNullOrWhiteSpace()) @@ -348,14 +491,44 @@ public class PluginManagementService : IAssemblyManagementService throw new NotImplementedException(); } - private Assembly OnAssemblyLoaderResolvingManaged(IAssemblyLoaderService arg1, AssemblyName arg2) + private Assembly OnAssemblyLoaderResolvingManaged(IAssemblyLoaderService requestingLoader, AssemblyName searchName) { - throw new NotImplementedException(); + using var lck = _operationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); + IService.CheckDisposed(this); + + foreach (var loader in _assemblyLoaders.Where(kvp => kvp.Value != requestingLoader) + .Select(kvp => kvp.Value).ToImmutableArray()) + { + if (loader.IsReferenceOnlyMode || !loader.Assemblies.Any()) + { + continue; + } + + foreach (var assembly in loader.Assemblies) + { + if (assembly.GetName().Equals(searchName)) + { + return assembly; + } + } + } + + return null; } private void OnAssemblyLoaderUnloading(IAssemblyLoaderService loader) { - throw new NotImplementedException(); + using var lck = _operationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); + + if (!loader.Assemblies.Any()) + { + return; + } + + foreach (var assembly in loader.Assemblies) + { + _eventService.PublishEvent(sub => sub.OnAssemblyUnloading(assembly)); + } } public FluentResults.Result UnloadManagedAssemblies() @@ -367,13 +540,72 @@ public class PluginManagementService : IAssemblyManagementService { return FluentResults.Result.Ok(); } + + var results = new FluentResults.Result(); + + results.WithReasons(UnsafeDisposeManagedTypeInstances().Reasons); + + ReflectionUtils.ResetCache(); + bool[] targetGcGeneration = new bool[GC.MaxGeneration]; + + for (int i = 0; i < targetGcGeneration.Length; i++) + { + targetGcGeneration[i] = false; + } + foreach (var loaderService in _assemblyLoaders) { - + try + { + loaderService.Value.Dispose(); + targetGcGeneration[GC.GetGeneration(loaderService.Value)] = true; + _unloadingAssemblyLoaders.Add(loaderService.Value, loaderService.Key); + } + catch (Exception e) + { + results.WithError(new ExceptionalError(e)); + } } - throw new NotImplementedException(); + _assemblyLoaders.Clear(); + + for (int i = 0; i < targetGcGeneration.Length; i++) + { + if (!targetGcGeneration[i]) + { + GC.Collect(i, GCCollectionMode.Aggressive, true); + } + } + + return results; + } + + private FluentResults.Result UnsafeDisposeManagedTypeInstances() + { + var results = new FluentResults.Result(); + _pluginInjectorContainer = null; + if (_pluginInstances.IsEmpty) + { + return FluentResults.Result.Ok(); + } + + foreach (var instance in _pluginInstances.SelectMany(kvp => kvp.Value)) + { + try + { + instance.Dispose(); + } + catch (Exception e) + { + results.WithError(new ExceptionalError(e)); + continue; + } + } + + _pluginInstances.Clear(); + + return results; } public Result GetLoadedAssembly(OneOf assemblyName, in Guid[] excludedContexts) diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/_Interfaces/IPluginManagementService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/_Interfaces/IPluginManagementService.cs index 83cd59981..711e69580 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/_Interfaces/IPluginManagementService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/_Interfaces/IPluginManagementService.cs @@ -33,15 +33,12 @@ public interface IPluginManagementService : IReusableService Type GetType(string typeName, bool isByRefType = false, bool includeInterfaces = false, bool includeDefaultContext = true); /// - /// Creates instances of the given type and provides Property Injection and instance reference caching. Disposes of - /// all references that throw errors on + /// /// - /// List of Types - /// - /// + /// + /// /// - ImmutableArray> ActivateTypeInstances(ImmutableArray types, bool serviceInjection = true, - bool hostInstanceReference = false) where T : IDisposable; + FluentResults.Result ActivatePluginInstances(ImmutableArray executionOrder, bool excludeAlreadyRunningPackages = true); /// /// Loads the provided assembly resources in the order of their dependencies and intra-mod priority load order. diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/_Plugins/AssemblyLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/_Plugins/AssemblyLoader.cs index 7a82c8494..87f920d70 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/_Plugins/AssemblyLoader.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/_Plugins/AssemblyLoader.cs @@ -551,6 +551,7 @@ public sealed class AssemblyLoader : AssemblyLoadContext, IAssemblyLoaderService IsDisposed = true; this.Unload(); this.DisposeInternal(); + GC.SuppressFinalize(this); } ~AssemblyLoader() @@ -582,7 +583,6 @@ public sealed class AssemblyLoader : AssemblyLoadContext, IAssemblyLoaderService base.Unloading -= OnUnload; this._dependencyResolvers.Clear(); this._loadedAssemblyData.Clear(); - GC.SuppressFinalize(this); } protected override Assembly Load(AssemblyName assemblyName)