File: Collections\ConcurrentDictionary.cs
Web Access
Project: ..\..\..\src\MSBuildTaskHost\MSBuildTaskHost.csproj (MSBuildTaskHost)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Microsoft.Build.TaskHost.Utilities;
 
namespace Microsoft.Build.TaskHost.Collections;
 
// The following class is back-ported from .NET 4.X CoreFX library because
// MSBuildTaskHost requires 3.5 .NET Framework. Only GetOrAdd method kept.
internal class ConcurrentDictionary<TKey, TValue>
    where TKey : notnull
{
    /// <summary>
    /// Tables that hold the internal state of the ConcurrentDictionary
    ///
    /// Wrapping the three tables in a single object allows us to atomically
    /// replace all tables at once.
    /// </summary>
    private sealed class Tables
    {
        internal readonly Node[] _buckets; // A singly-linked list for each bucket.
        internal readonly object[] _locks; // A set of locks, each guarding a section of the table.
        internal volatile int[] _countPerLock; // The number of elements guarded by each lock.
 
        internal Tables(Node[] buckets, object[] locks, int[] countPerLock)
        {
            _buckets = buckets;
            _locks = locks;
            _countPerLock = countPerLock;
        }
    }
 
    private volatile Tables _tables; // Internal tables of the dictionary
    private readonly IEqualityComparer<TKey> _comparer; // Key equality comparer
    private readonly bool _growLockArray; // Whether to dynamically increase the size of the striped lock
    private int _budget; // The maximum number of elements per lock before a resize operation is triggered
 
    // The default capacity, i.e. the initial # of buckets. When choosing this value, we are making
    // a trade-off between the size of a very small dictionary, and the number of resizes when
    // constructing a large dictionary. Also, the capacity should not be divisible by a small prime.
    private const int DefaultCapacity = 31;
 
    // The maximum size of the striped lock that will not be exceeded when locks are automatically
    // added as the dictionary grows. However, the user is allowed to exceed this limit by passing
    // a concurrency level larger than MaxLockNumber into the constructor.
    private const int MaxLockNumber = 1024;
 
    // Whether TValue is a type that can be written atomically (i.e., with no danger of torn reads)
    private static readonly bool s_isValueWriteAtomic = IsValueWriteAtomic();
 
    /// <summary>
    /// Determines whether type TValue can be written atomically
    /// </summary>
    private static bool IsValueWriteAtomic()
    {
        // Section 12.6.6 of ECMA CLI explains which types can be read and written atomically without
        // the risk of tearing.
        //
        // See http://www.ecma-international.org/publications/files/ECMA-ST/Ecma-335.pdf
        Type valueType = typeof(TValue);
        if (!valueType.IsValueType)
        {
            return true;
        }
 
        switch (Type.GetTypeCode(valueType))
        {
            case TypeCode.Boolean:
            case TypeCode.Byte:
            case TypeCode.Char:
            case TypeCode.Int16:
            case TypeCode.Int32:
            case TypeCode.SByte:
            case TypeCode.Single:
            case TypeCode.UInt16:
            case TypeCode.UInt32:
                return true;
            case TypeCode.Int64:
            case TypeCode.Double:
            case TypeCode.UInt64:
                return IntPtr.Size == 8;
            default:
                return false;
        }
    }
 
    /// <summary>
    /// Initializes a new instance of the <see
    /// cref="ConcurrentDictionary{TKey,TValue}"/>
    /// class that is empty, has the default concurrency level, has the default initial capacity, and
    /// uses the default comparer for the key type.
    /// </summary>
    public ConcurrentDictionary(IEqualityComparer<TKey>? comparer = null)
    {
        int concurrencyLevel = NativeMethods.GetLogicalCoreCount();
        int capacity = DefaultCapacity;
 
        // The capacity should be at least as large as the concurrency level. Otherwise, we would have locks that don't guard
        // any buckets.
        if (capacity < concurrencyLevel)
        {
            capacity = concurrencyLevel;
        }
 
        object[] locks = new object[concurrencyLevel];
        for (int i = 0; i < locks.Length; i++)
        {
            locks[i] = new object();
        }
 
        int[] countPerLock = new int[locks.Length];
        Node[] buckets = new Node[capacity];
        _tables = new Tables(buckets, locks, countPerLock);
 
        _comparer = comparer ?? EqualityComparer<TKey>.Default;
        _growLockArray = true;
        _budget = buckets.Length / locks.Length;
    }
 
    private bool TryGetValueInternal(TKey key, int hashcode, [MaybeNullWhen(false)] out TValue value)
    {
        Debug.Assert(_comparer.GetHashCode(key) == hashcode);
 
        // We must capture the _buckets field in a local variable. It is set to a new table on each table resize.
        Tables tables = _tables;
 
        int bucketNo = GetBucket(hashcode, tables._buckets.Length);
 
        // We can get away w/out a lock here.
        // The Volatile.Read ensures that we have a copy of the reference to tables._buckets[bucketNo].
        // This protects us from reading fields ('_hashcode', '_key', '_value' and '_next') of different instances.
        Thread.MemoryBarrier();
        Node? n = tables._buckets[bucketNo];
 
        while (n != null)
        {
            if (hashcode == n._hashcode && _comparer.Equals(n._key, key))
            {
                value = n._value;
                return true;
            }
 
            n = n._next;
        }
 
        value = default;
        return false;
    }
 
    /// <summary>
    /// Shared internal implementation for inserts and updates.
    /// If key exists, we always return false; and if updateIfExists == true we force update with value;
    /// If key doesn't exist, we always add value and return true;
    /// </summary>
    private bool TryAddInternal(TKey key, int hashcode, TValue value, bool updateIfExists, bool acquireLock, out TValue resultingValue)
    {
        Debug.Assert(_comparer.GetHashCode(key) == hashcode);
 
        while (true)
        {
            int bucketNo, lockNo;
 
            Tables tables = _tables;
            GetBucketAndLockNo(hashcode, out bucketNo, out lockNo, tables._buckets.Length, tables._locks.Length);
 
            bool resizeDesired = false;
            bool lockTaken = false;
            try
            {
                if (acquireLock)
                {
                    lockTaken = Monitor.TryEnter(tables._locks[lockNo]);
                }
 
                // If the table just got resized, we may not be holding the right lock, and must retry.
                // This should be a rare occurrence.
                if (tables != _tables)
                {
                    continue;
                }
 
                // Try to find this key in the bucket
                Node? prev = null;
                for (Node? node = tables._buckets[bucketNo]; node != null; node = node._next)
                {
                    Debug.Assert((prev == null && node == tables._buckets[bucketNo]) || prev!._next == node);
                    if (hashcode == node._hashcode && _comparer.Equals(node._key, key))
                    {
                        // The key was found in the dictionary. If updates are allowed, update the value for that key.
                        // We need to create a new node for the update, in order to support TValue types that cannot
                        // be written atomically, since lock-free reads may be happening concurrently.
                        if (updateIfExists)
                        {
                            if (s_isValueWriteAtomic)
                            {
                                node._value = value;
                            }
                            else
                            {
                                Node newNode = new Node(node._key, value, hashcode, node._next);
                                if (prev == null)
                                {
                                    Interlocked.Exchange(ref tables._buckets[bucketNo], newNode);
                                }
                                else
                                {
                                    prev._next = newNode;
                                }
                            }
                            resultingValue = value;
                        }
                        else
                        {
                            resultingValue = node._value;
                        }
                        return false;
                    }
                    prev = node;
                }
 
                // The key was not found in the bucket. Insert the key-value pair.
                Interlocked.Exchange(ref tables._buckets[bucketNo], new Node(key, value, hashcode, tables._buckets[bucketNo]));
                checked
                {
                    tables._countPerLock[lockNo]++;
                }
 
                // If the number of elements guarded by this lock has exceeded the budget, resize the bucket table.
                // It is also possible that GrowTable will increase the budget but won't resize the bucket table.
                // That happens if the bucket table is found to be poorly utilized due to a bad hash function.
                if (tables._countPerLock[lockNo] > _budget)
                {
                    resizeDesired = true;
                }
            }
            finally
            {
                if (lockTaken)
                {
                    Monitor.Exit(tables._locks[lockNo]);
                }
            }
 
            // The fact that we got here means that we just performed an insertion. If necessary, we will grow the table.
            //
            // Concurrency notes:
            // - Notice that we are not holding any locks at when calling GrowTable. This is necessary to prevent deadlocks.
            // - As a result, it is possible that GrowTable will be called unnecessarily. But, GrowTable will obtain lock 0
            //   and then verify that the table we passed to it as the argument is still the current table.
            if (resizeDesired)
            {
                GrowTable(tables);
            }
 
            resultingValue = value;
            return true;
        }
    }
 
    [DoesNotReturn]
    private static void ThrowKeyNullException()
        => throw new ArgumentNullException("key");
 
    /// <summary>
    /// Adds a key/value pair to the <see cref="ConcurrentDictionary{TKey,TValue}"/>
    /// if the key does not already exist.
    /// </summary>
    /// <param name="key">The key of the element to add.</param>
    /// <param name="valueFactory">The function used to generate a value for the key.</param>
    /// <exception cref="ArgumentNullException"><paramref name="key"/> is a null reference
    /// (Nothing in Visual Basic).</exception>
    /// <exception cref="ArgumentNullException"><paramref name="valueFactory"/> is a null reference
    /// (Nothing in Visual Basic).</exception>
    /// <exception cref="OverflowException">The dictionary contains too many
    /// elements.</exception>
    /// <returns>The value for the key. This will be either the existing value for the key if the
    /// key is already in the dictionary, or the new value for the key as returned by valueFactory
    /// if the key was not in the dictionary.</returns>
    public TValue GetOrAdd(TKey key, Func<TKey, TValue> valueFactory)
    {
        if (key == null)
        {
            ThrowKeyNullException();
        }
 
        if (valueFactory == null)
        {
            throw new ArgumentNullException(nameof(valueFactory));
        }
 
        int hashcode = _comparer.GetHashCode(key);
 
        if (!TryGetValueInternal(key, hashcode, out TValue? resultingValue))
        {
            TryAddInternal(key, hashcode, valueFactory(key), updateIfExists: false, acquireLock: true, out resultingValue);
        }
 
        return resultingValue;
    }
 
    /// <summary>
    /// Replaces the bucket table with a larger one. To prevent multiple threads from resizing the
    /// table as a result of races, the Tables instance that holds the table of buckets deemed too
    /// small is passed in as an argument to GrowTable(). GrowTable() obtains a lock, and then checks
    /// the Tables instance has been replaced in the meantime or not.
    /// </summary>
    private void GrowTable(Tables tables)
    {
        const int MaxArrayLength = 0X7FEFFFFF;
        int locksAcquired = 0;
        try
        {
            // The thread that first obtains _locks[0] will be the one doing the resize operation
            AcquireLocks(0, 1, ref locksAcquired);
 
            // Make sure nobody resized the table while we were waiting for lock 0:
            if (tables != _tables)
            {
                // We assume that since the table reference is different, it was already resized (or the budget
                // was adjusted). If we ever decide to do table shrinking, or replace the table for other reasons,
                // we will have to revisit this logic.
                return;
            }
 
            // Compute the (approx.) total size. Use an Int64 accumulation variable to avoid an overflow.
            long approxCount = 0;
            for (int i = 0; i < tables._countPerLock.Length; i++)
            {
                approxCount += tables._countPerLock[i];
            }
 
            // If the bucket array is too empty, double the budget instead of resizing the table
            if (approxCount < tables._buckets.Length / 4)
            {
                _budget = 2 * _budget;
 
                if (_budget < 0)
                {
                    _budget = int.MaxValue;
                }
 
                return;
            }
 
            // Compute the new table size. We find the smallest integer larger than twice the previous table size, and not divisible by
            // 2,3,5 or 7. We can consider a different table-sizing policy in the future.
            int newLength = 0;
            bool maximizeTableSize = false;
            try
            {
                checked
                {
                    // Double the size of the buckets table and add one, so that we have an odd integer.
                    newLength = (tables._buckets.Length * 2) + 1;
 
                    // Now, we only need to check odd integers, and find the first that is not divisible
                    // by 3, 5 or 7.
                    while (newLength % 3 == 0 || newLength % 5 == 0 || newLength % 7 == 0)
                    {
                        newLength += 2;
                    }
 
                    Debug.Assert(newLength % 2 != 0);
 
                    if (newLength > MaxArrayLength)
                    {
                        maximizeTableSize = true;
                    }
                }
            }
            catch (OverflowException)
            {
                maximizeTableSize = true;
            }
 
            if (maximizeTableSize)
            {
                newLength = MaxArrayLength;
 
                // We want to make sure that GrowTable will not be called again, since table is at the maximum size.
                // To achieve that, we set the budget to int.MaxValue.
                //
                // (There is one special case that would allow GrowTable() to be called in the future:
                // calling Clear() on the ConcurrentDictionary will shrink the table and lower the budget.)
                _budget = int.MaxValue;
            }
 
            // Now acquire all other locks for the table
            AcquireLocks(1, tables._locks.Length, ref locksAcquired);
 
            object[] newLocks = tables._locks;
 
            // Add more locks
            if (_growLockArray && tables._locks.Length < MaxLockNumber)
            {
                newLocks = new object[tables._locks.Length * 2];
                Array.Copy(tables._locks, 0, newLocks, 0, tables._locks.Length);
                for (int i = tables._locks.Length; i < newLocks.Length; i++)
                {
                    newLocks[i] = new object();
                }
            }
 
            Node[] newBuckets = new Node[newLength];
            int[] newCountPerLock = new int[newLocks.Length];
 
            // Copy all data into a new table, creating new nodes for all elements
            for (int i = 0; i < tables._buckets.Length; i++)
            {
                Node? current = tables._buckets[i];
                while (current != null)
                {
                    Node? next = current._next;
                    int newBucketNo, newLockNo;
                    GetBucketAndLockNo(current._hashcode, out newBucketNo, out newLockNo, newBuckets.Length, newLocks.Length);
 
                    newBuckets[newBucketNo] = new Node(current._key, current._value, current._hashcode, newBuckets[newBucketNo]);
 
                    checked
                    {
                        newCountPerLock[newLockNo]++;
                    }
 
                    current = next;
                }
            }
 
            // Adjust the budget
            _budget = Math.Max(1, newBuckets.Length / newLocks.Length);
 
            // Replace tables with the new versions
            _tables = new Tables(newBuckets, newLocks, newCountPerLock);
        }
        finally
        {
            // Release all locks that we took earlier
            ReleaseLocks(0, locksAcquired);
        }
    }
 
    /// <summary>
    /// Computes the bucket for a particular key.
    /// </summary>
    private static int GetBucket(int hashcode, int bucketCount)
    {
        int bucketNo = (hashcode & 0x7fffffff) % bucketCount;
        Debug.Assert(bucketNo >= 0 && bucketNo < bucketCount);
 
        return bucketNo;
    }
 
    /// <summary>
    /// Computes the bucket and lock number for a particular key.
    /// </summary>
    private static void GetBucketAndLockNo(int hashcode, out int bucketNo, out int lockNo, int bucketCount, int lockCount)
    {
        bucketNo = (hashcode & 0x7fffffff) % bucketCount;
        lockNo = bucketNo % lockCount;
 
        Debug.Assert(bucketNo >= 0 && bucketNo < bucketCount);
        Debug.Assert(lockNo >= 0 && lockNo < lockCount);
    }
 
    /// <summary>
    /// Acquires a contiguous range of locks for this hash table, and increments locksAcquired
    /// by the number of locks that were successfully acquired. The locks are acquired in an
    /// increasing order.
    /// </summary>
    private void AcquireLocks(int fromInclusive, int toExclusive, ref int locksAcquired)
    {
        Debug.Assert(fromInclusive <= toExclusive);
        object[] locks = _tables._locks;
 
        for (int i = fromInclusive; i < toExclusive; i++)
        {
            bool lockTaken = false;
            try
            {
                lockTaken = Monitor.TryEnter(locks[i]);
            }
            finally
            {
                if (lockTaken)
                {
                    locksAcquired++;
                }
            }
        }
    }
 
    /// <summary>
    /// Releases a contiguous range of locks.
    /// </summary>
    private void ReleaseLocks(int fromInclusive, int toExclusive)
    {
        Debug.Assert(fromInclusive <= toExclusive);
 
        for (int i = fromInclusive; i < toExclusive; i++)
        {
            Monitor.Exit(_tables._locks[i]);
        }
    }
 
    /// <summary>
    /// A node in a singly-linked list representing a particular hash table bucket.
    /// </summary>
    private sealed class Node
    {
        internal readonly TKey _key;
        internal TValue _value;
        internal volatile Node? _next;
        internal readonly int _hashcode;
 
        internal Node(TKey key, TValue value, int hashcode, Node? next)
        {
            _key = key;
            _value = value;
            _next = next;
            _hashcode = hashcode;
        }
    }
}