File: src\Analyzers\CSharp\CodeFixes\ConvertSwitchStatementToExpression\ConvertSwitchStatementToExpressionCodeFixProvider.Rewriter.cs
Web Access
Project: src\src\CodeStyle\CSharp\CodeFixes\Microsoft.CodeAnalysis.CSharp.CodeStyle.Fixes.csproj (Microsoft.CodeAnalysis.CSharp.CodeStyle.Fixes)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Diagnostics;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.ConvertSwitchStatementToExpression;
 
using static ConvertSwitchStatementToExpressionHelpers;
using static CSharpSyntaxTokens;
using static SyntaxFactory;
 
internal sealed partial class ConvertSwitchStatementToExpressionCodeFixProvider
{
    private sealed class Rewriter : CSharpSyntaxVisitor<ExpressionSyntax>
    {
        private readonly SemanticModel _semanticModel;
        private readonly bool _isAllThrowStatements;
        private readonly CancellationToken _cancellationToken;
 
        private ExpressionSyntax? _assignmentTarget;
 
        private Rewriter(SemanticModel semanticModel, bool isAllThrowStatements, CancellationToken cancellationToken)
        {
            _semanticModel = semanticModel;
            _isAllThrowStatements = isAllThrowStatements;
            _cancellationToken = cancellationToken;
        }
 
        public static StatementSyntax Rewrite(
            SwitchStatementSyntax switchStatement,
            SemanticModel model,
            ITypeSymbol? declaratorToRemoveType,
            SyntaxKind nodeToGenerate,
            bool shouldMoveNextStatementToSwitchExpression,
            bool generateDeclaration,
            CancellationToken cancellationToken)
        {
            if (switchStatement.ContainsDirectives)
            {
                // Do not rewrite statements with preprocessor directives
                return switchStatement;
            }
 
            var rewriter = new Rewriter(model, isAllThrowStatements: nodeToGenerate == SyntaxKind.ThrowStatement, cancellationToken);
 
            // Rewrite the switch statement as a switch expression.
            var switchExpression = rewriter.RewriteSwitchStatement(
                switchStatement, topLevel: true,
                allowMoveNextStatementToSwitchExpression: shouldMoveNextStatementToSwitchExpression);
 
            // Generate the final statement to wrap the switch expression, e.g. a "return" or an assignment.
            return rewriter.GetFinalStatement(
                switchExpression, switchStatement.SwitchKeyword.LeadingTrivia, declaratorToRemoveType, nodeToGenerate, generateDeclaration);
        }
 
        private StatementSyntax GetFinalStatement(
            ExpressionSyntax switchExpression,
            SyntaxTriviaList leadingTrivia,
            ITypeSymbol? declaratorToRemoveType,
            SyntaxKind nodeToGenerate,
            bool generateDeclaration)
        {
            switch (nodeToGenerate)
            {
                case SyntaxKind.ReturnStatement:
                    return ReturnStatement(
                        Token(leadingTrivia, SyntaxKind.ReturnKeyword, trailing: default),
                        switchExpression,
                        SemicolonToken);
                case SyntaxKind.ThrowStatement:
                    return ThrowStatement(
                        Token(leadingTrivia, SyntaxKind.ThrowKeyword, trailing: default),
                        switchExpression,
                        SemicolonToken);
            }
 
            Debug.Assert(SyntaxFacts.IsAssignmentExpression(nodeToGenerate));
            Debug.Assert(_assignmentTarget != null);
 
            return generateDeclaration
                ? GenerateVariableDeclaration(switchExpression, declaratorToRemoveType)
                : GenerateAssignment(switchExpression, nodeToGenerate, leadingTrivia);
        }
 
        private ExpressionStatementSyntax GenerateAssignment(ExpressionSyntax switchExpression, SyntaxKind assignmentKind, SyntaxTriviaList leadingTrivia)
        {
            Contract.ThrowIfNull(_assignmentTarget);
 
            return ExpressionStatement(
                AssignmentExpression(assignmentKind,
                    left: _assignmentTarget,
                    right: switchExpression))
                .WithLeadingTrivia(leadingTrivia);
        }
 
