File: AspireServerService.cs
Web Access
Project: ..\..\..\src\BuiltInTools\dotnet-watch\dotnet-watch.csproj (dotnet-watch)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#nullable enable
 
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.WebSockets;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using IAsyncDisposable = System.IAsyncDisposable;
 
namespace Aspire.Tools.Service;
 
/// <summary>
/// Implementation of the AspireServerService. A new instance of this service will be created for each
/// each call to IServiceBroker.CreateProxy()
/// </summary>
internal partial class AspireServerService : IAsyncDisposable
{
    public const string DebugSessionPortEnvVar = "DEBUG_SESSION_PORT";
    public const string DebugSessionTokenEnvVar = "DEBUG_SESSION_TOKEN";
    public const string DebugSessionServerCertEnvVar = "DEBUG_SESSION_SERVER_CERTIFICATE";
 
    public const int PingIntervalInSeconds = 5;
 
    private readonly IAspireServerEvents _aspireServerEvents;
 
    private readonly Action<string>? _reporter;
 
    private readonly string _currentSecret;
    private readonly string _displayName;
 
    private readonly CancellationTokenSource _shutdownCancellationTokenSource = new();
    private readonly int _port;
    private readonly X509Certificate2 _certificate;
    private readonly string _certificateEncodedBytes;
 
    private readonly SemaphoreSlim _webSocketAccess = new(1);
 
    private readonly SocketConnectionManager _socketConnectionManager = new();
 
    private volatile bool _isDisposed;
 
    private static readonly char[] s_charSeparator = { ' ' };
 
    private readonly Task _requestListener;
 
    public static readonly JsonSerializerOptions JsonSerializerOptions = new()
    {
        DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
        PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
        Converters =
        {
            new JsonStringEnumConverter(JsonNamingPolicy.CamelCase, allowIntegerValues: false)
        }
    };
 
    public AspireServerService(IAspireServerEvents aspireServerEvents, string displayName, Action<string>? reporter)
    {
        _aspireServerEvents = aspireServerEvents;
        _reporter = reporter;
        _displayName = displayName;
 
        _port = SocketUtilities.GetNextAvailablePort();
 
        // Set up the encryption so we can use it to generate our secret. 
        var aes = Aes.Create();
        aes.Mode = CipherMode.CBC;
        aes.KeySize = 128;
        aes.Padding = PaddingMode.PKCS7;
        aes.GenerateKey();
        _currentSecret = Convert.ToBase64String(aes.Key);
 
        _certificate = CertGenerator.GenerateCert();
        var certBytes = _certificate.Export(X509ContentType.Cert);
        _certificateEncodedBytes = Convert.ToBase64String(certBytes);
 
        // Kick of the web server.
        _requestListener = StartListeningAsync();
    }
 
    public async ValueTask DisposeAsync()
    {
        // Shutdown the service:
        _shutdownCancellationTokenSource.Cancel();
 
        Log("Waiting for server to shutdown ...");
 
        try
        {
            await _requestListener;
        }
        catch (OperationCanceledException)
        {
            // nop
        }
 
        _isDisposed = true;
 
        _socketConnectionManager.Dispose();
        _certificate.Dispose();
        _shutdownCancellationTokenSource.Dispose();
    }
 
    /// <inheritdoc/>
    public List<KeyValuePair<string, string>> GetServerConnectionEnvironment()
        => 
        [
            new(DebugSessionPortEnvVar, $"localhost:{_port}"),
            new(DebugSessionTokenEnvVar, _currentSecret),
            new(DebugSessionServerCertEnvVar, _certificateEncodedBytes),
        ];
 
    public ValueTask NotifySessionEndedAsync(string dcpId, string sessionId, int processId, int? exitCode, CancellationToken cancelationToken)
        => SendNotificationAsync(
            new SessionTerminatedNotification()
            {
                NotificationType = NotificationType.SessionTerminated,
                SessionId = sessionId,
                Pid = processId,
                ExitCode = exitCode
            },
            dcpId,
            sessionId,
            cancelationToken);
 
    public ValueTask NotifySessionStartedAsync(string dcpId, string sessionId, int processId, CancellationToken cancelationToken)
        => SendNotificationAsync(
            new ProcessRestartedNotification()
            {
                NotificationType = NotificationType.ProcessRestarted,
                SessionId = sessionId,
                PID = processId
            },
            dcpId,
            sessionId,
            cancelationToken);
 
    public ValueTask NotifyLogMessageAsync(string dcpId, string sessionId, bool isStdErr, string data, CancellationToken cancelationToken)
        => SendNotificationAsync(
            new ServiceLogsNotification()
            {
                NotificationType = NotificationType.ServiceLogs,
                SessionId = sessionId,
                IsStdErr = isStdErr,
                LogMessage = data
            },
            dcpId,
            sessionId,
            cancelationToken);
 
