File: PullMemberUp\MembersPuller.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.  
 
#nullable disable
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.AddImport;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.PullMemberUp;
using Microsoft.CodeAnalysis.RemoveUnnecessaryImports;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;
using static Microsoft.CodeAnalysis.CodeActions.CodeAction;
 
namespace Microsoft.CodeAnalysis.CodeRefactorings.PullMemberUp;
 
internal static class MembersPuller
{
    /// <summary>
    /// Annotation used to mark imports that we move over, so that we can remove these imports if they are unnecessary
    /// (and so we don't remove any other unnecessary imports)
    /// </summary>
    private static readonly SyntaxAnnotation s_removableImportAnnotation = new("PullMemberRemovableImport");
    private static readonly SyntaxAnnotation s_destinationNodeAnnotation = new("DestinationNode");
 
    public static CodeAction TryComputeCodeAction(
        Document document,
        ImmutableArray<ISymbol> selectedMembers,
        INamedTypeSymbol destination)
    {
        var result = PullMembersUpOptionsBuilder.BuildPullMembersUpOptions(destination,
            selectedMembers.SelectAsArray(m => (member: m, makeAbstract: false)));
 
        if (result.PullUpOperationNeedsToDoExtraChanges ||
            selectedMembers.Any(IsSelectedMemberDeclarationAlreadyInDestination, destination))
        {
            return null;
        }
 
        var title = result.Destination.Name;
        return SolutionChangeAction.Create(
            title,
            cancellationToken => PullMembersUpAsync(document, result, cancellationToken),
            title);
    }
 
    public static Task<Solution> PullMembersUpAsync(
        Document document,
        PullMembersUpOptions pullMembersUpOptions,
        CancellationToken cancellationToken)
    {
        return pullMembersUpOptions.Destination.TypeKind switch
        {
            TypeKind.Interface => PullMembersIntoInterfaceAsync(document, pullMembersUpOptions, cancellationToken),
            // We can treat VB modules as a static class
            TypeKind.Class or TypeKind.Module => PullMembersIntoClassAsync(document, pullMembersUpOptions, cancellationToken),
            _ => throw ExceptionUtilities.UnexpectedValue(pullMembersUpOptions.Destination),
        };
    }
 
    private static IMethodSymbol FilterOutNonPublicAccessor(IMethodSymbol getterOrSetter)
    {
        // We are pulling a public property, it could have a public getter/setter but 
        // the other getter/setter is not.
        // In this scenario, only the public getter/setter
        // will be add to the destination interface.
        return getterOrSetter?.DeclaredAccessibility == Accessibility.Public ? getterOrSetter : null;
    }
 
    private static IMethodSymbol MakePublicAccessor(IMethodSymbol getterOrSetter)
    {
        // Create a public getter/setter since user is trying to pull a non-public property to an interface.
        // If getterOrSetter is null, it means this property doesn't have a getter/setter, so just don't generate it.
        return getterOrSetter == null
            ? getterOrSetter
            : CodeGenerationSymbolFactory.CreateMethodSymbol(getterOrSetter, accessibility: Accessibility.Public);
    }
 
