|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Numerics.Tensors;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Extensions.DataIngestion.Chunkers;
/// <summary>
/// Splits a <see cref="IngestionDocument"/> into chunks based on semantic similarity between its elements based on cosine distance of their embeddings.
/// </summary>
public sealed class SemanticSimilarityChunker : IngestionChunker<string>
{
private readonly ElementsChunker _elementsChunker;
private readonly IEmbeddingGenerator<string, Embedding<float>> _embeddingGenerator;
private readonly float _thresholdPercentile;
/// <summary>
/// Initializes a new instance of the <see cref="SemanticSimilarityChunker"/> class.
/// </summary>
/// <param name="embeddingGenerator">Embedding generator.</param>
/// <param name="options">The options for the chunker.</param>
/// <param name="thresholdPercentile">Threshold percentile to consider the chunks to be sufficiently similar. 95th percentile will be used if not specified.</param>
public SemanticSimilarityChunker(
IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator,
IngestionChunkerOptions options,
float? thresholdPercentile = null)
{
_embeddingGenerator = embeddingGenerator;
_elementsChunker = new(options);
if (thresholdPercentile < 0f || thresholdPercentile > 100f)
{
Throw.ArgumentOutOfRangeException(nameof(thresholdPercentile), "Threshold percentile must be between 0 and 100.");
}
_thresholdPercentile = thresholdPercentile ?? 95.0f;
}
/// <inheritdoc/>
public override async IAsyncEnumerable<IngestionChunk<string>> ProcessAsync(IngestionDocument document,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(document);
List<(IngestionDocumentElement, float)> distances = await CalculateDistancesAsync(document, cancellationToken).ConfigureAwait(false);
foreach (var chunk in MakeChunks(document, distances))
{
yield return chunk;
}
}
private async Task<List<(IngestionDocumentElement element, float distance)>> CalculateDistancesAsync(IngestionDocument documents, CancellationToken cancellationToken)
{
List<(IngestionDocumentElement element, float distance)> elementDistances = [];
List<string> semanticContents = [];
foreach (IngestionDocumentElement element in documents.EnumerateContent())
{
string? semanticContent = element is IngestionDocumentImage img
? img.AlternativeText ?? img.Text
: element.GetMarkdown();
if (!string.IsNullOrEmpty(semanticContent))
{
elementDistances.Add((element, default));
semanticContents.Add(semanticContent!);
}
}
if (elementDistances.Count > 0)
{
var embeddings = await _embeddingGenerator.GenerateAsync(semanticContents, cancellationToken: cancellationToken).ConfigureAwait(false);
if (embeddings.Count != elementDistances.Count)
{
Throw.InvalidOperationException("The number of embeddings returned does not match the number of document elements.");
}
for (int i = 0; i < elementDistances.Count - 1; i++)
{
float distance = 1 - TensorPrimitives.CosineSimilarity(embeddings[i].Vector.Span, embeddings[i + 1].Vector.Span);
elementDistances[i] = (elementDistances[i].element, distance);
}
}
return elementDistances;
}
private IEnumerable<IngestionChunk<string>> MakeChunks(IngestionDocument document, List<(IngestionDocumentElement element, float distance)> elementDistances)
{
float distanceThreshold = Percentile(elementDistances);
List<IngestionDocumentElement> elementAccumulator = [];
string context = string.Empty;
for (int i = 0; i < elementDistances.Count; i++)
{
var (element, distance) = elementDistances[i];
elementAccumulator.Add(element);
if (distance > distanceThreshold || i == elementDistances.Count - 1)
{
foreach (var chunk in _elementsChunker.Process(document, context, elementAccumulator))
{
yield return chunk;
}
elementAccumulator.Clear();
}
}
}
private float Percentile(List<(IngestionDocumentElement element, float distance)> elementDistances)
{
if (elementDistances.Count == 0)
{
return 0f;
}
else if (elementDistances.Count == 1)
{
return elementDistances[0].distance;
}
float[] sorted = new float[elementDistances.Count];
for (int elementIndex = 0; elementIndex < elementDistances.Count; elementIndex++)
{
sorted[elementIndex] = elementDistances[elementIndex].distance;
}
Array.Sort(sorted);
float i = (_thresholdPercentile / 100f) * (sorted.Length - 1);
int i0 = (int)i;
int i1 = Math.Min(i0 + 1, sorted.Length - 1);
return sorted[i0] + ((i - i0) * (sorted[i1] - sorted[i0]));
}
}
|