File: HttpRequestStreamReader.cs
Web Access
Project: src\src\Http\WebUtilities\src\Microsoft.AspNetCore.WebUtilities.csproj (Microsoft.AspNetCore.WebUtilities)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Text;
 
namespace Microsoft.AspNetCore.WebUtilities;
 
/// <summary>
/// A <see cref="TextReader"/> to read the HTTP request stream.
/// </summary>
public class HttpRequestStreamReader : TextReader
{
    private const int DefaultBufferSize = 1024;
 
    private readonly Stream _stream;
    private readonly Encoding _encoding;
    private readonly Decoder _decoder;
 
    private readonly ArrayPool<byte> _bytePool;
    private readonly ArrayPool<char> _charPool;
 
    private readonly int _byteBufferSize;
    private readonly byte[] _byteBuffer;
    private readonly char[] _charBuffer;
 
    private int _charBufferIndex;
    private int _charsRead;
    private int _bytesRead;
 
    private bool _isBlocked;
    private bool _disposed;
 
    /// <summary>
    /// Initializes a new instance of <see cref="HttpRequestStreamReader"/>.
    /// </summary>
    /// <param name="stream">The HTTP request <see cref="Stream"/>.</param>
    /// <param name="encoding">The character encoding to use.</param>
    public HttpRequestStreamReader(Stream stream, Encoding encoding)
        : this(stream, encoding, DefaultBufferSize, ArrayPool<byte>.Shared, ArrayPool<char>.Shared)
    {
    }
 
    /// <summary>
    /// Initializes a new instance of <see cref="HttpRequestStreamReader"/>.
    /// </summary>
    /// <param name="stream">The HTTP request <see cref="Stream"/>.</param>
    /// <param name="encoding">The character encoding to use.</param>
    /// <param name="bufferSize">The minimum buffer size.</param>
    public HttpRequestStreamReader(Stream stream, Encoding encoding, int bufferSize)
        : this(stream, encoding, bufferSize, ArrayPool<byte>.Shared, ArrayPool<char>.Shared)
    {
    }
 
    /// <summary>
    /// Initializes a new instance of <see cref="HttpRequestStreamReader"/>.
    /// </summary>
    /// <param name="stream">The HTTP request <see cref="Stream"/>.</param>
    /// <param name="encoding">The character encoding to use.</param>
    /// <param name="bufferSize">The minimum buffer size.</param>
    /// <param name="bytePool">The byte array pool to use.</param>
    /// <param name="charPool">The char array pool to use.</param>
    public HttpRequestStreamReader(
        Stream stream,
        Encoding encoding,
        int bufferSize,
        ArrayPool<byte> bytePool,
        ArrayPool<char> charPool)
    {
        _stream = stream ?? throw new ArgumentNullException(nameof(stream));
        _encoding = encoding ?? throw new ArgumentNullException(nameof(encoding));
        _bytePool = bytePool ?? throw new ArgumentNullException(nameof(bytePool));
        _charPool = charPool ?? throw new ArgumentNullException(nameof(charPool));
 
        ArgumentOutOfRangeException.ThrowIfNegativeOrZero(bufferSize);
        if (!stream.CanRead)
        {
            throw new ArgumentException(Resources.HttpRequestStreamReader_StreamNotReadable, nameof(stream));
        }
 
        _byteBufferSize = bufferSize;
 
        _decoder = encoding.GetDecoder();
        _byteBuffer = _bytePool.Rent(bufferSize);
 
        try
        {
            var requiredLength = encoding.GetMaxCharCount(bufferSize);
            _charBuffer = _charPool.Rent(requiredLength);
        }
        catch
        {
            _bytePool.Return(_byteBuffer);
 
            if (_charBuffer != null)
            {
                _charPool.Return(_charBuffer);
            }
 
            throw;
        }
    }
 
