File: AtsRustCodeGenerator.cs
Web Access
Project: src\src\Aspire.Hosting.CodeGeneration.Rust\Aspire.Hosting.CodeGeneration.Rust.csproj (Aspire.Hosting.CodeGeneration.Rust)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Globalization;
using System.Reflection;
using System.Text;
using System.Text.Json;
using Aspire.Hosting.ApplicationModel;
using Aspire.Hosting.Ats;
 
namespace Aspire.Hosting.CodeGeneration.Rust;
 
/// <summary>
/// Generates a Rust SDK using the ATS (Aspire Type System) capability-based API.
/// Produces wrapper structs that proxy capabilities via JSON-RPC.
/// </summary>
public sealed class AtsRustCodeGenerator : ICodeGenerator
{
    private static readonly HashSet<string> s_rustKeywords = new(StringComparer.Ordinal)
    {
        "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else",
        "enum", "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop",
        "match", "mod", "move", "mut", "pub", "ref", "return", "self", "Self",
        "static", "struct", "super", "trait", "true", "type", "unsafe", "use",
        "where", "while", "abstract", "become", "box", "do", "final", "macro",
        "override", "priv", "try", "typeof", "unsized", "virtual", "yield"
    };
 
    private TextWriter _writer = null!;
    private readonly Dictionary<string, string> _structNames = new(StringComparer.Ordinal);
    private readonly Dictionary<string, string> _dtoNames = new(StringComparer.Ordinal);
    private readonly Dictionary<string, string> _enumNames = new(StringComparer.Ordinal);
 
    /// <inheritdoc />
    public string Language => "Rust";
 
    /// <inheritdoc />
    public Dictionary<string, string> GenerateDistributedApplication(AtsContext context)
    {
        return new Dictionary<string, string>(StringComparer.Ordinal)
        {
            ["mod.rs"] = """
                //! Aspire Rust SDK
                //! GENERATED CODE - DO NOT EDIT
 
                pub mod transport;
                pub mod base;
                pub mod aspire;
 
                pub use transport::*;
                pub use base::*;
                pub use aspire::*;
                """,
            ["transport.rs"] = GetEmbeddedResource("transport.rs"),
            ["base.rs"] = GetEmbeddedResource("base.rs"),
            ["aspire.rs"] = GenerateAspireSdk(context)
        };
    }
 
    private static string GetEmbeddedResource(string name)
    {
        var assembly = Assembly.GetExecutingAssembly();
        var resourceName = $"Aspire.Hosting.CodeGeneration.Rust.Resources.{name}";
 
        using var stream = assembly.GetManifestResourceStream(resourceName)
            ?? throw new InvalidOperationException($"Embedded resource '{name}' not found.");
        using var reader = new StreamReader(stream);
        return reader.ReadToEnd();
    }
 
    private string GenerateAspireSdk(AtsContext context)
    {
        using var stringWriter = new StringWriter(CultureInfo.InvariantCulture);
        _writer = stringWriter;
 
        var capabilities = context.Capabilities;
        var dtoTypes = context.DtoTypes;
        var enumTypes = context.EnumTypes;
 
        _enumNames.Clear();
        foreach (var enumType in enumTypes)
        {
            _enumNames[enumType.TypeId] = SanitizeIdentifier(enumType.Name);
        }
 
        _dtoNames.Clear();
        foreach (var dto in dtoTypes)
        {
            _dtoNames[dto.TypeId] = SanitizeIdentifier(dto.Name);
        }
 
        var handleTypes = BuildHandleTypes(context);
        var capabilitiesByTarget = GroupCapabilitiesByTarget(capabilities);
        var listTypeIds = CollectListAndDictTypeIds(capabilities);
 
        WriteHeader();
        GenerateEnumTypes(enumTypes);
        GenerateDtoTypes(dtoTypes);
        GenerateHandleTypes(handleTypes, capabilitiesByTarget);
        GenerateHandleWrapperRegistrations(handleTypes, listTypeIds);
        GenerateConnectionHelpers();
 
        return stringWriter.ToString();
    }
 
