File: Matching\HostMatcherPolicy.cs
Web Access
Project: src\src\Http\Routing\src\Microsoft.AspNetCore.Routing.csproj (Microsoft.AspNetCore.Routing)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Globalization;
using System.Linq;
using Microsoft.AspNetCore.Http;
 
namespace Microsoft.AspNetCore.Routing.Matching;
 
/// <summary>
/// A <see cref="MatcherPolicy"/> that implements filtering and selection by
/// the host header of a request.
/// </summary>
public sealed class HostMatcherPolicy : MatcherPolicy, IEndpointComparerPolicy, INodeBuilderPolicy, IEndpointSelectorPolicy
{
    private const string WildcardHost = "*";
    private const string WildcardPrefix = "*.";
 
    // Run after HTTP methods, but before 'default'.
    /// <inheritdoc />
    public override int Order { get; } = -100;
 
    /// <inheritdoc />
    public IComparer<Endpoint> Comparer { get; } = new HostMetadataEndpointComparer();
 
    bool INodeBuilderPolicy.AppliesToEndpoints(IReadOnlyList<Endpoint> endpoints)
    {
        ArgumentNullException.ThrowIfNull(endpoints);
 
        return !ContainsDynamicEndpoints(endpoints) && AppliesToEndpointsCore(endpoints);
    }
 
    bool IEndpointSelectorPolicy.AppliesToEndpoints(IReadOnlyList<Endpoint> endpoints)
    {
        // When the node contains dynamic endpoints we can't make any assumptions.
        var applies = ContainsDynamicEndpoints(endpoints);
        if (applies)
        {
            // Run for the side-effect of validating metadata.
            AppliesToEndpointsCore(endpoints);
        }
 
        return applies;
    }
 
    private static bool AppliesToEndpointsCore(IReadOnlyList<Endpoint> endpoints)
    {
        return endpoints.Any(e =>
        {
            var hosts = e.Metadata.GetMetadata<IHostMetadata>()?.Hosts;
            if (hosts == null || hosts.Count == 0)
            {
                return false;
            }
 
            foreach (var host in hosts)
            {
                // Don't run policy on endpoints that match everything
                var key = CreateEdgeKey(host);
                if (!key.MatchesAll)
                {
                    return true;
                }
            }
 
            return false;
        });
    }
 
    /// <inheritdoc />
    public Task ApplyAsync(HttpContext httpContext, CandidateSet candidates)
    {
        ArgumentNullException.ThrowIfNull(httpContext);
        ArgumentNullException.ThrowIfNull(candidates);
 
        for (var i = 0; i < candidates.Count; i++)
        {
            if (!candidates.IsValidCandidate(i))
            {
                continue;
            }
 
            var hosts = candidates[i].Endpoint.Metadata.GetMetadata<IHostMetadata>()?.Hosts;
            if (hosts == null || hosts.Count == 0)
            {
                // Can match any host.
                continue;
            }
 
            var matched = false;
            var (requestHost, requestPort) = GetHostAndPort(httpContext);
            for (var j = 0; j < hosts.Count; j++)
            {
                var host = hosts[j].AsSpan();
                var port = ReadOnlySpan<char>.Empty;
 
                // Split into host and port
                var pivot = host.IndexOf(':');
                if (pivot >= 0)
                {
                    port = host.Slice(pivot + 1);
                    host = host.Slice(0, pivot);
                }
 
                if (host.Length == 0 || MemoryExtensions.Equals(host, WildcardHost, StringComparison.OrdinalIgnoreCase))
                {
                    // Can match any host
                }
                else if (
                    host.StartsWith(WildcardPrefix) &&
 
                    // Note that we only slice off the `*`. We want to match the leading `.` also.
                    MemoryExtensions.EndsWith(requestHost, host.Slice(WildcardHost.Length), StringComparison.OrdinalIgnoreCase))
                {
                    // Matches a suffix wildcard.
                }
                else if (MemoryExtensions.Equals(requestHost, host, StringComparison.OrdinalIgnoreCase))
                {
                    // Matches exactly
                }
                else
                {
                    // If we get here then the host doesn't match.
                    continue;
                }
 
                if (MemoryExtensions.Equals(port, WildcardHost, StringComparison.OrdinalIgnoreCase))
                {
                    // Port is a wildcard, we allow any port.
                }
                else if (port.Length > 0 && (!int.TryParse(port, out var parsed) || parsed != requestPort))
                {
                    // If we get here then the port doesn't match.
                    continue;
                }
 
                matched = true;
                break;
            }
 
            if (!matched)
            {
                candidates.SetValidity(i, false);
            }
        }
 
        return Task.CompletedTask;
    }
 
