File: FileBufferingWriteStream.cs
Web Access
Project: src\src\Http\WebUtilities\src\Microsoft.AspNetCore.WebUtilities.csproj (Microsoft.AspNetCore.WebUtilities)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO.Pipelines;
using System.Runtime.InteropServices;
using Microsoft.AspNetCore.Internal;
 
namespace Microsoft.AspNetCore.WebUtilities;
 
/// <summary>
/// A <see cref="Stream"/> that buffers content to be written to disk. Use <see cref="DrainBufferAsync(Stream, CancellationToken)" />
/// to write buffered content to a target <see cref="Stream" />.
/// </summary>
public sealed class FileBufferingWriteStream : Stream
{
    private const int DefaultMemoryThreshold = 32 * 1024; // 32k
 
    private readonly int _memoryThreshold;
    private readonly long? _bufferLimit;
    private readonly Func<string> _tempFileDirectoryAccessor;
 
    /// <summary>
    /// Initializes a new instance of <see cref="FileBufferingWriteStream"/>.
    /// </summary>
    /// <param name="memoryThreshold">
    /// The maximum amount of memory in bytes to allocate before switching to a file on disk.
    /// Defaults to 32kb.
    /// </param>
    /// <param name="bufferLimit">
    /// The maximum amount of bytes that the <see cref="FileBufferingWriteStream"/> is allowed to buffer.
    /// </param>
    /// <param name="tempFileDirectoryAccessor">Provides the location of the directory to write buffered contents to.
    /// When unspecified, uses the value specified by the environment variable <c>ASPNETCORE_TEMP</c> if available, otherwise
    /// uses the value returned by <see cref="Path.GetTempPath"/>.
    /// </param>
    public FileBufferingWriteStream(
        int memoryThreshold = DefaultMemoryThreshold,
        long? bufferLimit = null,
        Func<string>? tempFileDirectoryAccessor = null)
    {
        ArgumentOutOfRangeException.ThrowIfNegative(memoryThreshold);
 
        if (bufferLimit != null && bufferLimit < memoryThreshold)
        {
            // We would expect a limit at least as much as memoryThreshold
            throw new ArgumentOutOfRangeException(nameof(bufferLimit), $"{nameof(bufferLimit)} must be larger than {nameof(memoryThreshold)}.");
        }
 
        _memoryThreshold = memoryThreshold;
        _bufferLimit = bufferLimit;
        _tempFileDirectoryAccessor = tempFileDirectoryAccessor ?? AspNetCoreTempDirectory.TempDirectoryFactory;
        PagedByteBuffer = new PagedByteBuffer(ArrayPool<byte>.Shared);
    }
 
    /// <summary>
    /// The maximum amount of memory in bytes to allocate before switching to a file on disk.
    /// </summary>
    /// <remarks>
    /// Defaults to 32kb.
    /// </remarks>
    public int MemoryThreshold => _memoryThreshold;
 
    /// <inheritdoc />
    public override bool CanRead => false;
 
    /// <inheritdoc />
    public override bool CanSeek => false;
 
    /// <inheritdoc />
    public override bool CanWrite => true;
 
    /// <inheritdoc />
    public override long Length => PagedByteBuffer.Length + (FileStream?.Length ?? 0);
 
    /// <inheritdoc />
    public override long Position
    {
        get => throw new NotSupportedException();
        set => throw new NotSupportedException();
    }
 
    internal PagedByteBuffer PagedByteBuffer { get; }
 
    internal FileStream? FileStream { get; private set; }
 
    internal bool Disposed { get; private set; }
 
    /// <inheritdoc />
    public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
 
    /// <inheritdoc />
    public override int Read(byte[] buffer, int offset, int count)
        => throw new NotSupportedException();
 
    /// <inheritdoc />
    public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
        => throw new NotSupportedException();
 
