File: Backchannel\ExtensionBackchannel.cs
Web Access
Project: src\src\Aspire.Cli\Aspire.Cli.Tool.csproj (aspire)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using Aspire.Cli.Resources;
using Aspire.Cli.Utils;
using Aspire.Hosting;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Spectre.Console;
using StreamJsonRpc;
 
namespace Aspire.Cli.Backchannel;
 
internal interface IExtensionBackchannel
{
    Task ConnectAsync(CancellationToken cancellationToken);
    Task<long> PingAsync(long timestamp, CancellationToken cancellationToken);
    Task DisplayMessageAsync(string emoji, string message, CancellationToken cancellationToken);
    Task DisplaySuccessAsync(string message, CancellationToken cancellationToken);
    Task DisplaySubtleMessageAsync(string message, CancellationToken cancellationToken);
    Task DisplayErrorAsync(string error, CancellationToken cancellationToken);
    Task DisplayEmptyLineAsync(CancellationToken cancellationToken);
    Task DisplayIncompatibleVersionErrorAsync(string requiredCapability, string appHostHostingSdkVersion, CancellationToken cancellationToken);
    Task DisplayCancellationMessageAsync(CancellationToken cancellationToken);
    Task DisplayLinesAsync(IEnumerable<DisplayLineState> lines, CancellationToken cancellationToken);
    Task DisplayDashboardUrlsAsync((string BaseUrlWithLoginToken, string? CodespacesUrlWithLoginToken) dashboardUrls, CancellationToken cancellationToken);
    Task ShowStatusAsync(string? status, CancellationToken cancellationToken);
    Task<T?> PromptForSelectionAsync<T>(string promptText, IEnumerable<T> choices, Func<T, string> choiceFormatter, CancellationToken cancellationToken) where T : notnull;
    Task<bool?> ConfirmAsync(string promptText, bool defaultValue, CancellationToken cancellationToken);
    Task<string?> PromptForStringAsync(string promptText, string? defaultValue, Func<string, ValidationResult>? validator, bool required, CancellationToken cancellationToken);
    Task OpenProjectAsync(string projectPath, CancellationToken cancellationToken);
}
 
internal sealed class ExtensionBackchannel(ILogger<ExtensionBackchannel> logger, ExtensionRpcTarget target, IConfiguration configuration) : IExtensionBackchannel
{
    private const string Name = "Aspire Extension";
    private const string BaselineCapability = "baseline.v1";
 
    private readonly ActivitySource _activitySource = new(nameof(ExtensionBackchannel));
    private readonly TaskCompletionSource<JsonRpc> _rpcTaskCompletionSource = new();
    private readonly string _token = configuration[KnownConfigNames.ExtensionToken]
        ?? throw new InvalidOperationException(ErrorStrings.ExtensionTokenMustBeSet);
 
    private TaskCompletionSource? _connectionSetupTcs;
 
    public async Task<long> PingAsync(long timestamp, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent ping with timestamp {Timestamp}", timestamp);
 
        var responseTimestamp = await rpc.InvokeWithCancellationAsync<long>(
            "PingAsync",
            [_token],
            cancellationToken);
 
        return responseTimestamp;
    }
 
