File: CorsMiddlewareTests.cs
Web Access
Project: src\src\Middleware\CORS\test\UnitTests\Microsoft.AspNetCore.Cors.Test.csproj (Microsoft.AspNetCore.Cors.Test)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Net;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq;
 
namespace Microsoft.AspNetCore.Cors.Infrastructure;
 
public class CorsMiddlewareTests
{
    private const string OriginUrl = "http://api.example.com";
 
    [Theory]
    [InlineData("PuT")]
    [InlineData("PUT")]
    public async Task CorsRequest_MatchesPolicy_OnCaseInsensitiveAccessControlRequestMethod(string accessControlRequestMethod)
    {
        // Arrange
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseCors(builder =>
                        builder.WithOrigins(OriginUrl)
                               .WithMethods("PUT"));
                    app.Run(async context =>
                    {
                        await context.Response.WriteAsync("Cross origin response");
                    });
                })
                .ConfigureServices(services => services.AddCors());
            }).Build();
 
        await host.StartAsync();
 
        using (var server = host.GetTestServer())
        {
            // Act
            // Actual request.
            var response = await server.CreateRequest("/")
                .AddHeader(CorsConstants.Origin, OriginUrl)
                .SendAsync(accessControlRequestMethod);
 
            // Assert
            response.EnsureSuccessStatusCode();
            Assert.Single(response.Headers);
            Assert.Equal("Cross origin response", await response.Content.ReadAsStringAsync());
            Assert.Equal(OriginUrl, response.Headers.GetValues(CorsConstants.AccessControlAllowOrigin).FirstOrDefault());
        }
    }
 
    [Fact]
    public async Task CorsRequest_MatchPolicy_SetsResponseHeaders()
    {
        // Arrange
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseCors(builder =>
                        builder.WithOrigins(OriginUrl)
                               .WithMethods("PUT")
                               .WithHeaders("Header1")
                               .WithExposedHeaders("AllowedHeader"));
                    app.Run(async context =>
                    {
                        await context.Response.WriteAsync("Cross origin response");
                    });
                })
                .ConfigureServices(services => services.AddCors());
            }).Build();
 
        await host.StartAsync();
 
        using (var server = host.GetTestServer())
        {
            // Act
            // Actual request.
            var response = await server.CreateRequest("/")
                .AddHeader(CorsConstants.Origin, OriginUrl)
                .SendAsync("PUT");
 
            // Assert
            response.EnsureSuccessStatusCode();
            Assert.Equal(2, response.Headers.Count());
            Assert.Equal("Cross origin response", await response.Content.ReadAsStringAsync());
            Assert.Equal(OriginUrl, response.Headers.GetValues(CorsConstants.AccessControlAllowOrigin).FirstOrDefault());
            Assert.Equal("AllowedHeader", response.Headers.GetValues(CorsConstants.AccessControlExposeHeaders).FirstOrDefault());
        }
    }
 
    [Theory]
    [InlineData("OpTions")]
    [InlineData("OPTIONS")]
    public async Task PreFlight_MatchesPolicy_OnCaseInsensitiveOptionsMethod(string preflightMethod)
    {
        // Arrange
        var policy = new CorsPolicy();
        policy.Origins.Add(OriginUrl);
        policy.Methods.Add("PUT");
 
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseCors("customPolicy");
                    app.Run(async context =>
                    {
                        await context.Response.WriteAsync("Cross origin response");
                    });
                })
                .ConfigureServices(services =>
                {
                    services.AddCors(options =>
                    {
                        options.AddPolicy("customPolicy", policy);
                    });
                });
            }).Build();
 
        await host.StartAsync();
 
        using (var server = host.GetTestServer())
        {
            // Act
            // Preflight request.
            var response = await server.CreateRequest("/")
                .AddHeader(CorsConstants.Origin, OriginUrl)
                .SendAsync(preflightMethod);
 
            // Assert
            response.EnsureSuccessStatusCode();
            Assert.Single(response.Headers);
            Assert.Equal(OriginUrl, response.Headers.GetValues(CorsConstants.AccessControlAllowOrigin).FirstOrDefault());
        }
    }
 
    [Fact]
    public async Task PreFlight_MatchesPolicy_SetsResponseHeaders()
    {
        // Arrange
        var policy = new CorsPolicy();
        policy.Origins.Add(OriginUrl);
        policy.Methods.Add("PUT");
        policy.Headers.Add("Header1");
        policy.ExposedHeaders.Add("AllowedHeader");
 
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseCors("customPolicy");
                    app.Run(async context =>
                    {
                        await context.Response.WriteAsync("Cross origin response");
                    });
                })
                .ConfigureServices(services =>
                {
                    services.AddCors(options =>
                    {
                        options.AddPolicy("customPolicy", policy);
                    });
                });
            }).Build();
 
        await host.StartAsync();
 
        using (var server = host.GetTestServer())
        {
            // Act
            // Preflight request.
            var response = await server.CreateRequest("/")
                .AddHeader(CorsConstants.Origin, OriginUrl)
                .AddHeader(CorsConstants.AccessControlRequestMethod, "PUT")
                .SendAsync(CorsConstants.PreflightHttpMethod);
 
            // Assert
            response.EnsureSuccessStatusCode();
            Assert.Collection(
                response.Headers.OrderBy(h => h.Key),
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowHeaders, kvp.Key);
                    Assert.Equal(new[] { "Header1" }, kvp.Value);
                },
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowMethods, kvp.Key);
                    Assert.Equal(new[] { "PUT" }, kvp.Value);
                },
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
                    Assert.Equal(new[] { OriginUrl }, kvp.Value);
                });
        }
    }
 
    [Fact]
    public async Task PreFlight_WithCredentialsAllowed_ReflectsRequestHeaders()
    {
        // Arrange
        var policy = new CorsPolicyBuilder(OriginUrl)
            .AllowAnyHeader()
            .AllowAnyMethod()
            .AllowCredentials()
            .Build();
 
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseCors("customPolicy");
                    app.Run(async context =>
                    {
                        await context.Response.WriteAsync("Cross origin response");
                    });
                })
                .ConfigureServices(services =>
                {
                    services.AddCors(options =>
                    {
                        options.AddPolicy("customPolicy", policy);
                    });
                });
            }).Build();
 
        await host.StartAsync();
 
        using (var server = host.GetTestServer())
        {
            // Act
            // Preflight request.
            var response = await server.CreateRequest("/")
                .AddHeader(CorsConstants.Origin, OriginUrl)
                .AddHeader(CorsConstants.AccessControlRequestMethod, "PUT")
                .AddHeader(CorsConstants.AccessControlRequestHeaders, "X-Test1,X-Test2")
                .SendAsync(CorsConstants.PreflightHttpMethod);
 
            // Assert
            response.EnsureSuccessStatusCode();
            Assert.Collection(
                response.Headers.OrderBy(h => h.Key),
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowCredentials, kvp.Key);
                    Assert.Equal(new[] { "true" }, kvp.Value);
                },
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowHeaders, kvp.Key);
                    Assert.Equal(new[] { "X-Test1,X-Test2" }, kvp.Value);
                },
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowMethods, kvp.Key);
                    Assert.Equal(new[] { "PUT" }, kvp.Value);
                },
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
                    Assert.Equal(new[] { OriginUrl }, kvp.Value);
                });
        }
    }
 
    [Fact]
    public async Task PreFlightRequest_DoesNotMatchPolicy_SetsResponseHeadersAndReturnsNoContent()
    {
        // Arrange
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseCors(builder =>
                        builder.WithOrigins(OriginUrl)
                               .WithMethods("PUT")
                               .WithHeaders("Header1")
                               .WithExposedHeaders("AllowedHeader"));
                    app.Run(async context =>
                    {
                        await context.Response.WriteAsync("Cross origin response");
                    });
                })
                .ConfigureServices(services => services.AddCors());
            }).Build();
 
        await host.StartAsync();
 
        using (var server = host.GetTestServer())
        {
            // Act
            // Preflight request.
            var response = await server.CreateRequest("/")
                .AddHeader(CorsConstants.Origin, "http://test.example.com")
                .AddHeader(CorsConstants.AccessControlRequestMethod, "PUT")
                .SendAsync(CorsConstants.PreflightHttpMethod);
 
            // Assert
            Assert.Equal(HttpStatusCode.NoContent, response.StatusCode);
            Assert.Empty(response.Headers);
        }
 
        await host.StartAsync();
    }
 
    [Fact]
    public async Task CorsRequest_DoesNotMatchPolicy_DoesNotSetHeaders()
    {
        // Arrange
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseCors(builder =>
                        builder.WithOrigins(OriginUrl)
                               .WithMethods("PUT")
                               .WithHeaders("Header1")
                               .WithExposedHeaders("AllowedHeader"));
                    app.Run(async context =>
                    {
                        await context.Response.WriteAsync("Cross origin response");
                    });
                })
                .ConfigureServices(services => services.AddCors());
            }).Build();
 
        await host.StartAsync();
 
        using (var server = host.GetTestServer())
        {
            // Act
            // Actual request.
            var response = await server.CreateRequest("/")
                .AddHeader(CorsConstants.Origin, "http://test.example.com")
                .SendAsync("PUT");
 
            // Assert
            Assert.Equal(HttpStatusCode.OK, response.StatusCode);
            Assert.Empty(response.Headers);
        }
    }
 
    [Fact]
    public async Task Uses_PolicyProvider_AsFallback()
    {
        // Arrange
        var corsService = Mock.Of<ICorsService>();
        var mockProvider = new Mock<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
        mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
            .Returns(Task.FromResult<CorsPolicy>(null))
            .Verifiable();
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            loggerFactory,
            policyName: null);
 
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider.Object);
 
        // Assert
        mockProvider.Verify(
            o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()),
            Times.Once);
    }
 
    [Fact]
    public async Task DoesNotSetHeaders_ForNoPolicy()
    {
        // Arrange
        var corsService = Mock.Of<ICorsService>();
        var mockProvider = new Mock<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
        mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
            .Returns(Task.FromResult<CorsPolicy>(null))
            .Verifiable();
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            loggerFactory,
            policyName: null);
 
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider.Object);
 
        // Assert
        Assert.Equal(200, httpContext.Response.StatusCode);
        Assert.Empty(httpContext.Response.Headers);
        mockProvider.Verify(
            o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()),
            Times.Once);
    }
 
    [Fact]
    public async Task PreFlight_MatchesDefaultPolicy_SetsResponseHeaders()
    {
        // Arrange
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseCors();
                    app.Run(async context =>
                    {
                        await context.Response.WriteAsync("Cross origin response");
                    });
                })
                .ConfigureServices(services =>
                {
                    services.AddCors(options =>
                    {
                        options.AddDefaultPolicy(policyBuilder =>
                        {
                            policyBuilder
                            .WithOrigins(OriginUrl)
                            .WithMethods("PUT")
                            .WithHeaders("Header1")
                            .WithExposedHeaders("AllowedHeader")
                            .Build();
                        });
                        options.AddPolicy("policy2", policyBuilder =>
                        {
                            policyBuilder
                            .WithOrigins("http://test.example.com")
                            .Build();
                        });
                    });
                });
            }).Build();
 
        await host.StartAsync();
 
        using (var server = host.GetTestServer())
        {
            // Act
            // Preflight request.
            var response = await server.CreateRequest("/")
                .AddHeader(CorsConstants.Origin, OriginUrl)
                .AddHeader(CorsConstants.AccessControlRequestMethod, "PUT")
                .SendAsync(CorsConstants.PreflightHttpMethod);
 
            // Assert
            response.EnsureSuccessStatusCode();
            Assert.Collection(
                response.Headers.OrderBy(h => h.Key),
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowHeaders, kvp.Key);
                    Assert.Equal(new[] { "Header1" }, kvp.Value);
                },
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowMethods, kvp.Key);
                    Assert.Equal(new[] { "PUT" }, kvp.Value);
                },
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
                    Assert.Equal(new[] { OriginUrl }, kvp.Value);
                });
        }
    }
 
    [Fact]
    public async Task CorsRequest_SetsResponseHeaders()
    {
        // Arrange
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    app.UseCors(builder =>
                        builder.WithOrigins(OriginUrl)
                            .WithMethods("PUT")
                            .WithHeaders("Header1")
                            .WithExposedHeaders("AllowedHeader"));
                    app.Run(async context =>
                    {
                        context.Response.Headers.Add("Test", "Should-Appear");
                        await context.Response.WriteAsync("Cross origin response");
                    });
                })
                .ConfigureServices(services => services.AddCors());
            }).Build();
 
        await host.StartAsync();
 
        using (var server = host.GetTestServer())
        {
            // Act
            // Actual request.
            var response = await server.CreateRequest("/")
                .AddHeader(CorsConstants.Origin, OriginUrl)
                .SendAsync("PUT");
 
            // Assert
            response.EnsureSuccessStatusCode();
            Assert.Collection(
                response.Headers.OrderBy(o => o.Key),
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
                    Assert.Equal(OriginUrl, Assert.Single(kvp.Value));
                },
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlExposeHeaders, kvp.Key);
                    Assert.Equal("AllowedHeader", Assert.Single(kvp.Value));
                },
                kvp =>
                {
                    Assert.Equal("Test", kvp.Key);
                    Assert.Equal("Should-Appear", Assert.Single(kvp.Value));
                });
 
            Assert.Equal("Cross origin response", await response.Content.ReadAsStringAsync());
        }
    }
 
    [Fact]
    public async Task CorsRequest_SetsResponseHeader_IfExceptionHandlerClearsResponse()
    {
        // Arrange
        var exceptionSeen = true;
        using var host = new HostBuilder()
            .ConfigureWebHost(webHostBuilder =>
            {
                webHostBuilder
                .UseTestServer()
                .Configure(app =>
                {
                    // Simulate ExceptionHandler middleware
                    app.Use(async (context, next) =>
                    {
                        try
                        {
                            await next(context);
                        }
                        catch (Exception)
                        {
                            exceptionSeen = true;
                            context.Response.Clear();
                            context.Response.StatusCode = 500;
                        }
                    });
 
                    app.UseCors(builder =>
                        builder.WithOrigins(OriginUrl)
                            .WithMethods("PUT")
                            .WithHeaders("Header1")
                            .WithExposedHeaders("AllowedHeader"));
 
                    app.Run(context =>
                    {
                        context.Response.Headers.Add("Test", "Should-Not-Exist");
                        throw new Exception("Runtime error");
                    });
                })
                .ConfigureServices(services => services.AddCors());
            }).Build();
 
        await host.StartAsync();
 
        using (var server = host.GetTestServer())
        {
            // Act
            // Actual request.
            var response = await server.CreateRequest("/")
                .AddHeader(CorsConstants.Origin, OriginUrl)
                .SendAsync("PUT");
 
            // Assert
            Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode);
            Assert.True(exceptionSeen, "We expect exception middleware to have executed");
 
            Assert.Collection(
                response.Headers.OrderBy(o => o.Key),
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
                    Assert.Equal(OriginUrl, Assert.Single(kvp.Value));
                },
                kvp =>
                {
                    Assert.Equal(CorsConstants.AccessControlExposeHeaders, kvp.Key);
                    Assert.Equal("AllowedHeader", Assert.Single(kvp.Value));
                });
        }
    }
 
    [Fact]
    public async Task Invoke_WithCustomPolicyProviderThatReturnsAsynchronously_Works()
    {
        // Arrange
        var corsService = new CorsService(Options.Create(new CorsOptions()), NullLoggerFactory.Instance);
        var mockProvider = new Mock<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
        var policy = new CorsPolicyBuilder()
            .WithOrigins(OriginUrl)
            .WithHeaders("AllowedHeader")
            .Build();
        mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
            .ReturnsAsync(policy, TimeSpan.FromMilliseconds(10));
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            loggerFactory,
            "DefaultPolicyName");
 
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Method = "OPTIONS";
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { OriginUrl });
        httpContext.Request.Headers.Add(CorsConstants.AccessControlRequestMethod, new[] { "PUT" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider.Object);
 
        // Assert
        var response = httpContext.Response;
        Assert.Collection(
            response.Headers.OrderBy(o => o.Key),
            kvp =>
            {
                Assert.Equal(CorsConstants.AccessControlAllowHeaders, kvp.Key);
                Assert.Equal("AllowedHeader", Assert.Single(kvp.Value));
            },
            kvp =>
            {
                Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
                Assert.Equal(OriginUrl, Assert.Single(kvp.Value));
            });
    }
 
    [Fact]
    public async Task Invoke_HasEndpointWithNoMetadata_RunsCors()
    {
        // Arrange
        var corsService = Mock.Of<ICorsService>();
        var mockProvider = new Mock<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
        mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
            .Returns(Task.FromResult<CorsPolicy>(null))
            .Verifiable();
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            loggerFactory,
            "DefaultPolicyName");
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, EndpointMetadataCollection.Empty, "Test endpoint"));
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider.Object);
 
        // Assert
        mockProvider.Verify(
            o => o.GetPolicyAsync(It.IsAny<HttpContext>(), "DefaultPolicyName"),
            Times.Once);
    }
 
    [Fact]
    public async Task Invoke_HasEndpointWithEnableMetadata_MiddlewareHasPolicyName_RunsCorsWithPolicyName()
    {
        // Arrange
        var corsService = Mock.Of<ICorsService>();
        var mockProvider = new Mock<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
        mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
            .Returns(Task.FromResult<CorsPolicy>(null))
            .Verifiable();
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            loggerFactory,
            "DefaultPolicyName");
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName")), "Test endpoint"));
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider.Object);
 
        // Assert
        mockProvider.Verify(
            o => o.GetPolicyAsync(It.IsAny<HttpContext>(), "MetadataPolicyName"),
            Times.Once);
    }
 
    [Fact]
    public async Task Invoke_HasEndpointWithEnableMetadata_HasSignificantDisableCors_ReturnsNoContentForPreflightRequest()
    {
        // Arrange
        var corsService = Mock.Of<ICorsService>();
        var policyProvider = Mock.Of<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
 
        var middleware = new CorsMiddleware(
            c => { throw new Exception("Should not be called."); },
            corsService,
            loggerFactory,
            "DefaultPolicyName");
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute(), new DisableCorsAttribute()), "Test endpoint"));
        httpContext.Request.Method = "OPTIONS";
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
        httpContext.Request.Headers.Add(CorsConstants.AccessControlRequestMethod, new[] { "GET" });
 
        // Act
        await middleware.Invoke(httpContext, policyProvider);
 
        // Assert
        Assert.Equal(StatusCodes.Status204NoContent, httpContext.Response.StatusCode);
    }
 
    [Fact]
    public async Task Invoke_HasEndpointWithEnableMetadata_HasSignificantDisableCors_ExecutesNextMiddleware()
    {
        // Arrange
        var executed = false;
        var corsService = Mock.Of<ICorsService>();
        var policyProvider = Mock.Of<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
 
        var middleware = new CorsMiddleware(
            c =>
            {
                executed = true;
                return Task.CompletedTask;
            },
            corsService,
            loggerFactory,
            "DefaultPolicyName");
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute(), new DisableCorsAttribute()), "Test endpoint"));
        httpContext.Request.Method = "GET";
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
        httpContext.Request.Headers.Add(CorsConstants.AccessControlRequestMethod, new[] { "GET" });
 
        // Act
        await middleware.Invoke(httpContext, policyProvider);
 
        // Assert
        Assert.True(executed);
        Mock.Get(policyProvider).Verify(v => v.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()), Times.Never());
        Mock.Get(corsService).Verify(v => v.EvaluatePolicy(It.IsAny<HttpContext>(), It.IsAny<CorsPolicy>()), Times.Never());
    }
 
    [Fact]
    public async Task Invoke_HasEndpointWithEnableMetadata_MiddlewareHasPolicy_RunsCorsWithPolicyName()
    {
        // Arrange
        var policy = new CorsPolicyBuilder().Build();
        var corsService = Mock.Of<ICorsService>();
        var mockProvider = new Mock<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
        mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
            .Returns(Task.FromResult<CorsPolicy>(null))
            .Verifiable();
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            policy,
            loggerFactory);
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName")), "Test endpoint"));
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider.Object);
 
        // Assert
        mockProvider.Verify(
            o => o.GetPolicyAsync(It.IsAny<HttpContext>(), "MetadataPolicyName"),
            Times.Once);
    }
 
    [Fact]
    public async Task Invoke_HasEndpointRequireCorsMetadata_MiddlewareHasPolicy_RunsCorsWithPolicyName()
    {
        // Arrange
        var defaultPolicy = new CorsPolicyBuilder().Build();
        var metadataPolicy = new CorsPolicyBuilder().Build();
        var mockCorsService = new Mock<ICorsService>();
        var mockProvider = new Mock<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
        mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
            .Returns(Task.FromResult<CorsPolicy>(null))
            .Verifiable();
        mockCorsService.Setup(o => o.EvaluatePolicy(It.IsAny<HttpContext>(), It.IsAny<CorsPolicy>()))
            .Returns(new CorsResult())
            .Verifiable();
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            mockCorsService.Object,
            defaultPolicy,
            loggerFactory);
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new CorsPolicyMetadata(metadataPolicy)), "Test endpoint"));
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider.Object);
 
        // Assert
        mockProvider.Verify(
            o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()),
            Times.Never);
        mockCorsService.Verify(
            o => o.EvaluatePolicy(It.IsAny<HttpContext>(), metadataPolicy),
            Times.Once);
    }
 
    [Fact]
    public async Task Invoke_HasEndpointWithEnableMetadataWithNoName_RunsCorsWithStaticPolicy()
    {
        // Arrange
        var policy = new CorsPolicyBuilder().Build();
        var mockCorsService = new Mock<ICorsService>();
        var mockProvider = new Mock<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
        mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
            .Returns(Task.FromResult<CorsPolicy>(null))
            .Verifiable();
        mockCorsService.Setup(o => o.EvaluatePolicy(It.IsAny<HttpContext>(), It.IsAny<CorsPolicy>()))
            .Returns(new CorsResult())
            .Verifiable();
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            mockCorsService.Object,
            policy,
            loggerFactory);
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute()), "Test endpoint"));
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider.Object);
 
        // Assert
        mockProvider.Verify(
            o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()),
            Times.Never);
        mockCorsService.Verify(
            o => o.EvaluatePolicy(It.IsAny<HttpContext>(), policy),
            Times.Once);
    }
 
    [Fact]
    public async Task Invoke_HasEndpointWithDisableMetadata_SkipCors()
    {
        // Arrange
        var corsService = Mock.Of<ICorsService>();
        var mockProvider = new Mock<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
        mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
            .Returns(Task.FromResult<CorsPolicy>(null))
            .Verifiable();
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            loggerFactory,
            "DefaultPolicyName");
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new DisableCorsAttribute()), "Test endpoint"));
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider.Object);
 
        // Assert
        mockProvider.Verify(
            o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()),
            Times.Never);
    }
 
    [Fact]
    public async Task Invoke_HasEndpointWithMutlipleMetadata_SkipCorsBecauseOfMetadataOrder()
    {
        // Arrange
        var corsService = Mock.Of<ICorsService>();
        var mockProvider = new Mock<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
        mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
            .Returns(Task.FromResult<CorsPolicy>(null))
            .Verifiable();
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            loggerFactory,
            "DefaultPolicyName");
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName"), new DisableCorsAttribute()), "Test endpoint"));
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider.Object);
 
        // Assert
        mockProvider.Verify(
            o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()),
            Times.Never);
    }
 
    [Fact]
    public async Task Invoke_InvokeFlagSet()
    {
        // Arrange
        var corsService = Mock.Of<ICorsService>();
        var mockProvider = Mock.Of<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            loggerFactory,
            "DefaultPolicyName");
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName"), new DisableCorsAttribute()), "Test endpoint"));
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider);
 
        // Assert
        Assert.Contains(httpContext.Items, item => string.Equals(item.Key as string, "__CorsMiddlewareWithEndpointInvoked"));
    }
 
    [Fact]
    public async Task Invoke_WithoutOrigin_InvokeFlagSet()
    {
        // Arrange
        var corsService = Mock.Of<ICorsService>();
        var mockProvider = Mock.Of<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            loggerFactory,
            "DefaultPolicyName");
 
        var httpContext = new DefaultHttpContext();
        httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName"), new DisableCorsAttribute()), "Test endpoint"));
 
        // Act
        await middleware.Invoke(httpContext, mockProvider);
 
        // Assert
        Assert.Contains(httpContext.Items, item => string.Equals(item.Key as string, "__CorsMiddlewareWithEndpointInvoked"));
    }
 
    [Fact]
    public async Task Invoke_WithoutEndpoint_InvokeFlagSet()
    {
        // Arrange
        var corsService = Mock.Of<ICorsService>();
        var mockProvider = Mock.Of<ICorsPolicyProvider>();
        var loggerFactory = NullLoggerFactory.Instance;
 
        var middleware = new CorsMiddleware(
            Mock.Of<RequestDelegate>(),
            corsService,
            loggerFactory,
            "DefaultPolicyName");
 
        var httpContext = new DefaultHttpContext();
        httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" });
 
        // Act
        await middleware.Invoke(httpContext, mockProvider);
 
        // Assert
        Assert.DoesNotContain(httpContext.Items, item => string.Equals(item.Key as string, "__CorsMiddlewareWithEndpointInvoked"));
    }
}