File: ChatCompletion\ImageGeneratingChatClientTests.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AI.Tests\Microsoft.Extensions.AI.Tests.csproj (Microsoft.Extensions.AI.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
 
namespace Microsoft.Extensions.AI;
 
public class ImageGeneratingChatClientTests
{
    [Fact]
    public void ImageGeneratingChatClient_InvalidArgs_Throws()
    {
        using var innerClient = new TestChatClient();
        using var imageGenerator = new TestImageGenerator();
 
        Assert.Throws<ArgumentNullException>("innerClient", () => new ImageGeneratingChatClient(null!, imageGenerator));
        Assert.Throws<ArgumentNullException>("imageGenerator", () => new ImageGeneratingChatClient(innerClient, null!));
    }
 
    [Fact]
    public void UseImageGeneration_WithNullBuilder_Throws()
    {
        Assert.Throws<ArgumentNullException>("builder", () => ((ChatClientBuilder)null!).UseImageGeneration());
    }
 
    [Fact]
    public async Task GetResponseAsync_WithoutImageGenerationTool_PassesThrough()
    {
        // Arrange
        using var innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = (messages, options, cancellationToken) =>
            {
                return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "test response")));
            },
        };
 
        using var imageGenerator = new TestImageGenerator();
        using var client = new ImageGeneratingChatClient(innerClient, imageGenerator);
 
        var chatOptions = new ChatOptions
        {
            Tools = [AIFunctionFactory.Create(() => "dummy function", name: "DummyFunction")]
        };
 
        // Act
        var response = await client.GetResponseAsync([new(ChatRole.User, "test")], chatOptions);
 
        // Assert
        Assert.NotNull(response);
        Assert.Equal("test response", response.Messages[0].Text);
 
        // Verify that tools collection still has the original function, not replaced
        Assert.Single(chatOptions.Tools);
        Assert.IsAssignableFrom<AIFunction>(chatOptions.Tools[0]);
    }
 
    [Fact]
    public async Task GetResponseAsync_WithImageGenerationTool_ReplacesTool()
    {
        // Arrange
        bool innerClientCalled = false;
        ChatOptions? capturedOptions = null;
 
        using var innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = (messages, options, cancellationToken) =>
            {
                innerClientCalled = true;
                capturedOptions = options;
                return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "test response")));
            },
        };
 
        using var imageGenerator = new TestImageGenerator();
        using var client = new ImageGeneratingChatClient(innerClient, imageGenerator);
 
        var chatOptions = new ChatOptions
        {
            Tools = [new HostedImageGenerationTool()]
        };
 
        // Act
        var response = await client.GetResponseAsync([new(ChatRole.User, "test")], chatOptions);
 
        // Assert
        Assert.True(innerClientCalled);
        Assert.NotNull(capturedOptions);
        Assert.NotNull(capturedOptions.Tools);
        Assert.Equal(3, capturedOptions.Tools.Count);
 
        // Verify the functions are properly created
        var generateImageFunction = capturedOptions.Tools[0] as AIFunction;
        var editImageFunction = capturedOptions.Tools[1] as AIFunction;
        var getImagesForEditImageFunction = capturedOptions.Tools[2] as AIFunction;
 
        Assert.NotNull(generateImageFunction);
        Assert.NotNull(editImageFunction);
        Assert.NotNull(getImagesForEditImageFunction);
        Assert.Equal("GenerateImage", generateImageFunction.Name);
        Assert.Equal("EditImage", editImageFunction.Name);
        Assert.Equal("GetImagesForEdit", getImagesForEditImageFunction.Name);
    }
 
    [Fact]
    public async Task GetResponseAsync_WithMixedTools_ReplacesOnlyImageGenerationTool()
    {
        // Arrange
        bool innerClientCalled = false;
        ChatOptions? capturedOptions = null;
 
        using var innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = (messages, options, cancellationToken) =>
            {
                innerClientCalled = true;
                capturedOptions = options;
                return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "test response")));
            },
        };
 
        using var imageGenerator = new TestImageGenerator();
        using var client = new ImageGeneratingChatClient(innerClient, imageGenerator);
 
        var dummyFunction = AIFunctionFactory.Create(() => "dummy", name: "DummyFunction");
        var chatOptions = new ChatOptions
        {
            Tools = [dummyFunction, new HostedImageGenerationTool()]
        };
 
        // Act
        var response = await client.GetResponseAsync([new(ChatRole.User, "test")], chatOptions);
 
        // Assert
        Assert.True(innerClientCalled);
        Assert.NotNull(capturedOptions);
        Assert.NotNull(capturedOptions.Tools);
        Assert.Equal(4, capturedOptions.Tools.Count); // DummyFunction + GenerateImage + EditImage + GetImagesForEdit
 
        Assert.Same(dummyFunction, capturedOptions.Tools[0]); // Original function preserved
        Assert.IsAssignableFrom<AIFunction>(capturedOptions.Tools[1]); // GenerateImage function
        Assert.IsAssignableFrom<AIFunction>(capturedOptions.Tools[2]); // EditImage function
    }
 
    [Fact]
    public void UseImageGeneration_ServiceProviderIntegration_Works()
    {
        // Arrange
        var services = new ServiceCollection();
        services.AddSingleton<IImageGenerator, TestImageGenerator>();
 
        using var serviceProvider = services.BuildServiceProvider();
        using var innerClient = new TestChatClient();
 
        // Act
        using var client = innerClient
            .AsBuilder()
            .UseImageGeneration()
            .Build(serviceProvider);
 
        // Assert
        Assert.IsType<ImageGeneratingChatClient>(client);
    }
 
    [Fact]
    public void UseImageGeneration_WithProvidedImageGenerator_Works()
    {
        // Arrange
        using var innerClient = new TestChatClient();
        using var imageGenerator = new TestImageGenerator();
 
        // Act
        using var client = innerClient
            .AsBuilder()
            .UseImageGeneration(imageGenerator)
            .Build();
 
        // Assert
        Assert.IsType<ImageGeneratingChatClient>(client);
    }
 
    [Fact]
    public void UseImageGeneration_WithConfigureCallback_CallsCallback()
    {
        // Arrange
        using var innerClient = new TestChatClient();
        using var imageGenerator = new TestImageGenerator();
        bool configureCallbackInvoked = false;
 
        // Act
        using var client = innerClient
            .AsBuilder()
            .UseImageGeneration(imageGenerator, configure: c =>
            {
                Assert.NotNull(c);
                configureCallbackInvoked = true;
            })
            .Build();
 
        // Assert
        Assert.True(configureCallbackInvoked);
    }
 
    [Fact]
    public async Task GetStreamingResponseAsync_WithImageGenerationTool_ReplacesTool()
    {
        // Arrange
        bool innerClientCalled = false;
        ChatOptions? capturedOptions = null;
 
        using var innerClient = new TestChatClient
        {
            GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) =>
            {
                innerClientCalled = true;
                capturedOptions = options;
                return GetUpdatesAsync();
            }
        };
 
        static async IAsyncEnumerable<ChatResponseUpdate> GetUpdatesAsync()
        {
            await Task.Yield();
            yield return new(ChatRole.Assistant, "test");
        }
 
        using var imageGenerator = new TestImageGenerator();
        using var client = new ImageGeneratingChatClient(innerClient, imageGenerator);
 
        var chatOptions = new ChatOptions
        {
            Tools = [new HostedImageGenerationTool()]
        };
 
        // Act
        await foreach (var update in client.GetStreamingResponseAsync([new(ChatRole.User, "test")], chatOptions))
        {
            // Process updates
        }
 
        // Assert
        Assert.True(innerClientCalled);
        Assert.NotNull(capturedOptions);
        Assert.NotNull(capturedOptions.Tools);
        Assert.Equal(3, capturedOptions.Tools.Count);
    }
 
    [Fact]
    public async Task GetResponseAsync_WithNullOptions_DoesNotThrow()
    {
        // Arrange
        using var innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = (messages, options, cancellationToken) =>
            {
                return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "test response")));
            },
        };
 
        using var imageGenerator = new TestImageGenerator();
        using var client = new ImageGeneratingChatClient(innerClient, imageGenerator);
 
        // Act & Assert
        var response = await client.GetResponseAsync([new(ChatRole.User, "test")], null);
        Assert.NotNull(response);
    }
 
    [Fact]
    public async Task GetResponseAsync_WithEmptyTools_DoesNotModify()
    {
        // Arrange
        ChatOptions? capturedOptions = null;
 
        using var innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = (messages, options, cancellationToken) =>
            {
                capturedOptions = options;
                return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "test response")));
            },
        };
 
        using var imageGenerator = new TestImageGenerator();
        using var client = new ImageGeneratingChatClient(innerClient, imageGenerator);
 
        var chatOptions = new ChatOptions
        {
            Tools = []
        };
 
        // Act
        await client.GetResponseAsync([new(ChatRole.User, "test")], chatOptions);
 
        // Assert
        Assert.Same(chatOptions, capturedOptions);
