File: IntroduceParameter\AbstractIntroduceParameterCodeRefactoringProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.FindSymbols;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Collections;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.IntroduceParameter;
 
internal abstract partial class AbstractIntroduceParameterCodeRefactoringProvider<
    TExpressionSyntax,
    TInvocationExpressionSyntax,
    TObjectCreationExpressionSyntax,
    TIdentifierNameSyntax,
    TArgumentSyntax> : CodeRefactoringProvider
    where TExpressionSyntax : SyntaxNode
    where TInvocationExpressionSyntax : TExpressionSyntax
    where TObjectCreationExpressionSyntax : TExpressionSyntax
    where TIdentifierNameSyntax : TExpressionSyntax
    where TArgumentSyntax : SyntaxNode
{
    private enum IntroduceParameterCodeActionKind
    {
        Refactor,
        Trampoline,
        Overload
    }
 
    protected abstract SyntaxNode GenerateExpressionFromOptionalParameter(IParameterSymbol parameterSymbol);
    protected abstract SyntaxNode UpdateArgumentListSyntax(SyntaxNode argumentList, SeparatedSyntaxList<TArgumentSyntax> arguments);
    protected abstract SyntaxNode? GetLocalDeclarationFromDeclarator(SyntaxNode variableDecl);
    protected abstract bool IsDestructor(IMethodSymbol methodSymbol);
 
    public sealed override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, textSpan, cancellationToken) = context;
        if (document.Project.Solution.WorkspaceKind == WorkspaceKind.MiscellaneousFiles)
            return;
 
        var expression = await document.TryGetRelevantNodeAsync<TExpressionSyntax>(textSpan, cancellationToken).ConfigureAwait(false);
        if (expression == null || CodeRefactoringHelpers.IsNodeUnderselected(expression, textSpan))
            return;
 
        var generator = SyntaxGenerator.GetGenerator(document);
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        if (!IsValidExpression(expression, syntaxFacts))
            return;
 
        var containingMethod = expression.FirstAncestorOrSelf<SyntaxNode>(node => generator.GetParameterListNode(node) is not null);
        if (containingMethod is null)
            return;
 
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        var expressionType = semanticModel.GetTypeInfo(expression, cancellationToken).Type;
        if (expressionType is null or IErrorTypeSymbol)
            return;
 
        var containingSymbol = semanticModel.GetDeclaredSymbol(containingMethod, cancellationToken);
        if (containingSymbol is not IMethodSymbol methodSymbol)
            return;
 
        var expressionSymbol = semanticModel.GetSymbolInfo(expression, cancellationToken).Symbol;
        if (expressionSymbol is IParameterSymbol parameterSymbol && parameterSymbol.ContainingSymbol.Equals(containingSymbol))
            return;
 
        // Code actions for trampoline and overloads will not be offered if the method is a constructor.
        // Code actions for overloads will not be offered if the method if the method is a local function.
        var methodKind = methodSymbol.MethodKind;
        if (methodKind is not (MethodKind.Ordinary or MethodKind.LocalFunction or MethodKind.Constructor))
            return;
 
        if (IsDestructor(methodSymbol))
            return;
 
        var actions = await GetActionsAsync(document, expression, methodSymbol, containingMethod, cancellationToken).ConfigureAwait(false);
        if (actions is null)
            return;
 
        var nodeString = syntaxFacts.ConvertToSingleLine(expression).ToString();
 
        if (actions.Value.actions.Length > 0)
        {
            context.RegisterRefactoring(
                CodeAction.Create(
                    string.Format(FeaturesResources.Introduce_parameter_for_0, nodeString),
                    actions.Value.actions, isInlinable: false, priority: CodeActionPriority.Low),
                containingMethod.FullSpan);
        }
 
        if (actions.Value.actionsAllOccurrences.Length > 0)
        {
            context.RegisterRefactoring(
                CodeAction.Create(
                    string.Format(FeaturesResources.Introduce_parameter_for_all_occurrences_of_0, nodeString),
                    actions.Value.actionsAllOccurrences, isInlinable: false, priority: CodeActionPriority.Low),
                containingMethod.FullSpan);
        }
    }
 
    private static bool IsValidExpression(SyntaxNode expression, ISyntaxFactsService syntaxFacts)
    {
        // Need to special case for highlighting of method types because they are also "contained" within a method,
        // but it does not make sense to introduce a parameter in that case.
        if (syntaxFacts.IsInNamespaceOrTypeContext(expression))
            return false;
 
        // Need to special case for expressions whose direct parent is a MemberAccessExpression since they will
        // never introduce a parameter that makes sense in that case.
        if (syntaxFacts.IsNameOfAnyMemberAccessExpression(expression))
            return false;
 
        // Need to special case for expressions that are contained within a parameter or attribute argument
        // because it is technically "contained" within a method, but does not make
        // sense to introduce.
        var invalidNode = expression.FirstAncestorOrSelf<SyntaxNode>(node => syntaxFacts.IsAttributeArgument(node) || syntaxFacts.IsParameter(node));
        return invalidNode is null;
    }
 
    /// <summary>
    /// Creates new code actions for each introduce parameter possibility.
    /// Does not create actions for overloads/trampoline if there are optional parameters or if the methodSymbol
    /// is a constructor.
    /// </summary>
    private async Task<(ImmutableArray<CodeAction> actions, ImmutableArray<CodeAction> actionsAllOccurrences)?> GetActionsAsync(Document document,
        TExpressionSyntax expression, IMethodSymbol methodSymbol, SyntaxNode containingMethod,
        CancellationToken cancellationToken)
    {
        var (shouldDisplay, containsClassExpression) = await ShouldExpressionDisplayCodeActionAsync(
            document, expression, cancellationToken).ConfigureAwait(false);
        if (!shouldDisplay)
            return null;
 
        using var actionsBuilder = TemporaryArray<CodeAction>.Empty;
        using var actionsBuilderAllOccurrences = TemporaryArray<CodeAction>.Empty;
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        var methodCallSites = await FindCallSitesAsync(document, methodSymbol, cancellationToken).ConfigureAwait(false);
 
        if (!containsClassExpression)
        {
            actionsBuilder.Add(CreateNewCodeAction(FeaturesResources.and_update_call_sites_directly, allOccurrences: false, IntroduceParameterCodeActionKind.Refactor));
            actionsBuilderAllOccurrences.Add(CreateNewCodeAction(FeaturesResources.and_update_call_sites_directly, allOccurrences: true, IntroduceParameterCodeActionKind.Refactor));
        }
 
        if (methodSymbol.MethodKind is not MethodKind.Constructor)
        {
            var containsObjectCreationReferences = methodCallSites.Values.Flatten().OfType<TObjectCreationExpressionSyntax>().Any();
            if (!containsObjectCreationReferences)
            {
                actionsBuilder.Add(CreateNewCodeAction(
                    FeaturesResources.into_extracted_method_to_invoke_at_call_sites, allOccurrences: false, IntroduceParameterCodeActionKind.Trampoline));
                actionsBuilderAllOccurrences.Add(CreateNewCodeAction(
                    FeaturesResources.into_extracted_method_to_invoke_at_call_sites, allOccurrences: true, IntroduceParameterCodeActionKind.Trampoline));
            }
 
            if (methodSymbol.MethodKind is not MethodKind.LocalFunction)
            {
                actionsBuilder.Add(CreateNewCodeAction(
                    FeaturesResources.into_new_overload, allOccurrences: false, IntroduceParameterCodeActionKind.Overload));
                actionsBuilderAllOccurrences.Add(CreateNewCodeAction(
                    FeaturesResources.into_new_overload, allOccurrences: true, IntroduceParameterCodeActionKind.Overload));
            }
        }
 
        return (actionsBuilder.ToImmutableAndClear(), actionsBuilderAllOccurrences.ToImmutableAndClear());
 
        // Local function to create a code action with more ease
        CodeAction CreateNewCodeAction(string actionName, bool allOccurrences, IntroduceParameterCodeActionKind selectedCodeAction)
        {
            return CodeAction.Create(
                actionName,
                cancellationToken => IntroduceParameterAsync(document, expression, methodSymbol, containingMethod, methodCallSites, allOccurrences, selectedCodeAction, cancellationToken),
                actionName);
        }
    }
 
    /// <summary>
    /// Determines if the expression is something that should have code actions displayed for it.
    /// Depends upon the identifiers in the expression mapping back to parameters.
    /// Does not handle params parameters.
    /// </summary>
    private static async Task<(bool shouldDisplay, bool containsClassExpression)> ShouldExpressionDisplayCodeActionAsync(
        Document document, TExpressionSyntax expression, CancellationToken cancellationToken)
    {
        var variablesInExpression = expression.DescendantNodes();
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        foreach (var variable in variablesInExpression)
        {
            var symbol = semanticModel.GetSymbolInfo(variable, cancellationToken).Symbol;
 
            // If the expression contains locals or range variables then we do not want to offer
            // code actions since there will be errors at call sites.
            if (symbol is IRangeVariableSymbol or ILocalSymbol)
                return default;
 
            if (symbol is IParameterSymbol parameter)
            {
                // We do not want to offer code actions if the expressions contains references
                // to params parameters because it is difficult to know what is being referenced
                // at the callsites.
                if (parameter.IsParams)
                    return default;
            }
        }
 
        // If expression contains this or base keywords, implicitly or explicitly,
        // then we do not want to refactor call sites that are not overloads/trampolines
        // because we do not know if the class specific information is available in other documents.
        var operation = semanticModel.GetOperation(expression, cancellationToken);
        var containsClassSpecificStatement = false;
        if (operation is not null)
        {
            containsClassSpecificStatement = operation.Descendants().Any(op => op.Kind == OperationKind.InstanceReference);
        }
 
        return (true, containsClassSpecificStatement);
    }
 
    /// <summary>
    /// Introduces a new parameter and refactors all the call sites based on the selected code action.
    /// </summary>
    private async Task<Solution> IntroduceParameterAsync(Document originalDocument, TExpressionSyntax expression,
        IMethodSymbol methodSymbol, SyntaxNode containingMethod, Dictionary<Document, List<TExpressionSyntax>> methodCallSites, bool allOccurrences, IntroduceParameterCodeActionKind selectedCodeAction,
        CancellationToken cancellationToken)
    {
        var modifiedSolution = originalDocument.Project.Solution;
        var rewriter = new IntroduceParameterDocumentRewriter(this, originalDocument,
            expression, methodSymbol, containingMethod, selectedCodeAction, allOccurrences);
 
        var changedRoots = await ProducerConsumer<(DocumentId documentId, SyntaxNode newRoot)>.RunParallelAsync(
            source: methodCallSites.GroupBy(kvp => kvp.Key.Project),
            produceItems: static async (tuple, callback, rewriter, cancellationToken) =>
            {
                var (project, projectCallSites) = tuple;
                var compilation = await project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false);
                await RoslynParallel.ForEachAsync(
                    projectCallSites,
                    cancellationToken,
                    async (tuple, cancellationToken) =>
                    {
                        var (document, invocations) = tuple;
                        var newRoot = await rewriter.RewriteDocumentAsync(compilation, document, invocations, cancellationToken).ConfigureAwait(false);
                        callback((document.Id, newRoot));
                    }).ConfigureAwait(false);
            },
            args: rewriter,
            cancellationToken).ConfigureAwait(false);
 
        return modifiedSolution.WithDocumentSyntaxRoots(changedRoots);
    }
 
    /// <summary>
    /// Locates all the call sites of the method that introduced the parameter
    /// </summary>
    protected static async Task<Dictionary<Document, List<TExpressionSyntax>>> FindCallSitesAsync(
        Document document, IMethodSymbol methodSymbol, CancellationToken cancellationToken)
    {
        var methodCallSites = new Dictionary<Document, List<TExpressionSyntax>>();
        var progress = new StreamingProgressCollector();
        await SymbolFinder.FindReferencesAsync(
            methodSymbol, document.Project.Solution, progress,
            documents: null, FindReferencesSearchOptions.Default, cancellationToken).ConfigureAwait(false);
        var referencedSymbols = progress.GetReferencedSymbols();
 
        // Ordering by descending to sort invocations by starting span to account for nested invocations
        var referencedLocations = referencedSymbols
            .SelectMany(referencedSymbol => referencedSymbol.Locations)
            .Distinct()
            .Where(reference => !reference.IsImplicit)
            .OrderByDescending(reference => reference.Location.SourceSpan.Start);
 
        // Adding the original document to ensure that it will be seen again when processing the call sites
        // in order to update the original expression and containing method.
        methodCallSites.Add(document, []);
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
 
        foreach (var refLocation in referencedLocations)
        {
            // Does not support cross-language references currently
            if (refLocation.Document.Project.Language == document.Project.Language)
            {
                var reference = refLocation.Location.FindNode(cancellationToken).GetRequiredParent();
                if (reference is not (TObjectCreationExpressionSyntax or TInvocationExpressionSyntax))
                    reference = reference.GetRequiredParent();
 
                // Only adding items that are of type InvocationExpressionSyntax or TObjectCreationExpressionSyntax
                var invocationOrCreation = reference as TObjectCreationExpressionSyntax ?? (TExpressionSyntax?)(reference as TInvocationExpressionSyntax);
                if (invocationOrCreation is null)
                    continue;
 
                if (!methodCallSites.TryGetValue(refLocation.Document, out var list))
                {
                    list = [];
                    methodCallSites.Add(refLocation.Document, list);
                }
 
                list.Add(invocationOrCreation);
            }
        }
 
        return methodCallSites;
    }
}