File: ChatCompletion\FunctionInvokingChatClientTests.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AI.Tests\Microsoft.Extensions.AI.Tests.csproj (Microsoft.Extensions.AI.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using OpenTelemetry.Trace;
using Xunit;
 
#pragma warning disable SA1118 // Parameter should not span multiple lines
 
namespace Microsoft.Extensions.AI;
 
public class FunctionInvokingChatClientTests
{
    [Fact]
    public void InvalidArgs_Throws()
    {
        Assert.Throws<ArgumentNullException>("innerClient", () => new FunctionInvokingChatClient(null!));
        Assert.Throws<ArgumentNullException>("builder", () => ((ChatClientBuilder)null!).UseFunctionInvocation());
    }
 
    [Fact]
    public void Ctor_HasExpectedDefaults()
    {
        using TestChatClient innerClient = new();
        using FunctionInvokingChatClient client = new(innerClient);
 
        Assert.False(client.AllowConcurrentInvocation);
        Assert.False(client.IncludeDetailedErrors);
        Assert.Equal(10, client.MaximumIterationsPerRequest);
        Assert.Equal(3, client.MaximumConsecutiveErrorsPerRequest);
    }
 
    [Fact]
    public async Task SupportsSingleFunctionCallPerRequestAsync()
    {
        var options = new ChatOptions
        {
            Tools =
            [
                AIFunctionFactory.Create(() => "Result 1", "Func1"),
                AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
                AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
            ]
        };
 
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]),
            new ChatMessage(ChatRole.Assistant, "world"),
        ];
 
        await InvokeAndAssertAsync(options, plan);
 
        await InvokeAndAssertStreamingAsync(options, plan);
    }
 
    [Theory]
    [InlineData(false)]
    [InlineData(true)]
    public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentInvocation)
    {
        var options = new ChatOptions
        {
            Tools =
            [
                AIFunctionFactory.Create((int? i = 42) => "Result 1", "Func1"),
                AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
            ]
        };
 
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant,
            [
                new FunctionCallContent("callId1", "Func1"),
                new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 34 } }),
                new FunctionCallContent("callId3", "Func2", arguments: new Dictionary<string, object?> { { "i", 56 } }),
            ]),
            new ChatMessage(ChatRole.Tool,
            [
                new FunctionResultContent("callId1", result: "Result 1"),
                new FunctionResultContent("callId2", result: "Result 2: 34"),
                new FunctionResultContent("callId3", result: "Result 2: 56"),
            ]),
            new ChatMessage(ChatRole.Assistant,
            [
                new FunctionCallContent("callId4", "Func2", arguments: new Dictionary<string, object?> { { "i", 78 } }),
                new FunctionCallContent("callId5", "Func1")
            ]),
            new ChatMessage(ChatRole.Tool,
            [
                new FunctionResultContent("callId4", result: "Result 2: 78"),
                new FunctionResultContent("callId5", result: "Result 1")
            ]),
            new ChatMessage(ChatRole.Assistant, "world"),
        ];
 
        Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(
            s => new FunctionInvokingChatClient(s) { AllowConcurrentInvocation = concurrentInvocation });
 
        await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
 
        await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
    }
 
    [Fact]
    public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync()
    {
        int remaining = 2;
        var tcs = new TaskCompletionSource<bool>();
 
        var options = new ChatOptions
        {
            Tools =
            [
                AIFunctionFactory.Create(async (string arg) =>
                {
                    if (Interlocked.Decrement(ref remaining) == 0)
                    {
                        tcs.SetResult(true);
                    }
 
                    await tcs.Task;
 
                    return arg + arg;
                }, "Func"),
            ]
        };
 
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant,
            [
                new FunctionCallContent("callId1", "Func", arguments: new Dictionary<string, object?> { { "arg", "hello" } }),
                new FunctionCallContent("callId2", "Func", arguments: new Dictionary<string, object?> { { "arg", "world" } }),
            ]),
            new ChatMessage(ChatRole.Tool,
            [
                new FunctionResultContent("callId1", result: "hellohello"),
                new FunctionResultContent("callId2", result: "worldworld"),
            ]),
            new ChatMessage(ChatRole.Assistant, "done"),
        ];
 
        Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(
            s => new FunctionInvokingChatClient(s) { AllowConcurrentInvocation = true });
 
        await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
 
        await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
    }
 
    [Fact]
    public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync()
    {
        int activeCount = 0;
 
        var options = new ChatOptions
        {
            Tools =
            [
                AIFunctionFactory.Create(async (string arg) =>
                {
                    Interlocked.Increment(ref activeCount);
                    await Task.Delay(100);
                    Assert.Equal(1, activeCount);
                    Interlocked.Decrement(ref activeCount);
                    return arg + arg;
                }, "Func"),
            ]
        };
 
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant,
            [
                new FunctionCallContent("callId1", "Func", arguments: new Dictionary<string, object?> { { "arg", "hello" } }),
                new FunctionCallContent("callId2", "Func", arguments: new Dictionary<string, object?> { { "arg", "world" } }),
            ]),
            new ChatMessage(ChatRole.Tool,
            [
                new FunctionResultContent("callId1", result: "hellohello"),
                new FunctionResultContent("callId2", result: "worldworld"),
            ]),
            new ChatMessage(ChatRole.Assistant, "done"),
        ];
 
        await InvokeAndAssertAsync(options, plan);
 
        await InvokeAndAssertStreamingAsync(options, plan);
    }
 
    [Fact]
    public async Task ContinuesWithSuccessfulCallsUntilMaximumIterations()
    {
        var maxIterations = 7;
        Func<ChatClientBuilder, ChatClientBuilder> configurePipeline = pipeline => pipeline
            .UseFunctionInvocation(configure: functionInvokingChatClient =>
            {
                functionInvokingChatClient.MaximumIterationsPerRequest = maxIterations;
            });
 
        var actualCallCount = 0;
        var options = new ChatOptions
        {
            Tools =
            [
                AIFunctionFactory.Create(() => { actualCallCount++; }, "VoidReturn"),
            ]
        };
 
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent($"callId0", "VoidReturn")]),
        ];
 
        // Note that this plan ends with a function call. Normally we would expect the system to try to resolve
        // the call, but it won't because of the maximum iterations limit.
        for (var i = 0; i < maxIterations; i++)
        {
            plan.Add(new ChatMessage(ChatRole.Tool, [new FunctionResultContent($"callId{i}", result: "Success: Function completed.")]));
            plan.Add(new ChatMessage(ChatRole.Assistant, [new FunctionCallContent($"callId{(i + 1)}", "VoidReturn")]));
        }
 
        await InvokeAndAssertAsync(options, plan, configurePipeline: configurePipeline);
        Assert.Equal(maxIterations, actualCallCount);
 
        actualCallCount = 0;
        await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configurePipeline);
        Assert.Equal(maxIterations, actualCallCount);
    }
 
    [Theory]
    [InlineData(false)]
    [InlineData(true)]
    public async Task ContinuesWithFailingCallsUntilMaximumConsecutiveErrors(bool allowConcurrentInvocation)
    {
        Func<ChatClientBuilder, ChatClientBuilder> configurePipeline = pipeline => pipeline
            .UseFunctionInvocation(configure: functionInvokingChatClient =>
            {
                functionInvokingChatClient.MaximumConsecutiveErrorsPerRequest = 2;
                functionInvokingChatClient.AllowConcurrentInvocation = allowConcurrentInvocation;
            });
 
        var options = new ChatOptions
        {
            Tools =
            [
                AIFunctionFactory.Create((bool shouldThrow, int callIndex) =>
                {
                    if (shouldThrow)
                    {
                        throw new InvalidTimeZoneException($"Exception from call {callIndex}");
                    }
                }, "Func"),
            ]
        };
 
        var callIndex = 0;
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
 
            // A single failure isn't enough to stop the cycle
            ..CreateFunctionCallIterationPlan(ref callIndex, true, false),
 
            // Now NumConsecutiveErrors = 1
            // We can reset the number of consecutive errors by having a successful iteration
            ..CreateFunctionCallIterationPlan(ref callIndex, false, false, false),
 
            // Now NumConsecutiveErrors = 0
            // Any failure within an iteration causes the whole iteration to be treated as failed
            ..CreateFunctionCallIterationPlan(ref callIndex, false, true, false),
 
            // Now NumConsecutiveErrors = 1
            // Even if several calls in the same iteration fail, that only counts as a single iteration having failed, so won't exceed the limit yet
            ..CreateFunctionCallIterationPlan(ref callIndex, true, true, true),
 
            // Now NumConsecutiveErrors = 2
            // Any more failures will now exceed the limit
            ..CreateFunctionCallIterationPlan(ref callIndex, true, true),
        ];
 
        if (allowConcurrentInvocation)
        {
            // With concurrent invocation, we always make all the calls in the iteration
            // and combine their exceptions into an AggregateException
            var ex = await Assert.ThrowsAsync<AggregateException>(() =>
                InvokeAndAssertAsync(options, plan, configurePipeline: configurePipeline));
            Assert.Equal(2, ex.InnerExceptions.Count);
            Assert.Equal("Exception from call 11", ex.InnerExceptions[0].Message);
            Assert.Equal("Exception from call 12", ex.InnerExceptions[1].Message);
 
            ex = await Assert.ThrowsAsync<AggregateException>(() =>
                InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configurePipeline));
            Assert.Equal(2, ex.InnerExceptions.Count);
            Assert.Equal("Exception from call 11", ex.InnerExceptions[0].Message);
            Assert.Equal("Exception from call 12", ex.InnerExceptions[1].Message);
        }
        else
        {
            // With serial invocation, we allow the threshold-crossing exception to propagate
            // directly and terminate the iteration
            var ex = await Assert.ThrowsAsync<InvalidTimeZoneException>(() =>
                InvokeAndAssertAsync(options, plan, configurePipeline: configurePipeline));
            Assert.Equal("Exception from call 11", ex.Message);
 
            ex = await Assert.ThrowsAsync<InvalidTimeZoneException>(() =>
                InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configurePipeline));
            Assert.Equal("Exception from call 11", ex.Message);
        }
    }
 
    [Theory]
    [InlineData(false)]
    [InlineData(true)]
    public async Task CanFailOnFirstException(bool allowConcurrentInvocation)
    {
        Func<ChatClientBuilder, ChatClientBuilder> configurePipeline = pipeline => pipeline
            .UseFunctionInvocation(configure: functionInvokingChatClient =>
            {
                functionInvokingChatClient.MaximumConsecutiveErrorsPerRequest = 0;
                functionInvokingChatClient.AllowConcurrentInvocation = allowConcurrentInvocation;
            });
 
        var options = new ChatOptions
        {
            Tools =
            [
                AIFunctionFactory.Create(() =>
                {
                    throw new InvalidTimeZoneException($"It failed");
                }, "Func"),
            ]
        };
 
        var callIndex = 0;
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            ..CreateFunctionCallIterationPlan(ref callIndex, true),
        ];
 
        // Regardless of AllowConcurrentInvocation, if there's only a single exception,
        // we don't wrap it in an AggregateException
        var ex = await Assert.ThrowsAsync<InvalidTimeZoneException>(() =>
            InvokeAndAssertAsync(options, plan, configurePipeline: configurePipeline));
        Assert.Equal("It failed", ex.Message);
 
        ex = await Assert.ThrowsAsync<InvalidTimeZoneException>(() =>
            InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configurePipeline));
        Assert.Equal("It failed", ex.Message);
    }
 
    private static IEnumerable<ChatMessage> CreateFunctionCallIterationPlan(ref int callIndex, params bool[] shouldThrow)
    {
        var assistantMessage = new ChatMessage(ChatRole.Assistant, []);
        var toolMessage = new ChatMessage(ChatRole.Tool, []);
 
        foreach (var callShouldThrow in shouldThrow)
        {
            var thisCallIndex = callIndex++;
            var callId = $"callId{thisCallIndex}";
            assistantMessage.Contents.Add(new FunctionCallContent(callId, "Func",
                arguments: new Dictionary<string, object?> { { "shouldThrow", callShouldThrow }, { "callIndex", thisCallIndex } }));
            toolMessage.Contents.Add(new FunctionResultContent(callId, result: callShouldThrow ? "Error: Function failed." : "Success"));
        }
 
        return [assistantMessage, toolMessage];
    }
 
    [Fact]
    public async Task KeepsFunctionCallingContent()
    {
        var options = new ChatOptions
        {
            Tools =
            [
                AIFunctionFactory.Create(() => "Result 1", "Func1"),
                AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
                AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
            ]
        };
 
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new FunctionCallContent("callId1", "Func1"), new TextContent("stuff")]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 1")]),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } }), new TextContent("more")]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]),
            new ChatMessage(ChatRole.Assistant, "world"),
        ];
 
