File: ConvertLinq\ConvertForEachToLinqQuery\AbstractConverter.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis.ConvertLinq.ConvertForEachToLinqQuery;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
using SyntaxNodeOrTokenExtensions = Microsoft.CodeAnalysis.Shared.Extensions.SyntaxNodeOrTokenExtensions;
 
namespace Microsoft.CodeAnalysis.CSharp.ConvertLinq.ConvertForEachToLinqQuery;
 
using static CSharpSyntaxTokens;
using static SyntaxFactory;
 
internal abstract class AbstractConverter(ForEachInfo<ForEachStatementSyntax, StatementSyntax> forEachInfo) : IConverter<ForEachStatementSyntax, StatementSyntax>
{
    public ForEachInfo<ForEachStatementSyntax, StatementSyntax> ForEachInfo { get; } = forEachInfo;
 
    public abstract void Convert(SyntaxEditor editor, bool convertToQuery, CancellationToken cancellationToken);
 
    /// <summary>
    /// Creates a query expression or a linq invocation expression.
    /// </summary>
    /// <param name="selectExpression">expression to be used into the last Select in the query expression or linq invocation.</param>
    /// <param name="leadingTokensForSelect">extra leading tokens to be added to the select clause</param>
    /// <param name="trailingTokensForSelect">extra trailing tokens to be added to the select clause</param>
    /// <param name="convertToQuery">Flag indicating if a query expression should be generated</param>
    /// <returns></returns>
    protected ExpressionSyntax CreateQueryExpressionOrLinqInvocation(
        ExpressionSyntax selectExpression,
        IEnumerable<SyntaxToken> leadingTokensForSelect,
        IEnumerable<SyntaxToken> trailingTokensForSelect,
        bool convertToQuery)
    {
        return convertToQuery
            ? CreateQueryExpression(selectExpression, leadingTokensForSelect, trailingTokensForSelect)
            : CreateLinqInvocationOrSimpleExpression(selectExpression, leadingTokensForSelect, trailingTokensForSelect);
    }
 
    /// <summary>
    /// Creates a query expression.
    /// </summary>
    /// <param name="selectExpression">expression to be used into the last 'select ...' in the query expression</param>
    /// <param name="leadingTokensForSelect">extra leading tokens to be added to the select clause</param>
    /// <param name="trailingTokensForSelect">extra trailing tokens to be added to the select clause</param>
    /// <returns></returns>
    private QueryExpressionSyntax CreateQueryExpression(
        ExpressionSyntax selectExpression,
        IEnumerable<SyntaxToken> leadingTokensForSelect,
        IEnumerable<SyntaxToken> trailingTokensForSelect)
        => QueryExpression(
            CreateFromClause(ForEachInfo.ForEachStatement, ForEachInfo.LeadingTokens.GetTrivia(), []),
            QueryBody(
                [.. ForEachInfo.ConvertingExtendedNodes.Select(node => CreateQueryClause(node))],
                SelectClause(selectExpression)
                    .WithCommentsFrom(leadingTokensForSelect, ForEachInfo.TrailingTokens.Concat(trailingTokensForSelect)),
                continuation: null)) // The current coverage of foreach statements to support does not need to use query continuations.                                                                                                           
        .WithAdditionalAnnotations(Formatter.Annotation);
 
    private static QueryClauseSyntax CreateQueryClause(ExtendedSyntaxNode node)
    {
        switch (node.Node.Kind())
        {
            case SyntaxKind.VariableDeclarator:
                var variable = (VariableDeclaratorSyntax)node.Node;
                return LetClause(
                            LetKeyword,
                            variable.Identifier,
                            variable.Initializer.EqualsToken,
                            variable.Initializer.Value)
                        .WithCommentsFrom(node.ExtraLeadingComments, node.ExtraTrailingComments);
 
            case SyntaxKind.ForEachStatement:
                return CreateFromClause((ForEachStatementSyntax)node.Node, node.ExtraLeadingComments, node.ExtraTrailingComments);
 
            case SyntaxKind.IfStatement:
                var ifStatement = (IfStatementSyntax)node.Node;
                return WhereClause(
                            WhereKeyword
                                .WithCommentsFrom(ifStatement.IfKeyword.LeadingTrivia, ifStatement.IfKeyword.TrailingTrivia),
                            ifStatement.Condition.WithCommentsFrom(ifStatement.OpenParenToken, ifStatement.CloseParenToken))
                        .WithCommentsFrom(node.ExtraLeadingComments, node.ExtraTrailingComments);
        }
 
        throw ExceptionUtilities.Unreachable();
    }
 
