File: ConvertLinq\CSharpConvertLinqQueryToForEachProvider.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Composition;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.ConvertLinq;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Simplification;
 
namespace Microsoft.CodeAnalysis.CSharp.ConvertLinq;
 
using static CSharpSyntaxTokens;
using static SyntaxFactory;
 
[ExportCodeRefactoringProvider(LanguageNames.CSharp, Name = PredefinedCodeRefactoringProviderNames.ConvertLinqQueryToForEach), Shared]
internal sealed class CSharpConvertLinqQueryToForEachProvider : AbstractConvertLinqQueryToForEachProvider<QueryExpressionSyntax, StatementSyntax>
{
    private static readonly TypeSyntax VarNameIdentifier = IdentifierName("var");
 
    [ImportingConstructor]
    [SuppressMessage("RoslynDiagnosticsReliability", "RS0033:Importing constructor should be [Obsolete]", Justification = "Used in test code: https://github.com/dotnet/roslyn/issues/42814")]
    public CSharpConvertLinqQueryToForEachProvider()
    {
    }
 
    protected override string Title => CSharpFeaturesResources.Convert_to_foreach;
 
    protected override bool TryConvert(
        QueryExpressionSyntax queryExpression,
        SemanticModel semanticModel,
        ISemanticFactsService semanticFacts,
        CancellationToken cancellationToken,
        out DocumentUpdateInfo documentUpdateInfo)
            => new Converter(semanticModel, semanticFacts, queryExpression, cancellationToken).TryConvert(out documentUpdateInfo);
 
    /// <summary>
    /// Finds a QueryExpressionSyntax node for the span.
    /// </summary>
    protected override Task<QueryExpressionSyntax> FindNodeToRefactorAsync(CodeRefactoringContext context)
        => context.TryGetRelevantNodeAsync<QueryExpressionSyntax>();
 
    private sealed class Converter(SemanticModel semanticModel, ISemanticFactsService semanticFacts, QueryExpressionSyntax source, CancellationToken cancellationToken)
    {
        private readonly SemanticModel _semanticModel = semanticModel;
        private readonly ISemanticFactsService _semanticFacts = semanticFacts;
        private readonly CancellationToken _cancellationToken = cancellationToken;
        private readonly QueryExpressionSyntax _source = source;
        private readonly List<string> _introducedLocalNames = [];
 
        public bool TryConvert(out DocumentUpdateInfo documentUpdateInfo)
        {
            // Do not try refactoring queries with comments or conditional compilation in them.
            // We can consider supporting queries with comments in the future.
            if (_source.DescendantTrivia().Any(trivia => trivia is (kind:
                    SyntaxKind.SingleLineCommentTrivia or
                    SyntaxKind.MultiLineCommentTrivia or
                    SyntaxKind.MultiLineDocumentationCommentTrivia) ||
                _source.ContainsDirectives))
            {
                documentUpdateInfo = null;
                return false;
            }
 
            // Bail out if there is no chance to convert it even with a local function.
            if (!CanTryConvertToLocalFunction() ||
                !TryCreateStackFromQueryExpression(out var queryExpressionProcessingInfo))
            {
                documentUpdateInfo = null;
                return false;
            }
 
            // GetDiagnostics is expensive. Move it to the end if there were no bail outs from the algorithm.
            // TODO likely adding more semantic checks will perform checks we expect from GetDiagnostics
            // We may consider removing GetDiagnostics.
            // https://github.com/dotnet/roslyn/issues/25639
            if ((TryConvertInternal(queryExpressionProcessingInfo, out documentUpdateInfo) ||
                TryReplaceWithLocalFunction(queryExpressionProcessingInfo, out documentUpdateInfo)) &&  // second attempt: at least to a local function
                !_semanticModel.GetDiagnostics(_source.Span, _cancellationToken).Any(static diagnostic => diagnostic.DefaultSeverity == DiagnosticSeverity.Error))
            {
                if (!documentUpdateInfo.Source.IsParentKind(SyntaxKind.Block) &&
                    documentUpdateInfo.Destinations.Length > 1)
                {
                    documentUpdateInfo = new DocumentUpdateInfo(documentUpdateInfo.Source, Block(documentUpdateInfo.Destinations));
                }
 
                return true;
            }
 
            documentUpdateInfo = null;
            return false;
        }
 