#pragma warning disable SA1005, S125
        Validate(await InvokeAndAssertAsync(options, plan));
 
        Validate(await InvokeAndAssertStreamingAsync(options, plan));
 
        static void Validate(List<ChatMessage> finalChat)
        {
            IEnumerable<AIContent> content = finalChat.SelectMany(m => m.Contents);
            Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent);
        }
    }
 
    [Theory]
    [InlineData(false)]
    [InlineData(true)]
    public async Task ExceptionDetailsOnlyReportedWhenRequestedAsync(bool detailedErrors)
    {
        var options = new ChatOptions
        {
            Tools =
            [
                AIFunctionFactory.Create(string () => throw new InvalidOperationException("Oh no!"), "Func1"),
            ]
        };
 
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: detailedErrors ? "Error: Function failed. Exception: Oh no!" : "Error: Function failed.")]),
            new ChatMessage(ChatRole.Assistant, "world"),
        ];
 
        Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(
            s => new FunctionInvokingChatClient(s) { IncludeDetailedErrors = detailedErrors });
 
        await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
 
        await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
    }
 
    [Theory]
    [InlineData(LogLevel.Trace)]
    [InlineData(LogLevel.Debug)]
    [InlineData(LogLevel.Information)]
    public async Task FunctionInvocationsLogged(LogLevel level)
    {
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["arg1"] = "value1" })]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
            new ChatMessage(ChatRole.Assistant, "world"),
        ];
 
        var options = new ChatOptions
        {
            Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")]
        };
 
        Func<ChatClientBuilder, ChatClientBuilder> configure = b =>
            b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService<ILoggerFactory>()));
 
        await InvokeAsync(services => InvokeAndAssertAsync(options, plan, configurePipeline: configure, services: services));
 
        await InvokeAsync(services => InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure, services: services));
 
        async Task InvokeAsync(Func<IServiceProvider, Task> work)
        {
            var collector = new FakeLogCollector();
 
            ServiceCollection c = new();
            c.AddLogging(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(level));
 
            await work(c.BuildServiceProvider());
 
            var logs = collector.GetSnapshot();
            if (level is LogLevel.Trace)
            {
                Assert.Collection(logs,
                    entry => Assert.True(entry.Message.Contains("Invoking Func1({") && entry.Message.Contains("\"arg1\": \"value1\"")),
                    entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && entry.Message.Contains("Result: \"Result 1\"")));
            }
            else if (level is LogLevel.Debug)
            {
                Assert.Collection(logs,
                    entry => Assert.True(entry.Message.Contains("Invoking Func1") && !entry.Message.Contains("arg1")),
                    entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && !entry.Message.Contains("Result")));
            }
            else
            {
                Assert.Empty(logs);
            }
        }
    }
 
    [Theory]
    [InlineData(false)]
    [InlineData(true)]
    public async Task FunctionInvocationTrackedWithActivity(bool enableTelemetry)
    {
        string sourceName = Guid.NewGuid().ToString();
 
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["arg1"] = "value1" })]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
            new ChatMessage(ChatRole.Assistant, "world"),
        ];
 
        ChatOptions options = new()
        {
            Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")]
        };
 
        Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(c =>
            new FunctionInvokingChatClient(new OpenTelemetryChatClient(c, sourceName: sourceName)));
 
        await InvokeAsync(() => InvokeAndAssertAsync(options, plan, configurePipeline: configure));
 
        await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure));
 
        async Task InvokeAsync(Func<Task> work)
        {
            var activities = new List<Activity>();
            using TracerProvider? tracerProvider = enableTelemetry ?
                OpenTelemetry.Sdk.CreateTracerProviderBuilder()
                .AddSource(sourceName)
                .AddInMemoryExporter(activities)
                .Build() :
                null;
 
            await work();
 
            if (enableTelemetry)
            {
                Assert.Collection(activities,
                    activity => Assert.Equal("chat", activity.DisplayName),
                    activity => Assert.Equal("Func1", activity.DisplayName),
                    activity => Assert.Equal("chat", activity.DisplayName),
                    activity => Assert.Equal(nameof(FunctionInvokingChatClient), activity.DisplayName));
 
                for (int i = 0; i < activities.Count - 1; i++)
                {
                    // Activities are exported in the order of completion, so all except the last are children of the last (i.e., outer)
                    Assert.Same(activities[activities.Count - 1], activities[i].Parent);
                }
            }
            else
            {
                Assert.Empty(activities);
            }
        }
    }
 
    [Fact]
    public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls()
    {
        var options = new ChatOptions
        {
            Tools = [AIFunctionFactory.Create((string text) => $"Result for {text}", "Func1")]
        };
 
        var messages = new List<ChatMessage>
        {
            new(ChatRole.User, "Hello"),
        };
 
        using var innerClient = new TestChatClient
        {
            GetStreamingResponseAsyncCallback = (chatContents, chatOptions, cancellationToken) =>
            {
                // If the conversation is just starting, issue two consecutive updates with function calls
                // Otherwise just end the conversation.
                List<ChatResponseUpdate> updates;
                string messageId = Guid.NewGuid().ToString("N");
                if (chatContents.Last().Text == "Hello")
                {
                    updates =
                    [
                        new() { Contents = [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["text"] = "Input 1" })] },
                        new() { Contents = [new FunctionCallContent("callId2", "Func1", new Dictionary<string, object?> { ["text"] = "Input 2" })] }
                    ];
                }
                else
                {
                    updates = [new() { Contents = [new TextContent("OK bye")] }];
                }
 
                foreach (var update in updates)
                {
                    update.MessageId = messageId;
                }
 
                return YieldAsync(updates);
            }
        };
 
        using var client = new FunctionInvokingChatClient(innerClient);
 
        var response = await client.GetStreamingResponseAsync(messages, options, CancellationToken.None).ToChatResponseAsync();
 
        // The returned message should include the FCCs and FRCs.
        Assert.Collection(response.Messages,
            m => Assert.Collection(m.Contents,
                c => Assert.Equal("Input 1", Assert.IsType<FunctionCallContent>(c).Arguments!["text"]),
                c => Assert.Equal("Input 2", Assert.IsType<FunctionCallContent>(c).Arguments!["text"])),
            m => Assert.Collection(m.Contents,
                c => Assert.Equal("Result for Input 1", Assert.IsType<FunctionResultContent>(c).Result?.ToString()),
                c => Assert.Equal("Result for Input 2", Assert.IsType<FunctionResultContent>(c).Result?.ToString())),
            m => Assert.Equal("OK bye", Assert.IsType<TextContent>(Assert.Single(m.Contents)).Text));
    }
 
    [Fact]
    public async Task AllResponseMessagesReturned()
    {
        var options = new ChatOptions
        {
            Tools = [AIFunctionFactory.Create(() => "doesn't matter", "Func1")]
        };
 
        var messages = new List<ChatMessage>
        {
            new(ChatRole.User, "Hello"),
        };
 
        using var innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = async (chatContents, chatOptions, cancellationToken) =>
            {
                await Task.Yield();
 
                ChatMessage message = chatContents.Count() is 1 or 3 ?
                    new(ChatRole.Assistant, [new FunctionCallContent($"callId{chatContents.Count()}", "Func1")]) :
                    new(ChatRole.Assistant, "The answer is 42.");
 
                return new(message);
            }
        };
 
        using var client = new FunctionInvokingChatClient(innerClient);
 
        ChatResponse response = await client.GetResponseAsync(messages, options);
 
        Assert.Equal(5, response.Messages.Count);
        Assert.Equal("The answer is 42.", response.Text);
        Assert.IsType<FunctionCallContent>(Assert.Single(response.Messages[0].Contents));
        Assert.IsType<FunctionResultContent>(Assert.Single(response.Messages[1].Contents));
        Assert.IsType<FunctionCallContent>(Assert.Single(response.Messages[2].Contents));
        Assert.IsType<FunctionResultContent>(Assert.Single(response.Messages[3].Contents));
        Assert.IsType<TextContent>(Assert.Single(response.Messages[4].Contents));
    }
 
    [Fact]
    public async Task CanAccesssFunctionInvocationContextFromFunctionCall()
    {
        var invocationContexts = new List<FunctionInvocationContext>();
        var function = AIFunctionFactory.Create(async (int i) =>
        {
            // The context should propogate across async calls
            await Task.Yield();
 
            var context = FunctionInvokingChatClient.CurrentContext!;
            invocationContexts.Add(context);
 
            if (i == 42)
            {
                context.Terminate = true;
            }
 
            return $"Result {i}";
        }, "Func1");
 
        var options = new ChatOptions
        {
            Tools = [function],
        };
 
        // The invocation loop should terminate after the second function call
        List<ChatMessage> planBeforeTermination =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["i"] = 41 })]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 41")]),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func1", new Dictionary<string, object?> { ["i"] = 42 })]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 42")]),
        ];
 
        // The full plan should never be fulfilled
        List<ChatMessage> plan =
        [
            .. planBeforeTermination,
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "Func1", new Dictionary<string, object?> { ["i"] = 43 })]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Result 43")]),
            new ChatMessage(ChatRole.Assistant, "world"),
        ];
 
        await InvokeAsync(() => InvokeAndAssertAsync(options, plan, planBeforeTermination));
 
        await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, planBeforeTermination));
 
        // The current context should be null outside the async call stack for the function invocation
        Assert.Null(FunctionInvokingChatClient.CurrentContext);
 
        async Task InvokeAsync(Func<Task<List<ChatMessage>>> work)
        {
            invocationContexts.Clear();
 
            var messages = await work();
 
            Assert.Collection(invocationContexts,
                c => AssertInvocationContext(c, iteration: 0, terminate: false),
                c => AssertInvocationContext(c, iteration: 1, terminate: true));
 
            void AssertInvocationContext(FunctionInvocationContext context, int iteration, bool terminate)
            {
                Assert.NotNull(context);
                Assert.Equal(messages.Count, context.Messages.Count);
                Assert.Equal(string.Concat(messages), string.Concat(context.Messages));
                Assert.Same(function, context.Function);
                Assert.Equal("Func1", context.CallContent.Name);
                Assert.Equal(0, context.FunctionCallIndex);
                Assert.Equal(1, context.FunctionCount);
                Assert.Equal(iteration, context.Iteration);
                Assert.Equal(terminate, context.Terminate);
            }
        }
    }
 
    [Fact]
    public async Task CanResumeFunctionCallingAfterTermination()
    {
        var function = AIFunctionFactory.Create((string? result = null) =>
        {
            if (!string.IsNullOrEmpty(result))
            {
                return result;
            }
 
            FunctionInvokingChatClient.CurrentContext!.Terminate = true;
            return (object?)null;
        }, "Search");
 
        using var innerChatClient = new TestChatClient
        {
            GetResponseAsyncCallback = (chatContents, chatOptions, cancellationToken) =>
            {
                // We can have a mixture of calls that are not terminated and terminated
                var existingSearchResult = chatContents.SingleOrDefault(m => m.Role == ChatRole.Tool);
                AIContent[] resultContents = existingSearchResult is not null && existingSearchResult.Contents.OfType<FunctionResultContent>().ToList() is { } frcs
                    ? [new TextContent($"The search results were '{string.Join(", ", frcs.Select(frc => frc.Result))}'")]
                    : [
                        new FunctionCallContent("callId1", "Search"),
                        new FunctionCallContent("callId2", "Search", new Dictionary<string, object?> { { "result", "birds" } }),
                        new FunctionCallContent("callId3", "Search"),
                      ];
 
                var message = new ChatMessage(ChatRole.Assistant, resultContents);
                return Task.FromResult(new ChatResponse(message));
            }
        };
        using var chatClient = new FunctionInvokingChatClient(innerChatClient);
 
        // The function should terminate the invocation loop without calling the inner client for a final answer
        // But it still makes all the function calls within the same iteration
        List<ChatMessage> messages = [new(ChatRole.User, "hello")];
        var chatOptions = new ChatOptions { Tools = [function] };
        var result = await chatClient.GetResponseAsync(messages, chatOptions);
        messages.AddMessages(result);
 
        // Application code can then set the results
        var lastMessage = messages.Last();
        Assert.Equal(ChatRole.Tool, lastMessage.Role);
        var frcs = lastMessage.Contents.OfType<FunctionResultContent>().ToList();
        Assert.Equal(3, frcs.Count);
        Assert.Equal("birds", frcs[1].Result!.ToString());
        frcs[0].Result = "dogs";
        frcs[2].Result = "cats";
 
        // We can re-enter the function calling mechanism to get a final answer
        result = await chatClient.GetResponseAsync(messages, chatOptions);
        Assert.Equal("The search results were 'dogs, birds, cats'", result.Text);
    }
 
    [Fact]
    public async Task PropagatesResponseChatThreadIdToOptions()
    {
        var options = new ChatOptions
        {
            Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")],
        };
 
        int iteration = 0;
 
        Func<IEnumerable<ChatMessage>, ChatOptions?, CancellationToken, ChatResponse> callback =
            (chatContents, chatOptions, cancellationToken) =>
            {
                iteration++;
 
                if (iteration == 1)
                {
                    Assert.Null(chatOptions?.ChatThreadId);
                    return new ChatResponse(new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId-abc", "Func1")]))
                    {
                        ChatThreadId = "12345",
                    };
                }
                else if (iteration == 2)
                {
                    Assert.Equal("12345", chatOptions?.ChatThreadId);
                    return new ChatResponse(new ChatMessage(ChatRole.Assistant, "done!"));
                }
                else
                {
                    throw new InvalidOperationException("Unexpected iteration");
                }
            };
 
        using var innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = (chatContents, chatOptions, cancellationToken) =>
                Task.FromResult(callback(chatContents, chatOptions, cancellationToken)),
            GetStreamingResponseAsyncCallback = (chatContents, chatOptions, cancellationToken) =>
                YieldAsync(callback(chatContents, chatOptions, cancellationToken).ToChatResponseUpdates()),
        };
 
        using IChatClient service = innerClient.AsBuilder().UseFunctionInvocation().Build();
 
        iteration = 0;
        Assert.Equal("done!", (await service.GetResponseAsync("hey", options)).ToString());
        iteration = 0;
        Assert.Equal("done!", (await service.GetStreamingResponseAsync("hey", options).ToChatResponseAsync()).ToString());
    }
 
    [Fact]
    public async Task FunctionInvocations_PassesServices()
    {
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["arg1"] = "value1" })]),
            new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
            new ChatMessage(ChatRole.Assistant, "world"),
        ];
 
        ServiceCollection c = new();
        IServiceProvider expected = c.BuildServiceProvider();
 
        var options = new ChatOptions
        {
            Tools = [AIFunctionFactory.Create((IServiceProvider actual) =>
            {
                Assert.Same(expected, actual);
                return "Result 1";
            }, "Func1")]
        };
 
        await InvokeAndAssertAsync(options, plan, services: expected);
    }
 
    [Fact]
    public async Task FunctionInvocations_InvokedOnOriginalSynchronizationContext()
    {
        SynchronizationContext ctx = new CustomSynchronizationContext();
        SynchronizationContext.SetSynchronizationContext(ctx);
 
        List<ChatMessage> plan =
        [
            new ChatMessage(ChatRole.User, "hello"),
            new ChatMessage(ChatRole.Assistant, [
                new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["arg"] = "value1" }),
                new FunctionCallContent("callId2", "Func1", new Dictionary<string, object?> { ["arg"] = "value2" }),
            ]),
            new ChatMessage(ChatRole.Tool,
            [
                new FunctionResultContent("callId2", result: "value1"),
                new FunctionResultContent("callId2", result: "value2")
            ]),
            new ChatMessage(ChatRole.Assistant, "world"),
        ];
 
        var options = new ChatOptions
        {
            Tools = [AIFunctionFactory.Create(async (string arg, CancellationToken cancellationToken) =>
            {
                await Task.Delay(1, cancellationToken);
                Assert.Same(ctx, SynchronizationContext.Current);
                return arg;
            }, "Func1")]
        };
 
        Func<ChatClientBuilder, ChatClientBuilder> configurePipeline = builder => builder
            .Use(async (messages, options, next, cancellationToken) =>
            {
                await Task.Delay(1, cancellationToken);
                await next(messages, options, cancellationToken);
            })
            .UseOpenTelemetry()
            .UseFunctionInvocation(configure: c => { c.AllowConcurrentInvocation = true; c.IncludeDetailedErrors = true; });
 
        await InvokeAndAssertAsync(options, plan, configurePipeline: configurePipeline);
        await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configurePipeline);
    }
 
    private sealed class CustomSynchronizationContext : SynchronizationContext
    {
        public override void Post(SendOrPostCallback d, object? state)
        {
            ThreadPool.QueueUserWorkItem(delegate
            {
                SetSynchronizationContext(this);
                d(state);
            });
        }
    }
 
    private static async Task<List<ChatMessage>> InvokeAndAssertAsync(
        ChatOptions options,
        List<ChatMessage> plan,
        List<ChatMessage>? expected = null,
        Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
        IServiceProvider? services = null)
    {
        Assert.NotEmpty(plan);
 
        configurePipeline ??= static b => b.UseFunctionInvocation();
 
        using CancellationTokenSource cts = new();
        List<ChatMessage> chat = [plan[0]];
        long expectedTotalTokenCounts = 0;
 
        using var innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = async (contents, actualOptions, actualCancellationToken) =>
            {
                Assert.Equal(cts.Token, actualCancellationToken);
 
                await Task.Yield();
 
                var usage = CreateRandomUsage();
                expectedTotalTokenCounts += usage.InputTokenCount!.Value;
 
                var message = new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count()].Contents])
                {
                    MessageId = Guid.NewGuid().ToString("N")
                };
                return new ChatResponse(message) { Usage = usage };
            }
        };
 
        IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services);
 
        var result = await service.GetResponseAsync(new EnumeratedOnceEnumerable<ChatMessage>(chat), options, cts.Token);
        Assert.NotNull(result);
 
        chat.AddRange(result.Messages);
 
        expected ??= plan;
        Assert.Equal(expected.Count, chat.Count);
        for (int i = 0; i < expected.Count; i++)
        {
            var expectedMessage = expected[i];
            var chatMessage = chat[i];
 
            Assert.Equal(expectedMessage.Role, chatMessage.Role);
            Assert.Equal(expectedMessage.Text, chatMessage.Text);
            Assert.Equal(expectedMessage.GetType(), chatMessage.GetType());
 
            Assert.Equal(expectedMessage.Contents.Count, chatMessage.Contents.Count);
            for (int j = 0; j < expectedMessage.Contents.Count; j++)
            {
                var expectedItem = expectedMessage.Contents[j];
                var chatItem = chatMessage.Contents[j];
 
                Assert.Equal(expectedItem.GetType(), chatItem.GetType());
                Assert.Equal(expectedItem.ToString(), chatItem.ToString());
                if (expectedItem is FunctionCallContent expectedFunctionCall)
                {
                    var chatFunctionCall = (FunctionCallContent)chatItem;
                    Assert.Equal(expectedFunctionCall.Name, chatFunctionCall.Name);
                    AssertExtensions.EqualFunctionCallParameters(expectedFunctionCall.Arguments, chatFunctionCall.Arguments);
                }
                else if (expectedItem is FunctionResultContent expectedFunctionResult)
                {
                    var chatFunctionResult = (FunctionResultContent)chatItem;
                    AssertExtensions.EqualFunctionCallResults(expectedFunctionResult.Result, chatFunctionResult.Result);
                }
            }
        }
 
        // Usage should be aggregated over all responses, including AdditionalUsage
        var actualUsage = result.Usage!;
        Assert.Equal(expectedTotalTokenCounts, actualUsage.InputTokenCount);
        Assert.Equal(expectedTotalTokenCounts, actualUsage.OutputTokenCount);
        Assert.Equal(expectedTotalTokenCounts, actualUsage.TotalTokenCount);
        Assert.Equal(2, actualUsage.AdditionalCounts!.Count);
        Assert.Equal(expectedTotalTokenCounts, actualUsage.AdditionalCounts["firstValue"]);
        Assert.Equal(expectedTotalTokenCounts, actualUsage.AdditionalCounts["secondValue"]);
 
        return chat;
    }
 
    private static UsageDetails CreateRandomUsage()
    {
        // We'll set the same random number on all the properties so that, when determining the
        // correct sum in tests, we only have to total the values once
        var value = new Random().Next(100);
        return new UsageDetails
        {
            InputTokenCount = value,
            OutputTokenCount = value,
            TotalTokenCount = value,
            AdditionalCounts = new() { ["firstValue"] = value, ["secondValue"] = value },
        };
    }
 
    private static async Task<List<ChatMessage>> InvokeAndAssertStreamingAsync(
        ChatOptions options,
        List<ChatMessage> plan,
        List<ChatMessage>? expected = null,
        Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
        IServiceProvider? services = null)
    {
        Assert.NotEmpty(plan);
 
        configurePipeline ??= static b => b.UseFunctionInvocation();
 
        using CancellationTokenSource cts = new();
        List<ChatMessage> chat = [plan[0]];
 
        using var innerClient = new TestChatClient
        {
            GetStreamingResponseAsyncCallback = (contents, actualOptions, actualCancellationToken) =>
            {
                Assert.Equal(cts.Token, actualCancellationToken);
 
                ChatMessage message = new(ChatRole.Assistant, [.. plan[contents.Count()].Contents])
                {
                    MessageId = Guid.NewGuid().ToString("N"),
                };
                return YieldAsync(new ChatResponse(message).ToChatResponseUpdates());
            }
        };
 
        IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services);
 
        var result = await service.GetStreamingResponseAsync(new EnumeratedOnceEnumerable<ChatMessage>(chat), options, cts.Token).ToChatResponseAsync();
        Assert.NotNull(result);
 
        chat.AddRange(result.Messages);
 
        expected ??= plan;
        Assert.Equal(expected.Count, chat.Count);
        for (int i = 0; i < expected.Count; i++)
        {
            var expectedMessage = expected[i];
            var chatMessage = chat[i];
 
            Assert.Equal(expectedMessage.Role, chatMessage.Role);
            Assert.Equal(expectedMessage.Text, chatMessage.Text);
            Assert.Equal(expectedMessage.GetType(), chatMessage.GetType());
 
            Assert.Equal(expectedMessage.Contents.Count, chatMessage.Contents.Count);
            for (int j = 0; j < expectedMessage.Contents.Count; j++)
            {
                var expectedItem = expectedMessage.Contents[j];
                var chatItem = chatMessage.Contents[j];
 
                Assert.Equal(expectedItem.GetType(), chatItem.GetType());
                Assert.Equal(expectedItem.ToString(), chatItem.ToString());
                if (expectedItem is FunctionCallContent expectedFunctionCall)
                {
                    var chatFunctionCall = (FunctionCallContent)chatItem;
                    Assert.Equal(expectedFunctionCall.Name, chatFunctionCall.Name);
                    AssertExtensions.EqualFunctionCallParameters(expectedFunctionCall.Arguments, chatFunctionCall.Arguments);
                }
                else if (expectedItem is FunctionResultContent expectedFunctionResult)
                {
                    var chatFunctionResult = (FunctionResultContent)chatItem;
                    AssertExtensions.EqualFunctionCallResults(expectedFunctionResult.Result, chatFunctionResult.Result);
                }
            }
        }
 
        return chat;
    }
 
    private static async IAsyncEnumerable<T> YieldAsync<T>(params IEnumerable<T> items)
    {
        await Task.Yield();
        foreach (var item in items)
        {
            yield return item;
        }
    }
 
    private sealed class EnumeratedOnceEnumerable<T>(IEnumerable<T> items) : IEnumerable<T>
    {
        private int _iterated;
 
        public IEnumerator<T> GetEnumerator()
        {
            if (Interlocked.Exchange(ref _iterated, 1) != 0)
            {
                throw new InvalidOperationException("This enumerable can only be enumerated once.");
            }
 
            foreach (var item in items)
            {
                yield return item;
            }
        }
 
        IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
    }
}