File: ChatCompletion\UseDelegateChatClientTests.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AI.Tests\Microsoft.Extensions.AI.Tests.csproj (Microsoft.Extensions.AI.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
 
namespace Microsoft.Extensions.AI;
 
public class UseDelegateChatClientTests
{
    [Fact]
    public void InvalidArgs_Throws()
    {
        using var client = new TestChatClient();
        ChatClientBuilder builder = new(client);
 
        Assert.Throws<ArgumentNullException>("getResponseFunc", () => builder.Use(null!, null!));
    }
 
    [Fact]
    public async Task Shared_ContextPropagated()
    {
        IList<ChatMessage> expectedMessages = [];
        ChatOptions expectedOptions = new();
        using CancellationTokenSource expectedCts = new();
 
        AsyncLocal<int> asyncLocal = new();
 
        using IChatClient innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = (chatMessages, options, cancellationToken) =>
            {
                Assert.Same(expectedMessages, chatMessages);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                Assert.Equal(42, asyncLocal.Value);
                return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "hello")));
            },
 
            GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) =>
            {
                Assert.Same(expectedMessages, chatMessages);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                Assert.Equal(42, asyncLocal.Value);
                return YieldUpdates(new ChatResponseUpdate { Text = "world" });
            },
        };
 
        using IChatClient client = new ChatClientBuilder(innerClient)
            .Use(async (chatMessages, options, next, cancellationToken) =>
            {
                Assert.Same(expectedMessages, chatMessages);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                asyncLocal.Value = 42;
                await next(chatMessages, options, cancellationToken);
            })
            .Build();
 
        Assert.Equal(0, asyncLocal.Value);
        ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token);
        Assert.Equal("hello", response.Message.Text);
 
        Assert.Equal(0, asyncLocal.Value);
        response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync();
        Assert.Equal("world", response.Message.Text);
    }
 
    [Fact]
    public async Task GetResponseFunc_ContextPropagated()
    {
        IList<ChatMessage> expectedMessages = [];
        ChatOptions expectedOptions = new();
        using CancellationTokenSource expectedCts = new();
        AsyncLocal<int> asyncLocal = new();
 
        using IChatClient innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = (chatMessages, options, cancellationToken) =>
            {
                Assert.Same(expectedMessages, chatMessages);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                Assert.Equal(42, asyncLocal.Value);
                return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "hello")));
            },
        };
 
        using IChatClient client = new ChatClientBuilder(innerClient)
            .Use(async (chatMessages, options, innerClient, cancellationToken) =>
            {
                Assert.Same(expectedMessages, chatMessages);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                asyncLocal.Value = 42;
                var cc = await innerClient.GetResponseAsync(chatMessages, options, cancellationToken);
                cc.Choices[0].Text += " world";
                return cc;
            }, null)
            .Build();
 
        Assert.Equal(0, asyncLocal.Value);
 
        ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token);
        Assert.Equal("hello world", response.Message.Text);
 
        response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync();
        Assert.Equal("hello world", response.Message.Text);
    }
 
    [Fact]
    public async Task GetStreamingResponseFunc_ContextPropagated()
    {
        IList<ChatMessage> expectedMessages = [];
        ChatOptions expectedOptions = new();
        using CancellationTokenSource expectedCts = new();
        AsyncLocal<int> asyncLocal = new();
 
        using IChatClient innerClient = new TestChatClient
        {
            GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) =>
            {
                Assert.Same(expectedMessages, chatMessages);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                Assert.Equal(42, asyncLocal.Value);
                return YieldUpdates(new ChatResponseUpdate { Text = "hello" });
            },
        };
 
        using IChatClient client = new ChatClientBuilder(innerClient)
            .Use(null, (chatMessages, options, innerClient, cancellationToken) =>
            {
                Assert.Same(expectedMessages, chatMessages);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                asyncLocal.Value = 42;
                return Impl(chatMessages, options, innerClient, cancellationToken);
 
                static async IAsyncEnumerable<ChatResponseUpdate> Impl(
                    IList<ChatMessage> chatMessages, ChatOptions? options, IChatClient innerClient, [EnumeratorCancellation] CancellationToken cancellationToken)
                {
                    await foreach (var update in innerClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken))
                    {
                        yield return update;
                    }
 
                    yield return new() { Text = " world" };
                }
            })
            .Build();
 
        Assert.Equal(0, asyncLocal.Value);
 
        ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token);
        Assert.Equal("hello world", response.Message.Text);
 
        response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync();
        Assert.Equal("hello world", response.Message.Text);
    }
 
    [Fact]
    public async Task BothGetResponseAndGetStreamingResponseFuncs_ContextPropagated()
    {
        IList<ChatMessage> expectedMessages = [];
        ChatOptions expectedOptions = new();
        using CancellationTokenSource expectedCts = new();
        AsyncLocal<int> asyncLocal = new();
 
        using IChatClient innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = (chatMessages, options, cancellationToken) =>
            {
                Assert.Same(expectedMessages, chatMessages);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                Assert.Equal(42, asyncLocal.Value);
                return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "non-streaming hello")));
            },
 
            GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) =>
            {
                Assert.Same(expectedMessages, chatMessages);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                Assert.Equal(42, asyncLocal.Value);
                return YieldUpdates(new ChatResponseUpdate { Text = "streaming hello" });
            },
        };
 
        using IChatClient client = new ChatClientBuilder(innerClient)
            .Use(
                async (chatMessages, options, innerClient, cancellationToken) =>
                {
                    Assert.Same(expectedMessages, chatMessages);
                    Assert.Same(expectedOptions, options);
                    Assert.Equal(expectedCts.Token, cancellationToken);
                    asyncLocal.Value = 42;
                    var cc = await innerClient.GetResponseAsync(chatMessages, options, cancellationToken);
                    cc.Choices[0].Text += " world (non-streaming)";
                    return cc;
                },
                (chatMessages, options, innerClient, cancellationToken) =>
                {
                    Assert.Same(expectedMessages, chatMessages);
                    Assert.Same(expectedOptions, options);
                    Assert.Equal(expectedCts.Token, cancellationToken);
                    asyncLocal.Value = 42;
                    return Impl(chatMessages, options, innerClient, cancellationToken);
 
                    static async IAsyncEnumerable<ChatResponseUpdate> Impl(
                        IList<ChatMessage> chatMessages, ChatOptions? options, IChatClient innerClient, [EnumeratorCancellation] CancellationToken cancellationToken)
                    {
                        await foreach (var update in innerClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken))
                        {
                            yield return update;
                        }
 
                        yield return new() { Text = " world (streaming)" };
                    }
                })
            .Build();
 
        Assert.Equal(0, asyncLocal.Value);
 
        ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token);
        Assert.Equal("non-streaming hello world (non-streaming)", response.Message.Text);
 
        response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync();
        Assert.Equal("streaming hello world (streaming)", response.Message.Text);
    }
 
    private static async IAsyncEnumerable<ChatResponseUpdate> YieldUpdates(params ChatResponseUpdate[] updates)
    {
        foreach (var update in updates)
        {
            await Task.Yield();
            yield return update;
        }
    }
}