File: AspireServerServiceTests.cs
Web Access
Project: ..\..\..\test\Microsoft.WebTools.AspireService.Tests\Microsoft.WebTools.AspireService.Tests.csproj (Microsoft.WebTools.AspireService.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#nullable disable
 
using System.Diagnostics;
using System.Net;
using System.Net.Http.Headers;
using System.Net.WebSockets;
using System.Security.Cryptography.X509Certificates;
using System.Text.Json;
 
namespace Aspire.Tools.Service.UnitTests;
 
public class AspireServerServiceTests(ITestOutputHelper output)
{
    private const string Project1Path = @"c:\test\Projects\project1.csproj";
    private const int ProcessId = 34213;
    private const string DcpId = "myid";
    private const string VersionedSessionUrl = $"{RunSessionRequest.Url}?{RunSessionRequest.VersionQuery}={RunSessionRequest.OurProtocolVersion}";
 
    private static readonly TestRunSessionRequest Project1SessionRequest = new TestRunSessionRequest(Project1Path, debugging: false, launchProfile: null, disableLaunchProfile: false)
    {
        args = new List<string> { "--project1Arg" },
        env = new List<EnvVar> { new EnvVar { Name = "var1", Value = "value1" } }
    };
 
    private static readonly TestRunSessionRequest Project2SessionRequest = new TestRunSessionRequest(Project1Path, debugging: false, launchProfile: null, disableLaunchProfile: false)
    {
        args = null,
        env = new List<EnvVar> { new EnvVar { Name = "var1", Value = "value1" } }
    };
 
    [Fact]
    public async Task SessionStarted_Test()
    {
        var mocks = new Mocks();
 
        var server = await GetAspireServer(mocks);
 
        // Start listening
        TaskCompletionSource<bool> connected = new();
 
        TaskCompletionSource<ProcessRestartedNotification> notificationTask = new();
        _ = ListenForSessionUpdatesAsync(server, connected, (sn) =>
        {
            notificationTask.SetResult((ProcessRestartedNotification)sn);
        });
 
        await connected.Task;
 
        await server.NotifySessionStartedAsync(DcpId,"1", ProcessId, CancellationToken.None);
 
        var result = await notificationTask.Task;
 
        Assert.Equal(ProcessId, result.PID);
        Assert.Equal("1", result.SessionId);
 
        await server.DisposeAsync();
 
        mocks.Verify();
    }
 
    [Fact]
    public async Task SessionEndedAsync_Test()
    {
        var mocks = new Mocks();
 
        var server = await GetAspireServer(mocks);
 
        // Start listening
        TaskCompletionSource<bool> connected = new();
        TaskCompletionSource<SessionTerminatedNotification> sessionEndNotificationTask = new();
        _ = ListenForSessionUpdatesAsync(server, connected, (sn) =>
        {
            if (sn.NotificationType == NotificationType.SessionTerminated)
            {
                sessionEndNotificationTask.SetResult((SessionTerminatedNotification)sn);
            }
        });
 
        await connected.Task;
 
        await server.NotifySessionEndedAsync(DcpId, "1", ProcessId, 130, CancellationToken.None);
 
        var result = await sessionEndNotificationTask.Task;
        Assert.Equal(ProcessId, result.Pid);
        Assert.Equal("1", result.SessionId);
        Assert.Equal(130, result.ExitCode);
 
        await server.DisposeAsync();
 
        mocks.Verify();
    }
 
    [Fact]
    public async Task LaunchProject_Success()
    {
        var mocks = new Mocks();
 
        mocks.GetOrCreate<IAspireServerEventsMock>()
             .ImplementStartProjectAsync(DcpId, "2");
 
        var server = await GetAspireServer(mocks);
        var tokens = server.GetServerVariables();
 
        using HttpClient client = GetHttpClient(tokens);
 
        HttpResponseMessage response;
        response = await client.PutAsJsonAsync(VersionedSessionUrl, Project1SessionRequest);
 
        Assert.Equal(HttpStatusCode.Created, response.StatusCode);
        Assert.Equal($"{client.BaseAddress}run_session/2", response.Headers.Location.AbsoluteUri);
 
        await server.DisposeAsync();
 
        mocks.Verify();
    }
 
    [Fact]
    public async Task LaunchProject_WithNullArgs_PassesThroughNullArgs()
    {
        var mocks = new Mocks();
 
        mocks.GetOrCreate<IAspireServerEventsMock>()
             .ImplementStartProjectAsync(DcpId, "2", requireNullArguments: true);
 
        var server = await GetAspireServer(mocks);
        var tokens = server.GetServerVariables();
 
        using HttpClient client = GetHttpClient(tokens);
 
        HttpResponseMessage response;
        response = await client.PutAsJsonAsync(VersionedSessionUrl, Project2SessionRequest);
 
        Assert.Equal(HttpStatusCode.Created, response.StatusCode);
        Assert.Equal($"{client.BaseAddress}run_session/2", response.Headers.Location.AbsoluteUri);
 
        await server.DisposeAsync();
 
        mocks.Verify();
    }
 
    [Fact]
    public async Task LaunchProject_Success_ThenStopProcessRequest()
    {
        var mocks = new Mocks();
 
        mocks.GetOrCreate<IAspireServerEventsMock>()
             .ImplementStartProjectAsync(DcpId, "2")
             .ImplementStopSessionAsync(DcpId, "2", exists: true)
             .ImplementStopSessionAsync(DcpId, "3", exists: false);
 
        var server = await GetAspireServer(mocks);
        var tokens = server.GetServerVariables();
 
        using HttpClient client = GetHttpClient(tokens);
 
        var response = await client.PutAsJsonAsync(VersionedSessionUrl, Project1SessionRequest);
        Assert.Equal(HttpStatusCode.Created, response.StatusCode);
 
        // Now send a stop session
        response = await client.DeleteAsync(RunSessionRequest.Url + "/2");
        Assert.Equal(HttpStatusCode.OK, response.StatusCode);
 
        // Validate NoContent response if session not found
        response = await client.DeleteAsync(RunSessionRequest.Url + "/3");
        Assert.Equal(HttpStatusCode.NoContent, response.StatusCode);
 
        await server.DisposeAsync();
 
        mocks.Verify();
    }
 
    [Fact]
    public async Task LaunchProject_FailedToLaunchProject()
    {
        var mocks = new Mocks();
 
        mocks.GetOrCreate<IAspireServerEventsMock>()
             .ImplementStartProjectAsync(DcpId, "2", new Exception("Launch project failed"));
 
        var server = await GetAspireServer(mocks);
 
        var tokens = server.GetServerVariables();
        using HttpClient client = GetHttpClient(tokens);
 
        var response = await client.PutAsJsonAsync(VersionedSessionUrl, Project1SessionRequest);
 
        Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode);
        Assert.Equal("application/json; charset=utf-8", response.Content.Headers.ContentType.ToString());
        Assert.Equal("{\"error\":{\"message\":\"Launch project failed\"}}", await response.Content.ReadAsStringAsync());
 
        await server.DisposeAsync();
        mocks.Verify();
    }
 
    [Fact]
    public async Task LaunchProject_FailNoBearerToken()
    {
        var mocks = new Mocks();
 
        var server = await GetAspireServer(mocks);
 
        var tokens = server.GetServerVariables();
        using HttpClient client = GetHttpClient(tokens);
        client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", "badToken");
 
        var response = await client.PutAsJsonAsync(VersionedSessionUrl, Project1SessionRequest);
 
        Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode);
 
        await server.DisposeAsync();
        mocks.Verify();
    }
 
    [Fact]
    public async Task LaunchProject_FailWrongUrl()
    {
        var mocks = new Mocks();
 
        var server = await GetAspireServer(mocks);
 
        var tokens = server.GetServerVariables();
        using HttpClient client = GetHttpClient(tokens);
 
        var response = await client.PutAsJsonAsync("/run_badurl", Project1SessionRequest);
 
        Assert.Equal(HttpStatusCode.NotFound, response.StatusCode);
 
        await server.DisposeAsync();
 
        mocks.Verify();
    }
 
    [Fact]
    public async Task LaunchProject_NotAPUTRequest()
    {
        var mocks = new Mocks();
 
        var aspireServer = await GetAspireServer(mocks);
 
        var tokens = aspireServer.GetServerVariables();
        using HttpClient client = GetHttpClient(tokens);
 
        var response = await client.PostAsJsonAsync(VersionedSessionUrl, Project1SessionRequest);
 
        Assert.Equal(HttpStatusCode.MethodNotAllowed, response.StatusCode);
 
        await aspireServer.DisposeAsync();
 
        mocks.Verify();
    }
 
    [Fact]
    public async Task StopSession_FailNoBearerToken()
    {
        var mocks = new Mocks();
 
        var server = await GetAspireServer(mocks);
 
        var tokens = server.GetServerVariables();
        using HttpClient client = GetHttpClient(tokens);
        client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", "badToken");
 
        var response = await client.DeleteAsync(RunSessionRequest.Url + "/2");
 
        Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode);
 
        await server.DisposeAsync();
        mocks.Verify();
    }
 
    [Fact]
    public async Task Info_Success()
    {
        var mocks = new Mocks();
 
        var server = await GetAspireServer(mocks);
 
        var tokens = server.GetServerVariables();
        using HttpClient client = GetHttpClient(tokens);
 
        var response = await client.GetAsync(InfoResponse.Url);
 
        Assert.Equal(HttpStatusCode.OK, response.StatusCode);
 
        await server.DisposeAsync();
        mocks.Verify();
    }
 
    [Fact]
    public async Task Info_FailNoBearerToken()
    {
        var mocks = new Mocks();
 
        var server = await GetAspireServer(mocks);
 
        var tokens = server.GetServerVariables();
        using HttpClient client = GetHttpClient(tokens);
        client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", "badToken");
 
        var response = await client.GetAsync(InfoResponse.Url);
 
        Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode);
 
        await server.DisposeAsync();
        mocks.Verify();
    }
 
    [Fact]
    public async Task SendLogMessageAsync_Test()
    {
        var mocks = new Mocks();
 
        var aspireServer = await GetAspireServer(mocks);
 
 
        // Start listening
        TaskCompletionSource<bool> connected = new();
        TaskCompletionSource<ServiceLogsNotification> notificationTask = new();
        _ = ListenForSessionUpdatesAsync(aspireServer, connected, (sn) =>
        {
            notificationTask.SetResult((ServiceLogsNotification)sn);
        });
 
        await connected.Task;
 
        await aspireServer.NotifyLogMessageAsync(DcpId, "1", isStdErr: false, "My Message", CancellationToken.None);
 
        var result = await notificationTask.Task;
 
        Assert.Equal("My Message", result.LogMessage);
        Assert.False(result.IsStdErr);
        await aspireServer.DisposeAsync();
 
        mocks.Verify();
    }
 
    [Fact]
    public async Task GetEnvironmentForOrchestrator_Tests()
    {
        var mocks = new Mocks();
 
        var server = await GetAspireServer(mocks, waitForListening: false);
 
        // First time should create a key
        var envVars = server.GetServerConnectionEnvironment();
 
        Assert.Equal(3, envVars.Count);
        var token = envVars[1];
        Assert.NotNull(token.Value);
 
        // Should return the same
        envVars = server.GetServerConnectionEnvironment();
        Assert.Equal(token, envVars[1]);
 
        mocks.Verify();
    }
 
    private async Task ListenForSessionUpdatesAsync(AspireServerService aspireServer, TaskCompletionSource<bool> connected, Action<SessionNotification> callback)
    {
        var tokens = aspireServer.GetServerVariables();
        using var httpClient = GetHttpClient(tokens);
 
        using var ws = new ClientWebSocket();
        ws.Options.SetRequestHeader("Authorization", $"Bearer {tokens.bearerToken}");
        try
        {
            await ws.ConnectAsync(new Uri($"wss://{tokens.serverAddress}{RunSessionRequest.Url}{SessionNotification.Url}"), httpClient, CancellationToken.None);
        }
        catch (Exception ex)
        {
            Assert.Fail("Could not connect to session update endpoint: " + ex.ToString());
            connected.SetResult(false);
            return;
        }
 
        connected.SetResult(true);
 
        while (ws.State == WebSocketState.Open)
        {
            string message;
            try
            {
                (message, var messageType) = await GetSocketMsgAsync(ws);
 
                if (messageType == WebSocketMessageType.Close)
                {
                    await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None);
                    return;
                }
            }
            catch
            {
                // This is expected if the connection is closed
                Assert.Equal(WebSocketState.Closed, ws.State);
                return;
            }
 
            var notification = JsonSerializer.Deserialize<SessionNotification>(message, AspireServerService.JsonSerializerOptions);
            Assert.NotNull(notification);
 
            SessionNotification value = notification.NotificationType switch
            {
                NotificationType.ProcessRestarted => JsonSerializer.Deserialize<ProcessRestartedNotification>(message, AspireServerService.JsonSerializerOptions),
                NotificationType.SessionTerminated => JsonSerializer.Deserialize<SessionTerminatedNotification>(message, AspireServerService.JsonSerializerOptions),
                NotificationType.ServiceLogs => JsonSerializer.Deserialize<ServiceLogsNotification>(message, AspireServerService.JsonSerializerOptions),
                _ => throw new InvalidOperationException($"Unexpected {notification.NotificationType}")
            };
 
            Assert.NotNull(value);
            callback.Invoke(value);
        }
    }
 
    private static HttpClient GetHttpClient((string serverAddress, string bearerToken, string certToken) tokens)
    {
        HttpClient client;
        var serverCert = X509CertificateLoader.LoadCertificate(Convert.FromBase64String(tokens.certToken));
        var clientHandler = new HttpClientHandler()
        {
            ClientCertificateOptions = ClientCertificateOption.Manual,
            SslProtocols = System.Security.Authentication.SslProtocols.None,
            ServerCertificateCustomValidationCallback = (message, cert, chain, errors) =>
            {
                return cert?.Thumbprint == serverCert.Thumbprint;
            }
        };
 
        client = new HttpClient(clientHandler);
        client.BaseAddress = new Uri($"https://{tokens.serverAddress}");
 
        client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.bearerToken);
        client.DefaultRequestHeaders.Add(HttpContextExtensions.DCPInstanceIDHeader, DcpId);
 
        return client;
    }
 
    private async Task<(string, WebSocketMessageType)> GetSocketMsgAsync(ClientWebSocket client)
    {
        var rcvBuffer = new ArraySegment<byte>(new byte[2048]);
        WebSocketReceiveResult rcvResult = await client.ReceiveAsync(rcvBuffer, CancellationToken.None);
        if (rcvResult.MessageType == WebSocketMessageType.Text)
        {
            byte[] msgBytes = rcvBuffer.Skip(rcvBuffer.Offset).Take(rcvResult.Count).ToArray();
            return (Encoding.UTF8.GetString(msgBytes), rcvResult.MessageType);
        }
 
        return (null, rcvResult.MessageType);
    }
 
    private async Task<AspireServerService> GetAspireServer(Mocks mocks, bool waitForListening = true)
    {
        var serverEvents = mocks.GetOrCreate<IAspireServerEventsMock>();
 
        var aspireServer = new AspireServerService(serverEvents.Object, displayName: "Test server",
            line =>
            {
                output.WriteLine(line);
                Debug.WriteLine(line);
            });
 
        if (waitForListening)
        {
            await aspireServer.WaitForListeningAsync();
        }
 
        return aspireServer;
    }
 
