File: Utils\LruCache.cs
Web Access
Project: src\src\Microsoft.ML.Tokenizers\Microsoft.ML.Tokenizers.csproj (Microsoft.ML.Tokenizers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
 
namespace Microsoft.ML.Tokenizers
{
    internal sealed class LruCache<TValue>
    {
        /// <summary>
        /// The default LRU cache size.
        /// </summary>
        public const int DefaultCacheSize = 8192;
 
        private readonly Dictionary<StringSpanOrdinalKey, LinkedListNode<KeyValuePair<string, TValue>>> _cache = new();
        private readonly LinkedList<KeyValuePair<string, TValue>> _lruList = new();
        private readonly int _cacheSize;
 
        private object SyncObj => _cache;
 
        /// <summary>
        /// Constructs an <see cref="LruCache{TValue}" /> object.
        /// </summary>
        /// <param name="cacheSize">
        /// The maximum number of mappings that can be cached. This defaults to <see cref="DefaultCacheSize" />, which is set to <value>8192</value>.
        /// </param>
        public LruCache(int cacheSize = DefaultCacheSize)
        {
            if (cacheSize <= 0)
            {
                throw new ArgumentOutOfRangeException(nameof(cacheSize), "Cache size must be a positive number.");
            }
 
            _cacheSize = cacheSize;
        }
 
        /// <summary>
        /// Retrieves the value associated with the specified key /> object.
        /// </summary>
        /// <param name="key">The object to be used as a key.</param>
        /// <param name="value">An out parameter that is set to the value of the key if key contains a mapping in the cache.</param>
        /// <returns>
        /// true if the cache contains a mapping for key, false otherwise.
        /// </returns>
        public bool TryGetValue(string key, out TValue value)
        {
            lock (SyncObj)
            {
                if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode<KeyValuePair<string, TValue>>? cached))
                {
                    _lruList.Remove(cached);
                    _lruList.AddFirst(cached);
                    value = cached.Value.Value;
                    return true;
                }
 
                value = default!;
                return false;
            }
        }
 
        /// <summary>
        /// Retrieves the value associated with the specified key /> object.
        /// </summary>
        /// <param name="key">The object to be used as a key.</param>
        /// <param name="value">An out parameter that is set to the value of the key if key contains a mapping in the cache.</param>
        /// <returns>
        /// true if the cache contains a mapping for key, false otherwise.
        /// </returns>
        public unsafe bool TryGetValue(ReadOnlySpan<char> key, out TValue value)
        {
            lock (SyncObj)
            {
                fixed (char* ptr = key)
                {
                    if (_cache.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out LinkedListNode<KeyValuePair<string, TValue>>? cached))
                    {
                        _lruList.Remove(cached);
                        _lruList.AddFirst(cached);
                        value = cached.Value.Value;
                        return true;
                    }
                }
 
                value = default!;
                return false;
            }
        }
 
        /// <summary>
        /// Adds or replaces a mapping in the cache.
        /// </summary>
        /// <param name="key">The key whose mapped <paramref name="value" /> is to be created or replaced.</param>
        /// <param name="value">The new value to be mapped to the <paramref name="key" />.</param>
        public void Add(string key, TValue value)
        {
            lock (SyncObj)
            {
                if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode<KeyValuePair<string, TValue>>? cached))
                {
                    cached.Value = new KeyValuePair<string, TValue>(key, value);
                    _lruList.Remove(cached);
                    _lruList.AddFirst(cached);
                    return;
                }
 
                while (_cache.Count >= _cacheSize)
                {
                    LinkedListNode<KeyValuePair<string, TValue>>? nodeToEvict = _lruList.Last;
                    _lruList.RemoveLast();
                    _cache.Remove(new StringSpanOrdinalKey(nodeToEvict!.Value.Key));
                }
 
                var node = new LinkedListNode<KeyValuePair<string, TValue>>(new KeyValuePair<string, TValue>(key, value));
                _cache[new StringSpanOrdinalKey(key)] = node;
                _lruList.AddFirst(node);
            }
        }
    }
}