File: AsyncStateTests.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AsyncState.Tests\Microsoft.Extensions.AsyncState.Tests.csproj (Microsoft.Extensions.AsyncState.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
 
namespace Microsoft.Extensions.AsyncState.Test;
 
[SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Testing")]
public class AsyncStateTests
{
    [Fact]
    public async Task GettingAsyncContextReturnsAsyncContext()
    {
        var state = new AsyncState();
        var context = new Thing();
 
        var token = state.RegisterAsyncContext();
 
        static Task SetAsyncContext(AsyncState state, IThing context, AsyncStateToken token)
        {
            state.Initialize();
            state.Set(token, context);
 
            return Task.CompletedTask;
        }
 
        await SetAsyncContext(state, context, token);
 
        Assert.Same(context, state.Get(token));
    }
 
    [Fact]
    public async Task GettingAsyncContextReturnsNullAsyncContextIfSetToNull()
    {
        var state = new AsyncState();
        var context = new Thing();
        var token = state.RegisterAsyncContext();
        state.Initialize();
        state.Set(token, context);
 
        var checkAsyncFlowTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
        var waitForNullTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
        var afterNullCheckTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
 
        var task = Task.Run(async () =>
        {
            // The AsyncContext flows with the execution context
            Assert.Same(context, state.Get(token));
 
            checkAsyncFlowTcs.SetResult(null!);
 
#pragma warning disable VSTHRD003 // Avoid awaiting foreign Tasks
            await waitForNullTcs.Task;
#pragma warning restore VSTHRD003 // Avoid awaiting foreign Tasks
 
            try
            {
                Assert.Null(state.Get(token));
 
                afterNullCheckTcs.SetResult(null!);
            }
            catch (Exception ex)
            {
                afterNullCheckTcs.SetException(ex);
            }
        });
 
        await checkAsyncFlowTcs.Task;
 
        // Null out the context
        state.Set(token, null);
 
        waitForNullTcs.SetResult(null!);
 
        Assert.Null(state.Get(token));
 
        await afterNullCheckTcs.Task;
        await task;
    }
 
    [Fact]
    public async Task GettingAsyncContextReturnsNullAsyncContextIfChanged()
    {
        var state = new AsyncState();
        var context = new Thing();
        var token = state.RegisterAsyncContext();
 
        state.Initialize();
        state.Set(token, context);
 
        var checkAsyncFlowTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
        var waitForNullTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
        var afterNullCheckTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
 
        var task = Task.Run(async () =>
        {
            // The AsyncContext flows with the execution context
            Assert.Same(context, state.Get(token));
 
            checkAsyncFlowTcs.SetResult(null!);
 
#pragma warning disable VSTHRD003 // Avoid awaiting foreign Tasks
            await waitForNullTcs.Task;
#pragma warning restore VSTHRD003 // Avoid awaiting foreign Tasks
 
            try
            {
                Assert.Throws<InvalidOperationException>(() => state.Get(token));
 
                afterNullCheckTcs.SetResult(null!);
            }
            catch (Exception ex)
            {
                afterNullCheckTcs.SetException(ex);
            }
        });
 
        await checkAsyncFlowTcs.Task;
 
        // Set a new Async context
        state.Initialize();
        var context2 = new Thing();
        state.Set(token, context2);
 
        waitForNullTcs.SetResult(null!);
 
        Assert.Same(context2, state.Get(token));
 
        await afterNullCheckTcs.Task;
        await task;
    }
 
    [Fact]
    public async Task GettingAsyncContextDoesNotFlowIfAccessorSetToNull()
    {
        var state = new AsyncState();
        var context = new Thing();
        var token = state.RegisterAsyncContext();
        state.Initialize();
        state.Set(token, context);
 
        var checkAsyncFlowTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
 
        state.Set(token, null);
 
        var task = Task.Run(() =>
        {
            try
            {
                // The AsyncContext flows with the execution context
                Assert.Null(state.Get(token));
 
                checkAsyncFlowTcs.SetResult(null!);
            }
            catch (Exception ex)
            {
                checkAsyncFlowTcs.SetException(ex);
            }
        });
 
        await checkAsyncFlowTcs.Task;
        await task;
    }
 
    [Fact]
    public async Task GettingAsyncContextDoesNotFlowIfExecutionContextDoesNotFlow()
    {
        var state = new AsyncState();
        var context = new Thing();
        var token = state.RegisterAsyncContext();
        state.Initialize();
        state.Set(token, context);
 
        var checkAsyncFlowTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
 
        ThreadPool.UnsafeQueueUserWorkItem(_ =>
        {
            try
            {
                // The AsyncContext flows with the execution context
                Assert.Throws<InvalidOperationException>(() => state.Get(token));
                checkAsyncFlowTcs.SetResult(null!);
            }
            catch (Exception ex)
            {
                checkAsyncFlowTcs.SetException(ex);
            }
        }, null);
 
        await checkAsyncFlowTcs.Task;
    }
 
    [Fact]
    public void RegisterContextCorrectly()
    {
        var asyncState = new AsyncState();
        var initialContextCount = asyncState.ContextCount;
 
        var c1 = asyncState.RegisterAsyncContext();
        Assert.Equal(0, c1.Index - initialContextCount);
        var c2 = asyncState.RegisterAsyncContext();
        Assert.Equal(1, c2.Index - initialContextCount);
        var c3 = asyncState.RegisterAsyncContext();
        Assert.Equal(2, c3.Index - initialContextCount);
 
        Assert.Equal(3, asyncState.ContextCount - initialContextCount);
    }
 
    [Fact]
    public void EnsureCount_IncreasesCountCorrectly()
    {
        var l = new List<object?>();
        AsyncState.EnsureCount(l, 5);
        Assert.Equal(5, l.Count);
    }
 
    [Fact]
    public void EnsureCount_WhenCountLessThanExpected()
    {
        var l = new List<object?>(new object?[5]);
        AsyncState.EnsureCount(l, 2);
        Assert.Equal(5, l.Count);
    }
 
    [Fact]
    public void EnsureCount_WhenCountEqualWithExpected()
    {
        var l = new List<object?>(new object?[5]);
        AsyncState.EnsureCount(l, 5);
        Assert.Equal(5, l.Count);
    }
 
    [Fact]
    public void AsyncStateCanBeUsedInDifferentServiceProviders()
    {
        var expectedValue1 = new Tuple<double>(3.14);
        var expectedValue2 = new Tuple<int>(42);
        var spOne = CreateAsyncState<Tuple<double>>();
        var spTwo = CreateAsyncState<Tuple<int>>();
        spOne.Set(expectedValue1);
        spTwo.Set(expectedValue2);
        var value1 = spOne.Get();
        var value2 = spTwo.Get();
 
        Assert.Equal(expectedValue1, value1);
        Assert.Equal(expectedValue2, value2);
 
        static IAsyncContext<T> CreateAsyncState<T>()
            where T : notnull
        {
            var services = new ServiceCollection().AddAsyncState().BuildServiceProvider();
            services.GetRequiredService<IAsyncState>().Initialize();
            return services.GetRequiredService<IAsyncContext<T>>();
        }
    }
}