    public async Task ConnectAsync(CancellationToken cancellationToken)
    {
        if (_connectionSetupTcs is not null)
        {
            using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
            var cancellationTask = Task.Delay(Timeout.Infinite, linkedCts.Token);
            await Task.WhenAny(_connectionSetupTcs.Task, cancellationTask).ConfigureAwait(false);
            return;
        }
 
        _connectionSetupTcs = new TaskCompletionSource();
 
        var endpoint = configuration[KnownConfigNames.ExtensionEndpoint];
        Debug.Assert(endpoint is not null);
 
        using var timer = new PeriodicTimer(TimeSpan.FromMilliseconds(50));
        var connectionAttempts = 0;
        logger.LogDebug("Starting backchannel connection to Aspire extension at {Endpoint}", endpoint);
 
        var startTime = DateTimeOffset.UtcNow;
 
        do
        {
            connectionAttempts++;
 
            try
            {
                await ConnectCoreAsync().ConfigureAwait(false);
                logger.LogDebug("Connected to ExtensionBackchannel at {Endpoint}", endpoint);
                _connectionSetupTcs.SetResult();
                return;
            }
            catch (SocketException ex)
            {
                var waitingFor = DateTimeOffset.UtcNow - startTime;
                if (waitingFor > TimeSpan.FromSeconds(10))
                {
                    logger.LogDebug("Slow polling for backchannel connection (attempt {ConnectionAttempts}), {SocketException}", connectionAttempts, ex);
                    await Task.Delay(1000, cancellationToken).ConfigureAwait(false);
                }
                else
                {
                    // We don't want to spam the logs with our early connection attempts.
                }
            }
            catch (ExtensionIncompatibleException ex)
            {
                logger.LogError(
                    "The Aspire extension is incompatible with the CLI and must be updated to a version that supports the {RequiredCapability} capability.",
                    ex.RequiredCapability
                    );
 
                // If the extension is incompatible then there is no point
                // trying to reconnect, we should propogate the exception
                // up to the code that needs to back channel so it can display
                // and error message to the user.
                _connectionSetupTcs.SetException(ex);
 
                throw;
            }
            catch (Exception ex)
            {
                logger.LogError(ex, "An unexpected error occurred while trying to connect to the backchannel.");
                _connectionSetupTcs.SetException(ex);
                throw;
            }
        } while (await timer.WaitForNextTickAsync(cancellationToken));
 
        return;
 
        async Task ConnectCoreAsync()
        {
            try
            {
                using var activity = _activitySource.StartActivity();
 
                if (_rpcTaskCompletionSource.Task.IsCompleted)
                {
                    throw new InvalidOperationException($"Already connected to {Name} backchannel.");
                }
 
                logger.LogDebug("Connecting to {Name} backchannel at {SocketPath}", Name, endpoint);
                var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                var addressParts = endpoint.Split(':');
                if (addressParts.Length != 2 || !int.TryParse(addressParts[1], out var port) || port <= 0 ||
                    port > 65535)
                {
                    throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, ErrorStrings.InvalidSocketPath, endpoint));
                }
 
                await socket.ConnectAsync(addressParts[0], port, cancellationToken);
                logger.LogDebug("Connected to {Name} backchannel at {SocketPath}", Name, endpoint);
 
                var stream = new SslStream(new NetworkStream(socket, true),
                    leaveInnerStreamOpen: true,
                    userCertificateValidationCallback: (_, c, _, e) =>
                    {
                        // Server certificate is already considered valid.
                        if (e == SslPolicyErrors.None)
                        {
                            return true;
                        }
 
                        if (c == null)
                        {
                            return false;
                        }
 
                        // Certificate isn't immediately valid. Check if it is the same as the one we expect.
                        // It's ok that comparison isn't time constant because this is public information.
                        return GetCertificate().RawData.SequenceEqual(c.GetRawCertData());
                    });
 
                await stream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions
                {
                    ClientCertificates = [GetCertificate()],
                }, cancellationToken);
 
                [UnconditionalSuppressMessage("AotAnalysis", "IL3050:RequiresDynamicCode",
                    Justification = "AddLocalRpcTarget closes on generic types if there are events on the target, which is explicitly disabled.")]
                static void AddLocalRpcTarget(JsonRpc rpc, ExtensionRpcTarget target)
                {
                    // We don't want to notify the client of events because we are not using the
                    // event system in the extension.
                    rpc.AddLocalRpcTarget(target, new JsonRpcTargetOptions() { NotifyClientOfEvents = false });
                }
 
                var rpc = new JsonRpc(new HeaderDelimitedMessageHandler(stream, stream, BackchannelJsonSerializerContext.CreateRpcMessageFormatter()));
                AddLocalRpcTarget(rpc, target);
                rpc.StartListening();
 
                var capabilities = await rpc.InvokeWithCancellationAsync<string[]>(
                    "getCapabilities",
                    [_token],
                    cancellationToken);
 
                if (!capabilities.Any(s => s == BaselineCapability))
                {
                    throw new ExtensionIncompatibleException(
                        string.Format(CultureInfo.CurrentCulture, ErrorStrings.ExtensionIncompatibleWithCli,
                            BaselineCapability),
                        BaselineCapability
                    );
                }
 
