File: ConvertLinq\ConvertForEachToLinqQuery\CSharpConvertForEachToLinqQueryProvider.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.ConvertLinq.ConvertForEachToLinqQuery;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
using SyntaxNodeOrTokenExtensions = Microsoft.CodeAnalysis.Shared.Extensions.SyntaxNodeOrTokenExtensions;
 
namespace Microsoft.CodeAnalysis.CSharp.ConvertLinq.ConvertForEachToLinqQuery;
 
[ExportCodeRefactoringProvider(LanguageNames.CSharp, Name = PredefinedCodeRefactoringProviderNames.ConvertForEachToLinqQuery), Shared]
internal sealed class CSharpConvertForEachToLinqQueryProvider
    : AbstractConvertForEachToLinqQueryProvider<ForEachStatementSyntax, StatementSyntax>
{
    [ImportingConstructor]
    [SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
    public CSharpConvertForEachToLinqQueryProvider()
    {
    }
 
    protected override IConverter<ForEachStatementSyntax, StatementSyntax> CreateDefaultConverter(
        ForEachInfo<ForEachStatementSyntax, StatementSyntax> forEachInfo)
        => new DefaultConverter(forEachInfo);
 
    protected override ForEachInfo<ForEachStatementSyntax, StatementSyntax> CreateForEachInfo(
        ForEachStatementSyntax forEachStatement,
        SemanticModel semanticModel,
        bool convertLocalDeclarations)
    {
        var identifiersBuilder = ArrayBuilder<SyntaxToken>.GetInstance();
        identifiersBuilder.Add(forEachStatement.Identifier);
        var convertingNodesBuilder = ArrayBuilder<ExtendedSyntaxNode>.GetInstance();
        IEnumerable<StatementSyntax>? statementsCannotBeConverted = null;
        var trailingTokensBuilder = ArrayBuilder<SyntaxToken>.GetInstance();
        var currentLeadingTokens = ArrayBuilder<SyntaxToken>.GetInstance();
 
        var current = forEachStatement.Statement;
        // Traverse descendants of the forEachStatement.
        // If a statement traversed can be converted into a query clause, 
        //  a. Add it to convertingNodesBuilder.
        //  b. set the current to its nested statement and proceed.
        // Otherwise, set statementsCannotBeConverted and stop processing.
        while (statementsCannotBeConverted == null)
        {
            switch (current.Kind())
            {
                case SyntaxKind.Block:
                    var block = (BlockSyntax)current;
                    // Keep comment trivia from braces to attach them to the qeury created.
                    currentLeadingTokens.Add(block.OpenBraceToken);
                    trailingTokensBuilder.Add(block.CloseBraceToken);
                    var array = block.Statements.ToArray();
                    if (array.Length > 0)
                    {
                        // All except the last one can be local declaration statements like
                        // {
                        //   var a = 0;
                        //   var b = 0;
                        //   if (x != y) <- this is the last one in the block. 
                        // We can support it to be a complex foreach or if or whatever. So, set the current to it.
                        //   ...
                        // }
                        for (var i = 0; i < array.Length - 1; i++)
                        {
                            var statement = array[i];
                            if (!(statement is LocalDeclarationStatementSyntax localDeclarationStatement &&
                                TryProcessLocalDeclarationStatement(localDeclarationStatement)))
                            {
                                // If this one is a local function declaration or has an empty initializer, stop processing.
                                statementsCannotBeConverted = array.Skip(i).ToArray();
                                break;
                            }
                        }
 
                        // Process the last statement separately.
                        current = array.Last();
                    }
                    else
                    {
                        // Silly case: the block is empty. Stop processing.
                        statementsCannotBeConverted = [];
                    }
 
                    break;
 
                case SyntaxKind.ForEachStatement:
                    // foreach can always be converted to a from clause.
                    var currentForEachStatement = (ForEachStatementSyntax)current;
                    identifiersBuilder.Add(currentForEachStatement.Identifier);
                    convertingNodesBuilder.Add(new ExtendedSyntaxNode(currentForEachStatement, currentLeadingTokens.ToImmutableAndFree(), []));
                    currentLeadingTokens = ArrayBuilder<SyntaxToken>.GetInstance();
                    // Proceed the loop with the nested statement.
                    current = currentForEachStatement.Statement;
                    break;
 
                case SyntaxKind.IfStatement:
                    // Prepare conversion of 'if (condition)' into where clauses.
                    // Do not support if-else statements in the conversion.
                    var ifStatement = (IfStatementSyntax)current;
                    if (ifStatement.Else == null)
                    {
                        convertingNodesBuilder.Add(new ExtendedSyntaxNode(
                            ifStatement, currentLeadingTokens.ToImmutableAndFree(), []));
                        currentLeadingTokens = ArrayBuilder<SyntaxToken>.GetInstance();
                        // Proceed the loop with the nested statement.
                        current = ifStatement.Statement;
                        break;
                    }
                    else
                    {
                        statementsCannotBeConverted = [current];
                        break;
                    }
 
                case SyntaxKind.LocalDeclarationStatement:
                    // This is a situation with "var a = something;" is the innermost statements inside the loop.
                    var localDeclaration = (LocalDeclarationStatementSyntax)current;
                    if (TryProcessLocalDeclarationStatement(localDeclaration))
                    {
                        statementsCannotBeConverted = [];
                    }
                    else
                    {
                        // As above, if there is an empty initializer, stop processing.
                        statementsCannotBeConverted = [current];
                    }
 
                    break;
 
                case SyntaxKind.EmptyStatement:
                    // The innermost statement is an empty statement, stop processing
                    // Example:
                    // foreach(...)
                    // {
                    //    ;<- empty statement
                    // }
                    statementsCannotBeConverted = [];
                    break;
 
                default:
                    // If no specific case found, stop processing.
                    statementsCannotBeConverted = [current];
                    break;
            }
        }
 
        // Trailing tokens are collected in the reverse order: from external block down to internal ones. Reverse them.
        trailingTokensBuilder.ReverseContents();
 
        return new ForEachInfo<ForEachStatementSyntax, StatementSyntax>(
            forEachStatement,
            semanticModel,
            convertingNodesBuilder.ToImmutableAndFree(),
            identifiersBuilder.ToImmutableAndFree(),
            [.. statementsCannotBeConverted],
            currentLeadingTokens.ToImmutableAndFree(),
            trailingTokensBuilder.ToImmutableAndFree());
 
        // Try to prepare variable declarations to be converted into separate let clauses.
        bool TryProcessLocalDeclarationStatement(LocalDeclarationStatementSyntax localDeclarationStatement)
        {
            if (!convertLocalDeclarations)
            {
                return false;
            }
 
            // Do not support declarations without initialization.
            // int a = 0, b, c = 0;
            if (localDeclarationStatement.Declaration.Variables.All(variable => variable.Initializer != null))
            {
                var localDeclarationLeadingTrivia = new IEnumerable<SyntaxTrivia>[] {
                currentLeadingTokens.ToImmutableAndFree().GetTrivia(),
                localDeclarationStatement.Declaration.Type.GetLeadingTrivia(),
                localDeclarationStatement.Declaration.Type.GetTrailingTrivia() }.Flatten();
                currentLeadingTokens = ArrayBuilder<SyntaxToken>.GetInstance();
                var localDeclarationTrailingTrivia = SyntaxNodeOrTokenExtensions.GetTrivia(localDeclarationStatement.SemicolonToken);
                var separators = localDeclarationStatement.Declaration.Variables.GetSeparators().ToArray();
                for (var i = 0; i < localDeclarationStatement.Declaration.Variables.Count; i++)
                {
                    var variable = localDeclarationStatement.Declaration.Variables[i];
                    convertingNodesBuilder.Add(new ExtendedSyntaxNode(
                        variable,
                        i == 0 ? localDeclarationLeadingTrivia : separators[i - 1].TrailingTrivia,
                        i == localDeclarationStatement.Declaration.Variables.Count - 1
                            ? localDeclarationTrailingTrivia
                            : separators[i].LeadingTrivia));
                    identifiersBuilder.Add(variable.Identifier);
                }
 
                return true;
            }
 
            return false;
        }
    }
 
    protected override bool TryBuildSpecificConverter(
        ForEachInfo<ForEachStatementSyntax, StatementSyntax> forEachInfo,
        SemanticModel semanticModel,
        StatementSyntax statementCannotBeConverted,
        CancellationToken cancellationToken,
        [NotNullWhen(true)] out IConverter<ForEachStatementSyntax, StatementSyntax>? converter)
    {
        switch (statementCannotBeConverted.Kind())
        {
            case SyntaxKind.ExpressionStatement:
                var expresisonStatement = (ExpressionStatementSyntax)statementCannotBeConverted;
                var expression = expresisonStatement.Expression;
                switch (expression.Kind())
                {
                    case SyntaxKind.PostIncrementExpression:
                        // Input:
                        // foreach (var x in a)
                        // {
                        //     ...
                        //     c++;
                        // }
                        // Output:
                        // (from x in a ... select x).Count();
                        // Here we put SyntaxFactory.IdentifierName(forEachStatement.Identifier) ('x' in the example) 
                        // into the select clause.
                        var postfixUnaryExpression = (PostfixUnaryExpressionSyntax)expression;
                        var operand = postfixUnaryExpression.Operand;
                        converter = new ToCountConverter(
                            forEachInfo,
                            selectExpression: SyntaxFactory.IdentifierName(forEachInfo.ForEachStatement.Identifier),
                            modifyingExpression: operand,
                            trivia: SyntaxNodeOrTokenExtensions.GetTrivia(
                                operand, postfixUnaryExpression.OperatorToken, expresisonStatement.SemicolonToken));
                        return true;
 
                    case SyntaxKind.InvocationExpression:
                        var invocationExpression = (InvocationExpressionSyntax)expression;
                        // Check that there is 'list.Add(item)'.
                        if (invocationExpression.Expression is MemberAccessExpressionSyntax memberAccessExpression &&
                            semanticModel.GetSymbolInfo(memberAccessExpression, cancellationToken).Symbol is IMethodSymbol methodSymbol &&
                            TypeSymbolIsList(methodSymbol.ContainingType, semanticModel) &&
                            methodSymbol.Name == nameof(IList.Add) &&
                            methodSymbol.Parameters.Length == 1 &&
                            invocationExpression.ArgumentList.Arguments.Count == 1)
                        {
                            // Input:
                            // foreach (var x in a)
                            // {
                            //     ...
                            //     list.Add(...);
                            // }
                            // Output:
                            // (from x in a ... select x).ToList();
                            var selectExpression = invocationExpression.ArgumentList.Arguments.Single().Expression;
                            converter = new ToToListConverter(
                                forEachInfo,
                                selectExpression,
                                modifyingExpression: memberAccessExpression.Expression,
                                trivia: SyntaxNodeOrTokenExtensions.GetTrivia(
                                    memberAccessExpression,
                                    invocationExpression.ArgumentList.OpenParenToken,
                                    invocationExpression.ArgumentList.CloseParenToken,
                                    expresisonStatement.SemicolonToken));
                            return true;
                        }
 
                        break;
                }
 
                break;
 
            case SyntaxKind.YieldReturnStatement:
                var memberDeclarationSymbol = semanticModel.GetEnclosingSymbol(
                    forEachInfo.ForEachStatement.SpanStart, cancellationToken)!;
 
                // Using Single() is valid even for partial methods.
                var memberDeclarationSyntax = memberDeclarationSymbol.DeclaringSyntaxReferences.Single().GetSyntax();
 
                var yieldStatementsCount = memberDeclarationSyntax.DescendantNodes().OfType<YieldStatementSyntax>()
                    // Exclude yield statements from nested local functions.
                    .Where(statement => Equals(semanticModel.GetEnclosingSymbol(
                        statement.SpanStart, cancellationToken), memberDeclarationSymbol)).Count();
 
                if (forEachInfo.ForEachStatement?.Parent is BlockSyntax block &&
                    block.Parent == memberDeclarationSyntax)
                {
                    // Check that 
                    // a. There are either just a single 'yield return' or 'yield return' with 'yield break' just after.
                    // b. Those foreach and 'yield break' (if exists) are last statements in the method (do not count local function declaration statements).
                    var statementsOnBlockWithForEach = block.Statements
                        .Where(statement => statement.Kind() != SyntaxKind.LocalFunctionStatement).ToArray();
                    var lastNonLocalFunctionStatement = statementsOnBlockWithForEach.Last();
                    if (yieldStatementsCount == 1 && lastNonLocalFunctionStatement == forEachInfo.ForEachStatement)
                    {
                        converter = new YieldReturnConverter(
                            forEachInfo,
                            (YieldStatementSyntax)statementCannotBeConverted,
                            yieldBreakStatement: null);
                        return true;
                    }
 
                    // foreach()
                    // {
                    //    yield return ...;
                    // }
                    // yield break;
                    // end of member
                    if (yieldStatementsCount == 2 &&
                        lastNonLocalFunctionStatement.Kind() == SyntaxKind.YieldBreakStatement &&
                        !lastNonLocalFunctionStatement.ContainsDirectives &&
                        statementsOnBlockWithForEach[statementsOnBlockWithForEach.Length - 2] == forEachInfo.ForEachStatement)
                    {
                        // This removes the yield break.
                        converter = new YieldReturnConverter(
                            forEachInfo,
                            (YieldStatementSyntax)statementCannotBeConverted,
                            yieldBreakStatement: (YieldStatementSyntax)lastNonLocalFunctionStatement);
                        return true;
                    }
                }
 
                break;
        }
 
        converter = null;
        return false;
    }
 
    protected override SyntaxNode AddLinqUsing(
        IConverter<ForEachStatementSyntax, StatementSyntax> converter,
        SemanticModel semanticModel,
        SyntaxNode root)
    {
        var namespaces = semanticModel.GetUsingNamespacesInScope(converter.ForEachInfo.ForEachStatement);
        if (!namespaces.Any(namespaceSymbol => namespaceSymbol.Name == nameof(System.Linq) &&
            namespaceSymbol.ContainingNamespace.Name == nameof(System)) &&
            root is CompilationUnitSyntax compilationUnit)
        {
            return compilationUnit.AddUsings(SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("System.Linq")));
        }
 
        return root;
    }
 
    internal static bool TypeSymbolIsList(ITypeSymbol typeSymbol, SemanticModel semanticModel)
        => Equals(typeSymbol?.OriginalDefinition, semanticModel.Compilation.ListOfTType());
}