|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Concurrent;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
using Aspire.Hosting.Ats;
using Microsoft.Extensions.Logging;
namespace Aspire.Hosting.RemoteHost.Ats;
/// <summary>
/// Delegate for capability implementations.
/// </summary>
/// <param name="args">The arguments as a JSON object.</param>
/// <param name="handles">The handle registry for resolving/registering handles.</param>
/// <returns>The result as JSON, or null for void operations.</returns>
internal delegate Task<JsonNode?> CapabilityHandler(
JsonObject? args,
HandleRegistry handles);
/// <summary>
/// Dispatches capability invocations to their implementations.
/// Scans provided assemblies for [AspireExport] attributes.
/// </summary>
internal sealed class CapabilityDispatcher
{
private readonly ConcurrentDictionary<string, CapabilityRegistration> _capabilities = new();
private readonly HandleRegistry _handles;
private readonly AtsMarshaller _marshaller;
private readonly ILogger _logger;
private Hosting.Ats.AtsContext? _atsContext;
/// <summary>
/// Represents a registered capability.
/// </summary>
private sealed class CapabilityRegistration
{
public required string CapabilityId { get; init; }
public required CapabilityHandler Handler { get; init; }
public string? Description { get; init; }
}
/// <summary>
/// Creates a new CapabilityDispatcher for DI.
/// </summary>
/// <param name="handles">The handle registry for resolving handle references.</param>
/// <param name="assemblyLoader">The assembly loader to get assemblies from.</param>
/// <param name="marshaller">The marshaller for converting objects to/from JSON.</param>
/// <param name="logger">The logger.</param>
public CapabilityDispatcher(
HandleRegistry handles,
AssemblyLoader assemblyLoader,
AtsMarshaller marshaller,
ILogger<CapabilityDispatcher> logger)
{
_handles = handles;
_marshaller = marshaller;
_logger = logger;
// Scan for capabilities on initialization
ScanAssemblies(assemblyLoader.GetAssemblies());
}
/// <summary>
/// Creates a new CapabilityDispatcher for testing purposes.
/// </summary>
/// <param name="handles">The handle registry for resolving handle references.</param>
/// <param name="marshaller">The marshaller for converting objects to/from JSON.</param>
/// <param name="assemblies">The assemblies to scan for capabilities.</param>
internal CapabilityDispatcher(
HandleRegistry handles,
AtsMarshaller marshaller,
IReadOnlyList<Assembly> assemblies)
{
_handles = handles;
_marshaller = marshaller;
_logger = Microsoft.Extensions.Logging.Abstractions.NullLogger<CapabilityDispatcher>.Instance;
ScanAssemblies(assemblies);
}
/// <summary>
/// Scans the provided assemblies for [AspireExport] and [AspireContextType] attributes.
/// Uses the shared AtsCapabilityScanner for discovery.
/// </summary>
private void ScanAssemblies(IEnumerable<Assembly> assemblies)
{
var assemblyList = assemblies.ToList();
_logger.LogDebug("Scanning {AssemblyCount} assemblies for capabilities...", assemblyList.Count);
// Scan all assemblies at once to get combined result with AtsContext
var result = AtsCapabilityScanner.ScanAssemblies(assemblyList);
// Store the AtsContext for capability registration
_atsContext = result.ToAtsContext();
// Log diagnostics from the scanner
foreach (var diagnostic in result.Diagnostics)
{
if (diagnostic.Severity == AtsDiagnosticSeverity.Error)
{
_logger.LogError("{Message} at {Location}", diagnostic.Message, diagnostic.Location);
}
else
{
_logger.LogWarning("{Message} at {Location}", diagnostic.Message, diagnostic.Location);
}
}
// Register all capabilities
foreach (var capability in result.Capabilities)
{
if ((capability.CapabilityKind == AtsCapabilityKind.PropertyGetter || capability.CapabilityKind == AtsCapabilityKind.PropertySetter)
&& result.Properties.TryGetValue(capability.CapabilityId, out var property))
{
// Context type property capability
RegisterContextTypeProperty(capability, property);
}
else if (capability.CapabilityKind == AtsCapabilityKind.InstanceMethod
&& result.Methods.TryGetValue(capability.CapabilityId, out var instanceMethod))
{
// Context type method capability (instance method)
RegisterContextTypeMethod(capability, instanceMethod);
}
else if (result.Methods.TryGetValue(capability.CapabilityId, out var method))
{
// Static method capability
RegisterFromCapability(capability, method);
}
}
// Log summary of all registered capabilities
_logger.LogDebug("Registered {CapabilityCount} capabilities", _capabilities.Count);
foreach (var capabilityId in _capabilities.Keys.OrderBy(k => k))
{
_logger.LogTrace(" - {CapabilityId}", capabilityId);
}
}
/// <summary>
/// Registers a context type property capability.
/// </summary>
private void RegisterContextTypeProperty(AtsCapabilityInfo capability, PropertyInfo property)
{
var capabilityId = capability.CapabilityId;
var prop = property; // Capture for closure
if (capability.CapabilityKind == AtsCapabilityKind.PropertyGetter)
{
// Getter capability
CapabilityHandler getterHandler = (args, handles) =>
{
if (args == null || !args.TryGetPropertyValue("context", out var contextNode))
{
throw CapabilityException.InvalidArgument(capabilityId, "context", "Missing required argument 'context'");
}
var handleRef = HandleRef.FromJsonNode(contextNode);
if (handleRef == null)
{
throw CapabilityException.InvalidArgument(capabilityId, "context", "Argument 'context' must be a handle reference");
}
if (!handles.TryGet(handleRef.HandleId, out var contextObj, out _))
{
throw CapabilityException.HandleNotFound(handleRef.HandleId, capabilityId);
}
var value = prop.GetValue(contextObj);
return Task.FromResult(_marshaller.MarshalToJson(value, capability.ReturnType));
};
_capabilities[capabilityId] = new CapabilityRegistration
{
CapabilityId = capabilityId,
Handler = getterHandler,
Description = capability.Description ?? $"Gets the {property.Name} property"
};
}
else if (capability.CapabilityKind == AtsCapabilityKind.PropertySetter)
{
// Setter capability - returns the context handle for fluent chaining
CapabilityHandler setterHandler = (args, handles) =>
{
if (args == null || !args.TryGetPropertyValue("context", out var contextNode))
{
throw CapabilityException.InvalidArgument(capabilityId, "context", "Missing required argument 'context'");
}
var handleRef = HandleRef.FromJsonNode(contextNode);
if (handleRef == null)
{
throw CapabilityException.InvalidArgument(capabilityId, "context", "Argument 'context' must be a handle reference");
}
if (!handles.TryGet(handleRef.HandleId, out var contextObj, out var typeId))
{
throw CapabilityException.HandleNotFound(handleRef.HandleId, capabilityId);
}
if (!args.TryGetPropertyValue("value", out var valueNode))
{
throw CapabilityException.InvalidArgument(capabilityId, "value", "Missing required argument 'value'");
}
var unmarshalContext = new AtsMarshaller.UnmarshalContext
{
CapabilityId = capabilityId,
ParameterName = "value"
};
var value = _marshaller.UnmarshalFromJson(valueNode, prop.PropertyType, unmarshalContext);
prop.SetValue(contextObj, value);
// Return the context handle for fluent chaining
return Task.FromResult<JsonNode?>(new JsonObject
{
["$handle"] = handleRef.HandleId,
["$type"] = typeId
});
};
_capabilities[capabilityId] = new CapabilityRegistration
{
CapabilityId = capabilityId,
Handler = setterHandler,
Description = capability.Description ?? $"Sets the {property.Name} property"
};
}
}
/// <summary>
/// Registers a context type method capability (instance method).
/// </summary>
private void RegisterContextTypeMethod(AtsCapabilityInfo capability, MethodInfo method)
{
var capabilityId = capability.CapabilityId;
var parameters = method.GetParameters();
CapabilityHandler handler = async (args, handles) =>
{
// First parameter is always "context" - the instance to invoke on
if (args == null || !args.TryGetPropertyValue("context", out var contextNode))
{
throw CapabilityException.InvalidArgument(capabilityId, "context", "Missing required argument 'context'");
}
var handleRef = HandleRef.FromJsonNode(contextNode);
if (handleRef == null)
{
throw CapabilityException.InvalidArgument(capabilityId, "context", "Argument 'context' must be a handle reference");
}
if (!handles.TryGet(handleRef.HandleId, out var contextObj, out _))
{
throw CapabilityException.HandleNotFound(handleRef.HandleId, capabilityId);
}
// Build method arguments from the remaining parameters
var methodArgs = new object?[parameters.Length];
for (int i = 0; i < parameters.Length; i++)
{
var param = parameters[i];
var paramName = param.Name ?? $"arg{i}";
if (args.TryGetPropertyValue(paramName, out var argNode))
{
var context = new AtsMarshaller.UnmarshalContext
{
CapabilityId = capabilityId,
ParameterName = paramName
};
methodArgs[i] = _marshaller.UnmarshalFromJson(argNode, param.ParameterType, context);
}
else if (param.HasDefaultValue)
{
methodArgs[i] = param.DefaultValue;
}
else
{
throw CapabilityException.InvalidArgument(
capabilityId, paramName, $"Missing required argument '{paramName}'");
}
}
// Handle generic methods - resolve type parameters from actual arguments
var methodToInvoke = method;
if (method.ContainsGenericParameters)
{
methodToInvoke = GenericMethodResolver.MakeGenericMethodFromArgs(method, methodArgs);
}
object? result;
try
{
// Invoke instance method on the context object
result = methodToInvoke.Invoke(contextObj, methodArgs);
}
catch (TargetInvocationException tie) when (tie.InnerException is not null)
{
throw tie.InnerException;
}
// Handle async methods - await instead of blocking
if (result is Task task)
{
try
{
await task.ConfigureAwait(false);
}
catch (Exception ex)
{
throw new InvalidOperationException(ex.Message, ex);
}
var taskType = task.GetType();
if (taskType.IsGenericType)
{
var resultProperty = taskType.GetProperty("Result");
result = resultProperty?.GetValue(task);
}
else
{
result = null;
}
}
return _marshaller.MarshalToJson(result, capability.ReturnType);
};
_capabilities[capabilityId] = new CapabilityRegistration
{
CapabilityId = capabilityId,
Handler = handler,
Description = capability.Description ?? $"Invokes the {method.Name} method"
};
}
/// <summary>
/// Registers a capability from its info and method.
/// Uses metadata from the shared scanner, creates runtime handler for invocation.
/// </summary>
private void RegisterFromCapability(AtsCapabilityInfo capability, MethodInfo method)
{
var capabilityId = capability.CapabilityId;
var parameters = method.GetParameters();
// Create a handler that invokes the method via reflection
CapabilityHandler handler = async (args, handles) =>
{
var methodArgs = new object?[parameters.Length];
for (int i = 0; i < parameters.Length; i++)
{
var param = parameters[i];
var paramName = param.Name ?? $"arg{i}";
if (args != null && args.TryGetPropertyValue(paramName, out var argNode))
{
var context = new AtsMarshaller.UnmarshalContext
{
CapabilityId = capabilityId,
ParameterName = paramName
};
methodArgs[i] = _marshaller.UnmarshalFromJson(argNode, param.ParameterType, context);
}
else if (param.HasDefaultValue)
{
methodArgs[i] = param.DefaultValue;
}
else
{
throw CapabilityException.InvalidArgument(
capabilityId, paramName, $"Missing required argument '{paramName}'");
}
}
// Handle generic methods - resolve type parameters from actual arguments
var methodToInvoke = method;
if (method.ContainsGenericParameters)
{
methodToInvoke = GenericMethodResolver.MakeGenericMethodFromArgs(method, methodArgs);
}
object? result;
try
{
result = methodToInvoke.Invoke(null, methodArgs);
}
catch (TargetInvocationException tie) when (tie.InnerException is not null)
{
// Unwrap the TargetInvocationException to get the actual exception
throw tie.InnerException;
}
// Handle async methods - await instead of blocking
if (result is Task task)
{
try
{
await task.ConfigureAwait(false);
}
catch (Exception ex)
{
// Rethrow the exception - it will be caught by the outer handler
// and converted to a CapabilityException
throw new InvalidOperationException(ex.Message, ex);
}
var taskType = task.GetType();
if (taskType.IsGenericType)
{
var resultProperty = taskType.GetProperty("Result");
result = resultProperty?.GetValue(task);
}
else
{
result = null;
}
}
return _marshaller.MarshalToJson(result, capability.ReturnType);
};
_capabilities[capabilityId] = new CapabilityRegistration
{
CapabilityId = capabilityId,
Handler = handler,
Description = capability.Description
};
}
/// <summary>
/// Registers a capability with its handler.
/// </summary>
/// <param name="capabilityId">The capability ID (e.g., "Aspire.Hosting.Redis/addRedis").</param>
/// <param name="handler">The handler that implements the capability.</param>
/// <param name="description">Optional description of the capability.</param>
public void Register(
string capabilityId,
CapabilityHandler handler,
string? description = null)
{
_capabilities[capabilityId] = new CapabilityRegistration
{
CapabilityId = capabilityId,
Handler = handler,
Description = description
};
}
/// <summary>
/// Invokes a capability by ID with the given arguments.
/// Type validation is performed by the CLR at runtime.
/// </summary>
/// <param name="capabilityId">The capability ID.</param>
/// <param name="args">The arguments as a JSON object.</param>
/// <returns>The result as JSON, or null for void methods.</returns>
public async Task<JsonNode?> InvokeAsync(string capabilityId, JsonObject? args)
{
// Look up the capability
if (!_capabilities.TryGetValue(capabilityId, out var registration))
{
throw CapabilityException.CapabilityNotFound(capabilityId);
}
args ??= new JsonObject();
try
{
return await registration.Handler(args, _handles).ConfigureAwait(false);
}
catch (CapabilityException)
{
throw;
}
catch (ArgumentException ex) when (IsTypeMismatchException(ex))
{
// Convert CLR type mismatch to ATS error
throw CapabilityException.TypeMismatch(capabilityId, "argument", "expected type", ex.Message);
}
catch (InvalidCastException ex)
{
// Convert CLR cast failures to ATS error
throw CapabilityException.TypeMismatch(capabilityId, "argument", "expected type", ex.Message);
}
catch (Exception ex)
{
throw CapabilityException.InternalError(capabilityId, ex.Message, ex);
}
}
/// <summary>
/// Invokes a capability by ID with the given arguments synchronously.
/// This is a convenience method that blocks until the async operation completes.
/// For production use, prefer InvokeAsync.
/// </summary>
/// <param name="capabilityId">The capability ID.</param>
/// <param name="args">The arguments as a JSON object.</param>
/// <returns>The result as JSON, or null for void methods.</returns>
public JsonNode? Invoke(string capabilityId, JsonObject? args)
{
return InvokeAsync(capabilityId, args).GetAwaiter().GetResult();
}
/// <summary>
/// Checks if an exception indicates a type mismatch.
/// </summary>
private static bool IsTypeMismatchException(ArgumentException ex)
{
// Check for common type mismatch patterns in exception messages
var message = ex.Message;
return message.Contains("cannot be converted") ||
message.Contains("is not assignable") ||
message.Contains("type mismatch", StringComparison.OrdinalIgnoreCase);
}
/// <summary>
/// Gets all registered capability IDs.
/// </summary>
public IEnumerable<string> GetCapabilityIds() => _capabilities.Keys;
/// <summary>
/// Checks if a capability is registered.
/// </summary>
public bool HasCapability(string capabilityId) => _capabilities.ContainsKey(capabilityId);
}
/// <summary>
/// Extension methods for working with JSON in capability handlers.
/// </summary>
internal static class CapabilityJsonExtensions
{
private static readonly JsonSerializerOptions s_jsonOptions = new()
{
PropertyNameCaseInsensitive = true,
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
};
/// <summary>
/// Gets a required string argument.
/// </summary>
public static string GetRequiredString(this JsonObject args, string name, string capabilityId)
{
if (!args.TryGetPropertyValue(name, out var node) || node is not JsonValue value)
{
throw CapabilityException.InvalidArgument(capabilityId, name, $"Missing required argument '{name}'");
}
return value.GetValue<string>() ??
throw CapabilityException.InvalidArgument(capabilityId, name, $"Argument '{name}' cannot be null");
}
/// <summary>
/// Gets an optional string argument.
/// </summary>
public static string? GetOptionalString(this JsonObject args, string name)
{
if (args.TryGetPropertyValue(name, out var node) && node is JsonValue value)
{
return value.GetValue<string>();
}
return null;
}
/// <summary>
/// Gets an optional int argument.
/// </summary>
public static int? GetOptionalInt(this JsonObject args, string name)
{
if (args.TryGetPropertyValue(name, out var node) && node is JsonValue value)
{
return value.GetValue<int>();
}
return null;
}
/// <summary>
/// Gets a required handle reference.
/// </summary>
public static T GetRequiredHandle<T>(
this JsonObject args,
string name,
string capabilityId,
HandleRegistry handles) where T : class
{
if (!args.TryGetPropertyValue(name, out var node))
{
throw CapabilityException.InvalidArgument(capabilityId, name, $"Missing required argument '{name}'");
}
var handleRef = HandleRef.FromJsonNode(node) ??
throw CapabilityException.InvalidArgument(capabilityId, name, $"Argument '{name}' must be a handle reference");
if (!handles.TryGet(handleRef.HandleId, out var obj, out _))
{
throw CapabilityException.HandleNotFound(handleRef.HandleId, capabilityId);
}
if (obj is not T typed)
{
throw CapabilityException.TypeMismatch(
capabilityId, name, typeof(T).Name, obj?.GetType().Name ?? "null");
}
return typed;
}
/// <summary>
/// Deserializes a DTO from a JSON argument.
/// </summary>
public static T? GetDto<T>(this JsonObject args, string name) where T : class
{
if (args.TryGetPropertyValue(name, out var node) && node is JsonObject obj)
{
return JsonSerializer.Deserialize<T>(obj.ToJsonString(), s_jsonOptions);
}
return null;
}
/// <summary>
/// Creates a handle result for returning from a capability.
/// </summary>
public static JsonObject CreateHandleResult(this HandleRegistry handles, object obj, string typeId)
{
return handles.Marshal(obj, typeId);
}
}
|