File: Internal\JsonRequestHelpers.cs
Web Access
Project: src\src\Grpc\JsonTranscoding\src\Microsoft.AspNetCore.Grpc.JsonTranscoding\Microsoft.AspNetCore.Grpc.JsonTranscoding.csproj (Microsoft.AspNetCore.Grpc.JsonTranscoding)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Text;
using System.Text.Json;
using Google.Api;
using Google.Protobuf;
using Google.Protobuf.Reflection;
using Grpc.Core;
using Grpc.Shared;
using Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal.Json;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc.Formatters;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;
 
namespace Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal;
 
internal static class JsonRequestHelpers
{
    public const string JsonContentType = "application/json";
    public const string JsonContentTypeWithCharset = "application/json; charset=utf-8";
 
    public const string StatusDetailsTrailerName = "grpc-status-details-bin";
 
    public static bool HasJsonContentType(HttpRequest request, out StringSegment charset)
    {
        ArgumentNullException.ThrowIfNull(request);
 
        if (!MediaTypeHeaderValue.TryParse(request.ContentType, out var mt))
        {
            charset = default;
            return false;
        }
 
        // Matches application/json
        if (mt.MediaType.Equals(JsonContentType, StringComparison.OrdinalIgnoreCase))
        {
            charset = mt.Charset;
            return true;
        }
 
        // Matches +json, e.g. application/ld+json
        if (mt.Suffix.Equals("json", StringComparison.OrdinalIgnoreCase))
        {
            charset = mt.Charset;
            return true;
        }
 
        charset = default;
        return false;
    }
 
    public static (Stream stream, bool usesTranscodingStream) GetStream(Stream innerStream, Encoding? encoding)
    {
        if (encoding == null || encoding.CodePage == Encoding.UTF8.CodePage)
        {
            return (innerStream, false);
        }
 
        var stream = Encoding.CreateTranscodingStream(innerStream, encoding, Encoding.UTF8, leaveOpen: true);
        return (stream, true);
    }
 
    public static Encoding? GetEncodingFromCharset(StringSegment charset)
    {
        if (charset.Equals("utf-8", StringComparison.OrdinalIgnoreCase))
        {
            // This is an optimization for utf-8 that prevents the Substring caused by
            // charset.Value
            return Encoding.UTF8;
        }
 
        try
        {
            // charset.Value might be an invalid encoding name as in charset=invalid.
            return charset.HasValue ? Encoding.GetEncoding(charset.Value) : null;
        }
        catch (Exception ex)
        {
            throw new InvalidOperationException($"Unable to read the request as JSON because the request content type charset '{charset}' is not a known encoding.", ex);
        }
    }
 
    public static async ValueTask SendErrorResponse(HttpResponse response, Encoding encoding, Metadata trailers, Status status, JsonSerializerOptions options)
    {
        if (!response.HasStarted)
        {
            response.StatusCode = MapStatusCodeToHttpStatus(status.StatusCode);
            response.ContentType = MediaType.ReplaceEncoding("application/json", encoding);
        }
 
        var e = GetStatusDetails(trailers) ?? new Google.Rpc.Status
        {
            Message = status.Detail,
            Code = (int)status.StatusCode
        };
 
        await WriteResponseMessage(response, encoding, e, options, CancellationToken.None);
 
        static Google.Rpc.Status? GetStatusDetails(Metadata trailers)
        {
            var statusDetails = trailers.Get(StatusDetailsTrailerName);
            if (statusDetails?.IsBinary == true)
            {
                try
                {
                    return Google.Rpc.Status.Parser.ParseFrom(statusDetails.ValueBytes);
                }
                catch (Exception ex)
                {
                    throw new InvalidOperationException($"Error when parsing the '{StatusDetailsTrailerName}' trailer.", ex);
                }
            }
 
            return null;
        }
    }
 
