File: src\VisualStudio\IntegrationTest\Harness\XUnitShared\Threading\InProcessIdeTestInvoker.cs
Web Access
Project: src\src\VisualStudio\IntegrationTest\Harness\XUnit\Microsoft.VisualStudio.Extensibility.Testing.Xunit.csproj (Microsoft.VisualStudio.Extensibility.Testing.Xunit)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
namespace Xunit.Threading
{
    using System;
    using System.Collections.Generic;
    using System.Reflection;
    using System.Threading;
    using System.Threading.Tasks;
    using Xunit.Abstractions;
    using Xunit.Harness;
    using Xunit.Sdk;
 
    public class InProcessIdeTestInvoker : XunitTestInvoker
    {
        private readonly Stack<BeforeAfterTestAttribute> _beforeAfterAttributesRun = new();
        private readonly IReadOnlyList<BeforeAfterTestAttribute> _beforeAfterAttributes;
 
        public InProcessIdeTestInvoker(ITest test, IMessageBus messageBus, Type testClass, object[] constructorArguments, MethodInfo testMethod, object[] testMethodArguments, IReadOnlyList<BeforeAfterTestAttribute> beforeAfterAttributes, ExceptionAggregator aggregator, CancellationTokenSource cancellationTokenSource)
            : base(test, messageBus, testClass, constructorArguments, testMethod, testMethodArguments, beforeAfterAttributes, aggregator, cancellationTokenSource)
        {
            _beforeAfterAttributes = beforeAfterAttributes;
        }
 
        public new Task<decimal> RunAsync()
        {
            return Aggregator.RunAsync(async delegate
            {
                if (!CancellationTokenSource.IsCancellationRequested)
                {
                    var testClassInstance = CreateTestClass();
                    try
                    {
                        var asyncLifetime = testClassInstance as IAsyncLifetime;
                        if (asyncLifetime != null)
                        {
                            try
                            {
                                await asyncLifetime.InitializeAsync().ConfigureAwait(true);
                            }
                            catch (Exception ex) when (DataCollectionService.LogAndPropagate(ex))
                            {
                                throw ExceptionUtilities.Unreachable;
                            }
                        }
 
                        if (!CancellationTokenSource.IsCancellationRequested)
                        {
                            await BeforeTestMethodInvokedAsync().ConfigureAwait(true);
                            if (!CancellationTokenSource.IsCancellationRequested && !Aggregator.HasExceptions)
                            {
                                await InvokeTestMethodAsync(testClassInstance).ConfigureAwait(true);
                            }
 
                            await AfterTestMethodInvokedAsync().ConfigureAwait(true);
                        }
 
                        if (asyncLifetime != null)
                        {
                            await Aggregator.RunAsync(async () =>
                            {
                                try
                                {
                                    await asyncLifetime.DisposeAsync().ConfigureAwait(true);
                                }
                                catch (Exception ex) when (DataCollectionService.LogAndPropagate(ex))
                                {
                                    throw ExceptionUtilities.Unreachable;
                                }
                            }).ConfigureAwait(true);
                        }
                    }
                    finally
                    {
                        Aggregator.Run(delegate
                        {
                            Test.DisposeTestClass(testClassInstance, MessageBus, Timer, CancellationTokenSource);
                        });
                    }
                }
 
                return Timer.Total;
            });
        }
 
        protected override object CreateTestClass()
        {
            try
            {
                return base.CreateTestClass();
            }
            catch (Exception ex) when (DataCollectionService.LogAndPropagate(ex))
            {
                throw ExceptionUtilities.Unreachable;
            }
        }
 