    private void WriteHeader()
    {
        WriteLine("//! aspire.rs - Capability-based Aspire SDK");
        WriteLine("//! GENERATED CODE - DO NOT EDIT");
        WriteLine();
        WriteLine("use std::collections::HashMap;");
        WriteLine("use std::sync::Arc;");
        WriteLine();
        WriteLine("use serde::{Deserialize, Serialize};");
        WriteLine("use serde_json::{json, Value};");
        WriteLine();
        WriteLine("use crate::transport::{");
        WriteLine("    AspireClient, CancellationToken, Handle,");
        WriteLine("    register_callback, register_cancellation, serialize_value,");
        WriteLine("};");
        WriteLine("use crate::base::{");
        WriteLine("    HandleWrapperBase, ResourceBuilderBase, ReferenceExpression,");
        WriteLine("    AspireList, AspireDict, serialize_handle, HasHandle,");
        WriteLine("};");
        WriteLine();
    }
 
    private void GenerateEnumTypes(IReadOnlyList<AtsEnumTypeInfo> enumTypes)
    {
        if (enumTypes.Count == 0)
        {
            return;
        }
 
        WriteLine("// ============================================================================");
        WriteLine("// Enums");
        WriteLine("// ============================================================================");
        WriteLine();
 
        foreach (var enumType in enumTypes)
        {
            if (enumType.ClrType is null)
            {
                continue;
            }
 
            var enumName = _enumNames[enumType.TypeId];
            WriteLine($"/// {enumType.Name}");
            WriteLine("#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]");
            WriteLine($"pub enum {enumName} {{");
            var firstMember = true;
            foreach (var member in Enum.GetNames(enumType.ClrType))
            {
                var memberName = ToPascalCase(member);
                if (firstMember)
                {
                    WriteLine($"    #[default]");
                    firstMember = false;
                }
                WriteLine($"    #[serde(rename = \"{member}\")]");
                WriteLine($"    {memberName},");
            }
            WriteLine("}");
            WriteLine();
 
            // Generate Display trait
            WriteLine($"impl std::fmt::Display for {enumName} {{");
            WriteLine("    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {");
            WriteLine("        match self {");
            foreach (var member in Enum.GetNames(enumType.ClrType))
            {
                var memberName = ToPascalCase(member);
                WriteLine($"            Self::{memberName} => write!(f, \"{member}\"),");
            }
            WriteLine("        }");
            WriteLine("    }");
            WriteLine("}");
            WriteLine();
        }
    }
 
    private void GenerateDtoTypes(IReadOnlyList<AtsDtoTypeInfo> dtoTypes)
    {
        if (dtoTypes.Count == 0)
        {
            return;
        }
 
        WriteLine("// ============================================================================");
        WriteLine("// DTOs");
        WriteLine("// ============================================================================");
        WriteLine();
 
        foreach (var dto in dtoTypes)
        {
            // Skip ReferenceExpression - it's defined in base.rs
            if (dto.TypeId == AtsConstants.ReferenceExpressionTypeId)
            {
                continue;
            }
 
            var dtoName = _dtoNames[dto.TypeId];
            WriteLine($"/// {dto.Name}");
            WriteLine("#[derive(Debug, Clone, Default, Serialize, Deserialize)]");
            WriteLine($"pub struct {dtoName} {{");
            foreach (var property in dto.Properties)
            {
                var propertyName = ToSnakeCase(property.Name);
                var propertyType = MapTypeRefToRustForDto(property.Type, property.IsOptional);
                if (property.IsOptional)
                {
                    WriteLine($"    #[serde(rename = \"{property.Name}\", skip_serializing_if = \"Option::is_none\")]");
                }
                else
                {
                    WriteLine($"    #[serde(rename = \"{property.Name}\")]");
                }
                WriteLine($"    pub {propertyName}: {propertyType},");
            }
            WriteLine("}");
            WriteLine();
 
            // Generate to_map method
            WriteLine($"impl {dtoName} {{");
            WriteLine("    pub fn to_map(&self) -> HashMap<String, Value> {");
            WriteLine("        let mut map = HashMap::new();");
            foreach (var property in dto.Properties)
            {
                var propertyName = ToSnakeCase(property.Name);
                if (property.IsOptional)
                {
                    WriteLine($"        if let Some(ref v) = self.{propertyName} {{");
                    WriteLine($"            map.insert(\"{property.Name}\".to_string(), serde_json::to_value(v).unwrap_or(Value::Null));");
                    WriteLine("        }");
                }
                else
                {
                    WriteLine($"        map.insert(\"{property.Name}\".to_string(), serde_json::to_value(&self.{propertyName}).unwrap_or(Value::Null));");
                }
            }
            WriteLine("        map");
            WriteLine("    }");
            WriteLine("}");
            WriteLine();
        }
    }
 
