File: src\nuget-client\build\Shared\TaskResultCache.cs
Web Access
Project: src\src\nuget-client\src\NuGet.Core\NuGet.DependencyResolver.Core\NuGet.DependencyResolver.Core.csproj (NuGet.DependencyResolver.Core)
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;

namespace NuGet
{
    /// <summary>
    /// Provides a caching mechanism for async operations.
    /// </summary>
    /// <typeparam name="TKey">The key to use for storing the async operation.</typeparam>
    /// <typeparam name="TValue">The return type of the async operation.</typeparam>
    internal sealed class TaskResultCache<TKey, TValue>
        where TKey : notnull
    {
        /// <summary>
        /// Represents the cache of async operations.
        /// </summary>
        private readonly ConcurrentDictionary<TKey, Task<TValue>> _cache;

        /// <summary>
        /// Represents a dictionary of locks to synchronize access to individual async operations in the cache.
        /// </summary>
        private readonly ConcurrentDictionary<TKey, object> _perTaskLock;

        /// <summary>
        /// Initializes a new instance of the <see cref="TaskResultCache{TKey, TValue}" /> class with the specified key comparer.
        /// </summary>
        /// <param name="comparer">An <see cref="IEqualityComparer{T}" /> to use when comparing keys.</param>
        public TaskResultCache(IEqualityComparer<TKey> comparer)
        {
            _cache = new(comparer);
            _perTaskLock = new(comparer);
        }

        /// <summary>
        /// Initializes a new instance of the <see cref="TaskResultCache{TKey, TValue}" /> class with the specified initial capacity.
        /// </summary>
        /// <param name="capacity">The default capacity for the cache.</param>
        public TaskResultCache(int capacity)
        {
            _cache = new(concurrencyLevel: Environment.ProcessorCount, capacity);
            _perTaskLock = new(concurrencyLevel: Environment.ProcessorCount, capacity);
        }

        /// <summary>
        /// Initializes a new instance of the <see cref="TaskResultCache{TKey, TValue}" /> class.
        /// </summary>
        public TaskResultCache()
        {
            _cache = new();
            _perTaskLock = new();
        }

        /// <summary>
        /// Gets a collection containing the keys in the cache.
        /// </summary>
        public ICollection<TKey> Keys => _cache.Keys;

        /// <summary>
        /// Gets the cached async operation associated with the specified key, or runs the operation asynchronously and returns <see cref="Task{TValue}" /> that the caller can await.
        /// </summary>
        /// <param name="key">The key for the async operation to get or store in the cache.</param>
        /// <param name="valueFactory">A <see cref="Func{TResult}" /> to execute asynchronously if a cached operation does not exist.</param>
        /// <param name="state">A state object to pass to the value factory.</param>
        /// <param name="cancellationToken">A <see cref="CancellationToken" /> to use for signaling that an operation should be cancelled.</param>
        /// <returns>A <see cref="Task{TResult}" /> for the specified asynchronous operation from the cache if found, otherwise the scheduled asynchronous operation to await.</returns>
        public Task<TValue> GetOrAddAsync<TState>(TKey key, Func<TState, Task<TValue>> valueFactory, TState state, CancellationToken cancellationToken)
        {
            return GetOrAddAsync(key, refresh: false, valueFactory, state, cancellationToken);
        }

        /// <summary>
        /// Gets the cached async operation associated with the specified key, or runs the operation asynchronously and returns <see cref="Task{TValue}" /> that the caller can await, and optionally refreshes the cache.
        /// </summary>
        /// <param name="key">The key for the async operation to get or store in the cache.</param>
        /// <param name="refresh"><see langword="true" /> to force the specified asynchronous operation to be executed and stored in the cache even if a cached operation exists, otherwise <see langword="false" />.</param>
        /// <param name="valueFactory">A <see cref="Func{T1, TResult}" /> to execute asynchronously if a cached operation does not exist.</param>
        /// <param name="state">A state object to pass to the value factory.</param>
        /// <param name="cancellationToken">A <see cref="CancellationToken" /> to use for signaling that an operation should be cancelled.</param>
        /// <returns>A <see cref="Task{TResult}" /> for the specified asynchronous operation from the cache if found, otherwise the scheduled asynchronous operation to await.</returns>
        public Task<TValue> GetOrAddAsync<TState>(TKey key, bool refresh, Func<TState, Task<TValue>> valueFactory, TState state, CancellationToken cancellationToken)
        {
            if (!refresh && _cache.TryGetValue(key, out Task<TValue>? value))
            {
                return value;
            }

            // Get a lock object for this one single key which allows other asynchronous tasks to be added and retrieved at the same time
            // rather than locking the entire cache.
            // NOTE: Be very careful about which overload of GetOrAdd is called. There was previously a very subtle bug with this call:
            //
            // GetOrAdd(key, static () => new object());
            //
            // Which calls the `GetOrAdd(TKey key, TValue value)` overload rather than the `GetOrAdd(TKey key, Func<TKey, TValue> valueFactory)`
            // overload. The consequence is that the same static delegate is cached and locked on rather than having one lock object per key.
            object lockObject = _perTaskLock.GetOrAdd(key, static (TKey _) => new object());

            lock (lockObject)
            {
                if (!refresh && _cache.TryGetValue(key, out value))
                {
                    return value;
                }

                return _cache[key] = valueFactory(state)
                    .ContinueWith(
                        static task => task.GetAwaiter().GetResult(),
                        cancellationToken,
                        TaskContinuationOptions.RunContinuationsAsynchronously,
                        TaskScheduler.Default);
            }
        }

        /// <summary>
        /// Gets the async operation associated with the specified key if one exists, otherwise throws a <see cref="KeyNotFoundException" />.
        /// </summary>
        /// <param name="key">The key for the async operation to get the value of.</param>
        /// <returns></returns>
        /// <exception cref="KeyNotFoundException">The specified key does not exist in the cache.</exception>
        public Task<TValue> GetValueAsync(TKey key)
        {
            if (TryGetValue(key, out Task<TValue>? value))
            {
                return value;
            }

            throw new KeyNotFoundException();
        }

        /// <inheritdoc cref="Dictionary{TKey, TValue}.TryGetValue(TKey, out TValue)" />
        public bool TryGetValue(TKey key, [NotNullWhen(true)] out Task<TValue>? value)
        {
            return _cache.TryGetValue(key, out value);
        }
    }
}