File: UseWhenExtensionsTests.cs
Web Access
Project: src\src\Http\Http.Abstractions\test\Microsoft.AspNetCore.Http.Abstractions.Tests.csproj (Microsoft.AspNetCore.Http.Abstractions.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using Microsoft.AspNetCore.Http;
 
namespace Microsoft.AspNetCore.Builder.Extensions;
 
public class UseWhenExtensionsTests
{
    [Fact]
    public void NullArguments_ArgumentNullException()
    {
        // Arrange
        var builder = CreateBuilder();
 
        // Act
        Action nullPredicate = () => builder.UseWhen(null!, app => { });
        Action nullConfiguration = () => builder.UseWhen(TruePredicate, null!);
 
        // Assert
        Assert.Throws<ArgumentNullException>(nullPredicate);
        Assert.Throws<ArgumentNullException>(nullConfiguration);
    }
 
    [Fact]
    public async Task PredicateTrue_BranchTaken_WillRejoin()
    {
        // Arrange
        var context = CreateContext();
        var parent = CreateBuilder();
 
        parent.UseWhen(TruePredicate, child =>
        {
            child.UseWhen(TruePredicate, grandchild =>
            {
                grandchild.Use(Increment("grandchild"));
            });
 
            child.Use(Increment("child"));
        });
 
        parent.Use(Increment("parent"));
 
        // Act
        await parent.Build().Invoke(context);
 
        // Assert
        Assert.Equal(1, Count(context, "parent"));
        Assert.Equal(1, Count(context, "child"));
        Assert.Equal(1, Count(context, "grandchild"));
    }
 
    [Fact]
    public async Task PredicateTrue_BranchTaken_CanTerminate()
    {
        // Arrange
        var context = CreateContext();
        var parent = CreateBuilder();
 
        parent.UseWhen(TruePredicate, child =>
        {
            child.UseWhen(TruePredicate, grandchild =>
            {
                grandchild.Use(Increment("grandchild", terminate: true));
            });
 
            child.Use(Increment("child"));
        });
 
        parent.Use(Increment("parent"));
 
        // Act
        await parent.Build().Invoke(context);
 
        // Assert
        Assert.Equal(0, Count(context, "parent"));
        Assert.Equal(0, Count(context, "child"));
        Assert.Equal(1, Count(context, "grandchild"));
    }
 
    [Fact]
    public async Task PredicateFalse_PassThrough()
    {
        // Arrange
        var context = CreateContext();
        var parent = CreateBuilder();
 
        parent.UseWhen(FalsePredicate, child =>
        {
            child.Use(Increment("child"));
        });
 
        parent.Use(Increment("parent"));
 
        // Act
        await parent.Build().Invoke(context);
 
        // Assert
        Assert.Equal(1, Count(context, "parent"));
        Assert.Equal(0, Count(context, "child"));
    }
 
    private static HttpContext CreateContext()
    {
        return new DefaultHttpContext();
    }
 
    private static ApplicationBuilder CreateBuilder()
    {
        return new ApplicationBuilder(serviceProvider: null!);
    }
 
    private static bool TruePredicate(HttpContext context)
    {
        return true;
    }
 
    private static bool FalsePredicate(HttpContext context)
    {
        return false;
    }
 
    private static Func<HttpContext, Func<Task>, Task> Increment(string key, bool terminate = false)
    {
        return (context, next) =>
        {
            if (!context.Items.ContainsKey(key))
            {
                context.Items[key] = 1;
            }
            else
            {
                var item = context.Items[key];
 
                if (item is int)
                {
                    context.Items[key] = 1 + (int)item;
                }
                else
                {
                    context.Items[key] = 1;
                }
            }
 
            return terminate ? Task.FromResult<object?>(null) : next();
        };
    }
 
    private static int Count(HttpContext context, string key)
    {
        if (!context.Items.ContainsKey(key))
        {
            return 0;
        }
 
        var item = context.Items[key];
 
        if (item is int)
        {
            return (int)item;
        }
 
        return 0;
    }
}