File: Resilience\HttpClientBuilderExtensionsTests.Resilience.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.Http.Resilience.Tests\Microsoft.Extensions.Http.Resilience.Tests.csproj (Microsoft.Extensions.Http.Resilience.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Diagnostics.Metrics.Testing;
using Microsoft.Extensions.Http.Resilience.Internal;
using Microsoft.Extensions.Http.Resilience.Test.Helpers;
using Microsoft.Extensions.Options;
using Moq;
using Polly;
using Polly.Registry;
using Polly.Telemetry;
using Xunit;
 
namespace Microsoft.Extensions.Http.Resilience.Test;
 
public sealed partial class HttpClientBuilderExtensionsTests
{
    [Fact]
    public void AddResilienceHandler_ArgumentValidation()
    {
        var services = new ServiceCollection();
        IHttpClientBuilder? builder = services.AddHttpClient("client");
 
        Assert.Throws<ArgumentNullException>(() => builder.AddResilienceHandler(null!, _ => { }));
        Assert.Throws<ArgumentException>(() => builder.AddResilienceHandler(string.Empty, _ => { }));
        Assert.Throws<ArgumentNullException>(() => builder.AddResilienceHandler(null!, (_, _) => { }));
        Assert.Throws<ArgumentException>(() => builder.AddResilienceHandler(string.Empty, (_, _) => { }));
        Assert.Throws<ArgumentNullException>(() => builder.AddResilienceHandler("dummy", (Action<ResiliencePipelineBuilder<HttpResponseMessage>>)null!));
        Assert.Throws<ArgumentNullException>(() => builder.AddResilienceHandler("dummy", (Action<ResiliencePipelineBuilder<HttpResponseMessage>, ResilienceHandlerContext>)null!));
 
        builder = null;
        Assert.Throws<ArgumentNullException>(() => builder!.AddResilienceHandler("pipeline-name", _ => { }));
        Assert.Throws<ArgumentNullException>(() => builder!.AddResilienceHandler("pipeline-name", (_, _) => { }));
    }
 
    [Fact]
    public void AddResilienceHandler_EnsureCorrectServicesRegistered()
    {
        var services = new ServiceCollection();
        IHttpClientBuilder? builder = services.AddHttpClient("client");
 
        builder.AddResilienceHandler("test", ConfigureBuilder);
 
        // add twice intentionally
        builder.AddResilienceHandler("test", ConfigureBuilder);
 
        Assert.Contains(services, s => s.ServiceType == typeof(ResiliencePipelineProvider<HttpKey>));
    }
 
    [Theory]
#if NET6_0_OR_GREATER
    [CombinatorialData]
#else
    [InlineData(true)]
#endif
    public async Task AddResilienceHandler_OnPipelineDisposed_EnsureCalled(bool asynchronous = true)
    {
        var onPipelineDisposedCalled = false;
        var services = new ServiceCollection();
        IHttpClientBuilder? builder = services
            .AddHttpClient("client")
            .ConfigurePrimaryHttpMessageHandler(() => new TestHandlerStub(HttpStatusCode.OK));
 
        builder.AddResilienceHandler("test", (builder, context) =>
        {
            builder.AddTimeout(TimeSpan.FromSeconds(1));
            context.OnPipelineDisposed(() => onPipelineDisposedCalled = true);
        });
 
        using (var serviceProvider = services.BuildServiceProvider())
        {
            var client = serviceProvider.GetRequiredService<IHttpClientFactory>().CreateClient("client");
            await SendRequest(client, "https://dummy", asynchronous);
        }
 
        onPipelineDisposedCalled.Should().BeTrue();
    }
 
    [Fact]
    public void AddResilienceHandler_EnsureServicesNotAddedTwice()
    {
        var services = new ServiceCollection();
        IHttpClientBuilder? builder = services.AddHttpClient("client");
 
        builder.AddResilienceHandler("test", ConfigureBuilder);
        var count = builder.Services.Count;
 
        // add twice intentionally
        builder.AddResilienceHandler("test", ConfigureBuilder);
 
        // We check that the count of existing services is not unnecessary increased.
        //
        // The additional 3 services that are registered are related to:
        // - Configuration of HTTP client options
        // - Configuration of resilience pipeline
        // - Registration of keyed service for resilience pipeline
        // UPDATE NOTE: Starting from .NET 8.0.2, the count of additional services is 2 instead of 3. This is due to the fact that the registration of the resilience
        // pipeline is now done in the `AddResilienceHandler` method.
        builder.Services.Should().HaveCount(count + 2);
    }
 
    [Theory]
#if NET6_0_OR_GREATER
    [CombinatorialData]
#else
    [InlineData(true)]
#endif
    public async Task AddResilienceHandler_EnsureErrorType(bool asynchronous = true)
    {
        using var metricCollector = new MetricCollector<int>(null, "Polly", "resilience.polly.strategy.events");
        var enricher = new TestMetricsEnricher();
        var clientBuilder = new ServiceCollection()
            .AddHttpClient("client")
            .ConfigurePrimaryHttpMessageHandler(() => new TestHandlerStub(HttpStatusCode.InternalServerError))
            .AddStandardResilienceHandler()
            .Configure(options =>
            {
                options.Retry.ShouldHandle = _ => PredicateResult.True();
                options.Retry.MaxRetryAttempts = 1;
                options.Retry.Delay = TimeSpan.Zero;
            });
 
        clientBuilder.Services.Configure<TelemetryOptions>(o => o.MeteringEnrichers.Add(enricher));
 
        var client = clientBuilder.Services.BuildServiceProvider().GetRequiredService<IHttpClientFactory>().CreateClient("client");
 
        using var response = await SendRequest(client, "https://dummy", asynchronous);
 
        enricher.Tags["error.type"].Should().BeOfType<string>().Subject.Should().Be("500");
    }
 
    [Theory]
#if NET6_0_OR_GREATER
    [CombinatorialData]
#else
    [InlineData(true)]
#endif
    public async Task AddResilienceHandler_EnsureResilienceHandlerContext(bool asynchronous = true)
    {
        var verified = false;
        _builder
            .AddResilienceHandler("test", (_, context) =>
            {
                context.ServiceProvider.Should().NotBeNull();
                context.BuilderName.Should().Be($"{BuilderName}-test");
                context.InstanceName.Should().Be("dummy-key");
                verified = true;
            })
            .SelectPipelineBy(_ => _ => "dummy-key");
 
        _builder.AddHttpMessageHandler(() => new TestHandlerStub(HttpStatusCode.InternalServerError));
 
        var client = CreateClient(BuilderName);
        await SendRequest(client, "https://dummy", asynchronous);
 
        verified.Should().BeTrue();
    }
 
    [Fact]
    public void AddResilienceHandler_EnsureCorrectRegistryOptions()
    {
        var services = new ServiceCollection();
        IHttpClientBuilder? builder = services.AddHttpClient("client");
        builder.AddResilienceHandler("test", ConfigureBuilder);
 
        using var serviceProvider = builder.Services.BuildServiceProvider();
        var registryOptions = serviceProvider.GetRequiredService<IOptions<ResiliencePipelineRegistryOptions<HttpKey>>>().Value;
        registryOptions.BuilderComparer.Equals(new HttpKey("A", "1"), new HttpKey("A", "2")).Should().BeTrue();
        registryOptions.BuilderComparer.Equals(new HttpKey("A", "1"), new HttpKey("B", "1")).Should().BeFalse();
 
        registryOptions.PipelineComparer.Equals(new HttpKey("A", "1"), new HttpKey("A", "1")).Should().BeTrue();
        registryOptions.PipelineComparer.Equals(new HttpKey("A", "1"), new HttpKey("A", "2")).Should().BeFalse();
 
        registryOptions.BuilderNameFormatter(new HttpKey("A", "1")).Should().Be("A");
        registryOptions.InstanceNameFormatter!(new HttpKey("A", "1")).Should().Be("1");
    }
 
    public enum PolicyType
    {
        Fallback,
        Retry,
        CircuitBreaker,
    }
 
    [Theory]
#if NET6_0_OR_GREATER
    [CombinatorialData]
#else
    [InlineData(true)]
    [InlineData(false)]
#endif
    public async Task AddResilienceHandler_EnsureProperPipelineInstanceRetrieved(bool bySelector, bool asynchronous = true)
    {
        // arrange
        var resilienceProvider = new Mock<ResiliencePipelineProvider<HttpKey>>(MockBehavior.Strict);
        var services = new ServiceCollection().AddLogging().AddMetrics().AddFakeRedaction();
        services.AddSingleton(resilienceProvider.Object);
        var builder = services.AddHttpClient("client");
        var pipelineBuilder = builder.AddResilienceHandler("dummy", ConfigureBuilder);
        var expectedPipelineKey = "client-dummy";
        if (bySelector)
        {
            pipelineBuilder.SelectPipelineByAuthority();
        }
 
        builder.AddHttpMessageHandler(() => new TestHandlerStub(HttpStatusCode.OK));
 
        using var provider = services.BuildServiceProvider();
        if (bySelector)
        {
            resilienceProvider
                .Setup(v => v.GetPipeline<HttpResponseMessage>(new HttpKey(expectedPipelineKey, "https://dummy1")))
                .Returns(ResiliencePipeline<HttpResponseMessage>.Empty);
        }
        else
        {
            resilienceProvider
                .Setup(v => v.GetPipeline<HttpResponseMessage>(new HttpKey(expectedPipelineKey, string.Empty)))
                .Returns(ResiliencePipeline<HttpResponseMessage>.Empty);
        }
 
        var client = provider.GetRequiredService<IHttpClientFactory>().CreateClient("client");
 
        // act
        await SendRequest(client, "https://dummy1", asynchronous);
 
        // assert
        resilienceProvider.VerifyAll();
    }
 
    [Theory]
#if NET6_0_OR_GREATER
    [CombinatorialData]
#else
    [InlineData(true)]
#endif
    public async Task AddResilienceHandlerBySelector_EnsureResiliencePipelineProviderCalled(bool asynchronous = true)
    {
        // arrange
        var services = new ServiceCollection().AddLogging().AddMetrics();
        var providerMock = new Mock<ResiliencePipelineProvider<HttpKey>>(MockBehavior.Strict);
 
        services.AddSingleton(providerMock.Object);
        var pipelineName = string.Empty;
 
        pipelineName = "client-my-pipeline";
        var clientBuilder = services.AddHttpClient("client");
        clientBuilder.AddResilienceHandler("my-pipeline", ConfigureBuilder);
        clientBuilder.AddHttpMessageHandler(() => new TestHandlerStub(HttpStatusCode.OK));
 
        providerMock
            .Setup(v => v.GetPipeline<HttpResponseMessage>(new HttpKey(pipelineName, string.Empty)))
            .Returns(ResiliencePipeline<HttpResponseMessage>.Empty)
            .Verifiable();
 
        using var provider = services.BuildServiceProvider();
        var client = provider.GetRequiredService<IHttpClientFactory>().CreateClient("client");
        var pipelineProvider = provider.GetRequiredService<ResiliencePipelineProvider<HttpKey>>();
 
        // act
        await SendRequest(client, "https://dummy1", asynchronous);
 
        // assert
        providerMock.VerifyAll();
    }
 
    [Fact]
    public void AddResilienceHandler_AuthorityByCustomSelector_NotValidated()
    {
        // arrange
        var clientBuilder = new ServiceCollection().AddLogging().AddMetrics().AddRedaction()
            .AddHttpClient("my-client")
            .AddResilienceHandler("my-pipeline", ConfigureBuilder)
            .SelectPipelineBy(_ => _ => string.Empty);
 
        using var serviceProvider = clientBuilder.Services.BuildServiceProvider();
        var factory = serviceProvider.GetRequiredService<IHttpClientFactory>();
 
        Assert.NotNull(factory.CreateClient("my-client"));
    }
 
    private void ConfigureBuilder(ResiliencePipelineBuilder<HttpResponseMessage> builder) => builder.AddTimeout(TimeSpan.FromSeconds(1));
 
    private class TestMetricsEnricher : MeteringEnricher
    {
        public Dictionary<string, object?> Tags { get; } = [];
 
        public override void Enrich<TResult, TArgs>(in EnrichmentContext<TResult, TArgs> context)
        {
            foreach (var tag in context.Tags)
            {
                Tags[tag.Key] = tag.Value;
            }
        }
    }
}