File: ConvertLinq\ConvertForEachToLinqQuery\AbstractConvertForEachToLinqQueryProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Shared.Extensions;
 
namespace Microsoft.CodeAnalysis.ConvertLinq.ConvertForEachToLinqQuery;
 
internal abstract class AbstractConvertForEachToLinqQueryProvider<TForEachStatement, TStatement> : CodeRefactoringProvider
    where TForEachStatement : TStatement
    where TStatement : SyntaxNode
{
    /// <summary>
    /// Parses the forEachStatement until a child node cannot be converted into a query clause.
    /// </summary>
    /// <param name="forEachStatement"></param>
    /// <returns>
    /// 1) extended nodes that can be converted into clauses
    /// 2) identifiers introduced in nodes that can be converted
    /// 3) statements that cannot be converted into clauses
    /// 4) trailing comments to be added to the end
    /// </returns>
    protected abstract ForEachInfo<TForEachStatement, TStatement> CreateForEachInfo(
        TForEachStatement forEachStatement, SemanticModel semanticModel, bool convertLocalDeclarations);
 
    /// <summary>
    /// Tries to build a specific converter that covers e.g. Count, ToList, yield return or other similar cases.
    /// </summary>
    protected abstract bool TryBuildSpecificConverter(
        ForEachInfo<TForEachStatement, TStatement> forEachInfo,
        SemanticModel semanticModel,
        TStatement statementCannotBeConverted,
        CancellationToken cancellationToken,
        [NotNullWhen(true)] out IConverter<TForEachStatement, TStatement>? converter);
 
    /// <summary>
    /// Creates a default converter where foreach is joined with some children statements but other children statements are kept unmodified.
    /// Example:
    /// foreach(...)
    /// {
    ///     if (condition)
    ///     {
    ///        doSomething();
    ///     }
    /// }
    /// is converted to
    /// foreach(... where condition)
    /// {
    ///        doSomething(); 
    /// }
    /// </summary>
    protected abstract IConverter<TForEachStatement, TStatement> CreateDefaultConverter(
        ForEachInfo<TForEachStatement, TStatement> forEachInfo);
 
    protected abstract SyntaxNode AddLinqUsing(
        IConverter<TForEachStatement, TStatement> converter, SemanticModel semanticModel, SyntaxNode root);
 
    public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, _, cancellationToken) = context;
 
        var forEachStatement = await context.TryGetRelevantNodeAsync<TForEachStatement>().ConfigureAwait(false);
        if (forEachStatement == null)
        {
            return;
        }
 
        if (forEachStatement.ContainsDiagnostics)
        {
            return;
        }
 
        // Do not try to refactor queries with conditional compilation in them.
        if (forEachStatement.ContainsDirectives)
        {
            return;
        }
 
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        if (!TryBuildConverter(forEachStatement, semanticModel, convertLocalDeclarations: true, cancellationToken, out var queryConverter))
        {
            return;
        }
 
        if (semanticModel.GetDiagnostics(forEachStatement.Span, cancellationToken)
            .Any(static diagnostic => diagnostic.DefaultSeverity == DiagnosticSeverity.Error))
        {
            return;
        }
 
        // Offer refactoring to convert foreach to LINQ query expression. For example:
        //
        // INPUT:
        //  foreach (var n1 in c1)
        //      foreach (var n2 in c2)
        //          if (n1 > n2)
        //              yield return n1 + n2;
        //
        // OUTPUT:
        //  from n1 in c1
        //  from n2 in c2
        //  where n1 > n2
        //  select n1 + n2
        //
        context.RegisterRefactoring(
            CodeAction.Create(
                FeaturesResources.Convert_to_linq,
                c => ApplyConversionAsync(queryConverter, document, convertToQuery: true, c),
                nameof(FeaturesResources.Convert_to_linq)),
            forEachStatement.Span);
 
        // Offer refactoring to convert foreach to LINQ invocation expression. For example:
        //
        // INPUT:
        //   foreach (var n1 in c1)
        //      foreach (var n2 in c2)
        //          if (n1 > n2)
        //              yield return n1 + n2;
        //
        // OUTPUT:
        //   c1.SelectMany(n1 => c2.Where(n2 => n1 > n2).Select(n2 => n1 + n2))
        //
        // CONSIDER: Currently, we do not handle foreach statements with a local declaration statement for Convert_to_linq refactoring,
        //           though we do handle them for the Convert_to_query refactoring above.
        //           In future, we can consider supporting them by creating anonymous objects/tuples wrapping the local declarations.
        if (TryBuildConverter(forEachStatement, semanticModel, convertLocalDeclarations: false, cancellationToken, out var linqConverter))
        {
            context.RegisterRefactoring(
                CodeAction.Create(
                    FeaturesResources.Convert_to_linq_call_form,
                    c => ApplyConversionAsync(linqConverter, document, convertToQuery: false, c),
                    nameof(FeaturesResources.Convert_to_linq_call_form)),
                forEachStatement.Span);
        }
    }
 
    private Task<Document> ApplyConversionAsync(
        IConverter<TForEachStatement, TStatement> converter,
        Document document,
        bool convertToQuery,
        CancellationToken cancellationToken)
    {
        var editor = new SyntaxEditor(converter.ForEachInfo.SemanticModel.SyntaxTree.GetRoot(cancellationToken), document.Project.Solution.Services);
        converter.Convert(editor, convertToQuery, cancellationToken);
        var newRoot = editor.GetChangedRoot();
        var rootWithLinqUsing = AddLinqUsing(converter, converter.ForEachInfo.SemanticModel, newRoot);
        return Task.FromResult(document.WithSyntaxRoot(rootWithLinqUsing));
    }
 
    /// <summary>
    /// Tries to build either a specific or a default converter.
    /// </summary>
    private bool TryBuildConverter(
        TForEachStatement forEachStatement,
        SemanticModel semanticModel,
        bool convertLocalDeclarations,
        CancellationToken cancellationToken,
        [NotNullWhen(true)] out IConverter<TForEachStatement, TStatement>? converter)
    {
        var forEachInfo = CreateForEachInfo(forEachStatement, semanticModel, convertLocalDeclarations);
 
        if (forEachInfo.Statements.Length == 1 &&
            TryBuildSpecificConverter(forEachInfo, semanticModel, forEachInfo.Statements.Single(), cancellationToken, out converter))
        {
            return true;
        }
 
        // No sense to convert a single foreach to foreach over the same collection.
        if (forEachInfo.ConvertingExtendedNodes.Length >= 1)
        {
            converter = CreateDefaultConverter(forEachInfo);
            return true;
        }
 
        converter = null;
        return false;
    }
}