File: Backchannel\AppHostCliBackchannel.cs
Web Access
Project: src\src\Aspire.Cli\Aspire.Cli.Tool.csproj (aspire)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Globalization;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using Aspire.Cli.Resources;
using Aspire.Cli.Telemetry;
using Microsoft.Extensions.Logging;
using StreamJsonRpc;
 
namespace Aspire.Cli.Backchannel;
 
internal interface IAppHostCliBackchannel
{
    Task RequestStopAsync(CancellationToken cancellationToken);
    Task<DashboardUrlsState> GetDashboardUrlsAsync(CancellationToken cancellationToken);
    IAsyncEnumerable<BackchannelLogEntry> GetAppHostLogEntriesAsync(CancellationToken cancellationToken);
    IAsyncEnumerable<RpcResourceState> GetResourceStatesAsync(CancellationToken cancellationToken);
    Task ConnectAsync(string socketPath, CancellationToken cancellationToken);
    Task ConnectAsync(string socketPath, bool autoReconnect, CancellationToken cancellationToken);
    IAsyncEnumerable<PublishingActivity> GetPublishingActivitiesAsync(CancellationToken cancellationToken);
    Task<string[]> GetCapabilitiesAsync(CancellationToken cancellationToken);
    Task CompletePromptResponseAsync(string promptId, PublishingPromptInputAnswer[] answers, CancellationToken cancellationToken);
    Task UpdatePromptResponseAsync(string promptId, PublishingPromptInputAnswer[] answers, CancellationToken cancellationToken);
    IAsyncEnumerable<CommandOutput> ExecAsync(CancellationToken cancellationToken);
}
 
internal sealed class AppHostCliBackchannel(ILogger<AppHostCliBackchannel> logger, AspireCliTelemetry telemetry) : IAppHostCliBackchannel
{
    private const string BaselineCapability = "baseline.v2";
    private TaskCompletionSource<JsonRpc> _rpcTaskCompletionSource = new();
    private string? _socketPath;
    private bool _autoReconnect;
    private CancellationToken _cancellationToken;
    private readonly object _lock = new();
    private volatile bool _isReconnecting;
 
    /// <summary>
    /// Gets the current RPC task in a thread-safe manner.
    /// </summary>
    private Task<JsonRpc> GetRpcTaskAsync()
    {
        lock (_lock)
        {
            return _rpcTaskCompletionSource.Task;
        }
    }
 
    public async Task RequestStopAsync(CancellationToken cancellationToken)
    {
        // This RPC call is required to allow the CLI to trigger a clean shutdown
        // of the AppHost process. The AppHost process will then trigger the shutdown
        // which will allow the CLI to await the pending run.
 
        using var activity = telemetry.ActivitySource.StartActivity();
        var rpc = await GetRpcTaskAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
 
        logger.LogDebug("Requesting stop");
 
        await rpc.InvokeWithCancellationAsync(
            "RequestStopAsync",
            [],
            cancellationToken);
    }
 
    public async Task<DashboardUrlsState> GetDashboardUrlsAsync(CancellationToken cancellationToken)
    {
        using var activity = telemetry.ActivitySource.StartActivity();
        var rpc = await GetRpcTaskAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
 
        logger.LogDebug("Requesting dashboard URL");
 
        var state = await rpc.InvokeWithCancellationAsync<DashboardUrlsState>(
            "GetDashboardUrlsAsync",
            [],
            cancellationToken);
        return state;
    }
 
    public async IAsyncEnumerable<BackchannelLogEntry> GetAppHostLogEntriesAsync([EnumeratorCancellation] CancellationToken cancellationToken)
    {
        while (!cancellationToken.IsCancellationRequested)
        {
            IAsyncEnumerable<BackchannelLogEntry>? logEntries = null;
            try
            {
                using var activity = telemetry.ActivitySource.StartActivity();
                var rpc = await GetRpcTaskAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
 
                logger.LogDebug("Requesting AppHost log entries");
 
                logEntries = await rpc.InvokeWithCancellationAsync<IAsyncEnumerable<BackchannelLogEntry>>(
                    "GetAppHostLogEntriesAsync",
                    [],
                    cancellationToken);
 
                logger.LogDebug("Received AppHost log entries async enumerable");
            }
            catch (Exception ex) when (_autoReconnect && !cancellationToken.IsCancellationRequested && IsConnectionLostException(ex))
            {
                logger.LogDebug("Connection lost while getting log entries, waiting for reconnect...");
                await WaitForReconnectionAsync(cancellationToken).ConfigureAwait(false);
                continue;
            }
 
            if (logEntries is not null)
            {
                await foreach (var entry in EnumerateWithReconnect(logEntries, cancellationToken))
                {
                    yield return entry;
                }
            }
        }
    }
 
