| 
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.VectorData;
using Microsoft.Shared.Diagnostics;
 
namespace Microsoft.Extensions.DataIngestion;
 
/// <summary>
/// Writes chunks to the <see cref="VectorStore"/> using the default schema. 
/// </summary>
/// <typeparam name="T">The type of the chunk content.</typeparam>
public sealed class VectorStoreWriter<T> : IngestionChunkWriter<T>
{
    // The names are lowercase with no special characters to ensure compatibility with various vector stores.
    private const string KeyName = "key";
    private const string EmbeddingName = "embedding";
    private const string ContentName = "content";
    private const string ContextName = "context";
    private const string DocumentIdName = "documentid";
 
    private readonly VectorStore _vectorStore;
    private readonly int _dimensionCount;
    private readonly VectorStoreWriterOptions _options;
    private readonly bool _keysAreStrings;
 
    private VectorStoreCollection<object, Dictionary<string, object?>>? _vectorStoreCollection;
 
    /// <summary>
    /// Initializes a new instance of the <see cref="VectorStoreWriter{T}"/> class.
    /// </summary>
    /// <param name="vectorStore">The <see cref="VectorStore"/> to use to store the <see cref="IngestionChunk{T}"/> instances.</param>
    /// <param name="dimensionCount">The number of dimensions that the vector has. This value is required when creating collections.</param>
    /// <param name="options">The options for the vector store writer.</param>
    /// <exception cref="ArgumentNullException">When <paramref name="vectorStore"/> is null.</exception>
    /// <exception cref="ArgumentOutOfRangeException">When <paramref name="dimensionCount"/> is less or equal zero.</exception>
    public VectorStoreWriter(VectorStore vectorStore, int dimensionCount, VectorStoreWriterOptions? options = default)
    {
        _vectorStore = Throw.IfNull(vectorStore);
        _dimensionCount = Throw.IfLessThanOrEqual(dimensionCount, 0);
        _options = options ?? new VectorStoreWriterOptions();
 
        // Not all vector store support string as the key type, examples:
        // Qdrant: https://github.com/microsoft/semantic-kernel/blob/28ea2f4df872e8fd03ef0792ebc9e1989b4be0ee/dotnet/src/VectorData/Qdrant/QdrantCollection.cs#L104
        // When https://github.com/microsoft/semantic-kernel/issues/13141 gets released,
        // we need to remove this workaround.
        _keysAreStrings = vectorStore.GetType().Name != "QdrantVectorStore";
    }
 
    /// <summary>
    /// Gets the underlying <see cref="VectorStoreCollection{TKey,TRecord}"/> used to store the chunks.
    /// </summary>
    /// <remarks>
    /// The collection is initialized when <see cref="WriteAsync(IAsyncEnumerable{IngestionChunk{T}}, CancellationToken)"/> is called for the first time.
    /// </remarks>
    /// <exception cref="InvalidOperationException">The collection has not been initialized yet.
    /// Call <see cref="WriteAsync(IAsyncEnumerable{IngestionChunk{T}}, CancellationToken)"/> first.</exception>
    public VectorStoreCollection<object, Dictionary<string, object?>> VectorStoreCollection
        => _vectorStoreCollection ?? throw new InvalidOperationException("The collection has not been initialized yet. Call WriteAsync first.");
 
