File: AsyncContextTests.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AsyncState.Tests\Microsoft.Extensions.AsyncState.Tests.csproj (Microsoft.Extensions.AsyncState.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Threading.Tasks;
using Moq;
using Xunit;
 
namespace Microsoft.Extensions.AsyncState.Test;
public class AsyncContextTests
{
    [Fact]
    public async Task CreateAsyncContext_BeforeInitialize()
    {
        var state = new AsyncState();
        var initialContextCount = state.ContextCount;
        var context1 = new AsyncContext<IThing>(state);
        var context2 = new AsyncContext<IThing>(state);
        var obj1 = new Thing();
        var obj2 = new Thing();
 
        Assert.Equal(2, state.ContextCount - initialContextCount);
        state.Initialize();
 
        await Task.Run(() => context1.Set(obj1));
 
        await Task.Run(() =>
        {
            Assert.Same(obj1, context1.Get());
 
            Assert.Null(context2.Get());
        });
    }
 
    [Fact]
    public async Task CreateAsyncContext_AfterInitialize()
    {
        var state = new AsyncState();
        state.Initialize();
        var initialContextCount = state.ContextCount;
 
        var context1 = new AsyncContext<IThing>(state);
        var context2 = new AsyncContext<IThing>(state);
 
        Assert.Equal(2, state.ContextCount - initialContextCount);
 
        var obj1 = new Thing();
 
        await Task.Run(() => context1.Set(obj1));
 
        await Task.Run(() =>
        {
            Assert.Same(obj1, context1.Get());
 
            Assert.Null(context2.Get());
        });
    }
 
    [Fact]
    public async Task CreateAsyncContext_BeforeAndAfterInitialize()
    {
        var state = new AsyncState();
        var initialContextCount = state.ContextCount;
        var context1 = new AsyncContext<IThing>(state);
        state.Initialize();
        var context2 = new AsyncContext<IThing>(state);
        var obj1 = new Thing();
        var obj2 = new Thing();
 
        Assert.Equal(2, state.ContextCount - initialContextCount);
 
        await Task.Run(() =>
        {
            context1.Set(obj1);
            context2.Set(obj2);
        });
 
        await Task.Run(() =>
        {
            Assert.Same(obj1, context1.Get());
 
            Assert.Same(obj2, context2.Get());
        });
    }
 
    [Fact]
    public async Task Tryget_BeforeAndAfterInitialize()
    {
        var state = new AsyncState();
        var initialContextCount = state.ContextCount;
        var context1 = new AsyncContext<IThing>(state);
        Assert.False(context1.TryGet(out _));
        state.Initialize();
        var obj1 = new Thing();
 
        Assert.Equal(1, state.ContextCount - initialContextCount);
 
        await Task.Run(() =>
        {
            Assert.True(context1.TryGet(out var ctx1));
            Assert.Null(ctx1);
            context1.Set(obj1);
        });
 
        await Task.Run(() =>
        {
            Assert.True(context1.TryGet(out var ctx1));
            Assert.Same(obj1, ctx1);
        });
    }
 
    [Fact]
    public async Task CreateAsyncContextInAsync_AfterInitialize()
    {
        var state = new AsyncState();
        var initialContextCount = state.ContextCount;
        var context1 = new AsyncContext<IThing>(state);
        var obj1 = new Thing();
 
        state.Initialize();
        Assert.Equal(1, state.ContextCount - initialContextCount);
 
        await Task.Run(async () =>
        {
            context1.Set(obj1);
 
            var context2 = new AsyncContext<IThing>(state);
            var obj2 = new Thing();
            context2.Set(obj2);
 
            await Task.Run(() =>
            {
                Assert.Same(obj1, context1.Get());
 
                Assert.Same(obj2, context2.Get());
            });
        });
    }
 
    [Fact]
    public void Get_WhenNotInitialized()
    {
        var state = new AsyncState();
        var context1 = new AsyncContext<IThing>(state);
 
        Assert.Throws<InvalidOperationException>(() => _ = context1.Get());
    }
 
    [Fact]
    public void TryGet_WhenNotInitialized()
    {
        var state = new AsyncState();
        var context1 = new AsyncContext<IThing>(state);
 
        Assert.False(context1.TryGet(out var context));
        Assert.Null(context);
    }
 
    [Fact]
    public void Set_WhenNotInitialized()
    {
        var state = new AsyncState();
        var context1 = new AsyncContext<IThing>(state);
 
        Assert.Throws<InvalidOperationException>(() => context1.Set(new Thing()));
    }
 
    [Fact]
    public void Reset_DoesNotDisposeObjects()
    {
        var state = new AsyncState();
        var context = new AsyncContext<IDisposable>(state);
        var obj = new Mock<IDisposable>();
        obj.Setup(m => m.Dispose());
 
        state.Initialize();
        context.Set(obj.Object);
        state.Reset();
 
        obj.Verify(m => m.Dispose(), Times.Never);
    }
 
    [Fact]
    public async Task TwoAsyncFlows_WithDiffrentAsyncStates()
    {
        var state = new AsyncState();
        var initialContextCount = state.ContextCount;
 
        var task1 = Task.Run(async () =>
        {
            var context = new AsyncContext<IThing>(state);
            state.Initialize();
            var obj = new Thing();
 
            await Task.Run(async () =>
            {
                context.Set(obj);
 
                await Task.Run(() => Assert.Same(obj, context.Get()));
            });
 
            state.Reset();
        });
 
        var task2 = Task.Run(async () =>
        {
            var context = new AsyncContext<string>(state);
            state.Initialize();
            var obj = string.Empty;
 
            await Task.Run(async () =>
            {
                context.Set(obj);
 
                await Task.Run(() => Assert.Same(obj, context.Get()));
            });
 
            state.Reset();
        });
 
        await Task.WhenAll(task1, task2);
        Assert.Equal(2, state.ContextCount - initialContextCount);
    }
 
    [Fact]
    public async Task IndependentAsyncFlows_WithSameAsyncState()
    {
        var state = new AsyncState();
        var initialContextCount = state.ContextCount;
        var context = new AsyncContext<IThing>(state);
 
        Func<Task?> setAsyncState = async () =>
        {
            state.Initialize();
 
            await Task.Run(async () =>
            {
                var obj2 = new Thing();
                context.Set(obj2);
 
                await Task.Run(() =>
                {
                    Assert.Same(obj2, context.Get());
                });
            });
 
            state.Reset();
        };
 
        var tasks = new Task[10];
 
        for (int i = 0; i < tasks.Length; i++)
        {
            tasks[i] = Task.Run(setAsyncState);
        }
 
        await Task.WhenAll(tasks);
 
        Assert.Equal(1, state.ContextCount - initialContextCount);
    }
}