File: HostFilteringMiddleware.cs
Web Access
Project: src\src\Middleware\HostFiltering\src\Microsoft.AspNetCore.HostFiltering.csproj (Microsoft.AspNetCore.HostFiltering)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Runtime.CompilerServices;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
 
namespace Microsoft.AspNetCore.HostFiltering;
 
/// <summary>
/// A middleware used to filter requests by their Host header.
/// </summary>
public class HostFilteringMiddleware
{
    // Matches Http.Sys.
    private static readonly byte[] DefaultResponse = (
        "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.01//EN\"\"http://www.w3.org/TR/html4/strict.dtd\">\r\n"u8
            + "<HTML><HEAD><TITLE>Bad Request</TITLE>\r\n"u8
            + "<META HTTP-EQUIV=\"Content-Type\" Content=\"text/html; charset=us-ascii\"></ HEAD >\r\n"u8
            + "<BODY><h2>Bad Request - Invalid Hostname</h2>\r\n"u8
            + "<hr><p>HTTP Error 400. The request hostname is invalid.</p>\r\n"u8
            + "</BODY></HTML>"u8
        ).ToArray();
 
    private readonly RequestDelegate _next;
    private readonly ILogger<HostFilteringMiddleware> _logger;
    private readonly MiddlewareConfigurationManager _middlewareConfigurationManager;
 
    /// <summary>
    /// A middleware used to filter requests by their Host header.
    /// </summary>
    /// <param name="next"></param>
    /// <param name="logger"></param>
    /// <param name="optionsMonitor"></param>
    public HostFilteringMiddleware(
        RequestDelegate next,
        ILogger<HostFilteringMiddleware> logger,
        IOptionsMonitor<HostFilteringOptions> optionsMonitor
    )
    {
        _next = next ?? throw new ArgumentNullException(nameof(next));
        _logger = logger ?? throw new ArgumentNullException(nameof(logger));
        _middlewareConfigurationManager = new MiddlewareConfigurationManager(optionsMonitor ?? throw new ArgumentNullException(nameof(optionsMonitor)), logger);
    }
 
    /// <summary>
    /// Processes requests
    /// </summary>
    /// <param name="context"></param>
    /// <returns></returns>
    public Task Invoke(HttpContext context)
    {
        var middlewareConfiguration = _middlewareConfigurationManager.GetLatestMiddlewareConfiguration();
 
        if (!CheckHost(context, middlewareConfiguration))
        {
            return HostValidationFailed(context, middlewareConfiguration);
        }
 
        return _next(context);
    }
 
    [MethodImpl(MethodImplOptions.NoInlining)]
    private static Task HostValidationFailed(HttpContext context, MiddlewareConfiguration middlewareConfiguration)
    {
        context.Response.StatusCode = 400;
 
        if (middlewareConfiguration.IncludeFailureMessage)
        {
            context.Response.ContentLength = DefaultResponse.Length;
            context.Response.ContentType = "text/html";
            return context.Response.Body.WriteAsync(DefaultResponse, 0, DefaultResponse.Length);
        }
 
        return Task.CompletedTask;
    }
 
    // This does not duplicate format validations that are expected to be performed by the host.
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private bool CheckHost(HttpContext context, MiddlewareConfiguration middlewareConfiguration)
    {
        var host = context.Request.Headers.Host.ToString();
 
        if (host.Length == 0)
        {
            return IsEmptyHostAllowed(context, middlewareConfiguration);
        }
 
        if (middlewareConfiguration.AllowAnyNonEmptyHost)
        {
            _logger.AllHostsAllowed();
            return true;
        }
 
        return CheckHostInAllowList(middlewareConfiguration.AllowedHosts, host);
    }
 
    [MethodImpl(MethodImplOptions.NoInlining)]
    private bool CheckHostInAllowList(IList<StringSegment>? allowedHosts, string host)
    {
        if (allowedHosts is not null && HostString.MatchesAny(new StringSegment(host), allowedHosts))
        {
            _logger.AllowedHostMatched(host);
            return true;
        }
 
        _logger.NoAllowedHostMatched(host);
 
        return false;
    }
 
    [MethodImpl(MethodImplOptions.NoInlining)]
    private bool IsEmptyHostAllowed(HttpContext context, MiddlewareConfiguration middlewareConfiguration)
    {
        // Http/1.0 does not require the host header.
        // Http/1.1 requires the header but the value may be empty.
        if (!middlewareConfiguration.AllowEmptyHosts)
        {
            _logger.RequestRejectedMissingHost(context.Request.Protocol);
            return false;
        }
 
        _logger.RequestAllowedMissingHost(context.Request.Protocol);
 
        return true;
    }
}