    public async IAsyncEnumerable<RpcResourceState> GetResourceStatesAsync([EnumeratorCancellation] CancellationToken cancellationToken)
    {
        while (!cancellationToken.IsCancellationRequested)
        {
            IAsyncEnumerable<RpcResourceState>? resourceStates = null;
            try
            {
                using var activity = telemetry.ActivitySource.StartActivity();
                var rpc = await GetRpcTaskAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
 
                logger.LogDebug("Requesting resource states");
 
                resourceStates = await rpc.InvokeWithCancellationAsync<IAsyncEnumerable<RpcResourceState>>(
                    "GetResourceStatesAsync",
                    [],
                    cancellationToken);
 
                logger.LogDebug("Received resource states async enumerable");
            }
            catch (Exception ex) when (_autoReconnect && !cancellationToken.IsCancellationRequested && IsConnectionLostException(ex))
            {
                logger.LogDebug("Connection lost while getting resource states, waiting for reconnect...");
                await WaitForReconnectionAsync(cancellationToken).ConfigureAwait(false);
                continue;
            }
 
            if (resourceStates is not null)
            {
                await foreach (var state in EnumerateWithReconnect(resourceStates, cancellationToken))
                {
                    yield return state;
                }
            }
        }
    }
 
    private async IAsyncEnumerable<T> EnumerateWithReconnect<T>(IAsyncEnumerable<T> source, [EnumeratorCancellation] CancellationToken cancellationToken)
    {
        var enumerator = source.GetAsyncEnumerator(cancellationToken);
        try
        {
            while (true)
            {
                bool hasNext;
                T current;
                try
                {
                    hasNext = await enumerator.MoveNextAsync().ConfigureAwait(false);
                    if (!hasNext)
                    {
                        yield break;
                    }
                    current = enumerator.Current;
                }
                catch (Exception ex) when (_autoReconnect && !cancellationToken.IsCancellationRequested && IsConnectionLostException(ex))
                {
                    logger.LogDebug("Connection lost during enumeration, will restart after reconnect");
                    yield break; // Exit this enumeration, outer loop will restart
                }
 
                yield return current;
            }
        }
        finally
        {
            // Disposing a dead connection's enumerator may throw - suppress it
            try
            {
                await enumerator.DisposeAsync().ConfigureAwait(false);
            }
            catch (Exception ex) when (IsConnectionLostException(ex))
            {
                logger.LogDebug("Ignoring connection lost exception during enumerator disposal");
            }
        }
    }
 
    private static bool IsConnectionLostException(Exception ex)
    {
        return ex is ConnectionLostException
            || ex is ObjectDisposedException
            || (ex is OperationCanceledException && ex.InnerException is ConnectionLostException);
    }
 
    private async Task WaitForReconnectionAsync(CancellationToken cancellationToken)
    {
        // Wait for the TCS to be reset and then completed again
        var startTime = DateTime.UtcNow;
        var maxWait = TimeSpan.FromSeconds(60);
 
        // First, wait for the reconnection to start (TCS to be reset)
        // This handles the race where we catch the exception before OnDisconnected fires
        Task<JsonRpc>? initialTask = null;
        while (!cancellationToken.IsCancellationRequested && DateTime.UtcNow - startTime < maxWait)
        {
            var currentTask = GetRpcTaskAsync();
 
            // If this is a new TCS (different from what we had), reconnection has started
            if (initialTask is not null && !ReferenceEquals(currentTask, initialTask))
            {
                break;
            }
 
            // If we haven't captured the initial task yet, do so
            initialTask ??= currentTask;
 
            // If the current task is not completed, reconnection has started (TCS was reset)
            if (!currentTask.IsCompleted)
            {
                break;
            }
 
            await Task.Delay(100, cancellationToken).ConfigureAwait(false);
        }
 
        // Now wait for the reconnection to complete
        while (!cancellationToken.IsCancellationRequested && DateTime.UtcNow - startTime < maxWait)
        {
            var rpcTask = GetRpcTaskAsync();
            if (rpcTask.IsCompletedSuccessfully)
            {
                logger.LogDebug("Reconnection completed successfully");
                return;
            }
 
            await Task.Delay(500, cancellationToken).ConfigureAwait(false);
        }
 
        logger.LogWarning("Timed out waiting for backchannel reconnection");
    }
 