    private async ValueTask SendNotificationAsync<TNotification>(TNotification notification, string dcpId, string sessionId, CancellationToken cancelationToken)
        where TNotification : SessionNotification
    {
        try
        {
            Log($"[#{sessionId}] Sending '{notification.NotificationType}'");
            var jsonSerialized = JsonSerializer.SerializeToUtf8Bytes(notification, JsonSerializerOptions);
            await SendMessageAsync(dcpId, jsonSerialized, cancelationToken);
        }
        catch (Exception e) when (e is not OperationCanceledException && LogAndPropagate(e))
        {
        }
 
        bool LogAndPropagate(Exception e)
        {
            Log($"[#{sessionId}] Sending '{notification.NotificationType}' failed: {e.Message}");
            return false;
        }
    }
 
    /// <summary>
    /// Waits for a connection so that it can get the WebSocket that will be used to send messages tio the client. It accepts messages via Restful http
    /// calls.
    /// </summary>
    private Task StartListeningAsync()
    {
        var builder = WebApplication.CreateSlimBuilder();
 
        builder.WebHost.ConfigureKestrel(kestrelOptions =>
        {
            kestrelOptions.ListenLocalhost(_port, listenOptions =>
            {
                listenOptions.UseHttps(_certificate);
            });
        });
 
        if (_reporter != null)
        {
            builder.Logging.ClearProviders();
            builder.Logging.AddProvider(new LoggerProvider(_reporter));
        }
 
        var app = builder.Build();
 
        app.MapGet("/", () => _displayName);
        app.MapGet(InfoResponse.Url, GetInfoAsync);
 
        // Set up the run session endpoints
        var runSessionApi = app.MapGroup(RunSessionRequest.Url);
 
        runSessionApi.MapPut("/", RunSessionPutAsync);
        runSessionApi.MapDelete("/{sessionId}", RunSessionDeleteAsync);
        runSessionApi.Map(SessionNotification.Url, RunSessionNotifyAsync);
 
        app.UseWebSockets(new WebSocketOptions
        {
             KeepAliveInterval = TimeSpan.FromSeconds(PingIntervalInSeconds)
        });
 
        // Run the application async. It will shutdown when the cancel token is signaled
        return app.RunAsync(_shutdownCancellationTokenSource.Token);
    }
 
    private async Task RunSessionPutAsync(HttpContext context)
    {
        // Check the authentication header
        if (!IsValidAuthentication(context))
        {
            Log("Authorization failure");
            context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
        }
        else
        {
            await HandleStartSessionRequestAsync(context);
        }
    }
 
    private async Task RunSessionDeleteAsync(HttpContext context, string sessionId)
    {
        // Check the authentication header
        if (!IsValidAuthentication(context))
        {
            Log("Authorization failure");
            context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
        }
        else
        {
            await HandleStopSessionRequestAsync(context, sessionId);
        }
    }
 
    private async Task GetInfoAsync(HttpContext context)
    {
        // Check the authentication header
        if (!IsValidAuthentication(context))
        {
            Log("Authorization failure");
            context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
        }
        else
        {
            context.Response.StatusCode = (int)HttpStatusCode.OK;
            await context.Response.WriteAsJsonAsync(InfoResponse.Instance, JsonSerializerOptions, _shutdownCancellationTokenSource.Token);
        }
    }
 
    private async Task RunSessionNotifyAsync(HttpContext context)
    {
        // Check the authentication header
        if (!IsValidAuthentication(context))
        {
            Log("Authorization failure");
            context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
            return;
        }
        else if (!context.WebSockets.IsWebSocketRequest)
        {
            context.Response.StatusCode = StatusCodes.Status400BadRequest;
            return;
        }
 
        var webSocket = await context.WebSockets.AcceptWebSocketAsync();
        var socketTcs = new TaskCompletionSource();
 
        // Track this connection.
        _socketConnectionManager.AddSocketConnection(webSocket, socketTcs,  context.GetDcpId(), context.RequestAborted);
 
        // We must keep the middleware pipeline alive for the duration of the socket
        await socketTcs.Task;
    }
 
    private void Log(string message)
    {
        _reporter?.Invoke(message);
    }
 
    private bool IsValidAuthentication(HttpContext context)
    {
        // Check the authentication header
        var authHeader = context.Request.Headers.Authorization;
        if (authHeader.Count == 1)
        {
            var authTokens = authHeader[0]!.Split(s_charSeparator, StringSplitOptions.RemoveEmptyEntries);
 
            return authTokens.Length == 2 &&
                   string.Equals(authTokens[0], "Bearer", StringComparison.Ordinal) &&
                   string.Equals(authTokens[1], _currentSecret, StringComparison.Ordinal);
        }
 
        return false;
    }
 
