File: System\IO\Pipelines\StreamPipeWriter.cs
Web Access
Project: src\src\libraries\System.IO.Pipelines\src\System.IO.Pipelines.csproj (System.IO.Pipelines)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
 
namespace System.IO.Pipelines
{
    internal sealed class StreamPipeWriter : PipeWriter
    {
        internal const int InitialSegmentPoolSize = 4; // 16K
        internal const int MaxSegmentPoolSize = 256; // 1MB
 
        private readonly int _minimumBufferSize;
 
        private BufferSegment? _head;
        private BufferSegment? _tail;
        private Memory<byte> _tailMemory;
        private int _tailBytesBuffered;
        private int _bytesBuffered;
 
        private readonly MemoryPool<byte>? _pool;
        private readonly int _maxPooledBufferSize;
 
        private CancellationTokenSource? _internalTokenSource;
        private bool _isCompleted;
        private readonly object _lockObject = new object();
 
        private BufferSegmentStack _bufferSegmentPool;
        private readonly bool _leaveOpen;
 
        private CancellationTokenSource InternalTokenSource
        {
            get
            {
                lock (_lockObject)
                {
                    return _internalTokenSource ??= new CancellationTokenSource();
                }
            }
        }
 
        public StreamPipeWriter(Stream writingStream, StreamPipeWriterOptions options)
        {
            if (writingStream is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.writingStream);
            }
            if (options is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.options);
            }
 
            InnerStream = writingStream;
            _minimumBufferSize = options.MinimumBufferSize;
            _pool = options.Pool == MemoryPool<byte>.Shared ? null : options.Pool;
            _maxPooledBufferSize = _pool?.MaxBufferSize ?? -1;
            _bufferSegmentPool = new BufferSegmentStack(InitialSegmentPoolSize);
            _leaveOpen = options.LeaveOpen;
        }
 
        /// <summary>
        /// Gets the inner stream that is being written to.
        /// </summary>
        public Stream InnerStream { get; }
 
        /// <inheritdoc />
        public override void Advance(int bytes)
        {
            if ((uint)bytes > (uint)_tailMemory.Length)
            {
                ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.bytes);
            }
 
            _tailBytesBuffered += bytes;
            _bytesBuffered += bytes;
            _tailMemory = _tailMemory.Slice(bytes);
        }
 
        /// <inheritdoc />
        public override Memory<byte> GetMemory(int sizeHint = 0)
        {
            if (_isCompleted)
            {
                ThrowHelper.ThrowInvalidOperationException_NoWritingAllowed();
            }
 
            if (sizeHint < 0)
            {
                ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.sizeHint);
            }
 
            AllocateMemory(sizeHint);
 
            return _tailMemory;
        }
 
        /// <inheritdoc />
        public override Span<byte> GetSpan(int sizeHint = 0)
        {
            if (_isCompleted)
            {
                ThrowHelper.ThrowInvalidOperationException_NoWritingAllowed();
            }
 
            if (sizeHint < 0)
            {
                ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.sizeHint);
            }
 
            AllocateMemory(sizeHint);
 
            return _tailMemory.Span;
        }
 
        private void AllocateMemory(int sizeHint)
        {
            if (_head == null)
            {
                // We need to allocate memory to write since nobody has written before
                BufferSegment newSegment = AllocateSegment(sizeHint);
 
                // Set all the pointers
                _head = _tail = newSegment;
                _tailBytesBuffered = 0;
            }
            else
            {
                Debug.Assert(_tail != null);
                int bytesLeftInBuffer = _tailMemory.Length;
 
                if (bytesLeftInBuffer == 0 || bytesLeftInBuffer < sizeHint)
                {
                    if (_tailBytesBuffered > 0)
                    {
                        // Flush buffered data to the segment
                        _tail.End += _tailBytesBuffered;
                        _tailBytesBuffered = 0;
                    }
 
                    BufferSegment newSegment = AllocateSegment(sizeHint);
 
                    _tail.SetNext(newSegment);
                    _tail = newSegment;
                }
            }
        }
 
        private BufferSegment AllocateSegment(int sizeHint)
        {
            Debug.Assert(sizeHint >= 0);
            BufferSegment newSegment = CreateSegmentUnsynchronized();
 
            int maxSize = _maxPooledBufferSize;
            if (sizeHint <= maxSize)
            {
                // Use the specified pool as it fits. Specified pool is not null as maxSize == -1 if _pool is null.
                newSegment.SetOwnedMemory(_pool!.Rent(GetSegmentSize(sizeHint, maxSize)));
            }
            else
            {
                // Use the array pool
                int sizeToRequest = GetSegmentSize(sizeHint);
                newSegment.SetOwnedMemory(ArrayPool<byte>.Shared.Rent(sizeToRequest));
            }
 
            _tailMemory = newSegment.AvailableMemory;
 
            return newSegment;
        }
 
        private int GetSegmentSize(int sizeHint, int maxBufferSize = int.MaxValue)
        {
            // First we need to handle case where hint is smaller than minimum segment size
            sizeHint = Math.Max(_minimumBufferSize, sizeHint);
            // After that adjust it to fit into pools max buffer size
            var adjustedToMaximumSize = Math.Min(maxBufferSize, sizeHint);
            return adjustedToMaximumSize;
        }
 
        private BufferSegment CreateSegmentUnsynchronized()
        {
            if (_bufferSegmentPool.TryPop(out BufferSegment? segment))
            {
                return segment;
            }
 
            return new BufferSegment();
        }
 
        private void ReturnSegmentUnsynchronized(BufferSegment segment)
        {
            segment.Reset();
            if (_bufferSegmentPool.Count < MaxSegmentPoolSize)
            {
                _bufferSegmentPool.Push(segment);
            }
        }
 
        /// <inheritdoc />
        public override void CancelPendingFlush()
        {
            Cancel();
        }
 
        /// <inheritdoc />
        public override bool CanGetUnflushedBytes => true;
 
        /// <inheritdoc />
        public override void Complete(Exception? exception = null)
        {
            if (_isCompleted)
            {
                return;
            }
 
            _isCompleted = true;
 
            try
            {
                FlushInternal(writeToStream: exception == null);
            }
            finally
            {
                _internalTokenSource?.Dispose();
 
                if (!_leaveOpen)
                {
                    InnerStream.Dispose();
                }
            }
        }
 
        public override async ValueTask CompleteAsync(Exception? exception = null)
        {
            if (_isCompleted)
            {
                return;
            }
 
            _isCompleted = true;
 
            try
            {
                await FlushAsyncInternal(writeToStream: exception == null, data: Memory<byte>.Empty).ConfigureAwait(false);
            }
            finally
            {
                _internalTokenSource?.Dispose();
 
                if (!_leaveOpen)
                {
#if (!NETSTANDARD2_0 && !NETFRAMEWORK)
                    await InnerStream.DisposeAsync().ConfigureAwait(false);
#else
                    InnerStream.Dispose();
#endif
                }
            }
        }
 
        /// <inheritdoc />
        public override ValueTask<FlushResult> FlushAsync(CancellationToken cancellationToken = default)
        {
            if (_bytesBuffered == 0)
            {
                return new ValueTask<FlushResult>(new FlushResult(isCanceled: false, isCompleted: false));
            }
 
            return FlushAsyncInternal(writeToStream: true, data: Memory<byte>.Empty, cancellationToken);
        }
 
        /// <inheritdoc />
        public override long UnflushedBytes => _bytesBuffered;
 
        public override ValueTask<FlushResult> WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
        {
            return FlushAsyncInternal(writeToStream: true, data: source, cancellationToken);
        }
 
        private void Cancel()
        {
            InternalTokenSource.Cancel();
        }
 
