File: DefaultHttpContextTests.cs
Web Access
Project: src\src\Http\Http\test\Microsoft.AspNetCore.Http.Tests.csproj (Microsoft.AspNetCore.Http.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Net.WebSockets;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Security.Claims;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Shared;
using Microsoft.Extensions.DependencyInjection;
 
namespace Microsoft.AspNetCore.Http;
 
public class DefaultHttpContextTests
{
    [Fact]
    public void GetOnSessionProperty_ThrowsOnMissingSessionFeature()
    {
        // Arrange
        var context = new DefaultHttpContext();
 
        // Act & Assert
        var exception = Assert.Throws<InvalidOperationException>(() => context.Session);
        Assert.Equal("Session has not been configured for this application or request.", exception.Message);
    }
 
    [Fact]
    public void GetOnSessionProperty_ReturnsAvailableSession()
    {
        // Arrange
        var context = new DefaultHttpContext();
        var session = new TestSession();
        session.Set("key1", null);
        session.Set("key2", null);
        var feature = new BlahSessionFeature();
        feature.Session = session;
        context.Features.Set<ISessionFeature>(feature);
 
        // Act & Assert
        Assert.Same(session, context.Session);
        context.Session.Set("key3", null);
        Assert.Equal(3, context.Session.Keys.Count());
    }
 
    [Fact]
    public void AllowsSettingSession_WithoutSettingUpSessionFeature_Upfront()
    {
        // Arrange
        var session = new TestSession();
        var context = new DefaultHttpContext();
 
        // Act
        context.Session = session;
 
        // Assert
        Assert.Same(session, context.Session);
    }
 
    [Fact]
    public void SettingSession_OverridesAvailableSession()
    {
        // Arrange
        var context = new DefaultHttpContext();
        var session = new TestSession();
        session.Set("key1", null);
        session.Set("key2", null);
        var feature = new BlahSessionFeature();
        feature.Session = session;
        context.Features.Set<ISessionFeature>(feature);
 
        // Act
        context.Session = new TestSession();
 
        // Assert
        Assert.NotSame(session, context.Session);
        Assert.Empty(context.Session.Keys);
    }
 
    [Fact]
    public void EmptyUserIsNeverNull()
    {
        var context = new DefaultHttpContext(new FeatureCollection());
        Assert.NotNull(context.User);
        Assert.Single(context.User.Identities);
        Assert.True(object.ReferenceEquals(context.User, context.User));
        Assert.False(context.User.Identity.IsAuthenticated);
        Assert.True(string.IsNullOrEmpty(context.User.Identity.AuthenticationType));
 
        context.User = null;
        Assert.NotNull(context.User);
        Assert.Single(context.User.Identities);
        Assert.True(object.ReferenceEquals(context.User, context.User));
        Assert.False(context.User.Identity.IsAuthenticated);
        Assert.True(string.IsNullOrEmpty(context.User.Identity.AuthenticationType));
 
        context.User = new ClaimsPrincipal();
        Assert.NotNull(context.User);
        Assert.Empty(context.User.Identities);
        Assert.True(object.ReferenceEquals(context.User, context.User));
        Assert.Null(context.User.Identity);
 
        context.User = new ClaimsPrincipal(new ClaimsIdentity("SomeAuthType"));
        Assert.Equal("SomeAuthType", context.User.Identity.AuthenticationType);
        Assert.True(context.User.Identity.IsAuthenticated);
    }
 
    [Fact]
    public void GetItems_DefaultCollectionProvided()
    {
        var context = new DefaultHttpContext(new FeatureCollection());
        Assert.Null(context.Features.Get<IItemsFeature>());
        var items = context.Items;
        Assert.NotNull(context.Features.Get<IItemsFeature>());
        Assert.NotNull(items);
        Assert.Same(items, context.Items);
        var item = new object();
        context.Items["foo"] = item;
        Assert.Same(item, context.Items["foo"]);
    }
 
    [Fact]
    public void GetItems_DefaultRequestIdentifierAvailable()
    {
        var context = new DefaultHttpContext(new FeatureCollection());
        Assert.Null(context.Features.Get<IHttpRequestIdentifierFeature>());
        var traceIdentifier = context.TraceIdentifier;
        Assert.NotNull(context.Features.Get<IHttpRequestIdentifierFeature>());
        Assert.NotNull(traceIdentifier);
        Assert.Same(traceIdentifier, context.TraceIdentifier);
 
        context.TraceIdentifier = "Hello";
        Assert.Same("Hello", context.TraceIdentifier);
    }
 
    [Fact]
    public void SetItems_NewCollectionUsed()
    {
        var context = new DefaultHttpContext(new FeatureCollection());
        Assert.Null(context.Features.Get<IItemsFeature>());
        var items = new Dictionary<object, object>();
        context.Items = items;
        Assert.NotNull(context.Features.Get<IItemsFeature>());
        Assert.Same(items, context.Items);
        var item = new object();
        items["foo"] = item;
        Assert.Same(item, context.Items["foo"]);
    }
 
    [Fact]
    public void UpdateFeatures_ClearsCachedFeatures()
    {
        var features = new FeatureCollection();
        features.Set<IHttpRequestFeature>(new HttpRequestFeature());
        features.Set<IHttpResponseFeature>(new HttpResponseFeature());
        features.Set<IHttpResponseBodyFeature>(new StreamResponseBodyFeature(Stream.Null));
        features.Set<IHttpWebSocketFeature>(new TestHttpWebSocketFeature());
 
        // FeatureCollection is set. all cached interfaces are null.
        var context = new DefaultHttpContext(features);
        TestAllCachedFeaturesAreNull(context, features);
        Assert.Equal(4, features.Count());
 
        // getting feature properties populates feature collection with defaults
        TestAllCachedFeaturesAreSet(context, features);
        Assert.NotEqual(4, features.Count());
 
        // FeatureCollection is null. and all cached interfaces are null.
        // only top level is tested because child objects are inaccessible.
        context.Uninitialize();
        TestCachedFeaturesAreNull(context, null);
 
        var newFeatures = new FeatureCollection();
        newFeatures.Set<IHttpRequestFeature>(new HttpRequestFeature());
        newFeatures.Set<IHttpResponseFeature>(new HttpResponseFeature());
        newFeatures.Set<IHttpResponseBodyFeature>(new StreamResponseBodyFeature(Stream.Null));
        newFeatures.Set<IHttpWebSocketFeature>(new TestHttpWebSocketFeature());
 
        // FeatureCollection is set to newFeatures. all cached interfaces are null.
        context.Initialize(newFeatures);
        TestAllCachedFeaturesAreNull(context, newFeatures);
        Assert.Equal(4, newFeatures.Count());
 
        // getting feature properties populates new feature collection with defaults
        TestAllCachedFeaturesAreSet(context, newFeatures);
        Assert.NotEqual(4, newFeatures.Count());
    }
 
    [Fact]
    public void RequestServicesAreNotOverwrittenIfAlreadySet()
    {
        var serviceProvider = new ServiceCollection()
                    .BuildServiceProvider();
 
        var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
 
        var context = new DefaultHttpContext();
        context.ServiceScopeFactory = scopeFactory;
        context.RequestServices = serviceProvider;
 
        Assert.Same(serviceProvider, context.RequestServices);
    }
 
    [Fact]
    public async Task RequestServicesAreDisposedOnCompleted()
    {
        var serviceProvider = new ServiceCollection()
            .AddTransient<DisposableThing>()
            .BuildServiceProvider();
 
        var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
        DisposableThing instance = null;
 
        var context = new DefaultHttpContext();
        context.ServiceScopeFactory = scopeFactory;
        var responseFeature = new TestHttpResponseFeature();
        context.Features.Set<IHttpResponseFeature>(responseFeature);
 
        Assert.NotNull(context.RequestServices);
        Assert.Single(responseFeature.CompletedCallbacks);
 
        instance = context.RequestServices.GetRequiredService<DisposableThing>();
 
        var callback = responseFeature.CompletedCallbacks[0];
        await callback.callback(callback.state);
 
        Assert.Null(context.RequestServices);
        Assert.True(instance.Disposed);
    }
 
    [Fact]
    public async Task RequestServicesAreDisposedAsynOnCompleted()
    {
        var serviceProvider = new AsyncDisposableServiceProvider(new ServiceCollection()
            .AddTransient<DisposableThing>()
            .BuildServiceProvider());
 
        var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
        DisposableThing instance = null;
 
        var context = new DefaultHttpContext();
        context.ServiceScopeFactory = scopeFactory;
        var responseFeature = new TestHttpResponseFeature();
        context.Features.Set<IHttpResponseFeature>(responseFeature);
 
        Assert.NotNull(context.RequestServices);
        Assert.Single(responseFeature.CompletedCallbacks);
 
        instance = context.RequestServices.GetRequiredService<DisposableThing>();
 
        var callback = responseFeature.CompletedCallbacks[0];
        await callback.callback(callback.state);
 
        Assert.Null(context.RequestServices);
        Assert.True(instance.Disposed);
        var scope = Assert.Single(serviceProvider.Scopes);
        Assert.True(scope.DisposeAsyncCalled);
        Assert.False(scope.DisposeCalled);
    }
 
    [Fact]
    public void InternalActiveFlagIsSetAndUnset()
    {
        var context = new DefaultHttpContext();
 
        Assert.False(context._active);
 
        context.Initialize(new FeatureCollection());
 
        Assert.True(context._active);
 
        context.Uninitialize();
 
        Assert.False(context._active);
    }
 
    [Fact]
    public void DebuggerToString_EmptyRequest()
    {
        var context = new DefaultHttpContext();
 
        var debugText = HttpContextDebugFormatter.ContextToString(context, reasonPhrase: null);
        Assert.Equal("(unspecified) 200", debugText);
    }
 
    [Fact]
    public void DebuggerToString_HasReason()
    {
        var context = new DefaultHttpContext();
 
        var debugText = HttpContextDebugFormatter.ContextToString(context, reasonPhrase: "OK");
        Assert.Equal("(unspecified) 200 OK", debugText);
    }
 
    [Fact]
    public void DebuggerToString_HasMethod()
    {
        var context = new DefaultHttpContext();
        context.Request.Method = "GET";
 
        var debugText = HttpContextDebugFormatter.ContextToString(context, reasonPhrase: null);
        Assert.Equal("GET (unspecified) 200", debugText);
    }
 
    [Fact]
    public void DebuggerToString_HasProtocol()
    {
        var context = new DefaultHttpContext();
        context.Request.Protocol = "HTTP/2";
 
        var debugText = HttpContextDebugFormatter.ContextToString(context, reasonPhrase: null);
        Assert.Equal("(unspecified) HTTP/2 200", debugText);
    }
 
    [Fact]
    public void DebuggerToString_HasContentType()
    {
        var context = new DefaultHttpContext();
        context.Request.ContentType = "application/json";
 
        var debugText = HttpContextDebugFormatter.ContextToString(context, reasonPhrase: null);
        Assert.Equal("(unspecified) application/json 200", debugText);
    }
 
    void TestAllCachedFeaturesAreNull(HttpContext context, IFeatureCollection features)
    {
        TestCachedFeaturesAreNull(context, features);
        TestCachedFeaturesAreNull(context.Request, features);
        TestCachedFeaturesAreNull(context.Response, features);
        TestCachedFeaturesAreNull(context.Connection, features);
        TestCachedFeaturesAreNull(context.WebSockets, features);
    }
 
    void TestCachedFeaturesAreNull(object value, IFeatureCollection features)
    {
        var type = value.GetType();
 
        var field = type
            .GetFields(BindingFlags.NonPublic | BindingFlags.Instance)
            .Single(f =>
                f.FieldType.GetTypeInfo().IsGenericType &&
                f.FieldType.GetGenericTypeDefinition() == typeof(FeatureReferences<>));
 
        var boxedExpectedStruct = features == null ?
            Activator.CreateInstance(field.FieldType) :
            Activator.CreateInstance(field.FieldType, features);
 
        var boxedActualStruct = field.GetValue(value);
 
        Assert.Equal(boxedExpectedStruct, boxedActualStruct);
    }
 
    void TestAllCachedFeaturesAreSet(HttpContext context, IFeatureCollection features)
    {
        TestCachedFeaturesAreSet(context, features);
        TestCachedFeaturesAreSet(context.Request, features);
        TestCachedFeaturesAreSet(context.Response, features);
        TestCachedFeaturesAreSet(context.Connection, features);
        TestCachedFeaturesAreSet(context.WebSockets, features);
    }
 
    void TestCachedFeaturesAreSet(object value, IFeatureCollection features)
    {
        var type = value.GetType();
 
        var properties = type
            .GetProperties(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)
            .Where(p => p.PropertyType.GetTypeInfo().IsInterface);
 
        TestFeatureProperties(value, features, properties);
 
        var fields = type
            .GetFields(BindingFlags.NonPublic | BindingFlags.Instance)
            .Where(f => f.FieldType.GetTypeInfo().IsInterface && f.GetCustomAttribute<CompilerGeneratedAttribute>() == null);
 
        foreach (var field in fields)
        {
            if (field.FieldType == typeof(IFeatureCollection))
            {
                Assert.Same(features, field.GetValue(value));
            }
            else
            {
                var v = field.GetValue(value);
                Assert.Same(features[field.FieldType], v);
                Assert.NotNull(v);
            }
        }
 
    }
 
    private static void TestFeatureProperties(object value, IFeatureCollection features, IEnumerable<PropertyInfo> properties)
    {
        foreach (var property in properties)
        {
            if (property.PropertyType == typeof(IFeatureCollection))
            {
                Assert.Same(features, property.GetValue(value));
            }
            else
            {
                if (property.Name.Contains("Feature"))
                {
                    var v = property.GetValue(value);
                    Assert.Same(features[property.PropertyType], v);
                    Assert.NotNull(v);
                }
            }
        }
    }
 
    private HttpContext CreateContext()
    {
        var context = new DefaultHttpContext();
        return context;
    }
 
    private class DisposableThing : IDisposable
    {
        public bool Disposed { get; set; }
        public void Dispose()
        {
            Disposed = true;
        }
    }
 
    private class TestHttpResponseFeature : IHttpResponseFeature
    {
        public List<(Func<object, Task> callback, object state)> CompletedCallbacks = new List<(Func<object, Task> callback, object state)>();
 
        public int StatusCode { get; set; }
        public string ReasonPhrase { get; set; }
        public IHeaderDictionary Headers { get; set; } = new HeaderDictionary();
        public Stream Body { get; set; }
 
        public bool HasStarted => false;
 
        public void OnCompleted(Func<object, Task> callback, object state)
        {
            CompletedCallbacks.Add((callback, state));
        }
 
        public void OnStarting(Func<object, Task> callback, object state)
        {
        }
    }
 
    private class TestSession : ISession
    {
        private readonly Dictionary<string, byte[]> _store
            = new Dictionary<string, byte[]>(StringComparer.OrdinalIgnoreCase);
 
        public string Id { get; set; }
 
        public bool IsAvailable { get; } = true;
 
        public IEnumerable<string> Keys { get { return _store.Keys; } }
 
        public void Clear()
        {
            _store.Clear();
        }
 
        public Task CommitAsync(CancellationToken cancellationToken)
        {
            return Task.FromResult(0);
        }
 
        public Task LoadAsync(CancellationToken cancellationToken)
        {
            return Task.FromResult(0);
        }
 
        public void Remove(string key)
        {
            _store.Remove(key);
        }
 
        public void Set(string key, byte[] value)
        {
            _store[key] = value;
        }
 
        public bool TryGetValue(string key, out byte[] value)
        {
            return _store.TryGetValue(key, out value);
        }
    }
 
    private class BlahSessionFeature : ISessionFeature
    {
        public ISession Session { get; set; }
    }
 
    private class TestHttpWebSocketFeature : IHttpWebSocketFeature
    {
        public bool IsWebSocketRequest
        {
            get
            {
                throw new NotImplementedException();
            }
        }
 
        public Task<WebSocket> AcceptAsync(WebSocketAcceptContext context)
        {
            throw new NotImplementedException();
        }
    }
 
    private class AsyncDisposableServiceProvider : IServiceProvider, IDisposable, IServiceScopeFactory
    {
        private readonly ServiceProvider _serviceProvider;
 
        public AsyncDisposableServiceProvider(ServiceProvider serviceProvider)
        {
            _serviceProvider = serviceProvider;
        }
 
        public List<AsyncServiceScope> Scopes { get; } = new List<AsyncServiceScope>();
 
        public object GetService(Type serviceType)
        {
            if (serviceType == typeof(IServiceScopeFactory))
            {
                return this;
            }
 
            return _serviceProvider.GetService(serviceType);
        }
 
        public void Dispose()
        {
            _serviceProvider.Dispose();
        }
 
        public IServiceScope CreateScope()
        {
            var scope = new AsyncServiceScope(_serviceProvider.GetService<IServiceScopeFactory>().CreateScope());
            Scopes.Add(scope);
            return scope;
        }
 
        internal class AsyncServiceScope : IServiceScope, IAsyncDisposable
        {
            private readonly IServiceScope _scope;
 
            public AsyncServiceScope(IServiceScope scope)
            {
                _scope = scope;
            }
 
            public bool DisposeCalled { get; set; }
 
            public bool DisposeAsyncCalled { get; set; }
 
            public void Dispose()
            {
                DisposeCalled = true;
                _scope.Dispose();
            }
 
            public ValueTask DisposeAsync()
            {
                DisposeAsyncCalled = true;
                _scope.Dispose();
                return default;
            }
 
            public IServiceProvider ServiceProvider => _scope.ServiceProvider;
        }
    }
}