    private async Task HandleStartSessionRequestAsync(HttpContext context)
    {
        string? projectPath = null;
 
        try
        {
            if (_isDisposed)
            {
                throw new ObjectDisposedException(nameof(AspireServerService), "Received 'PUT /run_session' request after the service has been disposed.");
            }
 
            // Get the project launch request data
            var projectLaunchRequest = await context.GetProjectLaunchInformationAsync(_shutdownCancellationTokenSource.Token);
            if (projectLaunchRequest == null)
            {
                // Unknown or unsupported version
                context.Response.StatusCode = (int)HttpStatusCode.BadRequest;
                return;
            }
 
            projectPath = projectLaunchRequest.ProjectPath;
 
            var sessionId = await _aspireServerEvents.StartProjectAsync(context.GetDcpId(), projectLaunchRequest, _shutdownCancellationTokenSource.Token);
 
            context.Response.StatusCode = (int)HttpStatusCode.Created;
            context.Response.Headers.Location = $"{context.Request.Scheme}://{context.Request.Host}{context.Request.Path}/{sessionId}";
        }
        catch (Exception e) when (e is not OperationCanceledException)
        {
            Log($"Failed to start project{(projectPath == null ? "" : $" '{projectPath}'")}: {e}");
 
            context.Response.StatusCode = (int)HttpStatusCode.InternalServerError;
            await WriteResponseTextAsync(context.Response, e, context.GetApiVersion() is not null);
        }
    }
 
    private async Task WriteResponseTextAsync(HttpResponse response, Exception ex, bool useRichErrorResponse)
    {
        byte[] errorResponse;
        if (useRichErrorResponse)
        {
            // If the exception is a webtools one, use the failure bucket strings as the error Code
            string? errorCode = null;
 
            var error = new ErrorResponse()
            {
                Error = new ErrorDetail { ErrorCode = errorCode, Message = ex.GetMessageFromException() }
            };
 
            await response.WriteAsJsonAsync(error, JsonSerializerOptions, _shutdownCancellationTokenSource.Token);
        }
        else
        {
            errorResponse = Encoding.UTF8.GetBytes(ex.GetMessageFromException());
            response.ContentType = "text/plain";
            response.ContentLength = errorResponse.Length;
            await response.WriteAsync(ex.GetMessageFromException(), _shutdownCancellationTokenSource.Token);
        }
    }
 
    private async Task SendMessageAsync(string dcpId, byte[] messageBytes, CancellationToken cancellationToken)
    {
        // Find the connection for the passed in dcpId
        WebSocketConnection? connection = _socketConnectionManager.GetSocketConnection(dcpId);
        if (connection is null)
        {
            // Most likely the connection has already gone away
            Log($"Send message failure: Connection with the following dcpId was not found {dcpId}");
            return;
        }
 
        var success = false;
        try
        {
            using var cancelTokenSource = CancellationTokenSource.CreateLinkedTokenSource(
                cancellationToken, _shutdownCancellationTokenSource.Token, connection.HttpRequestAborted);
 
            await _webSocketAccess.WaitAsync(cancelTokenSource.Token);
            await connection.Socket.SendAsync(new ArraySegment<byte>(messageBytes), WebSocketMessageType.Text, endOfMessage: true, cancelTokenSource.Token);
 
            success = true;
        }
        finally
        {
            if (!success)
            {
                // If the connection throws it almost certainly means the client has gone away, so clean up that connection
                _socketConnectionManager.RemoveSocketConnection(connection);
            }
 
            _webSocketAccess.Release();
        }
    }
 
    private async ValueTask HandleStopSessionRequestAsync(HttpContext context, string sessionId)
    {
        try
        {
            if (_isDisposed)
            {
                throw new ObjectDisposedException(nameof(AspireServerService), "Received 'DELETE /run_session' request after the service has been disposed.");
            }
 
            var sessionExists = await _aspireServerEvents.StopSessionAsync(context.GetDcpId(), sessionId, _shutdownCancellationTokenSource.Token);
            context.Response.StatusCode = (int)(sessionExists ? HttpStatusCode.OK : HttpStatusCode.NoContent);
        }
        catch (Exception e) when (e is not OperationCanceledException)
        {
            Log($"[#{sessionId}] Failed to stop: {e}");
 
            context.Response.StatusCode = (int)HttpStatusCode.InternalServerError;
            await WriteResponseTextAsync(context.Response, e, context.GetApiVersion() is not null);
        }
    }
}