File: StartupHookGenerator.cs
Web Access
Project: src\src\Components\Testing\gen\Microsoft.AspNetCore.Components.Testing.Generators.csproj (Microsoft.AspNetCore.Components.Testing.Generators)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
 
namespace Microsoft.AspNetCore.Components.Testing.Generators;
 
[Generator]
internal class StartupHookGenerator : IIncrementalGenerator
{
    public void Initialize(IncrementalGeneratorInitializationContext context)
    {
        // Static content — always emitted
        context.RegisterPostInitializationOutput(static ctx =>
        {
            ctx.AddSource("StartupHook.g.cs", StartupHookSource);
            ctx.AddSource("HostingStartupAttribute.g.cs", HostingStartupAttributeSource);
        });
 
        // Detect ConfigureServices callsites and emit a resolver
        var callsites = context.SyntaxProvider.CreateSyntaxProvider(
                predicate: IsConfigureServicesCandidate,
                transform: ExtractCallsite)
            .Where(static x => x is not null);
 
        var collected = callsites.Collect();
 
        context.RegisterSourceOutput(collected, EmitServiceOverrideResolver);
    }
 
    static bool IsConfigureServicesCandidate(SyntaxNode node, CancellationToken ct)
    {
        if (node is not InvocationExpressionSyntax invocation)
        {
            return false;
        }
 
        // Fast path: check if the method name looks like ConfigureServices
        return GetInvokedMethodName(invocation) == "ConfigureServices";
    }
 
    static string GetInvokedMethodName(InvocationExpressionSyntax invocation)
    {
        switch (invocation.Expression)
        {
            case MemberAccessExpressionSyntax memberAccess:
                return memberAccess.Name switch
                {
                    GenericNameSyntax generic => generic.Identifier.Text,
                    IdentifierNameSyntax identifier => identifier.Identifier.Text,
                    _ => ""
                };
            case GenericNameSyntax generic:
                return generic.Identifier.Text;
            case IdentifierNameSyntax identifier:
                return identifier.Identifier.Text;
            default:
                return "";
        }
    }
 
    static ServiceOverrideCallsite? ExtractCallsite(
        GeneratorSyntaxContext context, CancellationToken ct)
    {
        var invocation = (InvocationExpressionSyntax)context.Node;
        var semanticModel = context.SemanticModel;
 
        var symbolInfo = semanticModel.GetSymbolInfo(invocation, ct);
        if (symbolInfo.Symbol is not IMethodSymbol method)
        {
            return null;
        }
 
        // Verify it's ServerStartOptions.ConfigureServices
        if (method.Name != "ConfigureServices" ||
            method.ContainingType?.ToDisplayString() !=
                "Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions")
        {
            return null;
        }
 
        // Skip calls within ServerStartOptions itself (the generic overload
        // forwards to the non-generic overload with a non-constant parameter)
        var containingSymbol = semanticModel.GetEnclosingSymbol(invocation.SpanStart, ct);
        if (containingSymbol?.ContainingType?.ToDisplayString() ==
            "Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions")
        {
            return null;
        }
 
        ITypeSymbol? overrideType = null;
        ExpressionSyntax? methodNameExpr = null;
 
        if (method.IsGenericMethod && method.TypeArguments.Length == 1)
        {
            // Generic: ConfigureServices<T>("methodName")
            overrideType = method.TypeArguments[0];
            if (invocation.ArgumentList.Arguments.Count >= 1)
            {
                methodNameExpr = invocation.ArgumentList.Arguments[0].Expression;
            }
        }
        else if (!method.IsGenericMethod && method.Parameters.Length == 2)
        {
            // Non-generic: ConfigureServices(typeof(T), "methodName")
            if (invocation.ArgumentList.Arguments.Count >= 2)
            {
                var typeArg = invocation.ArgumentList.Arguments[0].Expression;
                if (typeArg is TypeOfExpressionSyntax typeOfExpr)
                {
                    var typeInfo = semanticModel.GetTypeInfo(typeOfExpr.Type, ct);
                    overrideType = typeInfo.Type;
                }
 
                methodNameExpr = invocation.ArgumentList.Arguments[1].Expression;
            }
        }
 
        if (overrideType is null || methodNameExpr is null)
        {
            return null;
        }
 
        // Resolve the method name constant (string literal or nameof)
        var constantValue = semanticModel.GetConstantValue(methodNameExpr, ct);
        if (!constantValue.HasValue || constantValue.Value is not string methodName)
        {
            // Non-constant — the analyzer will report E2E002, skip in generator
            return null;
        }
 
        // Validate the target method exists with the correct signature
        var targetMethod = overrideType.GetMembers(methodName)
            .OfType<IMethodSymbol>()
            .FirstOrDefault(m =>
                m.IsStatic &&
                m.Parameters.Length == 1 &&
                m.Parameters[0].Type.ToDisplayString() ==
                    "Microsoft.Extensions.DependencyInjection.IServiceCollection");
 
        if (targetMethod is null)
        {
            // Invalid callsite — the analyzer will report E2E001/E2E003, skip in generator
            return null;
        }
 
        var fullyQualified = overrideType.ToDisplayString(
            SymbolDisplayFormat.FullyQualifiedFormat);
        var fullName = overrideType.ToDisplayString();
        var assemblyName = overrideType.ContainingAssembly?.Name ?? "";
 
        return new ServiceOverrideCallsite(fullyQualified, fullName, assemblyName, methodName);
    }
 
