File: ConvertIfToSwitch\CSharpConvertIfToSwitchCodeRefactoringProvider.Rewriting.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.ConvertIfToSwitch;
 
using static CSharpSyntaxTokens;
using static SyntaxFactory;
 
internal sealed partial class CSharpConvertIfToSwitchCodeRefactoringProvider
{
    private static readonly Dictionary<BinaryOperatorKind, SyntaxKind> s_operatorMap = new Dictionary<BinaryOperatorKind, SyntaxKind>
    {
        { BinaryOperatorKind.LessThan, SyntaxKind.LessThanToken },
        { BinaryOperatorKind.GreaterThan, SyntaxKind.GreaterThanToken },
        { BinaryOperatorKind.LessThanOrEqual, SyntaxKind.LessThanEqualsToken },
        { BinaryOperatorKind.GreaterThanOrEqual, SyntaxKind.GreaterThanEqualsToken },
    };
 
    public override SyntaxNode CreateSwitchExpressionStatement(SyntaxNode target, ImmutableArray<AnalyzedSwitchSection> sections, Feature feature)
    {
        return ReturnStatement(
            SwitchExpression(
                (ExpressionSyntax)target,
                [.. sections.Select(section => AsSwitchExpressionArmSyntax(section, feature))]));
    }
 
    private static SwitchExpressionArmSyntax AsSwitchExpressionArmSyntax(AnalyzedSwitchSection section, Feature feature)
    {
        if (section.Labels.IsDefault)
            return SwitchExpressionArm(DiscardPattern(), AsExpressionSyntax(section.Body));
 
        var pattern = AsPatternSyntax(section.Labels[0].Pattern, feature);
        var whenClause = AsWhenClause(section.Labels[0]);
 
        Debug.Assert(whenClause == null || section.Labels.Length == 1, "We shouldn't have guards when we're combining multiple cases into a single arm");
 
        for (var i = 1; i < section.Labels.Length; i++)
        {
            var label = section.Labels[i];
            Debug.Assert(label.Guards.Length == 0, "We shouldn't have guards when we're combining multiple cases into a single arm");
            var nextPattern = AsPatternSyntax(label.Pattern, feature);
            pattern = BinaryPattern(SyntaxKind.OrPattern, pattern.Parenthesize(), nextPattern.Parenthesize());
        }
 
        return SwitchExpressionArm(pattern, whenClause, AsExpressionSyntax(section.Body));
    }
 
    private static ExpressionSyntax AsExpressionSyntax(IOperation operation)
        => operation switch
        {
            IReturnOperation { ReturnedValue: { } value } => (ExpressionSyntax)value.Syntax,
            IThrowOperation { Exception: { } exception } => ThrowExpression((ExpressionSyntax)exception.Syntax),
            IBlockOperation op => AsExpressionSyntax(op.Operations.Single()),
            var v => throw ExceptionUtilities.UnexpectedValue(v.Kind)
        };
 
    public override SyntaxNode CreateSwitchStatement(IfStatementSyntax ifStatement, SyntaxNode expression, IEnumerable<SyntaxNode> sectionList)
    {
        var block = ifStatement.Statement as BlockSyntax;
        return SwitchStatement(
            switchKeyword: SwitchKeyword.WithTriviaFrom(ifStatement.IfKeyword),
            openParenToken: ifStatement.OpenParenToken,
            expression: (ExpressionSyntax)expression,
            closeParenToken: ifStatement.CloseParenToken.WithPrependedLeadingTrivia(ElasticMarker),
            openBraceToken: block?.OpenBraceToken ?? OpenBraceToken,
            sections: [.. sectionList.Cast<SwitchSectionSyntax>()],
            closeBraceToken: block?.CloseBraceToken.WithoutLeadingTrivia() ?? CloseBraceToken);
    }
 
    private static WhenClauseSyntax? AsWhenClause(AnalyzedSwitchLabel label)
        => AsWhenClause(label.Guards
            .Select(e => e.WalkUpParentheses())
            .AggregateOrDefault((prev, current) => BinaryExpression(SyntaxKind.LogicalAndExpression, prev, current)));
 
    private static WhenClauseSyntax? AsWhenClause(ExpressionSyntax? expression)
        => expression is null ? null : WhenClause(expression);
 
    public override SyntaxNode AsSwitchLabelSyntax(AnalyzedSwitchLabel label, Feature feature)
        => CasePatternSwitchLabel(
            AsPatternSyntax(label.Pattern, feature),
            AsWhenClause(label),
            ColonToken);
 
    private static PatternSyntax AsPatternSyntax(AnalyzedPattern pattern, Feature feature)
        => pattern switch
        {
            AnalyzedPattern.And p => BinaryPattern(SyntaxKind.AndPattern, AsPatternSyntax(p.LeftPattern, feature).Parenthesize(), AsPatternSyntax(p.RightPattern, feature).Parenthesize()),
            AnalyzedPattern.Constant p => ConstantPattern(p.ExpressionSyntax),
            AnalyzedPattern.Source p => p.PatternSyntax,
            AnalyzedPattern.Type p when feature.HasFlag(Feature.TypePattern) => TypePattern((TypeSyntax)p.IsExpressionSyntax.Right),
            AnalyzedPattern.Type p => DeclarationPattern((TypeSyntax)p.IsExpressionSyntax.Right, DiscardDesignation()),
            AnalyzedPattern.Relational p => RelationalPattern(Token(s_operatorMap[p.OperatorKind]), p.Value),
            var p => throw ExceptionUtilities.UnexpectedValue(p)
        };
 
    public override IEnumerable<SyntaxNode> AsSwitchSectionStatements(IOperation operation)
    {
        var node = operation.Syntax;
        Debug.Assert(operation.SemanticModel is not null);
        var requiresBreak = operation.SemanticModel.AnalyzeControlFlow(node).EndPointIsReachable;
        var requiresBlock = !operation.SemanticModel.AnalyzeDataFlow(node).VariablesDeclared.IsDefaultOrEmpty;
 
        var statements = ArrayBuilder<SyntaxNode>.GetInstance();
        if (node is BlockSyntax block)
        {
            if (block.Statements.Count == 0)
            {
                statements.Add(BreakStatement());
            }
            else if (requiresBlock)
            {
                statements.Add(requiresBreak ? block.AddStatements(BreakStatement()) : block);
            }
            else
            {
                statements.AddRange(block.Statements);
                if (requiresBreak)
                {
                    statements.Add(BreakStatement().WithLeadingTrivia(block.CloseBraceToken.LeadingTrivia));
                }
            }
        }
        else
        {
            statements.Add(node);
            if (requiresBreak)
            {
                statements.Add(BreakStatement());
            }
        }
 
        return statements.ToArrayAndFree();
    }
}