File: InvertIf\AbstractInvertIfCodeRefactoringProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.InvertIf;
 
internal abstract partial class AbstractInvertIfCodeRefactoringProvider<
    TSyntaxKind, TStatementSyntax, TIfStatementSyntax, TEmbeddedStatement> : CodeRefactoringProvider
    where TSyntaxKind : struct, Enum
    where TStatementSyntax : SyntaxNode
    where TIfStatementSyntax : TStatementSyntax
{
    private enum InvertIfStyle
    {
        IfWithElse_SwapIfBodyWithElseBody,
        IfWithoutElse_SwapIfBodyWithSubsequentStatements,
        IfWithoutElse_MoveSubsequentStatementsToIfBody,
        IfWithoutElse_WithElseClause,
        IfWithoutElse_MoveIfBodyToElseClause,
        IfWithoutElse_WithSubsequentExitPointStatement,
        IfWithoutElse_WithNearmostJumpStatement,
        IfWithoutElse_WithNegatedCondition,
    }
 
    protected abstract string GetTitle();
 
    protected abstract SyntaxList<TStatementSyntax> GetStatements(SyntaxNode node);
    protected abstract TStatementSyntax? GetNextStatement(TStatementSyntax node);
 
    protected abstract TStatementSyntax GetJumpStatement(TSyntaxKind kind);
    protected abstract TSyntaxKind? GetJumpStatementKind(SyntaxNode node);
 
    protected abstract bool IsNoOpSyntaxNode(SyntaxNode node);
    protected abstract bool IsExecutableStatement(SyntaxNode node);
    protected abstract bool IsStatementContainer(SyntaxNode node);
    protected abstract bool IsSingleStatementStatementRange(StatementRange statementRange);
 
    protected abstract bool CanControlFlowOut(SyntaxNode node);
 
    protected abstract bool CanInvert(TIfStatementSyntax ifNode);
    protected abstract bool IsElseless(TIfStatementSyntax ifNode);
 
    protected abstract StatementRange GetIfBodyStatementRange(TIfStatementSyntax ifNode);
    protected abstract SyntaxNode GetCondition(TIfStatementSyntax ifNode);
 
    protected abstract IEnumerable<TStatementSyntax> UnwrapBlock(TEmbeddedStatement ifBody);
    protected abstract TEmbeddedStatement GetIfBody(TIfStatementSyntax ifNode);
    protected abstract TEmbeddedStatement GetElseBody(TIfStatementSyntax ifNode);
    protected abstract TEmbeddedStatement GetEmptyEmbeddedStatement();
 
    protected abstract TEmbeddedStatement AsEmbeddedStatement(
        IEnumerable<TStatementSyntax> statements,
        TEmbeddedStatement original);
 
    protected abstract TIfStatementSyntax UpdateIf(
        SourceText sourceText,
        TIfStatementSyntax ifNode,
        SyntaxNode condition,
        TEmbeddedStatement trueStatement,
        TEmbeddedStatement? falseStatement = default);
 
    protected abstract SyntaxNode WithStatements(
        SyntaxNode node,
        IEnumerable<TStatementSyntax> statements);
 
    public sealed override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, _, cancellationToken) = context;
 
        var ifNode = await context.TryGetRelevantNodeAsync<TIfStatementSyntax>().ConfigureAwait(false);
        if (ifNode == null)
            return;
 
        if (!CanInvert(ifNode))
            return;
 
        var title = GetTitle();
        context.RegisterRefactoring(
            CodeAction.Create(
                title,
                c => InvertIfAsync(document, ifNode, c),
                title),
            ifNode.Span);
    }
 
    private InvertIfStyle GetInvertIfStyle(
        ISyntaxKinds syntaxKinds,
        TIfStatementSyntax ifNode,
        SemanticModel semanticModel,
        out SyntaxNode? subsequentSingleExitPoint)
    {
        subsequentSingleExitPoint = null;
 
        if (!IsElseless(ifNode))
        {
            return InvertIfStyle.IfWithElse_SwapIfBodyWithElseBody;
        }
 
        var ifBodyStatementRange = GetIfBodyStatementRange(ifNode);
        if (IsEmptyStatementRange(ifBodyStatementRange))
        {
            // (1) An empty if-statement: just negate the condition
            //  
            //  if (condition) { }
            //
            // ->
            //
            //  if (!condition) { }
            //
            return InvertIfStyle.IfWithoutElse_WithNegatedCondition;
        }
 
        var subsequentStatementRanges = GetSubsequentStatementRanges(ifNode);
        if (subsequentStatementRanges.All(IsEmptyStatementRange))
        {
            // (2) No statements after if-statement, invert with the nearmost parent jump-statement
            //
            //  void M() {
            //    if (condition) {
            //      Body();
            //    }
            //  }
            //
            // ->
            //
            //  void M() {
            //    if (!condition) {
            //      return;
            //    }
            //    Body();
            //  }
            //
            return InvertIfStyle.IfWithoutElse_WithNearmostJumpStatement;
        }
 
        AnalyzeControlFlow(
            semanticModel, ifBodyStatementRange,
            out var ifBodyEndPointIsReachable,
            out var ifBodySingleExitPointOpt);
 
        AnalyzeSubsequentControlFlow(
            semanticModel, subsequentStatementRanges,
            out var subsequentEndPointIsReachable,
            out subsequentSingleExitPoint);
 
        if (subsequentEndPointIsReachable)
        {
            if (!ifBodyEndPointIsReachable)
            {
                if (IsSingleStatementStatementRange(ifBodyStatementRange) &&
                    SubsequentStatementsAreInTheSameBlock(ifNode, subsequentStatementRanges) &&
                    ifBodySingleExitPointOpt != null &&
                    GetNearestParentJumpStatementKind(ifNode).Equals(syntaxKinds.Convert<TSyntaxKind>(ifBodySingleExitPointOpt.RawKind)))
                {
                    // (3) Inverse of the case (2). Safe to move all subsequent statements to if-body.
                    // 
                    //  while (condition) {
                    //    if (condition) {
                    //      continue;
                    //    }
                    //    f();
                    //  }
                    //
                    // ->
                    //
                    //  while (condition) {
                    //    if (!condition) {
                    //      f();
                    //    }
                    //  }
                    //
                    return InvertIfStyle.IfWithoutElse_MoveSubsequentStatementsToIfBody;
                }
                else
                {
                    // (4) Otherwise, we generate the else and swap blocks to keep flow intact.
                    // 
                    //  while (condition) {
                    //    if (condition) {
                    //      return;
                    //    }
                    //    f();
                    //  }
                    //
                    // ->
                    //
                    //  while (condition) {
                    //    if (!condition) {
                    //      f();
                    //    } else {
                    //      return;
                    //    }
                    //  }
                    //
                    return InvertIfStyle.IfWithoutElse_WithElseClause;
                }
            }
        }
        else if (ifBodyEndPointIsReachable)
        {
            if (subsequentSingleExitPoint != null &&
                SingleSubsequentStatement(subsequentStatementRanges))
            {
                // (5) if-body end-point is reachable but the next statement is a only jump-statement.
                //     This usually happens in a switch-statement. We invert and use that jump-statement.
                // 
                //  case constant:
                //    if (condition) {
                //      f();
                //    }
                //    break;
                //
                // ->
                //
                //  case constant:
                //    if (!condition) {
                //      break;
                //    }
                //    f();
                //    break; // we always keep this so that we don't end up with invalid code.
                //
                return InvertIfStyle.IfWithoutElse_WithSubsequentExitPointStatement;
            }
        }
        else if (SubsequentStatementsAreInTheSameBlock(ifNode, subsequentStatementRanges))
        {
            // (6) If both if-body and subsequent statements have an unreachable end-point,
            //     it would be safe to just swap the two.
            //
            //    if (condition) {
            //      return;
            //    }
            //    break;
            //
            // ->
            //
            //  case constant:
            //    if (!condition) {
            //      break;
            //    }
            //    return;
            //
            return InvertIfStyle.IfWithoutElse_SwapIfBodyWithSubsequentStatements;
        }
 
        // (7) If none of the above worked, as the last resort we invert and generate an empty if-body.
        // 
        //  {
        //    if (condition) {
        //      f();
        //    }
        //    f();
        //  }
        //
        // ->
        //
        //  {
        //    if (!condition) {
        //    } else {
        //      f();
        //    }
        //    f();
        //  }
        //  
        return InvertIfStyle.IfWithoutElse_MoveIfBodyToElseClause;
    }
 
    private bool SingleSubsequentStatement(ImmutableArray<StatementRange> subsequentStatementRanges)
        => subsequentStatementRanges.Length == 1 && IsSingleStatementStatementRange(subsequentStatementRanges[0]);
 
    private async Task<Document> InvertIfAsync(
        Document document,
        TIfStatementSyntax ifNode,
        CancellationToken cancellationToken)
    {
        var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var sourceText = await document.GetValueTextAsync(cancellationToken).ConfigureAwait(false);
        var syntaxKinds = document.GetRequiredLanguageService<ISyntaxKindsService>();
 
        var invertIfStyle = GetInvertIfStyle(syntaxKinds, ifNode, semanticModel, out var subsequentSingleExitPoint);
        var generator = document.GetRequiredLanguageService<SyntaxGenerator>();
        return document.WithSyntaxRoot(
            GetRootWithInvertIfStatement(
                sourceText,
                root,
                ifNode,
                invertIfStyle,
                subsequentSingleExitPoint,
                negatedExpression: generator.Negate(
                    generator.SyntaxGeneratorInternal,
                    GetCondition(ifNode),
                    semanticModel,
                    cancellationToken),
                document.GetRequiredLanguageService<ISyntaxFactsService>()));
    }
 
    private static void AnalyzeSubsequentControlFlow(
        SemanticModel semanticModel,
        ImmutableArray<StatementRange> subsequentStatementRanges,
        out bool subsequentEndPointIsReachable,
        out SyntaxNode? subsequentSingleExitPoint)
    {
        subsequentEndPointIsReachable = true;
        subsequentSingleExitPoint = null;
 
        foreach (var statementRange in subsequentStatementRanges)
        {
            AnalyzeControlFlow(
                semanticModel,
                statementRange,
                out subsequentEndPointIsReachable,
                out subsequentSingleExitPoint);
 
            if (!subsequentEndPointIsReachable)
            {
                return;
            }
        }
    }
 
    private static void AnalyzeControlFlow(
        SemanticModel semanticModel,
        StatementRange statementRange,
        out bool endPointIsReachable,
        out SyntaxNode? singleExitPoint)
    {
        var flow = semanticModel.AnalyzeControlFlow(
            statementRange.FirstStatement,
            statementRange.LastStatement);
 
        endPointIsReachable = flow.EndPointIsReachable;
        singleExitPoint = flow.ExitPoints.Length == 1 ? flow.ExitPoints[0] : null;
    }
 
    private static bool SubsequentStatementsAreInTheSameBlock(
        TIfStatementSyntax ifNode,
        ImmutableArray<StatementRange> subsequentStatementRanges)
    {
        return subsequentStatementRanges.Length == 1 &&
               ifNode.Parent == subsequentStatementRanges[0].Parent;
    }
 
    private TSyntaxKind GetNearestParentJumpStatementKind(SyntaxNode ifNode)
    {
        foreach (var node in ifNode.Ancestors())
        {
            var jumpStatementRawKind = GetJumpStatementKind(node);
            if (jumpStatementRawKind != null)
                return jumpStatementRawKind.Value;
        }
 
        throw ExceptionUtilities.Unreachable();
    }
 
    private bool IsEmptyStatementRange(StatementRange statementRange)
    {
        if (!statementRange.IsEmpty)
        {
            var parent = statementRange.Parent;
            if (!IsStatementContainer(parent))
            {
                Debug.Assert(statementRange.FirstStatement == statementRange.LastStatement);
                return statementRange.FirstStatement.DescendantNodesAndSelf().All(IsNoOpSyntaxNode);
            }
 
            var statements = GetStatements(parent);
            var firstIndex = statements.IndexOf(statementRange.FirstStatement);
            var lastIndex = statements.IndexOf(statementRange.LastStatement);
            for (var i = firstIndex; i <= lastIndex; i++)
            {
                if (!statements[i].DescendantNodesAndSelf().All(IsNoOpSyntaxNode))
                {
                    return false;
                }
            }
        }
 
        return true;
    }
 
    private ImmutableArray<StatementRange> GetSubsequentStatementRanges(TIfStatementSyntax ifNode)
    {
        using var _ = ArrayBuilder<StatementRange>.GetInstance(out var builder);
 
        TStatementSyntax innerStatement = ifNode;
        foreach (var node in ifNode.Ancestors())
        {
            var nextStatement = GetNextStatement(innerStatement);
            if (nextStatement != null && IsStatementContainer(node))
                builder.Add(new StatementRange(nextStatement, GetStatements(node).Last()));
 
            if (!CanControlFlowOut(node))
            {
                // We no longer need to continue since other statements
                // are out of reach, as far as this analysis concerned.
                break;
            }
 
            if (IsExecutableStatement(node))
                innerStatement = (TStatementSyntax)node;
        }
 
        return builder.ToImmutableAndClear();
    }
 
    private SyntaxNode GetRootWithInvertIfStatement(
        SourceText text,
        SyntaxNode root,
        TIfStatementSyntax ifNode,
        InvertIfStyle invertIfStyle,
        SyntaxNode? subsequentSingleExitPoint,
        SyntaxNode negatedExpression,
        ISyntaxFacts syntaxFacts)
    {
        switch (invertIfStyle)
        {
            case InvertIfStyle.IfWithElse_SwapIfBodyWithElseBody:
                {
                    var updatedIf = UpdateIf(
                        text,
                        ifNode: ifNode,
                        condition: negatedExpression,
                        trueStatement: GetElseBody(ifNode)!,
                        falseStatement: GetIfBody(ifNode));
 
                    return root.ReplaceNode(ifNode, updatedIf);
                }
 
            case InvertIfStyle.IfWithoutElse_MoveIfBodyToElseClause:
                {
                    var updatedIf = UpdateIf(
                        text,
                        ifNode: ifNode,
                        condition: negatedExpression,
                        trueStatement: GetEmptyEmbeddedStatement(),
                        falseStatement: GetIfBody(ifNode));
 
                    return root.ReplaceNode(ifNode, updatedIf);
                }
 
            case InvertIfStyle.IfWithoutElse_WithNegatedCondition:
                {
                    var updatedIf = UpdateIf(
                        text,
                        ifNode: ifNode,
                        condition: negatedExpression,
                        trueStatement: GetIfBody(ifNode));
 
                    return root.ReplaceNode(ifNode, updatedIf);
                }
 
            case InvertIfStyle.IfWithoutElse_SwapIfBodyWithSubsequentStatements:
                {
                    var currentParent = ifNode.GetRequiredParent();
                    var statements = GetStatements(currentParent);
                    var index = statements.IndexOf(ifNode);
 
                    var statementsBeforeIf = statements.Take(index);
                    var statementsAfterIf = statements.Skip(index + 1).ToImmutableArray();
 
                    var ifBody = GetIfBody(ifNode);
 
                    var newTrailing = UnwrapBlock(ifBody).ToArray();
 
                    if (newTrailing.Length > 0)
                    {
                        // Get leading and trailing space of the expressions to preserve for the user
                        // ex:
                        // if (true)
                        // {
                        //    return true;
                        // }
                        //              // <<< preserve this line
                        // // preserve this comment
                        // return false;
                        var leadingTrivia = GetLeadingSpace(statementsAfterIf[0].GetLeadingTrivia()).Concat(GetTriviaAfterSpace(newTrailing[0].GetLeadingTrivia()));
                        var trailingTrivia = GetTriviaUntilSpace(newTrailing[^1].GetTrailingTrivia()).Concat(GetTrailingSpace(statementsAfterIf[^1].GetTrailingTrivia()));
                        newTrailing[0] = newTrailing[0].WithLeadingTrivia(leadingTrivia);
                        newTrailing[^1] = newTrailing[^1].WithTrailingTrivia(trailingTrivia);
                    }
 
                    var updatedIf = UpdateIf(
                        text,
                        ifNode: ifNode,
                        condition: negatedExpression,
                        trueStatement: AsEmbeddedStatement(statementsAfterIf, original: ifBody));
 
                    var updatedParent = WithStatements(
                        currentParent,
                        statementsBeforeIf.Concat(updatedIf).Concat(newTrailing));
 
                    return root.ReplaceNode(currentParent, updatedParent.WithAdditionalAnnotations(Formatter.Annotation));
                }
 
            case InvertIfStyle.IfWithoutElse_WithNearmostJumpStatement:
                {
                    var currentParent = ifNode.GetRequiredParent();
                    var statements = GetStatements(currentParent);
                    var index = statements.IndexOf(ifNode);
 
                    var ifBody = GetIfBody(ifNode);
                    var newIfBody = GetJumpStatement(GetNearestParentJumpStatementKind(ifNode));
 
                    var updatedIf = UpdateIf(
                        text,
                        ifNode: ifNode,
                        condition: negatedExpression,
                        trueStatement: AsEmbeddedStatement([newIfBody], original: ifBody));
 
                    var statementsBeforeIf = statements.Take(index);
 
                    var updatedParent = WithStatements(
                        currentParent,
                        statementsBeforeIf.Concat(updatedIf).Concat(UnwrapBlock(ifBody)));
 
                    return root.ReplaceNode(currentParent, updatedParent.WithAdditionalAnnotations(Formatter.Annotation));
                }
 
            case InvertIfStyle.IfWithoutElse_WithSubsequentExitPointStatement:
                {
                    Debug.Assert(subsequentSingleExitPoint is TStatementSyntax);
 
                    var currentParent = ifNode.GetRequiredParent();
                    var statements = GetStatements(currentParent);
                    var index = statements.IndexOf(ifNode);
 
                    var ifBody = GetIfBody(ifNode);
                    var newIfBody = (TStatementSyntax)subsequentSingleExitPoint;
 
                    var updatedIf = UpdateIf(
                        text,
                        ifNode: ifNode,
                        condition: negatedExpression,
                        trueStatement: AsEmbeddedStatement([newIfBody], ifBody));
 
                    var statementsBeforeIf = statements.Take(index);
 
                    var updatedParent = WithStatements(
                        currentParent,
                        statementsBeforeIf.Concat(updatedIf).Concat(UnwrapBlock(ifBody)).Concat(newIfBody));
 
                    return root.ReplaceNode(currentParent, updatedParent.WithAdditionalAnnotations(Formatter.Annotation));
                }
 
            case InvertIfStyle.IfWithoutElse_MoveSubsequentStatementsToIfBody:
                {
                    var currentParent = ifNode.GetRequiredParent();
                    var statements = GetStatements(currentParent);
                    var index = statements.IndexOf(ifNode);
 
                    var statementsBeforeIf = statements.Take(index);
                    var statementsAfterIf = statements.Skip(index + 1);
                    var ifBody = GetIfBody(ifNode);
 
                    var updatedIf = UpdateIf(
                        text,
                        ifNode: ifNode,
                        condition: negatedExpression,
                        trueStatement: AsEmbeddedStatement(statementsAfterIf, ifBody));
 
                    var updatedParent = WithStatements(
                        currentParent,
                        statementsBeforeIf.Concat(updatedIf));
 
                    return root.ReplaceNode(currentParent, updatedParent.WithAdditionalAnnotations(Formatter.Annotation));
                }
 
            case InvertIfStyle.IfWithoutElse_WithElseClause:
                {
                    var currentParent = ifNode.GetRequiredParent();
                    var statements = GetStatements(currentParent);
                    var index = statements.IndexOf(ifNode);
 
                    var statementsBeforeIf = statements.Take(index);
                    var statementsAfterIf = statements.Skip(index + 1);
 
                    var ifBody = GetIfBody(ifNode);
 
                    var updatedIf = UpdateIf(
                        text,
                        ifNode: ifNode,
                        condition: negatedExpression,
                        trueStatement: AsEmbeddedStatement(statementsAfterIf, ifBody),
                        falseStatement: ifBody);
 
                    var updatedParent = WithStatements(
                        currentParent,
                        statementsBeforeIf.Concat(updatedIf));
 
                    return root.ReplaceNode(currentParent, updatedParent.WithAdditionalAnnotations(Formatter.Annotation));
                }
 
            default:
                throw ExceptionUtilities.UnexpectedValue(invertIfStyle);
        }
 
        // 
        // local functions
        //
        IEnumerable<SyntaxTrivia> GetTriviaAfterSpace(IEnumerable<SyntaxTrivia> syntaxTrivias)
        {
            return syntaxTrivias.SkipWhile(syntaxFacts.IsWhitespaceOrEndOfLineTrivia);
        }
 
        IEnumerable<SyntaxTrivia> GetTriviaUntilSpace(IEnumerable<SyntaxTrivia> syntaxTrivias)
        {
            return GetTriviaAfterSpace(syntaxTrivias.Reverse()).Reverse();
        }
 
        IEnumerable<SyntaxTrivia> GetTrailingSpace(IEnumerable<SyntaxTrivia> syntaxTrivias)
        {
            return GetLeadingSpace(syntaxTrivias.Reverse()).Reverse();
        }
 
        IEnumerable<SyntaxTrivia> GetLeadingSpace(IEnumerable<SyntaxTrivia> syntaxTrivias)
        {
            return syntaxTrivias.TakeWhile(syntaxFacts.IsWhitespaceOrEndOfLineTrivia);
        }
    }
}