    private static EdgeKey CreateEdgeKey(string host)
    {
        if (host == null)
        {
            return EdgeKey.WildcardEdgeKey;
        }
 
        Span<Range> hostParts = stackalloc Range[3];
        var hostSpan = host.AsSpan();
        var length = hostSpan.Split(hostParts, ':');
        if (length == 1)
        {
            if (!hostSpan[hostParts[0]].IsEmpty)
            {
                return new EdgeKey(hostSpan[hostParts[0]].ToString(), null);
            }
        }
        if (length == 2)
        {
            if (!hostSpan[hostParts[0]].IsEmpty)
            {
                if (int.TryParse(hostSpan[hostParts[1]], out var port))
                {
                    return new EdgeKey(hostSpan[hostParts[0]].ToString(), port);
                }
                else if (hostSpan[hostParts[1]].Equals(WildcardHost, StringComparison.Ordinal))
                {
                    return new EdgeKey(hostSpan[hostParts[0]].ToString(), null);
                }
            }
        }
 
        throw new InvalidOperationException($"Could not parse host: {host}");
    }
 
    /// <inheritdoc />
    public IReadOnlyList<PolicyNodeEdge> GetEdges(IReadOnlyList<Endpoint> endpoints)
    {
        ArgumentNullException.ThrowIfNull(endpoints);
 
        // The algorithm here is designed to be preserve the order of the endpoints
        // while also being relatively simple. Preserving order is important.
 
        // First, build a dictionary of all of the hosts that are included
        // at this node.
        //
        // For now we're just building up the set of keys. We don't add any endpoints
        // to lists now because we don't want ordering problems.
        var edges = new Dictionary<EdgeKey, List<Endpoint>>();
        for (var i = 0; i < endpoints.Count; i++)
        {
            var endpoint = endpoints[i];
            var hosts = GetEdgeKeys(endpoint);
            if (hosts is null || hosts.Length == 0)
            {
                hosts = new[] { EdgeKey.WildcardEdgeKey };
            }
 
            for (var j = 0; j < hosts.Length; j++)
            {
                var host = hosts[j];
                if (!edges.ContainsKey(host))
                {
                    edges.Add(host, new List<Endpoint>());
                }
            }
        }
 
        // Now in a second loop, add endpoints to these lists. We've enumerated all of
        // the states, so we want to see which states this endpoint matches.
        for (var i = 0; i < endpoints.Count; i++)
        {
            var endpoint = endpoints[i];
 
            var endpointKeys = GetEdgeKeys(endpoint);
            if (endpointKeys is null || endpointKeys.Length == 0)
            {
                // OK this means that this endpoint matches *all* hosts.
                // So, loop and add it to all states.
                foreach (var kvp in edges)
                {
                    kvp.Value.Add(endpoint);
                }
            }
            else
            {
                // OK this endpoint matches specific hosts
                foreach (var kvp in edges)
                {
                    // The edgeKey maps to a possible request header value
                    var edgeKey = kvp.Key;
 
                    for (var j = 0; j < endpointKeys.Length; j++)
                    {
                        var endpointKey = endpointKeys[j];
 
                        if (edgeKey.Equals(endpointKey))
                        {
                            kvp.Value.Add(endpoint);
                            break;
                        }
                        else if (edgeKey.HasHostWildcard && endpointKey.HasHostWildcard &&
                            edgeKey.Port == endpointKey.Port && edgeKey.MatchHost(endpointKey.Host))
                        {
                            kvp.Value.Add(endpoint);
                            break;
                        }
                    }
                }
            }
        }
 
        var result = new PolicyNodeEdge[edges.Count];
        var index = 0;
        foreach (var kvp in edges)
        {
            result[index] = new PolicyNodeEdge(kvp.Key, kvp.Value);
            index++;
        }
        return result;
    }
 
    private static EdgeKey[]? GetEdgeKeys(Endpoint endpoint)
    {
        List<EdgeKey>? result = null;
        var hostMetadata = endpoint.Metadata.GetMetadata<IHostMetadata>();
        if (hostMetadata is not null)
        {
            foreach (var host in hostMetadata.Hosts)
            {
                (result ??= new()).Add(CreateEdgeKey(host));
            }
        }
        return result?.ToArray();
    }
 
    /// <inheritdoc />
    public PolicyJumpTable BuildJumpTable(int exitDestination, IReadOnlyList<PolicyJumpTableEdge> edges)
    {
        ArgumentNullException.ThrowIfNull(edges);
 
        // Since our 'edges' can have wildcards, we do a sort based on how wildcard-ey they
        // are then then execute them in linear order.
        var ordered = new (EdgeKey host, int destination)[edges.Count];
        for (var i = 0; i < edges.Count; i++)
        {
            PolicyJumpTableEdge e = edges[i];
            ordered[i] = (host: (EdgeKey)e.State, destination: e.Destination);
        }
        Array.Sort(ordered, static (left, right) => GetScore(left.host).CompareTo(GetScore(right.host)));
 
        return new HostPolicyJumpTable(exitDestination, ordered);
    }
 
