File: ConvertLinq\ConvertForEachToLinqQuery\AbstractToMethodConverter.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis.ConvertLinq.ConvertForEachToLinqQuery;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;
using SyntaxNodeOrTokenExtensions = Microsoft.CodeAnalysis.Shared.Extensions.SyntaxNodeOrTokenExtensions;
 
namespace Microsoft.CodeAnalysis.CSharp.ConvertLinq.ConvertForEachToLinqQuery;
 
/// <summary>
/// Provides a conversion to query.Method() like query.ToList(), query.Count().
/// </summary>
internal abstract class AbstractToMethodConverter(
    ForEachInfo<ForEachStatementSyntax, StatementSyntax> forEachInfo,
    ExpressionSyntax selectExpression,
    ExpressionSyntax modifyingExpression,
    SyntaxTrivia[] trivia) : AbstractConverter(forEachInfo)
{
    // It is "item" for for "list.Add(item)"
    // It can be anything for "counter++". It will be ingored in the case.
    private readonly ExpressionSyntax _selectExpression = selectExpression;
 
    // It is "list" for "list.Add(item)"
    // It is "counter" for "counter++"
    private readonly ExpressionSyntax _modifyingExpression = modifyingExpression;
 
    // Trivia around "counter++;" or "list.Add(item);". Required to keep comments.
    private readonly SyntaxTrivia[] _trivia = trivia;
 
    protected abstract string MethodName { get; }
 
    protected abstract bool CanReplaceInitialization(ExpressionSyntax expressionSyntax, CancellationToken cancellationToken);
 
    protected abstract StatementSyntax CreateDefaultStatement(ExpressionSyntax queryOrLinqInvocationExpression, ExpressionSyntax expression);
 
    public override void Convert(SyntaxEditor editor, bool convertToQuery, CancellationToken cancellationToken)
    {
        var queryOrLinqInvocationExpression = CreateQueryExpressionOrLinqInvocation(
            _selectExpression, [], [], convertToQuery);
 
        var previous = ForEachInfo.ForEachStatement.GetPreviousStatement();
 
        if (previous != null && !previous.ContainsDirectives)
        {
            switch (previous.Kind())
            {
                case SyntaxKind.LocalDeclarationStatement:
                    var variables = ((LocalDeclarationStatementSyntax)previous).Declaration.Variables;
                    var lastDeclaration = variables.Last();
                    // Check if 
                    // var ...., list = new List<T>(); or var ..., counter = 0;
                    // is just before the foreach.
                    // If so, join the declaration with the query.
                    if (_modifyingExpression is IdentifierNameSyntax identifierName &&
                        lastDeclaration.Identifier.ValueText.Equals(identifierName.Identifier.ValueText) &&
                        CanReplaceInitialization(lastDeclaration.Initializer.Value, cancellationToken))
                    {
                        Convert(lastDeclaration.Initializer.Value, variables.Count == 1 ? previous : lastDeclaration);
                        return;
                    }
 
                    break;
 
                case SyntaxKind.ExpressionStatement:
                    // Check if 
                    // list = new List<T>(); or counter = 0;
                    // is just before the foreach.
                    // If so, join the assignment with the query.
                    if (((ExpressionStatementSyntax)previous).Expression is AssignmentExpressionSyntax assignmentExpression &&
                        SymbolEquivalenceComparer.Instance.Equals(
                            ForEachInfo.SemanticModel.GetSymbolInfo(assignmentExpression.Left, cancellationToken).Symbol,
                            ForEachInfo.SemanticModel.GetSymbolInfo(_modifyingExpression, cancellationToken).Symbol) &&
                        CanReplaceInitialization(assignmentExpression.Right, cancellationToken))
                    {
                        Convert(assignmentExpression.Right, previous);
                        return;
                    }
 
                    break;
            }
        }
 
        // At least, we already can convert to 
        // list.AddRange(query) or counter += query.Count();
        editor.ReplaceNode(
            ForEachInfo.ForEachStatement,
            CreateDefaultStatement(queryOrLinqInvocationExpression, _modifyingExpression).WithAdditionalAnnotations(Formatter.Annotation));
 
        return;
 
        void Convert(ExpressionSyntax replacingExpression, SyntaxNode nodeToRemoveIfFollowedByReturn)
        {
            SyntaxTrivia[] leadingTrivia;
 
            // Check if expressionAssigning is followed by a return statement.
            var expresisonSymbol = ForEachInfo.SemanticModel.GetSymbolInfo(_modifyingExpression, cancellationToken).Symbol;
            if (expresisonSymbol is ILocalSymbol &&
                ForEachInfo.ForEachStatement.GetNextStatement() is ReturnStatementSyntax returnStatement &&
                !returnStatement.ContainsDirectives &&
                SymbolEquivalenceComparer.Instance.Equals(
                    expresisonSymbol, ForEachInfo.SemanticModel.GetSymbolInfo(returnStatement.Expression, cancellationToken).Symbol))
            {
                // Input:
                // var list = new List<T>(); or var counter = 0;
                // foreach(...)
                // {
                //     ...
                //     ...
                //     list.Add(item); or counter++;
                //  }
                //  return list; or return counter;
                //
                //  Output:
                //  return queryGenerated.ToList(); or return queryGenerated.Count();
                replacingExpression = returnStatement.Expression;
                leadingTrivia = [.. GetTriviaFromNode(nodeToRemoveIfFollowedByReturn), .. SyntaxNodeOrTokenExtensions.GetTrivia(replacingExpression)];
                editor.RemoveNode(nodeToRemoveIfFollowedByReturn);
            }
            else
            {
                leadingTrivia = SyntaxNodeOrTokenExtensions.GetTrivia(replacingExpression);
            }
 
            editor.ReplaceNode(
                replacingExpression,
                CreateInvocationExpression(queryOrLinqInvocationExpression)
                    .WithCommentsFrom(leadingTrivia, _trivia).KeepCommentsAndAddElasticMarkers());
            editor.RemoveNode(ForEachInfo.ForEachStatement);
        }
 
        static SyntaxTrivia[] GetTriviaFromVariableDeclarator(VariableDeclaratorSyntax variableDeclarator)
            => SyntaxNodeOrTokenExtensions.GetTrivia(variableDeclarator.Identifier, variableDeclarator.Initializer.EqualsToken, variableDeclarator.Initializer.Value);
        static SyntaxTrivia[] GetTriviaFromNode(SyntaxNode node)
        {
            switch (node.Kind())
            {
                case SyntaxKind.LocalDeclarationStatement:
 
                    var localDeclaration = (LocalDeclarationStatementSyntax)node;
                    if (localDeclaration.Declaration.Variables.Count != 1)
                    {
                        throw ExceptionUtilities.Unreachable();
                    }
 
                    return new IEnumerable<SyntaxTrivia>[] {
                        SyntaxNodeOrTokenExtensions.GetTrivia(localDeclaration.Declaration.Type),
                        GetTriviaFromVariableDeclarator(localDeclaration.Declaration.Variables[0]),
                        SyntaxNodeOrTokenExtensions.GetTrivia(localDeclaration.SemicolonToken)}.Flatten().ToArray();
 
                case SyntaxKind.VariableDeclarator:
                    return GetTriviaFromVariableDeclarator((VariableDeclaratorSyntax)node);
 
                case SyntaxKind.ExpressionStatement:
                    if (((ExpressionStatementSyntax)node).Expression is AssignmentExpressionSyntax assignmentExpression)
                    {
                        return SyntaxNodeOrTokenExtensions.GetTrivia(
                            assignmentExpression.Left, assignmentExpression.OperatorToken, assignmentExpression.Right);
                    }
 
                    break;
            }
 
            throw ExceptionUtilities.Unreachable();
        }
    }
 
    // query => query.Method()
    // like query.Count() or query.ToList()
    protected InvocationExpressionSyntax CreateInvocationExpression(ExpressionSyntax queryOrLinqInvocationExpression)
        => SyntaxFactory.InvocationExpression(
                SyntaxFactory.MemberAccessExpression(
                    SyntaxKind.SimpleMemberAccessExpression,
                    queryOrLinqInvocationExpression.Parenthesize(),
                    SyntaxFactory.IdentifierName(MethodName))).WithAdditionalAnnotations(Formatter.Annotation);
}