|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using OpenTelemetry.Trace;
using Xunit;
#pragma warning disable SA1118 // Parameter should not span multiple lines
namespace Microsoft.Extensions.AI;
public class FunctionInvokingChatClientTests
{
[Fact]
public void InvalidArgs_Throws()
{
Assert.Throws<ArgumentNullException>("innerClient", () => new FunctionInvokingChatClient(null!));
Assert.Throws<ArgumentNullException>("builder", () => ((ChatClientBuilder)null!).UseFunctionInvocation());
}
[Fact]
public void Ctor_HasExpectedDefaults()
{
using TestChatClient innerClient = new();
using FunctionInvokingChatClient client = new(innerClient);
Assert.False(client.ConcurrentInvocation);
Assert.False(client.DetailedErrors);
Assert.True(client.KeepFunctionCallingMessages);
Assert.Null(client.MaximumIterationsPerRequest);
Assert.False(client.RetryOnError);
}
[Fact]
public async Task SupportsSingleFunctionCallPerRequestAsync()
{
var options = new ChatOptions
{
Tools =
[
AIFunctionFactory.Create(() => "Result 1", "Func1"),
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
]
};
List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]),
new ChatMessage(ChatRole.Assistant, "world"),
];
await InvokeAndAssertAsync(options, plan);
await InvokeAndAssertStreamingAsync(options, plan);
}
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentInvocation)
{
var options = new ChatOptions
{
Tools =
[
AIFunctionFactory.Create((int i) => "Result 1", "Func1"),
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
]
};
List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant,
[
new FunctionCallContent("callId1", "Func1"),
new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 34 } }),
new FunctionCallContent("callId3", "Func2", arguments: new Dictionary<string, object?> { { "i", 56 } }),
]),
new ChatMessage(ChatRole.Tool,
[
new FunctionResultContent("callId1", "Func1", result: "Result 1"),
new FunctionResultContent("callId2", "Func2", result: "Result 2: 34"),
new FunctionResultContent("callId3", "Func2", result: "Result 2: 56"),
]),
new ChatMessage(ChatRole.Assistant,
[
new FunctionCallContent("callId4", "Func2", arguments: new Dictionary<string, object?> { { "i", 78 } }),
new FunctionCallContent("callId5", "Func1")
]),
new ChatMessage(ChatRole.Tool,
[
new FunctionResultContent("callId4", "Func2", result: "Result 2: 78"),
new FunctionResultContent("callId5", "Func1", result: "Result 1")
]),
new ChatMessage(ChatRole.Assistant, "world"),
];
Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation });
await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
}
[Fact]
public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync()
{
using var barrier = new Barrier(2);
var options = new ChatOptions
{
Tools =
[
AIFunctionFactory.Create((string arg) =>
{
barrier.SignalAndWait();
return arg + arg;
}, "Func"),
]
};
List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant,
[
new FunctionCallContent("callId1", "Func", arguments: new Dictionary<string, object?> { { "arg", "hello" } }),
new FunctionCallContent("callId2", "Func", arguments: new Dictionary<string, object?> { { "arg", "world" } }),
]),
new ChatMessage(ChatRole.Tool,
[
new FunctionResultContent("callId1", "Func", result: "hellohello"),
new FunctionResultContent("callId2", "Func", result: "worldworld"),
]),
new ChatMessage(ChatRole.Assistant, "done"),
];
Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = true });
await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
}
[Fact]
public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync()
{
int activeCount = 0;
var options = new ChatOptions
{
Tools =
[
AIFunctionFactory.Create(async (string arg) =>
{
Interlocked.Increment(ref activeCount);
await Task.Delay(100);
Assert.Equal(1, activeCount);
Interlocked.Decrement(ref activeCount);
return arg + arg;
}, "Func"),
]
};
List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant,
[
new FunctionCallContent("callId1", "Func", arguments: new Dictionary<string, object?> { { "arg", "hello" } }),
new FunctionCallContent("callId2", "Func", arguments: new Dictionary<string, object?> { { "arg", "world" } }),
]),
new ChatMessage(ChatRole.Tool,
[
new FunctionResultContent("callId1", "Func", result: "hellohello"),
new FunctionResultContent("callId2", "Func", result: "worldworld"),
]),
new ChatMessage(ChatRole.Assistant, "done"),
];
await InvokeAndAssertAsync(options, plan);
await InvokeAndAssertStreamingAsync(options, plan);
}
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task RemovesFunctionCallingMessagesWhenRequestedAsync(bool keepFunctionCallingMessages)
{
var options = new ChatOptions
{
Tools =
[
AIFunctionFactory.Create(() => "Result 1", "Func1"),
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
]
};
List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]),
new ChatMessage(ChatRole.Assistant, "world"),
];
List<ChatMessage>? expected = keepFunctionCallingMessages ? null :
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, "world")
];
Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages });
Validate(await InvokeAndAssertAsync(options, plan, expected, configure));
Validate(await InvokeAndAssertStreamingAsync(options, plan, expected, configure));
void Validate(List<ChatMessage> finalChat)
{
IEnumerable<AIContent> content = finalChat.SelectMany(m => m.Contents);
if (keepFunctionCallingMessages)
{
Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent);
}
else
{
Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent));
}
}
}
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task RemovesFunctionCallingContentWhenRequestedAsync(bool keepFunctionCallingMessages)
{
var options = new ChatOptions
{
Tools =
[
AIFunctionFactory.Create(() => "Result 1", "Func1"),
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
]
};
List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new FunctionCallContent("callId1", "Func1"), new TextContent("stuff")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 1")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } }), new TextContent("more")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]),
new ChatMessage(ChatRole.Assistant, "world"),
];
Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages });
#pragma warning disable SA1005, S125
Validate(await InvokeAndAssertAsync(options, plan, keepFunctionCallingMessages ? null :
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new TextContent("stuff")]),
new ChatMessage(ChatRole.Assistant, "more"),
new ChatMessage(ChatRole.Assistant, "world"),
], configure));
Validate(await InvokeAndAssertStreamingAsync(options, plan, keepFunctionCallingMessages ?
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 1")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]),
new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"),
] :
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"),
], configure));
void Validate(List<ChatMessage> finalChat)
{
IEnumerable<AIContent> content = finalChat.SelectMany(m => m.Contents);
if (keepFunctionCallingMessages)
{
Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent);
}
else
{
Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent));
}
}
}
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task ExceptionDetailsOnlyReportedWhenRequestedAsync(bool detailedErrors)
{
var options = new ChatOptions
{
Tools =
[
AIFunctionFactory.Create(string () => throw new InvalidOperationException("Oh no!"), "Func1"),
]
};
List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: detailedErrors ? "Error: Function failed. Exception: Oh no!" : "Error: Function failed.")]),
new ChatMessage(ChatRole.Assistant, "world"),
];
Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors });
await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
}
[Fact]
public async Task RejectsMultipleChoicesAsync()
{
var func1 = AIFunctionFactory.Create(() => "Some result 1", "Func1");
var func2 = AIFunctionFactory.Create(() => "Some result 2", "Func2");
var expected = new ChatCompletion(
[
new(ChatRole.Assistant, [new FunctionCallContent("callId1", func1.Metadata.Name)]),
new(ChatRole.Assistant, [new FunctionCallContent("callId2", func2.Metadata.Name)]),
]);
using var innerClient = new TestChatClient
{
CompleteAsyncCallback = async (chatContents, options, cancellationToken) =>
{
await Task.Yield();
return expected;
},
CompleteStreamingAsyncCallback = (chatContents, options, cancellationToken) =>
YieldAsync(expected.ToStreamingChatCompletionUpdates()),
};
IChatClient service = innerClient.AsBuilder().UseFunctionInvocation().Build();
List<ChatMessage> chat = [new ChatMessage(ChatRole.User, "hello")];
ChatOptions options = new() { Tools = [func1, func2] };
Validate(await Assert.ThrowsAsync<InvalidOperationException>(() => service.CompleteAsync(chat, options)));
Validate(await Assert.ThrowsAsync<InvalidOperationException>(() => service.CompleteStreamingAsync(chat, options).ToChatCompletionAsync()));
void Validate(Exception ex)
{
Assert.Contains("only accepts a single choice", ex.Message);
Assert.Single(chat); // It didn't add anything to the chat history
}
}
[Theory]
[InlineData(LogLevel.Trace)]
[InlineData(LogLevel.Debug)]
[InlineData(LogLevel.Information)]
public async Task FunctionInvocationsLogged(LogLevel level)
{
List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["arg1"] = "value1" })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]),
new ChatMessage(ChatRole.Assistant, "world"),
];
var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")]
};
Func<ChatClientBuilder, ChatClientBuilder> configure = b =>
b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService<ILogger<FunctionInvokingChatClient>>()));
await InvokeAsync(services => InvokeAndAssertAsync(options, plan, configurePipeline: configure, services: services));
await InvokeAsync(services => InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure, services: services));
async Task InvokeAsync(Func<IServiceProvider, Task> work)
{
var collector = new FakeLogCollector();
ServiceCollection c = new();
c.AddLogging(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(level));
await work(c.BuildServiceProvider());
var logs = collector.GetSnapshot();
if (level is LogLevel.Trace)
{
Assert.Collection(logs,
entry => Assert.True(entry.Message.Contains("Invoking Func1({") && entry.Message.Contains("\"arg1\": \"value1\"")),
entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && entry.Message.Contains("Result: \"Result 1\"")));
}
else if (level is LogLevel.Debug)
{
Assert.Collection(logs,
entry => Assert.True(entry.Message.Contains("Invoking Func1") && !entry.Message.Contains("arg1")),
entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && !entry.Message.Contains("Result")));
}
else
{
Assert.Empty(logs);
}
}
}
[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task FunctionInvocationTrackedWithActivity(bool enableTelemetry)
{
string sourceName = Guid.NewGuid().ToString();
List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["arg1"] = "value1" })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]),
new ChatMessage(ChatRole.Assistant, "world"),
];
ChatOptions options = new()
{
Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")]
};
Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(c =>
new FunctionInvokingChatClient(
new OpenTelemetryChatClient(c, sourceName: sourceName)));
await InvokeAsync(() => InvokeAndAssertAsync(options, plan, configurePipeline: configure));
await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure));
async Task InvokeAsync(Func<Task> work)
{
var activities = new List<Activity>();
using TracerProvider? tracerProvider = enableTelemetry ?
OpenTelemetry.Sdk.CreateTracerProviderBuilder()
.AddSource(sourceName)
.AddInMemoryExporter(activities)
.Build() :
null;
await work();
if (enableTelemetry)
{
Assert.Collection(activities,
activity => Assert.Equal("chat", activity.DisplayName),
activity => Assert.Equal("Func1", activity.DisplayName),
activity => Assert.Equal("chat", activity.DisplayName),
activity => Assert.Equal(nameof(FunctionInvokingChatClient), activity.DisplayName));
for (int i = 0; i < activities.Count - 1; i++)
{
// Activities are exported in the order of completion, so all except the last are children of the last (i.e., outer)
Assert.Same(activities[activities.Count - 1], activities[i].Parent);
}
}
else
{
Assert.Empty(activities);
}
}
}
[Fact]
public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls()
{
var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create((string text) => $"Result for {text}", "Func1")]
};
var messages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
};
using var innerClient = new TestChatClient
{
CompleteStreamingAsyncCallback = (chatContents, chatOptions, cancellationToken) =>
{
// If the conversation is just starting, issue two consecutive updates with function calls
// Otherwise just end the conversation
return chatContents.Last().Text == "Hello"
? YieldAsync(
new StreamingChatCompletionUpdate { Contents = [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["text"] = "Input 1" })] },
new StreamingChatCompletionUpdate { Contents = [new FunctionCallContent("callId2", "Func1", new Dictionary<string, object?> { ["text"] = "Input 2" })] })
: YieldAsync(
new StreamingChatCompletionUpdate { Contents = [new TextContent("OK bye")] });
}
};
using var client = new FunctionInvokingChatClient(innerClient);
var updates = new List<StreamingChatCompletionUpdate>();
await foreach (var update in client.CompleteStreamingAsync(messages, options, CancellationToken.None))
{
updates.Add(update);
}
// Message history should now include the FCCs and FRCs
Assert.Collection(messages,
m => Assert.Equal("Hello", Assert.IsType<TextContent>(Assert.Single(m.Contents)).Text),
m => Assert.Collection(m.Contents,
c => Assert.Equal("Input 1", Assert.IsType<FunctionCallContent>(c).Arguments!["text"]),
c => Assert.Equal("Input 2", Assert.IsType<FunctionCallContent>(c).Arguments!["text"])),
m => Assert.Collection(m.Contents,
c => Assert.Equal("Result for Input 1", Assert.IsType<FunctionResultContent>(c).Result?.ToString()),
c => Assert.Equal("Result for Input 2", Assert.IsType<FunctionResultContent>(c).Result?.ToString())));
// The returned updates should *not* include the FCCs and FRCs
var allUpdateContents = updates.SelectMany(updates => updates.Contents).ToList();
var singleUpdateContent = Assert.IsType<TextContent>(Assert.Single(allUpdateContents));
Assert.Equal("OK bye", singleUpdateContent.Text);
}
private static async Task<List<ChatMessage>> InvokeAndAssertAsync(
ChatOptions options,
List<ChatMessage> plan,
List<ChatMessage>? expected = null,
Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
IServiceProvider? services = null)
{
Assert.NotEmpty(plan);
configurePipeline ??= static b => b.UseFunctionInvocation();
using CancellationTokenSource cts = new();
List<ChatMessage> chat = [plan[0]];
var expectedTotalTokenCounts = 0;
using var innerClient = new TestChatClient
{
CompleteAsyncCallback = async (contents, actualOptions, actualCancellationToken) =>
{
Assert.Same(chat, contents);
Assert.Equal(cts.Token, actualCancellationToken);
await Task.Yield();
var usage = CreateRandomUsage();
expectedTotalTokenCounts += usage.InputTokenCount!.Value;
return new ChatCompletion(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])) { Usage = usage };
}
};
IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services);
var result = await service.CompleteAsync(chat, options, cts.Token);
chat.Add(result.Message);
expected ??= plan;
Assert.NotNull(result);
Assert.Equal(expected.Count, chat.Count);
for (int i = 0; i < expected.Count; i++)
{
var expectedMessage = expected[i];
var chatMessage = chat[i];
Assert.Equal(expectedMessage.Role, chatMessage.Role);
Assert.Equal(expectedMessage.Text, chatMessage.Text);
Assert.Equal(expectedMessage.GetType(), chatMessage.GetType());
Assert.Equal(expectedMessage.Contents.Count, chatMessage.Contents.Count);
for (int j = 0; j < expectedMessage.Contents.Count; j++)
{
var expectedItem = expectedMessage.Contents[j];
var chatItem = chatMessage.Contents[j];
Assert.Equal(expectedItem.GetType(), chatItem.GetType());
Assert.Equal(expectedItem.ToString(), chatItem.ToString());
if (expectedItem is FunctionCallContent expectedFunctionCall)
{
var chatFunctionCall = (FunctionCallContent)chatItem;
Assert.Equal(expectedFunctionCall.Name, chatFunctionCall.Name);
AssertExtensions.EqualFunctionCallParameters(expectedFunctionCall.Arguments, chatFunctionCall.Arguments);
}
else if (expectedItem is FunctionResultContent expectedFunctionResult)
{
var chatFunctionResult = (FunctionResultContent)chatItem;
AssertExtensions.EqualFunctionCallResults(expectedFunctionResult.Result, chatFunctionResult.Result);
}
}
}
// Usage should be aggregated over all responses, including AdditionalUsage
var actualUsage = result.Usage!;
Assert.Equal(expectedTotalTokenCounts, actualUsage.InputTokenCount);
Assert.Equal(expectedTotalTokenCounts, actualUsage.OutputTokenCount);
Assert.Equal(expectedTotalTokenCounts, actualUsage.TotalTokenCount);
Assert.Equal(2, actualUsage.AdditionalCounts!.Count);
Assert.Equal(expectedTotalTokenCounts, actualUsage.AdditionalCounts["firstValue"]);
Assert.Equal(expectedTotalTokenCounts, actualUsage.AdditionalCounts["secondValue"]);
return chat;
}
private static UsageDetails CreateRandomUsage()
{
// We'll set the same random number on all the properties so that, when determining the
// correct sum in tests, we only have to total the values once
var value = new Random().Next(100);
return new UsageDetails
{
InputTokenCount = value,
OutputTokenCount = value,
TotalTokenCount = value,
AdditionalCounts = new() { ["firstValue"] = value, ["secondValue"] = value },
};
}
private static async Task<List<ChatMessage>> InvokeAndAssertStreamingAsync(
ChatOptions options,
List<ChatMessage> plan,
List<ChatMessage>? expected = null,
Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
IServiceProvider? services = null)
{
Assert.NotEmpty(plan);
configurePipeline ??= static b => b.UseFunctionInvocation();
using CancellationTokenSource cts = new();
List<ChatMessage> chat = [plan[0]];
using var innerClient = new TestChatClient
{
CompleteStreamingAsyncCallback = (contents, actualOptions, actualCancellationToken) =>
{
Assert.Same(chat, contents);
Assert.Equal(cts.Token, actualCancellationToken);
return YieldAsync(new ChatCompletion(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])).ToStreamingChatCompletionUpdates());
}
};
IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services);
var result = await service.CompleteStreamingAsync(chat, options, cts.Token).ToChatCompletionAsync();
chat.Add(result.Message);
expected ??= plan;
Assert.NotNull(result);
Assert.Equal(expected.Count, chat.Count);
for (int i = 0; i < expected.Count; i++)
{
var expectedMessage = expected[i];
var chatMessage = chat[i];
Assert.Equal(expectedMessage.Role, chatMessage.Role);
Assert.Equal(expectedMessage.Text, chatMessage.Text);
Assert.Equal(expectedMessage.GetType(), chatMessage.GetType());
Assert.Equal(expectedMessage.Contents.Count, chatMessage.Contents.Count);
for (int j = 0; j < expectedMessage.Contents.Count; j++)
{
var expectedItem = expectedMessage.Contents[j];
var chatItem = chatMessage.Contents[j];
Assert.Equal(expectedItem.GetType(), chatItem.GetType());
Assert.Equal(expectedItem.ToString(), chatItem.ToString());
if (expectedItem is FunctionCallContent expectedFunctionCall)
{
var chatFunctionCall = (FunctionCallContent)chatItem;
Assert.Equal(expectedFunctionCall.Name, chatFunctionCall.Name);
AssertExtensions.EqualFunctionCallParameters(expectedFunctionCall.Arguments, chatFunctionCall.Arguments);
}
else if (expectedItem is FunctionResultContent expectedFunctionResult)
{
var chatFunctionResult = (FunctionResultContent)chatItem;
AssertExtensions.EqualFunctionCallResults(expectedFunctionResult.Result, chatFunctionResult.Result);
}
}
}
return chat;
}
private static async IAsyncEnumerable<T> YieldAsync<T>(params T[] items)
{
await Task.Yield();
foreach (var item in items)
{
yield return item;
}
}
}
|