diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ILuaScriptLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ILuaScriptLoader.cs index f86dcd965..5dde8dfbf 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ILuaScriptLoader.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ILuaScriptLoader.cs @@ -6,7 +6,7 @@ using MoonSharp.Interpreter.Loaders; namespace Barotrauma.LuaCs.Services.Safe; -public interface ILuaScriptLoader : IService, IScriptLoader +public interface ILuaScriptLoader : IService, IScriptLoader, ISafeStorageValidation { void ClearCaches(); Task)>>> CacheResourcesAsync(ImmutableArray resourceInfos); diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ISafeStorageService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ISafeStorageService.cs index 93247a023..773bbbc28 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ISafeStorageService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/ISafeStorageService.cs @@ -2,7 +2,9 @@ namespace Barotrauma.LuaCs.Services.Safe; -public interface ISafeStorageService : IStorageService +public interface ISafeStorageService : IStorageService, ISafeStorageValidation { } + +public interface ISafeStorageValidation { /// /// Checks the given file path to see if it can be read. This includes any permissions, whitelists and OS checks. @@ -16,26 +18,33 @@ public interface ISafeStorageService : IStorageService /// /// Adds the given path to the specified whitelists. /// - /// Either the fully-qualified or local reference path to the given file. - /// + /// The path to the file, exactly as it will be passed to the Try(Load|Save) methods in . + /// Whether to add it to the read whitelist only, or Read+Write whitelists. void AddFileToWhitelist(string path, bool readOnly = true); /// - /// Removes the given path from all whitelists (Read|Write). + /// Adds the given collection of file paths to whitelists (Read|+Write) + /// + /// The paths to the files, formatted exactly as it will be passed to the Try(Load|Save) methods in . + /// Whether to add it to the read whitelist only, or Read+Write whitelists. + void AddFilesToWhitelist(ImmutableArray paths, bool readOnly = true); + + /// + /// Removes the given path from all whitelists (Read|+Write). /// /// void RemoveFileFromAllWhitelists(string path); /// - /// Sets the whitelist filtering for read-only file permissions for the instance. + /// Sets the whitelist filtering for read-only file permissions for the instance. Overwrites previous list. /// - /// List of absolute file paths allowed. + /// List of file paths allowed, as will be passed to the Try(Load|Save) methods. FluentResults.Result SetReadOnlyWhitelist(ImmutableArray filePaths); /// - /// Sets the whitelist filtering for read & write file permissions for the instance. + /// Sets the whitelist filtering for read & write file permissions for the instance. Overwrites previous lists. /// - /// List of absolute file paths allowed. + /// List of file paths allowed, as will be passed to the Try(Load|Save) methods. FluentResults.Result SetReadWriteWhitelist(ImmutableArray filePaths); /// diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaScriptLoader.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaScriptLoader.cs index a2e23af47..c78f8c221 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaScriptLoader.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/LuaScriptLoader.cs @@ -15,99 +15,77 @@ namespace Barotrauma.LuaCs.Services.Safe { public class LuaScriptLoader : ScriptLoaderBase, ILuaScriptLoader { - public LuaScriptLoader(IStorageService storageService, Lazy loggerService, ILuaScriptServicesConfig luaScriptServicesConfig) + public LuaScriptLoader(ISafeStorageService storageService, Lazy loggerService) { this._storageService = storageService; this._loggerService = loggerService; - this._luaScriptServicesConfig = luaScriptServicesConfig; - _storageService.UseCaching = _luaScriptServicesConfig.UseCaching; - if (_luaScriptServicesConfig.SafeLuaIOEnabled) - { - //_storageService.EnableWhitelistOnly(); - } } - private readonly IStorageService _storageService; + private readonly ISafeStorageService _storageService; private readonly Lazy _loggerService; - private readonly ILuaScriptServicesConfig _luaScriptServicesConfig; public override object LoadFile(string file, Table globalContext) { - ((IService)this).CheckDisposed(); - - if (!CanReadFromPath(file)) + IService.CheckDisposed(this); + if (file.IsNullOrWhiteSpace()) { - LogErrors($"File access to \"{file}\" is not allowed."); + return null; + } + + var res = _storageService.TryLoadText(file); + + if (res.IsFailed || res is not { Value: { } script}) + { + UnsafeLogErrors($"Failed to load file '{file}'.", res.ToResult()); return null; } - if (_storageService.TryLoadText(file) is not { IsSuccess: true, Value: not null } script) + if (script.IsNullOrWhiteSpace()) { - LogErrors($"Failed to load file \"{file}\"."); + UnsafeLogErrors($"The file '{file}' is empty. ", res.ToResult()); return null; } - if (script.Value.IsNullOrWhiteSpace()) - { - LogErrors($"The file \"{file}\" was empty."); - return null; - } - - return script.Value; + return script; } public void ClearCaches() { - ((IService)this).CheckDisposed(); + IService.CheckDisposed(this); _storageService?.PurgeCache(); } public async Task)>>> CacheResourcesAsync(ImmutableArray resourceInfos) { - // TODO: Needs an async lock? IService.CheckDisposed(this); if (!_storageService.UseCaching) + { return FluentResults.Result.Fail($"Caching is not enabled."); + } + return await this._storageService.LoadPackageTextFilesAsync([..resourceInfos.SelectMany(ri => ri.FilePaths)]); } public override bool ScriptFileExists(string file) { - ((IService)this).CheckDisposed(); - - if (!CanReadFromPath(file)) - { - LogErrors($"File access to \"{file}\" is not allowed."); - return false; - } - + IService.CheckDisposed(this); var result = _storageService.FileExists(file); - if (result is { IsFailed: true }) { - LogErrors($"Unable to find and load file \"{file}\"."); + UnsafeLogErrors($"Unable to find and load file \"{file}\".", result.ToResult()); return false; } - return result.IsSuccess; + return true; } - private bool CanReadFromPath(string file) - { - throw new NotImplementedException(); - } - - private bool CanWriteToPath(string file) - { - throw new NotImplementedException(); - } - - private void LogErrors(string message, FluentResults.Result result = null) + private void UnsafeLogErrors(string message, FluentResults.Result result = null) { _loggerService.Value.LogError($"{nameof(LuaScriptLoader)}: {message}"); - - if (result is null || result.Errors.Count <= 0) + if (result is null || result.Errors.Count <= 0) + { return; + } foreach (var error in result.Errors) { @@ -117,14 +95,58 @@ namespace Barotrauma.LuaCs.Services.Safe public void Dispose() { - if (IsDisposed) + if (!ModUtils.Threading.CheckIfClearAndSetBool(ref _isDisposed)) + { return; - IsDisposed = true; + } _storageService?.Dispose(); _loggerService?.Value.Dispose(); } - public bool IsDisposed { get; private set; } + private int _isDisposed = 0; + public bool IsDisposed => ModUtils.Threading.GetBool(ref _isDisposed); + + public bool IsFileAccessible(string path, bool readOnly, bool checkWhitelistOnly = true) + { + IService.CheckDisposed(this); + return _storageService.IsFileAccessible(path, readOnly, checkWhitelistOnly); + } + + public void AddFileToWhitelist(string path, bool readOnly = true) + { + IService.CheckDisposed(this); + _storageService.AddFileToWhitelist(path, readOnly); + } + + public void AddFilesToWhitelist(ImmutableArray paths, bool readOnly = true) + { + IService.CheckDisposed(this); + _storageService.AddFilesToWhitelist(paths, readOnly); + } + + public void RemoveFileFromAllWhitelists(string path) + { + IService.CheckDisposed(this); + _storageService.RemoveFileFromAllWhitelists(path); + } + + public FluentResults.Result SetReadOnlyWhitelist(ImmutableArray filePaths) + { + IService.CheckDisposed(this); + return _storageService.SetReadOnlyWhitelist(filePaths); + } + + public FluentResults.Result SetReadWriteWhitelist(ImmutableArray filePaths) + { + IService.CheckDisposed(this); + return _storageService.SetReadWriteWhitelist(filePaths); + } + + public void ClearAllWhitelists() + { + IService.CheckDisposed(this); + _storageService.ClearAllWhitelists(); + } } } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/SafeStorageService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/SafeStorageService.cs index 77e9a2aee..d6063366f 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/SafeStorageService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/SafeStorageService.cs @@ -11,33 +11,47 @@ using Barotrauma.LuaCs.Data; using FarseerPhysics.Common; using FluentResults; using FluentResults.LuaCs; +using Microsoft.Toolkit.Diagnostics; using Path = System.IO.Path; namespace Barotrauma.LuaCs.Services.Safe; public class SafeStorageService : StorageService, ISafeStorageService { - private ConcurrentDictionary _fileListRead = new (), _fileListWrite = new(); + private ConcurrentDictionary + _fileListRead = new (), + _fileListWrite = new(); + private readonly AsyncReaderWriterLock _higherOperationsLock = new(); public SafeStorageService(IStorageServiceConfig configData) : base(configData) { + IsReadOperationAllowedEval = async Task (fp) => IsFileAccessible(fp, true, true); + IsWriteOperationAllowedEval = async Task (fp) => IsFileAccessible(fp, false, true); } private string GetFullPath(string path) => System.IO.Path.GetFullPath(path).CleanUpPathCrossPlatform(); public bool IsFileAccessible(string path, bool readOnly, bool checkWhitelistOnly = true) { - ((IService)this).CheckDisposed(); + Guard.IsNotNullOrWhiteSpace(path, nameof(path)); + using var lck = _higherOperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); + IService.CheckDisposed(this); try { path = GetFullPath(path); if (!_fileListRead.ContainsKey(path)) + { return false; + } if (!readOnly && !_fileListWrite.ContainsKey(path)) + { return false; + } if (checkWhitelistOnly) + { return true; + } using var fs = System.IO.File.Open( path, FileMode.Open, readOnly ? FileAccess.Read : FileAccess.ReadWrite, FileShare.ReadWrite); return readOnly ? fs.CanRead : fs.CanWrite; @@ -50,13 +64,18 @@ public class SafeStorageService : StorageService, ISafeStorageService public void AddFileToWhitelist(string path, bool readOnly = true) { - ((IService)this).CheckDisposed(); + Guard.IsNotNullOrWhiteSpace(path, nameof(path)); + using var lck = _higherOperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); + IService.CheckDisposed(this); + try { path = GetFullPath(path); _fileListRead.AddOrUpdate(path, s => 0, (s, b) => 0); if (!readOnly) + { _fileListWrite.AddOrUpdate(path, s => 0, (s, b) => 0); + } } catch { @@ -64,9 +83,23 @@ public class SafeStorageService : StorageService, ISafeStorageService } } + public void AddFilesToWhitelist(ImmutableArray paths, bool readOnly = true) + { + if (paths.IsDefaultOrEmpty) + ThrowHelper.ThrowArgumentNullException(nameof(paths)); + foreach (var path in paths) + { + AddFileToWhitelist(path, readOnly); + } + } + + public void RemoveFileFromAllWhitelists(string path) { - ((IService)this).CheckDisposed(); + Guard.IsNotNullOrWhiteSpace(path, nameof(path)); + using var lck = _higherOperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); + IService.CheckDisposed(this); + try { path = GetFullPath(path); @@ -81,13 +114,18 @@ public class SafeStorageService : StorageService, ISafeStorageService public FluentResults.Result SetReadOnlyWhitelist(ImmutableArray filePaths) { - ((IService)this).CheckDisposed(); + using var lck = _higherOperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); + IService.CheckDisposed(this); if (filePaths.IsDefaultOrEmpty) + { return FluentResults.Result.Fail($"{nameof(SetReadOnlyWhitelist)}: FilePaths cannot be empty."); + } + _fileListRead.Clear(); var res = new FluentResults.Result(); foreach (var path in filePaths) { + Guard.IsNotNullOrWhiteSpace(path, nameof(path)); try { var p = Path.GetFullPath(path.CleanUpPathCrossPlatform()); @@ -121,14 +159,19 @@ public class SafeStorageService : StorageService, ISafeStorageService public FluentResults.Result SetReadWriteWhitelist(ImmutableArray filePaths) { - ((IService)this).CheckDisposed(); if (filePaths.IsDefaultOrEmpty) + { return FluentResults.Result.Fail($"{nameof(SetReadOnlyWhitelist)}: FilePaths cannot be empty."); + } + using var lck = _higherOperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); + IService.CheckDisposed(this); + _fileListRead.Clear(); _fileListWrite.Clear(); var res = new FluentResults.Result(); foreach (var path in filePaths) { + Guard.IsNotNullOrWhiteSpace(path, nameof(path)); try { var p = Path.GetFullPath(path.CleanUpPathCrossPlatform()); @@ -167,115 +210,9 @@ public class SafeStorageService : StorageService, ISafeStorageService public void ClearAllWhitelists() { - ((IService)this).CheckDisposed(); + using var lck = _higherOperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); + IService.CheckDisposed(this); _fileListRead.Clear(); _fileListWrite.Clear(); } - - #region Base_Overrides - - private bool ReadCheck(string path) - { - return IsFileAccessible(path, true, true); - } - - private bool WriteCheck(string path) - { - return IsFileAccessible(path, false, true); - } - - public override Result FileExists(string filePath) - { - if (!ReadCheck(filePath)) - return FluentResults.Result.Fail("Cannot access file."); - return base.FileExists(filePath); - } - - public override Result TryLoadBinary(string filePath) - { - if (!ReadCheck(filePath)) - return FluentResults.Result.Fail("Cannot access file."); - return base.TryLoadBinary(filePath); - } - - public override async Task> TryLoadBinaryAsync(string filePath) - { - if (!ReadCheck(filePath)) - return FluentResults.Result.Fail("Cannot access file."); - return await base.TryLoadBinaryAsync(filePath); - } - - public override Result TryLoadText(string filePath, Encoding encoding = null) - { - if (!ReadCheck(filePath)) - return FluentResults.Result.Fail("Cannot access file."); - return base.TryLoadText(filePath, encoding); - } - - public override async Task> TryLoadTextAsync(string filePath, Encoding encoding = null) - { - if (!ReadCheck(filePath)) - return FluentResults.Result.Fail("Cannot access file."); - return await base.TryLoadTextAsync(filePath, encoding); - } - - public override Result TryLoadXml(string filePath, Encoding encoding = null) - { - if (!ReadCheck(filePath)) - return FluentResults.Result.Fail("Cannot access file."); - return base.TryLoadXml(filePath, encoding); - } - - public override async Task> TryLoadXmlAsync(string filePath, Encoding encoding = null) - { - if (!ReadCheck(filePath)) - return FluentResults.Result.Fail("Cannot access file."); - return await base.TryLoadXmlAsync(filePath, encoding); - } - - public override FluentResults.Result TrySaveBinary(string filePath, in byte[] bytes) - { - if (!WriteCheck(filePath)) - return FluentResults.Result.Fail("Cannot write to file."); - return base.TrySaveBinary(filePath, in bytes); - } - - public override async Task TrySaveBinaryAsync(string filePath, byte[] bytes) - { - if (!WriteCheck(filePath)) - return FluentResults.Result.Fail("Cannot write to file."); - return await base.TrySaveBinaryAsync(filePath, bytes); - } - - public override FluentResults.Result TrySaveText(string filePath, in string text, Encoding encoding = null) - { - if (!WriteCheck(filePath)) - return FluentResults.Result.Fail("Cannot write to file."); - return base.TrySaveText(filePath, in text, encoding); - } - - public override async Task TrySaveTextAsync(string filePath, string text, Encoding encoding = null) - { - if (!WriteCheck(filePath)) - return FluentResults.Result.Fail("Cannot write to file."); - return await base.TrySaveTextAsync(filePath, text, encoding); - } - - public override FluentResults.Result TrySaveXml(string filePath, in XDocument document, Encoding encoding = null) - { - if (!WriteCheck(filePath)) - return FluentResults.Result.Fail("Cannot write to file."); - return base.TrySaveXml(filePath, in document, encoding); - } - - public override async Task TrySaveXmlAsync(string filePath, XDocument document, Encoding encoding = null) - { - if (!WriteCheck(filePath)) - return FluentResults.Result.Fail("Cannot write to file."); - return await base.TrySaveXmlAsync(filePath, document, encoding); - } - - #endregion - - } diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/StorageService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/StorageService.cs index 724b2c4dc..220a6570f 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/StorageService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/StorageService.cs @@ -21,19 +21,54 @@ public class StorageService : IStorageService public StorageService(IStorageServiceConfig configData) { ConfigData = configData; + IsReadOperationAllowedEval = async Task (str) => true; + IsWriteOperationAllowedEval = async Task (str) => true; + } + + public StorageService(IStorageServiceConfig configData, + Func> isReadOperationAllowedEval, + Func> isWriteOperationAllowedEval) + { + Guard.IsNotNull(isReadOperationAllowedEval, nameof(isReadOperationAllowedEval)); + Guard.IsNotNull(isWriteOperationAllowedEval, nameof(isWriteOperationAllowedEval)); + ConfigData = configData; + IsReadOperationAllowedEval = isReadOperationAllowedEval; + IsWriteOperationAllowedEval = isWriteOperationAllowedEval; } private readonly ConcurrentDictionary> _fsCache = new(); protected readonly IStorageServiceConfig ConfigData; protected readonly AsyncReaderWriterLock OperationsLock = new(); + + private Func> _isReadOperationAllowedEval; + protected Func> IsReadOperationAllowedEval + { + get => _isReadOperationAllowedEval; + set + { + if (value is not null) + _isReadOperationAllowedEval = value; + } + } + + private Func> _isWriteOperationAllowedEval; + protected Func> IsWriteOperationAllowedEval + { + get => _isWriteOperationAllowedEval; + set + { + if (value is not null) + _isWriteOperationAllowedEval = value; + } + } public bool IsDisposed => ModUtils.Threading.GetBool(ref _isDisposed); private int _isDisposed = 0; - public void Dispose() + public virtual void Dispose() { + using var lck = OperationsLock.AcquireWriterLock().ConfigureAwait(false).GetAwaiter().GetResult(); if (!ModUtils.Threading.CheckIfClearAndSetBool(ref _isDisposed)) return; - using var lck = OperationsLock.AcquireWriterLock().ConfigureAwait(false).GetAwaiter().GetResult(); _fsCache.Clear(); } @@ -254,7 +289,10 @@ public class StorageService : IStorageService using var lck = await OperationsLock.AcquireReaderLock(); IService.CheckDisposed(this); if (!filePath.FullPath.StartsWith(ConfigData.WorkshopModsDirectory) && !filePath.FullPath.StartsWith(ConfigData.LocalModsDirectory)) - ThrowHelper.ThrowUnauthorizedAccessException($"{nameof(LoadPackageData)}: The filepath of `{filePath.FullPath}' is not in a package directory!"); + { + ThrowHelper.ThrowUnauthorizedAccessException( + $"{nameof(LoadPackageData)}: The filepath of `{filePath.FullPath}' is not in a package directory!"); + } return await dataLoader(filePath.FullPath); } @@ -269,7 +307,9 @@ public class StorageService : IStorageService ImmutableArray filePaths, Func>> dataLoader) { if (filePaths.IsDefaultOrEmpty) + { ThrowHelper.ThrowArgumentNullException($"{nameof(LoadPackageData)}: File paths is empty!"); + } using var lck = await OperationsLock.AcquireReaderLock(); var builder = ImmutableArray.CreateBuilder<(ContentPath, Result)>(); foreach (var path in filePaths) @@ -299,6 +339,7 @@ public class StorageService : IStorageService public virtual FluentResults.Result TryLoadXml(string filePath, Encoding encoding) { + Guard.IsNotNullOrWhiteSpace(filePath, nameof(filePath)); using var lck = OperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); IService.CheckDisposed(this); @@ -316,9 +357,15 @@ public class StorageService : IStorageService private FluentResults.Result TryLoadText(string filePath) => TryLoadText(filePath, null); public virtual FluentResults.Result TryLoadText(string filePath, Encoding encoding) { + Guard.IsNotNullOrWhiteSpace(filePath, nameof(filePath)); using var lck = OperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); IService.CheckDisposed(this); + if (IsReadOperationAllowedEval?.Invoke(filePath).ConfigureAwait(false).GetAwaiter().GetResult() is not true) + { + return FluentResults.Result.Fail($"{nameof(TryLoadText)}: File '{filePath}' is not allowed."); + } + if (UseCaching && _fsCache.TryGetValue(filePath, out var result) && result.TryPickT1(out var cachedVal, out _)) { @@ -338,9 +385,15 @@ public class StorageService : IStorageService public virtual FluentResults.Result TryLoadBinary(string filePath) { + Guard.IsNotNullOrWhiteSpace(filePath, nameof(filePath)); using var lck = OperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); IService.CheckDisposed(this); + if (IsReadOperationAllowedEval?.Invoke(filePath).ConfigureAwait(false).GetAwaiter().GetResult() is not true) + { + return FluentResults.Result.Fail($"{nameof(TryLoadBinary)}: File '{filePath}' is not allowed."); + } + if (UseCaching && _fsCache.TryGetValue(filePath, out var result) && result.TryPickT0(out var cachedVal, out _)) { @@ -353,7 +406,9 @@ public class StorageService : IStorageService fp = System.IO.Path.IsPathRooted(fp) ? fp : System.IO.Path.GetFullPath(fp); var fileData = System.IO.File.ReadAllBytes(fp); if (UseCaching) + { _fsCache[filePath] = fileData; + } return new FluentResults.Result().WithSuccess($"Loaded file successfully").WithValue(fileData); }); } @@ -365,6 +420,11 @@ public class StorageService : IStorageService using var lck = OperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); IService.CheckDisposed(this); + if (IsWriteOperationAllowedEval?.Invoke(filePath).ConfigureAwait(false).GetAwaiter().GetResult() is not true) + { + return FluentResults.Result.Fail($"{nameof(TrySaveText)}: File '{filePath}' is not allowed."); + } + string t = text; //copy return IOExceptionsOperationRunner(nameof(TrySaveText), filePath, () => { @@ -380,11 +440,16 @@ public class StorageService : IStorageService public virtual FluentResults.Result TrySaveBinary(string filePath, in byte[] bytes) { + Guard.IsNotNullOrWhiteSpace(filePath, nameof(filePath)); Guard.IsNotNull(bytes, nameof(bytes)); Guard.HasSizeGreaterThanOrEqualTo(bytes, 1, nameof(bytes)); using var lck = OperationsLock.AcquireReaderLock().ConfigureAwait(false).GetAwaiter().GetResult(); IService.CheckDisposed(this); + if (IsWriteOperationAllowedEval?.Invoke(filePath).ConfigureAwait(false).GetAwaiter().GetResult() is not true) + { + return FluentResults.Result.Fail($"{nameof(TrySaveBinary)}: File '{filePath}' is not allowed."); + } byte[] b = new byte[bytes.Length]; System.Buffer.BlockCopy(bytes, 0, b, 0, bytes.Length); @@ -401,7 +466,14 @@ public class StorageService : IStorageService public virtual FluentResults.Result FileExists(string filePath) { + Guard.IsNotNullOrWhiteSpace(filePath, nameof(filePath)); IService.CheckDisposed(this); + // lock not needed + if (IsReadOperationAllowedEval?.Invoke(filePath).ConfigureAwait(false).GetAwaiter().GetResult() is not true) + { + return FluentResults.Result.Fail($"{nameof(FileExists)}: File '{filePath}' is not allowed."); + } + return IOExceptionsOperationRunner(nameof(FileExists), filePath, () => { var fp = filePath.CleanUpPath(); @@ -412,7 +484,14 @@ public class StorageService : IStorageService public virtual FluentResults.Result DirectoryExists(string directoryPath) { + Guard.IsNotNullOrWhiteSpace(directoryPath, nameof(directoryPath)); IService.CheckDisposed(this); + // lock not needed + if (IsReadOperationAllowedEval?.Invoke(directoryPath).ConfigureAwait(false).GetAwaiter().GetResult() is not true) + { + return FluentResults.Result.Fail($"{nameof(DirectoryExists)}: File '{directoryPath}' is not allowed."); + } + try { var di = new DirectoryInfo(directoryPath); @@ -426,10 +505,20 @@ public class StorageService : IStorageService public virtual async Task> TryLoadXmlAsync(string filePath, Encoding encoding = null) { + Guard.IsNotNullOrWhiteSpace(filePath, nameof(filePath)); + using var lck = await OperationsLock.AcquireReaderLock(); IService.CheckDisposed(this); + if (await IsReadOperationAllowedEval.Invoke(filePath) is not true) + { + return FluentResults.Result.Fail($"{nameof(TryLoadXmlAsync)}: File '{filePath}' is not allowed."); + } + if (UseCaching && _fsCache.TryGetValue(filePath, out var cachedVal) && cachedVal.TryPickT2(out var cachedDoc, out _)) + { return FluentResults.Result.Ok(cachedDoc); + } + try { await using var fs = new FileStream(filePath, FileMode.Open, FileAccess.Read); @@ -446,10 +535,19 @@ public class StorageService : IStorageService public virtual async Task> TryLoadTextAsync(string filePath, Encoding encoding = null) { + Guard.IsNotNullOrWhiteSpace(filePath, nameof(filePath)); + using var lck = await OperationsLock.AcquireReaderLock(); IService.CheckDisposed(this); + if (await IsReadOperationAllowedEval.Invoke(filePath) is not true) + { + return FluentResults.Result.Fail($"{nameof(TryLoadTextAsync)}: File '{filePath}' is not allowed."); + } + if (UseCaching && _fsCache.TryGetValue(filePath, out var cachedVal) && cachedVal.TryPickT1(out var cachedTxt, out _)) + { return FluentResults.Result.Ok(cachedTxt); + } return await IOExceptionsOperationRunnerAsync(nameof(TryLoadTextAsync), filePath, async () => { @@ -464,7 +562,14 @@ public class StorageService : IStorageService public virtual async Task> TryLoadBinaryAsync(string filePath) { + Guard.IsNotNullOrWhiteSpace(filePath, nameof(filePath)); + using var lck = await OperationsLock.AcquireReaderLock(); IService.CheckDisposed(this); + if (await IsReadOperationAllowedEval.Invoke(filePath) is not true) + { + return FluentResults.Result.Fail($"{nameof(TryLoadBinaryAsync)}: File '{filePath}' is not allowed."); + } + if (UseCaching && _fsCache.TryGetValue(filePath, out var cachedVal) && cachedVal.TryPickT0(out var cachedBin, out _)) { @@ -479,13 +584,18 @@ public class StorageService : IStorageService }); } + // method group overload public virtual async Task TrySaveXmlAsync(string filePath, XDocument document, Encoding encoding = null) => await TrySaveTextAsync(filePath, document.ToString(), encoding); public virtual async Task TrySaveTextAsync(string filePath, string text, Encoding encoding = null) { Guard.IsNotNullOrWhiteSpace(text, nameof(text)); using var lck = await OperationsLock.AcquireReaderLock(); IService.CheckDisposed(this); - + if (await IsWriteOperationAllowedEval.Invoke(filePath) is not true) + { + return FluentResults.Result.Fail($"{nameof(TrySaveTextAsync)}: File '{filePath}' is not allowed."); + } + string t = text.ToString(); //copy return await IOExceptionsOperationRunnerAsync(nameof(TrySaveText), filePath, async () => { @@ -500,11 +610,15 @@ public class StorageService : IStorageService public virtual async Task TrySaveBinaryAsync(string filePath, byte[] bytes) { + Guard.IsNotNullOrWhiteSpace(filePath, nameof(filePath)); Guard.IsNotNull(bytes, nameof(bytes)); Guard.HasSizeGreaterThanOrEqualTo(bytes, 1, nameof(bytes)); using var lck = await OperationsLock.AcquireReaderLock(); IService.CheckDisposed(this); - + if (await IsWriteOperationAllowedEval.Invoke(filePath) is not true) + { + return FluentResults.Result.Fail($"{nameof(TrySaveBinaryAsync)}: File '{filePath}' is not allowed."); + } byte[] b = new byte[bytes.Length]; System.Buffer.BlockCopy(bytes, 0, b, 0, bytes.Length);