File: ConvertIfToSwitch\AbstractConvertIfToSwitchCodeRefactoringProvider.Analyzer.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Immutable;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.ConvertIfToSwitch;
 
using static BinaryOperatorKind;
 
internal abstract partial class AbstractConvertIfToSwitchCodeRefactoringProvider<
    TIfStatementSyntax, TExpressionSyntax, TIsExpressionSyntax, TPatternSyntax>
{
    // Match the following pattern which can be safely converted to switch statement
    //
    //    <if-statement-sequence>
    //        : if (<section-expr>) { <unreachable-end-point> }, <if-statement-sequence>
    //        : if (<section-expr>) { <unreachable-end-point> }, ( return | throw )
    //        | <if-statement>
    //
    //    <if-statement>
    //        : if (<section-expr>) { _ } else <if-statement>
    //        | if (<section-expr>) { _ } else { <if-statement-sequence> }
    //        | if (<section-expr>) { _ } else { _ }
    //        | if (<section-expr>) { _ }
    //
    //    <section-expr>
    //        : <section-expr> || <pattern-expr>
    //        | <pattern-expr>
    //
    //    <pattern-expr>
    //        : <pattern-expr> && <expr>                         // C#
    //        | <expr0> is <pattern>                             // C#
    //        | <expr0> is <type>                                // C#
    //        | <expr0> == <const-expr>                          // C#, VB
    //        | <expr0> <comparison-op> <const>                  // C#, VB
    //        | ( <expr0> >= <const> | <const> <= <expr0> )
    //           && ( <expr0> <= <const> | <const> >= <expr0> )  // C#, VB
    //        | ( <expr0> <= <const> | <const> >= <expr0> )
    //           && ( <expr0> >= <const> | <const> <= <expr0> )  // C#, VB
    //
    internal abstract class Analyzer(ISyntaxFacts syntaxFacts, AbstractConvertIfToSwitchCodeRefactoringProvider<TIfStatementSyntax, TExpressionSyntax, TIsExpressionSyntax, TPatternSyntax>.Feature features)
    {
        public abstract bool CanConvert(IConditionalOperation operation);
        public abstract bool HasUnreachableEndPoint(IOperation operation);
        public abstract bool CanImplicitlyConvert(SemanticModel semanticModel, SyntaxNode syntax, ITypeSymbol targetType);
 
        /// <summary>
        /// Holds the expression determined to be used as the target expression of the switch
        /// </summary>
        /// <remarks>
        /// Note that this is initially unset until we find a non-constant expression.
        /// </remarks>
        private SyntaxNode _switchTargetExpression = null!;
        /// <summary>
        /// Holds the type of the <see cref="_switchTargetExpression"/>
        /// </summary>
        private ITypeSymbol? _switchTargetType = null!;
        private readonly ISyntaxFacts _syntaxFacts = syntaxFacts;
 
        public Feature Features { get; } = features;
 
        public bool Supports(Feature feature)
            => (Features & feature) != 0;
 
        public (ImmutableArray<AnalyzedSwitchSection>, SyntaxNode TargetExpression) AnalyzeIfStatementSequence(ReadOnlySpan<IOperation> operations)
        {
            using var _ = ArrayBuilder<AnalyzedSwitchSection>.GetInstance(out var sections);
            if (!ParseIfStatementSequence(operations, sections, topLevel: true, out var defaultBodyOpt))
            {
                return default;
            }
 
            if (defaultBodyOpt is object)
            {
                sections.Add(new AnalyzedSwitchSection(labels: default, defaultBodyOpt, defaultBodyOpt.Syntax));
            }
 
            RoslynDebug.Assert(_switchTargetExpression is object);
            return (sections.ToImmutable(), _switchTargetExpression);
        }
 
        // Tree to parse:
        //
        //    <if-statement-sequence>
        //        : if (<section-expr>) { <unreachable-end-point> }, <if-statement-sequence>
        //        : if (<section-expr>) { <unreachable-end-point> }, ( return | throw )
        //        | <if-statement>
        //
        private bool ParseIfStatementSequence(
            ReadOnlySpan<IOperation> operations,
            ArrayBuilder<AnalyzedSwitchSection> sections,
            bool topLevel,
            out IOperation? defaultBodyOpt)
        {
            var current = 0;
            while (current < operations.Length &&
                operations[current] is IConditionalOperation { WhenFalse: null } op &&
                HasUnreachableEndPoint(op.WhenTrue) &&
                ParseIfStatement(op, sections, out _))
            {
                current++;
            }
 
            defaultBodyOpt = null;
            if (current == 0)
            {
                // didn't consume a sequence of if-statements with unreachable ends.  Check for the last case.
                if (operations.Length == 0)
                    return false;
 
                // If we're in the initial state, it's fine for there to be many operations that follow.  We're just
                // trying to check if the first one completes our analysis (and we'll not touch the ones that
                // follow). However, if we're actually in one of the recursive calls, it's *not* ok to ignore the 
                // following ops as those may impact if the higher call into us is ok.  So in that case, we do not
                // allow the parsing to succeed if we have more than one operation left.
                return topLevel
                    ? operations is [var op1, ..] && ParseIfStatement(op1, sections, out defaultBodyOpt)
                    : operations is [var op2] && ParseIfStatement(op2, sections, out defaultBodyOpt);
            }
            else
            {
                if (current < operations.Length)
                {
                    // consumed a sequence of if-statements with unreachable-ends.  If we end with a normal
                    // if-statement, we're done.  Otherwise, we end with whatever last return/throw we see.
                    var nextStatement = operations[current];
                    if (!ParseIfStatement(nextStatement, sections, out defaultBodyOpt) &&
                        nextStatement is IReturnOperation { ReturnedValue: not null } or IThrowOperation { Exception: not null })
                    {
                        defaultBodyOpt = nextStatement;
                    }
                }
 
                return true;
            }
        }
 
        // Tree to parse:
        //
        //    <if-statement>
        //        : if (<section-expr>) { _ } else <if-statement>
        //        | if (<section-expr>) { _ } else { <if-statement-sequence> }
        //        | if (<section-expr>) { _ } else { _ }
        //        | if (<section-expr>) { _ }
        //
        private bool ParseIfStatement(IOperation operation, ArrayBuilder<AnalyzedSwitchSection> sections, out IOperation? defaultBodyOpt)
        {
            switch (operation)
            {
                case IConditionalOperation op when CanConvert(op):
                    var section = ParseSwitchSection(op);
                    if (section is null)
                    {
                        break;
                    }
 
                    sections.Add(section);
 
                    if (op.WhenFalse is null)
                    {
                        defaultBodyOpt = null;
                    }
                    else if (!ParseIfStatementOrBlock(op.WhenFalse, sections, out defaultBodyOpt))
                    {
                        defaultBodyOpt = op.WhenFalse;
                    }
 
                    return true;
            }
 
            defaultBodyOpt = null;
            return false;
        }
 
        private bool ParseIfStatementOrBlock(IOperation op, ArrayBuilder<AnalyzedSwitchSection> sections, out IOperation? defaultBodyOpt)
        {
            return op is IBlockOperation block
                ? ParseIfStatementSequence(block.Operations.AsSpan(), sections, topLevel: false, out defaultBodyOpt)
                : ParseIfStatement(op, sections, out defaultBodyOpt);
        }
 
        private AnalyzedSwitchSection? ParseSwitchSection(IConditionalOperation operation)
        {
            using var _ = ArrayBuilder<AnalyzedSwitchLabel>.GetInstance(out var labels);
            if (!ParseSwitchLabels(operation.Condition, labels))
            {
                return null;
            }
 
            return new AnalyzedSwitchSection(labels.ToImmutable(), operation.WhenTrue, operation.Syntax);
        }
 
        // Tree to parse:
        //
        //    <section-expr>
        //        : <section-expr> || <pattern-expr>
        //        | <pattern-expr>
        //
        private bool ParseSwitchLabels(IOperation operation, ArrayBuilder<AnalyzedSwitchLabel> labels)
        {
            if (operation is IBinaryOperation { OperatorKind: ConditionalOr } op)
            {
                if (!ParseSwitchLabels(op.LeftOperand, labels))
                {
                    return false;
                }
 
                operation = op.RightOperand;
            }
 
            var label = ParseSwitchLabel(operation);
            if (label is null)
            {
                return false;
            }
 
            labels.Add(label);
            return true;
        }
 
        private AnalyzedSwitchLabel? ParseSwitchLabel(IOperation operation)
        {
            using var _ = ArrayBuilder<TExpressionSyntax>.GetInstance(out var guards);
            var pattern = ParsePattern(operation, guards);
            if (pattern is null)
                return null;
 
            return new AnalyzedSwitchLabel(pattern, guards.ToImmutable());
        }
 
        private enum ConstantResult
        {
            /// <summary>
            /// None of operands were constant.
            /// </summary>
            None,
            /// <summary>
            /// Signifies that the left operand is the constant.
            /// </summary>
            Left,
            /// <summary>
            /// Signifies that the right operand is the constant.
            /// </summary>
            Right,
        }
 
        private ConstantResult DetermineConstant(IBinaryOperation op)
        {
            return (op.LeftOperand, op.RightOperand) switch
            {
                var (e, v) when IsConstant(v) && CheckTargetExpression(e) && CheckConstantType(v) => ConstantResult.Right,
                var (v, e) when IsConstant(v) && CheckTargetExpression(e) && CheckConstantType(v) => ConstantResult.Left,
                _ => ConstantResult.None,
            };
        }
 
        // Tree to parse:
        //
        //    <pattern-expr>
        //        : <pattern-expr> && <expr>                         // C#
        //        | <expr0> is <pattern>                             // C#
        //        | <expr0> is <type>                                // C#
        //        | <expr0> == <const-expr>                          // C#, VB
        //        | <expr0> <comparison-op> <const>                  //     VB
        //        | ( <expr0> >= <const> | <const> <= <expr0> )
        //           && ( <expr0> <= <const> | <const> >= <expr0> )  //     VB
        //        | ( <expr0> <= <const> | <const> >= <expr0> )
        //           && ( <expr0> >= <const> | <const> <= <expr0> )  //     VB
        //
        private AnalyzedPattern? ParsePattern(IOperation operation, ArrayBuilder<TExpressionSyntax> guards)
        {
            switch (operation)
            {
                case IBinaryOperation { OperatorKind: ConditionalAnd } op
                    when Supports(Feature.RangePattern) && GetRangeBounds(op) is (TExpressionSyntax lower, TExpressionSyntax higher):
                    return new AnalyzedPattern.Range(lower, higher);
 
                case IBinaryOperation { OperatorKind: BinaryOperatorKind.Equals } op:
                    return DetermineConstant(op) switch
                    {
                        ConstantResult.Left when op.LeftOperand.Syntax is TExpressionSyntax left
                            => new AnalyzedPattern.Constant(left),
                        ConstantResult.Right when op.RightOperand.Syntax is TExpressionSyntax right
                            => new AnalyzedPattern.Constant(right),
                        _ => null
                    };
 
                case IBinaryOperation { OperatorKind: NotEquals } op
                    when Supports(Feature.InequalityPattern):
                    return ParseRelationalPattern(op);
 
                case IBinaryOperation op
                    when Supports(Feature.RelationalPattern) && IsRelationalOperator(op.OperatorKind):
                    return ParseRelationalPattern(op);
 
                // Check this below the cases that produce Relational/Ranges.  We would prefer to use those if
                // available before utilizing a CaseGuard.
                case IBinaryOperation { OperatorKind: ConditionalAnd } op
                    when Supports(Feature.AndPattern | Feature.CaseGuard):
                    {
                        var leftPattern = ParsePattern(op.LeftOperand, guards);
                        if (leftPattern == null)
                            return null;
 
                        if (Supports(Feature.AndPattern))
                        {
                            var guardCount = guards.Count;
                            var rightPattern = ParsePattern(op.RightOperand, guards);
                            if (rightPattern != null)
                                return new AnalyzedPattern.And(leftPattern, rightPattern);
 
                            // Making a pattern out of the RHS didn't work.  Reset the guards back to where we started.
                            guards.Count = guardCount;
                        }
 
                        if (Supports(Feature.CaseGuard) && op.RightOperand.Syntax is TExpressionSyntax node)
                        {
                            guards.Add(node);
                            return leftPattern;
                        }
 
                        return null;
                    }
 
                case IIsTypeOperation op
                    when Supports(Feature.IsTypePattern) && CheckTargetExpression(op.ValueOperand) && op.Syntax is TIsExpressionSyntax node:
                    return new AnalyzedPattern.Type(node);
 
                case IIsPatternOperation op
                    when Supports(Feature.SourcePattern) && CheckTargetExpression(op.Value) && op.Pattern.Syntax is TPatternSyntax pattern:
                    return new AnalyzedPattern.Source(pattern);
 
                case IParenthesizedOperation op:
                    return ParsePattern(op.Operand, guards);
            }
 
            return null;
        }
 
        private AnalyzedPattern? ParseRelationalPattern(IBinaryOperation op)
        {
            return DetermineConstant(op) switch
            {
                ConstantResult.Left when op.LeftOperand.Syntax is TExpressionSyntax left
                    => new AnalyzedPattern.Relational(Flip(op.OperatorKind), left),
                ConstantResult.Right when op.RightOperand.Syntax is TExpressionSyntax right
                    => new AnalyzedPattern.Relational(op.OperatorKind, right),
                _ => null
            };
        }
 
        private enum BoundKind
        {
            /// <summary>
            /// Not a range bound.
            /// </summary>
            None,
            /// <summary>
            /// Signifies that the lower-bound of a range pattern
            /// </summary>
            Lower,
            /// <summary>
            /// Signifies that the higher-bound of a range pattern
            /// </summary>
            Higher,
        }
 
        private (SyntaxNode Lower, SyntaxNode Higher) GetRangeBounds(IBinaryOperation op)
        {
            if (op is not
                { LeftOperand: IBinaryOperation left, RightOperand: IBinaryOperation right })
            {
                return default;
            }
 
            return (GetRangeBound(left), GetRangeBound(right)) switch
            {
                ({ Kind: BoundKind.Lower } low, { Kind: BoundKind.Higher } high)
                    when CheckTargetExpression(low.Expression, high.Expression) => (low.Value.Syntax, high.Value.Syntax),
                ({ Kind: BoundKind.Higher } high, { Kind: BoundKind.Lower } low)
                    when CheckTargetExpression(low.Expression, high.Expression) => (low.Value.Syntax, high.Value.Syntax),
                _ => default
            };
 
            bool CheckTargetExpression(IOperation left, IOperation right)
                => _syntaxFacts.AreEquivalent(left.Syntax, right.Syntax) && this.CheckTargetExpression(left);
        }
 
        private static (BoundKind Kind, IOperation Expression, IOperation Value) GetRangeBound(IBinaryOperation op)
        {
            return op.OperatorKind switch
            {
                // 5 <= i
                LessThanOrEqual when IsConstant(op.LeftOperand) => (BoundKind.Lower, op.RightOperand, op.LeftOperand),
                // i <= 5
                LessThanOrEqual when IsConstant(op.RightOperand) => (BoundKind.Higher, op.LeftOperand, op.RightOperand),
                // 5 >= i
                GreaterThanOrEqual when IsConstant(op.LeftOperand) => (BoundKind.Higher, op.RightOperand, op.LeftOperand),
                // i >= 5
                GreaterThanOrEqual when IsConstant(op.RightOperand) => (BoundKind.Lower, op.LeftOperand, op.RightOperand),
                _ => default
            };
        }
 
        /// <summary>
        /// Changes the direction the operator is pointing
        /// </summary>
        private static BinaryOperatorKind Flip(BinaryOperatorKind operatorKind)
        {
            return operatorKind switch
            {
                LessThan => GreaterThan,
                LessThanOrEqual => GreaterThanOrEqual,
                GreaterThanOrEqual => LessThanOrEqual,
                GreaterThan => LessThan,
                NotEquals => NotEquals,
                var v => throw ExceptionUtilities.UnexpectedValue(v)
            };
        }
 
        private static bool IsRelationalOperator(BinaryOperatorKind operatorKind)
        {
            switch (operatorKind)
            {
                case LessThan:
                case LessThanOrEqual:
                case GreaterThanOrEqual:
                case GreaterThan:
                    return true;
                default:
                    return false;
            }
        }
 
        private static bool IsConstant(IOperation operation)
        {
            // Constants do not propagate to conversions
            return operation is IConversionOperation { Conversion.IsUserDefined: false } op
                ? IsConstant(op.Operand)
                : operation.ConstantValue.HasValue;
        }
 
        private bool CheckTargetExpression(IOperation operation)
        {
            operation = operation.WalkDownConversion();
 
            var expression = operation.Syntax;
            // If we have not figured the switch expression yet,
            // we will assume that the first expression is the one.
            if (_switchTargetExpression is null)
            {
                RoslynDebug.Assert(_switchTargetType is null);
 
                _switchTargetExpression = expression;
                _switchTargetType = operation.Type;
                return true;
            }
 
            return _syntaxFacts.AreEquivalent(expression, _switchTargetExpression);
        }
 
        private bool CheckConstantType(IOperation operation)
        {
            RoslynDebug.AssertNotNull(operation.SemanticModel);
            RoslynDebug.AssertNotNull(_switchTargetType);
 
            return CanImplicitlyConvert(operation.SemanticModel, operation.Syntax, _switchTargetType);
        }
    }
 
    [Flags]
    internal enum Feature
    {
        None = 0,
        // VB/C# 9.0 features
        RelationalPattern = 1,
        // VB features
        InequalityPattern = 1 << 1,
        RangePattern = 1 << 2,
        // C# 7.0 features
        SourcePattern = 1 << 3,
        IsTypePattern = 1 << 4,
        CaseGuard = 1 << 5,
        // C# 8.0 features
        SwitchExpression = 1 << 6,
        // C# 9.0 features
        OrPattern = 1 << 7,
        AndPattern = 1 << 8,
        TypePattern = 1 << 9,
    }
}