    public Task ConnectAsync(string socketPath, CancellationToken cancellationToken)
        => ConnectAsync(socketPath, autoReconnect: false, cancellationToken);
 
    public async Task ConnectAsync(string socketPath, bool autoReconnect, CancellationToken cancellationToken)
    {
        try
        {
            using var activity = telemetry.ActivitySource.StartActivity();
 
            lock (_lock)
            {
                if (_rpcTaskCompletionSource.Task.IsCompleted && !_rpcTaskCompletionSource.Task.IsFaulted)
                {
                    throw new InvalidOperationException(ErrorStrings.AlreadyConnectedToBackchannel);
                }
            }
 
            _socketPath = socketPath;
            _autoReconnect = autoReconnect;
            _cancellationToken = cancellationToken;
 
            logger.LogDebug("Connecting to AppHost backchannel at {SocketPath} (autoReconnect={AutoReconnect})", socketPath, autoReconnect);
            var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
            var endpoint = new UnixDomainSocketEndPoint(socketPath);
            await socket.ConnectAsync(endpoint, cancellationToken);
            logger.LogDebug("Connected to AppHost backchannel at {SocketPath}", socketPath);
 
            var stream = new NetworkStream(socket, true);
            var rpc = new JsonRpc(new HeaderDelimitedMessageHandler(stream, stream, BackchannelJsonSerializerContext.CreateRpcMessageFormatter()));
            rpc.StartListening();
 
            var capabilities = await rpc.InvokeWithCancellationAsync<string[]>(
                "GetCapabilitiesAsync",
                [],
                cancellationToken);
 
            if (!capabilities.Any(s => s == BaselineCapability))
            {
                throw new AppHostIncompatibleException(
                    string.Format(CultureInfo.CurrentCulture, ErrorStrings.AppHostIncompatibleWithCli, BaselineCapability),
                    BaselineCapability
                    );
            }
 
            // Set up auto-reconnect if enabled
            if (autoReconnect)
            {
                rpc.Disconnected += OnDisconnected;
            }
 
            lock (_lock)
            {
                _rpcTaskCompletionSource.SetResult(rpc);
            }
        }
        catch (RemoteMethodNotFoundException ex)
        {
            logger.LogError(ex, "Failed to connect to AppHost backchannel. The AppHost must be updated to a version that supports the {BaselineCapability} capability.", BaselineCapability);
            throw new AppHostIncompatibleException(
                string.Format(CultureInfo.CurrentCulture, ErrorStrings.AppHostIncompatibleWithCli, BaselineCapability),
                BaselineCapability
                );
        }
    }
 
    private void OnDisconnected(object? sender, JsonRpcDisconnectedEventArgs args)
    {
        // Prevent concurrent reconnection attempts
        lock (_lock)
        {
            if (_isReconnecting)
            {
                logger.LogDebug("Backchannel disconnected but reconnection already in progress, ignoring.");
                return;
            }
            _isReconnecting = true;
        }
 
        logger.LogInformation("Backchannel disconnected: {Reason}. Attempting to reconnect...", args.Reason);
        _ = Task.Run(async () =>
        {
            try
            {
                await ReconnectInternalAsync().ConfigureAwait(false);
            }
            catch (Exception ex)
            {
                logger.LogWarning(ex, "Failed to reconnect backchannel");
            }
            finally
            {
                lock (_lock)
                {
                    _isReconnecting = false;
                }
            }
        });
    }
 
