File: Embeddings\EmbeddingGeneratorExtensionsTests.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AI.Abstractions.Tests\Microsoft.Extensions.AI.Abstractions.Tests.csproj (Microsoft.Extensions.AI.Abstractions.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Linq;
using System.Threading.Tasks;
using Xunit;
 
namespace Microsoft.Extensions.AI;
 
public class EmbeddingGeneratorExtensionsTests
{
    [Fact]
    public void GetService_InvalidArgs_Throws()
    {
        Assert.Throws<ArgumentNullException>("generator", () => EmbeddingGeneratorExtensions.GetService<object>(null!));
    }
 
    [Fact]
    public void GetRequiredService_InvalidArgs_Throws()
    {
        Assert.Throws<ArgumentNullException>("generator", () => EmbeddingGeneratorExtensions.GetRequiredService<object>(null!));
 
        using var generator = new TestEmbeddingGenerator();
        Assert.Throws<ArgumentNullException>("serviceType", () => generator.GetRequiredService(null!));
    }
 
    [Fact]
    public void GetService_ValidService_Returned()
    {
        using IEmbeddingGenerator<string, Embedding<float>> generator = new TestEmbeddingGenerator
        {
            GetServiceCallback = (serviceType, serviceKey) =>
            {
                if (serviceType == typeof(string))
                {
                    return serviceKey == null ? "null key" : "non-null key";
                }
 
                if (serviceType == typeof(IEmbeddingGenerator<string, Embedding<float>>))
                {
                    return new object();
                }
 
                return null;
            },
        };
 
        Assert.Equal("null key", generator.GetService(typeof(string)));
        Assert.Equal("null key", generator.GetService<string>());
 
        Assert.Equal("non-null key", generator.GetService(typeof(string), "key"));
        Assert.Equal("non-null key", generator.GetService<string>("key"));
 
        Assert.Null(generator.GetService(typeof(object)));
        Assert.Null(generator.GetService<object>());
 
        Assert.Null(generator.GetService(typeof(object), "key"));
        Assert.Null(generator.GetService<object>("key"));
 
        Assert.Null(generator.GetService<int?>());
 
        Assert.Equal("null key", generator.GetRequiredService(typeof(string)));
        Assert.Equal("null key", generator.GetRequiredService<string>());
 
        Assert.Equal("non-null key", generator.GetRequiredService(typeof(string), "key"));
        Assert.Equal("non-null key", generator.GetRequiredService<string>("key"));
 
        Assert.Throws<InvalidOperationException>(() => generator.GetRequiredService(typeof(object)));
        Assert.Throws<InvalidOperationException>(() => generator.GetRequiredService<object>());
 
        Assert.Throws<InvalidOperationException>(() => generator.GetRequiredService(typeof(object), "key"));
        Assert.Throws<InvalidOperationException>(() => generator.GetRequiredService<object>("key"));
 
        Assert.Throws<InvalidOperationException>(() => generator.GetRequiredService<int?>());
    }
 
    [Fact]
    public async Task GenerateAsync_InvalidArgs_ThrowsAsync()
    {
        await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateAsync("hello"));
        await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateVectorAsync("hello"));
        await Assert.ThrowsAsync<ArgumentNullException>("generator", () => ((TestEmbeddingGenerator)null!).GenerateAndZipAsync(["hello"]));
    }
 
    [Fact]
    public async Task GenerateAsync_ReturnsSingleEmbeddingAsync()
    {
        Embedding<float> result = new(new float[] { 1f, 2f, 3f });
 
        using TestEmbeddingGenerator service = new()
        {
            GenerateAsyncCallback = (values, options, cancellationToken) =>
                Task.FromResult<GeneratedEmbeddings<Embedding<float>>>([result])
        };
 
        Assert.Same(result, await service.GenerateAsync("hello"));
        Assert.Equal(result.Vector, await service.GenerateVectorAsync("hello"));
    }
 
    [Theory]
    [InlineData(0)]
    [InlineData(1)]
    [InlineData(10)]
    public async Task GenerateAndZipEmbeddingsAsync_ReturnsExpectedList(int count)
    {
        string[] inputs = Enumerable.Range(0, count).Select(i => $"hello {i}").ToArray();
        Embedding<float>[] embeddings = Enumerable
            .Range(0, count)
            .Select(i => new Embedding<float>(Enumerable.Range(i, 4).Select(i => (float)i).ToArray()))
            .ToArray();
 
        using TestEmbeddingGenerator service = new()
        {
            GenerateAsyncCallback = (values, options, cancellationToken) =>
                Task.FromResult<GeneratedEmbeddings<Embedding<float>>>(new(embeddings))
        };
 
        var results = await service.GenerateAndZipAsync(inputs);
        Assert.NotNull(results);
        Assert.Equal(count, results.Length);
        for (int i = 0; i < count; i++)
        {
            Assert.Equal(inputs[i], results[i].Value);
            Assert.Same(embeddings[i], results[i].Embedding);
        }
    }
}