File: OpenAIRealtimeExtensions.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.OpenAI\Microsoft.Extensions.AI.OpenAI.csproj (Microsoft.Extensions.AI.OpenAI)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
using OpenAI.RealtimeConversation;
 
namespace Microsoft.Extensions.AI;
 
/// <summary>
/// Provides extension methods for working with <see cref="RealtimeConversationSession"/> and related types.
/// </summary>
public static class OpenAIRealtimeExtensions
{
    private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}"{}").RootElement;
 
    /// <summary>
    /// Converts a <see cref="AIFunction"/> into a <see cref="ConversationFunctionTool"/> so that
    /// it can be used with <see cref="RealtimeConversationClient"/>.
    /// </summary>
    /// <returns>A <see cref="ConversationFunctionTool"/> that can be used with <see cref="RealtimeConversationClient"/>.</returns>
    public static ConversationFunctionTool ToConversationFunctionTool(this AIFunction aiFunction)
    {
        _ = Throw.IfNull(aiFunction);
 
        var parametersSchema = new ConversationFunctionToolParametersSchema
        {
            Type = "object",
            Properties = aiFunction.Metadata.Parameters
                .ToDictionary(p => p.Name, GetParameterSchema),
            Required = aiFunction.Metadata.Parameters
                .Where(p => p.IsRequired)
                .Select(p => p.Name),
        };
 
        return new ConversationFunctionTool
        {
            Name = aiFunction.Metadata.Name,
            Description = aiFunction.Metadata.Description,
            Parameters = new BinaryData(JsonSerializer.SerializeToUtf8Bytes(
                parametersSchema, OpenAIJsonContext.Default.ConversationFunctionToolParametersSchema))
        };
    }
 
    /// <summary>
    /// Handles tool calls.
    ///
    /// If the <paramref name="update"/> represents a tool call, calls the corresponding tool and
    /// adds the result to the <paramref name="session"/>.
    ///
    /// If the <paramref name="update"/> represents the end of a response, checks if this was due
    /// to a tool call and if so, instructs the <paramref name="session"/> to begin responding to it.
    /// </summary>
    /// <param name="session">The <see cref="RealtimeConversationSession"/>.</param>
    /// <param name="update">The <see cref="ConversationUpdate"/> being processed.</param>
    /// <param name="tools">The available tools.</param>
    /// <param name="detailedErrors">An optional flag specifying whether to disclose detailed exception information to the model. The default value is <see langword="false"/>.</param>
    /// <param name="jsonSerializerOptions">An optional <see cref="JsonSerializerOptions"/> that controls JSON handling.</param>
    /// <param name="cancellationToken">An optional <see cref="CancellationToken"/>.</param>
    /// <returns>A <see cref="Task"/> that represents the completion of processing, including invoking any asynchronous tools.</returns>
    public static async Task HandleToolCallsAsync(
        this RealtimeConversationSession session,
        ConversationUpdate update,
        IReadOnlyList<AIFunction> tools,
        bool? detailedErrors = false,
        JsonSerializerOptions? jsonSerializerOptions = null,
        CancellationToken cancellationToken = default)
    {
        _ = Throw.IfNull(session);
        _ = Throw.IfNull(update);
        _ = Throw.IfNull(tools);
 
        if (update is ConversationItemStreamingFinishedUpdate itemFinished)
        {
            // If we need to call a tool to update the model, do so
            if (!string.IsNullOrEmpty(itemFinished.FunctionName)
                && await itemFinished.GetFunctionCallOutputAsync(tools, detailedErrors, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) is { } output)
            {
                await session.AddItemAsync(output, cancellationToken).ConfigureAwait(false);
            }
        }
        else if (update is ConversationResponseFinishedUpdate responseFinished)
        {
            // If we added one or more function call results, instruct the model to respond to them
            if (responseFinished.CreatedItems.Any(item => !string.IsNullOrEmpty(item.FunctionName)))
            {
                await session!.StartResponseAsync(cancellationToken).ConfigureAwait(false);
            }
        }
    }
 
    private static JsonElement GetParameterSchema(AIFunctionParameterMetadata parameterMetadata)
    {
        return parameterMetadata switch
        {
            { Schema: JsonElement jsonElement } => jsonElement,
            _ => _defaultParameterSchema,
        };
    }
 
    private static async Task<ConversationItem?> GetFunctionCallOutputAsync(
        this ConversationItemStreamingFinishedUpdate update,
        IReadOnlyList<AIFunction> tools,
        bool? detailedErrors = false,
        JsonSerializerOptions? jsonSerializerOptions = null,
        CancellationToken cancellationToken = default)
    {
        if (!string.IsNullOrEmpty(update.FunctionName)
            && tools.FirstOrDefault(t => t.Metadata.Name == update.FunctionName) is AIFunction aiFunction)
        {
            var jsonOptions = jsonSerializerOptions ?? AIJsonUtilities.DefaultOptions;
 
            var functionCallContent = FunctionCallContent.CreateFromParsedArguments(
                update.FunctionCallArguments, update.FunctionCallId, update.FunctionName,
                    argumentParser: json => JsonSerializer.Deserialize(json,
                    (JsonTypeInfo<IDictionary<string, object>>)jsonOptions.GetTypeInfo(typeof(IDictionary<string, object>)))!);
 
            try
            {
                var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, cancellationToken).ConfigureAwait(false);
                var resultJson = JsonSerializer.Serialize(result, jsonOptions.GetTypeInfo(typeof(object)));
                return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, resultJson);
            }
            catch (JsonException)
            {
                return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, "Invalid JSON");
            }
            catch (Exception e) when (!cancellationToken.IsCancellationRequested)
            {
                var message = "Error calling tool";
 
                if (detailedErrors == true)
                {
                    message += $": {e.Message}";
                }
 
                return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, message);
            }
        }
 
        return null;
    }
 
    internal sealed class ConversationFunctionToolParametersSchema
    {
        public string? Type { get; set; }
        public IDictionary<string, JsonElement>? Properties { get; set; }
        public IEnumerable<string>? Required { get; set; }
    }
}