File: Ats\AtsCallbackProxyFactory.cs
Web Access
Project: src\src\Aspire.Hosting.RemoteHost\Aspire.Hosting.RemoteHost.csproj (Aspire.Hosting.RemoteHost)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Concurrent;
using System.Linq.Expressions;
using System.Reflection;
using System.Text.Json.Nodes;
 
namespace Aspire.Hosting.RemoteHost.Ats;
 
/// <summary>
/// Creates delegate proxies for ATS callbacks that invoke the remote client.
/// Works with any delegate type.
/// </summary>
internal sealed class AtsCallbackProxyFactory : IDisposable
{
    private readonly ICallbackInvoker _invoker;
    private readonly HandleRegistry _handleRegistry;
    private readonly CancellationTokenRegistry _cancellationTokenRegistry;
    private readonly AtsMarshaller _marshaller;
    private readonly ConcurrentDictionary<(string CallbackId, Type DelegateType), Delegate> _cache = new();
 
    /// <summary>
    /// Creates a new AtsCallbackProxyFactory with the specified invoker and handle registry.
    /// </summary>
    /// <param name="invoker">The callback invoker for making remote calls.</param>
    /// <param name="handleRegistry">The handle registry for marshalling objects.</param>
    /// <param name="cancellationTokenRegistry">The cancellation token registry.</param>
    /// <param name="marshaller">The marshaller for converting objects to/from JSON.</param>
    public AtsCallbackProxyFactory(
        ICallbackInvoker invoker,
        HandleRegistry handleRegistry,
        CancellationTokenRegistry cancellationTokenRegistry,
        AtsMarshaller marshaller)
    {
        _invoker = invoker;
        _handleRegistry = handleRegistry;
        _cancellationTokenRegistry = cancellationTokenRegistry;
        _marshaller = marshaller;
    }
 
    /// <summary>
    /// Gets the cancellation token registry for external access (e.g., RPC cancel method).
    /// </summary>
    public CancellationTokenRegistry CancellationTokenRegistry => _cancellationTokenRegistry;
 
    /// <summary>
    /// Creates a delegate proxy that invokes a callback on the remote client.
    /// </summary>
    /// <param name="callbackId">The callback ID registered on the client.</param>
    /// <param name="delegateType">The delegate type to create.</param>
    /// <returns>A delegate that invokes the remote callback, or null if the type is not valid.</returns>
    public Delegate? CreateProxy(string callbackId, Type delegateType)
    {
        if (!typeof(Delegate).IsAssignableFrom(delegateType))
        {
            return null;
        }
 
        return _cache.GetOrAdd((callbackId, delegateType), key => BuildProxy(key.CallbackId, key.DelegateType));
    }
 
    /// <summary>
    /// Builds a delegate proxy using expression trees.
    /// </summary>
    private Delegate BuildProxy(string callbackId, Type delegateType)
    {
        var invokeMethod = delegateType.GetMethod("Invoke")!;
        var parameters = invokeMethod.GetParameters();
        var returnType = invokeMethod.ReturnType;
 
        // Create parameter expressions
        var paramExprs = parameters.Select(p => Expression.Parameter(p.ParameterType, p.Name)).ToArray();
 
        // Determine if async (returns Task or Task<T>)
        var isAsync = typeof(Task).IsAssignableFrom(returnType);
        var hasResult = returnType != typeof(void) && (!isAsync || returnType.IsGenericType);
        var resultType = isAsync && returnType.IsGenericType
            ? returnType.GetGenericArguments()[0]
            : (returnType != typeof(void) && !isAsync ? returnType : null);
 
        // Find CancellationToken parameter if any
        var ctParamIndex = Array.FindIndex(parameters, p => p.ParameterType == typeof(CancellationToken));
        Expression? ctExpr = ctParamIndex >= 0 ? paramExprs[ctParamIndex] : null;
 
        // Build the body expression
        Expression body;
        var argsExpr = parameters.Length == 0 || (parameters.Length == 1 && ctParamIndex == 0)
            ? null
            : BuildMarshalArgs(paramExprs, parameters);
 
        if (isAsync)
        {
            if (!hasResult)
            {
                body = BuildAsyncVoidCall(callbackId, argsExpr, ctExpr);
            }
            else
            {
                body = BuildAsyncResultCall(callbackId, resultType!, argsExpr, ctExpr);
            }
        }
        else
        {
            if (!hasResult)
            {
                body = BuildSyncVoidCall(callbackId, argsExpr, ctExpr);
            }
            else
            {
                body = BuildSyncResultCall(callbackId, resultType!, argsExpr, ctExpr);
            }
        }
 
        var lambda = Expression.Lambda(delegateType, body, paramExprs);
        return lambda.Compile();
    }
 