#pragma warning disable IDE1006 // Naming Styles
    internal class TestRunSessionRequestP4
    {
        public string project_path { get; set; } = string.Empty;
        public bool debug { get; set; }
        public List<EnvVar> env { get; set; } = new List<EnvVar>();
        public List<string> args { get; set; } = new List<string>();
        public string launch_profile { get; set; }
        public bool disable_launch_profile { get; set; }
    }
 
    internal class TestRunSessionRequest
    {
        public TestRunSessionRequest(string projectPath, bool debugging, string launchProfile, bool disableLaunchProfile)
        {
            launch_configurations = new TestLaunchConfiguration[]
            {
                new() {
                    project_path = projectPath,
                    type = RunSessionRequest.ProjectLaunchConfigurationType,
                    mode= debugging? RunSessionRequest.DebugLaunchMode : RunSessionRequest.NoDebugLaunchMode,
                    launch_profile = launchProfile,
                    disable_launch_profile = disableLaunchProfile
                }
            };
        }
        public TestLaunchConfiguration[] launch_configurations { get; set; }
        public List<EnvVar> env { get; set; } = new List<EnvVar>();
        public List<string> args { get; set; } = new List<string>();
 
        public TestRunSessionRequestP4 ToTestRunSessionRequestP4()
        {
            var launchConfig = launch_configurations[0];
            return new TestRunSessionRequestP4()
            {
                project_path = launchConfig.project_path,
                debug = string.Equals(launchConfig.mode, RunSessionRequest.DebugLaunchMode, StringComparison.OrdinalIgnoreCase),
                args = args,
                env = env,
                launch_profile = launchConfig.launch_profile,
                disable_launch_profile = launchConfig.disable_launch_profile
            };
        }
    }
 
    internal class TestLaunchConfiguration
    {
        public string type { get; set; } = string.Empty;
        public string project_path { get; set; } = string.Empty;
        public string launch_profile { get; set; }
        public bool disable_launch_profile { get; set; }
        public string mode { get; set; } = string.Empty;
    }
 
    internal class TestStopSessionRequest
    {
        public string session_id { get; set; } = string.Empty;
    }
#pragma warning restore IDE1006 // Naming Styles
}
 
internal static class AspireServerServiceExtensions
{
    public static async Task WaitForListeningAsync(this AspireServerService aspireServer)
    {
        string serverAddress = aspireServer.GetServerVariables().serverAddress;
 
        // We need to wait on the port being available
        await Helpers.CanConnectToPortAsync(new Uri($"http://{serverAddress}"), 5000, CancellationToken.None);
 
    }
 
    public static (string serverAddress, string bearerToken, string certToken) GetServerVariables(this AspireServerService aspireServer)
    {
        var enVars = aspireServer.GetServerConnectionEnvironment();
        return (enVars[0].Value, enVars[1].Value, enVars[2].Value);
    }
}