File: CallbackProxyTests.cs
Web Access
Project: src\tests\Aspire.Hosting.RemoteHost.Tests\Aspire.Hosting.RemoteHost.Tests.csproj (Aspire.Hosting.RemoteHost.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Text.Json.Nodes;
using Aspire.Hosting.Ats;
using Aspire.Hosting.RemoteHost.Ats;
using Xunit;
 
namespace Aspire.Hosting.RemoteHost.Tests;
 
public class CallbackProxyTests
{
    [Fact]
    public void CreateProxy_ReturnsNullForNonDelegateType()
    {
        using var factory = CreateFactory();
 
        var result = factory.CreateProxy("callback1", typeof(string));
 
        Assert.Null(result);
    }
 
    [Fact]
    public void CreateProxy_ReturnsDelegateForAnyDelegateType()
    {
        using var factory = CreateFactory();
 
        // All delegate types are now accepted - no attribute required
        var result = factory.CreateProxy("callback1", typeof(Action));
 
        Assert.NotNull(result);
        Assert.IsAssignableFrom<Action>(result);
    }
 
    [Fact]
    public void CreateProxy_ReturnsDelegateForCustomType()
    {
        using var factory = CreateFactory();
 
        var result = factory.CreateProxy("callback1", typeof(TestCallbackNoArgs));
 
        Assert.NotNull(result);
        Assert.IsAssignableFrom<TestCallbackNoArgs>(result);
    }
 
    [Fact]
    public void CreateProxy_CachesDelegate()
    {
        using var factory = CreateFactory();
 
        var result1 = factory.CreateProxy("callback1", typeof(TestCallbackNoArgs));
        var result2 = factory.CreateProxy("callback1", typeof(TestCallbackNoArgs));
 
        Assert.Same(result1, result2);
    }
 
    [Fact]
    public async Task InvokedProxy_CallsCallbackInvoker()
    {
        var invoker = new TestCallbackInvoker();
        using var factory = CreateFactory(invoker);
 
        var proxy = (TestCallbackNoArgs)factory.CreateProxy("test-callback", typeof(TestCallbackNoArgs))!;
 
        await proxy();
 
        Assert.Single(invoker.Invocations);
        Assert.Equal("test-callback", invoker.Invocations[0].CallbackId);
    }
 
    [Fact]
    public async Task InvokedProxy_ReturnsResultFromInvoker()
    {
        var invoker = new TestCallbackInvoker { ResultToReturn = JsonValue.Create(42) };
        using var factory = CreateFactory(invoker);
 
        var proxy = (TestCallbackWithIntResult)factory.CreateProxy("test-callback", typeof(TestCallbackWithIntResult))!;
 
        var result = await proxy();
 
        Assert.Equal(42, result);
    }
 
    [Fact]
    public void CancellationTokenRegistry_IsExposed()
    {
        using var factory = CreateFactory();
 
        Assert.NotNull(factory.CancellationTokenRegistry);
    }
 
    [Fact]
    public void Dispose_DisposesResources()
    {
        var factory = CreateFactory();
 
        factory.Dispose();
 
        // Should not throw when disposed
    }
 
    // Tests for callbacks with parameters (bug fix verification)
    [Fact]
    public void CreateProxy_ReturnsDelegateForCallbackWithStringParameter()
    {
        using var factory = CreateFactory();
 
        var result = factory.CreateProxy("callback1", typeof(TestCallbackWithString));
 
        Assert.NotNull(result);
        Assert.IsAssignableFrom<TestCallbackWithString>(result);
    }
 
    [Fact]
    public async Task InvokedProxy_PassesStringArgumentAsJson()
    {
        var invoker = new TestCallbackInvoker();
        using var factory = CreateFactory(invoker);
 
        var proxy = (TestCallbackWithString)factory.CreateProxy("test-callback", typeof(TestCallbackWithString))!;
 
        await proxy("hello-world");
 
        Assert.Single(invoker.Invocations);
        var args = invoker.Invocations[0].Args as JsonObject;
        Assert.NotNull(args);
        // Arguments are passed with positional keys (p0, p1, p2, ...)
        Assert.Equal("hello-world", args["p0"]?.GetValue<string>());
    }
 
    [Fact]
    public async Task InvokedProxy_PassesMultipleArgumentsAsJson()
    {
        var invoker = new TestCallbackInvoker();
        using var factory = CreateFactory(invoker);
 
        var proxy = (TestCallbackWithMultipleParams)factory.CreateProxy("test-callback", typeof(TestCallbackWithMultipleParams))!;
 
        await proxy("test-name", 42);
 
        Assert.Single(invoker.Invocations);
        var args = invoker.Invocations[0].Args as JsonObject;
        Assert.NotNull(args);
        // Arguments are passed with positional keys (p0, p1, p2, ...)
        Assert.Equal("test-name", args["p0"]?.GetValue<string>());
        Assert.Equal(42, args["p1"]?.GetValue<int>());
    }
 
    [Fact]
    public async Task InvokedProxy_WithResultReturnsCorrectValue()
    {
        var invoker = new TestCallbackInvoker { ResultToReturn = JsonValue.Create("result-value") };
        using var factory = CreateFactory(invoker);
 
        var proxy = (TestCallbackWithStringResult)factory.CreateProxy("test-callback", typeof(TestCallbackWithStringResult))!;
 
        var result = await proxy("input");
 
        Assert.Equal("result-value", result);
    }
 
    [Fact]
    public async Task InvokedProxy_WithCancellationToken_IncludesTokenInArgs()
    {
        var invoker = new TestCallbackInvoker();
        using var factory = CreateFactory(invoker);
        using var cts = new CancellationTokenSource();
 
        var proxy = (TestCallbackWithCancellation)factory.CreateProxy("test-callback", typeof(TestCallbackWithCancellation))!;
 
        await proxy("test", cts.Token);
 
        Assert.Single(invoker.Invocations);
        var args = invoker.Invocations[0].Args as JsonObject;
        Assert.NotNull(args);
        // Arguments are passed with positional keys (p0, p1, p2, ...)
        // CancellationToken is not included in positional args, but added as $cancellationToken if not None
        Assert.Equal("test", args["p0"]?.GetValue<string>());
    }
 
    // Callback error handling tests
    [Fact]
    public async Task InvokedProxy_PropagatesExceptionFromInvoker()
    {
        var invoker = new TestCallbackInvoker
        {
            ExceptionToThrow = new InvalidOperationException("Callback failed")
        };
        using var factory = CreateFactory(invoker);
 
        var proxy = (TestCallbackNoArgs)factory.CreateProxy("test-callback", typeof(TestCallbackNoArgs))!;
 
        var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => proxy());
        Assert.Equal("Callback failed", ex.Message);
    }
 
    [Fact]
    public async Task InvokedProxy_WithResult_PropagatesExceptionFromInvoker()
    {
        var invoker = new TestCallbackInvoker
        {
            ExceptionToThrow = new InvalidOperationException("Callback with result failed")
        };
        using var factory = CreateFactory(invoker);
 
        var proxy = (TestCallbackWithIntResult)factory.CreateProxy("test-callback", typeof(TestCallbackWithIntResult))!;
 
        var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => proxy());
        Assert.Equal("Callback with result failed", ex.Message);
    }
 
    [Fact]
    public async Task InvokedProxy_PropagatesOperationCanceledException()
    {
        var invoker = new TestCallbackInvoker
        {
            ExceptionToThrow = new OperationCanceledException("Operation was cancelled")
        };
        using var factory = CreateFactory(invoker);
 
        var proxy = (TestCallbackNoArgs)factory.CreateProxy("test-callback", typeof(TestCallbackNoArgs))!;
 
        await Assert.ThrowsAsync<OperationCanceledException>(() => proxy());
    }
 
    [Fact]
    public void InvokedSyncVoidProxy_AppliesDtoWritebackFromResult()
    {
        var dto = new TestCallbackDto { Name = "original", Count = 0 };
 
        // The invoker returns the modified args (simulating TypeScript returning mutated DTO)
        var invoker = new TestCallbackInvoker
        {
            ResultToReturn = new JsonObject
            {
                ["p0"] = new JsonObject { ["name"] = "modified", ["count"] = 42 }
            }
        };
        using var factory = CreateFactory(invoker, registerDtoTypes: true);
 
        var proxy = (TestSyncVoidCallbackWithDto)factory.CreateProxy("test-callback", typeof(TestSyncVoidCallbackWithDto))!;
 
        proxy(dto);
 
        Assert.Equal("modified", dto.Name);
        Assert.Equal(42, dto.Count);
    }
 
    [Fact]
    public async Task InvokedAsyncVoidProxy_AppliesDtoWritebackFromResult()
    {
        var dto = new TestCallbackDto { Name = "original", Count = 0 };
 
        var invoker = new TestCallbackInvoker
        {
            ResultToReturn = new JsonObject
            {
                ["p0"] = new JsonObject { ["name"] = "async-modified", ["count"] = 99 }
            }
        };
        using var factory = CreateFactory(invoker, registerDtoTypes: true);
 
        var proxy = (TestAsyncVoidCallbackWithDto)factory.CreateProxy("test-callback", typeof(TestAsyncVoidCallbackWithDto))!;
 
        await proxy(dto);
 
        Assert.Equal("async-modified", dto.Name);
        Assert.Equal(99, dto.Count);
    }
 
    [Fact]
    public void InvokedSyncVoidProxy_DtoWritebackIgnoresNonDtoArgs()
    {
        // Use a delegate with both a non-DTO param (string) and a DTO param to exercise
        // the writeback path. The non-DTO arg at p0 should be safely skipped, while
        // the DTO arg at p1 should be written back.
        var dto = new TestCallbackDto { Name = "original", Count = 0 };
        var invoker = new TestCallbackInvoker
        {
            ResultToReturn = new JsonObject
            {
                ["p0"] = JsonValue.Create("some-string"),
                ["p1"] = new JsonObject { ["name"] = "mixed-modified", ["count"] = 77 }
            }
        };
        using var factory = CreateFactory(invoker, registerDtoTypes: true);
 
        var proxy = (TestSyncVoidCallbackWithMixedArgs)factory.CreateProxy("test-callback", typeof(TestSyncVoidCallbackWithMixedArgs))!;
 
        // Should not throw - non-DTO arg at p0 is skipped, DTO arg at p1 is written back
        proxy("hello", dto);
 
        Assert.Equal("mixed-modified", dto.Name);
        Assert.Equal(77, dto.Count);
    }
 
    [Fact]
    public void InvokedSyncVoidProxy_DtoWritebackHandlesNullResult()
    {
        var dto = new TestCallbackDto { Name = "original", Count = 0 };
 
        // Invoker returns null (TypeScript callback returned undefined and no args were sent back)
        var invoker = new TestCallbackInvoker { ResultToReturn = null };
        using var factory = CreateFactory(invoker, registerDtoTypes: true);
 
        var proxy = (TestSyncVoidCallbackWithDto)factory.CreateProxy("test-callback", typeof(TestSyncVoidCallbackWithDto))!;
 
        proxy(dto);
 
        // Original values should be unchanged
        Assert.Equal("original", dto.Name);
        Assert.Equal(0, dto.Count);
    }
 
    [Fact]
    public void InvokedSyncVoidProxy_AppliesWritebackToMultipleDtos()
    {
        var dto1 = new TestCallbackDto { Name = "first", Count = 1 };
        var dto2 = new TestCallbackDto { Name = "second", Count = 2 };
 
        var invoker = new TestCallbackInvoker
        {
            ResultToReturn = new JsonObject
            {
                ["p0"] = new JsonObject { ["name"] = "first-updated", ["count"] = 10 },
                ["p1"] = new JsonObject { ["name"] = "second-updated", ["count"] = 20 }
            }
        };
        using var factory = CreateFactory(invoker, registerDtoTypes: true);
 
        var proxy = (TestSyncVoidCallbackWithMultipleDtos)factory.CreateProxy("test-callback", typeof(TestSyncVoidCallbackWithMultipleDtos))!;
 
        proxy(dto1, dto2);
 
        Assert.Equal("first-updated", dto1.Name);
        Assert.Equal(10, dto1.Count);
        Assert.Equal("second-updated", dto2.Name);
        Assert.Equal(20, dto2.Count);
    }
 
    private static AtsCallbackProxyFactory CreateFactory(ICallbackInvoker? invoker = null, bool registerDtoTypes = false)
    {
        var handles = new HandleRegistry();
        var ctRegistry = new CancellationTokenRegistry();
        var dtoTypes = registerDtoTypes
            ? new List<AtsDtoTypeInfo>
            {
                new() { TypeId = "test/TestCallbackDto", Name = "TestCallbackDto", ClrType = typeof(TestCallbackDto), Properties = [] }
            }
            : [];
        var context = new AtsContext { Capabilities = [], HandleTypes = [], DtoTypes = dtoTypes, EnumTypes = [] };
        var marshaller = new AtsMarshaller(handles, context, ctRegistry, new Lazy<AtsCallbackProxyFactory>(() => throw new NotImplementedException()));
        return new AtsCallbackProxyFactory(invoker ?? new TestCallbackInvoker(), handles, ctRegistry, marshaller);
    }
 
    // Test delegates - any delegate type is now treated as a callback
    public delegate Task TestCallback(string value);
 
    public delegate Task TestCallbackNoArgs();
 
    public delegate Task<int> TestCallbackWithIntResult();
 
    public delegate Task TestCallbackWithString(string value);
 
    public delegate Task TestCallbackWithMultipleParams(string name, int count);
 
    public delegate Task<string> TestCallbackWithStringResult(string input);
 
    public delegate Task TestCallbackWithCancellation(string value, CancellationToken cancellationToken);
 
    public delegate void TestSyncVoidCallbackWithDto(TestCallbackDto dto);
 
    public delegate Task TestAsyncVoidCallbackWithDto(TestCallbackDto dto);
 
    public delegate void TestSyncVoidCallbackWithMixedArgs(string label, TestCallbackDto dto);
 
    public delegate void TestSyncVoidCallbackWithMultipleDtos(TestCallbackDto dto1, TestCallbackDto dto2);
 
    [AspireDto]
    public sealed class TestCallbackDto
    {
        public string? Name { get; set; }
        public int Count { get; set; }
    }
}
 
internal sealed class TestCallbackInvoker : ICallbackInvoker
{
    public List<(string CallbackId, JsonNode? Args)> Invocations { get; } = [];
    public JsonNode? ResultToReturn { get; set; }
    public Exception? ExceptionToThrow { get; set; }
    public bool IsConnected => true;
 
    public Task<TResult> InvokeAsync<TResult>(string callbackId, JsonNode? args, CancellationToken cancellationToken = default)
    {
        Invocations.Add((callbackId, args));
        if (ExceptionToThrow != null)
        {
            throw ExceptionToThrow;
        }
        return Task.FromResult(ResultToReturn is TResult result ? result : default!);
    }
 
    public Task InvokeAsync(string callbackId, JsonNode? args, CancellationToken cancellationToken = default)
    {
        Invocations.Add((callbackId, args));
        if (ExceptionToThrow != null)
        {
            throw ExceptionToThrow;
        }
        return Task.CompletedTask;
    }
}