    static void EmitServiceOverrideResolver(
        SourceProductionContext context,
        ImmutableArray<ServiceOverrideCallsite?> callsites)
    {
        var distinct = callsites
            .Where(c => c is not null)
            .Cast<ServiceOverrideCallsite>()
            .Distinct()
            .ToList();
 
        var sb = new StringBuilder();
        sb.AppendLine("// <auto-generated/>");
        sb.AppendLine("#nullable enable");
        sb.AppendLine();
        sb.AppendLine("namespace Microsoft.AspNetCore.Components.Testing.Generated;");
        sb.AppendLine();
        sb.AppendLine("internal sealed class ServiceOverrideResolver : global::Microsoft.AspNetCore.Components.Testing.Infrastructure.IE2EServiceOverrideResolver");
        sb.AppendLine("{");
        sb.AppendLine("    public global::System.Action<global::Microsoft.Extensions.DependencyInjection.IServiceCollection>? TryResolve(");
        sb.AppendLine("        string assemblyQualifiedTypeName, string methodName)");
        sb.AppendLine("    {");
 
        if (distinct.Count > 0)
        {
            // Group callsites by the override type
            var groups = distinct
                .GroupBy(c => new { c.TypeFullyQualifiedName, c.TypeFullName, c.AssemblyName })
                .ToList();
 
            foreach (var group in groups)
            {
                var typePrefix = group.Key.TypeFullName + ", " + group.Key.AssemblyName + ",";
                sb.AppendLine($"        if (assemblyQualifiedTypeName.StartsWith(\"{typePrefix}\", global::System.StringComparison.Ordinal))");
                sb.AppendLine("        {");
                sb.AppendLine("            return methodName switch");
                sb.AppendLine("            {");
 
                foreach (var callsite in group)
                {
                    sb.AppendLine($"                \"{callsite.MethodName}\" => {callsite.TypeFullyQualifiedName}.{callsite.MethodName},");
                }
 
                sb.AppendLine("                _ => null,");
                sb.AppendLine("            };");
                sb.AppendLine("        }");
                sb.AppendLine();
            }
        }
 
        sb.AppendLine("        return null;");
        sb.AppendLine("    }");
        sb.AppendLine("}");
 
        context.AddSource("ServiceOverrideResolver.g.cs", sb.ToString());
    }
 
    private const string StartupHookSource = """
        // <auto-generated/>
        using System;
        using System.IO;
        using System.Reflection;
        using System.Runtime.Loader;
 
        internal class StartupHook
        {
            public static void Initialize()
            {
                var startupHookAssembly = Assembly.GetExecutingAssembly();
                var testBinDirectory = Path.GetDirectoryName(startupHookAssembly.Location);
 
                AssemblyLoadContext.Default.Resolving += (context, assemblyName) =>
                {
                    var candidatePath = Path.Combine(testBinDirectory, assemblyName.Name + ".dll");
                    if (File.Exists(candidatePath))
                    {
                        return context.LoadFromAssemblyPath(candidatePath);
                    }
                    return null;
                };
            }
        }
        """;
 
    private const string HostingStartupAttributeSource = """
        // <auto-generated/>
        [assembly: Microsoft.AspNetCore.Hosting.HostingStartup(typeof(Microsoft.AspNetCore.Components.Testing.Infrastructure.TestReadinessHostingStartup))]
        """;
}