        private StatementSyntax ProcessClause(
            CSharpSyntaxNode node,
            StatementSyntax statement,
            bool isLastClause,
            bool hasExtraDeclarations,
            out StatementSyntax extraStatementToAddAbove)
        {
            extraStatementToAddAbove = null;
            switch (node.Kind())
            {
                case SyntaxKind.WhereClause:
                    return Block(IfStatement(((WhereClauseSyntax)node).Condition.WithAdditionalAnnotations(Simplifier.Annotation).WithoutTrivia(), statement));
                case SyntaxKind.FromClause:
                    var fromClause = (FromClauseSyntax)node;
 
                    // If we are processing the first from and
                    // there were joins and some evaluations were moved into declarations above the foreach
                    // Check if the declaration on the first fromclause should be moved for the evaluation above declarations already moved upfront.
                    ExpressionSyntax expression1;
                    if (isLastClause && hasExtraDeclarations && !IsLocalOrParameterSymbol(_source.FromClause.Expression))
                    {
                        var expressionName = _semanticFacts.GenerateNameForExpression(
                            _semanticModel,
                            fromClause.Expression,
                            capitalize: false,
                            _cancellationToken);
                        var variable = GetFreeSymbolNameAndMarkUsed(expressionName);
                        extraStatementToAddAbove = CreateLocalDeclarationStatement(variable, fromClause.Expression, generateTypeFromExpression: false);
                        expression1 = IdentifierName(variable);
                    }
                    else
                    {
                        expression1 = fromClause.Expression.WithoutTrivia();
                    }
 
                    return ForEachStatement(
                        fromClause.Type ?? VarNameIdentifier,
                        fromClause.Identifier,
                        expression1,
                        WrapWithBlock(statement));
                case SyntaxKind.LetClause:
                    var letClause = (LetClauseSyntax)node;
                    return AddToBlockTop(CreateLocalDeclarationStatement(letClause.Identifier, letClause.Expression, generateTypeFromExpression: false), statement);
                case SyntaxKind.JoinClause:
                    var joinClause = (JoinClauseSyntax)node;
                    if (joinClause.Into != null)
                    {
                        // This must be caught on the validation step. Therefore, here is an exception.
                        throw new ArgumentException("GroupJoin is not supported");
                    }
                    else
                    {
                        ExpressionSyntax expression2;
                        if (IsLocalOrParameterSymbol(joinClause.InExpression))
                        {
                            expression2 = joinClause.InExpression;
                        }
                        else
                        {
                            // Input: var q = from x in XX() from z in ZZ() join y in YY() on x equals y select x + y;
                            // Add 
                            // var xx = XX();
                            // var yy = YY();
                            // Do not add for ZZ()
                            var expressionName = _semanticFacts.GenerateNameForExpression(
                                _semanticModel,
                                joinClause.InExpression,
                                capitalize: false,
                                _cancellationToken);
                            var variable = GetFreeSymbolNameAndMarkUsed(expressionName);
                            extraStatementToAddAbove = CreateLocalDeclarationStatement(variable, joinClause.InExpression, generateTypeFromExpression: false);
 
                            // Replace YY() with yy declared above.
                            expression2 = IdentifierName(variable);
                        }
 
                        // Output for the join
                        // var yy == YY(); this goes to extraStatementToAddAbove
                        // ...
                        // foreach (var y in yy)
                        // {
                        //  if (object.Equals(x, y))
                        //  {
                        //      ...
                        return Block(
                            ForEachStatement(
                                joinClause.Type ?? VarNameIdentifier,
                                joinClause.Identifier,
                                expression2,
                                Block(
                                    IfStatement(
                                        InvocationExpression(
                                            MemberAccessExpression(
                                                SyntaxKind.SimpleMemberAccessExpression,
                                                PredefinedType(ObjectKeyword),
                                                IdentifierName(nameof(object.Equals))),
                                            ArgumentList([
                                                Argument(joinClause.LeftExpression),
                                                Argument(joinClause.RightExpression.WithoutTrailingTrivia())])),
                                        statement)))).WithAdditionalAnnotations(Simplifier.Annotation);
                    }
                case SyntaxKind.SelectClause:
                    // This is not the latest Select in the Query Expression
                    // There is a QueryBody with the Continuation as a parent.
                    var selectClause = (SelectClauseSyntax)node;
                    var identifier = ((QueryBodySyntax)selectClause.Parent).Continuation.Identifier;
                    return AddToBlockTop(CreateLocalDeclarationStatement(identifier, selectClause.Expression, generateTypeFromExpression: true), statement);
                default:
                    throw new ArgumentException($"Unexpected node kind {node.Kind()}");
            }
        }
 