        private StatementSyntax GenerateVariableDeclaration(ExpressionSyntax switchExpression, ITypeSymbol? declaratorToRemoveType)
        {
            Contract.ThrowIfFalse(_assignmentTarget is IdentifierNameSyntax);
 
            // There is a probability that we cannot use var if the declaration type is a reference type or nullable type.
            // In these cases, we generate the explicit type for now and decide later whether or not to use var.
            var cannotUseVar = declaratorToRemoveType != null && (declaratorToRemoveType.IsReferenceType || declaratorToRemoveType.IsNullable());
            var type = cannotUseVar ? declaratorToRemoveType!.GenerateTypeSyntax() : IdentifierName("var");
 
            return LocalDeclarationStatement(
                VariableDeclaration(
                    type,
                    variables: [VariableDeclarator(
                        identifier: ((IdentifierNameSyntax)_assignmentTarget).Identifier,
                        argumentList: null,
                        initializer: EqualsValueClause(switchExpression))]));
        }
 
        private SwitchExpressionArmSyntax GetSwitchExpressionArm(SwitchSectionSyntax node)
        {
            return SwitchExpressionArm(
                pattern: GetPattern(node.Labels, out var whenClauseOpt),
                whenClause: whenClauseOpt,
                expression: RewriteStatements(node.Statements));
        }
 
        private static PatternSyntax GetPattern(SyntaxList<SwitchLabelSyntax> switchLabels, out WhenClauseSyntax? whenClause)
        {
            if (switchLabels.Count == 1)
                return GetPattern(switchLabels[0], out whenClause);
 
            if (switchLabels.Any(label => IsDefaultSwitchLabel(label)))
            {
                // original group had a catch-all label.  just convert to a discard _ to indicate the same.
                whenClause = null;
                return DiscardPattern();
            }
 
            // Multiple labels, and no catch-all merge them using an 'or' pattern.
            var totalPattern = GetPattern(switchLabels[0], out var whenClauseUnused);
            Debug.Assert(whenClauseUnused == null, "We should not have offered to convert multiple cases if any have a when clause");
 
            for (var i = 1; i < switchLabels.Count; i++)
            {
                var nextPatternPart = GetPattern(switchLabels[i], out whenClauseUnused);
                Debug.Assert(whenClauseUnused == null, "We should not have offered to convert multiple cases if any have a when clause");
 
                totalPattern = BinaryPattern(SyntaxKind.OrPattern, totalPattern.Parenthesize(), nextPatternPart.Parenthesize());
            }
 
            whenClause = null;
            return totalPattern;
        }
 
        private static PatternSyntax GetPattern(SwitchLabelSyntax switchLabel, out WhenClauseSyntax? whenClause)
        {
            switch (switchLabel.Kind())
            {
                case SyntaxKind.CasePatternSwitchLabel:
                    var node = (CasePatternSwitchLabelSyntax)switchLabel;
                    whenClause = node.WhenClause;
                    return node.Pattern;
 
                case SyntaxKind.CaseSwitchLabel:
                    whenClause = null;
                    return ConstantPattern(((CaseSwitchLabelSyntax)switchLabel).Value);
 
                case SyntaxKind.DefaultSwitchLabel:
                    whenClause = null;
                    return DiscardPattern();
 
                case var value:
                    throw ExceptionUtilities.UnexpectedValue(value);
            }
        }
 
        public override ExpressionSyntax VisitAssignmentExpression(AssignmentExpressionSyntax node)
        {
            _assignmentTarget ??= node.Left;
            return CastIfChangeInRuntimeRepresentation(node.Right);
        }
 
        private ExpressionSyntax CastIfChangeInRuntimeRepresentation(ExpressionSyntax node)
        {
            // If the existing return/assign had an conversion involved that changed the runtime representation of
            // the type (or value), then insert that same cast explicitly in the final result.  This is needed as
            // switch statements do not use best-common-type, but switch expressions can use it.  We don't want the
            // original conversion to be lost because a new best-common-type conversion is added.
            var typeInfo = _semanticModel.GetTypeInfo(node, _cancellationToken);
            if (typeInfo.ConvertedType is not null &&
                typeInfo.Type is not null &&
                !Equals(typeInfo.ConvertedType, typeInfo.Type))
            {
                var conversion = _semanticModel.Compilation.ClassifyConversion(typeInfo.Type, typeInfo.ConvertedType);
                if (!conversion.IsIdentityOrImplicitReference())
                    return node.Cast(typeInfo.ConvertedType);
            }
 
            return node;
        }
 
        private ExpressionSyntax RewriteStatements(SyntaxList<StatementSyntax> statements)
        {
            Debug.Assert(statements.Count is 1 or 2);
            Debug.Assert(!statements[0].IsKind(SyntaxKind.BreakStatement));
            var result = Visit(statements[0]);
            Contract.ThrowIfNull(result);
            return result;
        }
 