                _rpcTaskCompletionSource.SetResult(rpc);
            }
            catch (RemoteMethodNotFoundException ex)
            {
                logger.LogError(ex,
                    "Failed to connect to {Name} backchannel. The connection must be updated to a version that supports the {BaselineCapability} capability.",
                    Name,
                    BaselineCapability);
 
                throw new ExtensionIncompatibleException(
                    string.Format(CultureInfo.CurrentCulture, ErrorStrings.ExtensionIncompatibleWithCli,
                        BaselineCapability),
                    BaselineCapability
                );
            }
        }
    }
 
    public async Task DisplayMessageAsync(string emoji, string message, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent message {Message}", message);
 
        await rpc.InvokeWithCancellationAsync(
            "displayMessage",
            [_token, emoji, message],
            cancellationToken);
    }
 
    public async Task DisplaySuccessAsync(string message, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent success message {Message}", message);
 
        await rpc.InvokeWithCancellationAsync(
            "displaySuccess",
            [_token, message],
            cancellationToken);
    }
 
    public async Task DisplaySubtleMessageAsync(string message, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent subtle message {Message}", message);
 
        await rpc.InvokeWithCancellationAsync(
            "displaySubtleMessage",
            [_token, message],
            cancellationToken);
    }
 
    public async Task DisplayErrorAsync(string error, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent error message {Error}", error);
 
        await rpc.InvokeWithCancellationAsync(
            "displayError",
            [_token, error],
            cancellationToken);
    }
 
    public async Task DisplayEmptyLineAsync(CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent empty line");
 
        await rpc.InvokeWithCancellationAsync(
            "displayEmptyLine",
            [_token],
            cancellationToken);
    }
 
    public async Task DisplayIncompatibleVersionErrorAsync(string requiredCapability, string appHostHostingSdkVersion, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent incompatible version error for capability {RequiredCapability} with hosting SDK version {AppHostHostingSdkVersion}",
            requiredCapability, appHostHostingSdkVersion);
 
        await rpc.InvokeWithCancellationAsync(
            "displayIncompatibleVersionError",
            [_token, requiredCapability, appHostHostingSdkVersion],
            cancellationToken);
    }
 
    public async Task DisplayCancellationMessageAsync(CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent cancellation message");
 
        await rpc.InvokeWithCancellationAsync(
            "displayCancellationMessage",
            [_token],
            cancellationToken);
    }
 
    public async Task DisplayLinesAsync(IEnumerable<DisplayLineState> lines, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent lines for display");
 
        await rpc.InvokeWithCancellationAsync(
            "displayLines",
            [_token, lines],
            cancellationToken);
    }
 
    public async Task DisplayDashboardUrlsAsync((string BaseUrlWithLoginToken, string? CodespacesUrlWithLoginToken) dashboardUrls, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent dashboard URLs for display");
 
        var dashboardUrlsState = new DashboardUrlsState()
        {
            BaseUrlWithLoginToken = dashboardUrls.BaseUrlWithLoginToken,
            CodespacesUrlWithLoginToken = dashboardUrls.CodespacesUrlWithLoginToken
        };
 
        await rpc.InvokeWithCancellationAsync(
            "displayDashboardUrls",
            [_token, dashboardUrlsState],
            cancellationToken);
    }
 
    public async Task ShowStatusAsync(string? status, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Sent status update: {Status}", status);
 
        await rpc.InvokeWithCancellationAsync(
            "showStatus",
            [_token, status],
            cancellationToken);
    }
 
    public async Task<T?> PromptForSelectionAsync<T>(string promptText, IEnumerable<T> choices, Func<T, string> choiceFormatter,
        CancellationToken cancellationToken) where T : notnull
    {
        await ConnectAsync(cancellationToken);
 
        var choicesList = choices.ToList();
        // this will throw if formatting results in non-distinct values. that should happen because we cannot send the formatter over the wire.
        var choicesByFormattedValue = choicesList.ToDictionary(choice => choiceFormatter(choice).RemoveSpectreFormatting(), choice => choice);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Prompting for selection with text: {PromptText}, choices: {Choices}", promptText, choicesByFormattedValue.Keys);
 
        var choicesArray = choicesByFormattedValue.Keys.ToArray();
        var result = await rpc.InvokeWithCancellationAsync<string?>(
            "promptForSelection",
            [_token, promptText, choicesArray],
            cancellationToken);
 
        return result is null
            ? default
            : choicesByFormattedValue[result];
    }
 
    public async Task<bool?> ConfirmAsync(string promptText, bool defaultValue, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Prompting for confirmation with text: {PromptText}, default value: {DefaultValue}", promptText, defaultValue);
 
        var result = await rpc.InvokeWithCancellationAsync<bool?>(
            "confirm",
            [_token, promptText, defaultValue],
            cancellationToken);
 
        return result;
    }
 
    public async Task<string?> PromptForStringAsync(string promptText, string? defaultValue, Func<string, ValidationResult>? validator,
        bool required, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        target.ValidationFunction = validator;
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Prompting for string with text: {PromptText}, default value: {DefaultValue}, required: {Required}", promptText, defaultValue, required);
 
        var result = await rpc.InvokeWithCancellationAsync<string?>(
            "promptForString",
            [_token, promptText, defaultValue, required],
            cancellationToken);
 
        return result;
    }
 
    public async Task OpenProjectAsync(string projectPath, CancellationToken cancellationToken)
    {
        await ConnectAsync(cancellationToken);
 
        using var activity = _activitySource.StartActivity();
 
        var rpc = await _rpcTaskCompletionSource.Task;
 
        logger.LogDebug("Opening project at path: {ProjectPath}", projectPath);
 
        await rpc.InvokeWithCancellationAsync(
            "openProject",
            [_token, projectPath],
            cancellationToken);
    }
 
    private X509Certificate2 GetCertificate()
    {
        var serverCertificate = configuration[KnownConfigNames.ExtensionCert];
        Debug.Assert(!string.IsNullOrEmpty(serverCertificate));
        var data = Convert.FromBase64String(serverCertificate);
        return new X509Certificate2(data);
    }
}