File: AspireOpenAIClientBuilderEmbeddingGeneratorExtensionsTests.cs
Web Access
Project: src\tests\Aspire.OpenAI.Tests\Aspire.OpenAI.Tests.csproj (Aspire.OpenAI.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Xunit;
 
namespace Aspire.OpenAI.Tests;
 
public class AspireOpenAIClientBuilderEmbeddingGeneratorExtensionsTests
{
    [Theory]
    [InlineData(true)]
    [InlineData(false)]
    public void CanReadDeploymentNameFromConfig(bool useKeyed)
    {
        var builder = Host.CreateEmptyApplicationBuilder(null);
        builder.Configuration.AddInMemoryCollection([
            new("Aspire:OpenAI:Endpoint", "https://aspireopenaitests.openai.azure.com/"),
            new("Aspire:OpenAI:Deployment", "testdeployment1"),
            new("Aspire:OpenAI:Key", "fake"),
        ]);
 
        if (useKeyed)
        {
            builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator");
        }
        else
        {
            builder.AddOpenAIClient("openai").AddEmbeddingGenerator();
        }
 
        using var host = builder.Build();
        var generator = useKeyed ?
            host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("openai_embeddinggenerator") :
            host.Services.GetRequiredService<IEmbeddingGenerator<string, Embedding<float>>>();
 
        Assert.NotNull(generator);
        Assert.Equal("testdeployment1", generator.Metadata.ModelId);
    }
 
    [Theory]
    [InlineData(true, "Model")]
    [InlineData(false, "Model")]
    [InlineData(true, "Deployment")]
    [InlineData(false, "Deployment")]
    public void CanReadDeploymentNameFromConnectionString(bool useKeyed, string connectionStringKey)
    {
        var builder = Host.CreateEmptyApplicationBuilder(null);
        builder.Configuration.AddInMemoryCollection([
            new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake;{connectionStringKey}=testdeployment1")
        ]);
 
        if (useKeyed)
        {
            builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator");
        }
        else
        {
            builder.AddOpenAIClient("openai").AddEmbeddingGenerator();
        }
 
        using var host = builder.Build();
        var generator = useKeyed ?
            host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("openai_embeddinggenerator") :
            host.Services.GetRequiredService<IEmbeddingGenerator<string, Embedding<float>>>();
 
        Assert.NotNull(generator);
        Assert.Equal("testdeployment1", generator.Metadata.ModelId);
    }
 
    [Theory]
    [InlineData(true)]
    [InlineData(false)]
    public void CanAcceptDeploymentNameAsArgument(bool useKeyed)
    {
        var builder = Host.CreateEmptyApplicationBuilder(null);
        builder.Configuration.AddInMemoryCollection([
            new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake")
        ]);
 
        if (useKeyed)
        {
            builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator", "testdeployment1");
        }
        else
        {
            builder.AddOpenAIClient("openai").AddEmbeddingGenerator("testdeployment1");
        }
 
        using var host = builder.Build();
        var generator = useKeyed ?
            host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("openai_embeddinggenerator") :
            host.Services.GetRequiredService<IEmbeddingGenerator<string, Embedding<float>>>();
 
        Assert.NotNull(generator);
        Assert.Equal("testdeployment1", generator.Metadata.ModelId);
    }
 
    [Theory]
    [InlineData(true)]
    [InlineData(false)]
    public void RejectsConnectionStringWithBothModelAndDeployment(bool useKeyed)
    {
        var builder = Host.CreateEmptyApplicationBuilder(null);
        builder.Configuration.AddInMemoryCollection([
            new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake;Deployment=testdeployment1;Model=something")
        ]);
 
        if (useKeyed)
        {
            builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator");
        }
        else
        {
            builder.AddOpenAIClient("openai").AddEmbeddingGenerator();
        }
 
        using var host = builder.Build();
        
        var ex = Assert.Throws<InvalidOperationException>(() =>
        {
            _ = useKeyed ?
                host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("openai_embeddinggenerator") :
                host.Services.GetRequiredService<IEmbeddingGenerator<string, Embedding<float>>>();
        });
 
        Assert.StartsWith("The connection string 'openai' contains both 'Deployment' and 'Model' keys.", ex.Message);
    }
 
    [Theory]
    [InlineData(true)]
    [InlineData(false)]
    public void RejectsDeploymentNameNotSpecified(bool useKeyed)
    {
        var builder = Host.CreateEmptyApplicationBuilder(null);
        builder.Configuration.AddInMemoryCollection([
            new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake")
        ]);
 
        if (useKeyed)
        {
            builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator");
        }
        else
        {
            builder.AddOpenAIClient("openai").AddEmbeddingGenerator();
        }
 
        using var host = builder.Build();
 
        var ex = Assert.Throws<InvalidOperationException>(() =>
        {
            _ = useKeyed ?
                host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("openai_embeddinggenerator") :
                host.Services.GetRequiredService<IEmbeddingGenerator<string, Embedding<float>>>();
        });
 
        Assert.StartsWith("The deployment could not be determined", ex.Message);
    }
 
    [Theory]
    [InlineData(true, false)]
    [InlineData(false, false)]
    [InlineData(true, true)]
    [InlineData(false, true)]
    public void AddsOpenTelemetry(bool useKeyed, bool disableOpenTelemetry)
    {
        var builder = Host.CreateEmptyApplicationBuilder(null);
        builder.Configuration.AddInMemoryCollection([
            new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake"),
            new("Aspire:OpenAI:DisableTracing", disableOpenTelemetry.ToString()),
        ]);
 
        if (useKeyed)
        {
            builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator", "testdeployment1");
        }
        else
        {
            builder.AddOpenAIClient("openai").AddEmbeddingGenerator("testdeployment1");
        }
 
        using var host = builder.Build();
        var generator = useKeyed ?
            host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("openai_embeddinggenerator") :
            host.Services.GetRequiredService<IEmbeddingGenerator<string, Embedding<float>>>();
 
        Assert.Equal(disableOpenTelemetry, generator.GetService<OpenTelemetryEmbeddingGenerator<string, Embedding<float>>>() is null);
    }
 
    [Theory]
    [InlineData(true)]
    [InlineData(false)]
    public async Task CanConfigurePipelineAsync(bool useKeyed)
    {
        var builder = Host.CreateEmptyApplicationBuilder(null);
        builder.Configuration.AddInMemoryCollection([
            new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake")
        ]);
 
        if (useKeyed)
        {
            builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator", "testdeployment1").Use(TestMiddleware);
        }
        else
        {
            builder.AddOpenAIClient("openai").AddEmbeddingGenerator("testdeployment1").Use(TestMiddleware);
        }
 
        using var host = builder.Build();
        var generator = useKeyed ?
            host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("openai_embeddinggenerator") :
            host.Services.GetRequiredService<IEmbeddingGenerator<string, Embedding<float>>>();
 
        var vector = await generator.GenerateEmbeddingVectorAsync("Hello");
        Assert.Equal(1.23f, vector.ToArray().Single());
    }
 
    private Task<GeneratedEmbeddings<Embedding<float>>> TestMiddleware(IEnumerable<string> inputs, EmbeddingGenerationOptions? options, IEmbeddingGenerator<string, Embedding<float>> nextAsync, CancellationToken cancellationToken)
    {
        float[] floats = [1.23f];
        return Task.FromResult(new GeneratedEmbeddings<Embedding<float>>(inputs.Select(i => new Embedding<float>(floats))));
    }
}