File: XunitParallelTheoryTestCaseRunner.cs
Web Access
Project: src\src\Microsoft.DotNet.XUnitExtensions\src\Microsoft.DotNet.XUnitExtensions.csproj (Microsoft.DotNet.XUnitExtensions)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Xunit.Abstractions;
using Xunit.Sdk;
 
namespace Microsoft.DotNet.XUnitExtensions
{
    /// <summary>
    /// This class uses the code from <see cref="XunitTheoryTestCaseRunner"/> with some slight modifications to run tests non-sequentially in <see cref="RunTestAsync"/>.
    /// </summary>
    internal sealed class XunitParallelTheoryTestCaseRunner : XunitTestCaseRunner
    {
        private static readonly object[] s_noArguments = new object[0];
        private readonly ExceptionAggregator _cleanupAggregator = new();
        private Exception _dataDiscoveryException;
        private readonly List<XunitTestRunner> _testRunners = new();
        private readonly List<IDisposable> _toDispose = new();
 
        /// <summary>
        /// Initializes a new instance of the <see cref="XunitParallelTheoryTestCaseRunner"/> class.
        /// </summary>
        /// <param name="testCase">The test case to be run.</param>
        /// <param name="displayName">The display name of the test case.</param>
        /// <param name="skipReason">The skip reason, if the test is to be skipped.</param>
        /// <param name="constructorArguments">The arguments to be passed to the test class constructor.</param>
        /// <param name="diagnosticMessageSink">The message sink used to send diagnostic messages</param>
        /// <param name="messageBus">The message bus to report run status to.</param>
        /// <param name="aggregator">The exception aggregator used to run code and collect exceptions.</param>
        /// <param name="cancellationTokenSource">The task cancellation token source, used to cancel the test run.</param>
        public XunitParallelTheoryTestCaseRunner(IXunitTestCase testCase,
                                         string displayName,
                                         string skipReason,
                                         object[] constructorArguments,
                                         IMessageSink diagnosticMessageSink,
                                         IMessageBus messageBus,
                                         ExceptionAggregator aggregator,
                                         CancellationTokenSource cancellationTokenSource)
            : base(testCase, displayName, skipReason, constructorArguments, s_noArguments, messageBus, aggregator, cancellationTokenSource)
        {
            DiagnosticMessageSink = diagnosticMessageSink;
        }
 
        private IMessageSink DiagnosticMessageSink { get; }
 