    private static async Task<Solution> PullMembersIntoInterfaceAsync(
        Document document,
        PullMembersUpOptions pullMemberUpOptions,
        CancellationToken cancellationToken)
    {
        var solution = document.Project.Solution;
        var solutionEditor = new SolutionEditor(solution);
        var codeGenerationService = document.Project.Services.GetRequiredService<ICodeGenerationService>();
        var destinationSyntaxNode = await codeGenerationService.FindMostRelevantNameSpaceOrTypeDeclarationAsync(
            solution, pullMemberUpOptions.Destination, location: null, cancellationToken).ConfigureAwait(false);
        var symbolToDeclarationsMap = await InitializeSymbolToDeclarationsMapAsync(pullMemberUpOptions, cancellationToken).ConfigureAwait(false);
        var symbolsToPullUp = pullMemberUpOptions.MemberAnalysisResults.SelectAsArray(GetSymbolsToPullUp);
 
        var destinationEditor = await solutionEditor.GetDocumentEditorAsync(
            solution.GetDocumentId(destinationSyntaxNode.SyntaxTree),
            cancellationToken).ConfigureAwait(false);
 
        // Add members to interface
        var context = new CodeGenerationContext(
            generateMethodBodies: false,
            generateMembers: false);
 
        var info = await destinationEditor.OriginalDocument.GetCodeGenerationInfoAsync(context, cancellationToken).ConfigureAwait(false);
        var destinationWithMembersAdded = info.Service.AddMembers(destinationSyntaxNode, symbolsToPullUp, info, cancellationToken);
 
        destinationEditor.ReplaceNode(destinationSyntaxNode, (syntaxNode, generator) => destinationWithMembersAdded);
 
        // Change original members
        foreach (var analysisResult in pullMemberUpOptions.MemberAnalysisResults)
        {
            foreach (var declaration in symbolToDeclarationsMap[analysisResult.Member])
            {
                var originalMemberEditor = await solutionEditor.GetDocumentEditorAsync(
                    solution.GetDocumentId(declaration.SyntaxTree),
                    cancellationToken).ConfigureAwait(false);
 
                if (analysisResult.Member.ContainingType.TypeKind == TypeKind.Interface)
                {
                    // If we are pulling member from interface to interface, the original member should be removed.
                    // Also don't need to worry about other changes since the member is removed
                    originalMemberEditor.RemoveNode(originalMemberEditor.Generator.GetDeclaration(declaration));
                }
                else
                {
                    if (analysisResult.ChangeOriginalToNonStatic || analysisResult.ChangeOriginalToPublic)
                    {
                        ChangeMemberToPublicAndNonStatic(
                            codeGenerationService, originalMemberEditor,
                            declaration, analysisResult.Member,
                            info, cancellationToken);
                    }
                }
            }
        }
 
        return solutionEditor.GetChangedSolution();
    }
 
    private static ISymbol GetSymbolsToPullUp(MemberAnalysisResult analysisResult)
    {
        var member = analysisResult.Member;
        // We don't support generating static interface members, so we need to update to non-static before generating.
        var modifier = DeclarationModifiers.From(member).WithIsStatic(false);
        if (member is IPropertySymbol propertySymbol)
        {
            // Property is treated differently since we need to make sure it gives right accessor symbol to ICodeGenerationService,
            // otherwise ICodeGenerationService won't give the expected declaration.
            if (analysisResult.ChangeOriginalToPublic)
            {
                // We are pulling a non-public property, change its getter/setter to public and itself to be public.
                return CodeGenerationSymbolFactory.CreatePropertySymbol(
                    propertySymbol,
                    accessibility: Accessibility.Public,
                    modifiers: modifier,
                    getMethod: MakePublicAccessor(propertySymbol.GetMethod),
                    setMethod: MakePublicAccessor(propertySymbol.SetMethod));
            }
            else
            {
                // We are pulling a public property, filter the non-public getter/setter.
                return CodeGenerationSymbolFactory.CreatePropertySymbol(
                    propertySymbol,
                    modifiers: modifier,
                    getMethod: FilterOutNonPublicAccessor(propertySymbol.GetMethod),
                    setMethod: FilterOutNonPublicAccessor(propertySymbol.SetMethod));
            }
        }
        else if (member is IMethodSymbol methodSymbol)
        {
            return CodeGenerationSymbolFactory.CreateMethodSymbol(methodSymbol, modifiers: modifier);
        }
        else if (member is IEventSymbol eventSymbol)
        {
            return CodeGenerationSymbolFactory.CreateEventSymbol(eventSymbol, modifiers: modifier);
        }
        else if (member is IFieldSymbol fieldSymbol)
        {
            return CodeGenerationSymbolFactory.CreateFieldSymbol(fieldSymbol, modifiers: modifier);
        }
 
        // ICodeGenerationService will give the right result if it is method or event
        return member;
    }
 