    public static int MapStatusCodeToHttpStatus(StatusCode statusCode)
    {
        switch (statusCode)
        {
            case StatusCode.OK:
                return StatusCodes.Status200OK;
            case StatusCode.Cancelled:
                return StatusCodes.Status408RequestTimeout;
            case StatusCode.Unknown:
                return StatusCodes.Status500InternalServerError;
            case StatusCode.InvalidArgument:
                return StatusCodes.Status400BadRequest;
            case StatusCode.DeadlineExceeded:
                return StatusCodes.Status504GatewayTimeout;
            case StatusCode.NotFound:
                return StatusCodes.Status404NotFound;
            case StatusCode.AlreadyExists:
                return StatusCodes.Status409Conflict;
            case StatusCode.PermissionDenied:
                return StatusCodes.Status403Forbidden;
            case StatusCode.Unauthenticated:
                return StatusCodes.Status401Unauthorized;
            case StatusCode.ResourceExhausted:
                return StatusCodes.Status429TooManyRequests;
            case StatusCode.FailedPrecondition:
                // Note, this deliberately doesn't translate to the similarly named '412 Precondition Failed' HTTP response status.
                return StatusCodes.Status400BadRequest;
            case StatusCode.Aborted:
                return StatusCodes.Status409Conflict;
            case StatusCode.OutOfRange:
                return StatusCodes.Status400BadRequest;
            case StatusCode.Unimplemented:
                return StatusCodes.Status501NotImplemented;
            case StatusCode.Internal:
                return StatusCodes.Status500InternalServerError;
            case StatusCode.Unavailable:
                return StatusCodes.Status503ServiceUnavailable;
            case StatusCode.DataLoss:
                return StatusCodes.Status500InternalServerError;
        }
 
        return StatusCodes.Status500InternalServerError;
    }
 
    public static async ValueTask WriteResponseMessage(HttpResponse response, Encoding encoding, object responseBody, JsonSerializerOptions options, CancellationToken cancellationToken)
    {
        var (stream, usesTranscodingStream) = GetStream(response.Body, encoding);
 
        try
        {
            await JsonSerializer.SerializeAsync(stream, responseBody, options, cancellationToken);
        }
        finally
        {
            if (usesTranscodingStream)
            {
                await stream.DisposeAsync();
            }
        }
    }
 
