File: ConcurrentPipeWriterTests.cs
Web Access
Project: src\src\Servers\Kestrel\Core\test\Microsoft.AspNetCore.Server.Kestrel.Core.Tests.csproj (Microsoft.AspNetCore.Server.Kestrel.Core.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Buffers;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure.PipeWriterHelpers;
using Microsoft.AspNetCore.InternalTesting;
using Xunit;
 
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests;
 
public class ConcurrentPipeWriterTests
{
    [Fact]
    public async Task PassthroughIfAllFlushesAreAwaited()
    {
        using (var memoryPool = new PinnedBlockMemoryPool())
        using (var diagnosticPool = new DiagnosticMemoryPool(memoryPool))
        {
            var pipeWriterFlushTcsArray = new[] {
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                };
 
            var sync = new object();
            var mockPipeWriter = new MockPipeWriter(pipeWriterFlushTcsArray);
            var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool, sync);
 
            ValueTask<FlushResult> flushTask;
 
            lock (sync)
            {
                var memory = concurrentPipeWriter.GetMemory();
                Assert.Equal(1, mockPipeWriter.GetMemoryCallCount);
 
                concurrentPipeWriter.Advance(memory.Length);
                Assert.Equal(1, mockPipeWriter.AdvanceCallCount);
 
                flushTask = concurrentPipeWriter.FlushAsync();
                Assert.Equal(1, mockPipeWriter.FlushCallCount);
 
                pipeWriterFlushTcsArray[0].SetResult(default);
            }
 
            await flushTask.DefaultTimeout();
 
            lock (sync)
            {
                var memory = concurrentPipeWriter.GetMemory();
                Assert.Equal(2, mockPipeWriter.GetMemoryCallCount);
 
                concurrentPipeWriter.Advance(memory.Length);
                Assert.Equal(2, mockPipeWriter.AdvanceCallCount);
 
                flushTask = concurrentPipeWriter.FlushAsync();
                Assert.Equal(2, mockPipeWriter.FlushCallCount);
 
                pipeWriterFlushTcsArray[1].SetResult(default);
            }
 
            await flushTask.DefaultTimeout();
 
            var completeEx = new Exception();
            ValueTask completeTask;
 
            lock (sync)
            {
                completeTask = concurrentPipeWriter.CompleteAsync(completeEx);
            }
 
            await completeTask.DefaultTimeout();
 
            Assert.Same(completeEx, mockPipeWriter.CompleteException);
        }
    }
 
    [Fact]
    public async Task QueuesIfFlushIsNotAwaited()
    {
        using (var memoryPool = new PinnedBlockMemoryPool())
        using (var diagnosticPool = new DiagnosticMemoryPool(memoryPool))
        {
            var pipeWriterFlushTcsArray = new[] {
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                };
 
            var sync = new object();
            var mockPipeWriter = new MockPipeWriter(pipeWriterFlushTcsArray);
            var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool, sync);
            var flushTask0 = default(ValueTask<FlushResult>);
            var flushTask1 = default(ValueTask<FlushResult>);
            var completeTask = default(ValueTask);
 
            lock (sync)
            {
                var memory = concurrentPipeWriter.GetMemory();
                Assert.Equal(1, mockPipeWriter.GetMemoryCallCount);
 
                concurrentPipeWriter.Advance(memory.Length);
                Assert.Equal(1, mockPipeWriter.AdvanceCallCount);
 
                flushTask0 = concurrentPipeWriter.FlushAsync();
                Assert.Equal(1, mockPipeWriter.FlushCallCount);
 
                Assert.False(flushTask0.IsCompleted);
 
                // Since the flush was not awaited, the following API calls are queued.
                memory = concurrentPipeWriter.GetMemory();
                concurrentPipeWriter.Advance(memory.Length);
                flushTask1 = concurrentPipeWriter.FlushAsync();
            }
 
            Assert.Equal(1, mockPipeWriter.GetMemoryCallCount);
            Assert.Equal(1, mockPipeWriter.AdvanceCallCount);
            Assert.Equal(1, mockPipeWriter.FlushCallCount);
 
            Assert.False(flushTask0.IsCompleted);
            Assert.False(flushTask1.IsCompleted);
 
            mockPipeWriter.FlushTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
            pipeWriterFlushTcsArray[0].SetResult(default);
 
            await mockPipeWriter.FlushTcs.Task.DefaultTimeout();
 
            lock (sync)
            {
                // Since the flush was not awaited, the following API calls are queued.
                var memory = concurrentPipeWriter.GetMemory();
                concurrentPipeWriter.Advance(memory.Length);
            }
 
            // We do not need to flush the final bytes, since the incomplete flush will pick it up.
            Assert.Equal(2, mockPipeWriter.GetMemoryCallCount);
            Assert.Equal(2, mockPipeWriter.AdvanceCallCount);
            Assert.Equal(2, mockPipeWriter.FlushCallCount);
 
            Assert.False(flushTask0.IsCompleted);
            Assert.False(flushTask1.IsCompleted);
 
            mockPipeWriter.FlushTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
            pipeWriterFlushTcsArray[1].SetResult(default);
 
            await mockPipeWriter.FlushTcs.Task.DefaultTimeout();
 
            // Even though we only called flush on the ConcurrentPipeWriter twice, the inner PipeWriter was flushed three times.
            Assert.Equal(3, mockPipeWriter.GetMemoryCallCount);
            Assert.Equal(3, mockPipeWriter.AdvanceCallCount);
            Assert.Equal(3, mockPipeWriter.FlushCallCount);
 
            Assert.False(flushTask0.IsCompleted);
            Assert.False(flushTask1.IsCompleted);
 
            var completeEx = new Exception();
            lock (sync)
            {
                completeTask = concurrentPipeWriter.CompleteAsync(completeEx);
            }
 
            await completeTask.DefaultTimeout();
 
            // Complete isn't called on the inner PipeWriter until the inner flushes have completed.
            Assert.Null(mockPipeWriter.CompleteException);
 
            pipeWriterFlushTcsArray[2].SetResult(default);
 
            await flushTask0.DefaultTimeout();
            await flushTask1.DefaultTimeout();
 
            Assert.Same(completeEx, mockPipeWriter.CompleteException);
        }
    }
 
    [Fact]
    public async Task KeepsQueueIfInnerFlushFinishesBetweenGetMemoryAndAdvance()
    {
        using (var memoryPool = new PinnedBlockMemoryPool())
        using (var diagnosticPool = new DiagnosticMemoryPool(memoryPool))
        {
            var pipeWriterFlushTcsArray = new[] {
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                };
 
            var sync = new object();
            var mockPipeWriter = new MockPipeWriter(pipeWriterFlushTcsArray);
            var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool, sync);
            var memory = default(Memory<byte>);
            var flushTask0 = default(ValueTask<FlushResult>);
            var flushTask1 = default(ValueTask<FlushResult>);
            var completeTask = default(ValueTask);
 
            lock (sync)
            {
                memory = concurrentPipeWriter.GetMemory();
                Assert.Equal(1, mockPipeWriter.GetMemoryCallCount);
 
                concurrentPipeWriter.Advance(memory.Length);
                Assert.Equal(1, mockPipeWriter.AdvanceCallCount);
 
                flushTask0 = concurrentPipeWriter.FlushAsync();
                Assert.Equal(1, mockPipeWriter.FlushCallCount);
                Assert.False(flushTask0.IsCompleted);
 
                // Only GetMemory() is called but not Advance() is not called yet when the first inner flush complets.
                memory = concurrentPipeWriter.GetMemory();
            }
 
            // If the inner flush completes between a call to GetMemory() and Advance(), the outer
            // flush completes, and the next flush will pick up the buffered data.
            pipeWriterFlushTcsArray[0].SetResult(default);
 
            await flushTask0.DefaultTimeout();
 
            Assert.Equal(1, mockPipeWriter.GetMemoryCallCount);
            Assert.Equal(1, mockPipeWriter.AdvanceCallCount);
            Assert.Equal(1, mockPipeWriter.FlushCallCount);
 
            lock (sync)
            {
                concurrentPipeWriter.Advance(memory.Length);
                memory = concurrentPipeWriter.GetMemory();
                concurrentPipeWriter.Advance(memory.Length);
 
                flushTask1 = concurrentPipeWriter.FlushAsync();
            }
 
            // Now that we flushed the ConcurrentPipeWriter again, the GetMemory() and Advance() calls are replayed.
            // Make sure that MockPipeWriter.PinnedBlockMemoryPoolBlockSize matches PinnedBlockMemoryPool._blockSize or else
            // it might take more or less calls to the inner PipeWriter's GetMemory method to copy all the data.
            Assert.Equal(3, mockPipeWriter.GetMemoryCallCount);
            Assert.Equal(3, mockPipeWriter.AdvanceCallCount);
            Assert.Equal(2, mockPipeWriter.FlushCallCount);
            Assert.False(flushTask1.IsCompleted);
 
            pipeWriterFlushTcsArray[1].SetResult(default);
 
            await flushTask1.DefaultTimeout();
 
            // Even though we only called flush on the ConcurrentPipeWriter twice, the inner PipeWriter was flushed three times.
            Assert.Equal(3, mockPipeWriter.GetMemoryCallCount);
            Assert.Equal(3, mockPipeWriter.AdvanceCallCount);
            Assert.Equal(2, mockPipeWriter.FlushCallCount);
 
            var completeEx = new Exception();
 
            lock (sync)
            {
                completeTask = concurrentPipeWriter.CompleteAsync(completeEx);
            }
 
            await completeTask.DefaultTimeout();
 
            Assert.Same(completeEx, mockPipeWriter.CompleteException);
        }
    }
 
    [Fact]
    public async Task CompleteFlushesQueuedBytes()
    {
        using (var memoryPool = new PinnedBlockMemoryPool())
        using (var diagnosticPool = new DiagnosticMemoryPool(memoryPool))
        {
            var pipeWriterFlushTcsArray = new[] {
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                };
 
            var sync = new object();
            var mockPipeWriter = new MockPipeWriter(pipeWriterFlushTcsArray);
            var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool, sync);
            var memory = default(Memory<byte>);
            var flushTask0 = default(ValueTask<FlushResult>);
            var completeTask = default(ValueTask);
 
            lock (sync)
            {
                memory = concurrentPipeWriter.GetMemory();
                Assert.Equal(1, mockPipeWriter.GetMemoryCallCount);
 
                concurrentPipeWriter.Advance(memory.Length);
                Assert.Equal(1, mockPipeWriter.AdvanceCallCount);
 
                flushTask0 = concurrentPipeWriter.FlushAsync();
                Assert.Equal(1, mockPipeWriter.FlushCallCount);
                Assert.False(flushTask0.IsCompleted);
 
                // Only GetMemory() is called but not Advance() is not called yet when the first inner flush completes.
                memory = concurrentPipeWriter.GetMemory();
            }
 
            // If the inner flush completes between a call to GetMemory() and Advance(), the outer
            // flush completes, and the next flush will pick up the buffered data.
            pipeWriterFlushTcsArray[0].SetResult(default);
 
            await flushTask0.DefaultTimeout();
 
            Assert.Equal(1, mockPipeWriter.GetMemoryCallCount);
            Assert.Equal(1, mockPipeWriter.AdvanceCallCount);
            Assert.Equal(1, mockPipeWriter.FlushCallCount);
 
            var completeEx = new Exception();
 
            lock (sync)
            {
                concurrentPipeWriter.Advance(memory.Length);
                memory = concurrentPipeWriter.GetMemory();
                concurrentPipeWriter.Advance(memory.Length);
 
                // Complete the ConcurrentPipeWriter without flushing any of the queued data.
                completeTask = concurrentPipeWriter.CompleteAsync(completeEx);
            }
 
            await completeTask.DefaultTimeout();
 
            // Now that we completed the ConcurrentPipeWriter, the GetMemory() and Advance() calls are replayed.
            // Make sure that MockPipeWriter.PinnedBlockMemoryPoolBlockSize matches PinnedBlockMemoryPool._blockSize or else
            // it might take more or less calls to the inner PipeWriter's GetMemory method to copy all the data.
            Assert.Equal(3, mockPipeWriter.GetMemoryCallCount);
            Assert.Equal(3, mockPipeWriter.AdvanceCallCount);
            Assert.Equal(1, mockPipeWriter.FlushCallCount);
            Assert.Same(completeEx, mockPipeWriter.CompleteException);
        }
    }
 
    [Fact]
    public async Task CancelPendingFlushInterruptsFlushLoop()
    {
        using (var memoryPool = new PinnedBlockMemoryPool())
        using (var diagnosticPool = new DiagnosticMemoryPool(memoryPool))
        {
            var pipeWriterFlushTcsArray = new[] {
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                    new TaskCompletionSource<FlushResult>(TaskCreationOptions.RunContinuationsAsynchronously),
                };
 
            var sync = new object();
            var mockPipeWriter = new MockPipeWriter(pipeWriterFlushTcsArray);
            var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool, sync);
            var flushTask0 = default(ValueTask<FlushResult>);
            var flushTask1 = default(ValueTask<FlushResult>);
            var flushTask2 = default(ValueTask<FlushResult>);
            var completeTask = default(ValueTask);
 
            lock (sync)
            {
                var memory = concurrentPipeWriter.GetMemory();
                Assert.Equal(1, mockPipeWriter.GetMemoryCallCount);
 
                concurrentPipeWriter.Advance(memory.Length);
                Assert.Equal(1, mockPipeWriter.AdvanceCallCount);
 
                flushTask0 = concurrentPipeWriter.FlushAsync();
                Assert.Equal(1, mockPipeWriter.FlushCallCount);
 
                Assert.False(flushTask0.IsCompleted);
 
                // Since the flush was not awaited, the following API calls are queued.
                memory = concurrentPipeWriter.GetMemory();
                concurrentPipeWriter.Advance(memory.Length);
                flushTask1 = concurrentPipeWriter.FlushAsync();
 
                Assert.Equal(1, mockPipeWriter.GetMemoryCallCount);
                Assert.Equal(1, mockPipeWriter.AdvanceCallCount);
                Assert.Equal(1, mockPipeWriter.FlushCallCount);
 
                Assert.False(flushTask0.IsCompleted);
                Assert.False(flushTask1.IsCompleted);
 
                // CancelPendingFlush() does not get queued.
                concurrentPipeWriter.CancelPendingFlush();
                Assert.Equal(1, mockPipeWriter.CancelPendingFlushCallCount);
            }
 
            pipeWriterFlushTcsArray[0].SetResult(new FlushResult(isCanceled: true, isCompleted: false));
 
            Assert.True((await flushTask0.DefaultTimeout()).IsCanceled);
            Assert.True((await flushTask1.DefaultTimeout()).IsCanceled);
 
            lock (sync)
            {
                flushTask2 = concurrentPipeWriter.FlushAsync();
            }
 
            Assert.False(flushTask2.IsCompleted);
 
            pipeWriterFlushTcsArray[1].SetResult(default);
 
            await flushTask2.DefaultTimeout();
 
            // We do not need to flush the final bytes, since the incomplete flush will pick it up.
            Assert.Equal(2, mockPipeWriter.GetMemoryCallCount);
            Assert.Equal(2, mockPipeWriter.AdvanceCallCount);
            Assert.Equal(2, mockPipeWriter.FlushCallCount);
 
            var completeEx = new Exception();
 
            lock (sync)
            {
                completeTask = concurrentPipeWriter.CompleteAsync(completeEx);
            }
 
            await completeTask.DefaultTimeout();
 
            Assert.Same(completeEx, mockPipeWriter.CompleteException);
        }
    }
 
    private class MockPipeWriter : PipeWriter
    {
        // It's important that this matches PinnedBlockMemoryPool._blockSize for all the tests to pass.
        private const int PinnedBlockMemoryPoolBlockSize = 4096;
 
        private readonly TaskCompletionSource<FlushResult>[] _flushResults;
 
        public MockPipeWriter(TaskCompletionSource<FlushResult>[] flushResults)
        {
            _flushResults = flushResults;
        }
 
        public int GetMemoryCallCount { get; set; }
        public int AdvanceCallCount { get; set; }
        public int FlushCallCount { get; set; }
        public int CancelPendingFlushCallCount { get; set; }
 
        public TaskCompletionSource FlushTcs { get; set; }
 
        public Exception CompleteException { get; set; }
 
        public override ValueTask<FlushResult> FlushAsync(CancellationToken cancellationToken = default)
        {
            FlushCallCount++;
            FlushTcs?.TrySetResult();
            return new ValueTask<FlushResult>(_flushResults[FlushCallCount - 1].Task);
        }
 
        public override Memory<byte> GetMemory(int sizeHint = 0)
        {
            GetMemoryCallCount++;
            return new Memory<byte>(new byte[sizeHint == 0 ? PinnedBlockMemoryPoolBlockSize : sizeHint]);
        }
 
        public override Span<byte> GetSpan(int sizeHint = 0)
        {
            return GetMemory(sizeHint).Span;
        }
 
        public override void Advance(int bytes)
        {
            AdvanceCallCount++;
        }
 
        public override void Complete(Exception exception = null)
        {
            CompleteException = exception;
        }
 
        public override void CancelPendingFlush()
        {
            CancelPendingFlushCallCount++;
        }
    }
}