File: PromptBasedFunctionCallingChatClient.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AI.Integration.Tests\Microsoft.Extensions.AI.Integration.Tests.csproj (Microsoft.Extensions.AI.Integration.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
 
#pragma warning disable SA1402 // File may only contain a single type
#pragma warning disable S1144 // Unused private types or members should be removed
#pragma warning disable S3459 // Unassigned members should be removed
 
namespace Microsoft.Extensions.AI;
 
// This isn't a feature we're planning to ship, but demonstrates how custom clients can
// layer in non-trivial functionality. In this case we're able to upgrade non-function-calling models
// to behaving as if they do support function calling.
//
// In practice:
//  - For llama3:8b or mistral:7b, this works fairly reliably, at least when it only needs to
//    make a single function call with a constrained set of args.
//  - For smaller models like phi3:mini, it works only on a more occasional basis (e.g., if there's
//    only one function defined, and it takes no arguments, but is very hit-and-miss beyond that).
 
internal sealed class PromptBasedFunctionCallingChatClient(IChatClient innerClient)
    : DelegatingChatClient(innerClient)
{
    private const string MessageIntro = "You are an AI model with function calling capabilities. Call one or more functions if they are relevant to the user's query.";
 
    private static readonly JsonSerializerOptions _jsonOptions = new(JsonSerializerDefaults.Web)
    {
        WriteIndented = true,
        DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
    };
 
    public override async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
    {
        // Our goal is to convert tools into a prompt describing them, then to detect tool calls in the
        // response and convert those into FunctionCallContent.
        if (options?.Tools is { Count: > 0 })
        {
            AddOrUpdateToolPrompt(chatMessages, options.Tools);
            options = options.Clone();
            options.Tools = null;
 
            options.StopSequences ??= [];
            if (!options.StopSequences.Contains("</tool_calls>"))
            {
                options.StopSequences.Add("</tool_calls>");
            }
 
            // Since the point of this client is to avoid relying on the underlying model having
            // native tool call support, we have to replace any "tool" or "toolcall" messages with
            // "user" or "assistant" ones.
            foreach (var message in chatMessages)
            {
                for (var itemIndex = 0; itemIndex < message.Contents.Count; itemIndex++)
                {
                    if (message.Contents[itemIndex] is FunctionResultContent frc)
                    {
                        var toolCallResultJson = JsonSerializer.Serialize(new ToolCallResult { Id = frc.CallId, Result = frc.Result }, _jsonOptions);
                        message.Role = ChatRole.User;
                        message.Contents[itemIndex] = new TextContent(
                            $"<tool_call_result>{toolCallResultJson}</tool_call_result>");
                    }
                    else if (message.Contents[itemIndex] is FunctionCallContent fcc)
                    {
                        var toolCallJson = JsonSerializer.Serialize(new { fcc.CallId, fcc.Name, fcc.Arguments }, _jsonOptions);
                        message.Role = ChatRole.Assistant;
                        message.Contents[itemIndex] = new TextContent(
                            $"<tool_call_json>{toolCallJson}</tool_call_json>");
                    }
                }
            }
        }
 
        var result = await base.CompleteAsync(chatMessages, options, cancellationToken);
 
        if (result.Choices.FirstOrDefault()?.Text is { } content && content.IndexOf("<tool_call_json>", StringComparison.Ordinal) is int startPos
            && startPos >= 0)
        {
            var message = result.Choices.First();
            var contentItem = message.Contents.SingleOrDefault();
            content = content.Substring(startPos);
 
            foreach (var toolCallJson in content.Split(["<tool_call_json>"], StringSplitOptions.None))
            {
                var toolCall = toolCallJson.Trim();
                if (toolCall.Length == 0)
                {
                    continue;
                }
 
                var endPos = toolCall.IndexOf("</tool", StringComparison.Ordinal);
                if (endPos > 0)
                {
                    toolCall = toolCall.Substring(0, endPos);
                    try
                    {
                        var toolCallParsed = JsonSerializer.Deserialize<ToolCall>(toolCall, _jsonOptions);
                        if (!string.IsNullOrEmpty(toolCallParsed?.Name))
                        {
                            if (toolCallParsed!.Arguments is not null)
                            {
                                ParseArguments(toolCallParsed.Arguments);
                            }
 
                            var id = Guid.NewGuid().ToString().Substring(0, 6);
                            message.Contents.Add(new FunctionCallContent(id, toolCallParsed.Name!, toolCallParsed.Arguments is { } args ? new ReadOnlyDictionary<string, object?>(args) : null));
 
                            if (contentItem is not null)
                            {
                                message.Contents.Remove(contentItem);
                            }
                        }
                    }
                    catch (JsonException)
                    {
                        // Ignore invalid tool calls
                    }
                }
            }
        }
 
        return result;
    }
 
    private static void ParseArguments(IDictionary<string, object?> arguments)
    {
        // This is a simple implementation. A more robust answer is to use other schema information given by
        // the AIFunction here, as for example is done in OpenAIChatClient.
        foreach (var kvp in arguments.ToArray())
        {
            if (kvp.Value is JsonElement jsonElement)
            {
                arguments[kvp.Key] = jsonElement.ValueKind switch
                {
                    JsonValueKind.String => jsonElement.GetString(),
                    JsonValueKind.Number => jsonElement.GetDouble(),
                    JsonValueKind.True => true,
                    JsonValueKind.False => false,
                    _ => jsonElement.ToString()
                };
            }
        }
    }
 
    private static void AddOrUpdateToolPrompt(IList<ChatMessage> chatMessages, IList<AITool> tools)
    {
        var existingToolPrompt = chatMessages.FirstOrDefault(c => c.Text?.StartsWith(MessageIntro, StringComparison.Ordinal) is true);
        if (existingToolPrompt is null)
        {
            existingToolPrompt = new ChatMessage(ChatRole.System, (string?)null);
            chatMessages.Insert(0, existingToolPrompt);
        }
 
        var toolDescriptorsJson = JsonSerializer.Serialize(tools.OfType<AIFunction>().Select(ToToolDescriptor), _jsonOptions);
        existingToolPrompt.Text = $$"""
            {{MessageIntro}}
 
            For each function call, return a JSON object with the function name and arguments within <tool_call_json></tool_call_json> XML tags
            as follows:
            <tool_calls>
              <tool_call_json>{"name": "tool_name", "arguments": { argname1: argval1, argname2: argval2, ... } }</tool_call_json>
            </tool_calls>
            Note that the contents of <tool_call_json></tool_call_json> MUST be a valid JSON object, with no other text.
 
            Once you receive the result as a JSON object within <tool_call_result></tool_call_result> XML tags, use it to
            answer the user's question without repeating the same tool call.
 
            Here are the available tools:
            <tools>{{toolDescriptorsJson}}</tools>
            """;
    }
 
    private static ToolDescriptor ToToolDescriptor(AIFunction tool) => new()
    {
        Name = tool.Metadata.Name,
        Description = tool.Metadata.Description,
        Arguments = tool.Metadata.Parameters.ToDictionary(
            p => p.Name,
            p => new ToolParameterDescriptor
            {
                Type = p.ParameterType?.Name,
                Description = p.Description,
                Enum = p.ParameterType?.IsEnum == true ? Enum.GetNames(p.ParameterType) : null,
                Required = p.IsRequired,
            }),
    };
 
    private sealed class ToolDescriptor
    {
        public string? Name { get; set; }
        public string? Description { get; set; }
        public IDictionary<string, ToolParameterDescriptor>? Arguments { get; set; }
    }
 
    private sealed class ToolParameterDescriptor
    {
        public string? Type { get; set; }
        public string? Description { get; set; }
        public bool? Required { get; set; }
        public string[]? Enum { get; set; }
    }
 
    private sealed class ToolCall
    {
        public string? Id { get; set; }
        public string? Name { get; set; }
        public IDictionary<string, object?>? Arguments { get; set; }
    }
 
    private sealed class ToolCallResult
    {
        public string? Id { get; set; }
        public object? Result { get; set; }
    }
}
 
public static class PromptBasedFunctionCallingChatClientExtensions
{
    public static ChatClientBuilder UsePromptBasedFunctionCalling(this ChatClientBuilder builder)
        => builder.Use(innerClient => new PromptBasedFunctionCallingChatClient(innerClient));
}