    private Expression BuildMarshalArgs(ParameterExpression[] paramExprs, ParameterInfo[] parameters)
    {
        // Build: new JsonObject { { "p0", MarshalArg(arg1) }, { "p1", MarshalArg(arg2) } }
        // Uses positional keys (p0, p1, p2, ...) instead of parameter names for predictable unpacking on TypeScript side
        var jsonObjectType = typeof(JsonObject);
        // JsonObject doesn't have a true parameterless constructor - it has JsonObject(JsonNodeOptions? options = null)
        // Expression.New can't handle optional parameters, so we need to call the constructor explicitly with null
        var jsonObjectCtor = jsonObjectType.GetConstructor([typeof(JsonNodeOptions?)])!;
        var newJsonObject = Expression.New(jsonObjectCtor, Expression.Constant(null, typeof(JsonNodeOptions?)));
 
        var addMethod = jsonObjectType.GetMethod("Add", [typeof(string), typeof(JsonNode)]);
 
        var expressions = new List<Expression>();
        var jsonObjVar = Expression.Variable(jsonObjectType, "args");
        expressions.Add(Expression.Assign(jsonObjVar, newJsonObject));
 
        var paramIndex = 0; // Track positional index (excludes CancellationToken)
        for (int i = 0; i < parameters.Length; i++)
        {
            var param = parameters[i];
            var paramExpr = paramExprs[i];
 
            // Skip CancellationToken for now (handled separately)
            if (param.ParameterType == typeof(CancellationToken))
            {
                continue;
            }
 
            // Call MarshalArg to convert to JsonNode
            var marshalMethod = typeof(AtsCallbackProxyFactory).GetMethod(
                nameof(MarshalArg),
                BindingFlags.Instance | BindingFlags.NonPublic)!;
 
            var marshalCall = Expression.Call(
                Expression.Constant(this),
                marshalMethod,
                Expression.Convert(paramExpr, typeof(object)));
 
            // Use positional key (p0, p1, p2, ...) instead of param.Name
            var addCall = Expression.Call(jsonObjVar, addMethod!, Expression.Constant($"p{paramIndex}"), marshalCall);
            expressions.Add(addCall);
            paramIndex++;
        }
 
        expressions.Add(jsonObjVar);
        return Expression.Block(new[] { jsonObjVar }, expressions);
    }
 
    private JsonNode? MarshalArg(object? value)
    {
        return _marshaller.MarshalToJson(value);
    }
 
    private Expression BuildSyncVoidCall(string callbackId, Expression? argsExpr, Expression? ctExpr)
    {
        var invokeMethod = typeof(AtsCallbackProxyFactory).GetMethod(
            nameof(InvokeSyncVoid),
            BindingFlags.Instance | BindingFlags.NonPublic)!;
 
        return Expression.Call(
            Expression.Constant(this),
            invokeMethod,
            Expression.Constant(callbackId),
            argsExpr ?? Expression.Constant(null, typeof(JsonObject)),
            ctExpr ?? Expression.Constant(CancellationToken.None, typeof(CancellationToken)));
    }
 
