File: Utilities\ValueSetFactory.NumericValueSet.cs
Web Access
Project: src\src\Compilers\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.csproj (Microsoft.CodeAnalysis.CSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp
{
    using static BinaryOperatorKind;
 
    internal static partial class ValueSetFactory
    {
        /// <summary>
        /// The implementation of a value set for an numeric type <typeparamref name="T"/>.
        /// </summary>
        private sealed class NumericValueSet<T> : IValueSet<T>
        {
            private readonly ImmutableArray<(T first, T last)> _intervals;
            private readonly INumericTC<T> _tc;
 
            public static NumericValueSet<T> AllValues(INumericTC<T> tc) => new NumericValueSet<T>(tc.MinValue, tc.MaxValue, tc);
 
            public static NumericValueSet<T> NoValues(INumericTC<T> tc) => new NumericValueSet<T>(ImmutableArray<(T first, T last)>.Empty, tc);
 
            internal NumericValueSet(T first, T last, INumericTC<T> tc) : this(ImmutableArray.Create((first, last)), tc)
            {
                Debug.Assert(tc.Related(LessThanOrEqual, first, last));
            }
 
            internal NumericValueSet(ImmutableArray<(T first, T last)> intervals, INumericTC<T> tc)
            {
#if DEBUG
                Debug.Assert(intervals.Length == 0 || tc.Related(GreaterThanOrEqual, intervals[0].first, tc.MinValue));
                for (int i = 0, n = intervals.Length; i < n; i++)
                {
                    Debug.Assert(tc.Related(LessThanOrEqual, intervals[i].first, intervals[i].last));
                    if (i != 0)
                    {
                        // intervals are in increasing order with a gap between them
                        Debug.Assert(tc.Related(LessThan, tc.Next(intervals[i - 1].last), intervals[i].first));
                    }
                }
#endif
                _intervals = intervals;
                _tc = tc;
            }
 
            public bool IsEmpty => _intervals.Length == 0;
 
            ConstantValue IValueSet.Sample
            {
                get
                {
                    if (IsEmpty)
                        throw new ArgumentException();
 
                    // Prefer a value near zero.
                    var gz = new NumericValueSetFactory<T>(_tc).Related(BinaryOperatorKind.GreaterThanOrEqual, _tc.Zero);
                    var t = (NumericValueSet<T>)this.Intersect(gz);
                    if (!t.IsEmpty)
                        return _tc.ToConstantValue(t._intervals[0].first);
                    return _tc.ToConstantValue(this._intervals[this._intervals.Length - 1].last);
                }
            }
 
            public bool Any(BinaryOperatorKind relation, T value)
            {
                switch (relation)
                {
                    case LessThan:
                    case LessThanOrEqual:
                        return _intervals.Length > 0 && _tc.Related(relation, _intervals[0].first, value);
                    case GreaterThan:
                    case GreaterThanOrEqual:
                        return _intervals.Length > 0 && _tc.Related(relation, _intervals[_intervals.Length - 1].last, value);
                    case Equal:
                        return anyIntervalContains(0, _intervals.Length - 1, value);
                    default:
                        throw ExceptionUtilities.UnexpectedValue(relation);
                }
 
                bool anyIntervalContains(int firstIntervalIndex, int lastIntervalIndex, T value)
                {
                    while (true)
                    {
                        if (lastIntervalIndex < firstIntervalIndex)
                            return false;
 
                        if (lastIntervalIndex == firstIntervalIndex)
                            return _tc.Related(GreaterThanOrEqual, value, _intervals[lastIntervalIndex].first) && _tc.Related(LessThanOrEqual, value, _intervals[lastIntervalIndex].last);
 
                        int midIndex = firstIntervalIndex + (lastIntervalIndex - firstIntervalIndex) / 2;
                        if (_tc.Related(LessThanOrEqual, value, _intervals[midIndex].last))
                            lastIntervalIndex = midIndex;
                        else
                            firstIntervalIndex = midIndex + 1;
                    }
                }
            }
 
            bool IValueSet.Any(BinaryOperatorKind relation, ConstantValue value) => value.IsBad || Any(relation, _tc.FromConstantValue(value));
 
            public bool All(BinaryOperatorKind relation, T value)
            {
                if (_intervals.Length == 0)
                    return true;
 
                switch (relation)
                {
                    case LessThan:
                    case LessThanOrEqual:
                        return _tc.Related(relation, _intervals[_intervals.Length - 1].last, value);
                    case GreaterThan:
                    case GreaterThanOrEqual:
                        return _tc.Related(relation, _intervals[0].first, value);
                    case Equal:
                        return _intervals.Length == 1 && _tc.Related(Equal, _intervals[0].first, value) && _tc.Related(Equal, _intervals[0].last, value);
                    default:
                        throw ExceptionUtilities.UnexpectedValue(relation);
                }
            }
 
            bool IValueSet.All(BinaryOperatorKind relation, ConstantValue value) => !value.IsBad && All(relation, _tc.FromConstantValue(value));
 
            public IValueSet<T> Complement()
            {
                if (_intervals.Length == 0)
                    return AllValues(_tc);
 
                var builder = ArrayBuilder<(T first, T last)>.GetInstance();
 
                // add a prefix if apropos.
                if (_tc.Related(LessThan, _tc.MinValue, _intervals[0].first))
                {
                    builder.Add((_tc.MinValue, _tc.Prev(_intervals[0].first)));
                }
 
                // add the in-between intervals
                int lastIndex = _intervals.Length - 1;
                for (int i = 0; i < lastIndex; i++)
                {
                    builder.Add((_tc.Next(_intervals[i].last), _tc.Prev(_intervals[i + 1].first)));
                }
 
                // add a suffix if apropos
                if (_tc.Related(LessThan, _intervals[lastIndex].last, _tc.MaxValue))
                {
                    builder.Add((_tc.Next(_intervals[lastIndex].last), _tc.MaxValue));
                }
 
                return new NumericValueSet<T>(builder.ToImmutableAndFree(), _tc);
            }
 
            IValueSet IValueSet.Complement() => this.Complement();
 
            public IValueSet<T> Intersect(IValueSet<T> o)
            {
                var other = (NumericValueSet<T>)o;
                Debug.Assert(this._tc.GetType() == other._tc.GetType());
 
                var builder = ArrayBuilder<(T first, T last)>.GetInstance();
                var left = this._intervals;
                var right = other._intervals;
                int l = 0;
                int r = 0;
                while (l < left.Length && r < right.Length)
                {
                    var leftInterval = left[l];
                    var rightInterval = right[r];
                    if (_tc.Related(LessThan, leftInterval.last, rightInterval.first))
                    {
                        l++;
                    }
                    else if (_tc.Related(LessThan, rightInterval.last, leftInterval.first))
                    {
                        r++;
                    }
                    else
                    {
                        Add(builder, Max(leftInterval.first, rightInterval.first, _tc), Min(leftInterval.last, rightInterval.last, _tc), _tc);
                        if (_tc.Related(LessThan, leftInterval.last, rightInterval.last))
                        {
                            l++;
                        }
                        else if (_tc.Related(LessThan, rightInterval.last, leftInterval.last))
                        {
                            r++;
                        }
                        else
                        {
                            l++;
                            r++;
                        }
                    }
                }
 
                return new NumericValueSet<T>(builder.ToImmutableAndFree(), _tc);
            }
 
            /// <summary>
            /// Add an interval to the end of the builder.
            /// </summary>
            private static void Add(ArrayBuilder<(T first, T last)> builder, T first, T last, INumericTC<T> tc)
            {
                Debug.Assert(tc.Related(LessThanOrEqual, first, last));
                Debug.Assert(tc.Related(GreaterThanOrEqual, first, tc.MinValue));
                Debug.Assert(tc.Related(LessThanOrEqual, last, tc.MaxValue));
                Debug.Assert(builder.Count == 0 || tc.Related(LessThanOrEqual, builder.Last().first, first));
                if (builder.Count > 0 && (tc.Related(Equal, tc.MinValue, first) || tc.Related(GreaterThanOrEqual, builder.Last().last, tc.Prev(first))))
                {
                    // merge with previous interval when adjacent
                    var oldLastInterval = builder.Pop();
                    oldLastInterval.last = Max(last, oldLastInterval.last, tc);
                    builder.Push(oldLastInterval);
                }
                else
                {
                    builder.Add((first, last));
                }
            }
            private static T Min(T a, T b, INumericTC<T> tc)
            {
                return tc.Related(LessThan, a, b) ? a : b;
            }
 
            private static T Max(T a, T b, INumericTC<T> tc)
            {
                return tc.Related(LessThan, a, b) ? b : a;
            }
 
            IValueSet IValueSet.Intersect(IValueSet other) => this.Intersect((IValueSet<T>)other);
 
            public IValueSet<T> Union(IValueSet<T> o)
            {
                var other = (NumericValueSet<T>)o;
                Debug.Assert(this._tc.GetType() == other._tc.GetType());
 
                var builder = ArrayBuilder<(T first, T last)>.GetInstance();
                var left = this._intervals;
                var right = other._intervals;
                int l = 0;
                int r = 0;
                while (l < left.Length && r < right.Length)
                {
                    var leftInterval = left[l];
                    var rightInterval = right[r];
                    if (_tc.Related(LessThan, leftInterval.last, rightInterval.first))
                    {
                        Add(builder, leftInterval.first, leftInterval.last, _tc);
                        l++;
                    }
                    else if (_tc.Related(LessThan, rightInterval.last, leftInterval.first))
                    {
                        Add(builder, rightInterval.first, rightInterval.last, _tc);
                        r++;
                    }
                    else
                    {
                        Add(builder, Min(leftInterval.first, rightInterval.first, _tc), Max(leftInterval.last, rightInterval.last, _tc), _tc);
                        l++;
                        r++;
                    }
                }
 
                while (l < left.Length)
                {
                    var leftInterval = left[l];
                    Add(builder, leftInterval.first, leftInterval.last, _tc);
                    l++;
                }
 
                while (r < right.Length)
                {
                    var rightInterval = right[r];
                    Add(builder, rightInterval.first, rightInterval.last, _tc);
                    r++;
                }
 
                return new NumericValueSet<T>(builder.ToImmutableAndFree(), _tc);
            }
 
            IValueSet IValueSet.Union(IValueSet other) => this.Union((IValueSet<T>)other);
 
            /// <summary>
            /// Produce a random value set for testing purposes.
            /// </summary>
            internal static IValueSet<T> Random(int expectedSize, Random random, INumericTC<T> tc)
            {
                T[] values = new T[expectedSize * 2];
                for (int i = 0, n = expectedSize * 2; i < n; i++)
                {
                    values[i] = tc.Random(random);
                }
                Array.Sort(values);
                var builder = ArrayBuilder<(T first, T last)>.GetInstance();
                for (int i = 0, n = values.Length; i < n; i += 2)
                {
                    T first = values[i];
                    T last = values[i + 1];
                    Add(builder, first, last, tc);
                }
 
                return new NumericValueSet<T>(builder.ToImmutableAndFree(), tc);
            }
 
            /// <summary>
            /// A string representation for testing purposes.
            /// </summary>
            public override string ToString()
            {
                return string.Join(",", this._intervals.Select(p => $"[{_tc.ToString(p.first)}..{_tc.ToString(p.last)}]"));
            }
 
            public override bool Equals(object? obj) =>
                obj is NumericValueSet<T> other &&
                this._intervals.SequenceEqual(other._intervals);
 
            public override int GetHashCode()
            {
                return Hash.Combine(Hash.CombineValues(_intervals), _intervals.Length);
            }
        }
    }
}