    /// <inheritdoc />
    protected override void Dispose(bool disposing)
    {
        if (disposing && !_disposed)
        {
            _disposed = true;
 
            _bytePool.Return(_byteBuffer);
            _charPool.Return(_charBuffer);
        }
 
        base.Dispose(disposing);
    }
 
    /// <inheritdoc />
    public override int Peek()
    {
        ThrowIfDisposed();
 
        if (_charBufferIndex == _charsRead)
        {
            if (_isBlocked || ReadIntoBuffer() == 0)
            {
                return -1;
            }
        }
 
        return _charBuffer[_charBufferIndex];
    }
 
    /// <inheritdoc />
    public override int Read()
    {
        ThrowIfDisposed();
 
        if (_charBufferIndex == _charsRead)
        {
            if (ReadIntoBuffer() == 0)
            {
                return -1;
            }
        }
 
        return _charBuffer[_charBufferIndex++];
    }
 
    /// <inheritdoc />
    public override int Read(char[] buffer, int index, int count)
    {
        ArgumentNullException.ThrowIfNull(buffer);
        ArgumentOutOfRangeException.ThrowIfNegative(index);
 
        if (count < 0 || index + count > buffer.Length)
        {
            throw new ArgumentOutOfRangeException(nameof(count));
        }
 
        var span = new Span<char>(buffer, index, count);
        return Read(span);
    }
 
    /// <inheritdoc />
    public override int Read(Span<char> buffer)
    {
        ThrowIfDisposed();
 
        var count = buffer.Length;
        var charsRead = 0;
        while (count > 0)
        {
            var charsRemaining = _charsRead - _charBufferIndex;
            if (charsRemaining == 0)
            {
                charsRemaining = ReadIntoBuffer();
            }
 
            if (charsRemaining == 0)
            {
                break;  // We're at EOF
            }
 
            if (charsRemaining > count)
            {
                charsRemaining = count;
            }
 
            var source = new ReadOnlySpan<char>(_charBuffer, _charBufferIndex, charsRemaining);
            source.CopyTo(buffer);
 
            _charBufferIndex += charsRemaining;
 
            charsRead += charsRemaining;
            count -= charsRemaining;
 
            buffer = buffer.Slice(charsRemaining, count);
 
            // If we got back fewer chars than we asked for, then it's likely the underlying stream is blocked.
            // Send the data back to the caller so they can process it.
            if (_isBlocked)
            {
                break;
            }
        }
 
        return charsRead;
    }
 
    /// <inheritdoc />
    public override Task<int> ReadAsync(char[] buffer, int index, int count)
    {
        ArgumentNullException.ThrowIfNull(buffer);
        ArgumentOutOfRangeException.ThrowIfNegative(index);
 
        if (count < 0 || index + count > buffer.Length)
        {
            throw new ArgumentOutOfRangeException(nameof(count));
        }
 
        var memory = new Memory<char>(buffer, index, count);
        return ReadAsync(memory).AsTask();
    }
 