    private static int GetScore(in EdgeKey key)
    {
        // Higher score == lower priority.
        if (key.MatchesHost && !key.HasHostWildcard && key.MatchesPort)
        {
            return 1; // Has host AND port, e.g. www.consoto.com:8080
        }
        else if (key.MatchesHost && !key.HasHostWildcard)
        {
            return 2; // Has host, e.g. www.consoto.com
        }
        else if (key.MatchesHost && key.MatchesPort)
        {
            return 3; // Has wildcard host AND port, e.g. *.consoto.com:8080
        }
        else if (key.MatchesHost)
        {
            return 4; // Has wildcard host, e.g. *.consoto.com
        }
        else if (key.MatchesPort)
        {
            return 5; // Has port, e.g. *:8080
        }
        else
        {
            return 6; // Has neither, e.g. *:* (or no metadata)
        }
    }
 
    private static (string host, int? port) GetHostAndPort(HttpContext httpContext)
    {
        var hostString = httpContext.Request.Host;
        if (hostString.Port != null)
        {
            return (hostString.Host, hostString.Port);
        }
        else if (string.Equals("https", httpContext.Request.Scheme, StringComparison.OrdinalIgnoreCase))
        {
            return (hostString.Host, 443);
        }
        else if (string.Equals("http", httpContext.Request.Scheme, StringComparison.OrdinalIgnoreCase))
        {
            return (hostString.Host, 80);
        }
        else
        {
            return (hostString.Host, null);
        }
    }
 
    private sealed class HostMetadataEndpointComparer : EndpointMetadataComparer<IHostMetadata>
    {
        protected override int CompareMetadata(IHostMetadata? x, IHostMetadata? y)
        {
            // Ignore the metadata if it has an empty list of hosts.
            return base.CompareMetadata(
                x?.Hosts.Count > 0 ? x : null,
                y?.Hosts.Count > 0 ? y : null);
        }
    }
 
    private sealed class HostPolicyJumpTable : PolicyJumpTable
    {
        private readonly (EdgeKey host, int destination)[] _destinations;
        private readonly int _exitDestination;
 
        public HostPolicyJumpTable(int exitDestination, (EdgeKey host, int destination)[] destinations)
        {
            _exitDestination = exitDestination;
            _destinations = destinations;
        }
 
        public override int GetDestination(HttpContext httpContext)
        {
            // HostString can allocate when accessing the host or port
            // Store host and port locally and reuse
            var (host, port) = GetHostAndPort(httpContext);
 
            var destinations = _destinations;
            for (var i = 0; i < destinations.Length; i++)
            {
                var destination = destinations[i];
 
                if ((!destination.host.MatchesPort || destination.host.Port == port) &&
                    destination.host.MatchHost(host))
                {
                    return destination.destination;
                }
            }
 
            return _exitDestination;
        }
    }
 
    private readonly struct EdgeKey : IEquatable<EdgeKey>, IComparable<EdgeKey>, IComparable
    {
        internal static readonly EdgeKey WildcardEdgeKey = new EdgeKey(null, null);
 
        public readonly int? Port;
        public readonly string Host;
 
        private readonly string? _wildcardEndsWith;
 
        public EdgeKey(string? host, int? port)
        {
            Host = host ?? WildcardHost;
            Port = port;
 
            HasHostWildcard = Host.StartsWith(WildcardPrefix, StringComparison.Ordinal);
            _wildcardEndsWith = HasHostWildcard ? Host.Substring(1) : null;
        }
 
        public bool HasHostWildcard { get; }
 
        public bool MatchesHost => !string.Equals(Host, WildcardHost, StringComparison.Ordinal);
 
        public bool MatchesPort => Port != null;
 
        public bool MatchesAll => !MatchesHost && !MatchesPort;
 
        public int CompareTo(EdgeKey other)
        {
            var result = Comparer<string>.Default.Compare(Host, other.Host);
            if (result != 0)
            {
                return result;
            }
 
            return Comparer<int?>.Default.Compare(Port, other.Port);
        }
 
        public int CompareTo(object? obj)
        {
            return CompareTo((EdgeKey)obj!);
        }
 
        public bool Equals(EdgeKey other)
        {
            return string.Equals(Host, other.Host, StringComparison.Ordinal) && Port == other.Port;
        }
 
        public bool MatchHost(string host)
        {
            if (MatchesHost)
            {
                if (HasHostWildcard)
                {
                    return host.EndsWith(_wildcardEndsWith!, StringComparison.OrdinalIgnoreCase);
                }
                else
                {
                    return string.Equals(host, Host, StringComparison.OrdinalIgnoreCase);
                }
            }
 
            return true;
        }
 
        public override int GetHashCode()
        {
            return (Host?.GetHashCode() ?? 0) ^ (Port?.GetHashCode() ?? 0);
        }
 
        public override bool Equals(object? obj)
        {
            if (obj is EdgeKey key)
            {
                return Equals(key);
            }
 
            return false;
        }
 
        public override string ToString()
        {
            return $"{Host}:{Port?.ToString(CultureInfo.InvariantCulture) ?? WildcardHost}";
        }
    }
}