File: src\SignalR\common\Shared\MemoryBufferWriter.cs
Web Access
Project: src\src\SignalR\server\StackExchangeRedis\src\Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj (Microsoft.AspNetCore.SignalR.StackExchangeRedis)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#nullable enable
 
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
 
namespace Microsoft.AspNetCore.Internal;
 
internal sealed class MemoryBufferWriter : Stream, IBufferWriter<byte>
{
    [ThreadStatic]
    private static MemoryBufferWriter? _cachedInstance;
 
#if DEBUG
    private bool _inUse;
#endif
 
    private readonly int _minimumSegmentSize;
    private int _bytesWritten;
 
    private List<CompletedBuffer>? _completedSegments;
    private byte[]? _currentSegment;
    private int _position;
 
    public MemoryBufferWriter(int minimumSegmentSize = 4096)
    {
        _minimumSegmentSize = minimumSegmentSize;
    }
 
    public override long Length => _bytesWritten;
    public override bool CanRead => false;
    public override bool CanSeek => false;
    public override bool CanWrite => true;
    public override long Position
    {
        get => throw new NotSupportedException();
        set => throw new NotSupportedException();
    }
 
    public static MemoryBufferWriter Get()
    {
        var writer = _cachedInstance;
        if (writer == null)
        {
            writer = new MemoryBufferWriter();
        }
        else
        {
            // Taken off the thread static
            _cachedInstance = null;
        }
#if DEBUG
        if (writer._inUse)
        {
            throw new InvalidOperationException("The reader wasn't returned!");
        }
 
        writer._inUse = true;
#endif
 
        return writer;
    }
 
    public static void Return(MemoryBufferWriter writer)
    {
        _cachedInstance = writer;
#if DEBUG
        writer._inUse = false;
#endif
        writer.Reset();
    }
 
    public void Reset()
    {
        if (_completedSegments != null)
        {
            for (var i = 0; i < _completedSegments.Count; i++)
            {
                _completedSegments[i].Return();
            }
 
            _completedSegments.Clear();
        }
 
        if (_currentSegment != null)
        {
            ArrayPool<byte>.Shared.Return(_currentSegment);
            _currentSegment = null;
        }
 
        _bytesWritten = 0;
        _position = 0;
    }
 
    public void Advance(int count)
    {
        _bytesWritten += count;
        _position += count;
    }
 
    public Memory<byte> GetMemory(int sizeHint = 0)
    {
        EnsureCapacity(sizeHint);
 
        return _currentSegment.AsMemory(_position, _currentSegment.Length - _position);
    }
 
    public Span<byte> GetSpan(int sizeHint = 0)
    {
        EnsureCapacity(sizeHint);
 
        return _currentSegment.AsSpan(_position, _currentSegment.Length - _position);
    }
 
    public void CopyTo(IBufferWriter<byte> destination)
    {
        if (_completedSegments != null)
        {
            // Copy completed segments
            var count = _completedSegments.Count;
            for (var i = 0; i < count; i++)
            {
                destination.Write(_completedSegments[i].Span);
            }
        }
 
        destination.Write(_currentSegment.AsSpan(0, _position));
    }
 
    public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
    {
        if (_completedSegments == null && _currentSegment is not null)
        {
            // There is only one segment so write without awaiting.
            return destination.WriteAsync(_currentSegment, 0, _position, cancellationToken);
        }
 
        return CopyToSlowAsync(destination, cancellationToken);
    }
 
    [MemberNotNull(nameof(_currentSegment))]
    private void EnsureCapacity(int sizeHint)
    {
        // This does the Right Thing. It only subtracts _position from the current segment length if it's non-null.
        // If _currentSegment is null, it returns 0.
        var remainingSize = _currentSegment?.Length - _position ?? 0;
 
        // If the sizeHint is 0, any capacity will do
        // Otherwise, the buffer must have enough space for the entire size hint, or we need to add a segment.
        if ((sizeHint == 0 && remainingSize > 0) || (sizeHint > 0 && remainingSize >= sizeHint))
        {
            // We have capacity in the current segment
#pragma warning disable CS8774 // Member must have a non-null value when exiting.
            return;
#pragma warning restore CS8774 // Member must have a non-null value when exiting.
        }
 
        AddSegment(sizeHint);
    }
 
    [MemberNotNull(nameof(_currentSegment))]
    private void AddSegment(int sizeHint = 0)
    {
        if (_currentSegment != null)
        {
            // We're adding a segment to the list
            if (_completedSegments == null)
            {
                _completedSegments = new List<CompletedBuffer>();
            }
 
            // Position might be less than the segment length if there wasn't enough space to satisfy the sizeHint when
            // GetMemory was called. In that case we'll take the current segment and call it "completed", but need to
            // ignore any empty space in it.
            _completedSegments.Add(new CompletedBuffer(_currentSegment, _position));
        }
 
        // Get a new buffer using the minimum segment size, unless the size hint is larger than a single segment.
        _currentSegment = ArrayPool<byte>.Shared.Rent(Math.Max(_minimumSegmentSize, sizeHint));
        _position = 0;
    }
 
