File: src\Compilers\Core\Portable\InternalUtilities\ConcurrentLruCache.cs
Web Access
Project: src\src\VisualStudio\Core\Def\Microsoft.VisualStudio.LanguageServices_pxr0p0dn_wpftmp.csproj (Microsoft.VisualStudio.LanguageServices)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
 
namespace Microsoft.CodeAnalysis.InternalUtilities
{
    /// <summary>
    /// Cache with a fixed size that evicts the least recently used members.
    /// Thread-safe.
    /// </summary>
    internal class ConcurrentLruCache<K, V>
        where K : notnull
        where V : notnull
    {
        private readonly int _capacity;
 
        private struct CacheValue
        {
            public V Value;
            public LinkedListNode<K> Node;
        }
 
        private readonly Dictionary<K, CacheValue> _cache;
        private readonly LinkedList<K> _nodeList;
        // This is a naive course-grained lock, it can probably be optimized
        private readonly object _lockObject = new();
 
        public ConcurrentLruCache(int capacity)
        {
            if (capacity <= 0)
            {
                throw new ArgumentOutOfRangeException(nameof(capacity));
            }
            _capacity = capacity;
            _cache = new Dictionary<K, CacheValue>(capacity);
            _nodeList = new LinkedList<K>();
        }
 
        /// <summary>
        /// Create cache from an array. The cache capacity will be the size
        /// of the array. All elements of the array will be added to the 
        /// cache. If any duplicate keys are found in the array a
        /// <see cref="ArgumentException"/> will be thrown.
        /// </summary>
        public ConcurrentLruCache(KeyValuePair<K, V>[] array)
            : this(array.Length)
        {
            foreach (var kvp in array)
            {
                this.UnsafeAdd(kvp.Key, kvp.Value, true);
            }
        }
 
        /// <summary>
        /// For testing. Very expensive.
        /// </summary>
        internal IEnumerable<KeyValuePair<K, V>> TestingEnumerable
        {
            get
            {
                lock (_lockObject)
                {
                    var copy = new KeyValuePair<K, V>[_cache.Count];
                    int index = 0;
                    foreach (K key in _nodeList)
                    {
                        copy[index++] = new KeyValuePair<K, V>(key,
                                                               _cache[key].Value);
                    }
                    return copy;
                }
            }
        }
 
        public void Add(K key, V value)
        {
            lock (_lockObject)
            {
                UnsafeAdd(key, value, true);
            }
        }
 
        private void MoveNodeToTop(LinkedListNode<K> node)
        {
            if (!ReferenceEquals(_nodeList.First, node))
            {
                _nodeList.Remove(node);
                _nodeList.AddFirst(node);
            }
        }
 
        /// <summary>
        /// Expects non-empty cache. Does not lock.
        /// </summary>
        private void UnsafeEvictLastNode()
        {
            Debug.Assert(_capacity > 0);
            var lastNode = _nodeList.Last;
            _nodeList.Remove(lastNode!);
            _cache.Remove(lastNode!.Value);
        }
 
        private void UnsafeAddNodeToTop(K key, V value)
        {
            var node = new LinkedListNode<K>(key);
            _cache.Add(key, new CacheValue { Node = node, Value = value });
            _nodeList.AddFirst(node);
        }
 
        /// <summary>
        /// Doesn't lock.
        /// </summary>
        private void UnsafeAdd(K key, V value, bool throwExceptionIfKeyExists)
        {
            if (_cache.TryGetValue(key, out var result))
            {
                if (throwExceptionIfKeyExists)
                {
                    throw new ArgumentException("Key already exists", nameof(key));
                }
                else if (!result.Value.Equals(value))
                {
                    result.Value = value;
                    _cache[key] = result;
                    MoveNodeToTop(result.Node);
                }
            }
            else
            {
                if (_cache.Count == _capacity)
                {
                    UnsafeEvictLastNode();
                }
                UnsafeAddNodeToTop(key, value);
            }
        }
 
        public V this[K key]
        {
            get
            {
                if (TryGetValue(key, out var value))
                {
                    return value;
                }
                else
                {
                    throw new KeyNotFoundException();
                }
            }
            set
            {
                lock (_lockObject)
                {
                    UnsafeAdd(key, value, false);
                }
            }
        }
 
        public bool TryGetValue(K key, [MaybeNullWhen(returnValue: false)] out V value)
        {
            lock (_lockObject)
            {
                return UnsafeTryGetValue(key, out value);
            }
        }
 
        /// <summary>
        /// Doesn't lock.
        /// </summary>
        public bool UnsafeTryGetValue(K key, [MaybeNullWhen(returnValue: false)] out V value)
        {
            if (_cache.TryGetValue(key, out var result))
            {
                MoveNodeToTop(result.Node);
                value = result.Value;
                return true;
            }
            else
            {
                value = default!;
                return false;
            }
        }
 
        public V GetOrAdd(K key, V value)
        {
            lock (_lockObject)
            {
                if (UnsafeTryGetValue(key, out var result))
                {
                    return result;
                }
                else
                {
                    UnsafeAdd(key, value, true);
                    return value;
                }
            }
        }
 
        public V GetOrAdd(K key, Func<V> creator)
        {
            lock (_lockObject)
            {
                if (UnsafeTryGetValue(key, out var result))
                {
                    return result;
                }
                else
                {
                    var value = creator();
                    UnsafeAdd(key, value, true);
                    return value;
                }
            }
        }
 
        public V GetOrAdd<T>(K key, T arg, Func<T, V> creator)
        {
            lock (_lockObject)
            {
                if (UnsafeTryGetValue(key, out var result))
                {
                    return result;
                }
                else
                {
                    var value = creator(arg);
                    UnsafeAdd(key, value, true);
                    return value;
                }
            }
        }
    }
}