File: ToolReduction\EmbeddingToolReductionStrategy.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI\Microsoft.Extensions.AI.csproj (Microsoft.Extensions.AI)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Numerics.Tensors;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
 
namespace Microsoft.Extensions.AI;
 
#pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14
 
/// <summary>
/// A tool reduction strategy that ranks tools by embedding similarity to the current conversation context.
/// </summary>
/// <remarks>
/// The strategy embeds each tool (name + description by default) once (cached) and embeds the current
/// conversation content each request. It then selects the top <c>toolLimit</c> tools by similarity.
/// </remarks>
[Experimental("MEAI001")]
public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy
{
    private readonly ConditionalWeakTable<AITool, Embedding<float>> _toolEmbeddingsCache = new();
    private readonly IEmbeddingGenerator<string, Embedding<float>> _embeddingGenerator;
    private readonly int _toolLimit;
 
    private Func<AITool, string> _toolEmbeddingTextSelector = static t =>
    {
        if (string.IsNullOrWhiteSpace(t.Name))
        {
            return t.Description;
        }
 
        if (string.IsNullOrWhiteSpace(t.Description))
        {
            return t.Name;
        }
 
        return t.Name + Environment.NewLine + t.Description;
    };
 
    private Func<IEnumerable<ChatMessage>, ValueTask<string>> _messagesEmbeddingTextSelector = static messages =>
    {
        var sb = new StringBuilder();
        foreach (var message in messages)
        {
            var contents = message.Contents;
            for (var i = 0; i < contents.Count; i++)
            {
                string text;
                switch (contents[i])
                {
                    case TextContent content:
                        text = content.Text;
                        break;
                    case TextReasoningContent content:
                        text = content.Text;
                        break;
                    default:
                        continue;
                }
 
                _ = sb.AppendLine(text);
            }
        }
 
        return new ValueTask<string>(sb.ToString());
    };
 
    private Func<ReadOnlyMemory<float>, ReadOnlyMemory<float>, float> _similarity = static (a, b) => TensorPrimitives.CosineSimilarity(a.Span, b.Span);
 
    private Func<AITool, bool> _isRequiredTool = static _ => false;
 
    /// <summary>
    /// Initializes a new instance of the <see cref="EmbeddingToolReductionStrategy"/> class.
    /// </summary>
    /// <param name="embeddingGenerator">Embedding generator used to produce embeddings.</param>
    /// <param name="toolLimit">Maximum number of tools to return, excluding required tools. Must be greater than zero.</param>
    public EmbeddingToolReductionStrategy(
        IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator,
        int toolLimit)
    {
        _embeddingGenerator = Throw.IfNull(embeddingGenerator);
        _toolLimit = Throw.IfLessThanOrEqual(toolLimit, min: 0);
    }
 
    /// <summary>
    /// Gets or sets the selector used to generate a single text string from a tool.
    /// </summary>
    /// <remarks>
    /// Defaults to: <c>Name + "\n" + Description</c> (omitting empty parts).
    /// </remarks>
    public Func<AITool, string> ToolEmbeddingTextSelector
    {
        get => _toolEmbeddingTextSelector;
        set => _toolEmbeddingTextSelector = Throw.IfNull(value);
    }
 
    /// <summary>
    /// Gets or sets the selector used to generate a single text string from a collection of chat messages for
    /// embedding purposes.
    /// </summary>
    public Func<IEnumerable<ChatMessage>, ValueTask<string>> MessagesEmbeddingTextSelector
    {
        get => _messagesEmbeddingTextSelector;
        set => _messagesEmbeddingTextSelector = Throw.IfNull(value);
    }
 
    /// <summary>
    /// Gets or sets a similarity function applied to (query, tool) embedding vectors.
    /// </summary>
    /// <remarks>
    /// Defaults to cosine similarity.
    /// </remarks>
    public Func<ReadOnlyMemory<float>, ReadOnlyMemory<float>, float> Similarity
    {
        get => _similarity;
        set => _similarity = Throw.IfNull(value);
    }
 
    /// <summary>
    /// Gets or sets a function that determines whether a tool is required (always included).
    /// </summary>
    /// <remarks>
    /// If this returns <see langword="true"/>, the tool is included regardless of ranking and does not count against
    /// the configured non-required tool limit. A tool explicitly named by <see cref="RequiredChatToolMode"/> (when
    /// <see cref="RequiredChatToolMode.RequiredFunctionName"/> is non-null) is also treated as required, independent
    /// of this delegate's result.
    /// </remarks>
    public Func<AITool, bool> IsRequiredTool
    {
        get => _isRequiredTool;
        set => _isRequiredTool = Throw.IfNull(value);
    }
 
    /// <summary>
    /// Gets or sets a value indicating whether to preserve original ordering of selected tools.
    /// If <see langword="false"/> (default), tools are ordered by descending similarity.
    /// If <see langword="true"/>, the top-N tools by similarity are re-emitted in their original order.
    /// </summary>
    public bool PreserveOriginalOrdering { get; set; }
 
    /// <inheritdoc />
    public async Task<IEnumerable<AITool>> SelectToolsForRequestAsync(
        IEnumerable<ChatMessage> messages,
        ChatOptions? options,
        CancellationToken cancellationToken = default)
    {
        _ = Throw.IfNull(messages);
 
        if (options?.Tools is not { Count: > 0 } tools)
        {
            // Prefer the original tools list reference if possible.
            // This allows ToolReducingChatClient to avoid unnecessarily copying ChatOptions.
            // When no reduction is performed.
            return options?.Tools ?? [];
        }
 
        Debug.Assert(_toolLimit > 0, "Expected the tool count limit to be greater than zero.");
 
        if (tools.Count <= _toolLimit)
        {
            // Since the total number of tools doesn't exceed the configured tool limit,
            // there's no need to determine which tools are optional, i.e., subject to reduction.
            // We can return the original tools list early.
            return tools;
        }
 
        var toolRankingInfoArray = ArrayPool<AIToolRankingInfo>.Shared.Rent(tools.Count);
        try
        {
            var toolRankingInfoMemory = toolRankingInfoArray.AsMemory(start: 0, length: tools.Count);
 
            // We allocate tool rankings in a contiguous chunk of memory, but partition them such that
            // required tools come first and are immediately followed by optional tools.
            // This allows us to separately rank optional tools by similarity score, but then later re-order
            // the top N tools (including required tools) to preserve their original relative order.
            var (requiredTools, optionalTools) = PartitionToolRankings(toolRankingInfoMemory, tools, options.ToolMode);
 
            if (optionalTools.Length <= _toolLimit)
            {
                // There aren't enough optional tools to require reduction, so we'll return the original
                // tools list.
                return tools;
            }
 
            // Build query text from recent messages.
            var queryText = await MessagesEmbeddingTextSelector(messages).ConfigureAwait(false);
            if (string.IsNullOrWhiteSpace(queryText))
            {
                // We couldn't build a meaningful query, likely because the message list was empty.
                // We'll just return the original tools list.
                return tools;
            }
 
            var queryEmbedding = await _embeddingGenerator.GenerateAsync(queryText, cancellationToken: cancellationToken).ConfigureAwait(false);
 
            // Compute and populate similarity scores in the tool ranking info.
            await ComputeSimilarityScoresAsync(optionalTools, queryEmbedding, cancellationToken);
 
            var topTools = toolRankingInfoMemory.Slice(start: 0, length: requiredTools.Length + _toolLimit);
#if NET
            optionalTools.Span.Sort(AIToolRankingInfo.CompareByDescendingSimilarityScore);
            if (PreserveOriginalOrdering)
            {
                topTools.Span.Sort(AIToolRankingInfo.CompareByOriginalIndex);
            }
#else
            Array.Sort(toolRankingInfoArray, index: requiredTools.Length, length: optionalTools.Length, AIToolRankingInfo.CompareByDescendingSimilarityScore);
            if (PreserveOriginalOrdering)
            {
                Array.Sort(toolRankingInfoArray, index: 0, length: topTools.Length, AIToolRankingInfo.CompareByOriginalIndex);
            }
#endif
            return ToToolList(topTools.Span);
 
            static List<AITool> ToToolList(ReadOnlySpan<AIToolRankingInfo> toolInfo)
            {
                var result = new List<AITool>(capacity: toolInfo.Length);
                foreach (var info in toolInfo)
                {
                    result.Add(info.Tool);
                }
 
                return result;
            }
        }
        finally
        {
            ArrayPool<AIToolRankingInfo>.Shared.Return(toolRankingInfoArray);
        }
    }
 
    private (Memory<AIToolRankingInfo> RequiredTools, Memory<AIToolRankingInfo> OptionalTools) PartitionToolRankings(
        Memory<AIToolRankingInfo> toolRankingInfo, IList<AITool> tools, ChatToolMode? toolMode)
    {
        // Always include a tool if its name matches the required function name.
        var requiredFunctionName = (toolMode as RequiredChatToolMode)?.RequiredFunctionName;
        var nextRequiredToolIndex = 0;
        var nextOptionalToolIndex = tools.Count - 1;
        for (var i = 0; i < toolRankingInfo.Length; i++)
        {
            var tool = tools[i];
            var isRequiredByToolMode = requiredFunctionName is not null && string.Equals(requiredFunctionName, tool.Name, StringComparison.Ordinal);
            var toolIndex = isRequiredByToolMode || IsRequiredTool(tool)
                ? nextRequiredToolIndex++
                : nextOptionalToolIndex--;
            toolRankingInfo.Span[toolIndex] = new AIToolRankingInfo(tool, originalIndex: i);
        }
 
        return (
            RequiredTools: toolRankingInfo.Slice(0, nextRequiredToolIndex),
            OptionalTools: toolRankingInfo.Slice(nextRequiredToolIndex));
    }
 
    private async Task ComputeSimilarityScoresAsync(Memory<AIToolRankingInfo> toolInfo, Embedding<float> queryEmbedding, CancellationToken cancellationToken)
    {
        var anyCacheMisses = false;
        List<string> cacheMissToolEmbeddingTexts = null!;
        List<int> cacheMissToolInfoIndexes = null!;
        for (var i = 0; i < toolInfo.Length; i++)
        {
            ref var info = ref toolInfo.Span[i];
            if (_toolEmbeddingsCache.TryGetValue(info.Tool, out var toolEmbedding))
            {
                info.SimilarityScore = Similarity(queryEmbedding.Vector, toolEmbedding.Vector);
            }
            else
            {
                if (!anyCacheMisses)
                {
                    anyCacheMisses = true;
                    cacheMissToolEmbeddingTexts = [];
                    cacheMissToolInfoIndexes = [];
                }
 
                var text = ToolEmbeddingTextSelector(info.Tool);
                cacheMissToolEmbeddingTexts.Add(text);
                cacheMissToolInfoIndexes.Add(i);
            }
        }
 
        if (!anyCacheMisses)
        {
            // There were no cache misses; no more work to do.
            return;
        }
 
        var uncachedEmbeddings = await _embeddingGenerator.GenerateAsync(cacheMissToolEmbeddingTexts, cancellationToken: cancellationToken).ConfigureAwait(false);
        if (uncachedEmbeddings.Count != cacheMissToolEmbeddingTexts.Count)
        {
            throw new InvalidOperationException($"Expected {cacheMissToolEmbeddingTexts.Count} embeddings, got {uncachedEmbeddings.Count}.");
        }
 
        for (var i = 0; i < uncachedEmbeddings.Count; i++)
        {
            var toolInfoIndex = cacheMissToolInfoIndexes[i];
            var toolEmbedding = uncachedEmbeddings[i];
            ref var info = ref toolInfo.Span[toolInfoIndex];
            info.SimilarityScore = Similarity(queryEmbedding.Vector, toolEmbedding.Vector);
            _toolEmbeddingsCache.Add(info.Tool, toolEmbedding);
        }
    }
 
    private struct AIToolRankingInfo(AITool tool, int originalIndex)
    {
        public static readonly Comparer<AIToolRankingInfo> CompareByDescendingSimilarityScore
            = Comparer<AIToolRankingInfo>.Create(static (a, b) =>
            {
                var result = b.SimilarityScore.CompareTo(a.SimilarityScore);
                return result != 0
                    ? result
                    : a.OriginalIndex.CompareTo(b.OriginalIndex); // Stabilize ties.
            });
 
        public static readonly Comparer<AIToolRankingInfo> CompareByOriginalIndex
            = Comparer<AIToolRankingInfo>.Create(static (a, b) => a.OriginalIndex.CompareTo(b.OriginalIndex));
 
        public AITool Tool { get; } = tool;
        public int OriginalIndex { get; } = originalIndex;
        public float SimilarityScore { get; set; }
    }
}