    private static FromClauseSyntax CreateFromClause(
        ForEachStatementSyntax forEachStatement,
        IEnumerable<SyntaxTrivia> extraLeadingTrivia,
        IEnumerable<SyntaxTrivia> extraTrailingTrivia)
        => FromClause(
                fromKeyword: FromKeyword
                                .WithCommentsFrom(
                                    forEachStatement.ForEachKeyword.LeadingTrivia,
                                    forEachStatement.ForEachKeyword.TrailingTrivia,
                                    forEachStatement.OpenParenToken)
                                .KeepCommentsAndAddElasticMarkers(),
                type: forEachStatement.Type.IsVar ? null : forEachStatement.Type,
                identifier: forEachStatement.Type.IsVar
                            ? forEachStatement.Identifier.WithPrependedLeadingTrivia(
                                SyntaxNodeOrTokenExtensions.GetTrivia(forEachStatement.Type.GetFirstToken())
                                .FilterComments(addElasticMarker: false))
                            : forEachStatement.Identifier,
                inKeyword: forEachStatement.InKeyword.KeepCommentsAndAddElasticMarkers(),
                expression: forEachStatement.Expression)
                    .WithCommentsFrom(extraLeadingTrivia, extraTrailingTrivia, forEachStatement.CloseParenToken);
 
    /// <summary>
    /// Creates a linq invocation expression.
    /// </summary>
    /// <param name="selectExpression">expression to be used in the last 'Select' invocation</param>
    /// <param name="leadingTokensForSelect">extra leading tokens to be added to the select clause</param>
    /// <param name="trailingTokensForSelect">extra trailing tokens to be added to the select clause</param>
    /// <returns></returns>
    private ExpressionSyntax CreateLinqInvocationOrSimpleExpression(
        ExpressionSyntax selectExpression,
        IEnumerable<SyntaxToken> leadingTokensForSelect,
        IEnumerable<SyntaxToken> trailingTokensForSelect)
    {
        var foreachStatement = ForEachInfo.ForEachStatement;
        selectExpression = selectExpression.WithCommentsFrom(leadingTokensForSelect, ForEachInfo.TrailingTokens.Concat(trailingTokensForSelect));
        var currentExtendedNodeIndex = 0;
 
        return CreateLinqInvocationOrSimpleExpression(
            foreachStatement,
            receiverForInvocation: foreachStatement.Expression,
            selectExpression: selectExpression,
            leadingCommentsTrivia: ForEachInfo.LeadingTokens.GetTrivia(),
            trailingCommentsTrivia: [],
            currentExtendedNodeIndex: ref currentExtendedNodeIndex)
            .WithAdditionalAnnotations(Formatter.Annotation);
    }
 
    private ExpressionSyntax CreateLinqInvocationOrSimpleExpression(
        ForEachStatementSyntax forEachStatement,
        ExpressionSyntax receiverForInvocation,
        IEnumerable<SyntaxTrivia> leadingCommentsTrivia,
        IEnumerable<SyntaxTrivia> trailingCommentsTrivia,
        ExpressionSyntax selectExpression,
        ref int currentExtendedNodeIndex)
    {
        leadingCommentsTrivia = forEachStatement.ForEachKeyword.GetAllTrivia().Concat(leadingCommentsTrivia);
 
        // Recursively create linq invocations, possibly updating the receiver (Where invocations), to get the inner expression for
        // the lambda body for the linq invocation to be created for this foreach statement. For example:
        //
        // INPUT:
        //   foreach (var n1 in c1)
        //      foreach (var n2 in c2)
        //          if (n1 > n2)
        //              yield return n1 + n2;
        //
        // OUTPUT:
        //   c1.SelectMany(n1 => c2.Where(n2 => n1 > n2).Select(n2 => n1 + n2))
        //
        var hasForEachChild = false;
        var lambdaBody = CreateLinqInvocationForExtendedNode(selectExpression, ref currentExtendedNodeIndex, ref receiverForInvocation, ref hasForEachChild);
        var lambda = SimpleLambdaExpression(
            Parameter(
                forEachStatement.Identifier.WithPrependedLeadingTrivia(
                SyntaxNodeOrTokenExtensions.GetTrivia(forEachStatement.Type.GetFirstToken())
                    .FilterComments(addElasticMarker: false))),
            lambdaBody)
            .WithCommentsFrom(leadingCommentsTrivia, trailingCommentsTrivia,
                forEachStatement.OpenParenToken, forEachStatement.InKeyword, forEachStatement.CloseParenToken);
 
        // Create Select or SelectMany linq invocation for this foreach statement. For example:
        //
        // INPUT:
        //   foreach (var n1 in c1)
        //      ...
        //
        // OUTPUT:
        //   c1.Select(n1 => ...
        //      OR
        //   c1.SelectMany(n1 => ...
        //
 
        var invokedMethodName = !hasForEachChild ? nameof(Enumerable.Select) : nameof(Enumerable.SelectMany);
 
        // Avoid `.Select(x => x)`
        if (invokedMethodName == nameof(Enumerable.Select) &&
            lambdaBody is IdentifierNameSyntax identifier &&
            identifier.Identifier.ValueText == forEachStatement.Identifier.ValueText)
        {
            // Because we're dropping the lambda, any comments associated with it need to be preserved.
 
            var droppedTrivia = new List<SyntaxTrivia>();
            foreach (var token in lambda.DescendantTokens())
            {
                droppedTrivia.AddRange(token.GetAllTrivia().Where(t => !t.IsWhitespace()));
            }
 
            return receiverForInvocation.WithAppendedTrailingTrivia(droppedTrivia);
        }
 
        return InvocationExpression(
            MemberAccessExpression(
                SyntaxKind.SimpleMemberAccessExpression,
                receiverForInvocation.Parenthesize(),
                IdentifierName(invokedMethodName)),
            ArgumentList([Argument(lambda)]));
    }
 
