|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Diagnostics;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.CSharp.ConvertSwitchStatementToExpression;
using static ConvertSwitchStatementToExpressionHelpers;
using static CSharpSyntaxTokens;
using static SyntaxFactory;
internal sealed partial class ConvertSwitchStatementToExpressionCodeFixProvider
{
private sealed class Rewriter : CSharpSyntaxVisitor<ExpressionSyntax>
{
private readonly SemanticModel _semanticModel;
private readonly bool _isAllThrowStatements;
private readonly CancellationToken _cancellationToken;
private ExpressionSyntax? _assignmentTarget;
private Rewriter(SemanticModel semanticModel, bool isAllThrowStatements, CancellationToken cancellationToken)
{
_semanticModel = semanticModel;
_isAllThrowStatements = isAllThrowStatements;
_cancellationToken = cancellationToken;
}
public static StatementSyntax Rewrite(
SwitchStatementSyntax switchStatement,
SemanticModel model,
ITypeSymbol? declaratorToRemoveType,
SyntaxKind nodeToGenerate,
bool shouldMoveNextStatementToSwitchExpression,
bool generateDeclaration,
CancellationToken cancellationToken)
{
if (switchStatement.ContainsDirectives)
{
// Do not rewrite statements with preprocessor directives
return switchStatement;
}
var rewriter = new Rewriter(model, isAllThrowStatements: nodeToGenerate == SyntaxKind.ThrowStatement, cancellationToken);
// Rewrite the switch statement as a switch expression.
var switchExpression = rewriter.RewriteSwitchStatement(
switchStatement, topLevel: true,
allowMoveNextStatementToSwitchExpression: shouldMoveNextStatementToSwitchExpression);
// Generate the final statement to wrap the switch expression, e.g. a "return" or an assignment.
return rewriter.GetFinalStatement(
switchExpression, switchStatement.SwitchKeyword.LeadingTrivia, declaratorToRemoveType, nodeToGenerate, generateDeclaration);
}
private StatementSyntax GetFinalStatement(
ExpressionSyntax switchExpression,
SyntaxTriviaList leadingTrivia,
ITypeSymbol? declaratorToRemoveType,
SyntaxKind nodeToGenerate,
bool generateDeclaration)
{
switch (nodeToGenerate)
{
case SyntaxKind.ReturnStatement:
return ReturnStatement(
Token(leadingTrivia, SyntaxKind.ReturnKeyword, trailing: default),
switchExpression,
SemicolonToken);
case SyntaxKind.ThrowStatement:
return ThrowStatement(
Token(leadingTrivia, SyntaxKind.ThrowKeyword, trailing: default),
switchExpression,
SemicolonToken);
}
Debug.Assert(SyntaxFacts.IsAssignmentExpression(nodeToGenerate));
Debug.Assert(_assignmentTarget != null);
return generateDeclaration
? GenerateVariableDeclaration(switchExpression, declaratorToRemoveType)
: GenerateAssignment(switchExpression, nodeToGenerate, leadingTrivia);
}
private ExpressionStatementSyntax GenerateAssignment(ExpressionSyntax switchExpression, SyntaxKind assignmentKind, SyntaxTriviaList leadingTrivia)
{
Contract.ThrowIfNull(_assignmentTarget);
return ExpressionStatement(
AssignmentExpression(assignmentKind,
left: _assignmentTarget,
right: switchExpression))
.WithLeadingTrivia(leadingTrivia);
}
private StatementSyntax GenerateVariableDeclaration(ExpressionSyntax switchExpression, ITypeSymbol? declaratorToRemoveType)
{
Contract.ThrowIfFalse(_assignmentTarget is IdentifierNameSyntax);
// There is a probability that we cannot use var if the declaration type is a reference type or nullable type.
// In these cases, we generate the explicit type for now and decide later whether or not to use var.
var cannotUseVar = declaratorToRemoveType != null && (declaratorToRemoveType.IsReferenceType || declaratorToRemoveType.IsNullable());
var type = cannotUseVar ? declaratorToRemoveType!.GenerateTypeSyntax() : IdentifierName("var");
return LocalDeclarationStatement(
VariableDeclaration(
type,
variables: [VariableDeclarator(
identifier: ((IdentifierNameSyntax)_assignmentTarget).Identifier,
argumentList: null,
initializer: EqualsValueClause(switchExpression))]));
}
private SwitchExpressionArmSyntax GetSwitchExpressionArm(SwitchSectionSyntax node)
{
return SwitchExpressionArm(
pattern: GetPattern(node.Labels, out var whenClauseOpt),
whenClause: whenClauseOpt,
expression: RewriteStatements(node.Statements));
}
private static PatternSyntax GetPattern(SyntaxList<SwitchLabelSyntax> switchLabels, out WhenClauseSyntax? whenClause)
{
if (switchLabels.Count == 1)
return GetPattern(switchLabels[0], out whenClause);
if (switchLabels.Any(label => IsDefaultSwitchLabel(label)))
{
// original group had a catch-all label. just convert to a discard _ to indicate the same.
whenClause = null;
return DiscardPattern();
}
// Multiple labels, and no catch-all merge them using an 'or' pattern.
var totalPattern = GetPattern(switchLabels[0], out var whenClauseUnused);
Debug.Assert(whenClauseUnused == null, "We should not have offered to convert multiple cases if any have a when clause");
for (var i = 1; i < switchLabels.Count; i++)
{
var nextPatternPart = GetPattern(switchLabels[i], out whenClauseUnused);
Debug.Assert(whenClauseUnused == null, "We should not have offered to convert multiple cases if any have a when clause");
totalPattern = BinaryPattern(SyntaxKind.OrPattern, totalPattern.Parenthesize(), nextPatternPart.Parenthesize());
}
whenClause = null;
return totalPattern;
}
private static PatternSyntax GetPattern(SwitchLabelSyntax switchLabel, out WhenClauseSyntax? whenClause)
{
switch (switchLabel.Kind())
{
case SyntaxKind.CasePatternSwitchLabel:
var node = (CasePatternSwitchLabelSyntax)switchLabel;
whenClause = node.WhenClause;
return node.Pattern;
case SyntaxKind.CaseSwitchLabel:
whenClause = null;
return ConstantPattern(((CaseSwitchLabelSyntax)switchLabel).Value);
case SyntaxKind.DefaultSwitchLabel:
whenClause = null;
return DiscardPattern();
case var value:
throw ExceptionUtilities.UnexpectedValue(value);
}
}
public override ExpressionSyntax VisitAssignmentExpression(AssignmentExpressionSyntax node)
{
_assignmentTarget ??= node.Left;
return CastIfChangeInRuntimeRepresentation(node.Right);
}
private ExpressionSyntax CastIfChangeInRuntimeRepresentation(ExpressionSyntax node)
{
// If the existing return/assign had an conversion involved that changed the runtime representation of
// the type (or value), then insert that same cast explicitly in the final result. This is needed as
// switch statements do not use best-common-type, but switch expressions can use it. We don't want the
// original conversion to be lost because a new best-common-type conversion is added.
var typeInfo = _semanticModel.GetTypeInfo(node, _cancellationToken);
if (typeInfo.ConvertedType is not null &&
typeInfo.Type is not null &&
!Equals(typeInfo.ConvertedType, typeInfo.Type))
{
var conversion = _semanticModel.Compilation.ClassifyConversion(typeInfo.Type, typeInfo.ConvertedType);
if (!conversion.IsIdentityOrImplicitReference())
return node.Cast(typeInfo.ConvertedType);
}
return node;
}
private ExpressionSyntax RewriteStatements(SyntaxList<StatementSyntax> statements)
{
Debug.Assert(statements.Count is 1 or 2);
Debug.Assert(!statements[0].IsKind(SyntaxKind.BreakStatement));
var result = Visit(statements[0]);
Contract.ThrowIfNull(result);
return result;
}
public override ExpressionSyntax VisitSwitchStatement(SwitchStatementSyntax node)
=> RewriteSwitchStatement(node, topLevel: false);
private ExpressionSyntax RewriteSwitchStatement(
SwitchStatementSyntax node,
bool topLevel,
bool allowMoveNextStatementToSwitchExpression = true)
{
var switchArms = node.Sections
// The default label must come last in the switch expression.
.OrderBy(section => section.Labels.Any(label => IsDefaultSwitchLabel(label)))
.Select(s =>
(tokensForLeadingTrivia: new[] { s.Labels[0].GetFirstToken(), s.Labels[0].GetLastToken() },
tokensForTrailingTrivia: new[] { s.Statements[0].GetFirstToken(), s.Statements[0].GetLastToken() },
armExpression: GetSwitchExpressionArm(s)))
.ToList();
if (allowMoveNextStatementToSwitchExpression)
{
var nextStatement = node.GetNextStatement();
if (nextStatement is (kind: SyntaxKind.ThrowStatement or SyntaxKind.ReturnStatement))
{
var armExpression = Visit(nextStatement);
Contract.ThrowIfNull(armExpression);
switchArms.Add(
(tokensForLeadingTrivia: new[] { nextStatement.GetFirstToken() },
tokensForTrailingTrivia: new[] { nextStatement.GetLastToken() },
SwitchExpressionArm(DiscardPattern(), armExpression)));
}
}
// add explicit cast if necessary
var switchStatement = topLevel ? AddCastIfNecessary(node) : node;
return SwitchExpression(
switchStatement.Expression.Parenthesize(),
Token(leading: default, SyntaxKind.SwitchKeyword, node.CloseParenToken.TrailingTrivia),
Token(leading: default, SyntaxKind.OpenBraceToken, node.OpenBraceToken.TrailingTrivia),
SeparatedList(
switchArms.Select(t => t.armExpression.WithLeadingTrivia(t.tokensForLeadingTrivia.GetTrivia().FilterComments(addElasticMarker: false))),
switchArms.Select(t => CommaToken.WithTrailingTrivia(t.tokensForTrailingTrivia.GetTrivia().FilterComments(addElasticMarker: true)))),
CloseBraceToken);
}
private SwitchStatementSyntax AddCastIfNecessary(SwitchStatementSyntax node)
{
// If the switch statement expression is being implicitly converted then we need to explicitly cast the
// expression before rewriting as a switch expression
var expressionType = _semanticModel.GetSymbolInfo(node.Expression).Symbol.GetSymbolType();
var expressionConvertedType = _semanticModel.GetTypeInfo(node.Expression).ConvertedType;
if (expressionConvertedType != null &&
!SymbolEqualityComparer.Default.Equals(expressionConvertedType, expressionType))
{
return node.Update(node.SwitchKeyword, node.OpenParenToken,
node.Expression.Cast(expressionConvertedType).WithAdditionalAnnotations(Formatter.Annotation),
node.CloseParenToken, node.OpenBraceToken,
node.Sections, node.CloseBraceToken);
}
return node;
}
public override ExpressionSyntax VisitReturnStatement(ReturnStatementSyntax node)
{
Contract.ThrowIfNull(node.Expression);
return CastIfChangeInRuntimeRepresentation(node.Expression);
}
public override ExpressionSyntax VisitThrowStatement(ThrowStatementSyntax node)
{
Contract.ThrowIfNull(node.Expression);
// If this is an all-throw switch statement, we return the expression rather than
// creating a throw expression so we can wrap the switch expression inside a throw expression.
return _isAllThrowStatements ? node.Expression : ThrowExpression(node.Expression);
}
public override ExpressionSyntax VisitExpressionStatement(ExpressionStatementSyntax node)
{
var result = Visit(node.Expression);
Contract.ThrowIfNull(result);
return result;
}
public override ExpressionSyntax DefaultVisit(SyntaxNode node)
=> throw ExceptionUtilities.UnexpectedValue(node.Kind());
}
}
|