File: ChatCompletion\ChatClientExtensionsTests.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AI.Abstractions.Tests\Microsoft.Extensions.AI.Abstractions.Tests.csproj (Microsoft.Extensions.AI.Abstractions.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
 
namespace Microsoft.Extensions.AI;
 
public class ChatClientExtensionsTests
{
    [Fact]
    public void GetService_InvalidArgs_Throws()
    {
        Assert.Throws<ArgumentNullException>("client", () => ChatClientExtensions.GetService<object>(null!));
    }
 
    [Fact]
    public void GetRequiredService_InvalidArgs_Throws()
    {
        Assert.Throws<ArgumentNullException>("client", () => ChatClientExtensions.GetRequiredService(null!, typeof(string)));
        Assert.Throws<ArgumentNullException>("client", () => ChatClientExtensions.GetRequiredService<object>(null!));
 
        using var client = new TestChatClient();
        Assert.Throws<ArgumentNullException>("serviceType", () => client.GetRequiredService(null!));
    }
 
    [Fact]
    public void GetService_ValidService_Returned()
    {
        using var client = new TestChatClient
        {
            GetServiceCallback = (serviceType, serviceKey) =>
            {
                if (serviceType == typeof(string))
                {
                    return serviceKey == null ? "null key" : "non-null key";
                }
 
                if (serviceType == typeof(IChatClient))
                {
                    return new object();
                }
 
                return null;
            },
        };
 
        Assert.Equal("null key", client.GetService<string>());
        Assert.Equal("null key", client.GetService<string>(null));
        Assert.Equal("non-null key", client.GetService<string>("key"));
 
        Assert.Null(client.GetService<object>());
        Assert.Null(client.GetService<object>("key"));
        Assert.Null(client.GetService<IChatClient>());
 
        Assert.Equal("null key", client.GetRequiredService(typeof(string)));
        Assert.Equal("null key", client.GetRequiredService<string>());
        Assert.Equal("null key", client.GetRequiredService<string>(null));
        Assert.Equal("non-null key", client.GetRequiredService(typeof(string), "key"));
        Assert.Equal("non-null key", client.GetRequiredService<string>("key"));
 
        Assert.Throws<InvalidOperationException>(() => client.GetRequiredService(typeof(object)));
        Assert.Throws<InvalidOperationException>(() => client.GetRequiredService<object>());
        Assert.Throws<InvalidOperationException>(() => client.GetRequiredService(typeof(object), "key"));
        Assert.Throws<InvalidOperationException>(() => client.GetRequiredService<object>("key"));
        Assert.Throws<InvalidOperationException>(() => client.GetRequiredService<IChatClient>());
    }
 
    [Fact]
    public void GetResponseAsync_InvalidArgs_Throws()
    {
        Assert.Throws<ArgumentNullException>("client", () =>
        {
            _ = ChatClientExtensions.GetResponseAsync(null!, "hello");
        });
 
        Assert.Throws<ArgumentNullException>("chatMessage", () =>
        {
            _ = ChatClientExtensions.GetResponseAsync(new TestChatClient(), (ChatMessage)null!);
        });
    }
 
    [Fact]
    public void GetStreamingResponseAsync_InvalidArgs_Throws()
    {
        Assert.Throws<ArgumentNullException>("client", () =>
        {
            _ = ChatClientExtensions.GetStreamingResponseAsync(null!, "hello");
        });
 
        Assert.Throws<ArgumentNullException>("chatMessage", () =>
        {
            _ = ChatClientExtensions.GetStreamingResponseAsync(new TestChatClient(), (ChatMessage)null!);
        });
    }
 
    [Fact]
    public async Task GetResponseAsync_CreatesTextMessageAsync()
    {
        var expectedResponse = new ChatResponse();
        var expectedOptions = new ChatOptions();
        using var cts = new CancellationTokenSource();
 
        using TestChatClient client = new()
        {
            GetResponseAsyncCallback = (messages, options, cancellationToken) =>
            {
                ChatMessage m = Assert.Single(messages);
                Assert.Equal(ChatRole.User, m.Role);
                Assert.Equal("hello", m.Text);
 
                Assert.Same(expectedOptions, options);
 
                Assert.Equal(cts.Token, cancellationToken);
 
                return Task.FromResult(expectedResponse);
            },
        };
 
        ChatResponse response = await client.GetResponseAsync("hello", expectedOptions, cts.Token);
 
        Assert.Same(expectedResponse, response);
    }
 
    [Fact]
    public async Task GetStreamingResponseAsync_CreatesTextMessageAsync()
    {
        var expectedOptions = new ChatOptions();
        using var cts = new CancellationTokenSource();
 
        using TestChatClient client = new()
        {
            GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) =>
            {
                ChatMessage m = Assert.Single(messages);
                Assert.Equal(ChatRole.User, m.Role);
                Assert.Equal("hello", m.Text);
 
                Assert.Same(expectedOptions, options);
 
                Assert.Equal(cts.Token, cancellationToken);
 
                return YieldAsync([new ChatResponseUpdate(ChatRole.Assistant, "world")]);
            },
        };
 
        int count = 0;
        await foreach (var update in client.GetStreamingResponseAsync("hello", expectedOptions, cts.Token))
        {
            Assert.Equal(0, count);
            Assert.Equal("world", update.Text);
            count++;
        }
 
        Assert.Equal(1, count);
    }
 
    private static async IAsyncEnumerable<ChatResponseUpdate> YieldAsync(params ChatResponseUpdate[] updates)
    {
        await Task.Yield();
        foreach (var update in updates)
        {
            yield return update;
        }
    }
}