    private void GenerateHandleTypes(
        IReadOnlyList<RustHandleType> handleTypes,
        Dictionary<string, List<AtsCapabilityInfo>> capabilitiesByTarget)
    {
        if (handleTypes.Count == 0)
        {
            return;
        }
 
        WriteLine("// ============================================================================");
        WriteLine("// Handle Wrappers");
        WriteLine("// ============================================================================");
        WriteLine();
 
        foreach (var handleType in handleTypes.OrderBy(t => t.StructName, StringComparer.Ordinal))
        {
            WriteLine($"/// Wrapper for {handleType.TypeId}");
            WriteLine($"pub struct {handleType.StructName} {{");
            WriteLine("    handle: Handle,");
            WriteLine("    client: Arc<AspireClient>,");
            WriteLine("}");
            WriteLine();
 
            // Implement HasHandle trait
            WriteLine($"impl HasHandle for {handleType.StructName} {{");
            WriteLine("    fn handle(&self) -> &Handle {");
            WriteLine("        &self.handle");
            WriteLine("    }");
            WriteLine("}");
            WriteLine();
 
            // Constructor and methods
            WriteLine($"impl {handleType.StructName} {{");
            WriteLine("    pub fn new(handle: Handle, client: Arc<AspireClient>) -> Self {");
            WriteLine("        Self { handle, client }");
            WriteLine("    }");
            WriteLine();
            WriteLine("    pub fn handle(&self) -> &Handle {");
            WriteLine("        &self.handle");
            WriteLine("    }");
            WriteLine();
            WriteLine("    pub fn client(&self) -> &Arc<AspireClient> {");
            WriteLine("        &self.client");
            WriteLine("    }");
 
            if (capabilitiesByTarget.TryGetValue(handleType.TypeId, out var methods))
            {
                foreach (var method in methods)
                {
                    GenerateCapabilityMethod(handleType.StructName, method);
                }
            }
 
            WriteLine("}");
            WriteLine();
        }
    }
 
