File: ConvertForEachToFor\AbstractConvertForEachToForCodeRefactoringProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Simplification;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.ConvertForEachToFor;
 
internal abstract class AbstractConvertForEachToForCodeRefactoringProvider<
    TStatementSyntax,
    TForEachStatement> : CodeRefactoringProvider
    where TStatementSyntax : SyntaxNode
    where TForEachStatement : TStatementSyntax
{
    private const string get_Count = nameof(get_Count);
    private const string get_Item = nameof(get_Item);
 
    private const string Length = nameof(Array.Length);
    private const string Count = nameof(IList.Count);
 
    private static readonly ImmutableArray<string> s_KnownInterfaceNames =
        [typeof(IList<>).FullName!, typeof(IReadOnlyList<>).FullName!, typeof(IList).FullName!];
 
    protected bool IsForEachVariableWrittenInside { get; private set; }
    protected abstract string Title { get; }
    protected abstract bool ValidLocation(ForEachInfo foreachInfo);
    protected abstract (SyntaxNode start, SyntaxNode end) GetForEachBody(TForEachStatement foreachStatement);
    protected abstract void ConvertToForStatement(
        SemanticModel model, ForEachInfo info, SyntaxEditor editor, CancellationToken cancellationToken);
    protected abstract bool IsValid(TForEachStatement foreachNode);
 
    /// <summary>
    /// Perform language specific checks if the conversion is supported.
    /// C#: Currently nothing blocking a conversion
    /// VB: Nested foreach loops sharing a single Next statement, Next statements with multiple variables and next statements
    /// not using the loop variable are not supported.
    /// </summary>
    protected abstract bool IsSupported(ILocalSymbol foreachVariable, IForEachLoopOperation forEachOperation, TForEachStatement foreachStatement);
 
    protected static SyntaxAnnotation CreateWarningAnnotation()
        => WarningAnnotation.Create(FeaturesResources.Warning_colon_semantics_may_change_when_converting_statement);
 
    public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, _, cancellationToken) = context;
        var foreachStatement = await context.TryGetRelevantNodeAsync<TForEachStatement>().ConfigureAwait(false);
        if (foreachStatement == null || !IsValid(foreachStatement))
        {
            return;
        }
 
        var model = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        var semanticFact = document.GetRequiredLanguageService<ISemanticFactsService>();
        var foreachInfo = GetForeachInfo(semanticFact, model, foreachStatement, cancellationToken);
        if (foreachInfo == null || !ValidLocation(foreachInfo))
        {
            return;
        }
 
        context.RegisterRefactoring(
            CodeAction.Create(
                Title,
                c => ConvertForeachToForAsync(document, foreachInfo, c),
                Title),
            foreachStatement.Span);
    }
 
    protected static SyntaxToken CreateUniqueName(
        ISemanticFactsService semanticFacts, SemanticModel model, SyntaxNode location, string baseName, CancellationToken cancellationToken)
        => semanticFacts.GenerateUniqueLocalName(model, location, container: null, baseName, cancellationToken);
 
    protected static SyntaxNode GetCollectionVariableName(
        SemanticModel model, SyntaxGenerator generator,
        ForEachInfo foreachInfo, SyntaxNode foreachCollectionExpression, CancellationToken cancellationToken)
    {
        if (foreachInfo.RequireCollectionStatement)
        {
            return generator.IdentifierName(
                CreateUniqueName(foreachInfo.SemanticFacts,
                    model, foreachInfo.ForEachStatement, foreachInfo.CollectionNameSuggestion, cancellationToken));
        }
 
        return foreachCollectionExpression.WithoutTrivia().WithAdditionalAnnotations(Formatter.Annotation);
    }
 
    protected static void IntroduceCollectionStatement(
        ForEachInfo foreachInfo, SyntaxEditor editor,
        SyntaxNode type, SyntaxNode foreachCollectionExpression, SyntaxNode collectionVariable)
    {
        if (!foreachInfo.RequireCollectionStatement)
        {
            return;
        }
 
        // TODO: refactor introduce variable refactoring to real service and use that service here to introduce local variable
        var generator = editor.Generator;
 
        // attach rename annotation to control variable
        var collectionVariableToken = generator.Identifier(collectionVariable.ToString()).WithAdditionalAnnotations(RenameAnnotation.Create());
 
        // this expression is from user code. don't simplify this.
        var expression = foreachCollectionExpression.WithoutAnnotations(SimplificationHelpers.DoNotSimplifyAnnotation);
        var collectionStatement = generator.LocalDeclarationStatement(
            type,
            collectionVariableToken,
            (foreachInfo.ExplicitCastInterface != null) ? generator.CastExpression(foreachInfo.ExplicitCastInterface, expression) : expression);
 
        // attach trivia to right place
        collectionStatement = collectionStatement.WithLeadingTrivia(foreachInfo.ForEachStatement.GetFirstToken().LeadingTrivia);
 
        editor.InsertBefore(foreachInfo.ForEachStatement, collectionStatement);
    }
 
    protected static TStatementSyntax AddItemVariableDeclaration(
        SyntaxGenerator generator, SyntaxNode type, SyntaxToken foreachVariable,
        ITypeSymbol castType, SyntaxNode collectionVariable, SyntaxToken indexVariable)
    {
        var memberAccess = generator.ElementAccessExpression(
                collectionVariable, generator.IdentifierName(indexVariable));
 
        if (castType != null)
        {
            memberAccess = generator.CastExpression(castType, memberAccess);
        }
 
        var localDecl = generator.LocalDeclarationStatement(
            type, foreachVariable, memberAccess);
 
        return (TStatementSyntax)localDecl.WithAdditionalAnnotations(Formatter.Annotation);
    }
 
    private ForEachInfo? GetForeachInfo(
        ISemanticFactsService semanticFact, SemanticModel model,
        TForEachStatement foreachStatement, CancellationToken cancellationToken)
    {
        if (model.GetOperation(foreachStatement, cancellationToken) is not IForEachLoopOperation operation || operation.Locals.Length != 1)
        {
            return null;
        }
 
        var foreachVariable = operation.Locals[0];
        if (foreachVariable == null)
        {
            return null;
        }
 
        // Perform language specific checks if the foreachStatement
        // is using unsupported features
        if (!IsSupported(foreachVariable, operation, foreachStatement))
        {
            return null;
        }
 
        IsForEachVariableWrittenInside = CheckIfForEachVariableIsWrittenInside(model, foreachVariable, foreachStatement);
 
        var foreachCollection = RemoveImplicitConversion(operation.Collection);
        if (foreachCollection == null)
        {
            return null;
        }
 
        GetInterfaceInfo(model, foreachVariable, foreachCollection,
            out var explicitCastInterface, out var collectionNameSuggestion, out var countName);
        if (collectionNameSuggestion == null || countName == null)
        {
            return null;
        }
 
        var requireCollectionStatement = CheckRequireCollectionStatement(foreachCollection);
        return new ForEachInfo(
            semanticFact, collectionNameSuggestion, countName, explicitCastInterface,
            foreachVariable.Type, requireCollectionStatement, foreachStatement);
    }
 
    private static void GetInterfaceInfo(
        SemanticModel model, ILocalSymbol foreachVariable, IOperation foreachCollection,
        out ITypeSymbol? explicitCastInterface, out string? collectionNameSuggestion, out string? countName)
    {
        explicitCastInterface = null;
        collectionNameSuggestion = null;
        countName = null;
 
        // go through list of types and interfaces to find out right set;
        var foreachType = foreachVariable.Type;
        if (IsNullOrErrorType(foreachType))
        {
            return;
        }
 
        var collectionType = foreachCollection.Type;
        if (IsNullOrErrorType(collectionType))
        {
            return;
        }
 
        // go through explicit types first.
 
        // check array case
        if (collectionType is IArrayTypeSymbol array)
        {
            if (array.Rank != 1)
            {
                // array type supports IList and other interfaces, but implementation
                // only supports Rank == 1 case. other case, it will throw on runtime
                // even if there is no error on compile time.
                // we explicitly mark that we only support Rank == 1 case
                return;
            }
 
            if (!IsExchangable(array.ElementType, foreachType, model.Compilation))
            {
                return;
            }
 
            collectionNameSuggestion = "array";
            explicitCastInterface = null;
            countName = Length;
            return;
        }
 
        // check string case
        if (collectionType.SpecialType == SpecialType.System_String)
        {
            var charType = model.Compilation.GetSpecialType(SpecialType.System_Char);
            if (!IsExchangable(charType, foreachType, model.Compilation))
            {
                return;
            }
 
            collectionNameSuggestion = "str";
            explicitCastInterface = null;
            countName = Length;
            return;
        }
 
        // check ImmutableArray case 
        if (collectionType.OriginalDefinition.Equals(model.Compilation.GetTypeByMetadataName(typeof(ImmutableArray<>).FullName!)))
        {
            var indexer = GetInterfaceMember(collectionType, get_Item);
            if (indexer != null)
            {
                if (!IsExchangable(indexer.ReturnType, foreachType, model.Compilation))
                {
                    return;
                }
 
                collectionNameSuggestion = "array";
                explicitCastInterface = null;
                countName = Length;
                return;
            }
        }
 
        // go through all known interfaces we support next.
        var knownCollectionInterfaces = s_KnownInterfaceNames.Select(
            model.Compilation.GetTypeByMetadataName).Where(t => !IsNullOrErrorType(t));
 
        // for all interfaces, we suggest collection name as "list"
        collectionNameSuggestion = "list";
 
        // check type itself is interface case
        if (collectionType.TypeKind == TypeKind.Interface && knownCollectionInterfaces.Contains(collectionType.OriginalDefinition))
        {
            var indexer = GetInterfaceMember(collectionType, get_Item);
            if (indexer != null &&
                IsExchangable(indexer.ReturnType, foreachType, model.Compilation))
            {
                explicitCastInterface = null;
                countName = Count;
                return;
            }
        }
 
        // check regular cases (implicitly implemented)
        ITypeSymbol? explicitInterface = null;
        foreach (var current in collectionType.AllInterfaces)
        {
            if (!knownCollectionInterfaces.Contains(current.OriginalDefinition))
            {
                continue;
            }
 
            // see how the type implements the interface
            var countSymbol = GetInterfaceMember(current, get_Count);
            var indexerSymbol = GetInterfaceMember(current, get_Item);
            if (countSymbol == null || indexerSymbol == null)
            {
                continue;
            }
 
            if (collectionType.FindImplementationForInterfaceMember(countSymbol) is not IMethodSymbol countImpl ||
                collectionType.FindImplementationForInterfaceMember(indexerSymbol) is not IMethodSymbol indexerImpl)
            {
                continue;
            }
 
            if (!IsExchangable(indexerImpl.ReturnType, foreachType, model.Compilation))
            {
                continue;
            }
 
            // implicitly implemented!
            if (countImpl.ExplicitInterfaceImplementations.IsEmpty &&
                indexerImpl.ExplicitInterfaceImplementations.IsEmpty)
            {
                explicitCastInterface = null;
                countName = Count;
                return;
            }
 
            explicitInterface ??= current;
        }
 
        // okay, we don't have implicitly implemented one, but we do have explicitly implemented one
        if (explicitInterface != null)
        {
            explicitCastInterface = explicitInterface;
            countName = Count;
        }
    }
 
    private static bool IsExchangable(
        ITypeSymbol type1, ITypeSymbol type2, Compilation compilation)
    {
        return compilation.HasImplicitConversion(type1, type2) ||
               compilation.HasImplicitConversion(type2, type1);
    }
 
    private static bool IsNullOrErrorType([NotNullWhen(false)] ITypeSymbol? type)
        => type is null or IErrorTypeSymbol;
 
    private static IMethodSymbol? GetInterfaceMember(ITypeSymbol interfaceType, string memberName)
    {
        foreach (var current in interfaceType.GetAllInterfacesIncludingThis())
        {
            var members = current.GetMembers(memberName);
            if (members is [IMethodSymbol method, ..])
                return method;
        }
 
        return null;
    }
 
    private static bool CheckRequireCollectionStatement(IOperation operation)
    {
        // this lists type of references in collection part of foreach we will use
        // as it is in
        //    var element = reference[indexer];
        //
        // otherwise, we will introduce local variable for the expression first and then
        // do "foreach to for" refactoring
        //
        // foreach(var a in new int[] {....}) 
        // to
        // var array = new int[] { ... }
        // foreach(var a in array)
        switch (operation.Kind)
        {
            case OperationKind.LocalReference:
            case OperationKind.FieldReference:
            case OperationKind.ParameterReference:
            case OperationKind.PropertyReference:
            case OperationKind.ArrayElementReference:
                return false;
            default:
                return true;
        }
    }
 
    private static IOperation RemoveImplicitConversion(IOperation collection)
    {
        return (collection is IConversionOperation conversion && conversion.IsImplicit)
            ? RemoveImplicitConversion(conversion.Operand) : collection;
    }
 
    private bool CheckIfForEachVariableIsWrittenInside(SemanticModel semanticModel, ISymbol foreachVariable, TForEachStatement foreachStatement)
    {
        var (start, end) = GetForEachBody(foreachStatement);
        if (start == null || end == null)
        {
            // empty body. this can happen in VB
            return false;
        }
 
        var dataFlow = semanticModel.AnalyzeDataFlow(start, end);
 
        if (!dataFlow.Succeeded)
        {
            // if we can't get good analysis, assume it is written
            return true;
        }
 
        return dataFlow.WrittenInside.Contains(foreachVariable);
    }
 
    private async Task<Document> ConvertForeachToForAsync(
        Document document,
        ForEachInfo foreachInfo,
        CancellationToken cancellationToken)
    {
        var model = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var services = document.Project.Solution.Services;
        var editor = new SyntaxEditor(model.SyntaxTree.GetRoot(cancellationToken), services);
 
        ConvertToForStatement(model, foreachInfo, editor, cancellationToken);
 
        var newRoot = editor.GetChangedRoot();
        return document.WithSyntaxRoot(newRoot);
    }
 
    protected sealed class ForEachInfo(
        ISemanticFactsService semanticFacts, string collectionNameSuggestion, string countName,
        ITypeSymbol? explicitCastInterface, ITypeSymbol forEachElementType,
        bool requireCollectionStatement, TForEachStatement forEachStatement)
    {
        public ISemanticFactsService SemanticFacts { get; } = semanticFacts;
 
        public string CollectionNameSuggestion { get; } = collectionNameSuggestion;
        public string CountName { get; } = countName;
        public ITypeSymbol? ExplicitCastInterface { get; } = explicitCastInterface;
        public ITypeSymbol ForEachElementType { get; } = forEachElementType;
        public bool RequireCollectionStatement { get; } = requireCollectionStatement || (explicitCastInterface != null);
        public TForEachStatement ForEachStatement { get; } = forEachStatement;
    }
}