    /// <summary>
    /// Creates a linq invocation expression for the <see cref="ForEachInfo{ForEachStatementSyntax, StatementSyntax}.ConvertingExtendedNodes"/> node at the given index <paramref name="extendedNodeIndex"/>
    /// or returns the <paramref name="selectExpression"/> if all extended nodes have been processed.
    /// </summary>
    /// <param name="selectExpression">Innermost select expression</param>
    /// <param name="extendedNodeIndex">Index into <see cref="ForEachInfo{ForEachStatementSyntax, StatementSyntax}.ConvertingExtendedNodes"/> to be processed and updated.</param>
    /// <param name="receiver">Receiver for the generated linq invocation. Updated when processing an if statement.</param>
    /// <param name="hasForEachChild">Flag indicating if any of the processed <see cref="ForEachInfo{ForEachStatementSyntax, StatementSyntax}.ConvertingExtendedNodes"/> is a <see cref="ForEachStatementSyntax"/>.</param>
    private ExpressionSyntax CreateLinqInvocationForExtendedNode(
        ExpressionSyntax selectExpression,
        ref int extendedNodeIndex,
        ref ExpressionSyntax receiver,
        ref bool hasForEachChild)
    {
        // Check if we have converted all the descendant foreach/if statements.
        // If so, we return the select expression.
        if (extendedNodeIndex == ForEachInfo.ConvertingExtendedNodes.Length)
        {
            return selectExpression;
        }
 
        // Otherwise, convert the next foreach/if statement into a linq invocation.
        var node = ForEachInfo.ConvertingExtendedNodes[extendedNodeIndex];
        switch (node.Node.Kind())
        {
            // Nested ForEach statement is converted into a nested Select or SelectMany linq invocation. For example:
            //
            // INPUT:
            //   foreach (var n1 in c1)
            //      foreach (var n2 in c2)
            //          ...
            //
            // OUTPUT:
            //   c1.SelectMany(n1 => c2.Select(n2 => ...
            //
            case SyntaxKind.ForEachStatement:
                hasForEachChild = true;
                var foreachStatement = (ForEachStatementSyntax)node.Node;
                ++extendedNodeIndex;
                return CreateLinqInvocationOrSimpleExpression(
                    foreachStatement,
                    receiverForInvocation: foreachStatement.Expression,
                    selectExpression: selectExpression,
                    leadingCommentsTrivia: node.ExtraLeadingComments,
                    trailingCommentsTrivia: node.ExtraTrailingComments,
                    currentExtendedNodeIndex: ref extendedNodeIndex);
 
            // Nested If statement is converted into a Where method invocation on the current receiver. For example:
            //
            // INPUT:
            //   foreach (var n1 in c1)
            //      if (n1 > 0)
            //          ...
            //
            // OUTPUT:
            //   c1.Where(n1 => n1 > 0).Select(n1 => ...
            //
            case SyntaxKind.IfStatement:
                var ifStatement = (IfStatementSyntax)node.Node;
                var parentForEachStatement = ifStatement.GetAncestor<ForEachStatementSyntax>();
                var lambdaParameter = Parameter(Identifier(parentForEachStatement.Identifier.ValueText));
                var lambda = SimpleLambdaExpression(
                    Parameter(
                        Identifier(parentForEachStatement.Identifier.ValueText)),
                    ifStatement.Condition.WithCommentsFrom(ifStatement.OpenParenToken, ifStatement.CloseParenToken))
                    .WithCommentsFrom(ifStatement.IfKeyword.GetAllTrivia().Concat(node.ExtraLeadingComments), node.ExtraTrailingComments);
 
                receiver = InvocationExpression(
                    MemberAccessExpression(
                        SyntaxKind.SimpleMemberAccessExpression,
                        receiver.Parenthesize(),
                        IdentifierName(nameof(Enumerable.Where))),
                    ArgumentList([Argument(lambda)]));
 
                ++extendedNodeIndex;
                return CreateLinqInvocationForExtendedNode(selectExpression, ref extendedNodeIndex, ref receiver, ref hasForEachChild);
        }
 
        throw ExceptionUtilities.Unreachable();
    }
}