File: OpenAIAssistantClient.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.OpenAI\Microsoft.Extensions.AI.OpenAI.csproj (Microsoft.Extensions.AI.OpenAI)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
using OpenAI;
using OpenAI.Assistants;
using OpenAI.Chat;
 
#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
#pragma warning disable CA1031 // Do not catch general exception types
#pragma warning disable S1067 // Expressions should not be too complex
#pragma warning disable S1751 // Loops with at most one iteration should be refactored
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
#pragma warning disable SA1204 // Static elements should appear before instance elements
#pragma warning disable SA1108 // Block statements should not contain embedded comments
 
namespace Microsoft.Extensions.AI;
 
/// <summary>Represents an <see cref="IChatClient"/> for an OpenAI <see cref="OpenAIClient"/> or <see cref="OpenAI.Chat.ChatClient"/>.</summary>
internal sealed class OpenAIAssistantClient : IChatClient
{
    /// <summary>Metadata for the client.</summary>
    private readonly ChatClientMetadata _metadata;
 
    /// <summary>The underlying <see cref="AssistantClient" />.</summary>
    private readonly AssistantClient _assistantClient;
 
    /// <summary>The ID of the assistant to use.</summary>
    private readonly string _assistantId;
 
    /// <summary>The thread ID to use if none is supplied in <see cref="ChatOptions.ChatThreadId"/>.</summary>
    private readonly string? _threadId;
 
    /// <summary>Initializes a new instance of the <see cref="OpenAIAssistantClient"/> class for the specified <see cref="AssistantClient"/>.</summary>
    /// <param name="assistantClient">The underlying client.</param>
    /// <param name="assistantId">The ID of the assistant to use.</param>
    /// <param name="threadId">
    /// The ID of the thread to use. If not supplied here, it should be supplied per request in <see cref="ChatOptions.ChatThreadId"/>.
    /// If none is supplied, a new thread will be created for a request.
    /// </param>
    public OpenAIAssistantClient(AssistantClient assistantClient, string assistantId, string? threadId)
    {
        _assistantClient = Throw.IfNull(assistantClient);
        _assistantId = Throw.IfNull(assistantId);
        _threadId = threadId;
 
        _metadata = new("openai");
    }
 
    /// <inheritdoc />
    public object? GetService(Type serviceType, object? serviceKey = null)
    {
        _ = Throw.IfNull(serviceType);
 
        return
            serviceKey is not null ? null :
            serviceType == typeof(ChatClientMetadata) ? _metadata :
            serviceType == typeof(AssistantClient) ? _assistantClient :
            serviceType.IsInstanceOfType(this) ? this :
            null;
    }
 
    /// <inheritdoc />
    public Task<ChatCompletion> CompleteAsync(
        IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) =>
        CompleteStreamingAsync(chatMessages, options, cancellationToken).ToChatCompletionAsync(coalesceContent: true, cancellationToken);
 
    /// <inheritdoc />
    public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
        IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        // Extract necessary state from chatMessages and options.
        (RunCreationOptions runOptions, List<FunctionResultContent>? toolResults) = CreateRunOptions(chatMessages, options);
 
        // Get the thread ID.
        string? threadId = options?.ChatThreadId ?? _threadId;
        if (threadId is null && toolResults is not null)
        {
            Throw.ArgumentException(nameof(chatMessages), "No thread ID was provided, but chat messages includes tool results.");
        }
 
        // Get the updates to process from the assistant. If we have any tool results, this means submitting those and ignoring
        // our runOptions. Otherwise, create a run, and a thread if we don't have one.
        IAsyncEnumerable<StreamingUpdate> updates;
        if (GetRunId(toolResults, out List<ToolOutput>? toolOutputs) is string existingRunId)
        {
            updates = _assistantClient.SubmitToolOutputsToRunStreamingAsync(threadId, existingRunId, toolOutputs, cancellationToken);
        }
        else if (threadId is null)
        {
            ThreadCreationOptions creationOptions = new();
            foreach (var message in runOptions.AdditionalMessages)
            {
                creationOptions.InitialMessages.Add(message);
            }
 
            runOptions.AdditionalMessages.Clear();
 
            updates = _assistantClient.CreateThreadAndRunStreamingAsync(_assistantId, creationOptions, runOptions, cancellationToken: cancellationToken);
        }
        else
        {
            updates = _assistantClient.CreateRunStreamingAsync(threadId, _assistantId, runOptions, cancellationToken);
        }
 
