File: Internal\HttpContextStreamWriter.cs
Web Access
Project: src\src\Grpc\JsonTranscoding\src\Microsoft.AspNetCore.Grpc.JsonTranscoding\Microsoft.AspNetCore.Grpc.JsonTranscoding.csproj (Microsoft.AspNetCore.Grpc.JsonTranscoding)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Text.Json;
using Google.Api;
using Grpc.Core;
 
namespace Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal;
 
internal sealed class HttpContextStreamWriter<TResponse> : IServerStreamWriter<TResponse>
    where TResponse : class
{
    private readonly JsonTranscodingServerCallContext _context;
    private readonly JsonSerializerOptions _serializerOptions;
    private readonly object _writeLock;
    private Task? _writeTask;
    private bool _completed;
 
    public HttpContextStreamWriter(JsonTranscodingServerCallContext context, JsonSerializerOptions serializerOptions)
    {
        _context = context;
        _serializerOptions = serializerOptions;
        _writeLock = new object();
    }
 
    public WriteOptions? WriteOptions
    {
        get => _context.WriteOptions;
        set => _context.WriteOptions = value;
    }
 
    Task IAsyncStreamWriter<TResponse>.WriteAsync(TResponse message, CancellationToken cancellationToken)
    {
        return WriteAsyncCore(message, cancellationToken);
    }
 
    public Task WriteAsync(TResponse message)
    {
        return WriteAsyncCore(message, CancellationToken.None);
    }
 
    private async Task WriteAsyncCore(TResponse message, CancellationToken cancellationToken)
    {
        ArgumentNullException.ThrowIfNull(message);
 
        // Register cancellation token early to ensure request is canceled if cancellation is requested.
        CancellationTokenRegistration? registration = null;
        if (cancellationToken.CanBeCanceled)
        {
            registration = cancellationToken.Register(
                static (state) => ((JsonTranscodingServerCallContext)state!).HttpContext.Abort(),
                _context);
        }
 
        try
        {
            cancellationToken.ThrowIfCancellationRequested();
 
            if (_completed || _context.CancellationToken.IsCancellationRequested)
            {
                throw new InvalidOperationException("Can't write the message because the request is complete.");
            }
 
            lock (_writeLock)
            {
                // Pending writes need to be awaited first
                if (IsWriteInProgressUnsynchronized)
                {
                    throw new InvalidOperationException("Can't write the message because the previous write is in progress.");
                }
 
                // Save write task to track whether it is complete. Must be set inside lock.
                _writeTask = WriteMessageAndDelimiter(message, cancellationToken);
            }
 
            await _writeTask;
        }
        finally
        {
            registration?.Dispose();
        }
    }
 
    private async Task WriteMessageAndDelimiter(TResponse message, CancellationToken cancellationToken)
    {
        if (message is HttpBody httpBody)
        {
            _context.EnsureResponseHeaders(httpBody.ContentType);
            await _context.HttpContext.Response.Body.WriteAsync(httpBody.Data.Memory, cancellationToken);
        }
        else
        {
            _context.EnsureResponseHeaders();
            await JsonRequestHelpers.SendMessage(_context, _serializerOptions, message, cancellationToken);
        }
 
        await _context.HttpContext.Response.Body.WriteAsync(GrpcProtocolConstants.StreamingDelimiter, cancellationToken);
    }
 
    public void Complete()
    {
        _completed = true;
    }
 
    /// <summary>
    /// A value indicating whether there is an async write already in progress.
    /// Should only check this property when holding the write lock.
    /// </summary>
    private bool IsWriteInProgressUnsynchronized
    {
        get
        {
            var writeTask = _writeTask;
            return writeTask != null && !writeTask.IsCompleted;
        }
    }
}