File: AzureAIInferenceChatClient.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.AzureAIInference\Microsoft.Extensions.AI.AzureAIInference.csproj (Microsoft.Extensions.AI.AzureAIInference)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.Inference;
using Microsoft.Shared.Diagnostics;
 
#pragma warning disable S1135 // Track uses of "TODO" tags
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
#pragma warning disable SA1204 // Static elements should appear before instance elements
 
namespace Microsoft.Extensions.AI;
 
/// <summary>Represents an <see cref="IChatClient"/> for an Azure AI Inference <see cref="ChatCompletionsClient"/>.</summary>
public sealed class AzureAIInferenceChatClient : IChatClient
{
    /// <summary>A default schema to use when a parameter lacks a pre-defined schema.</summary>
    private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}"{}").RootElement;
 
    /// <summary>The underlying <see cref="ChatCompletionsClient" />.</summary>
    private readonly ChatCompletionsClient _chatCompletionsClient;
 
    /// <summary>The <see cref="JsonSerializerOptions"/> use for any serialization activities related to tool call arguments and results.</summary>
    private JsonSerializerOptions _toolCallJsonSerializerOptions = AIJsonUtilities.DefaultOptions;
 
    /// <summary>Initializes a new instance of the <see cref="AzureAIInferenceChatClient"/> class for the specified <see cref="ChatCompletionsClient"/>.</summary>
    /// <param name="chatCompletionsClient">The underlying client.</param>
    /// <param name="modelId">The ID of the model to use. If null, it can be provided per request via <see cref="ChatOptions.ModelId"/>.</param>
    public AzureAIInferenceChatClient(ChatCompletionsClient chatCompletionsClient, string? modelId = null)
    {
        _ = Throw.IfNull(chatCompletionsClient);
        if (modelId is not null)
        {
            _ = Throw.IfNullOrWhitespace(modelId);
        }
 
        _chatCompletionsClient = chatCompletionsClient;
 
        // https://github.com/Azure/azure-sdk-for-net/issues/46278
        // The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages
        // implement the abstractions directly rather than providing adapters on top of the public APIs,
        // the package can provide such implementations separate from what's exposed in the public API.
        var providerUrl = typeof(ChatCompletionsClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
            ?.GetValue(chatCompletionsClient) as Uri;
 
        Metadata = new("az.ai.inference", providerUrl, modelId);
    }
 
    /// <summary>Gets or sets <see cref="JsonSerializerOptions"/> to use for any serialization activities related to tool call arguments and results.</summary>
    public JsonSerializerOptions ToolCallJsonSerializerOptions
    {
        get => _toolCallJsonSerializerOptions;
        set => _toolCallJsonSerializerOptions = Throw.IfNull(value);
    }
 
    /// <inheritdoc />
    public ChatClientMetadata Metadata { get; }
 
    /// <inheritdoc />
    public object? GetService(Type serviceType, object? serviceKey = null)
    {
        _ = Throw.IfNull(serviceType);
 
        return
            serviceKey is not null ? null :
            serviceType == typeof(ChatCompletionsClient) ? _chatCompletionsClient :
            serviceType.IsInstanceOfType(this) ? this :
            null;
    }
 
    /// <inheritdoc />
    public async Task<ChatCompletion> CompleteAsync(
        IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
    {
        _ = Throw.IfNull(chatMessages);
 
        // Make the call.
        ChatCompletions response = (await _chatCompletionsClient.CompleteAsync(
            ToAzureAIOptions(chatMessages, options),
            cancellationToken: cancellationToken).ConfigureAwait(false)).Value;
 
        // Create the return message.
        List<ChatMessage> returnMessages = [];
 
        // Populate its content from those in the response content.
        ChatMessage message = new()
        {
            RawRepresentation = response,
            Role = ToChatRole(response.Role),
        };
 
        if (response.Content is string content)
        {
            message.Text = content;
        }
 
        if (response.ToolCalls is { Count: > 0 } toolCalls)
        {
            foreach (var toolCall in toolCalls)
            {
                if (toolCall is ChatCompletionsToolCall ftc && !string.IsNullOrWhiteSpace(ftc.Name))
                {
                    FunctionCallContent callContent = ParseCallContentFromJsonString(ftc.Arguments, toolCall.Id, ftc.Name);
                    callContent.RawRepresentation = toolCall;
 
                    message.Contents.Add(callContent);
                }
            }
        }
 
        returnMessages.Add(message);
 
        UsageDetails? usage = null;
        if (response.Usage is CompletionsUsage completionsUsage)
        {
            usage = new()
            {
                InputTokenCount = completionsUsage.PromptTokens,
                OutputTokenCount = completionsUsage.CompletionTokens,
                TotalTokenCount = completionsUsage.TotalTokens,
            };
        }
 
        // Wrap the content in a ChatCompletion to return.
        return new ChatCompletion(returnMessages)
        {
            CompletionId = response.Id,
            CreatedAt = response.Created,
            ModelId = response.Model,
            FinishReason = ToFinishReason(response.FinishReason),
            RawRepresentation = response,
            Usage = usage,
        };
    }
 
    /// <inheritdoc />
    public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
        IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        _ = Throw.IfNull(chatMessages);
 
        Dictionary<string, FunctionCallInfo>? functionCallInfos = null;
        ChatRole? streamedRole = default;
        ChatFinishReason? finishReason = default;
        string? completionId = null;
        DateTimeOffset? createdAt = null;
        string? modelId = null;
        string lastCallId = string.Empty;
 
        // Process each update as it arrives
        var updates = await _chatCompletionsClient.CompleteStreamingAsync(ToAzureAIOptions(chatMessages, options), cancellationToken).ConfigureAwait(false);
        await foreach (StreamingChatCompletionsUpdate chatCompletionUpdate in updates.ConfigureAwait(false))
        {
            // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates.
            streamedRole ??= chatCompletionUpdate.Role is global::Azure.AI.Inference.ChatRole role ? ToChatRole(role) : null;
            finishReason ??= chatCompletionUpdate.FinishReason is CompletionsFinishReason reason ? ToFinishReason(reason) : null;
            completionId ??= chatCompletionUpdate.Id;
            createdAt ??= chatCompletionUpdate.Created;
            modelId ??= chatCompletionUpdate.Model;
 
            // Create the response content object.
            StreamingChatCompletionUpdate completionUpdate = new()
            {
                CompletionId = chatCompletionUpdate.Id,
                CreatedAt = chatCompletionUpdate.Created,
                FinishReason = finishReason,
                ModelId = modelId,
                RawRepresentation = chatCompletionUpdate,
                Role = streamedRole,
            };
 
            // Transfer over content update items.
            if (chatCompletionUpdate.ContentUpdate is string update)
            {
                completionUpdate.Contents.Add(new TextContent(update));
            }
 
            // Transfer over tool call updates.
            if (chatCompletionUpdate.ToolCallUpdate is { } toolCallUpdate)
            {
                // TODO https://github.com/Azure/azure-sdk-for-net/issues/46830: Azure.AI.Inference
                // has removed the Index property from ToolCallUpdate. It's now impossible via the
                // exposed APIs to correctly handle multiple parallel tool calls, as the CallId is
                // often null for anything other than the first update for a given call, and Index
                // isn't available to correlate which updates are for which call. This is a temporary
                // workaround to at least make a single tool call work and also make work multiple
                // tool calls when their updates aren't interleaved.
                if (toolCallUpdate.Id is not null)
                {
                    lastCallId = toolCallUpdate.Id;
                }
 
                functionCallInfos ??= [];
                if (!functionCallInfos.TryGetValue(lastCallId, out FunctionCallInfo? existing))
                {
                    functionCallInfos[lastCallId] = existing = new();
                }
 
                existing.Name ??= toolCallUpdate.Function.Name;
                if (toolCallUpdate.Function.Arguments is { } arguments)
                {
                    _ = (existing.Arguments ??= new()).Append(arguments);
                }
            }
 
            if (chatCompletionUpdate.Usage is { } usage)
            {
                completionUpdate.Contents.Add(new UsageContent(new()
                {
                    InputTokenCount = usage.PromptTokens,
                    OutputTokenCount = usage.CompletionTokens,
                    TotalTokenCount = usage.TotalTokens,
                }));
            }
 
            // Now yield the item.
            yield return completionUpdate;
        }
 
        // Now that we've received all updates, combine any for function calls into a single item to yield.
        if (functionCallInfos is not null)
        {
            var completionUpdate = new StreamingChatCompletionUpdate
            {
                CompletionId = completionId,
                CreatedAt = createdAt,
                FinishReason = finishReason,
                ModelId = modelId,
                Role = streamedRole,
            };
 
            foreach (var entry in functionCallInfos)
            {
                FunctionCallInfo fci = entry.Value;
                if (!string.IsNullOrWhiteSpace(fci.Name))
                {
                    FunctionCallContent callContent = ParseCallContentFromJsonString(
                        fci.Arguments?.ToString() ?? string.Empty,
                        entry.Key,
                        fci.Name!);
                    completionUpdate.Contents.Add(callContent);
                }
            }
 
            yield return completionUpdate;
        }
    }
 
    /// <inheritdoc />
    void IDisposable.Dispose()
    {
        // Nothing to dispose. Implementation required for the IChatClient interface.
    }
 
    /// <summary>POCO representing function calling info. Used to concatenation information for a single function call from across multiple streaming updates.</summary>
    private sealed class FunctionCallInfo
    {
        public string? Name;
        public StringBuilder? Arguments;
    }
 
    /// <summary>Converts an AzureAI role to an Extensions role.</summary>
    private static ChatRole ToChatRole(global::Azure.AI.Inference.ChatRole role) =>
        role.Equals(global::Azure.AI.Inference.ChatRole.System) ? ChatRole.System :
        role.Equals(global::Azure.AI.Inference.ChatRole.User) ? ChatRole.User :
        role.Equals(global::Azure.AI.Inference.ChatRole.Assistant) ? ChatRole.Assistant :
        role.Equals(global::Azure.AI.Inference.ChatRole.Tool) ? ChatRole.Tool :
        new ChatRole(role.ToString());
 
    /// <summary>Converts an AzureAI finish reason to an Extensions finish reason.</summary>
    private static ChatFinishReason? ToFinishReason(CompletionsFinishReason? finishReason) =>
        finishReason?.ToString() is not string s ? null :
        finishReason == CompletionsFinishReason.Stopped ? ChatFinishReason.Stop :
        finishReason == CompletionsFinishReason.TokenLimitReached ? ChatFinishReason.Length :
        finishReason == CompletionsFinishReason.ContentFiltered ? ChatFinishReason.ContentFilter :
        finishReason == CompletionsFinishReason.ToolCalls ? ChatFinishReason.ToolCalls :
        new(s);
 
    /// <summary>Converts an extensions options instance to an AzureAI options instance.</summary>
    private ChatCompletionsOptions ToAzureAIOptions(IList<ChatMessage> chatContents, ChatOptions? options)
    {
        ChatCompletionsOptions result = new(ToAzureAIInferenceChatMessages(chatContents))
        {
            Model = options?.ModelId ?? Metadata.ModelId ?? throw new InvalidOperationException("No model id was provided when either constructing the client or in the chat options.")
        };
 
        if (options is not null)
        {
            result.FrequencyPenalty = options.FrequencyPenalty;
            result.MaxTokens = options.MaxOutputTokens;
            result.NucleusSamplingFactor = options.TopP;
            result.PresencePenalty = options.PresencePenalty;
            result.Temperature = options.Temperature;
            result.Seed = options.Seed;
 
            if (options.StopSequences is { Count: > 0 } stopSequences)
            {
                foreach (string stopSequence in stopSequences)
                {
                    result.StopSequences.Add(stopSequence);
                }
            }
 
            // These properties are strongly typed on ChatOptions but not on ChatCompletionsOptions.
            if (options.TopK is int topK)
            {
                result.AdditionalProperties["top_k"] = new BinaryData(JsonSerializer.SerializeToUtf8Bytes(topK, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(int))));
            }
 
            if (options.AdditionalProperties is { } props)
            {
                foreach (var prop in props)
                {
                    switch (prop.Key)
                    {
                        // Propagate everything else to the ChatCompletionOptions' AdditionalProperties.
                        default:
                            if (prop.Value is not null)
                            {
                                byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
                                result.AdditionalProperties[prop.Key] = new BinaryData(data);
                            }
 
                            break;
                    }
                }
            }
 
            if (options.Tools is { Count: > 0 } tools)
            {
                foreach (AITool tool in tools)
                {
                    if (tool is AIFunction af)
                    {
                        result.Tools.Add(ToAzureAIChatTool(af));
                    }
                }
 
                switch (options.ToolMode)
                {
                    case AutoChatToolMode:
                        result.ToolChoice = ChatCompletionsToolChoice.Auto;
                        break;
 
                    case RequiredChatToolMode required:
                        result.ToolChoice = required.RequiredFunctionName is null ?
                            ChatCompletionsToolChoice.Required :
                            new ChatCompletionsToolChoice(new FunctionDefinition(required.RequiredFunctionName));
                        break;
                }
            }
 
            if (options.ResponseFormat is ChatResponseFormatText)
            {
                result.ResponseFormat = new ChatCompletionsResponseFormatText();
            }
            else if (options.ResponseFormat is ChatResponseFormatJson)
            {
                result.ResponseFormat = new ChatCompletionsResponseFormatJSON();
            }
        }
 
        return result;
    }
 
    /// <summary>Converts an Extensions function to an AzureAI chat tool.</summary>
    private static ChatCompletionsToolDefinition ToAzureAIChatTool(AIFunction aiFunction)
    {
        BinaryData resultParameters = AzureAIChatToolJson.ZeroFunctionParametersSchema;
 
        var parameters = aiFunction.Metadata.Parameters;
        if (parameters is { Count: > 0 })
        {
            AzureAIChatToolJson tool = new();
 
            foreach (AIFunctionParameterMetadata parameter in parameters)
            {
                tool.Properties.Add(
                    parameter.Name,
                    parameter.Schema is JsonElement schema ? schema : _defaultParameterSchema);
 
                if (parameter.IsRequired)
                {
                    tool.Required.Add(parameter.Name);
                }
            }
 
            resultParameters = BinaryData.FromBytes(
                JsonSerializer.SerializeToUtf8Bytes(tool, JsonContext.Default.AzureAIChatToolJson));
        }
 
        return new(new FunctionDefinition(aiFunction.Metadata.Name)
        {
            Description = aiFunction.Metadata.Description,
            Parameters = resultParameters,
        });
    }
 
    /// <summary>Converts an Extensions chat message enumerable to an AzureAI chat message enumerable.</summary>
    private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IList<ChatMessage> inputs)
    {
        // Maps all of the M.E.AI types to the corresponding AzureAI types.
        // Unrecognized or non-processable content is ignored.
 
        foreach (ChatMessage input in inputs)
        {
            if (input.Role == ChatRole.System)
            {
                yield return new ChatRequestSystemMessage(input.Text ?? string.Empty);
            }
            else if (input.Role == ChatRole.Tool)
            {
                foreach (AIContent item in input.Contents)
                {
                    if (item is FunctionResultContent resultContent)
                    {
                        string? result = resultContent.Result as string;
                        if (result is null && resultContent.Result is not null)
                        {
                            try
                            {
                                result = JsonSerializer.Serialize(resultContent.Result, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
                            }
                            catch (NotSupportedException)
                            {
                                // If the type can't be serialized, skip it.
                            }
                        }
 
                        yield return new ChatRequestToolMessage(result ?? string.Empty, resultContent.CallId);
                    }
                }
            }
            else if (input.Role == ChatRole.User)
            {
                yield return input.Contents.All(c => c is TextContent) ?
                    new ChatRequestUserMessage(string.Concat(input.Contents)) :
                    new ChatRequestUserMessage(GetContentParts(input.Contents));
            }
            else if (input.Role == ChatRole.Assistant)
            {
                // TODO: ChatRequestAssistantMessage only enables text content currently.
                // Update it with other content types when it supports that.
                ChatRequestAssistantMessage message = new(string.Concat(input.Contents.Where(c => c is TextContent)));
 
                foreach (var content in input.Contents)
                {
                    if (content is FunctionCallContent { CallId: not null } callRequest)
                    {
                        message.ToolCalls.Add(new ChatCompletionsToolCall(
                             callRequest.CallId,
                             new FunctionCall(
                                 callRequest.Name,
                                 JsonSerializer.Serialize(callRequest.Arguments, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(IDictionary<string, object>))))));
                    }
                }
 
                yield return message;
            }
        }
    }
 
    /// <summary>Converts a list of <see cref="AIContent"/> to a list of <see cref="ChatMessageContentItem"/>.</summary>
    private static List<ChatMessageContentItem> GetContentParts(IList<AIContent> contents)
    {
        Debug.Assert(contents is { Count: > 0 }, "Expected non-empty contents");
 
        List<ChatMessageContentItem> parts = [];
        foreach (var content in contents)
        {
            switch (content)
            {
                case TextContent textContent:
                    parts.Add(new ChatMessageTextContentItem(textContent.Text));
                    break;
 
                case ImageContent imageContent when imageContent.Data is { IsEmpty: false } data:
                    parts.Add(new ChatMessageImageContentItem(BinaryData.FromBytes(data), imageContent.MediaType));
                    break;
 
                case ImageContent imageContent when imageContent.Uri is string uri:
                    parts.Add(new ChatMessageImageContentItem(new Uri(uri)));
                    break;
            }
        }
 
        return parts;
    }
 
    private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
        FunctionCallContent.CreateFromParsedArguments(json, callId, name,
            argumentParser: static json => JsonSerializer.Deserialize(json,
                (JsonTypeInfo<IDictionary<string, object>>)AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary<string, object>)))!);
}