File: ForwardedHeadersMiddleware.cs
Web Access
Project: src\src\Middleware\HttpOverrides\src\Microsoft.AspNetCore.HttpOverrides.csproj (Microsoft.AspNetCore.HttpOverrides)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Linq;
using System.Net;
using System.Runtime.CompilerServices;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
 
namespace Microsoft.AspNetCore.HttpOverrides;
 
/// <summary>
/// A middleware for forwarding proxied headers onto the current request.
/// </summary>
public class ForwardedHeadersMiddleware
{
    private readonly ForwardedHeadersOptions _options;
    private readonly RequestDelegate _next;
    private readonly ILogger _logger;
    private bool _allowAllHosts;
    private IList<StringSegment>? _allowedHosts;
 
    // RFC 3986 scheme = ALPHA * (ALPHA / DIGIT / "+" / "-" / ".")
    private static readonly SearchValues<char> SchemeChars =
        SearchValues.Create("+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
 
    // Host Matches Http.Sys and Kestrel
    // Host Matches RFC 3986 except "*" / "+" / "," / ";" / "=" and "%" HEXDIG HEXDIG which are not allowed by Http.Sys
    private static readonly SearchValues<char> HostChars =
        SearchValues.Create("!$&'()-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz~");
 
    // 0-9 / A-F / a-f / ":" / "."
    private static readonly SearchValues<char> Ipv6HostChars =
        SearchValues.Create(".0123456789:ABCDEFabcdef");
 
    /// <summary>
    /// Create a new <see cref="ForwardedHeadersMiddleware"/>.
    /// </summary>
    /// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
    /// <param name="loggerFactory">The <see cref="ILoggerFactory"/> used for logging.</param>
    /// <param name="options">The <see cref="ForwardedHeadersOptions"/> for configuring the middleware.</param>
    public ForwardedHeadersMiddleware(RequestDelegate next, ILoggerFactory loggerFactory, IOptions<ForwardedHeadersOptions> options)
    {
        ArgumentNullException.ThrowIfNull(next);
        ArgumentNullException.ThrowIfNull(loggerFactory);
        ArgumentNullException.ThrowIfNull(options);
 
        // Make sure required options is not null or whitespace
        ArgumentException.ThrowIfNullOrWhiteSpace(options.Value.ForwardedForHeaderName);
        ArgumentException.ThrowIfNullOrWhiteSpace(options.Value.ForwardedHostHeaderName);
        ArgumentException.ThrowIfNullOrWhiteSpace(options.Value.ForwardedProtoHeaderName);
        ArgumentException.ThrowIfNullOrWhiteSpace(options.Value.ForwardedPrefixHeaderName);
        ArgumentException.ThrowIfNullOrWhiteSpace(options.Value.OriginalForHeaderName);
        ArgumentException.ThrowIfNullOrWhiteSpace(options.Value.OriginalHostHeaderName);
        ArgumentException.ThrowIfNullOrWhiteSpace(options.Value.OriginalProtoHeaderName);
        ArgumentException.ThrowIfNullOrWhiteSpace(options.Value.OriginalPrefixHeaderName);
 
        _options = options.Value;
        _logger = loggerFactory.CreateLogger<ForwardedHeadersMiddleware>();
        _next = next;
 
        PreProcessHosts();
    }
 
    private void PreProcessHosts()
    {
        if (_options.AllowedHosts == null || _options.AllowedHosts.Count == 0)
        {
            _allowAllHosts = true;
            return;
        }
 
        var allowedHosts = new List<StringSegment>();
        foreach (var entry in _options.AllowedHosts)
        {
            // Punycode. Http.Sys requires you to register Unicode hosts, but the headers contain punycode.
            var host = new HostString(entry).ToUriComponent();
 
            if (IsTopLevelWildcard(host))
            {
                // Disable filtering
                _allowAllHosts = true;
                return;
            }
 
            if (!allowedHosts.Contains(host, StringSegmentComparer.OrdinalIgnoreCase))
            {
                allowedHosts.Add(host);
            }
        }
 
        _allowedHosts = allowedHosts;
    }
 
    private static bool IsTopLevelWildcard(string host)
    {
        return (string.Equals("*", host, StringComparison.Ordinal) // HttpSys wildcard
                       || string.Equals("[::]", host, StringComparison.Ordinal) // Kestrel wildcard, IPv6 Any
                       || string.Equals("0.0.0.0", host, StringComparison.Ordinal)); // IPv4 Any
    }
 
    /// <summary>
    /// Executes the middleware.
    /// </summary>
    /// <param name="context">The <see cref="HttpContext"/> for the current request.</param>
    public Task Invoke(HttpContext context)
    {
        ApplyForwarders(context);
        return _next(context);
    }
 
    /// <summary>
    /// Forward the proxied headers to the given <see cref="HttpContext"/>.
    /// </summary>
    /// <param name="context">The <see cref="HttpContext"/>.</param>
    public void ApplyForwarders(HttpContext context)
    {
        // Gather expected headers.
        string[]? forwardedFor = null, forwardedProto = null, forwardedHost = null, forwardedPrefix = null;
        bool checkFor = false, checkProto = false, checkHost = false, checkPrefix = false;
        int entryCount = 0;
 
        var request = context.Request;
        var requestHeaders = context.Request.Headers;
        if (_options.ForwardedHeaders.HasFlag(ForwardedHeaders.XForwardedFor))
        {
            checkFor = true;
            forwardedFor = requestHeaders.GetCommaSeparatedValues(_options.ForwardedForHeaderName);
            entryCount = Math.Max(forwardedFor.Length, entryCount);
        }
 
        if (_options.ForwardedHeaders.HasFlag(ForwardedHeaders.XForwardedProto))
        {
            checkProto = true;
            forwardedProto = requestHeaders.GetCommaSeparatedValues(_options.ForwardedProtoHeaderName);
            if (_options.RequireHeaderSymmetry && checkFor && forwardedFor!.Length != forwardedProto.Length)
            {
                _logger.LogWarning(1, "Parameter count mismatch between X-Forwarded-For and X-Forwarded-Proto.");
                return;
            }
            entryCount = Math.Max(forwardedProto.Length, entryCount);
        }
 
        if (_options.ForwardedHeaders.HasFlag(ForwardedHeaders.XForwardedHost))
        {
            checkHost = true;
            forwardedHost = requestHeaders.GetCommaSeparatedValues(_options.ForwardedHostHeaderName);
            if (_options.RequireHeaderSymmetry
                && ((checkFor && forwardedFor!.Length != forwardedHost.Length)
                    || (checkProto && forwardedProto!.Length != forwardedHost.Length)))
            {
                _logger.LogWarning(1, "Parameter count mismatch between X-Forwarded-Host and X-Forwarded-For or X-Forwarded-Proto.");
                return;
            }
            entryCount = Math.Max(forwardedHost.Length, entryCount);
        }
 
        if (_options.ForwardedHeaders.HasFlag(ForwardedHeaders.XForwardedPrefix))
        {
            checkPrefix = true;
            forwardedPrefix = requestHeaders.GetCommaSeparatedValues(_options.ForwardedPrefixHeaderName);
            if (_options.RequireHeaderSymmetry
                && ((checkFor && forwardedFor!.Length != forwardedPrefix.Length)
                    || (checkProto && forwardedProto!.Length != forwardedPrefix.Length)
                    || (checkHost && forwardedHost!.Length != forwardedPrefix.Length)))
            {
                _logger.LogWarning(1, "Parameter count mismatch between X-Forwarded-Prefix and X-Forwarded-Host and X-Forwarded-For or X-Forwarded-Proto.");
                return;
            }
            entryCount = Math.Max(forwardedPrefix.Length, entryCount);
        }
 
        // Apply ForwardLimit, if any
        if (_options.ForwardLimit.HasValue && entryCount > _options.ForwardLimit)
        {
            entryCount = _options.ForwardLimit.Value;
        }
 
        // Group the data together.
        var sets = new SetOfForwarders[entryCount];
        for (int i = 0; i < sets.Length; i++)
        {
            // They get processed in reverse order, right to left.
            var set = new SetOfForwarders();
            if (checkFor && i < forwardedFor!.Length)
            {
                set.IpAndPortText = forwardedFor[forwardedFor.Length - i - 1];
            }
            if (checkProto && i < forwardedProto!.Length)
            {
                set.Scheme = forwardedProto[forwardedProto.Length - i - 1];
            }
            if (checkHost && i < forwardedHost!.Length)
            {
                set.Host = forwardedHost[forwardedHost.Length - i - 1];
            }
            if (checkPrefix && i < forwardedPrefix!.Length)
            {
                set.Prefix = forwardedPrefix[forwardedPrefix.Length - i - 1];
            }
            sets[i] = set;
        }
 
        // Gather initial values
        var connection = context.Connection;
        var currentValues = new SetOfForwarders()
        {
            RemoteIpAndPort = connection.RemoteIpAddress != null ? new IPEndPoint(connection.RemoteIpAddress, connection.RemotePort) : null,
            // Host and Scheme initial values are never inspected, no need to set them here.
        };
 
        var checkKnownIps = _options.KnownNetworks.Count > 0 || _options.KnownProxies.Count > 0;
        bool applyChanges = false;
        int entriesConsumed = 0;
 
        for (; entriesConsumed < sets.Length; entriesConsumed++)
        {
            var set = sets[entriesConsumed];
            if (checkFor)
            {
                // For the first instance, allow remoteIp to be null for servers that don't support it natively.
                if (currentValues.RemoteIpAndPort != null && checkKnownIps && !CheckKnownAddress(currentValues.RemoteIpAndPort.Address))
                {
                    // Stop at the first unknown remote IP, but still apply changes processed so far.
                    if (_logger.IsEnabled(LogLevel.Debug))
                    {
                        _logger.LogDebug(1, "Unknown proxy: {RemoteIpAndPort}", currentValues.RemoteIpAndPort);
                    }
                    break;
                }
 
                if (IPEndPoint.TryParse(set.IpAndPortText, out var parsedEndPoint))
                {
                    applyChanges = true;
                    set.RemoteIpAndPort = parsedEndPoint;
                    currentValues.IpAndPortText = set.IpAndPortText;
                    currentValues.RemoteIpAndPort = set.RemoteIpAndPort;
                }
                else if (!string.IsNullOrEmpty(set.IpAndPortText))
                {
                    // Stop at the first unparsable IP, but still apply changes processed so far.
                    if (_logger.IsEnabled(LogLevel.Debug))
                    {
                        _logger.LogDebug(1, "Unparsable IP: {IpAndPortText}", set.IpAndPortText);
                    }
                    break;
                }
                else if (_options.RequireHeaderSymmetry)
                {
                    _logger.LogWarning(2, "Missing forwarded IPAddress.");
                    return;
                }
            }
 
            if (checkProto)
            {
                if (!string.IsNullOrEmpty(set.Scheme) && set.Scheme.AsSpan().IndexOfAnyExcept(SchemeChars) < 0)
                {
                    applyChanges = true;
                    currentValues.Scheme = set.Scheme;
                }
                else if (_options.RequireHeaderSymmetry)
                {
                    _logger.LogWarning(3, $"Forwarded scheme is not present, this is required by {nameof(_options.RequireHeaderSymmetry)}");
                    return;
                }
            }
 
            if (checkHost)
            {
                if (!string.IsNullOrEmpty(set.Host) && TryValidateHost(set.Host)
                    && (_allowAllHosts || HostString.MatchesAny(set.Host, _allowedHosts!)))
                {
                    applyChanges = true;
                    currentValues.Host = set.Host;
                }
                else if (_options.RequireHeaderSymmetry)
                {
                    _logger.LogWarning(4, $"Incorrect number of x-forwarded-host header values, see {nameof(_options.RequireHeaderSymmetry)}.");
                    return;
                }
            }
 
            if (checkPrefix)
            {
                if (!string.IsNullOrEmpty(set.Prefix) && set.Prefix[0] == '/')
                {
                    applyChanges = true;
                    currentValues.Prefix = set.Prefix;
                }
                else if (_options.RequireHeaderSymmetry)
                {
                    _logger.LogWarning(5, $"Incorrect number of x-forwarded-prefix header values, see {nameof(_options.RequireHeaderSymmetry)}");
                    return;
                }
            }
        }
 
        if (applyChanges)
        {
            if (checkFor && currentValues.RemoteIpAndPort != null)
            {
                if (connection.RemoteIpAddress != null)
                {
                    // Save the original
                    requestHeaders[_options.OriginalForHeaderName] = new IPEndPoint(connection.RemoteIpAddress, connection.RemotePort).ToString();
                }
                if (forwardedFor!.Length > entriesConsumed)
                {
                    // Truncate the consumed header values
                    requestHeaders[_options.ForwardedForHeaderName] =
                        TruncateConsumedHeaderValues(forwardedFor, entriesConsumed);
                }
                else
                {
                    // All values were consumed
                    requestHeaders.Remove(_options.ForwardedForHeaderName);
                }
                connection.RemoteIpAddress = currentValues.RemoteIpAndPort.Address;
                connection.RemotePort = currentValues.RemoteIpAndPort.Port;
            }
 
            if (checkProto && currentValues.Scheme != null)
            {
                // Save the original
                requestHeaders[_options.OriginalProtoHeaderName] = request.Scheme;
                if (forwardedProto!.Length > entriesConsumed)
                {
                    // Truncate the consumed header values
                    requestHeaders[_options.ForwardedProtoHeaderName] =
                        TruncateConsumedHeaderValues(forwardedProto, entriesConsumed);
                }
                else
                {
                    // All values were consumed
                    requestHeaders.Remove(_options.ForwardedProtoHeaderName);
                }
                request.Scheme = currentValues.Scheme;
            }
 
            if (checkHost && currentValues.Host != null)
            {
                // Save the original
                requestHeaders[_options.OriginalHostHeaderName] = request.Host.ToString();
                if (forwardedHost!.Length > entriesConsumed)
                {
                    // Truncate the consumed header values
                    requestHeaders[_options.ForwardedHostHeaderName] =
                        TruncateConsumedHeaderValues(forwardedHost, entriesConsumed);
                }
                else
                {
                    // All values were consumed
                    requestHeaders.Remove(_options.ForwardedHostHeaderName);
                }
                request.Host = HostString.FromUriComponent(currentValues.Host);
            }
 
            if (checkPrefix && currentValues.Prefix != null)
            {
                if (request.PathBase.HasValue)
                {
                    // Save the original
                    requestHeaders[_options.OriginalPrefixHeaderName] = request.PathBase.ToString();
                }
 
                if (forwardedPrefix!.Length > entriesConsumed)
                {
                    // Truncate the consumed header values
                    requestHeaders[_options.ForwardedPrefixHeaderName] =
                        TruncateConsumedHeaderValues(forwardedPrefix, entriesConsumed);
                }
                else
                {
                    // All values were consumed
                    requestHeaders.Remove(_options.ForwardedPrefixHeaderName);
                }
 
                request.PathBase = PathString.FromUriComponent(currentValues.Prefix);
            }
        }
    }
 
    private bool CheckKnownAddress(IPAddress address)
    {
        if (address.IsIPv4MappedToIPv6)
        {
            var ipv4Address = address.MapToIPv4();
            if (CheckKnownAddress(ipv4Address))
            {
                return true;
            }
        }
        if (_options.KnownProxies.Contains(address))
        {
            return true;
        }
        foreach (var network in _options.KnownNetworks)
        {
            if (network.Contains(address))
            {
                return true;
            }
        }
        return false;
    }
 
    private struct SetOfForwarders
    {
        public string IpAndPortText;
        public IPEndPoint? RemoteIpAndPort;
        public string Host;
        public string Scheme;
        public string Prefix;
    }
 
    // Empty was checked for by the caller
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static bool TryValidateHost(string host)
    {
        if (host[0] == '[')
        {
            return TryValidateIPv6Host(host);
        }
 
        if (host[0] == ':')
        {
            // Only a port
            return false;
        }
 
        var firstNonHostCharIdx = host.AsSpan().IndexOfAnyExcept(HostChars);
        if (firstNonHostCharIdx == -1)
        {
            // no port
            return true;
        }
        else
        {
            return TryValidateHostPort(host, firstNonHostCharIdx);
        }
    }
 
    // The lead '[' was already checked
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static bool TryValidateIPv6Host(string hostText)
    {
        var host = hostText.AsSpan(1);
 
        var hostEndIdx = host.IndexOfAnyExcept(Ipv6HostChars);
        if ((uint)hostEndIdx >= (uint)host.Length || // No ']'. The uint cast is there to eliminate the
                                                     // bounds check on the 'host[hostEndIdx]' access below.
            host[hostEndIdx] != ']' || // We found an invalid host character
            hostEndIdx < 3) // [::1] is the shortest valid IPv6 host
        {
            return false;
        }
 
        // If there's nothing left, we're good. If there's more, validate it as a port.
        // +2 to skip the '[' and ']' (the '[' wasn't included in hostEndIdx because we
        // cut it off in the AsSpan above).
        return (hostEndIdx + 2 == hostText.Length) || TryValidateHostPort(hostText, hostEndIdx + 2);
    }
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static bool TryValidateHostPort(string hostText, int offset)
    {
        if (hostText[offset] != ':' || hostText.Length == offset + 1)
        {
            // Must have at least one number after the colon if present.
            return false;
        }
 
        return hostText.AsSpan(offset + 1).IndexOfAnyExceptInRange('0', '9') < 0;
    }
 
    private static string[] TruncateConsumedHeaderValues(string[] forwarded, int entriesConsumed)
    {
        var newLength = forwarded.Length - entriesConsumed;
        var remaining = new string[newLength];
        Array.Copy(forwarded, remaining, newLength);
        return remaining;
    }
}