File: Streams\OutputCacheStream.cs
Web Access
Project: src\src\Middleware\OutputCaching\src\Microsoft.AspNetCore.OutputCaching.csproj (Microsoft.AspNetCore.OutputCaching)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
 
namespace Microsoft.AspNetCore.OutputCaching;
 
internal sealed class OutputCacheStream : Stream
{
    private readonly Stream _innerStream;
    private readonly long _maxBufferSize;
    private readonly int _segmentSize;
    private readonly RecyclableSequenceBuilder _segmentWriteStream;
    private readonly Action _startResponseCallback;
 
    internal OutputCacheStream(Stream innerStream, long maxBufferSize, int segmentSize, Action startResponseCallback)
    {
        _innerStream = innerStream;
        _maxBufferSize = maxBufferSize;
        _segmentSize = segmentSize;
        _startResponseCallback = startResponseCallback;
        _segmentWriteStream = new(_segmentSize);
    }
 
    internal bool BufferingEnabled { get; private set; } = true;
 
    public override bool CanRead => _innerStream.CanRead;
 
    public override bool CanSeek => _innerStream.CanSeek;
 
    public override bool CanWrite => _innerStream.CanWrite;
 
    public override long Length => _innerStream.Length;
 
    public override long Position
    {
        get { return _innerStream.Position; }
        set
        {
            DisableBuffering();
            _innerStream.Position = value;
        }
    }
 
    internal ReadOnlySequence<byte> GetCachedResponseBody()
    {
        if (!BufferingEnabled)
        {
            throw new InvalidOperationException("Buffer stream cannot be retrieved since buffering is disabled.");
        }
        return _segmentWriteStream.DetachAndReset();
    }
 
    internal void DisableBuffering()
    {
        BufferingEnabled = false;
        _segmentWriteStream.Dispose();
    }
 
    public override void SetLength(long value)
    {
        DisableBuffering();
        _innerStream.SetLength(value);
    }
 
    public override long Seek(long offset, SeekOrigin origin)
    {
        DisableBuffering();
        return _innerStream.Seek(offset, origin);
    }
 
    public override void Flush()
    {
        try
        {
            _startResponseCallback();
            _innerStream.Flush();
        }
        catch
        {
            DisableBuffering();
            throw;
        }
    }
 
    public override async Task FlushAsync(CancellationToken cancellationToken)
    {
        try
        {
            _startResponseCallback();
            await _innerStream.FlushAsync(cancellationToken);
        }
        catch
        {
            DisableBuffering();
            throw;
        }
    }
 
    // Underlying stream is write-only, no need to override other read related methods
    public override int Read(byte[] buffer, int offset, int count)
        => _innerStream.Read(buffer, offset, count);
 
    public override void Write(byte[] buffer, int offset, int count)
    {
        try
        {
            _startResponseCallback();
            _innerStream.Write(buffer, offset, count);
        }
        catch
        {
            DisableBuffering();
            throw;
        }
 
        if (BufferingEnabled)
        {
            if (_segmentWriteStream.Length + count > _maxBufferSize)
            {
                DisableBuffering();
            }
            else
            {
                _segmentWriteStream.Write(buffer.AsSpan(offset, count));
            }
        }
    }
 
    public override void Write(ReadOnlySpan<byte> buffer)
    {
        try
        {
            _startResponseCallback();
            _innerStream.Write(buffer);
        }
        catch
        {
            DisableBuffering();
            throw;
        }
 
        if (BufferingEnabled)
        {
            if (_segmentWriteStream.Length + buffer.Length > _maxBufferSize)
            {
                DisableBuffering();
            }
            else
            {
                _segmentWriteStream.Write(buffer);
            }
        }
    }
 
    public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
        await WriteAsync(buffer.AsMemory(offset, count), cancellationToken);
 
    public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
    {
        try
        {
            _startResponseCallback();
            await _innerStream.WriteAsync(buffer, cancellationToken);
        }
        catch
        {
            DisableBuffering();
            throw;
        }
 
        if (BufferingEnabled)
        {
            if (_segmentWriteStream.Length + buffer.Length > _maxBufferSize)
            {
                DisableBuffering();
            }
            else
            {
                _segmentWriteStream.Write(buffer.Span);
            }
        }
    }
 
    public override void WriteByte(byte value)
    {
        try
        {
            _innerStream.WriteByte(value);
        }
        catch
        {
            DisableBuffering();
            throw;
        }
 
        if (BufferingEnabled)
        {
            if (_segmentWriteStream.Length + 1 > _maxBufferSize)
            {
                DisableBuffering();
            }
            else
            {
                _segmentWriteStream.WriteByte(value);
            }
        }
    }
 
    protected override void Dispose(bool disposing)
    {
        if (disposing)
        {
            _segmentWriteStream?.Dispose();
        }
        base.Dispose(disposing);
    }
 
    public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
        => TaskToApm.Begin(WriteAsync(buffer, offset, count, CancellationToken.None), callback, state);
 
    public override void EndWrite(IAsyncResult asyncResult)
        => TaskToApm.End(asyncResult);
}