    private void GenerateCapabilityMethod(string structName, AtsCapabilityInfo capability)
    {
        var targetParamName = capability.TargetParameterName ?? "builder";
        var methodName = ToSnakeCase(capability.MethodName);
        var parameters = capability.Parameters
            .Where(p => !string.Equals(p.Name, targetParamName, StringComparison.Ordinal))
            .ToList();
 
        // Check if this is a List/Dict property getter (no parameters, returns List/Dict)
        if (parameters.Count == 0 && IsListOrDictPropertyGetter(capability.ReturnType))
        {
            GenerateListOrDictProperty(structName, capability, methodName);
            return;
        }
 
        var returnType = MapTypeRefToRust(capability.ReturnType, false);
        var hasReturn = capability.ReturnType.TypeId != AtsConstants.Void;
 
        // Build parameter list
        var paramList = new StringBuilder();
        paramList.Append("&self");
        foreach (var parameter in parameters)
        {
            var paramName = ToSnakeCase(parameter.Name);
            string paramType;
            if (parameter.IsCallback)
            {
                paramType = "impl Fn(Vec<Value>) -> Value + Send + Sync + 'static";
            }
            else if (IsCancellationToken(parameter))
            {
                paramType = "Option<&CancellationToken>";
            }
            else if (IsHandleType(parameter.Type))
            {
                // Handle wrappers are passed by reference
                var handleTypeName = MapTypeRefToRust(parameter.Type, false);
                paramType = parameter.IsOptional ? $"Option<&{handleTypeName}>" : $"&{handleTypeName}";
            }
            else
            {
                // Use idiomatic Rust parameter types (e.g., &str instead of String)
                paramType = MapParameterTypeToRust(parameter.Type, parameter.IsOptional);
            }
            paramList.Append(CultureInfo.InvariantCulture, $", {paramName}: {paramType}");
        }
 
        // Generate doc comment
        if (!string.IsNullOrEmpty(capability.Description))
        {
            WriteLine();
            WriteLine($"    /// {capability.Description}");
        }
 
        var resultType = hasReturn ? $"Result<{returnType}, Box<dyn std::error::Error>>" : "Result<(), Box<dyn std::error::Error>>";
        WriteLine($"    pub fn {methodName}({paramList}) -> {resultType} {{");
        WriteLine("        let mut args: HashMap<String, Value> = HashMap::new();");
        WriteLine($"        args.insert(\"{targetParamName}\".to_string(), self.handle.to_json());");
 
        foreach (var parameter in parameters)
        {
            var paramName = ToSnakeCase(parameter.Name);
            if (parameter.IsCallback)
            {
                WriteLine($"        let callback_id = register_callback({paramName});");
                WriteLine($"        args.insert(\"{parameter.Name}\".to_string(), Value::String(callback_id));");
                continue;
            }
 
            if (IsCancellationToken(parameter))
            {
                WriteLine($"        if let Some(token) = {paramName} {{");
                WriteLine($"            let token_id = register_cancellation(token, self.client.clone());");
                WriteLine($"            args.insert(\"{parameter.Name}\".to_string(), Value::String(token_id));");
                WriteLine("        }");
                continue;
            }
 
            // Handle wrappers need to be converted to their handle JSON
            if (IsHandleType(parameter.Type))
            {
                if (parameter.IsOptional)
                {
                    WriteLine($"        if let Some(ref v) = {paramName} {{");
                    WriteLine($"            args.insert(\"{parameter.Name}\".to_string(), v.handle().to_json());");
                    WriteLine("        }");
                }
                else
                {
                    WriteLine($"        args.insert(\"{parameter.Name}\".to_string(), {paramName}.handle().to_json());");
                }
                continue;
            }
 
            if (parameter.IsOptional)
            {
                WriteLine($"        if let Some(ref v) = {paramName} {{");
                WriteLine($"            args.insert(\"{parameter.Name}\".to_string(), serde_json::to_value(v).unwrap_or(Value::Null));");
                WriteLine("        }");
            }
            else
            {
                WriteLine($"        args.insert(\"{parameter.Name}\".to_string(), serde_json::to_value(&{paramName}).unwrap_or(Value::Null));");
            }
        }
 
        WriteLine($"        let result = self.client.invoke_capability(\"{capability.CapabilityId}\", args)?;");
 
        if (hasReturn)
        {
            var returnTypeRef = capability.ReturnType;
 
            // Generate conversion based on return type
            if (IsHandleType(returnTypeRef))
            {
                var wrappedType = MapHandleType(returnTypeRef.TypeId);
                WriteLine($"        let handle: Handle = serde_json::from_value(result)?;");
                WriteLine($"        Ok({wrappedType}::new(handle, self.client.clone()))");
            }
            else if (returnTypeRef?.TypeId == AtsConstants.CancellationToken)
            {
                // CancellationToken needs special handling - create from handle
                WriteLine($"        let handle: Handle = serde_json::from_value(result)?;");
                WriteLine($"        Ok(CancellationToken::new(handle, self.client.clone()))");
            }
            else if (returnTypeRef?.Category == AtsTypeCategory.Dict && returnTypeRef.IsReadOnly == false)
            {
                // Handle-backed AspireDict
                WriteLine($"        let handle: Handle = serde_json::from_value(result)?;");
                WriteLine($"        Ok(AspireDict::new(handle, self.client.clone()))");
            }
            else if (returnTypeRef?.Category == AtsTypeCategory.List && returnTypeRef.IsReadOnly == false)
            {
                // Handle-backed AspireList
                WriteLine($"        let handle: Handle = serde_json::from_value(result)?;");
                WriteLine($"        Ok(AspireList::new(handle, self.client.clone()))");
            }
            else
            {
                WriteLine($"        Ok(serde_json::from_value(result)?)");
            }
        }
        else
        {
            WriteLine("        Ok(())");
        }
 
        WriteLine("    }");
    }
 