    /// <inheritdoc/>
    public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken)
        => throw new NotSupportedException();
 
    /// <inheritdoc />
    public override void Write(byte[] buffer, int offset, int count)
    {
        ValidateBufferArguments(buffer, offset, count);
        ThrowIfDisposed();
 
        if (_bufferLimit.HasValue && _bufferLimit - Length < count)
        {
            Dispose();
            throw new IOException("Buffer limit exceeded.");
        }
 
        // Allow buffering in memory if we're below the memory threshold once the current buffer is written.
        var allowMemoryBuffer = (_memoryThreshold - count) >= PagedByteBuffer.Length;
        if (allowMemoryBuffer)
        {
            // Buffer content in the MemoryStream if it has capacity.
            PagedByteBuffer.Add(buffer, offset, count);
            Debug.Assert(PagedByteBuffer.Length <= _memoryThreshold);
        }
        else
        {
            // If the MemoryStream is incapable of accommodating the content to be written
            // spool to disk.
            EnsureFileStream();
 
            // Spool memory content to disk.
            PagedByteBuffer.MoveTo(FileStream);
 
            FileStream.Write(buffer, offset, count);
        }
    }
 
    /// <inheritdoc />
    public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
    {
        await WriteAsync(buffer.AsMemory(offset, count), cancellationToken);
    }
 
    /// <inheritdoc />
    [SuppressMessage("ApiDesign", "RS0027:Public API with optional parameter(s) should have the most parameters amongst its public overloads", Justification = "This is a method overload.")]
    public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
    {
        ThrowIfDisposed();
 
        if (_bufferLimit.HasValue && _bufferLimit - Length < buffer.Length)
        {
            Dispose();
            throw new IOException("Buffer limit exceeded.");
        }
 
        // Allow buffering in memory if we're below the memory threshold once the current buffer is written.
        var allowMemoryBuffer = (_memoryThreshold - buffer.Length) >= PagedByteBuffer.Length;
        if (allowMemoryBuffer)
        {
            // Buffer content in the MemoryStream if it has capacity.
            PagedByteBuffer.Add(buffer);
            Debug.Assert(PagedByteBuffer.Length <= _memoryThreshold);
        }
        else
        {
            // If the MemoryStream is incapable of accommodating the content to be written
            // spool to disk.
            EnsureFileStream();
 
            // Spool memory content to disk.
            await PagedByteBuffer.MoveToAsync(FileStream, cancellationToken);
            await FileStream.WriteAsync(buffer, cancellationToken);
        }
    }
 
    /// <inheritdoc />
    public override void Flush()
    {
        // Do nothing.
    }
 
    /// <inheritdoc />
    public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
 
    /// <inheritdoc />
    public override void SetLength(long value) => throw new NotSupportedException();
 
    /// <summary>
    /// Drains buffered content to <paramref name="destination"/>.
    /// </summary>
    /// <param name="destination">The <see cref="Stream" /> to drain buffered contents to.</param>
    /// <param name="cancellationToken">The <see cref="CancellationToken" />.</param>
    /// <returns>A <see cref="Task" /> that represents the asynchronous drain operation.</returns>
    public async Task DrainBufferAsync(Stream destination, CancellationToken cancellationToken = default)
    {
        // When not null, FileStream always has "older" spooled content. The PagedByteBuffer always has "newer"
        // unspooled content. Copy the FileStream content first when available.
        if (FileStream != null)
        {
            // We make a new stream for async reads from disk and async writes to the destination
            await using var readStream = new FileStream(FileStream.Name, FileMode.Open, FileAccess.Read, FileShare.Delete | FileShare.ReadWrite, bufferSize: 1, useAsync: true);
 
            await readStream.CopyToAsync(destination, cancellationToken);
 
            // This is created with delete on close
            await FileStream.DisposeAsync();
            FileStream = null;
        }
 
        await PagedByteBuffer.MoveToAsync(destination, cancellationToken);
    }
 
    /// <summary>
    /// Drains buffered content to <paramref name="destination"/>.
    /// </summary>
    /// <param name="destination">The <see cref="PipeWriter" /> to drain buffered contents to.</param>
    /// <param name="cancellationToken">The <see cref="CancellationToken" />.</param>
    /// <returns>A <see cref="Task" /> that represents the asynchronous drain operation.</returns>
    [SuppressMessage("ApiDesign", "RS0026:Do not add multiple public overloads with optional parameters", Justification = "Required to maintain compatibility")]
    public async Task DrainBufferAsync(PipeWriter destination, CancellationToken cancellationToken = default)
    {
        // When not null, FileStream always has "older" spooled content. The PagedByteBuffer always has "newer"
        // unspooled content. Copy the FileStream content first when available.
        if (FileStream != null)
        {
            // We make a new stream for async reads from disk and async writes to the destination
            await using var readStream = new FileStream(FileStream.Name, FileMode.Open, FileAccess.Read, FileShare.Delete | FileShare.ReadWrite, bufferSize: 1, useAsync: true);
 
            await readStream.CopyToAsync(destination, cancellationToken);
 
            // This is created with delete on close
            await FileStream.DisposeAsync();
            FileStream = null;
        }
 
        await PagedByteBuffer.MoveToAsync(destination, cancellationToken);
    }
 
    /// <inheritdoc />
    protected override void Dispose(bool disposing)
    {
        if (!Disposed)
        {
            Disposed = true;
 
            PagedByteBuffer.Dispose();
            FileStream?.Dispose();
        }
    }
 
    /// <inheritdoc />
    public override async ValueTask DisposeAsync()
    {
        if (!Disposed)
        {
            Disposed = true;
 
            PagedByteBuffer.Dispose();
            await (FileStream?.DisposeAsync() ?? default);
        }
    }
 
    [MemberNotNull(nameof(FileStream))]
    private void EnsureFileStream()
    {
        if (FileStream == null)
        {
            var tempFileDirectory = _tempFileDirectoryAccessor();
            var tempFileName = Path.Combine(tempFileDirectory, "ASPNETCORE_" + Guid.NewGuid() + ".tmp");
 
            // Create a temp file with the correct Unix file mode before moving it to the assigned tempFileName in the _tempFileDirectory.
            if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
            {
                var tempTempFileName = Path.GetTempFileName();
                File.Move(tempTempFileName, tempFileName);
            }
 
            FileStream = new FileStream(
                tempFileName,
                FileMode.Create,
                FileAccess.Write,
                FileShare.Delete | FileShare.ReadWrite,
                bufferSize: 1,
                FileOptions.SequentialScan | FileOptions.DeleteOnClose);
        }
    }
 
    private void ThrowIfDisposed()
    {
        ObjectDisposedException.ThrowIf(Disposed, this);
    }
}