File: Infrastructure\ServerFactory.cs
Web Access
Project: src\src\Identity\test\Identity.FunctionalTests\Microsoft.AspNetCore.Identity.FunctionalTests.csproj (Microsoft.AspNetCore.Identity.FunctionalTests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using Identity.DefaultUI.WebSite;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Identity.UI;
using Microsoft.AspNetCore.Mvc.ApplicationParts;
using Microsoft.AspNetCore.Mvc.Testing;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
 
namespace Microsoft.AspNetCore.Identity.FunctionalTests;
 
public class ServerFactory<TStartup, TContext> : WebApplicationFactory<TStartup>
    where TStartup : class
    where TContext : DbContext
{
    private readonly SqliteConnection _connection
        = new SqliteConnection($"DataSource=:memory:");
 
    public ServerFactory()
    {
        _connection.Open();
 
        ClientOptions.AllowAutoRedirect = false;
        ClientOptions.BaseAddress = new Uri("https://localhost");
    }
 
    public string BootstrapFrameworkVersion { get; set; } = "V5";
 
    protected override IHostBuilder CreateHostBuilder()
    {
        return Program.CreateHostBuilder(new[] { "--use-startup=false" });
    }
 
    protected override void ConfigureWebHost(IWebHostBuilder builder)
    {
        base.ConfigureWebHost(builder);
        builder.UseStartup<TStartup>();
 
        builder.ConfigureServices(sc =>
        {
            sc.SetupTestDatabase<TContext>(_connection)
                .AddMvc()
                // Mark the cookie as essential for right now, as Identity uses it on
                // several places to pass important data in post-redirect-get flows.
                .AddCookieTempDataProvider(o => o.Cookie.IsEssential = true);
        });
 
        UpdateApplicationParts(builder);
    }
 
    private void UpdateApplicationParts(IWebHostBuilder builder) =>
        builder.ConfigureServices(services => AddRelatedParts(services, BootstrapFrameworkVersion));
 
    protected override IHost CreateHost(IHostBuilder builder)
    {
        var result = base.CreateHost(builder);
        EnsureDatabaseCreated(result.Services);
        return result;
    }
 
    protected override TestServer CreateServer(IWebHostBuilder builder)
    {
        var result = base.CreateServer(builder);
        EnsureDatabaseCreated(result.Host.Services);
        return result;
    }
 
    public void EnsureDatabaseCreated(IServiceProvider services)
    {
        using (var scope = services.CreateScope())
        {
            scope.ServiceProvider.GetService<TContext>()?.Database?.EnsureCreated();
        }
    }
 
    protected override void Dispose(bool disposing)
    {
        _connection.Dispose();
 
        base.Dispose(disposing);
    }
 
    private static void AddRelatedParts(IServiceCollection services, string framework)
    {
        var _assemblyMap =
            new Dictionary<UIFramework, string>()
            {
                [UIFramework.Bootstrap5] = "Microsoft.AspNetCore.Identity.UI.Views.V5",
                [UIFramework.Bootstrap4] = "Microsoft.AspNetCore.Identity.UI.Views.V4",
            };
 
        var mvcBuilder = services
            .AddMvc()
            .ConfigureApplicationPartManager(partManager =>
            {
                var thisAssembly = typeof(IdentityBuilderUIExtensions).Assembly;
                var relatedAssemblies = RelatedAssemblyAttribute.GetRelatedAssemblies(thisAssembly, throwOnError: true);
                var relatedParts = relatedAssemblies.ToDictionary(
                    ra => ra,
                    CompiledRazorAssemblyApplicationPartFactory.GetDefaultApplicationParts);
 
                var selectedFrameworkAssembly = _assemblyMap[framework == "V4" ? UIFramework.Bootstrap4 : UIFramework.Bootstrap5];
 
                foreach (var kvp in relatedParts)
                {
                    var assemblyName = kvp.Key.GetName().Name;
                    if (!IsAssemblyForFramework(selectedFrameworkAssembly, assemblyName))
                    {
                        RemoveParts(partManager, kvp.Value);
                    }
                    else
                    {
                        AddParts(partManager, kvp.Value);
                    }
                }
                bool IsAssemblyForFramework(string frameworkAssembly, string assemblyName) =>
                    string.Equals(assemblyName, frameworkAssembly, StringComparison.OrdinalIgnoreCase);
                void RemoveParts(
                    ApplicationPartManager manager,
                    IEnumerable<ApplicationPart> partsToRemove)
                {
                    for (var i = 0; i < manager.ApplicationParts.Count; i++)
                    {
                        var part = manager.ApplicationParts[i];
                        if (partsToRemove.Any(p => string.Equals(
                                p.Name,
                                part.Name,
                                StringComparison.OrdinalIgnoreCase)))
                        {
                            manager.ApplicationParts.Remove(part);
                        }
                    }
                }
                void AddParts(
                    ApplicationPartManager manager,
                    IEnumerable<ApplicationPart> partsToAdd)
                {
                    foreach (var part in partsToAdd)
                    {
                        if (!manager.ApplicationParts.Any(p => p.GetType() == part.GetType() &&
                            string.Equals(p.Name, part.Name, StringComparison.OrdinalIgnoreCase)))
                        {
                            manager.ApplicationParts.Add(part);
                        }
                    }
                }
            });
    }
}