    public static async ValueTask<TRequest> ReadMessage<TRequest>(JsonTranscodingServerCallContext serverCallContext, JsonSerializerOptions serializerOptions) where TRequest : class
    {
        try
        {
            GrpcServerLog.ReadingMessage(serverCallContext.Logger);
 
            IMessage requestMessage;
            if (serverCallContext.DescriptorInfo.BodyDescriptor != null)
            {
                Type type;
                object bodyContent;
 
                if (serverCallContext.DescriptorInfo.BodyDescriptor.FullName == HttpBody.Descriptor.FullName)
                {
                    type = typeof(HttpBody);
 
                    bodyContent = await ReadHttpBodyAsync(serverCallContext);
                }
                else
                {
                    if (!serverCallContext.IsJsonRequestContent)
                    {
                        GrpcServerLog.UnsupportedRequestContentType(serverCallContext.Logger, serverCallContext.HttpContext.Request.ContentType);
                        throw new InvalidOperationException($"Unable to read the request as JSON because the request content type '{serverCallContext.HttpContext.Request.ContentType}' is not a known JSON content type.");
                    }
 
                    var (stream, usesTranscodingStream) = GetStream(serverCallContext.HttpContext.Request.Body, serverCallContext.RequestEncoding);
 
                    try
                    {
                        if (serverCallContext.DescriptorInfo.BodyDescriptorRepeated)
                        {
                            requestMessage = (IMessage)Activator.CreateInstance<TRequest>();
 
                            // TODO: JsonSerializer currently doesn't support deserializing values onto an existing object or collection.
                            // Either update this to use new functionality in JsonSerializer or improve work-around perf.
                            type = JsonConverterHelper.GetFieldType(serverCallContext.DescriptorInfo.BodyFieldDescriptor);
 
                            var args = type.GetGenericArguments();
                            if (serverCallContext.DescriptorInfo.BodyFieldDescriptor.IsMap)
                            {
                                type = typeof(Dictionary<,>).MakeGenericType(args[0], args[1]);
                            }
                            else
                            {
                                type = typeof(List<>).MakeGenericType(args[0]);
                            }
 
                            GrpcServerLog.DeserializingMessage(serverCallContext.Logger, type);
 
                            bodyContent = (await JsonSerializer.DeserializeAsync(stream, type, serializerOptions))!;
 
                            if (bodyContent == null)
                            {
                                throw new InvalidOperationException($"Unable to deserialize null to {type.Name}.");
                            }
                        }
                        else
                        {
                            type = serverCallContext.DescriptorInfo.BodyDescriptor.ClrType;
 
                            GrpcServerLog.DeserializingMessage(serverCallContext.Logger, type);
                            bodyContent = (IMessage)(await JsonSerializer.DeserializeAsync(stream, serverCallContext.DescriptorInfo.BodyDescriptor.ClrType, serializerOptions))!;
                        }
                    }
                    finally
                    {
                        if (usesTranscodingStream)
                        {
                            await stream.DisposeAsync();
                        }
                    }
                }
 
                if (serverCallContext.DescriptorInfo.BodyFieldDescriptor != null)
                {
                    requestMessage = (IMessage)Activator.CreateInstance<TRequest>();
 
                    // The spec says that request body must be on the top-level message.
                    // Recursive request body isn't supported.
                    ServiceDescriptorHelpers.SetValue(requestMessage, serverCallContext.DescriptorInfo.BodyFieldDescriptor, bodyContent);
                }
                else
                {
                    if (bodyContent == null)
                    {
                        throw new InvalidOperationException($"Unable to deserialize null to {type.Name}.");
                    }
 
                    requestMessage = (IMessage)bodyContent;
                }
            }
            else
            {
                requestMessage = (IMessage)Activator.CreateInstance<TRequest>();
            }
 
            foreach (var parameterDescriptor in serverCallContext.DescriptorInfo.RouteParameterDescriptors)
            {
                var routeValue = serverCallContext.HttpContext.Request.RouteValues[parameterDescriptor.Key];
                if (routeValue != null)
                {
                    ServiceDescriptorHelpers.RecursiveSetValue(requestMessage, parameterDescriptor.Value.DescriptorsPath, routeValue);
                }
            }
 
            foreach (var item in serverCallContext.HttpContext.Request.Query)
            {
                if (CanBindQueryStringVariable(serverCallContext, item.Key))
                {
                    var pathDescriptors = GetPathDescriptors(serverCallContext, requestMessage, item.Key);
 
                    if (pathDescriptors != null)
                    {
                        var value = item.Value.Count == 1 ? (object?)item.Value[0] : item.Value;
                        ServiceDescriptorHelpers.RecursiveSetValue(requestMessage, pathDescriptors, value);
                    }
                }
            }
 
            GrpcServerLog.ReceivedMessage(serverCallContext.Logger);
            return (TRequest)requestMessage;
        }
        catch (JsonException ex)
        {
            GrpcServerLog.ErrorReadingMessage(serverCallContext.Logger, ex);
            throw new RpcException(new Status(StatusCode.InvalidArgument, "Request JSON payload is not correctly formatted.", ex));
        }
        catch (Exception ex)
        {
            GrpcServerLog.ErrorReadingMessage(serverCallContext.Logger, ex);
            throw new RpcException(new Status(StatusCode.InvalidArgument, ex.Message, ex));
        }
    }
 
    private static async ValueTask<IMessage> ReadHttpBodyAsync(JsonTranscodingServerCallContext serverCallContext)
    {
        var httpBody = (IMessage)Activator.CreateInstance(serverCallContext.DescriptorInfo.BodyDescriptor!.ClrType)!;
 
        var contentType = serverCallContext.HttpContext.Request.ContentType;
        if (contentType != null)
        {
            httpBody.Descriptor.Fields[HttpBody.ContentTypeFieldNumber].Accessor.SetValue(httpBody, contentType);
        }
 
        var data = await ReadDataAsync(serverCallContext);
        httpBody.Descriptor.Fields[HttpBody.DataFieldNumber].Accessor.SetValue(httpBody, UnsafeByteOperations.UnsafeWrap(data));
 
        return httpBody;
    }
 
