File: Embeddings\UseDelegateEmbeddingGeneratorTests.cs
Web Access
Project: src\test\Libraries\Microsoft.Extensions.AI.Tests\Microsoft.Extensions.AI.Tests.csproj (Microsoft.Extensions.AI.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
 
namespace Microsoft.Extensions.AI;
 
public class UseDelegateEmbeddingGeneratorTests
{
    [Fact]
    public void InvalidArgs_Throws()
    {
        using var generator = new TestEmbeddingGenerator();
        EmbeddingGeneratorBuilder<string, Embedding<float>> builder = new(generator);
 
        Assert.Throws<ArgumentNullException>("generateFunc", () =>
            builder.Use((Func<IEnumerable<string>, EmbeddingGenerationOptions?, IEmbeddingGenerator<string, Embedding<float>>, CancellationToken, Task<GeneratedEmbeddings<Embedding<float>>>>)null!));
 
        Assert.Throws<ArgumentNullException>("innerGenerator", () =>
            new AnonymousDelegatingEmbeddingGenerator<string, Embedding<float>>(
                null!, (values, options, innerGenerator, cancellationToken) => Task.FromResult(new GeneratedEmbeddings<Embedding<float>>(Array.Empty<Embedding<float>>()))));
 
        Assert.Throws<ArgumentNullException>("generateFunc", () =>
            new AnonymousDelegatingEmbeddingGenerator<string, Embedding<float>>(generator, null!));
    }
 
    [Fact]
    public async Task GenerateFunc_ContextPropagated()
    {
        GeneratedEmbeddings<Embedding<float>> expectedEmbeddings = new();
        IList<string> expectedValues = ["hello"];
        EmbeddingGenerationOptions expectedOptions = new();
        using CancellationTokenSource expectedCts = new();
        AsyncLocal<int> asyncLocal = new();
 
        using IEmbeddingGenerator<string, Embedding<float>> innerGenerator = new TestEmbeddingGenerator
        {
            GenerateAsyncCallback = (values, options, cancellationToken) =>
            {
                Assert.Same(expectedValues, values);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                Assert.Equal(42, asyncLocal.Value);
                return Task.FromResult(expectedEmbeddings);
            },
        };
 
        using IEmbeddingGenerator<string, Embedding<float>> generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>(innerGenerator)
            .Use(async (values, options, innerGenerator, cancellationToken) =>
            {
                Assert.Same(expectedValues, values);
                Assert.Same(expectedOptions, options);
                Assert.Equal(expectedCts.Token, cancellationToken);
                asyncLocal.Value = 42;
                var e = await innerGenerator.GenerateAsync(values, options, cancellationToken);
                e.Add(new Embedding<float>(default));
                return e;
            })
            .Build();
 
        Assert.Equal(0, asyncLocal.Value);
 
        GeneratedEmbeddings<Embedding<float>> actual = await generator.GenerateAsync(expectedValues, expectedOptions, expectedCts.Token);
        Assert.Same(expectedEmbeddings, actual);
        Assert.Single(actual);
    }
}