#pragma warning disable CA1508
        Assert.NotNull(capturedOptions?.Tools);
#pragma warning restore CA1508
        Assert.Empty(capturedOptions.Tools);
    }
 
    [Fact]
    public async Task GetResponseAsync_WithFunctionCallContent_ReplacesWithImageGenerationToolCallContent()
    {
        // Arrange
        var callId = "test-call-id";
        using var innerClient = new TestChatClient
        {
            GetResponseAsyncCallback = (messages, options, cancellationToken) =>
            {
                var responseMessage = new ChatMessage(ChatRole.Assistant,
                    [new FunctionCallContent(callId, "GenerateImage", new Dictionary<string, object?> { ["prompt"] = "a cat" })]);
                return Task.FromResult(new ChatResponse(responseMessage));
            },
        };
 
        using var imageGenerator = new TestImageGenerator();
        using var client = new ImageGeneratingChatClient(innerClient, imageGenerator);
 
        var chatOptions = new ChatOptions
        {
            Tools = [new HostedImageGenerationTool()]
        };
 
        // Act
        var response = await client.GetResponseAsync([new(ChatRole.User, "test")], chatOptions);
 
        // Assert
        Assert.NotNull(response);
        Assert.Single(response.Messages);
        var message = response.Messages[0];
        Assert.Single(message.Contents);
 
        var imageToolCallContent = Assert.IsType<ImageGenerationToolCallContent>(message.Contents[0]);
        Assert.Equal(callId, imageToolCallContent.ImageId);
    }
 
    [Fact]
    public async Task GetStreamingResponseAsync_WithFunctionCallContent_ReplacesWithImageGenerationToolCallContent()
    {
        // Arrange
        var callId = "test-call-id";
        using var innerClient = new TestChatClient
        {
            GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) =>
            {
                return GetUpdatesAsync();
            }
        };
 
        async IAsyncEnumerable<ChatResponseUpdate> GetUpdatesAsync()
        {
            await Task.Yield();
            yield return new ChatResponseUpdate(ChatRole.Assistant,
                [new FunctionCallContent(callId, "GenerateImage", new Dictionary<string, object?> { ["prompt"] = "a cat" })]);
        }
 
        using var imageGenerator = new TestImageGenerator();
        using var client = new ImageGeneratingChatClient(innerClient, imageGenerator);
 
        var chatOptions = new ChatOptions
        {
            Tools = [new HostedImageGenerationTool()]
        };
 
        // Act
        var updates = new List<ChatResponseUpdate>();
        await foreach (var responseUpdate in client.GetStreamingResponseAsync([new(ChatRole.User, "test")], chatOptions))
        {
            updates.Add(responseUpdate);
        }
 
        // Assert
        Assert.Single(updates);
        var update = updates[0];
        Assert.Single(update.Contents);
 
        var imageToolCallContent = Assert.IsType<ImageGenerationToolCallContent>(update.Contents[0]);
        Assert.Equal(callId, imageToolCallContent.ImageId);
    }
}