    private static void ChangeMemberToPublicAndNonStatic(
        ICodeGenerationService codeGenerationService,
        DocumentEditor editor,
        SyntaxNode memberDeclaration,
        ISymbol member,
        CodeGenerationContextInfo info,
        CancellationToken cancellationToken)
    {
        var modifiers = DeclarationModifiers.From(member).WithIsStatic(false);
        // Event is different since several events may be declared in one line.
        if (member is IEventSymbol eventSymbol)
        {
            ChangeEventToPublicAndNonStatic(
                codeGenerationService,
                editor,
                eventSymbol,
                memberDeclaration,
                modifiers,
                info,
                cancellationToken);
        }
        else
        {
            editor.SetAccessibility(memberDeclaration, Accessibility.Public);
            editor.SetModifiers(memberDeclaration, modifiers);
        }
    }
 
    private static void ChangeEventToPublicAndNonStatic(
        ICodeGenerationService codeGenerationService,
        DocumentEditor editor,
        IEventSymbol eventSymbol,
        SyntaxNode eventDeclaration,
        DeclarationModifiers modifiers,
        CodeGenerationContextInfo info,
        CancellationToken cancellationToken)
    {
        var declaration = editor.Generator.GetDeclaration(eventDeclaration);
        var isEventHasExplicitAddOrRemoveMethod =
            (eventSymbol.AddMethod != null && !eventSymbol.AddMethod.IsImplicitlyDeclared) ||
            (eventSymbol.RemoveMethod != null && !eventSymbol.RemoveMethod.IsImplicitlyDeclared);
        // There are three situations here:
        // 1. Single Event.
        // 2. Several events exist in one declaration.
        // 3. Event has add or remove method(user declared).
        // For situation 1, declaration is EventFieldDeclaration, eventDeclaration is variableDeclaration.
        // For situation 2, declaration and eventDeclaration are both EventDeclaration, which are same.
        // For situation 3, it is same as situation 2, but has add or remove method.
        if (declaration.Equals(eventDeclaration) && !isEventHasExplicitAddOrRemoveMethod)
        {
            // Several events are declared in same line
            var publicAndNonStaticSymbol = CodeGenerationSymbolFactory.CreateEventSymbol(
                eventSymbol,
                accessibility: Accessibility.Public,
                modifiers: modifiers);
 
            var eventGenerationInfo = info.WithContext(new CodeGenerationContext(generateMethodBodies: false));
 
            var publicAndNonStaticSyntax = codeGenerationService.CreateEventDeclaration(publicAndNonStaticSymbol, CodeGenerationDestination.ClassType, eventGenerationInfo, cancellationToken);
            // Insert a new declaration and remove the original declaration
            editor.InsertAfter(declaration, publicAndNonStaticSyntax);
            editor.RemoveNode(eventDeclaration);
        }
        else
        {
            // Handle both single event and event has add or remove method
            editor.SetAccessibility(declaration, Accessibility.Public);
            editor.SetModifiers(declaration, modifiers);
        }
    }
 