    private void ResetForReconnection()
    {
        lock (_lock)
        {
            logger.LogDebug("Resetting backchannel for reconnection");
            _rpcTaskCompletionSource = new TaskCompletionSource<JsonRpc>();
        }
    }
 
    private async Task ReconnectInternalAsync()
    {
        if (_socketPath is null)
        {
            throw new InvalidOperationException("Cannot reconnect: no previous connection.");
        }
 
        ResetForReconnection();
 
        // Wait for the new socket to appear (the new DistributedApplication needs to start)
        var startTime = DateTime.UtcNow;
        var maxWait = TimeSpan.FromSeconds(30);
 
        while (!_cancellationToken.IsCancellationRequested)
        {
            try
            {
                await ConnectAsync(_socketPath, _autoReconnect, _cancellationToken).ConfigureAwait(false);
                logger.LogInformation("Successfully reconnected to backchannel");
                return;
            }
            catch (SocketException) when (DateTime.UtcNow - startTime < maxWait)
            {
                // Socket not ready yet, wait and retry
                await Task.Delay(500, _cancellationToken).ConfigureAwait(false);
            }
        }
 
        logger.LogWarning("Timed out waiting for backchannel reconnection");
    }
 
    public async IAsyncEnumerable<PublishingActivity> GetPublishingActivitiesAsync([EnumeratorCancellation] CancellationToken cancellationToken)
    {
        using var activity = telemetry.ActivitySource.StartActivity();
        var rpc = await GetRpcTaskAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
 
        logger.LogDebug("Requesting publishing activities.");
 
        var publishingActivities = await rpc.InvokeWithCancellationAsync<IAsyncEnumerable<PublishingActivity>>(
            "GetPublishingActivitiesAsync",
            [],
            cancellationToken);
 
        logger.LogDebug("Received publishing activities.");
 
        await foreach (var state in publishingActivities.WithCancellation(cancellationToken))
        {
            yield return state;
        }
    }
 
    public async Task<string[]> GetCapabilitiesAsync(CancellationToken cancellationToken)
    {
        using var activity = telemetry.ActivitySource.StartActivity();
        var rpc = await GetRpcTaskAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
 
        logger.LogDebug("Requesting capabilities");
 
        var capabilities = await rpc.InvokeWithCancellationAsync<string[]>(
            "GetCapabilitiesAsync",
            [],
            cancellationToken).ConfigureAwait(false);
 
        return capabilities;
    }
 
    public async Task CompletePromptResponseAsync(string promptId, PublishingPromptInputAnswer[] answers, CancellationToken cancellationToken)
    {
        using var activity = telemetry.ActivitySource.StartActivity();
        var rpc = await GetRpcTaskAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
 
        logger.LogDebug("Providing prompt responses for prompt ID {PromptId}", promptId);
 
        await rpc.InvokeWithCancellationAsync(
            "CompletePromptResponseAsync",
            [promptId, answers],
            cancellationToken).ConfigureAwait(false);
    }
 
    public async Task UpdatePromptResponseAsync(string promptId, PublishingPromptInputAnswer[] answers, CancellationToken cancellationToken)
    {
        using var activity = telemetry.ActivitySource.StartActivity();
        var rpc = await GetRpcTaskAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
 
        logger.LogDebug("Providing prompt responses for prompt ID {PromptId}", promptId);
 
        await rpc.InvokeWithCancellationAsync(
            "UpdatePromptResponseAsync",
            [promptId, answers],
            cancellationToken).ConfigureAwait(false);
    }
 
    public async IAsyncEnumerable<CommandOutput> ExecAsync([EnumeratorCancellation] CancellationToken cancellationToken)
    {
        using var activity = telemetry.ActivitySource.StartActivity();
        var rpc = await GetRpcTaskAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
 
        logger.LogDebug("Requesting execution.");
        var commandOutputs = await rpc.InvokeWithCancellationAsync<IAsyncEnumerable<CommandOutput>>(
            "ExecAsync",
            Array.Empty<object>(),
            cancellationToken);
 
        logger.LogDebug("Requested execution.");
        await foreach (var commandOutput in commandOutputs.WithCancellation(cancellationToken))
        {
            yield return commandOutput;
        }
    }
 
}