#nullable enable using System; using System.Collections.Generic; using System.Linq; using Microsoft.Xna.Framework; namespace Barotrauma.Networking; /* * What are segment tables for? * * Segment tables help make our networking packet reading code more robust by * clearly stating where part of a message begins. Previously we would've done * something like: * * msg.WriteByte(SegmentType.A); * ... * msg.WriteByte(SegmentType.B); * ... * msg.WriteByte(SegmentType.EndOfMessage); * * The problem with this design is that it's hard to debug when the writing and reading * code do not align for whatever reason. INetSerializableStruct is an awesome way * of avoiding that problem, but deploying it on a broad scale means rewriting most * of the netcode. That isn't going to happen any time soon, so this exists as an easier * way of increasing robustness. * * A segment table is laid out as follows: * * [TablePointer: UInt16] * [Segment: arbitrary] * ... * [Segment: arbitrary] * [NumberOfSegments: UInt16] * [(Identifier, SegmentPointer): (T, UInt16)] * ... * [(Identifier, SegmentPointer): (T, UInt16)] * * A pointer in this context is an offset relative to the BitPosition where the TablePointer is written. * * It is used as follows: * * using (var segmentTable = SegmentTableWriter.StartWriting(outMsg)) * { * segmentTable.StartNewSegment(T.A); * ... write segment to outMsg ... * segmentTable.StartNewSegment(T.B); * ... write segment to outMsg ... * } * peer.SendMessage(outMsg); * * ... * * SegmentTableReader.Read(inc, * segmentDataReader: (segment, inc) => * { * switch (segment) * { * ... read segments ... * } * } * } * * The advantages of this approach are: * - If a message is truncated or corrupted near the end, it becomes far more obvious because the table * would not be read properly and look like garbage when printed to the console. * - If the reading and writing code for a segment disagree on something, issues will be isolated to that * one segment. * - The code no longer has to fiddle with padding and temporary buffers because the segment table is able * to handle content that is not byte-aligned just fine. * - Exception handling is far easier when using a segment table, when combined with a using statement * any uncaught exception will result in the entire table being skipped, allowing the remainder of the * message to still be read. * - It's harder to make mistakes in the implementation of segments themselves with this approach. By using * the SegmentTableWriter and SegmentTableReader types, you get a type-safe way of delimiting segments * and it's harder to forget to finalize a packet. */ [NetworkSerialize] public readonly record struct Segment(T Identifier, int Pointer) : INetSerializableStruct where T : struct; readonly ref struct SegmentTableWriter where T : struct { private readonly IWriteMessage message; private readonly List> segments; public readonly int PointerLocation; private SegmentTableWriter(IWriteMessage message, int pointerLocation) { this.message = message; this.PointerLocation = pointerLocation; this.segments = new List>(); } public static SegmentTableWriter StartWriting(IWriteMessage msg) { var retVal = new SegmentTableWriter(msg, msg.BitPosition); msg.WriteInt32(0); //reserve space for the table pointer return retVal; } private void ThrowOnInvalidState() { if (segments.Count >= UInt16.MaxValue) { throw new InvalidOperationException($"Too many segments in SegmentTable<{typeof(T).Name}>"); } } public void StartNewSegment(T value) { ThrowOnInvalidState(); segments.Add(new Segment(value, message.BitPosition - PointerLocation)); } public void Dispose() { ThrowOnInvalidState(); int tablePosition = message.BitPosition; //rewrite the table pointer now that we know where the table ends message.BitPosition = PointerLocation; message.WriteInt32(tablePosition - PointerLocation); //write the table message.BitPosition = tablePosition; message.WriteUInt16((UInt16)segments.Count); foreach (var segment in segments) { message.WriteNetSerializableStruct(segment); } } } readonly ref struct SegmentTableReader where T : struct { private class SegmentReadMsg : IReadMessage { private readonly IReadMessage underlyingMsg; private readonly IReadOnlyList> segments; private readonly int segmentIndex; private readonly int offset; private readonly int lengthBits; public SegmentReadMsg(IReadMessage underlyingMsg, IReadOnlyList> segments, int segmentIndex, int offset, int lengthBits) { this.underlyingMsg = underlyingMsg; this.segments = segments; this.segmentIndex = segmentIndex; this.offset = offset; this.lengthBits = lengthBits; if (offset + lengthBits >= underlyingMsg.LengthBits) { throw new Exception( $"Segment table is corrupt, segment length is invalid: {offset} + {lengthBits} >= {underlyingMsg.LengthBits}"); } } private void Check() { if (BitPosition > lengthBits) { throw new Exception($"Tried to read too much data from segment."); } } private TRead Check(TRead v) { Check(); return v; } public bool ReadBoolean() => Check(underlyingMsg.ReadBoolean()); public void ReadPadBits() { Check(); underlyingMsg.ReadPadBits(); } public byte ReadByte() => Check(underlyingMsg.ReadByte()); public byte PeekByte() => Check(underlyingMsg.PeekByte()); public ushort ReadUInt16() => Check(underlyingMsg.ReadUInt16()); public short ReadInt16() => Check(underlyingMsg.ReadInt16()); public uint ReadUInt32() => Check(underlyingMsg.ReadUInt32()); public int ReadInt32() => Check(underlyingMsg.ReadInt32()); public ulong ReadUInt64() => Check(underlyingMsg.ReadUInt64()); public long ReadInt64() => Check(underlyingMsg.ReadInt64()); public float ReadSingle() => Check(underlyingMsg.ReadSingle()); public double ReadDouble() => Check(underlyingMsg.ReadDouble()); public uint ReadVariableUInt32() => Check(underlyingMsg.ReadVariableUInt32()); public string ReadString() => Check(underlyingMsg.ReadString()); public Identifier ReadIdentifier() => Check(underlyingMsg.ReadIdentifier()); public Color ReadColorR8G8B8() => Check(underlyingMsg.ReadColorR8G8B8()); public Color ReadColorR8G8B8A8() => Check(underlyingMsg.ReadColorR8G8B8A8()); public int ReadRangedInteger(int min, int max) => Check(underlyingMsg.ReadRangedInteger(min, max)); public float ReadRangedSingle(float min, float max, int bitCount) => Check(underlyingMsg.ReadRangedSingle(min, max, bitCount)); public byte[] ReadBytes(int numberOfBytes) => Check(underlyingMsg.ReadBytes(numberOfBytes)); public int BitPosition { get => underlyingMsg.BitPosition - offset; set => Check(underlyingMsg.BitPosition = value + offset); } public int BytePosition => BitPosition / 8; public byte[] Buffer => underlyingMsg.Buffer; public int LengthBits { get => lengthBits; set => throw new InvalidOperationException($"Cannot resize {nameof(SegmentReadMsg)}"); } public int LengthBytes => lengthBits / 8; public NetworkConnection Sender => underlyingMsg.Sender; } private readonly IReadMessage message; private readonly List> segments; private readonly int exitLocation; public readonly int PointerLocation; private SegmentTableReader(IReadMessage message, List> segments, int pointerLocation, int exitLocation) { this.message = message; this.segments = segments; this.PointerLocation = pointerLocation; this.exitLocation = exitLocation; } public IReadOnlyList> Segments => segments; public enum BreakSegmentReading { No, Yes } public delegate BreakSegmentReading SegmentDataReader( T segmentHeader, IReadMessage incMsg); public delegate void ExceptionHandler( Segment segmentWithError, Segment[] previousSegments, Exception exceptionThrown); public static void Read( IReadMessage msg, SegmentDataReader segmentDataReader, ExceptionHandler? exceptionHandler = null) { int pointerLocation = msg.BitPosition; int tablePointer = msg.ReadInt32(); int tableLocation = pointerLocation + tablePointer; int returnPosition = msg.BitPosition; //read the table var segments = new List>(); msg.BitPosition = tableLocation; int numSegments = msg.ReadUInt16(); for (int i = 0; i < numSegments; i++) { segments.Add(INetSerializableStruct.Read>(msg)); } //store the exit location and go back to the top int exitLocation = msg.BitPosition; msg.BitPosition = returnPosition; using var segmentTable = new SegmentTableReader(msg, segments, pointerLocation, exitLocation); for (int i = 0; i < segmentTable.Segments.Count; i++) { var segment = segmentTable.Segments[i]; msg.BitPosition = segmentTable.PointerLocation + segment.Pointer; try { if (segmentDataReader(segment.Identifier, new SegmentReadMsg( msg, segments, i, offset: segmentTable.PointerLocation + segment.Pointer, lengthBits: (i < segmentTable.Segments.Count - 1 ? segments[i + 1].Pointer : tablePointer) - segment.Pointer)) is BreakSegmentReading.Yes) { break; } } catch (Exception e) { var prevSegments = segments.Take(i).ToArray(); if (exceptionHandler is not null) { exceptionHandler(segment, prevSegments, e); } else { throw new Exception( $"Exception thrown while reading segment {segment.Identifier} at position {segment.Pointer}." + (prevSegments.Any() ? $" Previous segments: {string.Join(", ", prevSegments)}." : ""), e); } } } } public void Dispose() { message.BitPosition = exitLocation; } }