        /// <inheritdoc/>
        protected override async Task AfterTestCaseStartingAsync()
        {
            await base.AfterTestCaseStartingAsync();
 
            try
            {
                IEnumerable<IAttributeInfo> dataAttributes = TestCase.TestMethod.Method.GetCustomAttributes(typeof(DataAttribute));
 
                foreach (IAttributeInfo dataAttribute in dataAttributes)
                {
                    IAttributeInfo discovererAttribute = dataAttribute.GetCustomAttributes(typeof(DataDiscovererAttribute)).First();
 
                    IDataDiscoverer discoverer;
                    try
                    {
                        discoverer = ExtensibilityPointFactory.GetDataDiscoverer(DiagnosticMessageSink, discovererAttribute);
                    }
                    catch (InvalidCastException)
                    {
                        if (dataAttribute is IReflectionAttributeInfo reflectionAttribute)
                            Aggregator.Add(new InvalidOperationException($"Data discoverer specified for {reflectionAttribute.Attribute.GetType()} on {TestCase.TestMethod.TestClass.Class.Name}.{TestCase.TestMethod.Method.Name} does not implement IDataDiscoverer."));
                        else
                            Aggregator.Add(new InvalidOperationException($"A data discoverer specified on {TestCase.TestMethod.TestClass.Class.Name}.{TestCase.TestMethod.Method.Name} does not implement IDataDiscoverer."));
 
                        continue;
                    }
 
                    IEnumerable<object[]> data = discoverer.GetData(dataAttribute, TestCase.TestMethod.Method);
                    if (data == null)
                    {
                        Aggregator.Add(new InvalidOperationException($"Test data returned null for {TestCase.TestMethod.TestClass.Class.Name}.{TestCase.TestMethod.Method.Name}. Make sure it is statically initialized before this test method is called."));
                        continue;
                    }
 
                    foreach (object[] dataRow in data)
                    {
                        _toDispose.AddRange(dataRow.OfType<IDisposable>());
 
                        ITypeInfo[] resolvedTypes = null;
                        MethodInfo methodToRun = TestMethod;
                        object[] convertedDataRow = methodToRun.ResolveMethodArguments(dataRow);
 
                        if (methodToRun.IsGenericMethodDefinition)
                        {
                            resolvedTypes = TestCase.TestMethod.Method.ResolveGenericTypes(convertedDataRow);
                            methodToRun = methodToRun.MakeGenericMethod(resolvedTypes.Select(t => ((IReflectionTypeInfo)t).Type).ToArray());
                        }
 
                        Type[] parameterTypes = methodToRun.GetParameters().Select(p => p.ParameterType).ToArray();
                        convertedDataRow = Reflector.ConvertArguments(convertedDataRow, parameterTypes);
 
                        string theoryDisplayName = TestCase.TestMethod.Method.GetDisplayNameWithArguments(DisplayName, convertedDataRow, resolvedTypes);
                        ITest test = CreateTest(TestCase, theoryDisplayName);
                        string skipReason = SkipReason ?? dataAttribute.GetNamedArgument<string>("Skip");
                        _testRunners.Add(CreateTestRunner(test, MessageBus, TestClass, ConstructorArguments, methodToRun, convertedDataRow, skipReason, BeforeAfterAttributes, Aggregator, CancellationTokenSource));
                    }
                }
            }
            catch (Exception ex)
            {
                // Stash the exception so we can surface it during RunTestAsync
                _dataDiscoveryException = ex;
            }
        }
 
        /// <inheritdoc/>
        protected override Task BeforeTestCaseFinishedAsync()
        {
            Aggregator.Aggregate(_cleanupAggregator);
 
            return base.BeforeTestCaseFinishedAsync();
        }
 
        /// <inheritdoc/>
        protected override async Task<RunSummary> RunTestAsync()
        {
            if (_dataDiscoveryException != null)
                return RunTest_DataDiscoveryException();
 
            var runningTests = new List<Task<RunSummary>>(_testRunners.Count);
            foreach (XunitTestRunner testRunner in _testRunners)
                runningTests.Add(testRunner.RunAsync());
 
            RunSummary[] results = await Task.WhenAll(runningTests);
            var runSummary = new RunSummary();
            foreach (RunSummary result in results)
            {
                runSummary.Aggregate(result);
            }
 
            // Run the cleanup here so we can include cleanup time in the run summary,
            // but save any exceptions so we can surface them during the cleanup phase,
            // so they get properly reported as test case cleanup failures.
            var timer = new ExecutionTimer();
            foreach (IDisposable disposable in _toDispose)
                timer.Aggregate(() => _cleanupAggregator.Run(disposable.Dispose));
 
            runSummary.Time += timer.Total;
            return runSummary;
        }
 
        private RunSummary RunTest_DataDiscoveryException()
        {
            var test = new XunitTest(TestCase, DisplayName);
 
            if (!MessageBus.QueueMessage(new TestStarting(test)))
                CancellationTokenSource.Cancel();
            else if (!MessageBus.QueueMessage(new TestFailed(test, 0, null, Unwrap(_dataDiscoveryException))))
                CancellationTokenSource.Cancel();
            if (!MessageBus.QueueMessage(new TestFinished(test, 0, null)))
                CancellationTokenSource.Cancel();
 
            return new RunSummary { Total = 1, Failed = 1 };
        }
 
        private static Exception Unwrap(Exception ex)
        {
            while (true)
            {
                if (ex is not TargetInvocationException tiex)
                    return ex;
 
                ex = tiex.InnerException;
            }
        }
    }
}