File: HeaderPropagationMiddleware.cs
Web Access
Project: src\src\Middleware\HeaderPropagation\src\Microsoft.AspNetCore.HeaderPropagation.csproj (Microsoft.AspNetCore.HeaderPropagation)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Net.Http;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
 
namespace Microsoft.AspNetCore.HeaderPropagation;
 
/// <summary>
/// A Middleware for propagating headers to an <see cref="HttpClient"/>.
/// </summary>
public class HeaderPropagationMiddleware
{
    private readonly RequestDelegate _next;
    private readonly HeaderPropagationOptions _options;
    private readonly HeaderPropagationValues _values;
 
    /// <summary>
    /// Initializes a new instance of <see cref="HeaderPropagationMiddleware"/>.
    /// </summary>
    /// <param name="next">The next middleware in the pipeline.</param>
    /// <param name="options">The <see cref="IOptions{HeaderPropagationOptions}"/>.</param>
    /// <param name="values">
    /// The <see cref="HeaderPropagationValues"/> that stores the request headers to be propagated in an <see cref="System.Threading.AsyncLocal{T}"/>
    /// </param>
    public HeaderPropagationMiddleware(RequestDelegate next, IOptions<HeaderPropagationOptions> options, HeaderPropagationValues values)
    {
        ArgumentNullException.ThrowIfNull(next);
        ArgumentNullException.ThrowIfNull(options);
        ArgumentNullException.ThrowIfNull(values);
 
        _next = next;
        _options = options.Value;
        _values = values;
    }
 
    /// <summary>
    /// Executes the middleware that stores the request headers to be propagated in using <see cref="HeaderPropagationValues"/>.
    /// </summary>
    /// <param name="context">The <see cref="HttpContext"/> for the current request.</param>
    public Task Invoke(HttpContext context)
    {
        // We need to intialize the headers because the message handler will use this to detect misconfiguration.
        var headers = _values.Headers ??= new Dictionary<string, StringValues>(StringComparer.OrdinalIgnoreCase);
 
        // Perf: avoid foreach since we don't define a struct enumerator.
        var entries = _options.Headers;
        for (var i = 0; i < entries.Count; i++)
        {
            var entry = entries[i];
 
            // We intentionally process entries in order, and allow earlier entries to
            // take precedence over later entries when they have the same output name.
            if (!headers.ContainsKey(entry.CapturedHeaderName))
            {
                var value = GetValue(context, entry);
                if (!StringValues.IsNullOrEmpty(value))
                {
                    headers.Add(entry.CapturedHeaderName, value);
                }
            }
        }
 
        return _next.Invoke(context);
    }
 
    private static StringValues GetValue(HttpContext context, HeaderPropagationEntry entry)
    {
        context.Request.Headers.TryGetValue(entry.InboundHeaderName, out var value);
        if (entry.ValueFilter != null)
        {
            value = entry.ValueFilter(new HeaderPropagationContext(context, entry.InboundHeaderName, value));
        }
 
        return value;
    }
}