File: OpenAIModelMappers.StreamingChatCompletion.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.OpenAI\Microsoft.Extensions.AI.OpenAI.csproj (Microsoft.Extensions.AI.OpenAI)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#pragma warning disable SA1204 // Static elements should appear before instance elements
#pragma warning disable S103 // Lines should not be too long
#pragma warning disable CA1859 // Use concrete types when possible for improved performance
 
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using OpenAI.Chat;
 
namespace Microsoft.Extensions.AI;
 
internal static partial class OpenAIModelMappers
{
    public static async IAsyncEnumerable<OpenAI.Chat.StreamingChatCompletionUpdate> ToOpenAIStreamingChatCompletionAsync(
        IAsyncEnumerable<StreamingChatCompletionUpdate> chatCompletions,
        JsonSerializerOptions options,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        await foreach (var chatCompletionUpdate in chatCompletions.WithCancellation(cancellationToken).ConfigureAwait(false))
        {
            List<StreamingChatToolCallUpdate>? toolCallUpdates = null;
            ChatTokenUsage? chatTokenUsage = null;
 
            foreach (var content in chatCompletionUpdate.Contents)
            {
                if (content is FunctionCallContent functionCallContent)
                {
                    toolCallUpdates ??= [];
                    toolCallUpdates.Add(OpenAIChatModelFactory.StreamingChatToolCallUpdate(
                        index: toolCallUpdates.Count,
                        toolCallId: functionCallContent.CallId,
                        functionName: functionCallContent.Name,
                        functionArgumentsUpdate: new(JsonSerializer.SerializeToUtf8Bytes(functionCallContent.Arguments, options.GetTypeInfo(typeof(IDictionary<string, object?>))))));
                }
                else if (content is UsageContent usageContent)
                {
                    chatTokenUsage = ToOpenAIUsage(usageContent.Details);
                }
            }
 
            yield return OpenAIChatModelFactory.StreamingChatCompletionUpdate(
                completionId: chatCompletionUpdate.CompletionId,
                model: chatCompletionUpdate.ModelId,
                createdAt: chatCompletionUpdate.CreatedAt ?? default,
                role: ToOpenAIChatRole(chatCompletionUpdate.Role),
                finishReason: ToOpenAIFinishReason(chatCompletionUpdate.FinishReason),
                contentUpdate: [.. ToOpenAIChatContent(chatCompletionUpdate.Contents)],
                toolCallUpdates: toolCallUpdates,
                refusalUpdate: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault<string>(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.RefusalUpdate)),
                contentTokenLogProbabilities: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault<IReadOnlyList<ChatTokenLogProbabilityDetails>>(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.ContentTokenLogProbabilities)),
                refusalTokenLogProbabilities: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault<IReadOnlyList<ChatTokenLogProbabilityDetails>>(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.RefusalTokenLogProbabilities)),
                systemFingerprint: chatCompletionUpdate.AdditionalProperties.GetValueOrDefault<string>(nameof(OpenAI.Chat.StreamingChatCompletionUpdate.SystemFingerprint)),
                usage: chatTokenUsage);
        }
    }
 
    public static async IAsyncEnumerable<StreamingChatCompletionUpdate> FromOpenAIStreamingChatCompletionAsync(
        IAsyncEnumerable<OpenAI.Chat.StreamingChatCompletionUpdate> chatCompletionUpdates,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        Dictionary<int, FunctionCallInfo>? functionCallInfos = null;
        ChatRole? streamedRole = null;
        ChatFinishReason? finishReason = null;
        StringBuilder? refusal = null;
        string? completionId = null;
        DateTimeOffset? createdAt = null;
        string? modelId = null;
        string? fingerprint = null;
 
        // Process each update as it arrives
        await foreach (OpenAI.Chat.StreamingChatCompletionUpdate chatCompletionUpdate in chatCompletionUpdates.WithCancellation(cancellationToken).ConfigureAwait(false))
        {
            // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates.
            streamedRole ??= chatCompletionUpdate.Role is ChatMessageRole role ? FromOpenAIChatRole(role) : null;
            finishReason ??= chatCompletionUpdate.FinishReason is OpenAI.Chat.ChatFinishReason reason ? FromOpenAIFinishReason(reason) : null;
            completionId ??= chatCompletionUpdate.CompletionId;
            createdAt ??= chatCompletionUpdate.CreatedAt;
            modelId ??= chatCompletionUpdate.Model;
            fingerprint ??= chatCompletionUpdate.SystemFingerprint;
 
            // Create the response content object.
            StreamingChatCompletionUpdate completionUpdate = new()
            {
                CompletionId = chatCompletionUpdate.CompletionId,
                CreatedAt = chatCompletionUpdate.CreatedAt,
                FinishReason = finishReason,
                ModelId = modelId,
                RawRepresentation = chatCompletionUpdate,
                Role = streamedRole,
            };
 
            // Populate it with any additional metadata from the OpenAI object.
            if (chatCompletionUpdate.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs)
            {
                (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.ContentTokenLogProbabilities)] = contentTokenLogProbs;
            }
 
            if (chatCompletionUpdate.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs)
            {
                (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.RefusalTokenLogProbabilities)] = refusalTokenLogProbs;
            }
 
            if (fingerprint is not null)
            {
                (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.SystemFingerprint)] = fingerprint;
            }
 
            // Transfer over content update items.
            if (chatCompletionUpdate.ContentUpdate is { Count: > 0 })
            {
                foreach (ChatMessageContentPart contentPart in chatCompletionUpdate.ContentUpdate)
                {
                    if (ToAIContent(contentPart) is AIContent aiContent)
                    {
                        completionUpdate.Contents.Add(aiContent);
                    }
                }
            }
 
            // Transfer over refusal updates.
            if (chatCompletionUpdate.RefusalUpdate is not null)
            {
                _ = (refusal ??= new()).Append(chatCompletionUpdate.RefusalUpdate);
            }
 
            // Transfer over tool call updates.
            if (chatCompletionUpdate.ToolCallUpdates is { Count: > 0 } toolCallUpdates)
            {
                foreach (StreamingChatToolCallUpdate toolCallUpdate in toolCallUpdates)
                {
                    functionCallInfos ??= [];
                    if (!functionCallInfos.TryGetValue(toolCallUpdate.Index, out FunctionCallInfo? existing))
                    {
                        functionCallInfos[toolCallUpdate.Index] = existing = new();
                    }
 
                    existing.CallId ??= toolCallUpdate.ToolCallId;
                    existing.Name ??= toolCallUpdate.FunctionName;
                    if (toolCallUpdate.FunctionArgumentsUpdate is { } update && !update.ToMemory().IsEmpty)
                    {
                        _ = (existing.Arguments ??= new()).Append(update.ToString());
                    }
                }
            }
 
            // Transfer over usage updates.
            if (chatCompletionUpdate.Usage is ChatTokenUsage tokenUsage)
            {
                var usageDetails = FromOpenAIUsage(tokenUsage);
                completionUpdate.Contents.Add(new UsageContent(usageDetails));
            }
 
            // Now yield the item.
            yield return completionUpdate;
        }
 
        // Now that we've received all updates, combine any for function calls into a single item to yield.
        if (functionCallInfos is not null)
        {
            StreamingChatCompletionUpdate completionUpdate = new()
            {
                CompletionId = completionId,
                CreatedAt = createdAt,
                FinishReason = finishReason,
                ModelId = modelId,
                Role = streamedRole,
            };
 
            foreach (var entry in functionCallInfos)
            {
                FunctionCallInfo fci = entry.Value;
                if (!string.IsNullOrWhiteSpace(fci.Name))
                {
                    var callContent = ParseCallContentFromJsonString(
                        fci.Arguments?.ToString() ?? string.Empty,
                        fci.CallId!,
                        fci.Name!);
                    completionUpdate.Contents.Add(callContent);
                }
            }
 
            // Refusals are about the model not following the schema for tool calls. As such, if we have any refusal,
            // add it to this function calling item.
            if (refusal is not null)
            {
                (completionUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString();
            }
 
            // Propagate additional relevant metadata.
            if (fingerprint is not null)
            {
                (completionUpdate.AdditionalProperties ??= [])[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)] = fingerprint;
            }
 
            yield return completionUpdate;
        }
    }
}