    private static async Task<Solution> PullMembersIntoClassAsync(
        Document document,
        PullMembersUpOptions result,
        CancellationToken cancellationToken)
    {
        var solution = document.Project.Solution;
        var solutionEditor = new SolutionEditor(solution);
        var codeGenerationService = document.Project.Services.GetRequiredService<ICodeGenerationService>();
 
        var destinationSyntaxNode = await codeGenerationService.FindMostRelevantNameSpaceOrTypeDeclarationAsync(
            solution, result.Destination, location: null, cancellationToken).ConfigureAwait(false);
 
        var destinationEditor = await solutionEditor.GetDocumentEditorAsync(
            solution.GetDocumentId(destinationSyntaxNode.SyntaxTree),
            cancellationToken).ConfigureAwait(false);
 
        var symbolToDeclarations = await InitializeSymbolToDeclarationsMapAsync(result, cancellationToken).ConfigureAwait(false);
 
        // Add members to destination
        var pullUpMembersSymbols = result.MemberAnalysisResults.SelectAsArray(
            memberResult =>
            {
                if (memberResult.MakeMemberDeclarationAbstract && !memberResult.Member.IsKind(SymbolKind.Field))
                {
                    // Change the member to abstract if user choose to make them abstract
                    return MakeAbstractVersion(memberResult.Member);
                }
                else
                {
                    return memberResult.Member;
                }
            });
 
        var context = new CodeGenerationContext(
            reuseSyntax: true,
            generateMethodBodies: false);
 
        var options = await destinationEditor.OriginalDocument.GetCleanCodeGenerationOptionsAsync(cancellationToken).ConfigureAwait(false);
        var info = codeGenerationService.GetInfo(context, options.GenerationOptions, destinationEditor.OriginalDocument.Project.ParseOptions);
 
        var newDestination = codeGenerationService
            .AddMembers(destinationSyntaxNode, pullUpMembersSymbols, info, cancellationToken)
            .WithAdditionalAnnotations(s_destinationNodeAnnotation);
        using var _ = PooledHashSet<SyntaxNode>.GetInstance(out var sourceImports);
 
        var syntaxFacts = destinationEditor.OriginalDocument.GetRequiredLanguageService<ISyntaxFactsService>();
 
        // Remove some original members since we are pulling members into class.
        // Note: If the user chooses to make the member abstract, then the original member will be changed to an override,
        // and it will pull an abstract declaration up to the destination.
        // But if the member is abstract itself, it will still be removed.
        foreach (var analysisResult in result.MemberAnalysisResults)
        {
            var resultNamespace = analysisResult.Member.ContainingNamespace;
            if (!resultNamespace.IsGlobalNamespace)
            {
                sourceImports.Add(
                    destinationEditor.Generator.NamespaceImportDeclaration(
                        resultNamespace.ToDisplayString(SymbolDisplayFormats.NameFormat))
                    .WithAdditionalAnnotations(s_removableImportAnnotation));
            }
 
            foreach (var syntax in symbolToDeclarations[analysisResult.Member])
            {
                var originalMemberEditor = await solutionEditor.GetDocumentEditorAsync(
                    solution.GetDocumentId(syntax.SyntaxTree),
                    cancellationToken).ConfigureAwait(false);
 
                sourceImports.AddRange(GetImports(syntax, syntaxFacts)
                    .Select(import => import
                        .WithoutLeadingTrivia()
                        .WithTrailingTrivia(originalMemberEditor.Generator.ElasticCarriageReturnLineFeed)
                        .WithAdditionalAnnotations(s_removableImportAnnotation)));
 
                if (!analysisResult.MakeMemberDeclarationAbstract || analysisResult.Member.IsAbstract)
                {
                    originalMemberEditor.RemoveNode(originalMemberEditor.Generator.GetDeclaration(syntax));
                }
                else
                {
                    var declarationSyntax = originalMemberEditor.Generator.GetDeclaration(syntax);
                    originalMemberEditor.ReplaceNode(declarationSyntax, (node, generator) => generator.WithModifiers(node, DeclarationModifiers.Override));
                }
            }
        }
 
        // Change the destination to abstract class if needed.
        if (!result.Destination.IsAbstract &&
            result.MemberAnalysisResults.Any(static analysis => analysis.Member.IsAbstract || analysis.MakeMemberDeclarationAbstract))
        {
            var modifiers = DeclarationModifiers.From(result.Destination).WithIsAbstract(true);
            newDestination = destinationEditor.Generator.WithModifiers(newDestination, modifiers);
        }
 
        destinationEditor.ReplaceNode(destinationSyntaxNode, newDestination);
 
        // add imports by moving all source imports to destination container, then taking out unneccessary
        // imports that we just added (marked by our annotation).
        var addImportsService = destinationEditor.OriginalDocument.GetRequiredLanguageService<IAddImportsService>();
 
        var destinationTrivia = GetLeadingTriviaBeforeFirstMember(destinationEditor.OriginalRoot, syntaxFacts);
 
        destinationEditor.ReplaceNode(destinationEditor.OriginalRoot, (root, _) =>
            RemoveLeadingTriviaBeforeFirstMember(root, syntaxFacts));
 
        destinationEditor.ReplaceNode(destinationEditor.OriginalRoot, (node, generator) => addImportsService.AddImports(
            destinationEditor.SemanticModel.Compilation,
            node,
            node.GetAnnotatedNodes(s_destinationNodeAnnotation).FirstOrDefault(),
            sourceImports,
            generator,
            options.CleanupOptions.AddImportOptions,
            cancellationToken));
 
        var removeImportsService = destinationEditor.OriginalDocument.GetRequiredLanguageService<IRemoveUnnecessaryImportsService>();
        var destinationDocument = await removeImportsService.RemoveUnnecessaryImportsAsync(
            destinationEditor.GetChangedDocument(),
            node => node.HasAnnotation(s_removableImportAnnotation),
            cancellationToken).ConfigureAwait(false);
 
        // Format whitespace trivia within the import statements we pull up
        destinationDocument = await Formatter.FormatAsync(destinationDocument, s_removableImportAnnotation, options.CleanupOptions.FormattingOptions, cancellationToken).ConfigureAwait(false);
 
        var destinationRoot = AddLeadingTriviaBeforeFirstMember(
            await destinationDocument.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false),
            syntaxFacts,
            destinationTrivia);
 
