File: HeaderPropagationIntegrationTest.cs
Web Access
Project: src\src\Middleware\HeaderPropagation\test\Microsoft.AspNetCore.HeaderPropagation.Tests.csproj (Microsoft.AspNetCore.HeaderPropagation.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
 
namespace Microsoft.AspNetCore.HeaderPropagation.Tests;
 
public class HeaderPropagationIntegrationTest
{
    [Fact]
    public async Task HeaderPropagation_WithoutMiddleware_Throws()
    {
        // Arrange
        Exception captured = null;
 
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .ConfigureServices(services =>
                {
                    services.AddHttpClient("test").AddHeaderPropagation();
                    services.AddHeaderPropagation(options =>
                    {
                        options.Headers.Add("X-TraceId");
                    });
                })
                .Configure(app =>
                {
                    // note: no header propagation middleware
 
                    app.Run(async context =>
                    {
                        try
                        {
                            var client = context.RequestServices.GetRequiredService<IHttpClientFactory>().CreateClient("test");
                            await client.GetAsync("http://localhost/"); // will throw
                        }
                        catch (Exception ex)
                        {
                            captured = ex;
                        }
                    });
                });
            }).Build();
 
        await host.StartAsync();
 
        var server = host.GetTestServer();
        var client = server.CreateClient();
 
        var request = new HttpRequestMessage();
 
        // Act
        var response = await client.SendAsync(request);
 
        // Assert
        Assert.Equal(HttpStatusCode.OK, response.StatusCode);
        Assert.IsType<InvalidOperationException>(captured);
        Assert.Equal(
            "The HeaderPropagationValues.Headers property has not been initialized. Register the header propagation middleware " +
            "by adding 'app.UseHeaderPropagation()' in the 'Configure(...)' method. Header propagation can only be used within " +
            "the context of an HTTP request.",
            captured.Message);
    }
 
    [Fact]
    public async Task HeaderPropagation_OutsideOfIncomingRequest_Throws()
    {
        // Arrange
        var services = new ServiceCollection();
        services.AddHttpClient("test").AddHeaderPropagation();
        services.AddHeaderPropagation(options =>
        {
            options.Headers.Add("X-TraceId");
        });
        var serviceProvider = services.BuildServiceProvider();
 
        // Act & Assert
        var client = serviceProvider.GetRequiredService<IHttpClientFactory>().CreateClient("test");
        var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => client.GetAsync("http://localhost/"));
        Assert.Equal(
            "The HeaderPropagationValues.Headers property has not been initialized. Register the header propagation middleware " +
            "by adding 'app.UseHeaderPropagation()' in the 'Configure(...)' method. Header propagation can only be used within " +
            "the context of an HTTP request.",
            exception.Message);
    }
 
    [Fact]
    public async Task HeaderInRequest_AddCorrectValue()
    {
        // Arrange
        var handler = new SimpleHandler();
        using var host = await CreateHost(c =>
            c.Headers.Add("in", "out"),
            handler);
        var server = host.GetTestServer();
        var client = server.CreateClient();
 
        var request = new HttpRequestMessage();
        request.Headers.Add("in", "test");
 
        // Act
        var response = await client.SendAsync(request);
 
        // Assert
        Assert.Equal(HttpStatusCode.OK, response.StatusCode);
        Assert.True(handler.Headers.Contains("out"));
        Assert.Equal(new[] { "test" }, handler.Headers.GetValues("out"));
    }
 
    [Fact]
    public async Task MultipleHeaders_HeadersInRequest_AddAllHeaders()
    {
        // Arrange
        var handler = new SimpleHandler();
        using var host = await CreateHost(c =>
            {
                c.Headers.Add("first");
                c.Headers.Add("second");
            },
            handler);
        var server = host.GetTestServer();
        var client = server.CreateClient();
 
        var request = new HttpRequestMessage();
        request.Headers.Add("first", "value");
        request.Headers.Add("second", "other");
 
        // Act
        var response = await client.SendAsync(request);
 
        // Assert
        Assert.Equal(HttpStatusCode.OK, response.StatusCode);
        Assert.True(handler.Headers.Contains("first"));
        Assert.Equal(new[] { "value" }, handler.Headers.GetValues("first"));
        Assert.True(handler.Headers.Contains("second"));
        Assert.Equal(new[] { "other" }, handler.Headers.GetValues("second"));
    }
 
    [Fact]
    public async Task Builder_UseHeaderPropagation_Without_AddHeaderPropagation_Throws()
    {
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseHeaderPropagation();
                });
            }).Build();
 
        var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => host.StartAsync());
        Assert.Equal(
            "Unable to find the required services. Please add all the required services by calling 'IServiceCollection.AddHeaderPropagation' inside the call to 'ConfigureServices(...)' in the application startup code.",
            exception.Message);
    }
 
    [Fact]
    public async Task HeaderInRequest_OverrideHeaderPerClient_AddCorrectValue()
    {
        // Arrange
        var handler = new SimpleHandler();
        using var host = await CreateHost(
            c => c.Headers.Add("in", "out"),
            handler,
            c => c.Headers.Add("out", "different"));
        var server = host.GetTestServer();
        var client = server.CreateClient();
 
        var request = new HttpRequestMessage();
        request.Headers.Add("in", "test");
 
        // Act
        var response = await client.SendAsync(request);
 
        // Assert
        Assert.Equal(HttpStatusCode.OK, response.StatusCode);
        Assert.True(handler.Headers.Contains("different"));
        Assert.Equal(new[] { "test" }, handler.Headers.GetValues("different"));
    }
 
    private async Task<IHost> CreateHost(Action<HeaderPropagationOptions> configure, HttpMessageHandler primaryHandler, Action<HeaderPropagationMessageHandlerOptions> configureClient = null)
    {
        var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseHeaderPropagation();
                    app.UseMiddleware<SimpleMiddleware>();
                })
                .ConfigureServices(services =>
                {
                    services.AddHeaderPropagation(configure);
                    var client = services.AddHttpClient("example.com", c => c.BaseAddress = new Uri("http://example.com"))
                        .ConfigurePrimaryHttpMessageHandler(() => primaryHandler);
 
                    if (configureClient != null)
                    {
                        client.AddHeaderPropagation(configureClient);
                    }
                    else
                    {
                        client.AddHeaderPropagation();
                    }
                });
            }).Build();
 
        await host.StartAsync();
 
        return host;
    }
 
    private class SimpleHandler : DelegatingHandler
    {
        public HttpHeaders Headers { get; private set; }
 
        protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
        {
            Headers = request.Headers;
            return Task.FromResult(new HttpResponseMessage());
        }
    }
 
    private class SimpleMiddleware
    {
        private readonly IHttpClientFactory _httpClientFactory;
 
        public SimpleMiddleware(RequestDelegate next, IHttpClientFactory httpClientFactory)
        {
            _httpClientFactory = httpClientFactory;
        }
 
        public Task InvokeAsync(HttpContext _)
        {
            var client = _httpClientFactory.CreateClient("example.com");
            return client.GetAsync("");
        }
    }
}