File: AddImport\AbstractAddImportCodeRefactoringProvider.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeRefactorings;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Microsoft.CodeAnalysis.Simplification;
 
namespace Microsoft.CodeAnalysis.AddImport;
 
internal abstract class AbstractAddImportCodeRefactoringProvider<
    TExpressionSyntax,
    TMemberAccessExpressionSyntax,
    TNameSyntax,
    TSimpleNameSyntax,
    TQualifiedNameSyntax,
    TAliasQualifiedNameSyntax,
    TImportDirectiveSyntax>(ISyntaxFacts syntaxFacts)
    : CodeRefactoringProvider
    where TExpressionSyntax : SyntaxNode
    where TMemberAccessExpressionSyntax : TExpressionSyntax
    where TNameSyntax : TExpressionSyntax
    where TSimpleNameSyntax : TNameSyntax
    where TQualifiedNameSyntax : TNameSyntax
    where TAliasQualifiedNameSyntax : TNameSyntax
    where TImportDirectiveSyntax : SyntaxNode
{
    private static readonly SyntaxAnnotation s_annotation = new();
    private readonly ObjectPool<PooledHashSet<string>> _hashSetPool = PooledHashSet<string>.CreatePool(syntaxFacts.StringComparer);
 
    protected abstract string AddImportTitle { get; }
    protected abstract string AddImportAndSimplifyAllOccurrencesTitle { get; }
 
    public override async Task ComputeRefactoringsAsync(CodeRefactoringContext context)
    {
        var (document, textSpan, cancellationToken) = context;
 
        // Only offer when the cursor is at a single point (not a selection)
        if (!textSpan.IsEmpty)
            return;
 
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
 
        var token = root.FindToken(textSpan.Start);
        var node = token.GetRequiredParent();
        if (node is not TNameSyntax name)
            return;
 
        // Get the qualified type reference - this might be a QualifiedName, AliasQualifiedName, or a member access
        // expression that refers to a type.
        var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var (qualifiedTypeReferenceNode, namedType) = GetQualifiedTypeReference();
        if (namedType == null)
            return;
 
        // To simplify things, we don't offer this on a naked alias.  In other words, we don't offer
        // `global::TopLevelType`.  This simplifies later processing.
        var qualifiedTypeReference = qualifiedTypeReferenceNode switch
        {
            TQualifiedNameSyntax qualifiedName => qualifiedName,
            TMemberAccessExpressionSyntax memberAccessExpression => memberAccessExpression,
            _ => (TExpressionSyntax?)null,
        };
 
        if (qualifiedTypeReference == null)
            return;
 
        // Don't want to offer to add a import/using for a namespace if we're already inside an import directive.
        if (qualifiedTypeReference.AncestorsAndSelf().OfType<TImportDirectiveSyntax>().Any())
            return;
 
        // Only offer to add imports for top-most types.  We can't add a (normal) using/import to a type to pull in
        // nested types.  And while we can make a static-using, that's niche enough to not support for now.
        if (namedType.ContainingType != null)
            return;
 
        var namespaceSymbol = namedType.ContainingNamespace;
        if (namespaceSymbol is null || namespaceSymbol.IsGlobalNamespace)
            return;
 
        // If this is actually a type reference off of an alias, don't offer to add a using/import.  The user
        // has already qualified in the way they want.
        var namespaceReference = syntaxFacts.GetLeftSideOfDot(qualifiedTypeReference);
        Contract.ThrowIfNull(namespaceReference);
        if (namespaceReference.DescendantNodesAndSelf().Any(n => semanticModel.GetAliasInfo(n, cancellationToken) is { Target: not INamespaceSymbol { IsGlobalNamespace: true } }))
            return;
 
        // Check if there's already a using directive for this namespace
        var namespaceDisplayString = namespaceSymbol.ToDisplayString();
        var addImportsService = document.GetRequiredLanguageService<IAddImportsService>();
        var generator = SyntaxGenerator.GetGenerator(document);
        var namespaceImport = generator.NamespaceImportDeclaration(namespaceDisplayString);
 
        if (addImportsService.HasExistingImport(semanticModel, root, qualifiedTypeReference, namespaceImport, generator, cancellationToken))
            return;
 
        context.RegisterRefactorings([
            CodeAction.Create(
                string.Format(AddImportTitle, namespaceDisplayString),
                cancellationToken => AddImportAndSimplifyAsync(simplifyAllOccurrences: false, cancellationToken)),
            CodeAction.Create(
                string.Format(AddImportAndSimplifyAllOccurrencesTitle, namespaceDisplayString),
                cancellationToken => AddImportAndSimplifyAsync(simplifyAllOccurrences: true, cancellationToken))],
            qualifiedTypeReference.Span);
 
        static bool IsQualified([NotNullWhen(true)] SyntaxNode? node)
            => node is TQualifiedNameSyntax or TAliasQualifiedNameSyntax or TMemberAccessExpressionSyntax;
 
        (SyntaxNode? qualifiedTypeReference, INamedTypeSymbol? namedType) GetQualifiedTypeReference()
        {
            // Offer on any of the namespace or type names in `global::System.Console.WriteLine()`.
            var symbol = semanticModel.GetSymbolInfo(name, cancellationToken).Symbol;
            if (symbol is INamespaceOrTypeSymbol namespaceOrType)
            {
                // Walk up if we keep seeing a named-type/namespace above us.
                SyntaxNode current = name;
                while (IsQualified(current.Parent))
                {
                    var parentSymbol = semanticModel.GetSymbolInfo(current.Parent, cancellationToken).Symbol;
                    if (parentSymbol is INamespaceOrTypeSymbol)
                    {
                        current = current.Parent;
 
                        // we want to stop on the first named type we see. In other words, if we have NS1.NS2.T1.T2, we want 
                        // to stop on T1.
                        if (parentSymbol is INamespaceSymbol)
                            continue;
 
                        return (current, (INamedTypeSymbol)parentSymbol);
                    }
 
                    // `[System.Obsolete]` will bind to the attributes constructor.
                    if (parentSymbol is IMethodSymbol { MethodKind: MethodKind.Constructor } constructor &&
                        constructor.ContainingType.IsAttribute())
                    {
                        current = current.Parent;
                        return (current, constructor.ContainingType);
                    }
 
                    break;
                }
            }
 
            return default;
        }
 
        async Task<Document> AddImportAndSimplifyAsync(
           bool simplifyAllOccurrences,
           CancellationToken cancellationToken)
        {
            var options = await document.GetAddImportPlacementOptionsAsync(cancellationToken).ConfigureAwait(false);
 
            var rewrittenRoot = RewriteRoot(simplifyAllOccurrences, cancellationToken);
            var rewrittenQualifiedTypeReference = rewrittenRoot.GetAnnotatedNodes(s_annotation).Single();
 
            var finalRoot = addImportsService.AddImport(
                semanticModel,
                rewrittenRoot,
                rewrittenQualifiedTypeReference,
                namespaceImport,
                generator,
                options,
                cancellationToken);
 
            return document.WithSyntaxRoot(finalRoot);
        }
 
        SyntaxNode RewriteRoot(
           bool simplifyAllOccurrences,
           CancellationToken cancellationToken)
        {
            var editor = new SyntaxEditor(root, document.Project.Solution.Services);
 
            // Add all the new type names we know the using/import will be bringing into scope. If we see such a name in
            // the tree, we'll qualify it to ensure it doesn't change meaning.
 
            using var _1 = _hashSetPool.GetPooledObject();
            using var _2 = PooledHashSet<SyntaxNode>.GetInstance(out var qualifiedTypeReferenceNodes);
 
            var newTypeNamesInScope = _1.Object;
            newTypeNamesInScope.AddRange(namespaceSymbol.GetTypeMembers().Select(t => t.Name));
 
            qualifiedTypeReferenceNodes.AddRange(qualifiedTypeReference.DescendantNodes());
 
            Debug.Assert(qualifiedTypeReference is TQualifiedNameSyntax or TMemberAccessExpressionSyntax);
            var namespacePortion = syntaxFacts.GetLeftSideOfDot(qualifiedTypeReference);
 
            // Process simple names from inside out.
            foreach (var child in root.DescendantNodes().OrderByDescending(n => n.SpanStart))
            {
                // Don't touch any nodes under the `global::System.Console` node.  We handle that specially.
                // This ensures we can always find it and always annotate it properly, without other edits
                // interfering.
                if (qualifiedTypeReferenceNodes.Contains(child))
                    continue;
 
                if (child == qualifiedTypeReference)
                {
                    // Mark the node to be simplified, and add the appropriate annotation on it so that our caller can
                    // find this node again to use as the context node when adding the using/import. Note: we can use
                    // the simple ReplaceNode that does not take a callback as the above check ensures that no edits
                    // will have happened underneath us.
                    editor.ReplaceNode(
                        qualifiedTypeReference,
                        qualifiedTypeReference.WithAdditionalAnnotations(Simplifier.Annotation, s_annotation));
                    continue;
                }
 
                // If we run into a name like `Console` and we know we're adding `System`, then qualify this name so
                // that it doesn't change after this point.
                if (child is TSimpleNameSyntax simpleName &&
                    newTypeNamesInScope.Contains(syntaxFacts.GetIdentifierOfSimpleName(simpleName).ValueText))
                {
                    if (syntaxFacts.IsLeftSideOfDot(simpleName) ||
                        syntaxFacts.GetStandaloneExpression(simpleName) == simpleName)
                    {
                        var symbol = semanticModel.GetSymbolInfo(simpleName, cancellationToken).Symbol;
                        if (symbol is INamedTypeSymbol namedType)
                        {
                            var typeContext = syntaxFacts.IsInNamespaceOrTypeContext(simpleName);
                            editor.ReplaceNode(
                                simpleName,
                                (current, _) => generator.SyntaxGeneratorInternal.Type(namedType, typeContext));
                        }
                    }
 
                    continue;
                }
 
                // If we're adding `using System.Collections.Generic;` and we're simplifying everything, and we run
                // into `System.Collection.Generic.IList<C>`, attempt to simplify that as well.
                if (simplifyAllOccurrences &&
                    child is TMemberAccessExpressionSyntax or TQualifiedNameSyntax)
                {
                    // Left side could be something like `System` or `System.Collections.Generic`
                    var leftSide = syntaxFacts.GetLeftSideOfDot(child);
                    if (leftSide is TMemberAccessExpressionSyntax or TQualifiedNameSyntax or TSimpleNameSyntax)
                    {
                        // Right side is now something like `System` (in the `System.Console` case or `Generic` in the
                        // `System.Collections.Generic.List<T>` case).  Check if that's the name of the namespace we're
                        // adding. if so, mark it to be simplified if possible.
                        var rightSideName = leftSide is TSimpleNameSyntax ? leftSide : syntaxFacts.GetRightSideOfDot(leftSide);
                        Debug.Assert(rightSideName != null);
                        if (syntaxFacts.StringComparer.Equals(
                                namespaceSymbol.Name,
                                syntaxFacts.GetIdentifierOfSimpleName(rightSideName).ValueText) &&
                            SymbolEquivalenceComparer.IgnoreAssembliesInstance.Equals(namespaceSymbol, semanticModel.GetSymbolInfo(leftSide, cancellationToken).Symbol))
                        {
                            editor.ReplaceNode(
                                child,
                                (child, _) => child.WithAdditionalAnnotations(Simplifier.Annotation));
                        }
                    }
                }
            }
 
            return editor.GetChangedRoot();
        }
    }
}