File: AzureAIInferenceImageEmbeddingGenerator.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.AzureAIInference\Microsoft.Extensions.AI.AzureAIInference.csproj (Microsoft.Extensions.AI.AzureAIInference)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.Inference;
using Microsoft.Shared.Diagnostics;
 
#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
#pragma warning disable S109 // Magic numbers should not be used
 
namespace Microsoft.Extensions.AI;
 
/// <summary>Represents an <see cref="IEmbeddingGenerator{DataContent, Embedding}"/> for an Azure.AI.Inference <see cref="ImageEmbeddingsClient"/>.</summary>
internal sealed class AzureAIInferenceImageEmbeddingGenerator :
    IEmbeddingGenerator<DataContent, Embedding<float>>
{
    /// <summary>Metadata about the embedding generator.</summary>
    private readonly EmbeddingGeneratorMetadata _metadata;
 
    /// <summary>The underlying <see cref="ImageEmbeddingsClient" />.</summary>
    private readonly ImageEmbeddingsClient _imageEmbeddingsClient;
 
    /// <summary>The number of dimensions produced by the generator.</summary>
    private readonly int? _dimensions;
 
    /// <summary>Initializes a new instance of the <see cref="AzureAIInferenceImageEmbeddingGenerator"/> class.</summary>
    /// <param name="imageEmbeddingsClient">The underlying client.</param>
    /// <param name="defaultModelId">
    /// The ID of the model to use. This can also be overridden per request via <see cref="EmbeddingGenerationOptions.ModelId"/>.
    /// Either this parameter or <see cref="EmbeddingGenerationOptions.ModelId"/> must provide a valid model ID.
    /// </param>
    /// <param name="defaultModelDimensions">The number of dimensions to generate in each embedding.</param>
    /// <exception cref="ArgumentNullException"><paramref name="imageEmbeddingsClient"/> is <see langword="null"/>.</exception>
    /// <exception cref="ArgumentException"><paramref name="defaultModelId"/> is empty or composed entirely of whitespace.</exception>
    /// <exception cref="ArgumentOutOfRangeException"><paramref name="defaultModelDimensions"/> is not positive.</exception>
    public AzureAIInferenceImageEmbeddingGenerator(
        ImageEmbeddingsClient imageEmbeddingsClient, string? defaultModelId = null, int? defaultModelDimensions = null)
    {
        _ = Throw.IfNull(imageEmbeddingsClient);
 
        if (defaultModelId is not null)
        {
            _ = Throw.IfNullOrWhitespace(defaultModelId);
        }
 
        if (defaultModelDimensions is < 1)
        {
            Throw.ArgumentOutOfRangeException(nameof(defaultModelDimensions), "Value must be greater than 0.");
        }
 
        _imageEmbeddingsClient = imageEmbeddingsClient;
        _dimensions = defaultModelDimensions;
 
        // https://github.com/Azure/azure-sdk-for-net/issues/46278
        // The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages
        // implement the abstractions directly rather than providing adapters on top of the public APIs,
        // the package can provide such implementations separate from what's exposed in the public API.
        var providerUrl = typeof(ImageEmbeddingsClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
            ?.GetValue(imageEmbeddingsClient) as Uri;
 
        _metadata = new EmbeddingGeneratorMetadata("az.ai.inference", providerUrl, defaultModelId, defaultModelDimensions);
    }
 
    /// <inheritdoc />
    object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey)
    {
        _ = Throw.IfNull(serviceType);
 
        return
            serviceKey is not null ? null :
            serviceType == typeof(ImageEmbeddingsClient) ? _imageEmbeddingsClient :
            serviceType == typeof(EmbeddingGeneratorMetadata) ? _metadata :
            serviceType.IsInstanceOfType(this) ? this :
            null;
    }
 
    /// <inheritdoc />
    public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(
        IEnumerable<DataContent> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
    {
        _ = Throw.IfNull(values);
 
        var azureAIOptions = ToAzureAIOptions(values, options, EmbeddingEncodingFormat.Base64);
 
        var embeddings = (await _imageEmbeddingsClient.EmbedAsync(azureAIOptions, cancellationToken).ConfigureAwait(false)).Value;
 
        GeneratedEmbeddings<Embedding<float>> result = new(embeddings.Data.Select(e =>
            new Embedding<float>(AzureAIInferenceEmbeddingGenerator.ParseBase64Floats(e.Embedding))
            {
                CreatedAt = DateTimeOffset.UtcNow,
                ModelId = embeddings.Model ?? azureAIOptions.Model,
            }));
 
        if (embeddings.Usage is not null)
        {
            result.Usage = new()
            {
                InputTokenCount = embeddings.Usage.PromptTokens,
                TotalTokenCount = embeddings.Usage.TotalTokens
            };
        }
 
        return result;
    }
 
    /// <inheritdoc />
    void IDisposable.Dispose()
    {
        // Nothing to dispose. Implementation required for the IEmbeddingGenerator interface.
    }
 
    /// <summary>Converts an extensions options instance to an Azure.AI.Inference options instance.</summary>
    private ImageEmbeddingsOptions ToAzureAIOptions(IEnumerable<DataContent> inputs, EmbeddingGenerationOptions? options, EmbeddingEncodingFormat format)
    {
        ImageEmbeddingsOptions result = new(inputs.Select(dc => new ImageEmbeddingInput(dc.Uri)))
        {
            Dimensions = options?.Dimensions ?? _dimensions,
            Model = options?.ModelId ?? _metadata.DefaultModelId,
            EncodingFormat = format,
        };
 
        if (options?.AdditionalProperties is { } props)
        {
            foreach (var prop in props)
            {
                if (prop.Value is not null)
                {
                    byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object)));
                    result.AdditionalProperties[prop.Key] = new BinaryData(data);
                }
            }
        }
 
        return result;
    }
}