    private static bool IsListOrDictPropertyGetter(AtsTypeRef? returnType)
    {
        if (returnType is null)
        {
            return false;
        }
 
        return returnType.Category == AtsTypeCategory.List || returnType.Category == AtsTypeCategory.Dict;
    }
 
#pragma warning disable IDE0060 // Remove unused parameter - structName kept for API consistency
    private void GenerateListOrDictProperty(string structName, AtsCapabilityInfo capability, string methodName)
#pragma warning restore IDE0060
    {
        var returnType = capability.ReturnType!;
        var isDict = returnType.Category == AtsTypeCategory.Dict;
        var wrapperType = isDict ? "AspireDict" : "AspireList";
 
        // Determine type arguments
        string typeArgs;
        if (isDict)
        {
            var keyType = MapTypeRefToRust(returnType.KeyType, false);
            var valueType = MapTypeRefToRust(returnType.ValueType, false);
            typeArgs = $"<{keyType}, {valueType}>";
        }
        else
        {
            var elementType = MapTypeRefToRust(returnType.ElementType, false);
            typeArgs = $"<{elementType}>";
        }
 
        var fullType = $"{wrapperType}{typeArgs}";
 
        // Generate doc comment
        if (!string.IsNullOrEmpty(capability.Description))
        {
            WriteLine();
            WriteLine($"    /// {capability.Description}");
        }
 
        // Generate getter method that creates AspireList/AspireDict with lazy getter
        WriteLine($"    pub fn {methodName}(&self) -> {fullType} {{");
        WriteLine($"        {wrapperType}::with_getter(self.handle.clone(), self.client.clone(), \"{capability.CapabilityId}\")");
        WriteLine("    }");
    }
 
#pragma warning disable IDE0060 // Remove unused parameter - keeping for API consistency with other generators
    private void GenerateHandleWrapperRegistrations(
        IReadOnlyList<RustHandleType> handleTypes,
        HashSet<string> listTypeIds)
#pragma warning restore IDE0060
    {
        WriteLine("// ============================================================================");
        WriteLine("// Handle wrapper registrations");
        WriteLine("// ============================================================================");
        WriteLine();
        WriteLine("pub fn register_all_wrappers() {");
        WriteLine("    // Handle wrappers are created inline in generated code");
        WriteLine("    // This function is provided for API compatibility");
        WriteLine("}");
        WriteLine();
    }
 
    private void GenerateConnectionHelpers()
    {
        var builderStructName = _structNames.TryGetValue(AtsConstants.BuilderTypeId, out var name)
            ? name
            : "DistributedApplicationBuilder";
 
        WriteLine("// ============================================================================");
        WriteLine("// Connection Helpers");
        WriteLine("// ============================================================================");
        WriteLine();
        WriteLine("/// Establishes a connection to the AppHost server.");
        WriteLine("pub fn connect() -> Result<Arc<AspireClient>, Box<dyn std::error::Error>> {");
        WriteLine("    let socket_path = std::env::var(\"REMOTE_APP_HOST_SOCKET_PATH\")");
        WriteLine("        .map_err(|_| \"REMOTE_APP_HOST_SOCKET_PATH environment variable not set. Run this application using `aspire run`\")?;");
        WriteLine("    let client = Arc::new(AspireClient::new(&socket_path));");
        WriteLine("    client.connect()?;");
        WriteLine("    Ok(client)");
        WriteLine("}");
        WriteLine();
        WriteLine($"/// Creates a new distributed application builder.");
        WriteLine($"pub fn create_builder(options: Option<CreateBuilderOptions>) -> Result<{builderStructName}, Box<dyn std::error::Error>> {{");
        WriteLine("    let client = connect()?;");
        WriteLine("    let mut resolved_options: HashMap<String, Value> = HashMap::new();");
        WriteLine("    if let Some(opts) = options {");
        WriteLine("        for (k, v) in opts.to_map() {");
        WriteLine("            resolved_options.insert(k, v);");
        WriteLine("        }");
        WriteLine("    }");
        WriteLine("    if !resolved_options.contains_key(\"Args\") {");
        WriteLine("        let args: Vec<String> = std::env::args().skip(1).collect();");
        WriteLine("        resolved_options.insert(\"Args\".to_string(), serde_json::to_value(args).unwrap_or(Value::Null));");
        WriteLine("    }");
        WriteLine("    if !resolved_options.contains_key(\"ProjectDirectory\") {");
        WriteLine("        if let Ok(pwd) = std::env::current_dir() {");
        WriteLine("            resolved_options.insert(\"ProjectDirectory\".to_string(), Value::String(pwd.to_string_lossy().to_string()));");
        WriteLine("        }");
        WriteLine("    }");
        WriteLine("    let mut args: HashMap<String, Value> = HashMap::new();");
        WriteLine("    args.insert(\"options\".to_string(), serde_json::to_value(resolved_options).unwrap_or(Value::Null));");
        WriteLine("    let result = client.invoke_capability(\"Aspire.Hosting/createBuilderWithOptions\", args)?;");
        WriteLine("    let handle: Handle = serde_json::from_value(result)?;");
        WriteLine($"    Ok({builderStructName}::new(handle, client))");
        WriteLine("}");
        WriteLine();
    }
 