        protected override Task BeforeTestMethodInvokedAsync()
        {
            foreach (var beforeAfterAttribute in _beforeAfterAttributes)
            {
                var attributeName = beforeAfterAttribute.GetType().Name;
                if (!MessageBus.QueueMessage(new BeforeTestStarting(Test, attributeName)))
                {
                    CancellationTokenSource.Cancel();
                }
                else
                {
                    try
                    {
                        Timer.Aggregate(() => beforeAfterAttribute.Before(TestMethod));
                        _beforeAfterAttributesRun.Push(beforeAfterAttribute);
                    }
                    catch (Exception ex) when (DataCollectionService.LogAndCatch(ex))
                    {
                        Aggregator.Add(ex);
                        break;
                    }
                    finally
                    {
                        if (!MessageBus.QueueMessage(new BeforeTestFinished(Test, attributeName)))
                        {
                            CancellationTokenSource.Cancel();
                        }
                    }
                }
 
                if (CancellationTokenSource.IsCancellationRequested)
                {
                    break;
                }
            }
 
#if NET472
            return Task.CompletedTask;
#else
            var tcs = new TaskCompletionSource<bool>();
            tcs.SetResult(true);
            return tcs.Task;
#endif
        }
 
        protected override async Task<decimal> InvokeTestMethodAsync(object testClassInstance)
        {
            var oldSyncContext = SynchronizationContext.Current;
 
            try
            {
                var asyncSyncContext = new AsyncTestSyncContext(oldSyncContext);
                SynchronizationContext.SetSynchronizationContext(asyncSyncContext);
 
                await Aggregator.RunAsync(
                    () => Timer.AggregateAsync(
                        async () =>
                        {
                            var parameterCount = TestMethod.GetParameters().Length;
                            var valueCount = TestMethodArguments == null ? 0 : TestMethodArguments.Length;
                            if (parameterCount != valueCount)
                            {
                                Aggregator.Add(
                                    new InvalidOperationException(
                                        $"The test method expected {parameterCount} parameter value{(parameterCount == 1 ? string.Empty : "s")}, but {valueCount} parameter value{(valueCount == 1 ? string.Empty : "s")} {(valueCount == 1 ? "was" : "were")} provided."));
                            }
                            else
                            {
                                var result = CallTestMethod(testClassInstance);
                                var task = GetTaskFromResult(result);
                                if (task != null)
                                {
                                    if (task.Status == TaskStatus.Created)
                                    {
                                        throw new InvalidOperationException("Test method returned a non-started Task (tasks must be started before being returned)");
                                    }
 
                                    try
                                    {
                                        await task.ConfigureAwait(true);
                                    }
                                    catch (Exception ex) when (DataCollectionService.LogAndPropagate(ex))
                                    {
                                        throw ExceptionUtilities.Unreachable;
                                    }
                                }
                                else
                                {
                                    var ex = await asyncSyncContext.WaitForCompletionAsync().ConfigureAwait(true);
                                    if (ex != null)
                                    {
                                        DataCollectionService.TryLog(ex);
                                        Aggregator.Add(ex);
                                    }
                                }
                            }
                        })).ConfigureAwait(true);
            }
            finally
            {
                SynchronizationContext.SetSynchronizationContext(oldSyncContext);
            }
 
            return Timer.Total;
        }
 
        protected override object CallTestMethod(object testClassInstance)
        {
            try
            {
                return base.CallTestMethod(testClassInstance);
            }
            catch (Exception ex) when (DataCollectionService.LogAndPropagate(ex))
            {
                throw ExceptionUtilities.Unreachable;
            }
        }
 
        protected override Task AfterTestMethodInvokedAsync()
        {
            foreach (var beforeAfterAttribute in _beforeAfterAttributesRun)
            {
                var attributeName = beforeAfterAttribute.GetType().Name;
                if (!MessageBus.QueueMessage(new AfterTestStarting(Test, attributeName)))
                {
                    CancellationTokenSource.Cancel();
                }
 
                Aggregator.Run(() =>
                {
                    Timer.Aggregate(() =>
                    {
                        try
                        {
                            beforeAfterAttribute.After(TestMethod);
                        }
                        catch (Exception ex) when (DataCollectionService.LogAndPropagate(ex))
                        {
                            throw ExceptionUtilities.Unreachable;
                        }
                    });
                });
 
                if (!MessageBus.QueueMessage(new AfterTestFinished(Test, attributeName)))
                {
                    CancellationTokenSource.Cancel();
                }
            }
 
#if NET472
            return Task.CompletedTask;
#else
            var tcs = new TaskCompletionSource<bool>();
            tcs.SetResult(true);
            return tcs.Task;
#endif
        }
    }
}