|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Globalization;
using System.IO.Pipelines;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
internal abstract class MessageBody
{
private static readonly MessageBody _zeroContentLengthClose = new ZeroContentLengthMessageBody(keepAlive: false);
private static readonly MessageBody _zeroContentLengthKeepAlive = new ZeroContentLengthMessageBody(keepAlive: true);
private readonly HttpProtocol _context;
private bool _send100Continue = true;
private long _observedBytes;
private bool _stopped;
protected bool _timingEnabled;
protected bool _backpressure;
protected long _alreadyTimedBytes;
protected long _examinedUnconsumedBytes;
protected MessageBody(HttpProtocol context)
{
_context = context;
}
public static MessageBody ZeroContentLengthClose => _zeroContentLengthClose;
public static MessageBody ZeroContentLengthKeepAlive => _zeroContentLengthKeepAlive;
public bool RequestKeepAlive { get; protected set; }
public bool RequestUpgrade { get; protected set; }
public bool ExtendedConnect { get; protected set; }
public HttpProtocol Context => _context;
public virtual bool IsEmpty => false;
protected KestrelTrace Log => _context.ServiceContext.Log;
public abstract ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default);
public abstract bool TryRead(out ReadResult readResult);
public void AdvanceTo(SequencePosition consumed)
{
AdvanceTo(consumed, consumed);
}
public abstract void AdvanceTo(SequencePosition consumed, SequencePosition examined);
public abstract void CancelPendingRead();
public abstract void Complete(Exception? exception);
public virtual ValueTask CompleteAsync(Exception? exception)
{
Complete(exception);
return default;
}
public virtual Task ConsumeAsync()
{
Task startTask = TryStartAsync();
if (!startTask.IsCompletedSuccessfully)
{
return ConsumeAwaited(startTask);
}
return OnConsumeAsync();
}
private async Task ConsumeAwaited(Task startTask)
{
await startTask;
await OnConsumeAsync();
}
public virtual ValueTask StopAsync()
{
TryStop();
return OnStopAsync();
}
protected virtual Task OnConsumeAsync() => Task.CompletedTask;
protected virtual ValueTask OnStopAsync() => default;
public virtual void Reset()
{
_send100Continue = true;
_observedBytes = 0;
_stopped = false;
_timingEnabled = false;
_backpressure = false;
_alreadyTimedBytes = 0;
_examinedUnconsumedBytes = 0;
}
protected ValueTask<FlushResult> TryProduceContinueAsync()
{
if (_send100Continue)
{
_send100Continue = false;
return _context.HttpResponseControl.ProduceContinueAsync();
}
return default;
}
protected Task TryStartAsync()
{
if (_context.HasStartedConsumingRequestBody)
{
return Task.CompletedTask;
}
OnReadStarting();
_context.HasStartedConsumingRequestBody = true;
if (!RequestUpgrade && !ExtendedConnect)
{
// Accessing TraceIdentifier will lazy-allocate a string ID.
// Don't access TraceIdentifer unless logging is enabled.
if (Log.IsEnabled(LogLevel.Debug))
{
Log.RequestBodyStart(_context.ConnectionIdFeature, _context.TraceIdentifier);
}
if (_context.MinRequestBodyDataRate != null)
{
_timingEnabled = true;
_context.TimeoutControl.StartRequestBody(_context.MinRequestBodyDataRate);
}
}
return OnReadStartedAsync();
}
protected void TryStop()
{
if (_stopped)
{
return;
}
_stopped = true;
if (!RequestUpgrade && !ExtendedConnect)
{
// Accessing TraceIdentifier will lazy-allocate a string ID
// Don't access TraceIdentifer unless logging is enabled.
if (Log.IsEnabled(LogLevel.Debug))
{
Log.RequestBodyDone(_context.ConnectionIdFeature, _context.TraceIdentifier);
}
if (_timingEnabled)
{
if (_backpressure)
{
_context.TimeoutControl.StopTimingRead();
}
_context.TimeoutControl.StopRequestBody();
}
}
}
protected virtual void OnReadStarting()
{
}
protected virtual Task OnReadStartedAsync()
{
return Task.CompletedTask;
}
protected void AddAndCheckObservedBytes(long observedBytes)
{
_observedBytes += observedBytes;
var maxRequestBodySize = _context.MaxRequestBodySize;
if (_observedBytes > maxRequestBodySize)
{
OnObservedBytesExceedMaxRequestBodySize(maxRequestBodySize.Value);
}
}
protected virtual void OnObservedBytesExceedMaxRequestBodySize(long maxRequestBodySize)
{
KestrelBadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge, maxRequestBodySize.ToString(CultureInfo.InvariantCulture));
}
protected ValueTask<ReadResult> StartTimingReadAsync(ValueTask<ReadResult> readAwaitable, CancellationToken cancellationToken)
{
if (!readAwaitable.IsCompleted)
{
ValueTask<FlushResult> continueTask = TryProduceContinueAsync();
if (!continueTask.IsCompletedSuccessfully)
{
return StartTimingReadAwaited(continueTask, readAwaitable, cancellationToken);
}
else
{
continueTask.GetAwaiter().GetResult();
}
if (_timingEnabled)
{
_backpressure = true;
_context.TimeoutControl.StartTimingRead();
}
}
return readAwaitable;
}
protected async ValueTask<ReadResult> StartTimingReadAwaited(ValueTask<FlushResult> continueTask, ValueTask<ReadResult> readAwaitable, CancellationToken cancellationToken)
{
await continueTask;
if (_timingEnabled)
{
_backpressure = true;
_context.TimeoutControl.StartTimingRead();
}
return await readAwaitable;
}
protected void CountBytesRead(long bytesInReadResult)
{
var numFirstSeenBytes = bytesInReadResult - _alreadyTimedBytes;
if (numFirstSeenBytes > 0)
{
_context.TimeoutControl.BytesRead(numFirstSeenBytes);
}
}
protected void StopTimingRead(long bytesInReadResult)
{
CountBytesRead(bytesInReadResult);
if (_backpressure)
{
_backpressure = false;
_context.TimeoutControl.StopTimingRead();
}
}
protected long TrackConsumedAndExaminedBytes(ReadResult readResult, SequencePosition consumed, SequencePosition examined)
{
// This code path is fairly hard to understand so let's break it down with an example
// ReadAsync returns a ReadResult of length 50.
// Advance(25, 40). The examined length would be 40 and consumed length would be 25.
// _totalExaminedInPreviousReadResult starts at 0. newlyExamined is 40.
// OnDataRead is called with length 40.
// _totalExaminedInPreviousReadResult is now 40 - 25 = 15.
// The next call to ReadAsync returns 50 again
// Advance(5, 5) is called
// newlyExamined is 5 - 15, or -10.
// Update _totalExaminedInPreviousReadResult to 10 as we consumed 5.
// The next call to ReadAsync returns 50 again
// _totalExaminedInPreviousReadResult is 10
// Advance(50, 50) is called
// newlyExamined = 50 - 10 = 40
// _totalExaminedInPreviousReadResult is now 50
// _totalExaminedInPreviousReadResult is finally 0 after subtracting consumedLength.
long examinedLength, consumedLength, totalLength;
if (consumed.Equals(examined))
{
examinedLength = readResult.Buffer.Slice(readResult.Buffer.Start, examined).Length;
consumedLength = examinedLength;
}
else
{
consumedLength = readResult.Buffer.Slice(readResult.Buffer.Start, consumed).Length;
examinedLength = consumedLength + readResult.Buffer.Slice(consumed, examined).Length;
}
if (examined.Equals(readResult.Buffer.End))
{
totalLength = examinedLength;
}
else
{
totalLength = readResult.Buffer.Length;
}
var newlyExaminedBytes = examinedLength - _examinedUnconsumedBytes;
_examinedUnconsumedBytes += newlyExaminedBytes - consumedLength;
_alreadyTimedBytes = totalLength - consumedLength;
return newlyExaminedBytes;
}
}
|