    private Expression BuildSyncResultCall(string callbackId, Type resultType, Expression? argsExpr, Expression? ctExpr)
    {
        var invokeMethod = typeof(AtsCallbackProxyFactory).GetMethod(
            nameof(InvokeSyncResult),
            BindingFlags.Instance | BindingFlags.NonPublic)!.MakeGenericMethod(resultType);
 
        return Expression.Call(
            Expression.Constant(this),
            invokeMethod,
            Expression.Constant(callbackId),
            argsExpr ?? Expression.Constant(null, typeof(JsonObject)),
            ctExpr ?? Expression.Constant(CancellationToken.None, typeof(CancellationToken)));
    }
 
    private Expression BuildAsyncVoidCall(string callbackId, Expression? argsExpr, Expression? ctExpr)
    {
        var invokeMethod = typeof(AtsCallbackProxyFactory).GetMethod(
            nameof(InvokeAsyncVoid),
            BindingFlags.Instance | BindingFlags.NonPublic)!;
 
        return Expression.Call(
            Expression.Constant(this),
            invokeMethod,
            Expression.Constant(callbackId),
            argsExpr ?? Expression.Constant(null, typeof(JsonObject)),
            ctExpr ?? Expression.Constant(CancellationToken.None, typeof(CancellationToken)));
    }
 
    private Expression BuildAsyncResultCall(string callbackId, Type resultType, Expression? argsExpr, Expression? ctExpr)
    {
        var invokeMethod = typeof(AtsCallbackProxyFactory).GetMethod(
            nameof(InvokeAsyncResult),
            BindingFlags.Instance | BindingFlags.NonPublic)!.MakeGenericMethod(resultType);
 
        return Expression.Call(
            Expression.Constant(this),
            invokeMethod,
            Expression.Constant(callbackId),
            argsExpr ?? Expression.Constant(null, typeof(JsonObject)),
            ctExpr ?? Expression.Constant(CancellationToken.None, typeof(CancellationToken)));
    }
 
    private void InvokeSyncVoid(string callbackId, JsonObject? args, CancellationToken cancellationToken)
    {
        AddCancellationTokenToArgs(ref args, cancellationToken);
        _invoker.InvokeAsync<JsonNode?>(callbackId, args, cancellationToken).GetAwaiter().GetResult();
    }
 
    private T? InvokeSyncResult<T>(string callbackId, JsonObject? args, CancellationToken cancellationToken)
    {
        AddCancellationTokenToArgs(ref args, cancellationToken);
        var result = _invoker.InvokeAsync<JsonNode?>(callbackId, args, cancellationToken).GetAwaiter().GetResult();
        return UnmarshalResult<T>(result, callbackId);
    }
 
    private async Task InvokeAsyncVoid(string callbackId, JsonObject? args, CancellationToken cancellationToken)
    {
        AddCancellationTokenToArgs(ref args, cancellationToken);
        await _invoker.InvokeAsync<JsonNode?>(callbackId, args, cancellationToken).ConfigureAwait(false);
    }
 
    private async Task<T?> InvokeAsyncResult<T>(string callbackId, JsonObject? args, CancellationToken cancellationToken)
    {
        AddCancellationTokenToArgs(ref args, cancellationToken);
        var result = await _invoker.InvokeAsync<JsonNode?>(callbackId, args, cancellationToken).ConfigureAwait(false);
        return UnmarshalResult<T>(result, callbackId);
    }
 
    private T? UnmarshalResult<T>(JsonNode? result, string callbackId)
    {
        if (result == null)
        {
            return default;
        }
 
        var context = new AtsMarshaller.UnmarshalContext
        {
            CapabilityId = $"callback:{callbackId}",
            ParameterName = "$result"
        };
 
        return (T?)_marshaller.UnmarshalFromJson(result, typeof(T), context);
    }
 
    private void AddCancellationTokenToArgs(ref JsonObject? args, CancellationToken cancellationToken)
    {
        if (cancellationToken != CancellationToken.None)
        {
            var (tokenId, _) = _cancellationTokenRegistry.CreateLinked(cancellationToken);
            args ??= new JsonObject();
            args["$cancellationToken"] = tokenId;
        }
    }
 
    public void Dispose()
    {
        // No-op - CancellationTokenRegistry is managed by DI
    }
}