    /// <inheritdoc/>
    public override async Task WriteAsync(IAsyncEnumerable<IngestionChunk<T>> chunks, CancellationToken cancellationToken = default)
    {
        _ = Throw.IfNull(chunks);
 
        IReadOnlyList<object>? preExistingKeys = null;
        await foreach (IngestionChunk<T> chunk in chunks.WithCancellation(cancellationToken))
        {
            if (_vectorStoreCollection is null)
            {
                _vectorStoreCollection = _vectorStore.GetDynamicCollection(_options.CollectionName, GetVectorStoreRecordDefinition(chunk));
 
                await _vectorStoreCollection.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false);
            }
 
            // We obtain the IDs of the pre-existing chunks for given document,
            // and delete them after we finish inserting the new chunks,
            // to avoid a situation where we delete the chunks and then fail to insert the new ones.
            preExistingKeys ??= await GetPreExistingChunksIdsAsync(chunk.Document, cancellationToken).ConfigureAwait(false);
 
            var key = Guid.NewGuid();
            Dictionary<string, object?> record = new()
            {
                [KeyName] = _keysAreStrings ? key.ToString() : key,
                [ContentName] = chunk.Content,
                [EmbeddingName] = chunk.Content,
                [ContextName] = chunk.Context,
                [DocumentIdName] = chunk.Document.Identifier,
            };
 
            if (chunk.HasMetadata)
            {
                foreach (var metadata in chunk.Metadata)
                {
                    record[metadata.Key] = metadata.Value;
                }
            }
 
            await _vectorStoreCollection.UpsertAsync(record, cancellationToken).ConfigureAwait(false);
        }
 
        if (preExistingKeys?.Count > 0)
        {
            await _vectorStoreCollection!.DeleteAsync(preExistingKeys, cancellationToken).ConfigureAwait(false);
        }
    }
 
    /// <inheritdoc/>
    protected override void Dispose(bool disposing)
    {
        try
        {
            _vectorStoreCollection?.Dispose();
        }
        finally
        {
            _vectorStore.Dispose();
            base.Dispose(disposing);
        }
    }
 
    private VectorStoreCollectionDefinition GetVectorStoreRecordDefinition(IngestionChunk<T> representativeChunk)
    {
        VectorStoreCollectionDefinition definition = new()
        {
            Properties =
            {
                new VectorStoreKeyProperty(KeyName, _keysAreStrings ? typeof(string) : typeof(Guid)),
 
                // By using T as the type here we allow the vector store
                // to handle the conversion from T to the actual vector type it supports.
                new VectorStoreVectorProperty(EmbeddingName, typeof(T), _dimensionCount)
                {
                    DistanceFunction = _options.DistanceFunction,
                    IndexKind = _options.IndexKind
                },
                new VectorStoreDataProperty(ContentName, typeof(T)),
                new VectorStoreDataProperty(ContextName, typeof(string)),
                new VectorStoreDataProperty(DocumentIdName, typeof(string))
                {
                    IsIndexed = true
                }
            }
        };
 
        if (representativeChunk.HasMetadata)
        {
            foreach (var metadata in representativeChunk.Metadata)
            {
                Type propertyType = metadata.Value.GetType();
                definition.Properties.Add(new VectorStoreDataProperty(metadata.Key, propertyType)
                {
                    // We use lowercase storage names to ensure compatibility with various vector stores.
#pragma warning disable CA1308 // Normalize strings to uppercase
                    StorageName = metadata.Key.ToLowerInvariant()
#pragma warning restore CA1308 // Normalize strings to uppercase
 
                    // We could consider indexing for certain keys like classification etc. but for now we leave it as non-indexed.
                    // The reason is that not every DB supports it, moreover we would need to expose the ability to configure it.
                });
            }
        }
 
        return definition;
    }
 
    private async Task<IReadOnlyList<object>> GetPreExistingChunksIdsAsync(IngestionDocument document, CancellationToken cancellationToken)
    {
        if (!_options.IncrementalIngestion)
        {
            return [];
        }
 
        // Each Vector Store has a different max top count limit, so we use low value and loop.
        const int MaxTopCount = 1_000;
 
        List<object> keys = [];
        int insertedCount;
        do
        {
            insertedCount = 0;
 
            await foreach (var record in _vectorStoreCollection!.GetAsync(
                filter: record => (string)record[DocumentIdName]! == document.Identifier,
                top: MaxTopCount,
                cancellationToken: cancellationToken).ConfigureAwait(false))
            {
                keys.Add(record[KeyName]!);
                insertedCount++;
            }
        }
        while (insertedCount == MaxTopCount);
 
        return keys;
    }
}
 |