File: Logging\Internal\HttpRequestBodyReader.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.Http.Diagnostics\Microsoft.Extensions.Http.Diagnostics.csproj (Microsoft.Extensions.Http.Diagnostics)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Frozen;
using System.IO;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
#if NETCOREAPP3_1_OR_GREATER
using Microsoft.Extensions.ObjectPool;
#endif
using Microsoft.Shared.Diagnostics;
#if NETCOREAPP3_1_OR_GREATER
using Microsoft.Shared.Pools;
#else
using System.Buffers;
#endif
 
namespace Microsoft.Extensions.Http.Logging.Internal;
 
internal sealed class HttpRequestBodyReader
{
    /// <summary>
    /// Exposed for testing purposes.
    /// </summary>
    internal readonly TimeSpan RequestReadTimeout;
 
#if NETCOREAPP3_1_OR_GREATER
    private static readonly ObjectPool<BufferWriter<byte>> _bufferWriterPool = BufferWriterPool.SharedBufferWriterPool;
#endif
    private readonly FrozenSet<string> _readableRequestContentTypes;
    private readonly int _requestReadLimit;
 
    public HttpRequestBodyReader(LoggingOptions requestOptions, IDebuggerState? debugger = null)
    {
        _readableRequestContentTypes = requestOptions.RequestBodyContentTypes.ToFrozenSet(StringComparer.OrdinalIgnoreCase);
        debugger ??= DebuggerState.System;
        _requestReadLimit = requestOptions.BodySizeLimit;
 
        RequestReadTimeout = debugger.IsAttached
            ? Timeout.InfiniteTimeSpan
            : requestOptions.BodyReadTimeout;
    }
 
    public ValueTask<string> ReadAsync(HttpRequestMessage request, CancellationToken cancellationToken)
    {
        if (request.Content == null || request.Method == HttpMethod.Get)
        {
            return new(string.Empty);
        }
 
        var contentType = request.Content.Headers.ContentType;
        if (contentType == null)
        {
            return new(Constants.NoContent);
        }
 
        if (!_readableRequestContentTypes.Covers(contentType.MediaType))
        {
            return new(Constants.UnreadableContent);
        }
 
        return ReadFromStreamWithTimeoutAsync(request, RequestReadTimeout, _requestReadLimit, cancellationToken).Preserve();
    }
 
    private static async ValueTask<string> ReadFromStreamWithTimeoutAsync(HttpRequestMessage request,
        TimeSpan readTimeout, int readSizeLimit, CancellationToken cancellationToken)
    {
        using var joinedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
        joinedTokenSource.CancelAfter(readTimeout);
 
        try
        {
            return await ReadFromStreamAsync(request, readSizeLimit, joinedTokenSource.Token).ConfigureAwait(false);
        }
 
        // when readTimeout occurred:
        catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
        {
            return Constants.ReadCancelled;
        }
    }
 
    private static async ValueTask<string> ReadFromStreamAsync(HttpRequestMessage request, int readSizeLimit,
        CancellationToken cancellationToken)
    {
#if NET5_0_OR_GREATER
        var streamToReadFrom = await request.Content!.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
#else
        var streamToReadFrom = await request.Content.ReadAsStreamAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
#endif
 
        var readLimit = Math.Min(readSizeLimit, (int)streamToReadFrom.Length);
#if NETCOREAPP3_1_OR_GREATER
        var bufferWriter = _bufferWriterPool.Get();
        try
        {
            var memory = bufferWriter.GetMemory(readLimit).Slice(0, readLimit);
            var charsWritten = await streamToReadFrom.ReadAsync(memory, cancellationToken).ConfigureAwait(false);
 
            return Encoding.UTF8.GetString(memory[..charsWritten].Span);
        }
        finally
        {
            _bufferWriterPool.Return(bufferWriter);
            streamToReadFrom.Seek(0, SeekOrigin.Begin);
        }
 
#else
        var buffer = ArrayPool<byte>.Shared.Rent(readLimit);
        try
        {
            _ = await streamToReadFrom.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
            return Encoding.UTF8.GetString(buffer.AsSpan(0, readLimit).ToArray());
        }
        finally
        {
            ArrayPool<byte>.Shared.Return(buffer);
            streamToReadFrom.Seek(0, SeekOrigin.Begin);
        }
#endif
    }
}