        public override ExpressionSyntax VisitSwitchStatement(SwitchStatementSyntax node)
            => RewriteSwitchStatement(node, topLevel: false);
 
        private ExpressionSyntax RewriteSwitchStatement(
            SwitchStatementSyntax node,
            bool topLevel,
            bool allowMoveNextStatementToSwitchExpression = true)
        {
            var switchArms = node.Sections
                // The default label must come last in the switch expression.
                .OrderBy(section => section.Labels.Any(label => IsDefaultSwitchLabel(label)))
                .Select(s =>
                    (tokensForLeadingTrivia: new[] { s.Labels[0].GetFirstToken(), s.Labels[0].GetLastToken() },
                     tokensForTrailingTrivia: new[] { s.Statements[0].GetFirstToken(), s.Statements[0].GetLastToken() },
                     armExpression: GetSwitchExpressionArm(s)))
                .ToList();
 
            if (allowMoveNextStatementToSwitchExpression)
            {
                var nextStatement = node.GetNextStatement();
                if (nextStatement is (kind: SyntaxKind.ThrowStatement or SyntaxKind.ReturnStatement))
                {
                    var armExpression = Visit(nextStatement);
                    Contract.ThrowIfNull(armExpression);
 
                    switchArms.Add(
                        (tokensForLeadingTrivia: new[] { nextStatement.GetFirstToken() },
                         tokensForTrailingTrivia: new[] { nextStatement.GetLastToken() },
                         SwitchExpressionArm(DiscardPattern(), armExpression)));
                }
            }
            // add explicit cast if necessary 
            var switchStatement = topLevel ? AddCastIfNecessary(node) : node;
 
            return SwitchExpression(
                switchStatement.Expression.Parenthesize(),
                Token(leading: default, SyntaxKind.SwitchKeyword, node.CloseParenToken.TrailingTrivia),
                Token(leading: default, SyntaxKind.OpenBraceToken, node.OpenBraceToken.TrailingTrivia),
                SeparatedList(
                    switchArms.Select(t => t.armExpression.WithLeadingTrivia(t.tokensForLeadingTrivia.GetTrivia().FilterComments(addElasticMarker: false))),
                    switchArms.Select(t => CommaToken.WithTrailingTrivia(t.tokensForTrailingTrivia.GetTrivia().FilterComments(addElasticMarker: true)))),
                CloseBraceToken);
        }
 
        private SwitchStatementSyntax AddCastIfNecessary(SwitchStatementSyntax node)
        {
            // If the switch statement expression is being implicitly converted then we need to explicitly cast the
            // expression before rewriting as a switch expression
            var expressionType = _semanticModel.GetSymbolInfo(node.Expression).Symbol.GetSymbolType();
            var expressionConvertedType = _semanticModel.GetTypeInfo(node.Expression).ConvertedType;
 
            if (expressionConvertedType != null &&
                !SymbolEqualityComparer.Default.Equals(expressionConvertedType, expressionType))
            {
                return node.Update(node.SwitchKeyword, node.OpenParenToken,
                    node.Expression.Cast(expressionConvertedType).WithAdditionalAnnotations(Formatter.Annotation),
                    node.CloseParenToken, node.OpenBraceToken,
                    node.Sections, node.CloseBraceToken);
            }
 
            return node;
        }
 
        public override ExpressionSyntax VisitReturnStatement(ReturnStatementSyntax node)
        {
            Contract.ThrowIfNull(node.Expression);
            return CastIfChangeInRuntimeRepresentation(node.Expression);
        }
 
        public override ExpressionSyntax VisitThrowStatement(ThrowStatementSyntax node)
        {
            Contract.ThrowIfNull(node.Expression);
            // If this is an all-throw switch statement, we return the expression rather than
            // creating a throw expression so we can wrap the switch expression inside a throw expression.
            return _isAllThrowStatements ? node.Expression : ThrowExpression(node.Expression);
        }
 
        public override ExpressionSyntax VisitExpressionStatement(ExpressionStatementSyntax node)
        {
            var result = Visit(node.Expression);
            Contract.ThrowIfNull(result);
            return result;
        }
 
        public override ExpressionSyntax DefaultVisit(SyntaxNode node)
            => throw ExceptionUtilities.UnexpectedValue(node.Kind());
    }
}