diff --git a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/SafeStorageService.cs b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/SafeStorageService.cs index 328705184..77e9a2aee 100644 --- a/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/SafeStorageService.cs +++ b/Barotrauma/BarotraumaShared/SharedSource/LuaCs/Services/Safe/SafeStorageService.cs @@ -3,20 +3,24 @@ using System.Collections.Concurrent; using System.Collections.Immutable; using System.IO; using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Xml.Linq; using Barotrauma.IO; using Barotrauma.LuaCs.Data; using FarseerPhysics.Common; using FluentResults; +using FluentResults.LuaCs; +using Path = System.IO.Path; namespace Barotrauma.LuaCs.Services.Safe; public class SafeStorageService : StorageService, ISafeStorageService { - private ConcurrentDictionary _fileListRead = new (), _fileListReadWrite = new(); + private ConcurrentDictionary _fileListRead = new (), _fileListWrite = new(); public SafeStorageService(IStorageServiceConfig configData) : base(configData) { - } private string GetFullPath(string path) => System.IO.Path.GetFullPath(path).CleanUpPathCrossPlatform(); @@ -28,25 +32,15 @@ public class SafeStorageService : StorageService, ISafeStorageService try { path = GetFullPath(path); - if (!readOnly && IsReadOnlyMode) + if (!_fileListRead.ContainsKey(path)) + return false; + if (!readOnly && !_fileListWrite.ContainsKey(path)) return false; - if (readOnly) - { - if (!_fileListRead.ContainsKey(path)) - return false; - } - else - { - if (!_fileListReadWrite.ContainsKey(path)) - return false; - } if (checkWhitelistOnly) return true; - using var fs = System.IO.File.Open( path, FileMode.Open, readOnly ? FileAccess.Read : FileAccess.ReadWrite, FileShare.ReadWrite); - - return true; + return readOnly ? fs.CanRead : fs.CanWrite; } catch { @@ -61,8 +55,8 @@ public class SafeStorageService : StorageService, ISafeStorageService { path = GetFullPath(path); _fileListRead.AddOrUpdate(path, s => 0, (s, b) => 0); - if (!readOnly && !IsReadOnlyMode) - _fileListRead.AddOrUpdate(path, s => 0, (s, b) => 0); + if (!readOnly) + _fileListWrite.AddOrUpdate(path, s => 0, (s, b) => 0); } catch { @@ -77,7 +71,7 @@ public class SafeStorageService : StorageService, ISafeStorageService { path = GetFullPath(path); _fileListRead.TryRemove(path, out _); - _fileListReadWrite.TryRemove(path, out _); + _fileListWrite.TryRemove(path, out _); } catch { @@ -90,34 +84,198 @@ public class SafeStorageService : StorageService, ISafeStorageService ((IService)this).CheckDisposed(); if (filePaths.IsDefaultOrEmpty) return FluentResults.Result.Fail($"{nameof(SetReadOnlyWhitelist)}: FilePaths cannot be empty."); + _fileListRead.Clear(); var res = new FluentResults.Result(); foreach (var path in filePaths) { - // TODO: Cleanup path and add it. + try + { + var p = Path.GetFullPath(path.CleanUpPathCrossPlatform()); + if (_fileListRead.ContainsKey(p)) + { + res = res.WithReason(new Success($"Path already in whitelist: {p}")); + continue; + } + + if (_fileListRead.TryAdd(p, 0)) + { + res = res.WithSuccess($"Added path successfully: {p}"); + continue; + } + + res = res.WithError(new Error($"Failed to add path to list: {p}")); + } + catch (Exception e) + { + res = res.WithError(new ExceptionalError(e) + .WithMetadata(MetadataType.ExceptionObject, this) + .WithMetadata(MetadataType.ExceptionDetails, e.Message) + .WithMetadata(MetadataType.RootObject, path) + ); + continue; + } } - throw new NotImplementedException(); + return res; } public FluentResults.Result SetReadWriteWhitelist(ImmutableArray filePaths) { ((IService)this).CheckDisposed(); - throw new System.NotImplementedException(); + if (filePaths.IsDefaultOrEmpty) + return FluentResults.Result.Fail($"{nameof(SetReadOnlyWhitelist)}: FilePaths cannot be empty."); + _fileListRead.Clear(); + _fileListWrite.Clear(); + var res = new FluentResults.Result(); + foreach (var path in filePaths) + { + try + { + var p = Path.GetFullPath(path.CleanUpPathCrossPlatform()); + TryAddToList(_fileListRead, p); + TryAddToList(_fileListWrite, p); + res = res.WithError(new Error($"Failed to add path to list: {p}")); + } + catch (Exception e) + { + res = res.WithError(new ExceptionalError(e) + .WithMetadata(MetadataType.ExceptionObject, this) + .WithMetadata(MetadataType.ExceptionDetails, e.Message) + .WithMetadata(MetadataType.RootObject, path) + ); + continue; + } + } + + void TryAddToList(ConcurrentDictionary dict, string p) + { + if (dict.ContainsKey(p)) + { + res = res.WithReason(new Success($"Path already in whitelist: {p}")); + return; + } + + if (dict.TryAdd(p, 0)) + { + res = res.WithSuccess($"Added path successfully: {p}"); + return; + } + } + + return res; } public void ClearAllWhitelists() { - throw new System.NotImplementedException(); + ((IService)this).CheckDisposed(); + _fileListRead.Clear(); + _fileListWrite.Clear(); } - private int _isReadOnlyMode = 0; - public bool IsReadOnlyMode => ModUtils.Threading.GetBool(ref _isReadOnlyMode); - - public bool EnableReadOnlyMode() + #region Base_Overrides + + private bool ReadCheck(string path) { - ModUtils.Threading.SetBool(ref _isReadOnlyMode, true); - return ModUtils.Threading.GetBool(ref _isReadOnlyMode); + return IsFileAccessible(path, true, true); } + + private bool WriteCheck(string path) + { + return IsFileAccessible(path, false, true); + } + + public override Result FileExists(string filePath) + { + if (!ReadCheck(filePath)) + return FluentResults.Result.Fail("Cannot access file."); + return base.FileExists(filePath); + } + + public override Result TryLoadBinary(string filePath) + { + if (!ReadCheck(filePath)) + return FluentResults.Result.Fail("Cannot access file."); + return base.TryLoadBinary(filePath); + } + + public override async Task> TryLoadBinaryAsync(string filePath) + { + if (!ReadCheck(filePath)) + return FluentResults.Result.Fail("Cannot access file."); + return await base.TryLoadBinaryAsync(filePath); + } + + public override Result TryLoadText(string filePath, Encoding encoding = null) + { + if (!ReadCheck(filePath)) + return FluentResults.Result.Fail("Cannot access file."); + return base.TryLoadText(filePath, encoding); + } + + public override async Task> TryLoadTextAsync(string filePath, Encoding encoding = null) + { + if (!ReadCheck(filePath)) + return FluentResults.Result.Fail("Cannot access file."); + return await base.TryLoadTextAsync(filePath, encoding); + } + + public override Result TryLoadXml(string filePath, Encoding encoding = null) + { + if (!ReadCheck(filePath)) + return FluentResults.Result.Fail("Cannot access file."); + return base.TryLoadXml(filePath, encoding); + } + + public override async Task> TryLoadXmlAsync(string filePath, Encoding encoding = null) + { + if (!ReadCheck(filePath)) + return FluentResults.Result.Fail("Cannot access file."); + return await base.TryLoadXmlAsync(filePath, encoding); + } + + public override FluentResults.Result TrySaveBinary(string filePath, in byte[] bytes) + { + if (!WriteCheck(filePath)) + return FluentResults.Result.Fail("Cannot write to file."); + return base.TrySaveBinary(filePath, in bytes); + } + + public override async Task TrySaveBinaryAsync(string filePath, byte[] bytes) + { + if (!WriteCheck(filePath)) + return FluentResults.Result.Fail("Cannot write to file."); + return await base.TrySaveBinaryAsync(filePath, bytes); + } + + public override FluentResults.Result TrySaveText(string filePath, in string text, Encoding encoding = null) + { + if (!WriteCheck(filePath)) + return FluentResults.Result.Fail("Cannot write to file."); + return base.TrySaveText(filePath, in text, encoding); + } + + public override async Task TrySaveTextAsync(string filePath, string text, Encoding encoding = null) + { + if (!WriteCheck(filePath)) + return FluentResults.Result.Fail("Cannot write to file."); + return await base.TrySaveTextAsync(filePath, text, encoding); + } + + public override FluentResults.Result TrySaveXml(string filePath, in XDocument document, Encoding encoding = null) + { + if (!WriteCheck(filePath)) + return FluentResults.Result.Fail("Cannot write to file."); + return base.TrySaveXml(filePath, in document, encoding); + } + + public override async Task TrySaveXmlAsync(string filePath, XDocument document, Encoding encoding = null) + { + if (!WriteCheck(filePath)) + return FluentResults.Result.Fail("Cannot write to file."); + return await base.TrySaveXmlAsync(filePath, document, encoding); + } + + #endregion }