    /// <inheritdoc />
    [SuppressMessage("ApiDesign", "RS0027:Public API with optional parameter(s) should have the most parameters amongst its public overloads.", Justification = "Required to maintain compatibility")]
    public override async ValueTask<int> ReadAsync(Memory<char> buffer, CancellationToken cancellationToken = default)
    {
        ThrowIfDisposed();
 
        if (_charBufferIndex == _charsRead && await ReadIntoBufferAsync() == 0)
        {
            return 0;
        }
 
        var count = buffer.Length;
 
        var charsRead = 0;
        while (count > 0)
        {
            // n is the characters available in _charBuffer
            var charsRemaining = _charsRead - _charBufferIndex;
 
            // charBuffer is empty, let's read from the stream
            if (charsRemaining == 0)
            {
                _charsRead = 0;
                _charBufferIndex = 0;
                _bytesRead = 0;
 
                // We loop here so that we read in enough bytes to yield at least 1 char.
                // We break out of the loop if the stream is blocked (EOF is reached).
                do
                {
                    Debug.Assert(charsRemaining == 0);
                    _bytesRead = await _stream.ReadAsync(_byteBuffer.AsMemory(0, _byteBufferSize), cancellationToken);
                    if (_bytesRead == 0)  // EOF
                    {
                        _isBlocked = true;
                        break;
                    }
 
                    // _isBlocked == whether we read fewer bytes than we asked for.
                    _isBlocked = (_bytesRead < _byteBufferSize);
 
                    Debug.Assert(charsRemaining == 0);
 
                    _charBufferIndex = 0;
                    charsRemaining = _decoder.GetChars(
                        _byteBuffer,
                        0,
                        _bytesRead,
                        _charBuffer,
                        0);
 
                    Debug.Assert(charsRemaining > 0);
 
                    _charsRead += charsRemaining; // Number of chars in StreamReader's buffer.
                }
                while (charsRemaining == 0);
 
                if (charsRemaining == 0)
                {
                    break; // We're at EOF
                }
            }
 
            // Got more chars in charBuffer than the user requested
            if (charsRemaining > count)
            {
                charsRemaining = count;
            }
 
            var source = new Memory<char>(_charBuffer, _charBufferIndex, charsRemaining);
            source.CopyTo(buffer);
 
            _charBufferIndex += charsRemaining;
 
            charsRead += charsRemaining;
            count -= charsRemaining;
 
            buffer = buffer.Slice(charsRemaining, count);
 
            // This function shouldn't block for an indefinite amount of time,
            // or reading from a network stream won't work right.  If we got
            // fewer bytes than we requested, then we want to break right here.
            if (_isBlocked)
            {
                break;
            }
        }
 
        return charsRead;
    }
 
    /// <inheritdoc />
    public override async Task<string?> ReadLineAsync()
    {
        ThrowIfDisposed();
 
        StringBuilder? sb = null;
        var consumeLineFeed = false;
 
        while (true)
        {
            if (_charBufferIndex == _charsRead)
            {
                if (await ReadIntoBufferAsync() == 0)
                {
                    // reached EOF, we need to return null if we were at EOF from the beginning
                    return sb?.ToString();
                }
            }
 
            var stepResult = ReadLineStep(ref sb, ref consumeLineFeed);
 
            if (stepResult.Completed)
            {
                return stepResult.Result ?? sb?.ToString();
            }
 
            continue;
        }
    }
 
    // Reads a line. A line is defined as a sequence of characters followed by
    // a carriage return ('\r'), a line feed ('\n'), or a carriage return
    // immediately followed by a line feed. The resulting string does not
    // contain the terminating carriage return and/or line feed. The returned
    // value is null if the end of the input stream has been reached.
    /// <inheritdoc />
    public override string? ReadLine()
    {
        ThrowIfDisposed();
 
        StringBuilder? sb = null;
        var consumeLineFeed = false;
 
        while (true)
        {
            if (_charBufferIndex == _charsRead)
            {
                if (ReadIntoBuffer() == 0)
                {
                    // reached EOF, we need to return null if we were at EOF from the beginning
                    return sb?.ToString();
                }
            }
 
            var stepResult = ReadLineStep(ref sb, ref consumeLineFeed);
 
            if (stepResult.Completed)
            {
                return stepResult.Result ?? sb?.ToString();
            }
        }
    }
 