    private IReadOnlyList<RustHandleType> BuildHandleTypes(AtsContext context)
    {
        var handleTypeIds = new HashSet<string>(StringComparer.Ordinal);
        foreach (var handleType in context.HandleTypes)
        {
            // Skip ReferenceExpression - it's defined in base.rs
            if (handleType.AtsTypeId == AtsConstants.ReferenceExpressionTypeId)
            {
                continue;
            }
            handleTypeIds.Add(handleType.AtsTypeId);
        }
 
        foreach (var capability in context.Capabilities)
        {
            AddHandleTypeIfNeeded(handleTypeIds, capability.TargetType);
            AddHandleTypeIfNeeded(handleTypeIds, capability.ReturnType);
            foreach (var parameter in capability.Parameters)
            {
                AddHandleTypeIfNeeded(handleTypeIds, parameter.Type);
                if (parameter.IsCallback && parameter.CallbackParameters is not null)
                {
                    foreach (var callbackParam in parameter.CallbackParameters)
                    {
                        AddHandleTypeIfNeeded(handleTypeIds, callbackParam.Type);
                    }
                }
            }
        }
 
        _structNames.Clear();
        foreach (var typeId in handleTypeIds)
        {
            _structNames[typeId] = CreateStructName(typeId);
        }
 
        var handleTypeMap = context.HandleTypes.ToDictionary(t => t.AtsTypeId, StringComparer.Ordinal);
        var results = new List<RustHandleType>();
        foreach (var typeId in handleTypeIds)
        {
            var isResourceBuilder = false;
            if (handleTypeMap.TryGetValue(typeId, out var typeInfo))
            {
                isResourceBuilder = typeInfo.ClrType is not null &&
                    typeof(IResource).IsAssignableFrom(typeInfo.ClrType);
            }
 
            results.Add(new RustHandleType(typeId, _structNames[typeId], isResourceBuilder));
        }
 
        return results;
    }
 
    private static Dictionary<string, List<AtsCapabilityInfo>> GroupCapabilitiesByTarget(
        IReadOnlyList<AtsCapabilityInfo> capabilities)
    {
        var result = new Dictionary<string, List<AtsCapabilityInfo>>(StringComparer.Ordinal);
 
        foreach (var capability in capabilities)
        {
            if (string.IsNullOrEmpty(capability.TargetTypeId))
            {
                continue;
            }
 
            var targetTypes = capability.ExpandedTargetTypes.Count > 0
                ? capability.ExpandedTargetTypes
                : capability.TargetType is not null
                    ? [capability.TargetType]
                    : [];
 
            foreach (var targetType in targetTypes)
            {
                if (targetType.TypeId is null)
                {
                    continue;
                }
 
                if (!result.TryGetValue(targetType.TypeId, out var list))
                {
                    list = new List<AtsCapabilityInfo>();
                    result[targetType.TypeId] = list;
                }
                list.Add(capability);
            }
        }
 
        return result;
    }
 
    private static HashSet<string> CollectListAndDictTypeIds(IReadOnlyList<AtsCapabilityInfo> capabilities)
    {
        var typeIds = new HashSet<string>(StringComparer.Ordinal);
        foreach (var capability in capabilities)
        {
            AddListOrDictTypeIfNeeded(typeIds, capability.TargetType);
            AddListOrDictTypeIfNeeded(typeIds, capability.ReturnType);
            foreach (var parameter in capability.Parameters)
            {
                AddListOrDictTypeIfNeeded(typeIds, parameter.Type);
                if (parameter.IsCallback && parameter.CallbackParameters is not null)
                {
                    foreach (var callbackParam in parameter.CallbackParameters)
                    {
                        AddListOrDictTypeIfNeeded(typeIds, callbackParam.Type);
                    }
                }
            }
        }
 
        return typeIds;
    }
 
    private string MapTypeRefToRust(AtsTypeRef? typeRef, bool isOptional)
    {
        if (typeRef is null)
        {
            return "Value";
        }
 
        if (typeRef.TypeId == AtsConstants.ReferenceExpressionTypeId)
        {
            return isOptional ? "Option<ReferenceExpression>" : "ReferenceExpression";
        }
 
        var baseType = typeRef.Category switch
        {
            AtsTypeCategory.Primitive => MapPrimitiveType(typeRef.TypeId),
            AtsTypeCategory.Enum => MapEnumType(typeRef.TypeId),
            AtsTypeCategory.Handle => MapHandleType(typeRef.TypeId),
            AtsTypeCategory.Dto => MapDtoType(typeRef.TypeId),
            AtsTypeCategory.Callback => "Box<dyn Fn(Vec<Value>) -> Value + Send + Sync>",
            AtsTypeCategory.Array => $"Vec<{MapTypeRefToRust(typeRef.ElementType, false)}>",
            AtsTypeCategory.List => typeRef.IsReadOnly
                ? $"Vec<{MapTypeRefToRust(typeRef.ElementType, false)}>"
                : $"AspireList<{MapTypeRefToRust(typeRef.ElementType, false)}>",
            AtsTypeCategory.Dict => typeRef.IsReadOnly
                ? $"HashMap<{MapTypeRefToRust(typeRef.KeyType, false)}, {MapTypeRefToRust(typeRef.ValueType, false)}>"
                : $"AspireDict<{MapTypeRefToRust(typeRef.KeyType, false)}, {MapTypeRefToRust(typeRef.ValueType, false)}>",
            AtsTypeCategory.Union => "Value",
            AtsTypeCategory.Unknown => "Value",
            _ => "Value"
        };
 
        return isOptional ? $"Option<{baseType}>" : baseType;
    }
 
