File: Utils\Batching.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.DataIngestion\Microsoft.Extensions.DataIngestion.csproj (Microsoft.Extensions.DataIngestion)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
#if NET10_0_OR_GREATER
using System.Linq;
#endif
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Shared.Diagnostics;
 
namespace Microsoft.Extensions.DataIngestion;
 
internal static class Batching
{
    internal static async IAsyncEnumerable<IngestionChunk<string>> ProcessAsync<TMetadata>(IAsyncEnumerable<IngestionChunk<string>> chunks,
        EnricherOptions options,
        string metadataKey,
        ChatMessage systemPrompt,
        ILogger? logger,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
        where TMetadata : notnull
    {
        _ = Throw.IfNull(chunks);
 
        await foreach (var batch in chunks.Chunk(options.BatchSize).WithCancellation(cancellationToken))
        {
            List<AIContent> contents = new(batch.Length);
            foreach (var chunk in batch)
            {
                contents.Add(new TextContent(chunk.Content));
            }
 
            try
            {
                ChatResponse<TMetadata[]> response = await options.ChatClient.GetResponseAsync<TMetadata[]>(
                [
                    systemPrompt,
                    new(ChatRole.User, contents)
                ], options.ChatOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
 
                if (response.Result.Length == contents.Count)
                {
                    for (int i = 0; i < response.Result.Length; i++)
                    {
                        batch[i].Metadata[metadataKey] = response.Result[i];
                    }
                }
                else
                {
                    logger?.UnexpectedResultsCount(response.Result.Length, contents.Count);
                }
            }
#pragma warning disable CA1031 // Do not catch general exception types
            catch (Exception ex)
#pragma warning restore CA1031 // Do not catch general exception types
            {
                // Enricher failures should not fail the whole ingestion pipeline, as they are best-effort enhancements.
                logger?.UnexpectedEnricherFailure(ex);
            }
 
            foreach (var chunk in batch)
            {
                yield return chunk;
            }
        }
    }
 
#if !NET10_0_OR_GREATER
#pragma warning disable VSTHRD200 // Use "Async" suffix for async methods
    private static IAsyncEnumerable<TSource[]> Chunk<TSource>(this IAsyncEnumerable<TSource> source, int count)
#pragma warning restore VSTHRD200 // Use "Async" suffix for async methods
    {
        _ = Throw.IfNull(source);
        _ = Throw.IfLessThanOrEqual(count, 0);
 
        return CoreAsync(source, count);
 
        static async IAsyncEnumerable<TSource[]> CoreAsync(IAsyncEnumerable<TSource> source, int count,
            [EnumeratorCancellation] CancellationToken cancellationToken = default)
        {
            var buffer = new TSource[count];
            int index = 0;
 
            await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
            {
                buffer[index++] = item;
 
                if (index == count)
                {
                    index = 0;
                    yield return buffer;
                }
            }
 
            if (index > 0)
            {
                Array.Resize(ref buffer, index);
                yield return buffer;
            }
        }
    }
#endif
}