File: Commands\Test\MTP\TestApplication.cs
Web Access
Project: src\sdk\src\Cli\dotnet\dotnet.csproj (dotnet)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.IO.Pipes;
using System.Threading;
using Microsoft.DotNet.Cli.Commands.Test.IPC;
using Microsoft.DotNet.Cli.Commands.Test.IPC.Models;
using Microsoft.DotNet.Cli.Commands.Test.IPC.Serializers;
using Microsoft.DotNet.Cli.Commands.Test.Terminal;
using Microsoft.DotNet.Cli.Utils;
using Microsoft.DotNet.ProjectTools;

namespace Microsoft.DotNet.Cli.Commands.Test;

internal sealed class TestApplication(
    TestModule module,
    BuildOptions buildOptions,
    TestOptions testOptions,
    TerminalTestReporter output,
    Action<CommandLineOptionMessages> onHelpRequested) : IDisposable
{
    private static readonly Version ProtocolVersion_1_1 = new(1, 1, 0);
    private const int LiveOutputTailLineCount = 200;

    private readonly Lock _requestLock = new();
    private readonly BuildOptions _buildOptions = buildOptions;
    private readonly Action<CommandLineOptionMessages> _onHelpRequested = onHelpRequested;
    private readonly TestApplicationHandler _handler = new(output, module, testOptions);

    private readonly string _pipeName = NamedPipeServer.GetPipeName(Guid.NewGuid().ToString("N"));

    private readonly List<NamedPipeServer> _testAppPipeConnections = [];
    private readonly Dictionary<NamedPipeServer, HandshakeMessage> _handshakes = new();

    private int _hasRun;
    private int _protocolNegotiated;
    private Version? _negotiatedProtocolVersion;
    private ProcessOutputCollector? _standardOutputCollector;
    private ProcessOutputCollector? _standardErrorCollector;

    public TestModule Module { get; } = module;
    public TestOptions TestOptions { get; } = testOptions;

    public bool HasFailureDuringDispose { get; private set; }

    internal bool IsProtocol_1_1_OrHigher =>
        _negotiatedProtocolVersion is { } negotiatedProtocolVersion &&
        negotiatedProtocolVersion.CompareTo(ProtocolVersion_1_1) >= 0;

    public async Task<int> RunAsync(CtrlCCancellationManager ctrlC)
    {
        if (Interlocked.Exchange(ref _hasRun, 1) != 0)
        {
            throw new InvalidOperationException(CliCommandStrings.RunAsyncCalledMoreThanOnce);
        }

        var processStartInfo = CreateProcessStartInfo();

        var cancellationTokenSource = new CancellationTokenSource();
        var cancellationToken = cancellationTokenSource.Token;
        var testAppPipeConnectionLoop = Task.Run(async () => await WaitConnectionAsync(cancellationToken));

        Process? process = null;
        try
        {
            Logger.LogTrace($"Starting test process with command '{processStartInfo.FileName}' and arguments '{processStartInfo.Arguments}'.");

            process = Process.Start(processStartInfo)!;

            // Register with the Ctrl+C manager so a force-exit (second Ctrl+C) kills this process
            // tree even if the child's own cooperative cancellation hangs.
            ctrlC.Register(process);

            // Reading from process stdout/stderr is done on separate threads to avoid blocking IO on the threadpool.
            // Note: even with 'process.StandardOutput.ReadToEndAsync()' or 'process.BeginOutputReadLine()', we ended up with
            // many TP threads just doing synchronous IO, slowing down the progress of the test run.
            // We want to read requests coming through the pipe and sending responses back to the test app as fast as possible.
            // The collector is thread-safe for the timeout case.
            // In the timeout case, we leave stdOutTask and stdErrTask running, just we stop observing them.
            var stdOutBuilder = new ProcessOutputCollector(LiveOutputTailLineCount, _handler.WriteMessage);
            var stdErrBuilder = new ProcessOutputCollector(LiveOutputTailLineCount, _handler.WriteMessage);
            Volatile.Write(ref _standardOutputCollector, stdOutBuilder);
            Volatile.Write(ref _standardErrorCollector, stdErrBuilder);

            var stdOutTask = Task.Factory.StartNew(() =>
            {
                var stdOut = process.StandardOutput;
                string? currentLine;
                while ((currentLine = stdOut.ReadLine()) is not null)
                {
                    stdOutBuilder.AddLine(currentLine, GetLiveOutputStreamingState());
                }
            }, TaskCreationOptions.LongRunning);

            var stdErrTask = Task.Factory.StartNew(() =>
            {
                var stdErr = process.StandardError;
                string? currentLine;
                while ((currentLine = stdErr.ReadLine()) is not null)
                {
                    stdErrBuilder.AddLine(currentLine, GetLiveOutputStreamingState());
                }
            }, TaskCreationOptions.LongRunning);

            // WaitForExitAsync only waits for process exit (and doesn't wait for output) for our usage here.
            // If we use BeginOutputReadLine/BeginErrorReadLine, it will also wait for output which can deadlock.
            await process.WaitForExitAsync();

            // At this point, process already exited. Allow for 5 seconds to consume stdout/stderr.
            // We might not be able to consume all the output if the test app has exited but left a child process alive.
            try
            {
                await Task.WhenAll(stdOutTask, stdErrTask).WaitAsync(TimeSpan.FromSeconds(5));
            }
            catch (TimeoutException)
            {
            }

            var exitCode = process.ExitCode;
            _handler.OnTestProcessExited(exitCode, stdOutBuilder.GetOutput(), stdErrBuilder.GetOutput());

            // This condition is to prevent considering the test app as successful when we didn't receive test session end.
            // We don't produce the exception if the exit code is already non-zero to avoid surfacing this exception when there is already a known failure.
            // For example, if hangdump timeout is reached, the process will be killed and we will have mismatching count.
            // Or if there is a crash (e.g, Environment.FailFast), etc.
            // So this is only a safe guard to avoid passing the test run if Environment.Exit(0) is called in one of the tests for example.
            if (exitCode == 0 && _handler.HasMismatchingTestSessionEventCount())
            {
                throw new InvalidOperationException(CliCommandStrings.MissingTestSessionEnd);
            }

            return exitCode;
        }
        finally
        {
            if (process is not null)
            {
                ctrlC.Unregister(process);
                process.Dispose();
            }

            Volatile.Write(ref _standardOutputCollector, null);
            Volatile.Write(ref _standardErrorCollector, null);

            cancellationTokenSource.Cancel();
            await testAppPipeConnectionLoop;
        }
    }

    private ProcessStartInfo CreateProcessStartInfo()
    {
        var processStartInfo = new ProcessStartInfo
        {
            // We should get correct RunProperties right away.
            // For the case of dotnet test --test-modules path/to/dll, the TestModulesFilterHandler is responsible
            // for providing the dotnet muxer as RunCommand, and `exec "path/to/dll"` as RunArguments.
            FileName = Module.RunProperties.Command,
            Arguments = GetArguments(),
            RedirectStandardOutput = true,
            RedirectStandardError = true,
            // False is already the default on .NET Core, but prefer to be explicit.
            UseShellExecute = false,
        };

        if (!string.IsNullOrEmpty(Module.RunProperties.WorkingDirectory))
        {
            processStartInfo.WorkingDirectory = Module.RunProperties.WorkingDirectory;
        }

        if (Module.LaunchSettings is ProjectLaunchProfile)
        {
            foreach (var entry in Module.LaunchSettings.EnvironmentVariables)
            {
                processStartInfo.Environment[entry.Key] = entry.Value;
            }

            if (!_buildOptions.NoLaunchProfileArguments &&
                !string.IsNullOrEmpty(Module.LaunchSettings.CommandLineArgs))
            {
                processStartInfo.Arguments = $"{processStartInfo.Arguments} {Module.LaunchSettings.CommandLineArgs}";
            }
        }

        // Env variables specified on command line override those specified in launch profile:
        foreach (var (name, value) in TestOptions.EnvironmentVariables)
        {
            processStartInfo.Environment[name] = value;
        }

        if (Module.DotnetRootArchVariableName is not null)
        {
            processStartInfo.Environment[Module.DotnetRootArchVariableName] = Path.GetDirectoryName(new Muxer().MuxerPath);
        }

        processStartInfo.Environment["DOTNET_CLI_TEST_COMMAND_WORKING_DIRECTORY"] = Directory.GetCurrentDirectory();
        return processStartInfo;
    }

    private string GetArguments()
    {
        // Keep RunArguments first.
        // In the case of UseAppHost=false, RunArguments is set to `exec $(TargetPath)`:
        // https://github.com/dotnet/sdk/blob/333388c31d811701e3b6be74b5434359151424dc/src/Tasks/Microsoft.NET.Build.Tasks/targets/Microsoft.NET.Sdk.targets#L1411
        // So, we keep that first always.
        // RunArguments is intentionally not escaped. It can contain multiple arguments and spaces there shouldn't cause the whole
        // value to be wrapped in double quotes. This matches dotnet run behavior.
        // In short, it's expected to already be escaped properly.
        StringBuilder builder = new(Module.RunProperties.Arguments);

        if (TestOptions.IsHelp)
        {
            builder.Append($" {CliConstants.HelpOptionKey}");
        }

        if (TestOptions.IsDiscovery)
        {
            builder.Append($" {TestCommandDefinition.MicrosoftTestingPlatform.ListTestsOptionName}");
        }

        if (_buildOptions.PathOptions.ResultsDirectoryPath is { } resultsDirectoryPath)
        {
            builder.Append($" {TestCommandDefinition.MicrosoftTestingPlatform.ResultsDirectoryOptionName} {ArgumentEscaper.EscapeSingleArg(resultsDirectoryPath)}");
        }

        if (_buildOptions.PathOptions.ConfigFilePath is { } configFilePath)
        {
            builder.Append($" {TestCommandDefinition.MicrosoftTestingPlatform.ConfigFileOptionName} {ArgumentEscaper.EscapeSingleArg(configFilePath)}");
        }

        if (_buildOptions.PathOptions.DiagnosticOutputDirectoryPath is { } diagnosticOutputDirectoryPath)
        {
            builder.Append($" {TestCommandDefinition.MicrosoftTestingPlatform.DiagnosticOutputDirectoryOptionName} {ArgumentEscaper.EscapeSingleArg(diagnosticOutputDirectoryPath)}");
        }

        foreach (var arg in _buildOptions.TestApplicationArguments)
        {
            builder.Append($" {ArgumentEscaper.EscapeSingleArg(arg)}");
        }

        builder.Append($" {CliConstants.ServerOptionKey} {CliConstants.ServerOptionValue} {CliConstants.DotNetTestPipeOptionKey} {ArgumentEscaper.EscapeSingleArg(_pipeName)}");

        return builder.ToString();
    }

    private async Task WaitConnectionAsync(CancellationToken token)
    {
        try
        {
            while (!token.IsCancellationRequested)
            {
                var pipeConnection = new NamedPipeServer(_pipeName, OnRequest, NamedPipeServerStream.MaxAllowedServerInstances, token, skipUnknownMessages: true);
                pipeConnection.RegisterAllSerializers();

                await pipeConnection.WaitConnectionAsync(token);
                _testAppPipeConnections.Add(pipeConnection);
            }
        }
        catch (OperationCanceledException ex)
        {
            // We are exiting
            Logger.LogTrace($"WaitConnectionAsync() throws OperationCanceledException with {(ex.CancellationToken == token ? "internal token" : "external token")}");
        }
        catch (Exception ex)
        {
            var exAsString = ex.ToString();
            Logger.LogTrace(exAsString);
            Environment.FailFast(exAsString);
        }
    }

    private Task<IResponse> OnRequest(NamedPipeServer server, IRequest request)
    {
        // We need to lock as we might be called concurrently when test app child processes all communicate with us.
        // For example, in a case of a sharding extension, we could get test result messages concurrently.
        // To be the most safe, we lock the whole OnRequest.
        lock (_requestLock)
        {
            try
            {
                switch (request)
                {
                    case HandshakeMessage handshakeMessage:
                        if (!_handshakes.TryAdd(server, handshakeMessage))
                        {
                            throw new InvalidOperationException(CliCommandStrings.DotnetTestDuplicateHandshakeOnConnection);
                        }
                        string negotiatedVersion = GetSupportedProtocolVersion(handshakeMessage);
                        // If the handler rejects the handshake (unsupported version, missing required
                        // properties, mismatching info, ...) respond with an empty negotiated version so
                        // Microsoft.Testing.Platform stops sending further messages on this connection.
                        bool handshakeAccepted = OnHandshakeMessage(handshakeMessage, negotiatedVersion.Length > 0);
                        SetNegotiatedProtocolVersion(handshakeAccepted ? negotiatedVersion : string.Empty);
                        return Task.FromResult((IResponse)CreateHandshakeMessage(handshakeAccepted ? negotiatedVersion : string.Empty));

                    case CommandLineOptionMessages commandLineOptionMessages:
                        OnCommandLineOptionMessages(commandLineOptionMessages);
                        break;

                    case DiscoveredTestMessages discoveredTestMessages:
                        OnDiscoveredTestMessages(discoveredTestMessages);
                        break;

                    case TestResultMessages testResultMessages:
                        OnTestResultMessages(testResultMessages);
                        break;

                    case FileArtifactMessages fileArtifactMessages:
                        OnFileArtifactMessages(fileArtifactMessages);
                        break;

                    case TestInProgressMessages testInProgressMessages:
                        OnTestInProgressMessages(testInProgressMessages);
                        break;

                    case TestSessionEvent sessionEvent:
                        OnSessionEvent(sessionEvent);
                        break;

                    // If we don't recognize the message, log and skip it
                    case UnknownMessage unknownMessage:
                        Logger.LogTrace($"Request '{request.GetType()}' with Serializer ID = {unknownMessage.SerializerId} is unsupported.");
                        return Task.FromResult((IResponse)VoidResponse.CachedInstance);

                    default:
                        // If it doesn't match any of the above, throw an exception
                        throw new NotSupportedException(string.Format(CliCommandStrings.CmdUnsupportedMessageRequestTypeException, request.GetType()));
                }
            }
            catch (Exception ex)
            {
                // BE CAREFUL:
                // When handling some of the messages, we may throw an exception in unexpected state.
                // (e.g, OnSessionEvent may throw if we receive TestSessionEnd without TestSessionStart).
                // (or if we receive help-related messages when not in help mode)
                // In that case, we FailFast.
                // The lack of FailFast *might* have unintended consequences, such as breaking the internal loop of pipe server.
                // In that case, maybe MTP app will continue waiting for response, but we don't send the response and are waiting for
                // MTP app process exit (which doesn't happen).
                // So, we explicitly FailFast here.
                string exAsString = ex.ToString();
                Logger.LogTrace(exAsString);
                Environment.FailFast(exAsString);
            }

            return Task.FromResult((IResponse)VoidResponse.CachedInstance);
        }
    }

    internal static string GetSupportedProtocolVersion(HandshakeMessage handshakeMessage)
    {
        if (!handshakeMessage.Properties.TryGetValue(HandshakeMessagePropertyNames.SupportedProtocolVersions, out string? protocolVersions) ||
            string.IsNullOrWhiteSpace(protocolVersions))
        {
            // The handshake didn't advertise any supported protocol versions. Return empty so the
            // handler can surface a dedicated "missing protocol versions" failure to the user via
            // 'HandshakeFailure' (rather than throwing here, which would route to 'FailFast').
            return string.Empty;
        }

        List<(Version Version, string Text)> sdkSupportedVersions = [];
        foreach (string supportedVersion in ProtocolConstants.SupportedVersions.Split(';'))
        {
            string trimmedSupportedVersion = supportedVersion.Trim();
            if (Version.TryParse(trimmedSupportedVersion, out Version? parsedSupportedVersion))
            {
                sdkSupportedVersions.Add((parsedSupportedVersion, trimmedSupportedVersion));
            }
        }

        Version? highestCommonVersion = null;
        string highestCommonVersionText = string.Empty;
        foreach (string advertisedVersion in protocolVersions.Split(';'))
        {
            if (!Version.TryParse(advertisedVersion.Trim(), out Version? parsedAdvertisedVersion))
            {
                continue;
            }

            foreach ((Version sdkSupportedVersion, string sdkSupportedVersionText) in sdkSupportedVersions)
            {
                if (parsedAdvertisedVersion.Equals(sdkSupportedVersion) &&
                    (highestCommonVersion is null || sdkSupportedVersion.CompareTo(highestCommonVersion) > 0))
                {
                    highestCommonVersion = sdkSupportedVersion;
                    highestCommonVersionText = sdkSupportedVersionText;
                }
            }
        }

        return highestCommonVersionText;
    }

    private static HandshakeMessage CreateHandshakeMessage(string version) =>
        new HandshakeMessage(new Dictionary<byte, string>(capacity: 5)
        {
            { HandshakeMessagePropertyNames.PID, Environment.ProcessId.ToString(CultureInfo.InvariantCulture) },
            { HandshakeMessagePropertyNames.Architecture, RuntimeInformation.ProcessArchitecture.ToString() },
            { HandshakeMessagePropertyNames.Framework, RuntimeInformation.FrameworkDescription },
            { HandshakeMessagePropertyNames.OS, RuntimeInformation.OSDescription },
            { HandshakeMessagePropertyNames.SupportedProtocolVersions, version }
        });

    private void SetNegotiatedProtocolVersion(string negotiatedVersion)
    {
        if (Version.TryParse(negotiatedVersion, out Version? parsedNegotiatedVersion) &&
            (_negotiatedProtocolVersion is null || parsedNegotiatedVersion.CompareTo(_negotiatedProtocolVersion) > 0))
        {
            _negotiatedProtocolVersion = parsedNegotiatedVersion;
        }

        Volatile.Write(ref _protocolNegotiated, 1);
        FlushBufferedOutputIfLiveStreamingEnabled();
    }

    private bool? GetLiveOutputStreamingState() =>
        Volatile.Read(ref _protocolNegotiated) == 0 ? null : IsProtocol_1_1_OrHigher;

    private void FlushBufferedOutputIfLiveStreamingEnabled()
    {
        bool? liveOutputStreamingState = GetLiveOutputStreamingState();
        Volatile.Read(ref _standardOutputCollector)?.FlushBufferedOutputIfLiveStreamingEnabled(liveOutputStreamingState);
        Volatile.Read(ref _standardErrorCollector)?.FlushBufferedOutputIfLiveStreamingEnabled(liveOutputStreamingState);
    }

    public bool OnHandshakeMessage(HandshakeMessage handshakeMessage, bool gotSupportedVersion)
        => _handler.OnHandshakeReceived(handshakeMessage, gotSupportedVersion);

    private void OnCommandLineOptionMessages(CommandLineOptionMessages commandLineOptionMessages)
    {
        if (!TestOptions.IsHelp)
        {
            throw new InvalidOperationException(CliCommandStrings.UnexpectedHelpMessage);
        }

        _onHelpRequested(commandLineOptionMessages);
    }

    private void OnDiscoveredTestMessages(DiscoveredTestMessages discoveredTestMessages)
        => _handler.OnDiscoveredTestsReceived(discoveredTestMessages);

    private void OnTestResultMessages(TestResultMessages testResultMessage)
        => _handler.OnTestResultsReceived(testResultMessage);

    private void OnFileArtifactMessages(FileArtifactMessages fileArtifactMessages)
        => _handler.OnFileArtifactsReceived(fileArtifactMessages);

    private void OnTestInProgressMessages(TestInProgressMessages testInProgressMessages)
        => _handler.OnTestInProgressReceived(testInProgressMessages);

    private void OnSessionEvent(TestSessionEvent sessionEvent)
        => _handler.OnSessionEventReceived(sessionEvent);

    private sealed class ProcessOutputCollector(int liveOutputTailLineCount, Action<string> writeOutput)
    {
        private readonly object _lock = new();
        private readonly Queue<string> _lines = [];
        private bool _liveStreamingEnabled;

        public void AddLine(string line, bool? liveOutputStreamingState)
        {
            string? outputToWrite = null;
            lock (_lock)
            {
                _lines.Enqueue(line);
                if (liveOutputStreamingState == true)
                {
                    if (_liveStreamingEnabled)
                    {
                        outputToWrite = line + Environment.NewLine;
                    }
                    else
                    {
                        _liveStreamingEnabled = true;
                        outputToWrite = JoinLinesWithTrailingNewLine(_lines);
                    }

                    TrimToBoundedTail();
                }
            }

            if (outputToWrite is not null)
            {
                writeOutput(outputToWrite);
            }
        }

        public void FlushBufferedOutputIfLiveStreamingEnabled(bool? liveOutputStreamingState)
        {
            string? outputToWrite = null;
            lock (_lock)
            {
                if (liveOutputStreamingState == true && !_liveStreamingEnabled)
                {
                    _liveStreamingEnabled = true;
                    outputToWrite = JoinLinesWithTrailingNewLine(_lines);
                    TrimToBoundedTail();
                }
            }

            if (!string.IsNullOrEmpty(outputToWrite))
            {
                writeOutput(outputToWrite);
            }
        }

        public string GetOutput()
        {
            lock (_lock)
            {
                return string.Join(Environment.NewLine, _lines);
            }
        }

        private void TrimToBoundedTail()
        {
            while (_lines.Count > liveOutputTailLineCount)
            {
                _lines.Dequeue();
            }
        }

        private static string JoinLinesWithTrailingNewLine(IEnumerable<string> lines)
        {
            StringBuilder builder = new();
            foreach (string line in lines)
            {
                builder.AppendLine(line);
            }

            return builder.ToString();
        }
    }

    public override string ToString()
    {
        StringBuilder builder = new();

        if (!string.IsNullOrEmpty(Module.RunProperties.Command))
        {
            builder.Append($"{ProjectProperties.RunCommand}: {Module.RunProperties.Command}");
        }

        if (!string.IsNullOrEmpty(Module.RunProperties.Arguments))
        {
            builder.Append($"{ProjectProperties.RunArguments}: {Module.RunProperties.Arguments}");
        }

        if (!string.IsNullOrEmpty(Module.RunProperties.WorkingDirectory))
        {
            builder.Append($"{ProjectProperties.RunWorkingDirectory}: {Module.RunProperties.WorkingDirectory}");
        }

        if (!string.IsNullOrEmpty(Module.ProjectFullPath))
        {
            builder.Append($"{ProjectProperties.ProjectFullPath}: {Module.ProjectFullPath}");
        }

        if (!string.IsNullOrEmpty(Module.TargetFramework))
        {
            builder.Append($"{ProjectProperties.TargetFramework} : {Module.TargetFramework}");
        }

        return builder.ToString();
    }

    public void Dispose()
    {
        foreach (var namedPipeServer in _testAppPipeConnections)
        {
            try
            {
                namedPipeServer.Dispose();
            }
            catch (Exception ex)
            {
                StringBuilder messageBuilder;
                if (_handshakes.TryGetValue(namedPipeServer, out var handshake))
                {
                    messageBuilder = new StringBuilder(CliCommandStrings.DotnetTestPipeFailureHasHandshake);
                    messageBuilder.AppendLine();
                    foreach (var kvp in handshake.Properties)
                    {
                        messageBuilder.AppendLine($"{kvp.Key}: {kvp.Value}");
                    }
                }
                else
                {
                    messageBuilder = new StringBuilder(CliCommandStrings.DotnetTestPipeFailureWithoutHandshake);
                    messageBuilder.AppendLine();
                }

                messageBuilder.AppendLine($"RunCommand: {Module.RunProperties.Command}");
                messageBuilder.AppendLine($"RunArguments: {Module.RunProperties.Arguments}");
                messageBuilder.AppendLine(ex.ToString());

                HasFailureDuringDispose = true;
                Reporter.Error.WriteLine(messageBuilder.ToString());
            }
        }
    }
}