|
#region Copyright notice and license
// Copyright 2019 The gRPC Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#endregion
using System.Collections;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Reflection;
using Google.Api;
using Google.Protobuf;
using Google.Protobuf.Reflection;
using Google.Protobuf.WellKnownTypes;
using Microsoft.Extensions.Primitives;
using Type = System.Type;
namespace Grpc.Shared;
internal static class ServiceDescriptorHelpers
{
private static readonly HashSet<string> WellKnownTypeNames = new HashSet<string>
{
"google/protobuf/any.proto",
"google/protobuf/api.proto",
"google/protobuf/duration.proto",
"google/protobuf/empty.proto",
"google/protobuf/wrappers.proto",
"google/protobuf/timestamp.proto",
"google/protobuf/field_mask.proto",
"google/protobuf/source_context.proto",
"google/protobuf/struct.proto",
"google/protobuf/type.proto",
};
internal static bool IsWellKnownType(MessageDescriptor messageDescriptor) => messageDescriptor.File.Package == "google.protobuf" &&
WellKnownTypeNames.Contains(messageDescriptor.File.Name);
internal static bool IsWrapperType(DescriptorBase m) =>
m.File.Package == "google.protobuf" && m.File.Name == "google/protobuf/wrappers.proto";
public static ServiceDescriptor? GetServiceDescriptor(Type serviceReflectionType)
{
var property = serviceReflectionType.GetProperty("Descriptor", BindingFlags.Public | BindingFlags.Static);
if (property != null)
{
return (ServiceDescriptor?)property.GetValue(null);
}
throw new InvalidOperationException($"Get not find Descriptor property on {serviceReflectionType.Name}.");
}
public static bool TryResolveDescriptors(MessageDescriptor messageDescriptor, IList<string> path, bool allowJsonName, [NotNullWhen(true)]out List<FieldDescriptor>? fieldDescriptors)
{
fieldDescriptors = null;
MessageDescriptor? currentDescriptor = messageDescriptor;
foreach (var fieldName in path)
{
FieldDescriptor? field = null;
if (currentDescriptor != null)
{
field = allowJsonName
? GetFieldByName(currentDescriptor, fieldName)
: currentDescriptor.FindFieldByName(fieldName);
}
if (field == null)
{
fieldDescriptors = null;
return false;
}
fieldDescriptors ??= new List<FieldDescriptor>();
fieldDescriptors.Add(field);
if (field.FieldType == FieldType.Message)
{
currentDescriptor = field.MessageType;
}
else
{
currentDescriptor = null;
}
}
return fieldDescriptors != null;
}
private static FieldDescriptor? GetFieldByName(MessageDescriptor messageDescriptor, string fieldName)
{
// Search fields by field name and JSON name. Both names can be referenced.
// JSON name takes precendence. If there are conflicts, then the last field with a name wins.
// This logic matches how properties are used in JSON serialization's MessageTypeInfoResolver.
var fields = messageDescriptor.Fields.InFieldNumberOrder();
FieldDescriptor? fieldNameDescriptorMatch = null;
for (var i = fields.Count - 1; i >= 0; i--)
{
// We're checking JSON name first, in reverse order through fields.
// That means the method can exit early on match because the match has the highest precedence.
var field = fields[i];
if (field.JsonName == fieldName)
{
return field;
}
// If there is a match on field name then store the first match.
if (fieldNameDescriptorMatch is null && field.Name == fieldName)
{
fieldNameDescriptorMatch = field;
}
}
// No match with JSON name. If there is a field name match then return it.
return fieldNameDescriptorMatch;
}
private static object? ConvertValue(object? value, FieldDescriptor descriptor)
{
switch (descriptor.FieldType)
{
case FieldType.Double:
return Convert.ToDouble(value, CultureInfo.InvariantCulture);
case FieldType.Float:
return Convert.ToSingle(value, CultureInfo.InvariantCulture);
case FieldType.Int64:
case FieldType.SInt64:
case FieldType.SFixed64:
return Convert.ToInt64(value, CultureInfo.InvariantCulture);
case FieldType.UInt64:
case FieldType.Fixed64:
return Convert.ToUInt64(value, CultureInfo.InvariantCulture);
case FieldType.Int32:
case FieldType.SInt32:
case FieldType.SFixed32:
return Convert.ToInt32(value, CultureInfo.InvariantCulture);
case FieldType.Bool:
return Convert.ToBoolean(value, CultureInfo.InvariantCulture);
case FieldType.String:
return value;
case FieldType.Bytes:
{
if (value is string s)
{
return ByteString.FromBase64(s);
}
throw new InvalidOperationException("Base64 encoded string required to convert to bytes.");
}
case FieldType.UInt32:
case FieldType.Fixed32:
return Convert.ToUInt32(value, CultureInfo.InvariantCulture);
case FieldType.Enum:
{
if (value is string s)
{
var enumValueDescriptor = int.TryParse(s, NumberStyles.Integer, CultureInfo.InvariantCulture, out var i)
? descriptor.EnumType.FindValueByNumber(i)
: descriptor.EnumType.FindValueByName(s);
if (enumValueDescriptor == null)
{
throw new InvalidOperationException($"Invalid value '{s}' for enum type {descriptor.EnumType.Name}.");
}
return enumValueDescriptor.Number;
}
throw new InvalidOperationException("String required to convert to enum.");
}
case FieldType.Message:
if (IsWellKnownType(descriptor.MessageType))
{
if (IsWrapperType(descriptor.MessageType))
{
if (value == null)
{
return null;
}
return ConvertValue(value, descriptor.MessageType.FindFieldByName("value"));
}
else if (descriptor.MessageType.FullName == FieldMask.Descriptor.FullName)
{
return FieldMask.FromString((string)value!);
}
else if (descriptor.MessageType.FullName == Duration.Descriptor.FullName)
{
var (seconds, nanos) = Legacy.ParseDuration((string)value!);
var duration = new Duration();
duration.Seconds = seconds;
duration.Nanos = nanos;
return duration;
}
else if (descriptor.MessageType.FullName == Timestamp.Descriptor.FullName)
{
var (seconds, nanos) = Legacy.ParseTimestamp((string)value!);
var timestamp = new Timestamp();
timestamp.Seconds = seconds;
timestamp.Nanos = nanos;
return timestamp;
}
}
break;
}
throw new InvalidOperationException("Unsupported type: " + descriptor.FieldType);
}
public static void RecursiveSetValue(IMessage currentValue, List<FieldDescriptor> pathDescriptors, object? values)
{
for (var i = 0; i < pathDescriptors.Count; i++)
{
var isLast = i == pathDescriptors.Count - 1;
var field = pathDescriptors[i];
if (isLast)
{
SetValue(currentValue, field, values);
}
else
{
var fieldMessage = (IMessage)field.Accessor.GetValue(currentValue);
if (fieldMessage == null)
{
fieldMessage = (IMessage)Activator.CreateInstance(field.MessageType.ClrType)!;
field.Accessor.SetValue(currentValue, fieldMessage);
}
currentValue = fieldMessage;
}
}
}
public static void SetValue(IMessage message, FieldDescriptor field, object? values)
{
if (field.IsMap)
{
var map = (IDictionary)field.Accessor.GetValue(message);
if (values is IDictionary dictionaryValues)
{
foreach (DictionaryEntry value in dictionaryValues)
{
map[value.Key] = value.Value;
}
}
else
{
throw new InvalidOperationException("Map field requires repeating keys and values.");
}
}
else if (field.IsRepeated)
{
var list = (IList)field.Accessor.GetValue(message);
if (values is StringValues stringValues)
{
foreach (var value in stringValues)
{
list.Add(ConvertValue(value, field));
}
}
else if (values is IList listValues)
{
foreach (var value in listValues)
{
var v = field.Accessor.Descriptor.FieldType == FieldType.Message
? value
: ConvertValue(value, field);
list.Add(v);
}
}
else
{
list.Add(ConvertValue(values, field));
}
}
else
{
if (values is StringValues stringValues)
{
if (stringValues.Count == 1)
{
field.Accessor.SetValue(message, ConvertValue(stringValues[0], field));
}
else
{
throw new InvalidOperationException("Can't set multiple values onto a non-repeating field.");
}
}
else if (values is IMessage messageValue)
{
if (IsWrapperType(messageValue.Descriptor))
{
const int WrapperValueFieldNumber = Int32Value.ValueFieldNumber;
var wrappedValue = messageValue.Descriptor.Fields[WrapperValueFieldNumber].Accessor.GetValue(messageValue);
field.Accessor.SetValue(message, wrappedValue);
}
else
{
field.Accessor.SetValue(message, messageValue);
}
}
else
{
field.Accessor.SetValue(message, ConvertValue(values, field));
}
}
}
// Transcoding assumes that the app is referencing Google.Api.CommonProtos and HttpRule is from that assembly.
// However, it's possible the app has compiled http.proto with Grpc.Tools, so the extension value is HttpRule from a different assembly.
// This custom extension uses the HttpRule field number but has a return type of object.
// The method always returns the extension value, and the calling code can convert it to the expected type.
// See https://github.com/protocolbuffers/protobuf/issues/9626 for more details.
private static readonly Extension<MethodOptions, object> UntypedHttpExtension =
new Extension<MethodOptions, object>(AnnotationsExtensions.Http.FieldNumber, codec: null);
public static bool TryGetHttpRule(MethodDescriptor methodDescriptor, [NotNullWhen(true)] out HttpRule? httpRule)
{
var options = methodDescriptor.GetOptions();
// The untyped extension always returns the extension value. If the type is already the expected HttpRule then use it directly.
// A different message indicates a custom HttpRule was used. Convert the message to bytes and reparse it to the known HttpRule type.
var extensionValue = options?.GetExtension(UntypedHttpExtension);
httpRule = extensionValue switch
{
HttpRule rule => rule,
IMessage message => HttpRule.Parser.ParseFrom(message.ToByteArray()),
_ => null
};
return httpRule != null;
}
public static bool TryResolvePattern(HttpRule http, [NotNullWhen(true)] out string? pattern, [NotNullWhen(true)] out string? verb)
{
switch (http.PatternCase)
{
case HttpRule.PatternOneofCase.Get:
pattern = http.Get;
verb = "GET";
return true;
case HttpRule.PatternOneofCase.Put:
pattern = http.Put;
verb = "PUT";
return true;
case HttpRule.PatternOneofCase.Post:
pattern = http.Post;
verb = "POST";
return true;
case HttpRule.PatternOneofCase.Delete:
pattern = http.Delete;
verb = "DELETE";
return true;
case HttpRule.PatternOneofCase.Patch:
pattern = http.Patch;
verb = "PATCH";
return true;
case HttpRule.PatternOneofCase.Custom:
pattern = http.Custom.Path;
verb = http.Custom.Kind;
return true;
default:
pattern = null;
verb = null;
return false;
}
}
public static Dictionary<string, RouteParameter> ResolveRouteParameterDescriptors(
List<HttpRouteVariable> variables,
MessageDescriptor messageDescriptor)
{
var routeParameterDescriptors = new Dictionary<string, RouteParameter>(StringComparer.Ordinal);
foreach (var variable in variables)
{
var path = variable.FieldPath;
if (!TryResolveDescriptors(messageDescriptor, path, allowJsonName: false, out var fieldDescriptors))
{
throw new InvalidOperationException($"Couldn't find matching field for route parameter '{string.Join(".", path)}' on {messageDescriptor.Name}.");
}
var completeFieldPath = string.Join(".", fieldDescriptors.Select(d => d.Name));
var completeJsonPath = string.Join(".", fieldDescriptors.Select(d => d.JsonName));
routeParameterDescriptors.Add(completeFieldPath, new RouteParameter(fieldDescriptors, variable, completeJsonPath));
}
return routeParameterDescriptors;
}
public static BodyDescriptorInfo? ResolveBodyDescriptor(string body, Type serviceType, MethodDescriptor methodDescriptor)
{
if (!string.IsNullOrEmpty(body))
{
if (!string.Equals(body, "*", StringComparison.Ordinal))
{
if (body.Contains('.', StringComparison.Ordinal))
{
throw new InvalidOperationException($"The body field '{body}' references a nested field. The body field name must be on the top-level request message.");
}
var bodyDescriptor = methodDescriptor.InputType.FindFieldByName(body);
if (bodyDescriptor == null)
{
throw new InvalidOperationException($"Couldn't find matching field for body '{body}' on {methodDescriptor.InputType.Name}.");
}
var propertyName = FormatUnderscoreName(bodyDescriptor.Name, pascalCase: true, preservePeriod: false);
var propertyInfo = bodyDescriptor.ContainingType.ClrType.GetProperty(propertyName);
if (bodyDescriptor.IsRepeated)
{
// A repeating field isn't a message type. The JSON parser will parse using the containing
// type to get the repeating collection.
return new BodyDescriptorInfo(bodyDescriptor.ContainingType, bodyDescriptor, isDescriptorRepeated: true, propertyInfo);
}
else
{
return new BodyDescriptorInfo(bodyDescriptor.MessageType, bodyDescriptor, isDescriptorRepeated: false, propertyInfo);
}
}
else
{
ParameterInfo? requestParameter = null;
var methodInfo = serviceType.GetMethod(methodDescriptor.Name);
if (methodInfo != null)
{
requestParameter = methodInfo.GetParameters().SingleOrDefault(p => p.Name == "request");
}
return new BodyDescriptorInfo(methodDescriptor.InputType, fieldDescriptor: null, isDescriptorRepeated: false, parameterInfo: requestParameter);
}
}
return null;
}
public static FieldDescriptor? ResolveResponseBodyDescriptor(string responseBody, MethodDescriptor methodDescriptor)
{
if (!string.IsNullOrEmpty(responseBody))
{
if (responseBody.Contains('.', StringComparison.Ordinal))
{
throw new InvalidOperationException($"The response body field '{responseBody}' references a nested field. The response body field name must be on the top-level response message.");
}
var responseBodyDescriptor = methodDescriptor.OutputType.FindFieldByName(responseBody);
if (responseBodyDescriptor == null)
{
throw new InvalidOperationException($"Couldn't find matching field for response body '{responseBody}' on {methodDescriptor.OutputType.Name}.");
}
return responseBodyDescriptor;
}
return null;
}
public static Dictionary<string, FieldDescriptor> ResolveQueryParameterDescriptors(
Dictionary<string, RouteParameter> routeParameters,
MethodDescriptor methodDescriptor,
MessageDescriptor? bodyDescriptor,
FieldDescriptor? bodyFieldDescriptor)
{
var existingParameters = new List<FieldDescriptor>();
foreach (var routeParameter in routeParameters)
{
// Each route field descriptors collection contains all the descriptors in the path.
// We only care about the final place the route value is set and so add only the last
// descriptor to the existing parameters collection.
existingParameters.Add(routeParameter.Value.DescriptorsPath.Last());
}
if (bodyDescriptor != null)
{
if (bodyFieldDescriptor != null)
{
// Body with field name.
existingParameters.Add(bodyFieldDescriptor);
}
else
{
// Body with wildcard. All parameters are in the body so no query parameters.
return new Dictionary<string, FieldDescriptor>();
}
}
var queryParameters = new Dictionary<string, FieldDescriptor>();
RecursiveVisitMessages(queryParameters, existingParameters, methodDescriptor.InputType, new List<FieldDescriptor>());
return queryParameters;
static void RecursiveVisitMessages(Dictionary<string, FieldDescriptor> queryParameters, List<FieldDescriptor> existingParameters, MessageDescriptor messageDescriptor, List<FieldDescriptor> path)
{
var messageFields = messageDescriptor.Fields.InFieldNumberOrder();
foreach (var fieldDescriptor in messageFields)
{
// If a field is set via route parameter or body then don't add query parameter.
if (existingParameters.Contains(fieldDescriptor))
{
continue;
}
// Add current field descriptor. It should be included in the path.
path.Add(fieldDescriptor);
switch (fieldDescriptor.FieldType)
{
case FieldType.Double:
case FieldType.Float:
case FieldType.Int64:
case FieldType.UInt64:
case FieldType.Int32:
case FieldType.Fixed64:
case FieldType.Fixed32:
case FieldType.Bool:
case FieldType.String:
case FieldType.Bytes:
case FieldType.UInt32:
case FieldType.SFixed32:
case FieldType.SFixed64:
case FieldType.SInt32:
case FieldType.SInt64:
case FieldType.Enum:
{
var joinedPath = string.Join(".", path.Select(d => d.JsonName));
queryParameters[joinedPath] = fieldDescriptor;
}
break;
case FieldType.Group:
case FieldType.Message:
default:
// Complex repeated fields aren't valid query parameters.
if (IsCustomType(fieldDescriptor.MessageType))
{
var joinedPath = string.Join(".", path.Select(d => d.JsonName));
queryParameters[joinedPath] = fieldDescriptor;
}
else if (!fieldDescriptor.IsRepeated)
{
RecursiveVisitMessages(queryParameters, existingParameters, fieldDescriptor.MessageType, path);
}
break;
}
// Remove current field descriptor.
path.RemoveAt(path.Count - 1);
}
}
}
private static bool IsCustomType(MessageDescriptor messageDescriptor)
{
// The messages flags here should be kept in sync with GrpcDataContractResolver.TryCustomizeMessage.
if (IsWrapperType(messageDescriptor) ||
messageDescriptor.FullName == Timestamp.Descriptor.FullName ||
messageDescriptor.FullName == Duration.Descriptor.FullName ||
messageDescriptor.FullName == FieldMask.Descriptor.FullName ||
messageDescriptor.FullName == Struct.Descriptor.FullName ||
messageDescriptor.FullName == ListValue.Descriptor.FullName ||
messageDescriptor.FullName == Value.Descriptor.FullName ||
messageDescriptor.FullName == Any.Descriptor.FullName)
{
return true;
}
return false;
}
public sealed class BodyDescriptorInfo
{
public MessageDescriptor Descriptor { get; }
public FieldDescriptor? FieldDescriptor { get; }
public bool IsDescriptorRepeated { get; }
public PropertyInfo? PropertyInfo { get; }
public ParameterInfo? ParameterInfo { get; }
public BodyDescriptorInfo(
MessageDescriptor descriptor,
FieldDescriptor? fieldDescriptor,
bool isDescriptorRepeated,
PropertyInfo? propertyInfo = null,
ParameterInfo? parameterInfo = null)
{
Descriptor = descriptor;
FieldDescriptor = fieldDescriptor;
IsDescriptorRepeated = isDescriptorRepeated;
PropertyInfo = propertyInfo;
ParameterInfo = parameterInfo;
}
}
public static string FormatUnderscoreName(string input, bool pascalCase, bool preservePeriod)
{
var capitalizeNext = pascalCase;
var result = string.Empty;
for (var i = 0; i < input.Length; i++)
{
if (char.IsLower(input[i]))
{
if (capitalizeNext)
{
result += char.ToUpper(input[i], CultureInfo.InvariantCulture);
}
else
{
result += input[i];
}
capitalizeNext = false;
}
else if (char.IsUpper(input[i]))
{
if (i == 0 && !capitalizeNext)
{
// Force first letter to lower-case unless explicitly told to
// capitalize it.
result += char.ToLower(input[i], CultureInfo.InvariantCulture);
}
else
{
// Capital letters after the first are left as-is.
result += input[i];
}
capitalizeNext = false;
}
else if (char.IsDigit(input[i]))
{
result += input[i];
capitalizeNext = true;
}
else
{
capitalizeNext = true;
if (input[i] == '.' && preservePeriod)
{
result += '.';
}
}
}
// Add a trailing "_" if the name should be altered.
if (input.Length > 0 && input[input.Length - 1] == '#')
{
result += '_';
}
return result;
}
}
internal sealed class RouteParameter
{
public List<FieldDescriptor> DescriptorsPath { get; }
public HttpRouteVariable RouteVariable { get; }
public string JsonPath { get; }
public RouteParameter(
List<FieldDescriptor> descriptorsPath,
HttpRouteVariable routeVariable,
string jsonPath)
{
DescriptorsPath = descriptorsPath;
RouteVariable = routeVariable;
JsonPath = jsonPath;
}
}
|