    private static async ValueTask<byte[]> ReadDataAsync(JsonTranscodingServerCallContext serverCallContext)
    {
        // Buffer to disk if content is larger than 30Kb.
        // Based on value in XmlSerializer and NewtonsoftJson input formatters.
        const int DefaultMemoryThreshold = 1024 * 30;
 
        var memoryThreshold = DefaultMemoryThreshold;
        var contentLength = serverCallContext.HttpContext.Request.ContentLength.GetValueOrDefault();
        if (contentLength > 0 && contentLength < memoryThreshold)
        {
            // If the Content-Length is known and is smaller than the default buffer size, use it.
            memoryThreshold = (int)contentLength;
        }
 
        using var fs = new FileBufferingReadStream(serverCallContext.HttpContext.Request.Body, memoryThreshold);
 
        // Read the request body into buffer.
        // No explicit cancellation token. Request body uses underlying request aborted token.
        await fs.DrainAsync(CancellationToken.None);
        fs.Seek(0, SeekOrigin.Begin);
 
        var data = new byte[fs.Length];
        var read = fs.Read(data);
        Debug.Assert(read == data.Length);
 
        return data;
    }
 
    private static List<FieldDescriptor>? GetPathDescriptors(JsonTranscodingServerCallContext serverCallContext, IMessage requestMessage, string path)
    {
        return serverCallContext.DescriptorInfo.PathDescriptorsCache.GetOrAdd(path, p =>
        {
            ServiceDescriptorHelpers.TryResolveDescriptors(requestMessage.Descriptor, p.Split('.'), allowJsonName: true, out var pathDescriptors);
            return pathDescriptors;
        });
    }
 
    public static async ValueTask SendMessage<TResponse>(JsonTranscodingServerCallContext serverCallContext, JsonSerializerOptions serializerOptions, TResponse message, CancellationToken cancellationToken) where TResponse : class
    {
        var response = serverCallContext.HttpContext.Response;
 
        try
        {
            GrpcServerLog.SendingMessage(serverCallContext.Logger);
 
            object responseBody;
            Type responseType;
 
            if (serverCallContext.DescriptorInfo.ResponseBodyDescriptor != null)
            {
                // The spec says that response body must be on the top-level message.
                // Recursive response body isn't supported.
                responseBody = serverCallContext.DescriptorInfo.ResponseBodyDescriptor.Accessor.GetValue((IMessage)message);
                responseType = JsonConverterHelper.GetFieldType(serverCallContext.DescriptorInfo.ResponseBodyDescriptor);
            }
            else
            {
                responseBody = message;
                responseType = message.GetType();
            }
 
            await JsonRequestHelpers.WriteResponseMessage(response, serverCallContext.RequestEncoding, responseBody, serializerOptions, cancellationToken);
 
            GrpcServerLog.SerializedMessage(serverCallContext.Logger, responseType);
            GrpcServerLog.MessageSent(serverCallContext.Logger);
        }
        catch (Exception ex)
        {
            GrpcServerLog.ErrorSendingMessage(serverCallContext.Logger, ex);
            throw;
        }
    }
 
    private static bool CanBindQueryStringVariable(JsonTranscodingServerCallContext serverCallContext, string variable)
    {
        if (serverCallContext.DescriptorInfo.BodyDescriptor != null)
        {
            var bodyFieldName = serverCallContext.DescriptorInfo.BodyFieldDescriptor?.Name;
 
            // Null field name indicates "*" which means the entire message is bound to the body.
            if (bodyFieldName == null)
            {
                return false;
            }
 
            // Exact match
            if (variable == bodyFieldName)
            {
                return false;
            }
 
            // Nested field of field name.
            if (bodyFieldName.Length + 1 < variable.Length &&
                variable.StartsWith(bodyFieldName, StringComparison.Ordinal) &&
                variable[bodyFieldName.Length] == '.')
            {
                return false;
            }
        }
 
        if (serverCallContext.DescriptorInfo.RouteParameterDescriptors.ContainsKey(variable))
        {
            return false;
        }
 
        return true;
    }
}