#if NET
        [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
#endif
        private async ValueTask<FlushResult> FlushAsyncInternal(bool writeToStream, ReadOnlyMemory<byte> data, CancellationToken cancellationToken = default)
        {
            // Write all completed segments and whatever remains in the current segment
            // and flush the result.
            CancellationTokenRegistration reg = default;
            if (cancellationToken.CanBeCanceled)
            {
                reg = cancellationToken.UnsafeRegister(state => ((StreamPipeWriter)state!).Cancel(), this);
            }
 
            if (_tailBytesBuffered > 0)
            {
                Debug.Assert(_tail != null);
 
                // Update any buffered data
                _tail.End += _tailBytesBuffered;
                _tailBytesBuffered = 0;
            }
 
            using (reg)
            {
                CancellationToken localToken = InternalTokenSource.Token;
                try
                {
                    BufferSegment? segment = _head;
                    while (segment != null)
                    {
                        BufferSegment returnSegment = segment;
                        segment = segment.NextSegment;
 
                        if (returnSegment.Length > 0 && writeToStream)
                        {
                            await InnerStream.WriteAsync(returnSegment.Memory, localToken).ConfigureAwait(false);
                        }
 
                        ReturnSegmentUnsynchronized(returnSegment);
 
                        // Update the head segment after we return the current segment
                        _head = segment;
                    }
 
                    if (writeToStream)
                    {
                        // Write data after the buffered data
                        if (data.Length > 0)
                        {
                            await InnerStream.WriteAsync(data, localToken).ConfigureAwait(false);
                        }
 
                        if (_bytesBuffered > 0 || data.Length > 0)
                        {
                            await InnerStream.FlushAsync(localToken).ConfigureAwait(false);
                        }
                    }
 
                    // Mark bytes as written *after* flushing
                    _head = null;
                    _tail = null;
                    _tailMemory = default;
                    _bytesBuffered = 0;
 
                    return new FlushResult(isCanceled: false, isCompleted: false);
                }
                catch (OperationCanceledException)
                {
                    // Remove the cancellation token such that the next time Flush is called
                    // A new CTS is created.
                    lock (_lockObject)
                    {
                        _internalTokenSource = null;
                    }
 
                    if (localToken.IsCancellationRequested && !cancellationToken.IsCancellationRequested)
                    {
                        // Catch cancellation and translate it into setting isCanceled = true
                        return new FlushResult(isCanceled: true, isCompleted: false);
                    }
 
                    throw;
                }
            }
        }
 
        private void FlushInternal(bool writeToStream)
        {
            // Write all completed segments and whatever remains in the current segment
            // and flush the result.
            if (_tailBytesBuffered > 0)
            {
                Debug.Assert(_tail != null);
 
                // Update any buffered data
                _tail.End += _tailBytesBuffered;
                _tailBytesBuffered = 0;
            }
 
            BufferSegment? segment = _head;
            while (segment != null)
            {
                BufferSegment returnSegment = segment;
                segment = segment.NextSegment;
 
                if (returnSegment.Length > 0 && writeToStream)
                {
#if (!NETSTANDARD2_0 && !NETFRAMEWORK)
                    InnerStream.Write(returnSegment.Memory.Span);
#else
                    InnerStream.Write(returnSegment.Memory);
#endif
                }
 
                ReturnSegmentUnsynchronized(returnSegment);
 
                // Update the head segment after we return the current segment
                _head = segment;
            }
 
            if (_bytesBuffered > 0 && writeToStream)
            {
                InnerStream.Flush();
            }
 
            // Mark bytes as written *after* flushing
            _head = null;
            _tail = null;
            _tailMemory = default;
            _bytesBuffered = 0;
        }
    }
}