File: InternalUtilities\OneOrMany.cs
Web Access
Project: src\src\Compilers\Core\Portable\Microsoft.CodeAnalysis.csproj (Microsoft.CodeAnalysis)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Collections;
 
namespace Roslyn.Utilities
{
    /// <summary>
    /// Represents a single item or many items (including none).
    /// </summary>
    /// <remarks>
    /// Used when a collection usually contains a single item but sometimes might contain multiple.
    /// </remarks>
    [DebuggerDisplay("{GetDebuggerDisplay(),nq}")]
    [DebuggerTypeProxy(typeof(OneOrMany<>.DebuggerProxy))]
    internal readonly struct OneOrMany<T>
    {
        public static readonly OneOrMany<T> Empty = new OneOrMany<T>(ImmutableArray<T>.Empty);
 
        private readonly T? _one;
        private readonly ImmutableArray<T> _many;
 
        public OneOrMany(T one)
        {
            _one = one;
            _many = default;
        }
 
        public OneOrMany(ImmutableArray<T> many)
        {
            if (many.IsDefault)
            {
                throw new ArgumentNullException(nameof(many));
            }
 
            if (many is [var item])
            {
                _one = item;
                _many = default;
            }
            else
            {
                _one = default;
                _many = many;
            }
        }
 
        /// <summary>
        /// True if the collection has a single item. This item is stored in <see cref="_one"/>.
        /// </summary>
        [MemberNotNullWhen(true, nameof(_one))]
        private bool HasOneItem
            => _many.IsDefault;
 
        public bool IsDefault
            => _one == null && _many.IsDefault;
 
        public T this[int index]
        {
            get
            {
                if (HasOneItem)
                {
                    if (index != 0)
                    {
                        throw new IndexOutOfRangeException();
                    }
 
                    return _one;
                }
                else
                {
                    return _many[index];
                }
            }
        }
 
        public int Count
            => HasOneItem ? 1 : _many.Length;
 
        public bool IsEmpty
            => Count == 0;
 
        public OneOrMany<T> Add(T item)
            => HasOneItem ? OneOrMany.Create(_one, item) :
               IsEmpty ? OneOrMany.Create(item) :
               OneOrMany.Create(_many.Add(item));
 
        public void AddRangeTo(ArrayBuilder<T> builder)
        {
            if (HasOneItem)
            {
                builder.Add(_one);
            }
            else
            {
                builder.AddRange(_many);
            }
        }
 
        public bool Contains(T item)
            => HasOneItem ? EqualityComparer<T>.Default.Equals(item, _one) : _many.Contains(item);
 
        public OneOrMany<T> RemoveAll(T item)
        {
            if (HasOneItem)
            {
                return EqualityComparer<T>.Default.Equals(item, _one) ? Empty : this;
            }
 
            return OneOrMany.Create(_many.WhereAsArray(static (value, item) => !EqualityComparer<T>.Default.Equals(value, item), item));
        }
 
        public OneOrMany<TResult> Select<TResult>(Func<T, TResult> selector)
        {
            return HasOneItem ?
                OneOrMany.Create(selector(_one)) :
                OneOrMany.Create(_many.SelectAsArray(selector));
        }
 
        public OneOrMany<TResult> Select<TResult, TArg>(Func<T, TArg, TResult> selector, TArg arg)
        {
            return HasOneItem ?
                OneOrMany.Create(selector(_one, arg)) :
                OneOrMany.Create(_many.SelectAsArray(selector, arg));
        }
 
        public T First() => this[0];
 
        public T? FirstOrDefault()
            => HasOneItem ? _one : _many.FirstOrDefault();
 
        public T? FirstOrDefault(Func<T, bool> predicate)
        {
            if (HasOneItem)
            {
                return predicate(_one) ? _one : default;
            }
 
            return _many.FirstOrDefault(predicate);
        }
 
        public T? FirstOrDefault<TArg>(Func<T, TArg, bool> predicate, TArg arg)
        {
            if (HasOneItem)
            {
                return predicate(_one, arg) ? _one : default;
            }
 
            return _many.FirstOrDefault(predicate, arg);
        }
 
        public static OneOrMany<T> CastUp<TDerived>(OneOrMany<TDerived> from) where TDerived : class, T
        {
            return from.HasOneItem
                ? new OneOrMany<T>(from._one)
                : new OneOrMany<T>(ImmutableArray<T>.CastUp(from._many));
        }
 
        public bool All(Func<T, bool> predicate)
            => HasOneItem ? predicate(_one) : _many.All(predicate);
 
        public bool All<TArg>(Func<T, TArg, bool> predicate, TArg arg)
            => HasOneItem ? predicate(_one, arg) : _many.All(predicate, arg);
 
        public bool Any()
            => !IsEmpty;
 
        public bool Any(Func<T, bool> predicate)
            => HasOneItem ? predicate(_one) : _many.Any(predicate);
 
        public bool Any<TArg>(Func<T, TArg, bool> predicate, TArg arg)
            => HasOneItem ? predicate(_one, arg) : _many.Any(predicate, arg);
 
        public ImmutableArray<T> ToImmutable()
            => HasOneItem ? ImmutableArray.Create(_one) : _many;
 
        public T[] ToArray()
            => HasOneItem ? new[] { _one } : _many.ToArray();
 
        public bool SequenceEqual(OneOrMany<T> other, IEqualityComparer<T>? comparer = null)
        {
            comparer ??= EqualityComparer<T>.Default;
 
            if (Count != other.Count)
            {
                return false;
            }
 
            Debug.Assert(HasOneItem == other.HasOneItem);
 
            return HasOneItem ? comparer.Equals(_one, other._one!) :
                   System.Linq.ImmutableArrayExtensions.SequenceEqual(_many, other._many, comparer);
        }
 
        public bool SequenceEqual(ImmutableArray<T> other, IEqualityComparer<T>? comparer = null)
            => SequenceEqual(OneOrMany.Create(other), comparer);
 
        public bool SequenceEqual(IEnumerable<T> other, IEqualityComparer<T>? comparer = null)
        {
            comparer ??= EqualityComparer<T>.Default;
 
            if (!HasOneItem)
            {
                return _many.SequenceEqual(other, comparer);
            }
 
            var first = true;
            foreach (var otherItem in other)
            {
                if (!first || !comparer.Equals(_one, otherItem))
                {
                    return false;
                }
 
                first = false;
            }
 
            return true;
        }
 
        public Enumerator GetEnumerator()
            => new(this);
 
        internal struct Enumerator
        {
            private readonly OneOrMany<T> _collection;
            private int _index;
 
            internal Enumerator(OneOrMany<T> collection)
            {
                _collection = collection;
                _index = -1;
            }
 
            public bool MoveNext()
            {
                _index++;
                return _index < _collection.Count;
            }
 
            public T Current => _collection[_index];
        }
 
        private sealed class DebuggerProxy(OneOrMany<T> instance)
        {
            private readonly OneOrMany<T> _instance = instance;
 
            [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)]
            public T[] Items => _instance.ToArray();
        }
 
        private string GetDebuggerDisplay()
            => "Count = " + Count;
    }
 
    internal static class OneOrMany
    {
        public static OneOrMany<T> Create<T>(T one)
            => new OneOrMany<T>(one);
 
        public static OneOrMany<T> Create<T>(T one, T two)
            => new OneOrMany<T>(ImmutableArray.Create(one, two));
 
        public static OneOrMany<T> OneOrNone<T>(T? one)
            => one is null ? OneOrMany<T>.Empty : new OneOrMany<T>(one);
 
        public static OneOrMany<T> Create<T>(ImmutableArray<T> many)
            => new OneOrMany<T>(many);
 
        public static bool SequenceEqual<T>(this ImmutableArray<T> array, OneOrMany<T> other, IEqualityComparer<T>? comparer = null)
            => Create(array).SequenceEqual(other, comparer);
 
        public static bool SequenceEqual<T>(this IEnumerable<T> array, OneOrMany<T> other, IEqualityComparer<T>? comparer = null)
            => other.SequenceEqual(array, comparer);
    }
}