    /// <summary>
    /// Maps type references to Rust types for use in DTO fields.
    /// Handle types are mapped to Handle (the serializable type) instead of wrapper types.
    /// </summary>
    private string MapTypeRefToRustForDto(AtsTypeRef? typeRef, bool isOptional)
    {
        if (typeRef is null)
        {
            return "Value";
        }
 
        if (typeRef.TypeId == AtsConstants.ReferenceExpressionTypeId)
        {
            return isOptional ? "Option<ReferenceExpression>" : "ReferenceExpression";
        }
 
        var baseType = typeRef.Category switch
        {
            AtsTypeCategory.Primitive => MapPrimitiveType(typeRef.TypeId),
            AtsTypeCategory.Enum => MapEnumType(typeRef.TypeId),
            // Use Handle directly for handle types in DTOs since Handle implements Serialize/Deserialize
            AtsTypeCategory.Handle => "Handle",
            AtsTypeCategory.Dto => MapDtoType(typeRef.TypeId),
            AtsTypeCategory.Callback => "Value", // Callbacks can't be serialized in DTOs
            AtsTypeCategory.Array => $"Vec<{MapTypeRefToRustForDto(typeRef.ElementType, false)}>",
            AtsTypeCategory.List => $"Vec<{MapTypeRefToRustForDto(typeRef.ElementType, false)}>",
            AtsTypeCategory.Dict => $"HashMap<{MapTypeRefToRustForDto(typeRef.KeyType, false)}, {MapTypeRefToRustForDto(typeRef.ValueType, false)}>",
            AtsTypeCategory.Union => "Value",
            AtsTypeCategory.Unknown => "Value",
            _ => "Value"
        };
 
        return isOptional ? $"Option<{baseType}>" : baseType;
    }
 
    private string MapHandleType(string typeId) =>
        _structNames.TryGetValue(typeId, out var name) ? name : "Handle";
 
    private string MapDtoType(string typeId) =>
        _dtoNames.TryGetValue(typeId, out var name) ? name : "HashMap<String, Value>";
 
    private string MapEnumType(string typeId) =>
        _enumNames.TryGetValue(typeId, out var name) ? name : "String";
 
    private static string MapPrimitiveType(string typeId) => typeId switch
    {
        AtsConstants.String or AtsConstants.Char => "String",
        AtsConstants.Number => "f64",
        AtsConstants.Boolean => "bool",
        AtsConstants.Void => "()",
        AtsConstants.Any => "Value",
        AtsConstants.DateTime or AtsConstants.DateTimeOffset or
        AtsConstants.DateOnly or AtsConstants.TimeOnly => "String",
        AtsConstants.TimeSpan => "f64",
        AtsConstants.Guid or AtsConstants.Uri => "String",
        AtsConstants.CancellationToken => "CancellationToken",
        _ => "Value"
    };
 
    // Maps parameter types to more idiomatic Rust types (e.g., &str instead of String)
    private string MapParameterTypeToRust(AtsTypeRef? typeRef, bool isOptional)
    {
        if (typeRef is null)
        {
            return "Value";
        }
 
        // For primitives that are strings, use &str for parameters (more idiomatic)
        if (typeRef.Category == AtsTypeCategory.Primitive)
        {
            var baseType = typeRef.TypeId switch
            {
                AtsConstants.String or AtsConstants.Char => "&str",
                AtsConstants.Number => "f64",
                AtsConstants.Boolean => "bool",
                AtsConstants.Void => "()",
                AtsConstants.Any => "&Value",
                AtsConstants.DateTime or AtsConstants.DateTimeOffset or
                AtsConstants.DateOnly or AtsConstants.TimeOnly => "&str",
                AtsConstants.TimeSpan => "f64",
                AtsConstants.Guid or AtsConstants.Uri => "&str",
                AtsConstants.CancellationToken => "&CancellationToken",
                _ => "&Value"
            };
            return isOptional ? $"Option<{baseType}>" : baseType;
        }
 
        // For arrays/lists of strings, use Vec<String> since we need owned values
        // For other types, use the standard mapping
        return MapTypeRefToRust(typeRef, isOptional);
    }
 