    private ReadLineStepResult ReadLineStep(ref StringBuilder? sb, ref bool consumeLineFeed)
    {
        const char carriageReturn = '\r';
        const char lineFeed = '\n';
 
        if (consumeLineFeed)
        {
            if (_charBuffer[_charBufferIndex] == lineFeed)
            {
                _charBufferIndex++;
            }
            return ReadLineStepResult.Done;
        }
 
        var span = new Span<char>(_charBuffer, _charBufferIndex, _charsRead - _charBufferIndex);
 
        var index = span.IndexOfAny(carriageReturn, lineFeed);
 
        if (index != -1)
        {
            if (span[index] == carriageReturn)
            {
                span = span.Slice(0, index);
                _charBufferIndex += index + 1;
 
                if (_charBufferIndex < _charsRead)
                {
                    // consume following line feed
                    if (_charBuffer[_charBufferIndex] == lineFeed)
                    {
                        _charBufferIndex++;
                    }
 
                    if (sb != null)
                    {
                        sb.Append(span);
                        return ReadLineStepResult.Done;
                    }
 
                    // perf: if the new line is found in first pass, we skip the StringBuilder
                    return ReadLineStepResult.FromResult(span.ToString());
                }
 
                // we where at the end of buffer, we need to read more to check for a line feed to consume
                sb ??= new StringBuilder();
                sb.Append(span);
                consumeLineFeed = true;
                return ReadLineStepResult.Continue;
            }
 
            if (span[index] == lineFeed)
            {
                span = span.Slice(0, index);
                _charBufferIndex += index + 1;
 
                if (sb != null)
                {
                    sb.Append(span);
                    return ReadLineStepResult.Done;
                }
 
                // perf: if the new line is found in first pass, we skip the StringBuilder
                return ReadLineStepResult.FromResult(span.ToString());
            }
        }
 
        sb ??= new StringBuilder();
        sb.Append(span);
        _charBufferIndex = _charsRead;
 
        return ReadLineStepResult.Continue;
    }
 
    private int ReadIntoBuffer()
    {
        _charsRead = 0;
        _charBufferIndex = 0;
        _bytesRead = 0;
 
        do
        {
            _bytesRead = _stream.Read(_byteBuffer, 0, _byteBufferSize);
            if (_bytesRead == 0)  // We're at EOF
            {
                return _charsRead;
            }
 
            _isBlocked = (_bytesRead < _byteBufferSize);
            _charsRead += _decoder.GetChars(
                _byteBuffer,
                0,
                _bytesRead,
                _charBuffer,
                _charsRead);
        }
        while (_charsRead == 0);
 
        return _charsRead;
    }
 
    private async Task<int> ReadIntoBufferAsync()
    {
        _charsRead = 0;
        _charBufferIndex = 0;
        _bytesRead = 0;
 
        do
        {
            _bytesRead = await _stream.ReadAsync(_byteBuffer.AsMemory(0, _byteBufferSize)).ConfigureAwait(false);
            if (_bytesRead == 0)
            {
                // We're at EOF
                return _charsRead;
            }
 
            // _isBlocked == whether we read fewer bytes than we asked for.
            _isBlocked = (_bytesRead < _byteBufferSize);
 
            _charsRead += _decoder.GetChars(
                _byteBuffer,
                0,
                _bytesRead,
                _charBuffer,
                _charsRead);
        }
        while (_charsRead == 0);
 
        return _charsRead;
    }
 
    /// <inheritdoc />
    public override async Task<string> ReadToEndAsync()
    {
        StringBuilder sb = new StringBuilder(_charsRead - _charBufferIndex);
        do
        {
            int tmpCharPos = _charBufferIndex;
            sb.Append(_charBuffer, tmpCharPos, _charsRead - tmpCharPos);
            _charBufferIndex = _charsRead;  // We consumed these characters
            await ReadIntoBufferAsync().ConfigureAwait(false);
        } while (_charsRead > 0);
 
        return sb.ToString();
    }
 
    private readonly struct ReadLineStepResult
    {
        public static readonly ReadLineStepResult Done = new ReadLineStepResult(true, null);
        public static readonly ReadLineStepResult Continue = new ReadLineStepResult(false, null);
 
        public static ReadLineStepResult FromResult(string value) => new ReadLineStepResult(true, value);
 
        private ReadLineStepResult(bool completed, string? result)
        {
            Completed = completed;
            Result = result;
        }
 
        public bool Completed { get; }
        public string? Result { get; }
    }
 
    private void ThrowIfDisposed()
    {
        ObjectDisposedException.ThrowIf(_disposed, this);
    }
}