        private bool TryConvertInternal(QueryExpressionProcessingInfo queryExpressionProcessingInfo, out DocumentUpdateInfo documentUpdateInfo)
        {
            // (from a in b select a); 
            var parent = _source.WalkUpParentheses().Parent;
 
            switch (parent.Kind())
            {
                // return from a in b select a;
                case SyntaxKind.ReturnStatement:
                    return TryConvertIfInReturnStatement((ReturnStatementSyntax)parent, queryExpressionProcessingInfo, out documentUpdateInfo);
                // foreach(var x in from a in b select a)
                case SyntaxKind.ForEachStatement:
                    return TryConvertIfInForEach((ForEachStatementSyntax)parent, queryExpressionProcessingInfo, out documentUpdateInfo);
                // (from a in b select a).ToList(), (from a in b select a).Count(), etc.
                case SyntaxKind.SimpleMemberAccessExpression:
                    return TryConvertIfInMemberAccessExpression((MemberAccessExpressionSyntax)parent, queryExpressionProcessingInfo, out documentUpdateInfo);
            }
 
            documentUpdateInfo = null;
            return false;
        }
 
        /// <summary>
        /// Checks if the location of the query expression allows to convert it at least to a local function.
        /// It still does not guarantees that the conversion can be performed. There can be bail outs of later stages.
        /// </summary>
        /// <returns></returns>
        private bool CanTryConvertToLocalFunction()
        {
            SyntaxNode currentNode = _source;
            while (currentNode != null)
            {
                if (currentNode is StatementSyntax)
                {
                    return true;
                }
 
                if (currentNode is ExpressionSyntax or
                    ArgumentSyntax or
                    ArgumentListSyntax or
                    EqualsValueClauseSyntax or
                    VariableDeclaratorSyntax or
                    VariableDeclarationSyntax)
                {
                    currentNode = currentNode.Parent;
                }
                else
                {
                    return false;
                }
            }
 
            return false;
        }
 
        private bool TryConvertIfInMemberAccessExpression(
           MemberAccessExpressionSyntax memberAccessExpression,
           QueryExpressionProcessingInfo queryExpressionProcessingInfo,
           out DocumentUpdateInfo documentUpdateInfo)
        {
            if (memberAccessExpression.Parent is InvocationExpressionSyntax invocationExpression)
            {
                // This also covers generic names (i.e. with type arguments) like 'ToList<int>'. 
                // The ValueText is still just 'ToList'. 
                switch (memberAccessExpression.Name.Identifier.ValueText)
                {
                    case nameof(Enumerable.ToList):
                        return TryConvertIfInToListInvocation(invocationExpression, queryExpressionProcessingInfo, out documentUpdateInfo);
                    case nameof(Enumerable.Count):
                        return TryConvertIfInCountInvocation(invocationExpression, queryExpressionProcessingInfo, out documentUpdateInfo);
                }
            }
 
            documentUpdateInfo = null;
            return false;
        }
 
        private bool TryConvertIfInCountInvocation(
            InvocationExpressionSyntax invocationExpression,
            QueryExpressionProcessingInfo queryExpressionProcessingInfo,
            out DocumentUpdateInfo documentUpdateInfo)
        {
            if (_semanticModel.GetSymbolInfo(invocationExpression, _cancellationToken).Symbol is IMethodSymbol methodSymbol &&
                methodSymbol.Parameters.Length == 0 &&
                methodSymbol.ReturnType?.SpecialType == SpecialType.System_Int32 &&
                methodSymbol.RefKind == RefKind.None)
            {
                // before var count = (from a in b select a).Count();
                // after
                // var count = 0;
                // foreach (var a in b)
                // {
                //     count++;
                //  }
                return TryConvertIfInInvocation(
                        invocationExpression,
                        queryExpressionProcessingInfo,
                        IsInt,
                        (variableIdentifier, expression) => ExpressionStatement(
                            PostfixUnaryExpression(SyntaxKind.PostIncrementExpression, variableIdentifier)), // Generating 'count++'
                        LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)), // count = 0
                        variableName: "count",
                        out documentUpdateInfo);
            }
 