    private async Task CopyToSlowAsync(Stream destination, CancellationToken cancellationToken)
    {
        if (_completedSegments != null)
        {
            // Copy full segments
            var count = _completedSegments.Count;
            for (var i = 0; i < count; i++)
            {
                var segment = _completedSegments[i];
#if NETCOREAPP
                await destination.WriteAsync(segment.Buffer.AsMemory(0, segment.Length), cancellationToken).ConfigureAwait(false);
#else
                await destination.WriteAsync(segment.Buffer, 0, segment.Length, cancellationToken).ConfigureAwait(false);
#endif
            }
        }
 
        if (_currentSegment is not null)
        {
#if NETCOREAPP
            await destination.WriteAsync(_currentSegment.AsMemory(0, _position), cancellationToken).ConfigureAwait(false);
#else
            await destination.WriteAsync(_currentSegment, 0, _position, cancellationToken).ConfigureAwait(false);
#endif
        }
    }
 
    public byte[] ToArray()
    {
        if (_currentSegment == null)
        {
            return Array.Empty<byte>();
        }
 
        var result = new byte[_bytesWritten];
 
        var totalWritten = 0;
 
        if (_completedSegments != null)
        {
            // Copy full segments
            var count = _completedSegments.Count;
            for (var i = 0; i < count; i++)
            {
                var segment = _completedSegments[i];
                segment.Span.CopyTo(result.AsSpan(totalWritten));
                totalWritten += segment.Span.Length;
            }
        }
 
        // Copy current incomplete segment
        _currentSegment.AsSpan(0, _position).CopyTo(result.AsSpan(totalWritten));
 
        return result;
    }
 
    public void CopyTo(Span<byte> span)
    {
        Debug.Assert(span.Length >= _bytesWritten);
 
        if (_currentSegment == null)
        {
            return;
        }
 
        var totalWritten = 0;
 
        if (_completedSegments != null)
        {
            // Copy full segments
            var count = _completedSegments.Count;
            for (var i = 0; i < count; i++)
            {
                var segment = _completedSegments[i];
                segment.Span.CopyTo(span.Slice(totalWritten));
                totalWritten += segment.Span.Length;
            }
        }
 
        // Copy current incomplete segment
        _currentSegment.AsSpan(0, _position).CopyTo(span.Slice(totalWritten));
 
        Debug.Assert(_bytesWritten == totalWritten + _position);
    }
 
    public override void Flush() { }
    public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
    public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
    public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
    public override void SetLength(long value) => throw new NotSupportedException();
 
    public override void WriteByte(byte value)
    {
        if (_currentSegment != null && (uint)_position < (uint)_currentSegment.Length)
        {
            _currentSegment[_position] = value;
        }
        else
        {
            AddSegment();
            _currentSegment[0] = value;
        }
 
        _position++;
        _bytesWritten++;
    }
 
    public override void Write(byte[] buffer, int offset, int count)
    {
        var position = _position;
        if (_currentSegment != null && position < _currentSegment.Length - count)
        {
            Buffer.BlockCopy(buffer, offset, _currentSegment, position, count);
 
            _position = position + count;
            _bytesWritten += count;
        }
        else
        {
            BuffersExtensions.Write(this, buffer.AsSpan(offset, count));
        }
    }
 
#if NETCOREAPP
    public override void Write(ReadOnlySpan<byte> span)
    {
        if (_currentSegment != null && span.TryCopyTo(_currentSegment.AsSpan(_position)))
        {
            _position += span.Length;
            _bytesWritten += span.Length;
        }
        else
        {
            BuffersExtensions.Write(this, span);
        }
    }
#endif
 
    public WrittenBuffers DetachAndReset()
    {
        _completedSegments ??= new List<CompletedBuffer>();
 
        if (_currentSegment is not null)
        {
            _completedSegments.Add(new CompletedBuffer(_currentSegment, _position));
        }
 
        var written = new WrittenBuffers(_completedSegments, _bytesWritten);
 
        _currentSegment = null;
        _completedSegments = null;
        _bytesWritten = 0;
        _position = 0;
 
        return written;
    }
 
    protected override void Dispose(bool disposing)
    {
        if (disposing)
        {
            Reset();
        }
    }
 
    /// <summary>
    /// Holds the written segments from a MemoryBufferWriter and is no longer attached to a MemoryBufferWriter.
    /// You are now responsible for calling Dispose on this type to return the memory to the pool.
    /// </summary>
    internal readonly ref struct WrittenBuffers
    {
        public readonly List<CompletedBuffer> Segments;
        private readonly int _bytesWritten;
 
        public WrittenBuffers(List<CompletedBuffer> segments, int bytesWritten)
        {
            Segments = segments;
            _bytesWritten = bytesWritten;
        }
 
        public int ByteLength => _bytesWritten;
 
        public void Dispose()
        {
            for (var i = 0; i < Segments.Count; i++)
            {
                Segments[i].Return();
            }
            Segments.Clear();
        }
    }
 
    /// <summary>
    /// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it.
    /// </summary>
    internal readonly struct CompletedBuffer
    {
        public byte[] Buffer { get; }
        public int Length { get; }
 
        public ReadOnlySpan<byte> Span => Buffer.AsSpan(0, Length);
 
        public CompletedBuffer(byte[] buffer, int length)
        {
            Buffer = buffer;
            Length = length;
        }
 
        public void Return()
        {
            ArrayPool<byte>.Shared.Return(Buffer);
        }
    }
}