File: Generators\StartupHookGeneratorTests.cs
Web Access
Project: src\src\Components\Testing\test\Microsoft.AspNetCore.Components.Testing.Tests.csproj (Microsoft.AspNetCore.Components.Testing.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.AspNetCore.Components.Testing.Generators;
 
namespace Microsoft.AspNetCore.Components.Testing.Tests.Generators;
 
public class StartupHookGeneratorTests
{
    [Fact]
    public void Generator_EmitsStartupHookSource()
    {
        var compilation = CreateMinimalCompilation();
        var result = RunGenerator(compilation);
 
        var startupHook = result.GeneratedTrees
            .SingleOrDefault(t => t.FilePath.EndsWith("StartupHook.g.cs"));
        Assert.NotNull(startupHook);
 
        var source = startupHook!.GetText(TestContext.Current.CancellationToken).ToString();
        Assert.Contains("internal class StartupHook", source);
        Assert.Contains("public static void Initialize()", source);
        Assert.Contains("AssemblyLoadContext.Default.Resolving", source);
        Assert.Contains("Assembly.GetExecutingAssembly()", source);
    }
 
    [Fact]
    public void Generator_EmitsHostingStartupAttribute()
    {
        var compilation = CreateMinimalCompilation();
        var result = RunGenerator(compilation);
 
        var attributeSource = result.GeneratedTrees
            .SingleOrDefault(t => t.FilePath.EndsWith("HostingStartupAttribute.g.cs"));
        Assert.NotNull(attributeSource);
 
        var source = attributeSource!.GetText(TestContext.Current.CancellationToken).ToString();
        Assert.Contains("HostingStartup", source);
        Assert.Contains("Microsoft.AspNetCore.Components.Testing.Infrastructure.TestReadinessHostingStartup", source);
    }
 
    [Fact]
    public void Generator_EmitsThreeFiles()
    {
        var compilation = CreateMinimalCompilation();
        var result = RunGenerator(compilation);
 
        // StartupHook.g.cs, HostingStartupAttribute.g.cs, ServiceOverrideResolver.g.cs
        Assert.Equal(3, result.GeneratedTrees.Length);
    }
 
    [Fact]
    public void Generator_ProducesNoDiagnostics_WhenNoCallsites()
    {
        var compilation = CreateMinimalCompilation();
        var result = RunGenerator(compilation);
 
        Assert.Empty(result.Diagnostics);
    }
 
    [Fact]
    public void Generator_StartupHook_HasAutoGeneratedComment()
    {
        var compilation = CreateMinimalCompilation();
        var result = RunGenerator(compilation);
 
        var startupHook = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("StartupHook.g.cs"));
        var source = startupHook.GetText(TestContext.Current.CancellationToken).ToString();
        Assert.Contains("<auto-generated/>", source);
    }
 
    [Fact]
    public void Generator_StartupHook_ProbesTestBinDirectory()
    {
        var compilation = CreateMinimalCompilation();
        var result = RunGenerator(compilation);
 
        var startupHook = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("StartupHook.g.cs"));
        var source = startupHook.GetText(TestContext.Current.CancellationToken).ToString();
        Assert.Contains("GetDirectoryName", source);
        Assert.Contains("assemblyName.Name + \".dll\"", source);
        Assert.Contains("LoadFromAssemblyPath", source);
    }
 
    [Fact]
    public void Generator_HostingStartupAttribute_HasAutoGeneratedComment()
    {
        var compilation = CreateMinimalCompilation();
        var result = RunGenerator(compilation);
 
        var attribute = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("HostingStartupAttribute.g.cs"));
        var source = attribute.GetText(TestContext.Current.CancellationToken).ToString();
        Assert.Contains("<auto-generated/>", source);
    }
 
    [Fact]
    public void Generator_ResultsPerGenerator_HasSingleEntry()
    {
        var compilation = CreateMinimalCompilation();
        var result = RunGenerator(compilation);
 
        Assert.Single(result.Results);
        Assert.Equal(3, result.Results[0].GeneratedSources.Length);
        Assert.Empty(result.Results[0].Diagnostics);
    }
 
    [Fact]
    public void Generator_EmitsEmptyResolver_WhenNoCallsites()
    {
        var compilation = CreateMinimalCompilation();
        var result = RunGenerator(compilation);
 
        var resolver = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("ServiceOverrideResolver.g.cs"));
        var source = resolver.GetText(TestContext.Current.CancellationToken).ToString();
 
        Assert.Contains("class ServiceOverrideResolver", source);
        Assert.Contains("IE2EServiceOverrideResolver", source);
        Assert.Contains("return null;", source);
        // No switch cases
        Assert.DoesNotContain("return methodName switch", source);
    }
 
    [Fact]
    public void Generator_DetectsGenericConfigureServices_WithNameof()
    {
        var source = """
            namespace TestApp
            {
                class TestOverrides
                {
                    public static void FakeWeather(Microsoft.Extensions.DependencyInjection.IServiceCollection services) { }
                }
 
                class MyTest
                {
                    void Setup()
                    {
                        var options = new Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions();
                        options.ConfigureServices<TestOverrides>(nameof(TestOverrides.FakeWeather));
                    }
                }
            }
            """;
 
        var compilation = CreateCompilationWithInfrastructure(source);
        var result = RunGenerator(compilation);
 
        var resolver = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("ServiceOverrideResolver.g.cs"));
        var resolverSource = resolver.GetText(TestContext.Current.CancellationToken).ToString();
 
        Assert.Contains("\"FakeWeather\" => global::TestApp.TestOverrides.FakeWeather", resolverSource);
        Assert.Empty(result.Diagnostics);
    }
 
    [Fact]
    public void Generator_DetectsGenericConfigureServices_WithStringLiteral()
    {
        var source = """
            namespace TestApp
            {
                class TestOverrides
                {
                    public static void FakeWeather(Microsoft.Extensions.DependencyInjection.IServiceCollection services) { }
                }
 
                class MyTest
                {
                    void Setup()
                    {
                        var options = new Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions();
                        options.ConfigureServices<TestOverrides>("FakeWeather");
                    }
                }
            }
            """;
 
        var compilation = CreateCompilationWithInfrastructure(source);
        var result = RunGenerator(compilation);
 
        var resolver = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("ServiceOverrideResolver.g.cs"));
        var resolverSource = resolver.GetText(TestContext.Current.CancellationToken).ToString();
 
        Assert.Contains("\"FakeWeather\" => global::TestApp.TestOverrides.FakeWeather", resolverSource);
        Assert.Empty(result.Diagnostics);
    }
 
    [Fact]
    public void Generator_DetectsNonGenericConfigureServices()
    {
        var source = """
            namespace TestApp
            {
                class TestOverrides
                {
                    public static void Configure(Microsoft.Extensions.DependencyInjection.IServiceCollection services) { }
                }
 
                class MyTest
                {
                    void Setup()
                    {
                        var options = new Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions();
                        options.ConfigureServices(typeof(TestOverrides), "Configure");
                    }
                }
            }
            """;
 
        var compilation = CreateCompilationWithInfrastructure(source);
        var result = RunGenerator(compilation);
 
        var resolver = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("ServiceOverrideResolver.g.cs"));
        var resolverSource = resolver.GetText(TestContext.Current.CancellationToken).ToString();
 
        Assert.Contains("\"Configure\" => global::TestApp.TestOverrides.Configure", resolverSource);
        Assert.Empty(result.Diagnostics);
    }
 
    [Fact]
    public void Generator_DetectsMultipleOverridesOnSameType()
    {
        var source = """
            namespace TestApp
            {
                class TestOverrides
                {
                    public static void FakeWeather(Microsoft.Extensions.DependencyInjection.IServiceCollection services) { }
                    public static void LockableWeather(Microsoft.Extensions.DependencyInjection.IServiceCollection services) { }
                }
 
                class MyTest
                {
                    void Setup()
                    {
                        var options = new Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions();
                        options.ConfigureServices<TestOverrides>(nameof(TestOverrides.FakeWeather));
                        options.ConfigureServices<TestOverrides>(nameof(TestOverrides.LockableWeather));
                    }
                }
            }
            """;
 
        var compilation = CreateCompilationWithInfrastructure(source);
        var result = RunGenerator(compilation);
 
        var resolver = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("ServiceOverrideResolver.g.cs"));
        var resolverSource = resolver.GetText(TestContext.Current.CancellationToken).ToString();
 
        Assert.Contains("\"FakeWeather\" => global::TestApp.TestOverrides.FakeWeather", resolverSource);
        Assert.Contains("\"LockableWeather\" => global::TestApp.TestOverrides.LockableWeather", resolverSource);
        Assert.Empty(result.Diagnostics);
    }
 
    [Fact]
    public void Generator_SkipsInvalidCallsite_WhenMethodNotFound()
    {
        var source = """
            namespace TestApp
            {
                class TestOverrides
                {
                    // No FakeWeather method
                }
 
                class MyTest
                {
                    void Setup()
                    {
                        var options = new Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions();
                        options.ConfigureServices<TestOverrides>("FakeWeather");
                    }
                }
            }
            """;
 
        var compilation = CreateCompilationWithInfrastructure(source);
        var result = RunGenerator(compilation);
 
        // Generator should NOT report diagnostics — that's the analyzer's job
        Assert.Empty(result.Diagnostics);
 
        // Resolver should be emitted but with no entries
        var resolver = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("ServiceOverrideResolver.g.cs"));
        var resolverSource = resolver.GetText(TestContext.Current.CancellationToken).ToString();
        Assert.DoesNotContain("return methodName switch", resolverSource);
    }
 
    [Fact]
    public void Generator_SkipsCallsite_WhenMethodNotStatic()
    {
        var source = """
            namespace TestApp
            {
                class TestOverrides
                {
                    public void FakeWeather(Microsoft.Extensions.DependencyInjection.IServiceCollection services) { }
                }
 
                class MyTest
                {
                    void Setup()
                    {
                        var options = new Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions();
                        options.ConfigureServices<TestOverrides>("FakeWeather");
                    }
                }
            }
            """;
 
        var compilation = CreateCompilationWithInfrastructure(source);
        var result = RunGenerator(compilation);
 
        // Generator should NOT report diagnostics
        Assert.Empty(result.Diagnostics);
 
        // Resolver should be emitted but with no entries
        var resolver = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("ServiceOverrideResolver.g.cs"));
        var resolverSource = resolver.GetText(TestContext.Current.CancellationToken).ToString();
        Assert.DoesNotContain("return methodName switch", resolverSource);
    }
 
    [Fact]
    public void Generator_Resolver_MatchesOnTypePrefixAndAssembly()
    {
        var source = """
            namespace TestApp.Overrides
            {
                class TestOverrides
                {
                    public static void FakeWeather(Microsoft.Extensions.DependencyInjection.IServiceCollection services) { }
                }
 
                class MyTest
                {
                    void Setup()
                    {
                        var options = new Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions();
                        options.ConfigureServices<TestOverrides>(nameof(TestOverrides.FakeWeather));
                    }
                }
            }
            """;
 
        var compilation = CreateCompilationWithInfrastructure(source);
        var result = RunGenerator(compilation);
 
        var resolver = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("ServiceOverrideResolver.g.cs"));
        var resolverSource = resolver.GetText(TestContext.Current.CancellationToken).ToString();
 
        // Should contain the full type name + assembly for prefix matching
        Assert.Contains("TestApp.Overrides.TestOverrides, TestAssembly,", resolverSource);
    }
 
    [Fact]
    public void Generator_DeduplicatesIdenticalCallsites()
    {
        var source = """
            namespace TestApp
            {
                class TestOverrides
                {
                    public static void FakeWeather(Microsoft.Extensions.DependencyInjection.IServiceCollection services) { }
                }
 
                class Test1
                {
                    void Setup()
                    {
                        var options = new Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions();
                        options.ConfigureServices<TestOverrides>("FakeWeather");
                    }
                }
 
                class Test2
                {
                    void Setup()
                    {
                        var options = new Microsoft.AspNetCore.Components.Testing.Infrastructure.ServerStartOptions();
                        options.ConfigureServices<TestOverrides>("FakeWeather");
                    }
                }
            }
            """;
 
        var compilation = CreateCompilationWithInfrastructure(source);
        var result = RunGenerator(compilation);
 
        var resolver = result.GeneratedTrees
            .Single(t => t.FilePath.EndsWith("ServiceOverrideResolver.g.cs"));
        var resolverSource = resolver.GetText(TestContext.Current.CancellationToken).ToString();
 
        // Should appear only once despite two callsites
        var count = resolverSource.Split("\"FakeWeather\" =>").Length - 1;
        Assert.Equal(1, count);
    }
 
    internal static GeneratorDriverRunResult RunGenerator(CSharpCompilation compilation)
    {
        var generator = new StartupHookGenerator();
        var driver = CSharpGeneratorDriver.Create(generator);
        driver = (CSharpGeneratorDriver)driver.RunGenerators(compilation);
        return driver.GetRunResult();
    }
 
    internal static CSharpCompilation CreateMinimalCompilation()
    {
        var syntaxTree = CSharpSyntaxTree.ParseText("namespace TestAssembly;");
 
        return CSharpCompilation.Create(
            "TestAssembly",
            [syntaxTree],
            [MetadataReference.CreateFromFile(typeof(object).Assembly.Location)],
            new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
    }
 
    internal static CSharpCompilation CreateCompilationWithInfrastructure(string userSource)
    {
        // Provide minimal mock types so the generator can resolve symbols
        var infraSource = """
            namespace System
            {
                public class Action<T> { }
            }
            namespace Microsoft.Extensions.DependencyInjection
            {
                public interface IServiceCollection { }
            }
            namespace Microsoft.AspNetCore.Components.Testing.Infrastructure
            {
                public class ServerStartOptions
                {
                    public void ConfigureServices<T>(string methodName) { }
                    public void ConfigureServices(System.Type type, string methodName) { }
                }
                public interface IE2EServiceOverrideResolver
                {
                    System.Action<Microsoft.Extensions.DependencyInjection.IServiceCollection> TryResolve(
                        string assemblyQualifiedTypeName, string methodName);
                }
            }
            """;
 
        var trees = new[]
        {
            CSharpSyntaxTree.ParseText(infraSource),
            CSharpSyntaxTree.ParseText(userSource),
        };
 
        var references = new[]
        {
            MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
        };
 
        return CSharpCompilation.Create(
            "TestAssembly",
            trees,
            references,
            new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
    }
}