File: MethodHandlerDetails.cs
Web Access
Project: src\src\LanguageServer\Microsoft.CommonLanguageServerProtocol.Framework\Microsoft.CommonLanguageServerProtocol.Framework.Package.csproj (Microsoft.CommonLanguageServerProtocol.Framework.Package)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
// This is consumed as 'generated' code in a source package and therefore requires an explicit nullable enable
#nullable enable
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Reflection;
using System.Threading;
 
namespace Microsoft.CommonLanguageServerProtocol.Framework;
 
/// <summary>
/// Provides information about an <see cref="IMethodHandler"/>.
/// </summary>
/// <param name="MethodName">The name of the LSP method handled.</param>
/// <param name="Language">The language this <see cref="IMethodHandler"/> targets.</param>
/// <param name="RequestTypeRef">A <see cref="TypeRef"/> representing the request type, if any.</param>
/// <param name="ResponseTypeRef">A <see cref="TypeRef"/> representing the response type, if any.</param>
/// <param name="RequestContextTypeRef">A <see cref="TypeRef"/> representing the context type.</param>
internal sealed record MethodHandlerDetails(
    string MethodName,
    string Language,
    TypeRef? RequestTypeRef,
    TypeRef? ResponseTypeRef,
    TypeRef RequestContextTypeRef)
{
    public static ImmutableArray<MethodHandlerDetails> From(Type handlerType)
    {
        var allHandlerDetails = GetAllHandlerDetails(handlerType);
        var builder = ImmutableArray.CreateBuilder<MethodHandlerDetails>(initialCapacity: allHandlerDetails.Length);
 
        foreach (var (requestType, responseType, requestContextType) in allHandlerDetails)
        {
            var (method, languages) = GetRequestHandlerMethod(handlerType, requestType, requestContextType, responseType);
 
            foreach (var language in languages)
            {
                builder.Add(new(
                    method,
                    language,
                    TypeRef.FromOrNull(requestType),
                    TypeRef.FromOrNull(responseType),
                    TypeRef.From(requestContextType)));
            }
        }
 
        return builder.DrainToImmutable();
    }
 
    /// <summary>
    /// Retrieves the generic argument information from the request handler type without instantiating it.
    /// </summary>
    private static ImmutableArray<(Type? RequestType, Type? ResponseType, Type RequestContextType)> GetAllHandlerDetails(Type handlerType)
    {
        var builder = ImmutableArray.CreateBuilder<(Type? RequestType, Type? ResponseType, Type RequestContextType)>();
 
        foreach (var interfaceType in handlerType.GetInterfaces())
        {
            if (!interfaceType.IsGenericType)
            {
                continue;
            }
 
            var genericDefinition = interfaceType.GetGenericTypeDefinition();
 
            if (genericDefinition == typeof(IRequestHandler<,,>))
            {
                var genericArguments = interfaceType.GetGenericArguments();
                builder.Add((RequestType: genericArguments[0], ResponseType: genericArguments[1], RequestContextType: genericArguments[2]));
            }
            else if (genericDefinition == typeof(IRequestHandler<,>))
            {
                var genericArguments = interfaceType.GetGenericArguments();
                builder.Add((RequestType: null, ResponseType: genericArguments[0], RequestContextType: genericArguments[1]));
            }
            else if (genericDefinition == typeof(INotificationHandler<,>))
            {
                var genericArguments = interfaceType.GetGenericArguments();
                builder.Add((RequestType: genericArguments[0], ResponseType: null, RequestContextType: genericArguments[1]));
            }
            else if (genericDefinition == typeof(INotificationHandler<>))
            {
                var genericArguments = interfaceType.GetGenericArguments();
                builder.Add((RequestType: null, ResponseType: null, RequestContextType: genericArguments[0]));
            }
        }
 
        if (builder.Count == 0)
        {
            throw new InvalidOperationException($"Provided handler type {handlerType.FullName} does not implement {nameof(IMethodHandler)}");
        }
 
        return builder.DrainToImmutable();
    }
 
    private static (string name, IEnumerable<string> languages) GetRequestHandlerMethod(Type handlerType, Type? requestType, Type contextType, Type? responseType)
    {
        // Get the LSP method name from the handler's method name attribute.
        var methodAttribute = GetMethodAttributeFromClassOrInterface(handlerType);
        if (methodAttribute is null)
        {
            methodAttribute = GetMethodAttributeFromHandlerMethod(handlerType, requestType, contextType, responseType);
 
            if (methodAttribute is null)
            {
                throw new InvalidOperationException($"{handlerType.FullName} is missing {nameof(LanguageServerEndpointAttribute)}");
            }
        }
 
        return (methodAttribute.Method, methodAttribute.Languages);
 
        static LanguageServerEndpointAttribute? GetMethodAttributeFromHandlerMethod(Type handlerType, Type? requestType, Type contextType, Type? responseType)
        {
            const string HandleRequestName = nameof(IRequestHandler<object, object, object>.HandleRequestAsync);
            const string HandleRequestSuffix = "." + HandleRequestName;
            const string HandleNotificationName = nameof(INotificationHandler<object, object>.HandleNotificationAsync);
            const string HandleNotificationSuffix = "." + HandleNotificationName;
 
            foreach (var methodInfo in handlerType.GetRuntimeMethods())
            {
                if (MethodInfoMatches(methodInfo))
                    return methodInfo.GetCustomAttribute<LanguageServerEndpointAttribute>();
            }
 
            throw new InvalidOperationException("Somehow we are missing the method for our registered handler");
 
            bool MethodInfoMatches(MethodInfo methodInfo)
            {
                var methodName = methodInfo.Name;
 
                switch (requestType != null, responseType != null)
                {
                    case (true, true):
                        return NameMatches(methodInfo, HandleRequestName, HandleRequestSuffix)
                            && TypesMatch(methodInfo, [requestType, contextType, typeof(CancellationToken)]);
 
                    case (false, true):
                        return NameMatches(methodInfo, HandleRequestName, HandleNotificationSuffix)
                            && TypesMatch(methodInfo, [contextType, typeof(CancellationToken)]);
 
                    case (true, false):
                        return NameMatches(methodInfo, HandleNotificationName, HandleNotificationSuffix)
                            && TypesMatch(methodInfo, [requestType, contextType, typeof(CancellationToken)]);
 
                    case (false, false):
                        return NameMatches(methodInfo, HandleNotificationName, HandleNotificationSuffix)
                            && TypesMatch(methodInfo, [contextType, typeof(CancellationToken)]);
                }
            }
 
            static bool NameMatches(MethodInfo methodInfo, string name, string suffix)
            {
                var methodName = methodInfo.Name;
 
                return methodName == name || methodName.EndsWith(suffix);
            }
 
            static bool TypesMatch(MethodInfo methodInfo, Type?[] types)
            {
                var parameters = methodInfo.GetParameters();
                if (parameters.Length != types.Length)
                    return false;
 
                for (int i = 0, n = parameters.Length; i < n; i++)
                {
                    if (!Equals(types[i], parameters[i].ParameterType))
                        return false;
                }
 
                return true;
            }
        }
 
        static LanguageServerEndpointAttribute? GetMethodAttributeFromClassOrInterface(Type type)
        {
            var attribute = Attribute.GetCustomAttribute(type, typeof(LanguageServerEndpointAttribute)) as LanguageServerEndpointAttribute;
 
            if (attribute is null)
            {
                var interfaces = type.GetInterfaces();
                foreach (var @interface in interfaces)
                {
                    attribute = GetMethodAttributeFromClassOrInterface(@interface);
                    if (attribute is not null)
                    {
                        break;
                    }
                }
            }
 
            return attribute;
        }
    }
}