        destinationEditor.ReplaceNode(destinationEditor.OriginalRoot, destinationRoot);
 
        return solutionEditor.GetChangedSolution();
    }
 
    /// <summary>
    /// In the case where we have leading whitespace in front of the first member and there are no imports, adding imports
    /// moves that trivia to above the import (and sometimes removes it entirely if the import is later removed). 
    /// So, we want to cache the trivia before, delete it, then add it back in after the imports are added.
    /// </summary>
    private static SyntaxTriviaList GetLeadingTriviaBeforeFirstMember(SyntaxNode root, ISyntaxFactsService syntaxFacts)
    {
        var members = syntaxFacts.GetMembersOfCompilationUnit(root);
        // guaranteed to have at least one member, as we need a base class
        var firstMember = members.First();
        return firstMember.GetLeadingTrivia();
    }
 
    private static SyntaxNode RemoveLeadingTriviaBeforeFirstMember(SyntaxNode root, ISyntaxFactsService syntaxFacts)
    {
        var members = syntaxFacts.GetMembersOfCompilationUnit(root);
        // guaranteed to have at least one member, as we need a base class
        var firstMember = members.First();
        return root.ReplaceNode(firstMember, firstMember.WithoutLeadingTrivia());
    }
 
    private static SyntaxNode AddLeadingTriviaBeforeFirstMember(SyntaxNode root, ISyntaxFactsService syntaxFacts, SyntaxTriviaList trivia)
    {
        var members = syntaxFacts.GetMembersOfCompilationUnit(root);
        // guaranteed to have at least one member, as we need a base class
        var firstMember = members.First();
        return root.ReplaceNode(firstMember, firstMember.WithLeadingTrivia(trivia));
    }
 
    /// <summary>
    /// Get all import statements in scope for this syntax by traversing up the tree and searching in containing namespaces and compilation units.
    /// </summary>
    /// <param name="start">The node to start traversing up from</param>
    /// <returns>All the import/using directives found along the traversal</returns>
    private static ImmutableArray<SyntaxNode> GetImports(SyntaxNode start, ISyntaxFactsService syntaxFacts)
    {
        return start.AncestorsAndSelf()
            .Where(node => node is ICompilationUnitSyntax || syntaxFacts.IsBaseNamespaceDeclaration(node))
            .SelectManyAsArray(node => node is ICompilationUnitSyntax
                ? syntaxFacts.GetImportsOfCompilationUnit(node)
                : syntaxFacts.GetImportsOfBaseNamespaceDeclaration(node));
    }
 
    private static ISymbol MakeAbstractVersion(ISymbol member)
    {
        if (member.IsAbstract)
        {
            return member;
        }
 
        var modifier = DeclarationModifiers.From(member).WithIsAbstract(true);
        if (member is IMethodSymbol methodSymbol)
        {
            return CodeGenerationSymbolFactory.CreateMethodSymbol(methodSymbol, modifiers: modifier);
        }
        else if (member is IPropertySymbol propertySymbol)
        {
            return CodeGenerationSymbolFactory.CreatePropertySymbol(propertySymbol, modifiers: modifier, getMethod: propertySymbol.GetMethod, setMethod: propertySymbol.SetMethod);
        }
        else if (member is IEventSymbol eventSymbol)
        {
            return CodeGenerationSymbolFactory.CreateEventSymbol(eventSymbol, modifiers: modifier);
        }
        else
        {
            throw ExceptionUtilities.UnexpectedValue(member);
        }
    }
 
    private static async Task<ImmutableDictionary<ISymbol, ImmutableArray<SyntaxNode>>> InitializeSymbolToDeclarationsMapAsync(
        PullMembersUpOptions result,
        CancellationToken cancellationToken)
    {
        // One member may have multiple syntax nodes (e.g partial method).
        // Create a map from ISymbol to SyntaxNode find them more easily.
        var symbolToDeclarationsBuilder = ImmutableDictionary.CreateBuilder<ISymbol, ImmutableArray<SyntaxNode>>();
 
        foreach (var memberAnalysisResult in result.MemberAnalysisResults)
        {
            var tasks = memberAnalysisResult.Member.DeclaringSyntaxReferences.SelectAsArray(@ref => @ref.GetSyntaxAsync(cancellationToken));
            var allSyntaxes = await Task.WhenAll(tasks).ConfigureAwait(false);
            symbolToDeclarationsBuilder.Add(memberAnalysisResult.Member, [.. allSyntaxes]);
        }
 
        return symbolToDeclarationsBuilder.ToImmutableDictionary();
    }
 
    /// <summary>
    ///  This method is used to check whether the selected member overrides the member in destination.
    ///  It just checks the members directly declared in the destination.
    /// </summary>
    private static bool IsSelectedMemberDeclarationAlreadyInDestination(ISymbol selectedMember, INamedTypeSymbol destination)
    {
        return destination.TypeKind == TypeKind.Interface
            ? IsSelectedMemberDeclarationAlreadyInDestinationInterface(selectedMember, destination)
            : IsSelectedMemberDeclarationAlreadyInDestinationClass(selectedMember, destination);
    }
 
    private static bool IsSelectedMemberDeclarationAlreadyInDestinationClass(ISymbol selectedMember, INamedTypeSymbol destination)
    {
        if (selectedMember is IFieldSymbol fieldSymbol)
        {
            // If there is a member with same name in destination, pull the selected field will cause error,
            // so don't provide refactoring under this scenario
            return destination.GetMembers(fieldSymbol.Name).Any();
        }
        else
        {
            var overrideMembersSet = new HashSet<ISymbol>();
            for (var symbol = selectedMember; symbol != null; symbol = symbol.GetOverriddenMember())
                overrideMembersSet.Add(symbol);
 
            // Since the destination and selectedMember may belong different language, so use SymbolEquivalenceComparer as comparer
            return overrideMembersSet.Intersect(destination.GetMembers(), SymbolEquivalenceComparer.Instance).Any();
        }
    }
 
    private static bool IsSelectedMemberDeclarationAlreadyInDestinationInterface(
        ISymbol selectedMember, INamedTypeSymbol destination)
    {
        foreach (var interfaceMember in destination.GetMembers())
        {
            var implementationOfMember = selectedMember.ContainingType.FindImplementationForInterfaceMember(interfaceMember);
            if (SymbolEquivalenceComparer.Instance.Equals(selectedMember, implementationOfMember?.OriginalDefinition))
            {
                return true;
            }
        }
 
        return false;
    }
}