File: Utility\Cache.cs
Web Access
Project: src\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj (Microsoft.ML.GenAI.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using TorchSharp;
using static TorchSharp.torch;
 
namespace Microsoft.ML.GenAI.Core;
 
public interface IKVCache : IDictionary<int, (Tensor, Tensor)>, IDisposable
{
    public (Tensor, Tensor) UpdateKVCache(Tensor key, Tensor value, int layerIndex);
 
    public int GetSeqLen(int layerIndex = 0);
 
    public int? GetMaxLength();
 
    public int GetUsableLength(int newSeqLen, int layerIndex = 0);
}
 
public class DynamicKVCache : Dictionary<int, (Tensor, Tensor)>, IKVCache
{
    private readonly DisposeScope _disposeScope = NewDisposeScope();
    public DynamicKVCache()
    {
    }
 
    public (Tensor, Tensor) UpdateKVCache(Tensor key, Tensor value, int layerIndex)
    {
        if (this.ContainsKey(layerIndex))
        {
            var (oldKey, oldValue) = this[layerIndex];
            oldKey.DetachFromDisposeScope();
            oldValue.DetachFromDisposeScope();
 
            var newKey = torch.cat([oldKey, key], -2).MoveToOtherDisposeScope(this._disposeScope);
            var newValue = torch.cat([oldValue, value], -2).MoveToOtherDisposeScope(this._disposeScope);
 
            oldKey.Dispose();
            oldValue.Dispose();
 
            this[layerIndex] = (newKey, newValue);
        }
        else
        {
            this.Add(layerIndex, (key.MoveToOtherDisposeScope(this._disposeScope), value.MoveToOtherDisposeScope(this._disposeScope)));
        }
 
        return this[layerIndex];
    }
 
    public int GetSeqLen(int layerIndex = 0)
    {
        if (this.TryGetValue(layerIndex, out var kv))
        {
            return kv.Item1.IntShape()[^2];
        }
 
        return 0;
    }
 
    public int? GetMaxLength()
    {
        return null;
    }
 
    public int GetUsableLength(int newSeqLen, int layerIndex = 0)
    {
        var maxLength = this.GetMaxLength();
        var previousSeqLen = this.GetSeqLen(layerIndex);
 
        if (maxLength.HasValue && previousSeqLen + newSeqLen > maxLength.Value)
        {
            return maxLength.Value - previousSeqLen;
        }
 
        return previousSeqLen;
    }
 
    public void Dispose()
    {
        this._disposeScope.Dispose();
    }
}