            documentUpdateInfo = null;
            return false;
        }
 
        private bool TryConvertIfInToListInvocation(
            InvocationExpressionSyntax invocationExpression,
            QueryExpressionProcessingInfo queryExpressionProcessingInfo,
            out DocumentUpdateInfo documentUpdateInfo)
        {
            // before var list = (from a in b select a).ToList();
            // after
            // var list = new List<T>();
            // foreach (var a in b)
            // {
            //     list.Add(a)
            // }
            if (_semanticModel.GetSymbolInfo(invocationExpression, _cancellationToken).Symbol is IMethodSymbol methodSymbol &&
                methodSymbol.RefKind == RefKind.None &&
                IsList(methodSymbol.ReturnType) &&
                methodSymbol.Parameters.Length == 0)
            {
                return TryConvertIfInInvocation(
                          invocationExpression,
                          queryExpressionProcessingInfo,
                          IsList,
                          (listIdentifier, expression) => ExpressionStatement(InvocationExpression(
                                MemberAccessExpression(
                                    SyntaxKind.SimpleMemberAccessExpression,
                                    listIdentifier,
                                    IdentifierName(nameof(IList.Add))),
                                ArgumentList([Argument(expression)]))),
                          ObjectCreationExpression(
                              methodSymbol.GenerateReturnTypeSyntax().WithAdditionalAnnotations(Simplifier.Annotation),
                              ArgumentList(),
                              initializer: null),
                           variableName: "list",
                          out documentUpdateInfo);
            }
 
            documentUpdateInfo = null;
            return false;
        }
 
        private bool IsInt(ITypeSymbol typeSymbol)
            => typeSymbol.SpecialType == SpecialType.System_Int32;
 
        private bool IsList(ITypeSymbol typeSymbol)
            => Equals(typeSymbol.OriginalDefinition, _semanticModel.Compilation.GetTypeByMetadataName(typeof(List<>).FullName));
 
        private bool TryConvertIfInInvocation(
            InvocationExpressionSyntax invocationExpression,
            QueryExpressionProcessingInfo queryExpressionProcessingInfo,
            Func<ITypeSymbol, bool> typeCheckMethod,
            Func<ExpressionSyntax, ExpressionSyntax, StatementSyntax> leafExpressionCreationMethod,
            ExpressionSyntax initializer,
            string variableName,
            out DocumentUpdateInfo documentUpdateInfo)
        {
            var parentStatement = invocationExpression.GetAncestorOrThis<StatementSyntax>();
            if (parentStatement != null)
            {
                if (TryConvertIfInInvocationInternal(
                    invocationExpression,
                    typeCheckMethod,
                    parentStatement,
                    initializer,
                    variableName,
                    out var variable,
                    out var nodesBefore,
                    out var nodesAfter))
                {
                    var statements = GenerateStatements(expression => leafExpressionCreationMethod(variable, expression), queryExpressionProcessingInfo);
                    var list = new List<StatementSyntax>();
                    list.AddRange(nodesBefore);
                    list.AddRange(statements);
                    list.AddRange(nodesAfter);
                    documentUpdateInfo = new DocumentUpdateInfo(parentStatement, list);
                    return true;
                }
            }
 
            documentUpdateInfo = null;
            return false;
        }
 
        private bool TryConvertIfInInvocationInternal(
            InvocationExpressionSyntax invocationExpression,
            Func<ITypeSymbol, bool> typeCheckMethod,
            StatementSyntax parentStatement,
            ExpressionSyntax initializer,
            string variableName,
            out ExpressionSyntax variable,
            out StatementSyntax[] nodesBefore,
            out StatementSyntax[] nodesAfter)
        {
            var invocationParent = invocationExpression.WalkUpParentheses().Parent;
            var symbolName = GetFreeSymbolNameAndMarkUsed(variableName);
 
            void Convert(
                ExpressionSyntax variableExpression,
                ExpressionSyntax expressionToVerifyType,
                bool checkForLocalOrParameter,
                out ExpressionSyntax variableLocal,
                out StatementSyntax[] nodesBeforeLocal,
                out StatementSyntax[] nodesAfterLocal)
            {
                // Check that we can re-use the local variable or parameter
                if (typeCheckMethod(_semanticModel.GetTypeInfo(expressionToVerifyType, _cancellationToken).Type) &&
                    (!checkForLocalOrParameter || IsLocalOrParameterSymbol(variableExpression)))
                {
                    // before
                    // a = (from a in b select a).ToList(); or var a = (from a in b select a).ToList()
                    // after 
                    // a = new List<T>(); or var a = new List<T>();
                    // foreach(...)
                    variableLocal = variableExpression;
                    nodesBeforeLocal = [parentStatement.ReplaceNode(invocationExpression, initializer.WithAdditionalAnnotations(Simplifier.Annotation))];
                    nodesAfterLocal = [];
                }
                else
                {
                    // before 
                    // IReadOnlyList<int> a = (from a in b select a).ToList(); or an assignment
                    // after 
                    // var list = new List<T>(); or assignment
                    // foreach(...)
                    // IReadOnlyList<int> a = list;
                    variableLocal = IdentifierName(symbolName);
                    nodesBeforeLocal = new[] { CreateLocalDeclarationStatement(symbolName, initializer, generateTypeFromExpression: false) };
                    nodesAfterLocal = [parentStatement.ReplaceNode(invocationExpression, variableLocal.WithAdditionalAnnotations(Simplifier.Annotation))];
                }
            }
 
            switch (invocationParent.Kind())
            {
                case SyntaxKind.EqualsValueClause:
                    // Avoid for(int i = (from x in a select x).Count(); i < 10; i++)
                    if (invocationParent?.Parent.Kind() is
                            SyntaxKind.VariableDeclarator or
                            SyntaxKind.VariableDeclaration or
                            SyntaxKind.LocalDeclarationStatement &&
                        // Avoid int i = (from x in a select x).Count(), j = i;
                        ((VariableDeclarationSyntax)invocationParent.Parent.Parent).Variables.Count == 1)
                    {
                        var variableDeclarator = ((VariableDeclaratorSyntax)invocationParent.Parent);
                        Convert(
                            IdentifierName(variableDeclarator.Identifier),
                            ((VariableDeclarationSyntax)variableDeclarator.Parent).Type,
                            checkForLocalOrParameter: false,
                            out variable,
                            out nodesBefore,
                            out nodesAfter);
                        return true;
                    }
 
                    break;
                case SyntaxKind.SimpleAssignmentExpression:
                    var assignmentExpression = (AssignmentExpressionSyntax)invocationParent;
                    if (assignmentExpression.Right.WalkDownParentheses() == invocationExpression)
                    {
                        Convert(
                            assignmentExpression.Left,
                            assignmentExpression.Left,
                            checkForLocalOrParameter: true,
                            out variable,
                            out nodesBefore,
                            out nodesAfter);
                        return true;
                    }
 
                    break;
                case SyntaxKind.ReturnStatement:
                    // before return (from a in b select a).ToList();
                    // after var list = new List<T>();
                    // foreach(...)
                    // return list;
                    variable = IdentifierName(symbolName);
                    nodesBefore = new[] { CreateLocalDeclarationStatement(symbolName, initializer, generateTypeFromExpression: false) };
                    nodesAfter = new[] { ReturnStatement(variable).WithAdditionalAnnotations(Simplifier.Annotation) };
                    return true;
                    // SyntaxKind.Argument:
                    // SyntaxKind.ArrowExpressionClause is not supported
            }
 
            // Will still try to replace with a local function above.
            nodesBefore = null;
            nodesAfter = null;
            variable = null;
            return false;
        }
 
        private LocalDeclarationStatementSyntax CreateLocalDeclarationStatement(
            SyntaxToken identifier,
            ExpressionSyntax expression,
            bool generateTypeFromExpression)
        {
            var typeSyntax = generateTypeFromExpression
                ? _semanticModel.GetTypeInfo(expression, _cancellationToken).ConvertedType.GenerateTypeSyntax()
                : VarNameIdentifier;
            return LocalDeclarationStatement(
                        VariableDeclaration(
                            typeSyntax,
                            [VariableDeclarator(
                                identifier,
                                argumentList: null,
                                EqualsValueClause(expression))])).WithAdditionalAnnotations(Simplifier.Annotation);
        }
 
        private bool TryReplaceWithLocalFunction(QueryExpressionProcessingInfo queryExpressionProcessingInfo, out DocumentUpdateInfo documentUpdateInfo)
        {
            var parentStatement = _source.GetAncestorOrThis<StatementSyntax>();
            if (parentStatement == null)
            {
                documentUpdateInfo = null;
                return false;
            }
 
            // before statement ... from a in select b ...
            // after
            // IEnumerable<T> localFunction()
            // {
            //   foreach(var a in b)
            //   {
            //       yield return a;
            //   }
            // }
            //  statement ... localFunction();
            var returnTypeInfo = _semanticModel.GetTypeInfo(_source, _cancellationToken);
            ITypeSymbol returnedType;
 
            if (returnTypeInfo.Type.OriginalDefinition?.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T)
            {
                returnedType = returnTypeInfo.Type;
            }
            else
            {
                if (returnTypeInfo.ConvertedType.OriginalDefinition?.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T)
                {
                    returnedType = returnTypeInfo.ConvertedType;
                }
                else
                {
                    documentUpdateInfo = null;
                    return false;
                }
            }
 
            static StatementSyntax internalNodeMethod(ExpressionSyntax expression)
                => YieldStatement(SyntaxKind.YieldReturnStatement, expression);
 
            var statements = GenerateStatements(internalNodeMethod, queryExpressionProcessingInfo);
            var localFunctionNamePrefix = _semanticFacts.GenerateNameForExpression(
                _semanticModel,
                _source,
                capitalize: false,
                _cancellationToken);
            var localFunctionToken = GetFreeSymbolNameAndMarkUsed(localFunctionNamePrefix);
            var localFunctionDeclaration = LocalFunctionStatement(
                modifiers: default,
                returnType: returnedType.GenerateTypeSyntax().WithAdditionalAnnotations(Simplifier.Annotation),
                identifier: localFunctionToken,
                typeParameterList: null,
                parameterList: ParameterList(),
                constraintClauses: default,
                body: Block(
                    Token(
                        [],
                        SyntaxKind.OpenBraceToken,
                        [EndOfLine(Environment.NewLine)]),
                    [.. statements],
                    CloseBraceToken),
                expressionBody: null);
 
            var localFunctionInvocation = InvocationExpression(IdentifierName(localFunctionToken)).WithAdditionalAnnotations(Simplifier.Annotation);
            var newParentExpressionStatement = parentStatement.ReplaceNode(_source.WalkUpParentheses(), localFunctionInvocation.WithAdditionalAnnotations(Simplifier.Annotation));
            documentUpdateInfo = new DocumentUpdateInfo(parentStatement, [localFunctionDeclaration, newParentExpressionStatement]);
            return true;
        }
 
        private SyntaxToken GetFreeSymbolNameAndMarkUsed(string prefix)
        {
            var freeToken = _semanticFacts.GenerateUniqueName(_semanticModel, _source, container: null, baseName: prefix, _introducedLocalNames, _cancellationToken);
            _introducedLocalNames.Add(freeToken.ValueText);
            return freeToken;
        }
 
        private bool TryConvertIfInForEach(
            ForEachStatementSyntax forEachStatement,
            QueryExpressionProcessingInfo queryExpressionProcessingInfo,
            out DocumentUpdateInfo documentUpdateInfo)
        {
            // before foreach(var x in from a in b select a)
 
            if (forEachStatement.Expression.WalkDownParentheses() != _source)
            {
                documentUpdateInfo = null;
                return false;
            }
 
            // check that the body of the forEach does not contain any identifiers from the query
            foreach (var identifierName in queryExpressionProcessingInfo.IdentifierNames)
            {
                // Identifier from the foreach can already be in scope of the foreach statement.
                if (forEachStatement.Identifier.ValueText != identifierName)
                {
                    if (_semanticFacts.GenerateUniqueName(
                            _semanticModel,
                            location: forEachStatement.Statement,
                            container: forEachStatement.Statement,
                            baseName: identifierName,
                            usedNames: [],
                            _cancellationToken).ValueText != identifierName)
                    {
                        documentUpdateInfo = null;
                        return false;
                    }
                }
            }
 
            // If query does not contains identifier with the same name as declared in the foreach,
            // declare this identifier in the body.
            if (!queryExpressionProcessingInfo.ContainsIdentifier(forEachStatement.Identifier))
            {
                documentUpdateInfo = ConvertIfInToForeachWithExtraVariableDeclaration(forEachStatement, queryExpressionProcessingInfo);
                return true;
            }
            else
            {
                // The last select expression in the query returns this identifier:
                // foreach(var thisIdentifier in from ....... select thisIdentifier)
                // if thisIdentifier in foreach is var, it is OK
                // if a type is specified for thisIdentifier in forEach, check that the type is the same as in the select expression
                // foreach(MyType thisIdentifier in from ....... from MyType thisIdentifier ... select thisIdentifier)
 
                // The last clause in query stack must be SelectClauseSyntax.
                var lastSelectExpression = ((SelectClauseSyntax)queryExpressionProcessingInfo.Stack.Peek()).Expression;
                if (lastSelectExpression is IdentifierNameSyntax identifierName &&
                    forEachStatement.Identifier.ValueText == identifierName.Identifier.ValueText &&
                    queryExpressionProcessingInfo.IdentifierNames.Contains(identifierName.Identifier.ValueText))
                {
                    var forEachStatementTypeSymbolType = _semanticModel.GetTypeInfo(forEachStatement.Type, _cancellationToken).Type;
                    var lastSelectExpressionTypeInfo = _semanticModel.GetTypeInfo(lastSelectExpression, _cancellationToken);
                    if (Equals(lastSelectExpressionTypeInfo.ConvertedType, lastSelectExpressionTypeInfo.Type) &&
                        Equals(lastSelectExpressionTypeInfo.ConvertedType, forEachStatementTypeSymbolType))
                    {
                        documentUpdateInfo = ConvertIfInToForeachWithoutExtraVariableDeclaration(forEachStatement, queryExpressionProcessingInfo);
                        return true;
                    }
                }
            }
 
            // in all other cases try to replace with a local function - this is called above.
            documentUpdateInfo = null;
            return false;
        }
 
        private DocumentUpdateInfo ConvertIfInToForeachWithExtraVariableDeclaration(
            ForEachStatementSyntax forEachStatement,
            QueryExpressionProcessingInfo queryExpressionProcessingInfo)
        {
            // before foreach(var x in from ... a) { dosomething(x); }
            // after 
            // foreach (var a in ...)
            // ...
            // {
            //      var x = a;
            //      dosomething(x); 
            //  }
            var statements = GenerateStatements(
                expression => AddToBlockTop(LocalDeclarationStatement(
                    VariableDeclaration(
                        forEachStatement.Type,
                        [VariableDeclarator(
                            forEachStatement.Identifier,
                            argumentList: null,
                            EqualsValueClause(expression))])),
                                forEachStatement.Statement).WithAdditionalAnnotations(Formatter.Annotation), queryExpressionProcessingInfo);
            return new DocumentUpdateInfo(forEachStatement, statements);
        }
 
        private DocumentUpdateInfo ConvertIfInToForeachWithoutExtraVariableDeclaration(
            ForEachStatementSyntax forEachStatement,
            QueryExpressionProcessingInfo queryExpressionProcessingInfo)
        {
            // before 
            //  foreach (var a in from a in b where a > 5 select a) 
            //  { 
            //      dosomething(a); 
            //  }
            // after 
            //  foreach (var a in b)
            //  {
            //      if (a > 5)
            //      {
            //          dosomething(a); 
            //      }
            //  }
            var statements = GenerateStatements(
                expression => forEachStatement.Statement.WithAdditionalAnnotations(Formatter.Annotation),
                queryExpressionProcessingInfo);
            return new DocumentUpdateInfo(forEachStatement, statements);
        }
 
        private bool TryConvertIfInReturnStatement(
            ReturnStatementSyntax returnStatement,
            QueryExpressionProcessingInfo queryExpressionProcessingInfo,
            out DocumentUpdateInfo documentUpdateInfo)
        {
            // The conversion requires yield return which cannot be added to lambdas and anonymous method declarations.
            if (IsWithinImmediateLambdaOrAnonymousMethod(returnStatement))
            {
                documentUpdateInfo = null;
                return false;
            }
 
            var memberDeclarationNode = FindParentMemberDeclarationNode(returnStatement, out var declaredSymbol);
            if (declaredSymbol is not IMethodSymbol methodSymbol)
            {
                documentUpdateInfo = null;
                return false;
            }
 
            if (methodSymbol.ReturnType.OriginalDefinition?.SpecialType != SpecialType.System_Collections_Generic_IEnumerable_T)
            {
                documentUpdateInfo = null;
                return false;
            }
 
            // if there are more than one return in the method, convert to local function.
            if (memberDeclarationNode.DescendantNodes().OfType<ReturnStatementSyntax>().Count() == 1)
            {
                // before: return from a in b select a;
                // after: 
                // foreach(var a in b)
                // {
                //      yield return a;
                // }
                //
                // yield break;
                var statements = GenerateStatements((ExpressionSyntax expression)
                    => YieldStatement(SyntaxKind.YieldReturnStatement, expression), queryExpressionProcessingInfo);
 
                // add an yield break to avoid throws after the return.
                var yieldBreakStatement = YieldStatement(SyntaxKind.YieldBreakStatement);
                documentUpdateInfo = new DocumentUpdateInfo(returnStatement, statements.Concat([yieldBreakStatement]));
                return true;
            }
 
            documentUpdateInfo = null;
            return false;
        }
 
        // We may assume that the query is defined within a method, field, property and so on and it is declare just once.
        private SyntaxNode FindParentMemberDeclarationNode(SyntaxNode node, out ISymbol declaredSymbol)
        {
            declaredSymbol = _semanticModel.GetEnclosingSymbol(node.SpanStart, _cancellationToken);
            return declaredSymbol.DeclaringSyntaxReferences.Single().GetSyntax();
        }
 
        private bool TryCreateStackFromQueryExpression(out QueryExpressionProcessingInfo queryExpressionProcessingInfo)
        {
            queryExpressionProcessingInfo = new QueryExpressionProcessingInfo(_source.FromClause);
            return TryProcessQueryBody(_source.Body, queryExpressionProcessingInfo);
        }
 
        private StatementSyntax[] GenerateStatements(
            Func<ExpressionSyntax, StatementSyntax> leafExpressionCreationMethod,
            QueryExpressionProcessingInfo queryExpressionProcessingInfo)
        {
            StatementSyntax statement = null;
            var stack = queryExpressionProcessingInfo.Stack;
            // Executes syntax building methods from bottom to the top of the tree.
            // Process last clause
            if (stack.Any())
            {
                var node = stack.Pop();
                if (node is SelectClauseSyntax selectClause)
                {
                    statement = WrapWithBlock(leafExpressionCreationMethod(selectClause.Expression));
                }
                else
                {
                    throw new ArgumentException("Last node must me the select clause");
                }
            }
 
            // Process all other clauses
            var statements = new List<StatementSyntax>();
            while (stack.Any())
            {
                statement = ProcessClause(
                    stack.Pop(),
                    statement,
                    isLastClause: !stack.Any(),
                    hasExtraDeclarations: statements.Any(),
                    out var extraStatement);
                if (extraStatement != null)
                {
                    statements.Add(extraStatement);
                }
            }
 
            // The stack was processed in the reverse order, but the extra statements should be provided in the direct order.
            statements.Reverse();
            statements.Add(statement.WithAdditionalAnnotations(Simplifier.Annotation));
            return [.. statements];
        }
 
        private bool TryProcessQueryBody(QueryBodySyntax queryBody, QueryExpressionProcessingInfo queryExpressionProcessingInfo)
        {
            do
            {
                foreach (var queryClause in queryBody.Clauses)
                {
                    switch (queryClause.Kind())
                    {
                        case SyntaxKind.WhereClause:
                            queryExpressionProcessingInfo.Add(queryClause);
                            break;
                        case SyntaxKind.LetClause:
                            if (!queryExpressionProcessingInfo.TryAdd(queryClause, ((LetClauseSyntax)queryClause).Identifier))
                            {
                                return false;
                            }
 
                            break;
                        case SyntaxKind.FromClause:
                            var fromClause = (FromClauseSyntax)queryClause;
                            if (!queryExpressionProcessingInfo.TryAdd(queryClause, fromClause.Identifier))
                            {
                                return false;
                            }
 
                            break;
                        case SyntaxKind.JoinClause:
                            var joinClause = (JoinClauseSyntax)queryClause;
                            if (joinClause.Into == null) // GroupJoin is not supported
                            {
                                if (queryExpressionProcessingInfo.TryAdd(joinClause, joinClause.Identifier))
                                {
                                    break;
                                }
                            }
 
                            return false;
                        // OrderBy is not supported by foreach.
                        default:
                            return false;
                    }
                }
 
                // GroupClause is not supported by the conversion
                if (queryBody.SelectOrGroup is not SelectClauseSyntax selectClause)
                {
                    return false;
                }
 
                if (_semanticModel.GetTypeInfo(selectClause.Expression, _cancellationToken).Type.ContainsAnonymousType())
                {
                    return false;
                }
 
                queryExpressionProcessingInfo.Add(selectClause);
                queryBody = queryBody.Continuation?.Body;
            } while (queryBody != null);
 
            return true;
        }
 
        private static BlockSyntax AddToBlockTop(StatementSyntax newStatement, StatementSyntax statement)
        {
            if (statement is BlockSyntax block)
            {
                return block.WithStatements(block.Statements.Insert(0, newStatement));
            }
            else
            {
                return Block(newStatement, statement);
            }
        }
 
        private bool IsLocalOrParameterSymbol(ExpressionSyntax expression)
            => IsLocalOrParameterSymbol(_semanticModel.GetOperation(expression, _cancellationToken));
 
        private static bool IsLocalOrParameterSymbol(IOperation operation)
        {
            if (operation is IConversionOperation conversion && conversion.IsImplicit)
            {
                return IsLocalOrParameterSymbol(conversion.Operand);
            }
 
            return operation.Kind is OperationKind.LocalReference or OperationKind.ParameterReference;
        }
 
        private static BlockSyntax WrapWithBlock(StatementSyntax statement)
            => statement is BlockSyntax block ? block : Block(statement);
 
        // Checks if the node is within an immediate lambda or within an immediate anonymous method.
        // 'lambda => node' returns true
        // 'lambda => localfunction => node' returns false
        // 'member => node' returns false
        private static bool IsWithinImmediateLambdaOrAnonymousMethod(SyntaxNode node)
        {
            while (node != null)
            {
                switch (node.Kind())
                {
                    case SyntaxKind.AnonymousMethodExpression:
                    case SyntaxKind.ParenthesizedLambdaExpression:
                    case SyntaxKind.SimpleLambdaExpression:
                        return true;
                    case SyntaxKind.LocalFunctionStatement:
                        return false;
                    default:
                        if (node is MemberDeclarationSyntax)
                        {
                            return false;
                        }
 
                        break;
                }
 
                node = node.Parent;
            }
 
            return false;
        }
 
        private class QueryExpressionProcessingInfo
        {
            public Stack<CSharpSyntaxNode> Stack { get; private set; }
 
            public HashSet<string> IdentifierNames { get; private set; }
 
            public QueryExpressionProcessingInfo(FromClauseSyntax fromClause)
            {
                Stack = new Stack<CSharpSyntaxNode>();
                Stack.Push(fromClause);
                IdentifierNames = [fromClause.Identifier.ValueText];
            }
 
            public bool TryAdd(CSharpSyntaxNode node, SyntaxToken identifier)
            {
                // Duplicate identifiers are not allowed.
                // var q = from x in new[] { 1 } select x + 2 into x where x > 0 select 7 into y let x = ""aaa"" select x;
                if (!IdentifierNames.Add(identifier.ValueText))
                {
                    return false;
                }
 
                Stack.Push(node);
                return true;
            }
 
            public void Add(CSharpSyntaxNode node) => Stack.Push(node);
 
            public bool ContainsIdentifier(SyntaxToken identifier)
                => IdentifierNames.Contains(identifier.ValueText);
        }
    }
}