File: ConvertForToForEach\AbstractConvertForToForEachCodeRefactoringProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.ConvertForToForEach;
 
internal abstract class AbstractConvertForToForEachCodeRefactoringProvider<
    TStatementSyntax,
    TForStatementSyntax,
    TExpressionSyntax,
    TMemberAccessExpressionSyntax,
    TTypeNode,
    TVariableDeclaratorSyntax> : CodeRefactoringProvider
    where TStatementSyntax : SyntaxNode
    where TForStatementSyntax : TStatementSyntax
    where TExpressionSyntax : SyntaxNode
    where TMemberAccessExpressionSyntax : SyntaxNode
    where TTypeNode : SyntaxNode
    where TVariableDeclaratorSyntax : SyntaxNode
{
    protected abstract string GetTitle();
 
    protected abstract SyntaxList<TStatementSyntax> GetBodyStatements(TForStatementSyntax forStatement);
    protected abstract bool IsValidVariableDeclarator(TVariableDeclaratorSyntax firstVariable);
 
    protected abstract bool TryGetForStatementComponents(
        TForStatementSyntax forStatement,
        out SyntaxToken iterationVariable,
        [NotNullWhen(true)] out TExpressionSyntax? initializer,
        [NotNullWhen(true)] out TMemberAccessExpressionSyntax? memberAccess,
        out TExpressionSyntax? stepValueExpressionOpt,
        CancellationToken cancellationToken);
 
    protected abstract SyntaxNode ConvertForNode(
        TForStatementSyntax currentFor, TTypeNode? typeNode, SyntaxToken foreachIdentifier,
        TExpressionSyntax collectionExpression, ITypeSymbol iterationVariableType);
 
    public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, textSpan, cancellationToken) = context;
        var forStatement = await context.TryGetRelevantNodeAsync<TForStatementSyntax>().ConfigureAwait(false);
        if (forStatement == null)
            return;
 
        if (!TryGetForStatementComponents(forStatement,
                out var iterationVariable, out var initializer, out var memberAccess, out var stepValueExpressionOpt, cancellationToken))
        {
            return;
        }
 
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        syntaxFacts.GetPartsOfMemberAccessExpression(memberAccess,
            out var collectionExpressionNode, out var memberAccessNameNode);
 
        var collectionExpression = (TExpressionSyntax)collectionExpressionNode;
        syntaxFacts.GetNameAndArityOfSimpleName(memberAccessNameNode, out var memberAccessName, out _);
        if (memberAccessName is not nameof(Array.Length) and not nameof(IList.Count))
            return;
 
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        // Make sure it's a single-variable for loop and that we're not a loop where we're
        // referencing some previously declared symbol.  i.e
        // VB allows:
        //
        //      dim i as integer
        //      for i = 0 to ...
        //
        // We can't convert this as it would change important semantics.
        // NOTE: we could potentially update this if we saw that the variable was not used
        // after the for-loop.  But, for now, we'll just be conservative and assume this means
        // the user wanted the 'i' for some other purpose and we should keep things as is.
        if (semanticModel.GetOperation(forStatement, cancellationToken) is not ILoopOperation { Locals.Length: 1 })
            return;
 
        // Make sure we're starting at 0.
        var initializerValue = semanticModel.GetConstantValue(initializer, cancellationToken);
        if (initializerValue is not { HasValue: true, Value: 0 })
            return;
 
        // Make sure we're incrementing by 1.
        if (stepValueExpressionOpt != null)
        {
            var stepValue = semanticModel.GetConstantValue(stepValueExpressionOpt);
            if (stepValue is not { HasValue: true, Value: 1 })
                return;
        }
 
        var collectionType = semanticModel.GetTypeInfo(collectionExpression, cancellationToken);
        if (collectionType.Type is null or IErrorTypeSymbol)
            return;
 
        var containingType = semanticModel.GetEnclosingNamedType(textSpan.Start, cancellationToken);
        if (containingType == null)
            return;
 
        var ienumerableType = semanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T);
        var ienumeratorType = semanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerator_T);
 
        // make sure the collection can be iterated.
        if (!TryGetIterationElementType(
                containingType, collectionType.Type,
                ienumerableType, ienumeratorType,
                out var iterationType))
        {
            return;
        }
 
        // If the user uses the iteration variable for any other reason, we can't convert this.
        var bodyStatements = GetBodyStatements(forStatement);
        foreach (var statement in bodyStatements)
        {
            if (IterationVariableIsUsedForMoreThanCollectionIndex(statement))
                return;
        }
 
        // Looks good.  We can convert this.
        var title = GetTitle();
        context.RegisterRefactoring(
            CodeAction.Create(
                title,
                cancellationToken => ConvertForToForEachAsync(
                    document, forStatement, iterationVariable, collectionExpression,
                    containingType, collectionType.Type, iterationType, cancellationToken),
                title),
            forStatement.Span);
 
        return;
 
        // local functions
        bool IterationVariableIsUsedForMoreThanCollectionIndex(SyntaxNode current)
        {
            if (syntaxFacts.IsIdentifierName(current))
            {
                syntaxFacts.GetNameAndArityOfSimpleName(current, out var name, out _);
                if (name == iterationVariable.ValueText)
                {
                    // found a reference.  make sure it's only used inside something like
                    // list[i]
 
                    var argument = current.Parent;
                    if (!syntaxFacts.IsSimpleArgument(argument))
                        return true;
 
                    // we support `list[i]` or `list.ElementAt(i)`
                    var argumentList = argument?.Parent;
                    if (argumentList is null)
                        return true;
 
                    var arguments = syntaxFacts.GetArgumentsOfArgumentList(argumentList);
                    // was used in a multi-dimensional indexing, or multiple argument method call.  Can't convert this.
                    if (arguments.Count != 1)
                        return true;
 
                    if (!IsGoodElementAccessExpression(argumentList) &&
                        !IsGoodInvocationExpression(argumentList))
                    {
                        // used in something other than accessing into a collection.
                        // can't convert this for-loop.
                        return true;
                    }
                }
 
                // this usage of the for-variable is fine.
            }
 
            foreach (var child in current.ChildNodesAndTokens())
            {
                if (child.AsNode(out var childNode) &&
                    IterationVariableIsUsedForMoreThanCollectionIndex(childNode))
                {
                    return true;
                }
            }
 
            return false;
        }
 
        bool IsGoodElementAccessExpression(SyntaxNode argumentList)
        {
            if (syntaxFacts.IsElementAccessExpression(argumentList.Parent))
            {
                var expr = syntaxFacts.GetExpressionOfElementAccessExpression(argumentList.Parent);
 
                // Have to be indexing into the collection.
                if (syntaxFacts.AreEquivalent(expr, collectionExpression))
                    return true;
            }
 
            return false;
        }
 
        bool IsGoodInvocationExpression(SyntaxNode argumentList)
        {
            if (syntaxFacts.IsInvocationExpression(argumentList.Parent))
            {
                var invokedExpression = syntaxFacts.GetExpressionOfInvocationExpression(argumentList.Parent);
                if (syntaxFacts.IsMemberAccessExpression(invokedExpression))
                {
                    syntaxFacts.GetPartsOfMemberAccessExpression(invokedExpression, out var accessedExpression, out var accessedName);
                    syntaxFacts.GetNameAndArityOfSimpleName(accessedName, out var memberName, out _);
 
                    // Have to be indexing into the collection.
                    if (memberName == nameof(Enumerable.ElementAt) &&
                        syntaxFacts.AreEquivalent(accessedExpression, collectionExpression))
                    {
                        return true;
                    }
                }
            }
 
            return false;
        }
    }
 
    private static IEnumerable<TSymbol> TryFindMembersInThisOrBaseTypes<TSymbol>(
        INamedTypeSymbol containingType, ITypeSymbol type, string memberName) where TSymbol : class, ISymbol
    {
        var methods = type.GetAccessibleMembersInThisAndBaseTypes<TSymbol>(containingType);
        return methods.Where(m => m.Name == memberName);
    }
 
    private static TSymbol? TryFindMemberInThisOrBaseTypes<TSymbol>(
        INamedTypeSymbol containingType, ITypeSymbol type, string memberName) where TSymbol : class, ISymbol
    {
        return TryFindMembersInThisOrBaseTypes<TSymbol>(containingType, type, memberName).FirstOrDefault();
    }
 
    private static bool TryGetIterationElementType(
        INamedTypeSymbol containingType, ITypeSymbol collectionType,
        INamedTypeSymbol ienumerableType, INamedTypeSymbol ienumeratorType,
        [NotNullWhen(true)] out ITypeSymbol? iterationType)
    {
        if (collectionType is IArrayTypeSymbol arrayType)
        {
            iterationType = arrayType.ElementType;
 
            // We only support single-dimensional array iteration.
            return arrayType.Rank == 1;
        }
 
        // Check in the class/struct hierarchy first.
        var getEnumeratorMethod = TryFindMemberInThisOrBaseTypes<IMethodSymbol>(
            containingType, collectionType, WellKnownMemberNames.GetEnumeratorMethodName);
        if (getEnumeratorMethod != null)
        {
            return TryGetIterationElementTypeFromGetEnumerator(
                containingType, getEnumeratorMethod, ienumeratorType, out iterationType);
        }
 
        // couldn't find .GetEnumerator on the class/struct.  Check the interface hierarchy.
        var instantiatedIEnumerableType = collectionType.GetAllInterfacesIncludingThis().FirstOrDefault(
            t => Equals(t.OriginalDefinition, ienumerableType));
 
        if (instantiatedIEnumerableType != null)
        {
            iterationType = instantiatedIEnumerableType.TypeArguments[0];
            return true;
        }
 
        iterationType = null;
        return false;
    }
 
    private static bool TryGetIterationElementTypeFromGetEnumerator(
        INamedTypeSymbol containingType, IMethodSymbol getEnumeratorMethod,
        INamedTypeSymbol ienumeratorType, [NotNullWhen(true)] out ITypeSymbol? iterationType)
    {
        var getEnumeratorReturnType = getEnumeratorMethod.ReturnType;
 
        // Check in the class/struct hierarchy first.
        var currentProperty = TryFindMemberInThisOrBaseTypes<IPropertySymbol>(
            containingType, getEnumeratorReturnType, WellKnownMemberNames.CurrentPropertyName);
        if (currentProperty != null)
        {
            iterationType = currentProperty.Type;
            return true;
        }
 
        // couldn't find .Current on the class/struct.  Check the interface hierarchy.
        var instantiatedIEnumeratorType = getEnumeratorReturnType.GetAllInterfacesIncludingThis().FirstOrDefault(
            t => Equals(t.OriginalDefinition, ienumeratorType));
 
        if (instantiatedIEnumeratorType != null)
        {
            iterationType = instantiatedIEnumeratorType.TypeArguments[0];
            return true;
        }
 
        iterationType = null;
        return false;
    }
 
    private async Task<Document> ConvertForToForEachAsync(
        Document document,
        TForStatementSyntax forStatement,
        SyntaxToken iterationVariable,
        TExpressionSyntax collectionExpression,
        INamedTypeSymbol containingType,
        ITypeSymbol collectionType,
        ITypeSymbol iterationType,
        CancellationToken cancellationToken)
    {
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        var semanticFacts = document.GetRequiredLanguageService<ISemanticFactsService>();
        var generator = SyntaxGenerator.GetGenerator(document);
 
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
        var editor = new SyntaxEditor(root, generator);
 
        // create dummy "list[i]" and "list.ElementAt(i)" expressions.  We'll use this to find all places to replace
        // in the current for statement.
        var indexExpression = generator.ElementAccessExpression(collectionExpression, generator.IdentifierName(iterationVariable));
        var elementAtExpression = generator.InvocationExpression(
            generator.MemberAccessExpression(collectionExpression, generator.IdentifierName(nameof(Enumerable.ElementAt))),
            generator.IdentifierName(iterationVariable));
 
        // See if the first statement in the for loop is of the form:
        //      var x = list[i]   or
        //
        // If so, we'll use those as the iteration variables for the new foreach statement.
        var (typeNode, foreachIdentifier, declarationStatement) = TryDeconstructInitialDeclaration();
 
        if (typeNode == null)
        {
            // user didn't provide an explicit type.  Check if the index-type of the collection
            // is different from than .Current type of the enumerator.  If so, add an explicit
            // type so that the foreach will coerce the types accordingly.
            var indexerType = GetIndexerType(containingType, collectionType, semanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T));
            if (!Equals(indexerType, iterationType))
            {
                typeNode = (TTypeNode)generator.TypeExpression(
                    indexerType ?? semanticModel.Compilation.GetSpecialType(SpecialType.System_Object));
            }
        }
 
        // If we couldn't find an appropriate existing variable to use as the foreach
        // variable, then generate one automatically.
        if (foreachIdentifier.RawKind == 0)
        {
            foreachIdentifier = semanticFacts.GenerateUniqueName(
                semanticModel, forStatement, container: null, baseName: "v", usedNames: [], cancellationToken);
            foreachIdentifier = foreachIdentifier.WithAdditionalAnnotations(RenameAnnotation.Create());
        }
 
        // Create the expression we'll use to replace all matches in the for-body.
        var foreachIdentifierReference = foreachIdentifier.WithoutAnnotations(RenameAnnotation.Kind).WithoutTrivia();
 
        // Walk the for statement, replacing any matches we find.
        FindAndReplaceMatches(forStatement);
 
        // Finally, remove the declaration statement if we found one.  Move all its leading
        // trivia to the next statement.
        if (declarationStatement != null)
        {
            editor.RemoveNode(declarationStatement,
                SyntaxGenerator.DefaultRemoveOptions | SyntaxRemoveOptions.KeepLeadingTrivia);
        }
 
        editor.ReplaceNode(
            forStatement,
            (currentFor, _) => ConvertForNode(
                (TForStatementSyntax)currentFor, typeNode, foreachIdentifier,
                collectionExpression, iterationType));
 
        return document.WithSyntaxRoot(editor.GetChangedRoot());
 
        // local functions
        (TTypeNode?, SyntaxToken, TStatementSyntax) TryDeconstructInitialDeclaration()
        {
            var bodyStatements = GetBodyStatements(forStatement);
 
            if (bodyStatements.Count >= 1)
            {
                var firstStatement = bodyStatements[0];
                if (syntaxFacts.IsLocalDeclarationStatement(firstStatement))
                {
                    var variables = syntaxFacts.GetVariablesOfLocalDeclarationStatement(firstStatement);
                    if (variables.Count == 1)
                    {
                        var firstVariable = (TVariableDeclaratorSyntax)variables[0];
                        if (IsValidVariableDeclarator(firstVariable))
                        {
                            var initializer = syntaxFacts.GetInitializerOfVariableDeclarator(firstVariable);
                            if (initializer != null)
                            {
                                var firstVariableInitializer = syntaxFacts.GetValueOfEqualsValueClause(initializer);
                                if (syntaxFacts.AreEquivalent(firstVariableInitializer, indexExpression))
                                {
                                    var type = (TTypeNode?)syntaxFacts.GetTypeOfVariableDeclarator(firstVariable)?.WithoutLeadingTrivia();
                                    var identifier = syntaxFacts.GetIdentifierOfVariableDeclarator(firstVariable);
                                    var statement = firstStatement;
                                    return (type, identifier, statement);
                                }
                            }
                        }
                    }
                }
            }
 
            return default;
        }
 
        void FindAndReplaceMatches(SyntaxNode current)
        {
            if (SemanticEquivalence.AreEquivalent(semanticModel, current, collectionExpression))
            {
                if (syntaxFacts.AreEquivalent(current.Parent, indexExpression))
                {
                    // Found a match.  replace with iteration variable.
                    var indexMatch = current.GetRequiredParent();
                    Replace(indexMatch);
                }
                else if (syntaxFacts.AreEquivalent(current.Parent?.Parent, elementAtExpression))
                {
                    // Found a match.  replace with iteration variable.
                    var indexMatch = current.GetRequiredParent().GetRequiredParent();
                    Replace(indexMatch);
                }
                else
                {
                    // Collection was used for some other purpose.  If it's passed as an argument
                    // to something, or is written to, or has a method invoked on it, we'll warn
                    // that it's potentially changing and may break if you switch to a foreach loop.
                    var shouldWarn = syntaxFacts.IsArgument(current.Parent);
                    shouldWarn |= semanticFacts.IsWrittenTo(semanticModel, current, cancellationToken);
                    shouldWarn |=
                        syntaxFacts.IsMemberAccessExpression(current.Parent) &&
                        syntaxFacts.IsInvocationExpression(current.Parent.Parent);
 
                    if (shouldWarn)
                    {
                        editor.ReplaceNode(
                            current,
                            (node, _) => node.WithAdditionalAnnotations(
                                WarningAnnotation.Create(FeaturesResources.Warning_colon_Iteration_variable_crossed_function_boundary)));
                    }
                }
 
                return;
            }
 
            foreach (var child in current.ChildNodesAndTokens())
            {
                if (child.AsNode(out var childNode))
                    FindAndReplaceMatches(childNode);
            }
        }
 
        bool CrossesFunctionBoundary(SyntaxNode node)
        {
            var containingFunction = node.AncestorsAndSelf().FirstOrDefault(
                n => syntaxFacts.IsLocalFunctionStatement(n) || syntaxFacts.IsAnonymousFunctionExpression(n));
 
            if (containingFunction == null)
                return false;
 
            return containingFunction.AncestorsAndSelf().Contains(forStatement);
        }
 
        void Replace(SyntaxNode indexMatch)
        {
            var replacementToken = foreachIdentifierReference;
 
            if (semanticFacts.IsWrittenTo(semanticModel, indexMatch, cancellationToken))
            {
                replacementToken = replacementToken.WithAdditionalAnnotations(
                    WarningAnnotation.Create(FeaturesResources.Warning_colon_Collection_was_modified_during_iteration));
            }
 
            if (CrossesFunctionBoundary(indexMatch))
            {
                replacementToken = replacementToken.WithAdditionalAnnotations(
                    WarningAnnotation.Create(FeaturesResources.Warning_colon_Iteration_variable_crossed_function_boundary));
            }
 
            editor.ReplaceNode(
                indexMatch,
                generator.IdentifierName(replacementToken).WithTriviaFrom(indexMatch));
        }
    }
 
    private static ITypeSymbol? GetIndexerType(
        INamedTypeSymbol containingType,
        ITypeSymbol collectionType,
        INamedTypeSymbol ienumerableType)
    {
        if (collectionType is IArrayTypeSymbol arrayType)
            return arrayType.Rank == 1 ? arrayType.ElementType : null;
 
        var indexer = collectionType
            .GetAccessibleMembersInThisAndBaseTypes<IPropertySymbol>(containingType)
            .Where(IsViableIndexer)
            .FirstOrDefault();
 
        if (indexer?.Type != null)
            return indexer.Type;
 
        if (collectionType.IsInterfaceType())
        {
            var interfaces = collectionType.GetAllInterfacesIncludingThis();
            indexer = interfaces.SelectMany(i => i.GetMembers().OfType<IPropertySymbol>().Where(IsViableIndexer)).FirstOrDefault();
 
            if (indexer?.Type != null)
                return indexer.Type;
        }
 
        foreach (var interfaceType in collectionType.GetAllInterfacesIncludingThis())
        {
            if (Equals(interfaceType.OriginalDefinition, ienumerableType))
                return interfaceType.TypeArguments[0];
        }
 
        return null;
    }
 
    private static bool IsViableIndexer(IPropertySymbol property)
        => property is { IsIndexer: true, Parameters: [{ Type.SpecialType: SpecialType.System_Int32 }] };
}