    private static bool IsHandleType(AtsTypeRef? typeRef) =>
        typeRef?.Category == AtsTypeCategory.Handle
        && typeRef.TypeId != AtsConstants.ReferenceExpressionTypeId;
 
    private static bool IsCancellationToken(AtsParameterInfo parameter) =>
        parameter.Type?.TypeId == AtsConstants.CancellationToken;
 
    private static void AddHandleTypeIfNeeded(HashSet<string> handleTypeIds, AtsTypeRef? typeRef)
    {
        if (typeRef is null)
        {
            return;
        }
 
        // Skip ReferenceExpression - it's defined in base.rs
        if (typeRef.TypeId == AtsConstants.ReferenceExpressionTypeId)
        {
            return;
        }
 
        if (typeRef.Category == AtsTypeCategory.Handle)
        {
            handleTypeIds.Add(typeRef.TypeId);
        }
    }
 
    private static void AddListOrDictTypeIfNeeded(HashSet<string> typeIds, AtsTypeRef? typeRef)
    {
        if (typeRef is null)
        {
            return;
        }
 
        if (typeRef.Category == AtsTypeCategory.List || typeRef.Category == AtsTypeCategory.Dict)
        {
            if (!typeRef.IsReadOnly)
            {
                typeIds.Add(typeRef.TypeId);
            }
        }
    }
 
    private string CreateStructName(string typeId)
    {
        var baseName = ExtractTypeName(typeId);
        var name = SanitizeIdentifier(baseName);
        if (_structNames.Values.Contains(name, StringComparer.Ordinal))
        {
            var assemblyName = typeId.Split('/')[0];
            var assemblyPrefix = SanitizeIdentifier(assemblyName);
            name = $"{assemblyPrefix}{name}";
        }
 
        var counter = 1;
        var candidate = name;
        while (_structNames.Values.Contains(candidate, StringComparer.Ordinal))
        {
            counter++;
            candidate = $"{name}{counter}";
        }
 
        return candidate;
    }
 
    private static string ExtractTypeName(string typeId)
    {
        var slashIndex = typeId.IndexOf('/', StringComparison.Ordinal);
        var typeName = slashIndex >= 0 ? typeId[(slashIndex + 1)..] : typeId;
        var lastDot = typeName.LastIndexOf('.');
        var plusIndex = typeName.LastIndexOf('+');
        var delimiterIndex = Math.Max(lastDot, plusIndex);
        return delimiterIndex >= 0 ? typeName[(delimiterIndex + 1)..] : typeName;
    }
 
    private static string SanitizeIdentifier(string name)
    {
        if (string.IsNullOrWhiteSpace(name))
        {
            return "_";
        }
 
        var builder = new StringBuilder(name.Length);
        foreach (var ch in name)
        {
            builder.Append(char.IsLetterOrDigit(ch) || ch == '_' ? ch : '_');
        }
 
        if (!char.IsLetter(builder[0]) && builder[0] != '_')
        {
            builder.Insert(0, '_');
        }
 
        var sanitized = builder.ToString();
        return s_rustKeywords.Contains(sanitized) ? $"r#{sanitized}" : sanitized;
    }
 
    /// <summary>
    /// Converts a name to PascalCase for Rust type names.
    /// </summary>
    private static string ToPascalCase(string name)
    {
        if (string.IsNullOrEmpty(name))
        {
            return name;
        }
        if (char.IsUpper(name[0]))
        {
            return name;
        }
        return char.ToUpperInvariant(name[0]) + name[1..];
    }
 
    /// <summary>
    /// Converts a name to snake_case for Rust identifiers.
    /// </summary>
    private static string ToSnakeCase(string name)
    {
        if (string.IsNullOrEmpty(name))
        {
            return name;
        }
 
        return JsonNamingPolicy.SnakeCaseLower.ConvertName(name);
    }
 
    private void WriteLine(string value = "")
    {
        _writer.WriteLine(value);
    }
 
    private sealed record RustHandleType(string TypeId, string StructName, bool IsResourceBuilder);
}