        // Process each update.
        await foreach (var update in updates.ConfigureAwait(false))
        {
            switch (update)
            {
                case MessageContentUpdate mcu:
                    yield return new()
                    {
                        ChatThreadId = threadId,
                        RawRepresentation = mcu,
                        Role = mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant,
                        Text = mcu.Text,
                    };
                    break;
 
                case ThreadUpdate tu when options is not null:
                    threadId ??= tu.Value.Id;
                    break;
 
                case RunUpdate ru:
                    threadId ??= ru.Value.ThreadId;
 
                    StreamingChatCompletionUpdate ruUpdate = new()
                    {
                        AuthorName = ru.Value.AssistantId,
                        ChatThreadId = threadId,
                        CompletionId = ru.Value.Id,
                        CreatedAt = ru.Value.CreatedAt,
                        ModelId = ru.Value.Model,
                        RawRepresentation = ru,
                        Role = ChatRole.Assistant,
                    };
 
                    if (ru.Value.Usage is { } usage)
                    {
                        ruUpdate.Contents.Add(new UsageContent(new()
                        {
                            InputTokenCount = usage.InputTokenCount,
                            OutputTokenCount = usage.OutputTokenCount,
                            TotalTokenCount = usage.TotalTokenCount,
                        }));
                    }
 
                    if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName)
                    {
                        ruUpdate.Contents.Add(
                            new FunctionCallContent(
                                JsonSerializer.Serialize(new[] { ru.Value.Id, toolCallId }, OpenAIJsonContext.Default.StringArray!),
                                functionName,
                                JsonSerializer.Deserialize(rau.FunctionArguments, OpenAIJsonContext.Default.IDictionaryStringObject)!));
                    }
 
                    yield return ruUpdate;
                    break;
            }
        }
    }
 
    /// <inheritdoc />
    void IDisposable.Dispose()
    {
        // Nothing to dispose. Implementation required for the IChatClient interface.
    }
 
    /// <summary>Adds the provided messages to the thread and returns the options to use for the request.</summary>
    private static (RunCreationOptions RunOptions, List<FunctionResultContent>? ToolResults) CreateRunOptions(IList<ChatMessage> chatMessages, ChatOptions? options)
    {
        _ = Throw.IfNull(chatMessages);
 
        RunCreationOptions runOptions = new();
 
        // Handle ChatOptions.
        if (options is not null)
        {
            // Propagate the simple properties that have a 1:1 correspondence.
            runOptions.MaxOutputTokenCount = options.MaxOutputTokens;
            runOptions.ModelOverride = options.ModelId;
            runOptions.NucleusSamplingFactor = options.TopP;
            runOptions.Temperature = options.Temperature;
 
            // Propagate additional properties from AdditionalProperties.
            if (options.AdditionalProperties?.TryGetValue(nameof(RunCreationOptions.AllowParallelToolCalls), out bool allowParallelToolCalls) is true)
            {
                runOptions.AllowParallelToolCalls = allowParallelToolCalls;
            }
 
            if (options.AdditionalProperties?.TryGetValue(nameof(RunCreationOptions.MaxInputTokenCount), out int maxInputTokenCount) is true)
            {
                runOptions.MaxInputTokenCount = maxInputTokenCount;
            }
 
            if (options.AdditionalProperties?.TryGetValue(nameof(RunCreationOptions.TruncationStrategy), out RunTruncationStrategy? truncationStrategy) is true)
            {
                runOptions.TruncationStrategy = truncationStrategy;
            }
 
            // Store all the tools to use.
            if (options.Tools is { Count: > 0 } tools)
            {
                foreach (AITool tool in tools)
                {
                    if (tool is AIFunction aiFunction)
                    {
                        bool? strict =
                            aiFunction.Metadata.AdditionalProperties.TryGetValue("Strict", out object? strictObj) &&
                            strictObj is bool strictValue ?
                            strictValue : null;
 
                        var functionParameters = BinaryData.FromBytes(
                            JsonSerializer.SerializeToUtf8Bytes(
                                JsonSerializer.Deserialize(aiFunction.Metadata.Schema, OpenAIJsonContext.Default.OpenAIChatToolJson)!,
                                OpenAIJsonContext.Default.OpenAIChatToolJson));
 
                        runOptions.ToolsOverride.Add(ToolDefinition.CreateFunction(aiFunction.Metadata.Name, aiFunction.Metadata.Description, functionParameters, strict));
                    }
                }
            }
 
            // Store the tool mode.
            switch (options.ToolMode)
            {
                case AutoChatToolMode:
                    runOptions.ToolConstraint = ToolConstraint.Auto;
                    break;
 
                case RequiredChatToolMode required:
                    runOptions.ToolConstraint = required.RequiredFunctionName is null ?
                        new ToolConstraint(ToolDefinition.CreateFunction(required.RequiredFunctionName)) :
                        ToolConstraint.Required;
                    break;
            }
 
            // Store the response format.
            if (options.ResponseFormat is ChatResponseFormatText)
            {
                runOptions.ResponseFormat = AssistantResponseFormat.Text;
            }
            else if (options.ResponseFormat is ChatResponseFormatJson jsonFormat)
            {
                runOptions.ResponseFormat = jsonFormat.Schema is { } jsonSchema ?
                    AssistantResponseFormat.CreateJsonSchemaFormat(
                        jsonFormat.SchemaName ?? "json_schema",
                        BinaryData.FromBytes(JsonSerializer.SerializeToUtf8Bytes(jsonSchema, OpenAIJsonContext.Default.JsonElement)),
                        jsonFormat.SchemaDescription) :
                    AssistantResponseFormat.JsonObject;
            }
        }
 
        // Handle ChatMessages. System messages are turned into additional instructions.
        StringBuilder? instructions = null;
        List<FunctionResultContent>? functionResults = null;
        foreach (var chatMessage in chatMessages)
        {
            List<MessageContent> messageContents = [];
 
            if (chatMessage.Role == ChatRole.System ||
                chatMessage.Role == OpenAIModelMappers.ChatRoleDeveloper)
            {
                instructions ??= new();
                foreach (var textContent in chatMessage.Contents.OfType<TextContent>())
                {
                    _ = instructions.Append(textContent);
                }
 
                continue;
            }
 
            foreach (AIContent content in chatMessage.Contents)
            {
                switch (content)
                {
                    case TextContent tc:
                        messageContents.Add(MessageContent.FromText(tc.Text));
                        break;
 
                    case DataContent dc when dc.MediaTypeStartsWith("image/"):
                        messageContents.Add(MessageContent.FromImageUri(new(dc.Uri)));
                        break;
 
                    case FunctionResultContent frc:
                        (functionResults ??= []).Add(frc);
                        break;
                }
            }
 
            if (messageContents.Count > 0)
            {
                runOptions.AdditionalMessages.Add(new(
                    chatMessage.Role == ChatRole.Assistant ? MessageRole.Assistant : MessageRole.User,
                    messageContents));
            }
        }
 
        if (instructions is not null)
        {
            runOptions.AdditionalInstructions = instructions.ToString();
        }
 
        return (runOptions, functionResults);
    }
 
    private static string? GetRunId(List<FunctionResultContent>? toolResults, out List<ToolOutput>? toolOutputs)
    {
        string? runId = null;
        toolOutputs = null;
        if (toolResults?.Count > 0)
        {
            foreach (var frc in toolResults)
            {
                // When creating the FunctionCallContext, we created it with a CallId == [runId, callId].
                // We need to extract the run ID and ensure that the ToolOutput we send back to OpenAI
                // is only the call ID.
                string[]? runAndCallIDs;
                try
                {
                    runAndCallIDs = JsonSerializer.Deserialize(frc.CallId, OpenAIJsonContext.Default.StringArray);
                }
                catch
                {
                    continue;
                }
 
                if (runAndCallIDs is null ||
                    runAndCallIDs.Length != 2 ||
                    string.IsNullOrWhiteSpace(runAndCallIDs[0]) || // run ID
                    string.IsNullOrWhiteSpace(runAndCallIDs[1]) || // call ID
                    (runId is not null && runId != runAndCallIDs[0]))
                {
                    continue;
                }
 
                runId = runAndCallIDs[0];
                (toolOutputs ??= []).Add(new(runAndCallIDs[1], frc.Result?.ToString() ?? string.Empty));
            }
        }
 
        return runId;
    }
}