File: AddImport\CSharpAddImportFeatureService.cs
Web Access
Project: src\src\Features\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.Features.csproj (Microsoft.CodeAnalysis.CSharp.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.AddImport;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Host;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;
using static Microsoft.CodeAnalysis.CSharp.AddImport.AddImportDiagnosticIds;
 
namespace Microsoft.CodeAnalysis.CSharp.AddImport;
 
using static CSharpSyntaxTokens;
using static SyntaxFactory;
 
[ExportLanguageService(typeof(IAddImportFeatureService), LanguageNames.CSharp), Shared]
[method: ImportingConstructor]
[method: Obsolete(MefConstruction.ImportingConstructorMessage, error: true)]
internal sealed class CSharpAddImportFeatureService() : AbstractAddImportFeatureService<SimpleNameSyntax>
{
    protected override bool IsWithinImport(SyntaxNode node)
        => node.GetAncestor<UsingDirectiveSyntax>()?.Parent is CompilationUnitSyntax;
 
    protected override bool CanAddImport(SyntaxNode node, bool allowInHiddenRegions, CancellationToken cancellationToken)
        => node.CanAddUsingDirectives(allowInHiddenRegions, cancellationToken);
 
    protected override bool CanAddImportForMethod(
        string diagnosticId, ISyntaxFacts syntaxFacts, SyntaxNode node, out SimpleNameSyntax nameNode)
    {
        nameNode = null;
 
        switch (diagnosticId)
        {
            case CS7036:
            case CS0308:
            case CS0428:
            case CS1061:
                if (node is ConditionalAccessExpressionSyntax conditionalAccess)
                {
                    node = conditionalAccess.WhenNotNull;
                }
                else if (node is MemberBindingExpressionSyntax memberBinding1)
                {
                    node = memberBinding1.Name;
                }
                else if (node.Parent.IsKind(SyntaxKind.CollectionInitializerExpression))
                {
                    return true;
                }
 
                break;
            case CS0122:
            case CS1501:
                if (node is SimpleNameSyntax)
                {
                    break;
                }
                else if (node is MemberBindingExpressionSyntax memberBindingExpr)
                {
                    node = memberBindingExpr.Name;
                }
 
                break;
            case CS1929:
                var memberAccessName = (node.Parent as MemberAccessExpressionSyntax)?.Name;
                var conditionalAccessName = (((node.Parent as ConditionalAccessExpressionSyntax)?.WhenNotNull as InvocationExpressionSyntax)?.Expression as MemberBindingExpressionSyntax)?.Name;
                if (memberAccessName == null && conditionalAccessName == null)
                {
                    return false;
                }
 
                node = memberAccessName ?? conditionalAccessName;
                break;
 
            case CS1503:
                //// look up its corresponding method name
                var parent = node.GetAncestor<InvocationExpressionSyntax>();
                if (parent == null)
                {
                    return false;
                }
 
                if (parent.Expression is MemberAccessExpressionSyntax method)
                {
                    node = method.Name;
                }
 
                break;
            case CS1955:
                break;
 
            default:
                return false;
        }
 
        nameNode = node as SimpleNameSyntax;
        if (nameNode?.Parent?.Kind() is not SyntaxKind.SimpleMemberAccessExpression and not SyntaxKind.MemberBindingExpression)
            return false;
 
        var memberAccess = nameNode.Parent as MemberAccessExpressionSyntax;
        var memberBinding = nameNode.Parent as MemberBindingExpressionSyntax;
        if (memberAccess?.Parent?.Kind() is SyntaxKind.SimpleMemberAccessExpression or SyntaxKind.ElementAccessExpression ||
            memberBinding?.Parent?.Kind() is SyntaxKind.SimpleMemberAccessExpression or SyntaxKind.ElementAccessExpression)
        {
            return false;
        }
 
        if (!syntaxFacts.IsNameOfSimpleMemberAccessExpression(node) &&
            !syntaxFacts.IsNameOfMemberBindingExpression(node))
        {
            return false;
        }
 
        return true;
    }
 
    protected override bool CanAddImportForDeconstruct(string diagnosticId, SyntaxNode node)
        => diagnosticId == CS8129;
 
    protected override bool CanAddImportForGetAwaiter(string diagnosticId, ISyntaxFacts syntaxFactsService, SyntaxNode node)
        => (diagnosticId == CS1061 || // Regular cases
            diagnosticId == CS4036 || // WinRT async interfaces
            diagnosticId == CS1929) && // An extension `GetAwaiter()` is in scope, but for another type
            AncestorOrSelfIsAwaitExpression(syntaxFactsService, node);
 
    protected override bool CanAddImportForGetEnumerator(string diagnosticId, ISyntaxFacts syntaxFactsService, SyntaxNode node)
        => diagnosticId is CS1579 or CS8414;
 
    protected override bool CanAddImportForGetAsyncEnumerator(string diagnosticId, ISyntaxFacts syntaxFactsService, SyntaxNode node)
        => diagnosticId is CS8411 or CS8415;
 
    protected override bool CanAddImportForNamespace(string diagnosticId, SyntaxNode node, out SimpleNameSyntax nameNode)
    {
        nameNode = null;
        return false;
    }
 
    protected override bool CanAddImportForQuery(string diagnosticId, SyntaxNode node)
        => (diagnosticId == CS1935 || // Regular cases
            diagnosticId == CS1929) && // An extension method is in scope, but for another type
            node.AncestorsAndSelf().Any(n => n is QueryExpressionSyntax && !(n.Parent is QueryContinuationSyntax));
 
    protected override bool CanAddImportForTypeOrNamespace(
        string diagnosticId, SyntaxNode node, out SimpleNameSyntax nameNode)
    {
        nameNode = null;
        switch (diagnosticId)
        {
            case CS0103:
            case IDEDiagnosticIds.UnboundIdentifierId:
            case CS0246:
            case CS0305:
            case CS0308:
            case CS0122:
            case CS0307:
            case CS0616:
            case CS1580:
            case CS1581:
            case CS1955:
            case CS0281:
                break;
 
            case CS1574:
            case CS1584:
                if (node is QualifiedCrefSyntax cref)
                {
                    node = cref.Container;
                }
 
                break;
 
            case CS0234:
                // The type or namespace name 'X' does not exist in the namespace 'Y'.
                //
                // We support this within a using, on any part of the using name that doesn't bind.
                if (!this.IsWithinImport(node))
                    return false;
 
                nameNode = node as SimpleNameSyntax;
                return true;
 
            default:
                return false;
        }
 
        return TryFindStandaloneType(node, out nameNode);
    }
 
    private static bool TryFindStandaloneType(SyntaxNode node, out SimpleNameSyntax nameNode)
    {
        if (node is QualifiedNameSyntax qn)
        {
            node = GetLeftMostSimpleName(qn);
        }
 
        nameNode = node as SimpleNameSyntax;
        return nameNode.LooksLikeStandaloneTypeName();
    }
 
    private static SimpleNameSyntax GetLeftMostSimpleName(QualifiedNameSyntax qn)
    {
        while (qn != null)
        {
            var left = qn.Left;
            if (left is SimpleNameSyntax simpleName)
            {
                return simpleName;
            }
 
            qn = left as QualifiedNameSyntax;
        }
 
        return null;
    }
 
    protected override ISet<INamespaceSymbol> GetImportNamespacesInScope(
        SemanticModel semanticModel,
        SyntaxNode node,
        CancellationToken cancellationToken)
    {
        return semanticModel.GetUsingNamespacesInScope(node);
    }
 
    protected override ITypeSymbol GetDeconstructInfo(
        SemanticModel semanticModel, SyntaxNode node, CancellationToken cancellationToken)
    {
        return semanticModel.GetTypeInfo(node, cancellationToken).Type;
    }
 
    protected override ITypeSymbol GetQueryClauseInfo(
        SemanticModel semanticModel,
        SyntaxNode node,
        CancellationToken cancellationToken)
    {
        var query = node.AncestorsAndSelf().OfType<QueryExpressionSyntax>().First();
 
        if (InfoBoundSuccessfully(semanticModel.GetQueryClauseInfo(query.FromClause, cancellationToken)))
        {
            return null;
        }
 
        foreach (var clause in query.Body.Clauses)
        {
            if (InfoBoundSuccessfully(semanticModel.GetQueryClauseInfo(clause, cancellationToken)))
            {
                return null;
            }
        }
 
        if (InfoBoundSuccessfully(semanticModel.GetSymbolInfo(query.Body.SelectOrGroup, cancellationToken)))
        {
            return null;
        }
 
        var fromClause = query.FromClause;
        return semanticModel.GetTypeInfo(fromClause.Expression, cancellationToken).Type;
    }
 
    private static bool InfoBoundSuccessfully(SymbolInfo symbolInfo)
        => InfoBoundSuccessfully(symbolInfo.Symbol);
 
    private static bool InfoBoundSuccessfully(QueryClauseInfo semanticInfo)
        => InfoBoundSuccessfully(semanticInfo.OperationInfo);
 
    private static bool InfoBoundSuccessfully(ISymbol operation)
    {
        operation = operation.GetOriginalUnreducedDefinition();
        return operation != null;
    }
 
    protected override string GetDescription(IReadOnlyList<string> nameParts)
        => $"using {string.Join(".", nameParts)};";
 
    protected override (string description, bool hasExistingImport) GetDescription(
        Document document,
        AddImportPlacementOptions options,
        INamespaceOrTypeSymbol namespaceOrTypeSymbol,
        SemanticModel semanticModel,
        SyntaxNode contextNode,
        CancellationToken cancellationToken)
    {
        var root = GetCompilationUnitSyntaxNode(contextNode, cancellationToken);
 
        // See if this is a reference to a type from a reference that has a specific alias
        // associated with it.  If that extern alias hasn't already been brought into scope
        // then add that one.
        var (externAlias, hasExistingExtern) = GetExternAliasDirective(
            namespaceOrTypeSymbol, semanticModel, contextNode);
 
        var (usingDirective, hasExistingUsing) = GetUsingDirective(
            document, options, namespaceOrTypeSymbol, semanticModel, root, contextNode);
 
        var externAliasString = externAlias != null ? $"extern alias {externAlias.Identifier.ValueText};" : null;
        var usingDirectiveString = usingDirective != null ? GetUsingDirectiveString(namespaceOrTypeSymbol) : null;
 
        if (externAlias == null && usingDirective == null)
        {
            return (null, false);
        }
 
        if (externAlias != null && !hasExistingExtern)
        {
            return (externAliasString, false);
        }
 
        if (usingDirective != null && !hasExistingUsing)
        {
            return (usingDirectiveString, false);
        }
 
        return externAlias != null
            ? (externAliasString, hasExistingExtern)
            : (usingDirectiveString, hasExistingUsing);
    }
 
    private static string GetUsingDirectiveString(INamespaceOrTypeSymbol namespaceOrTypeSymbol)
    {
        var displayString = namespaceOrTypeSymbol.ToDisplayString();
        return namespaceOrTypeSymbol.IsKind(SymbolKind.Namespace)
            ? $"using {displayString};"
            : $"using static {displayString};";
    }
 
    protected override async Task<Document> AddImportAsync(
        SyntaxNode contextNode,
        INamespaceOrTypeSymbol namespaceOrTypeSymbol,
        Document document,
        AddImportPlacementOptions options,
        CancellationToken cancellationToken)
    {
        var root = GetCompilationUnitSyntaxNode(contextNode, cancellationToken);
        var newRoot = await AddImportWorkerAsync(document, root, contextNode, namespaceOrTypeSymbol, options, cancellationToken).ConfigureAwait(false);
        return document.WithSyntaxRoot(newRoot);
    }
 
    private static async Task<CompilationUnitSyntax> AddImportWorkerAsync(
        Document document, CompilationUnitSyntax root, SyntaxNode contextNode, INamespaceOrTypeSymbol namespaceOrTypeSymbol,
        AddImportPlacementOptions options, CancellationToken cancellationToken)
    {
        var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        var (externAliasDirective, hasExistingExtern) = GetExternAliasDirective(
            namespaceOrTypeSymbol, semanticModel, contextNode);
 
        var (usingDirective, hasExistingUsing) = GetUsingDirective(
            document, options, namespaceOrTypeSymbol, semanticModel, root, contextNode);
 
        using var _ = ArrayBuilder<SyntaxNode>.GetInstance(out var newImports);
 
        if (!hasExistingExtern && externAliasDirective != null)
        {
            newImports.Add(externAliasDirective);
        }
 
        if (!hasExistingUsing && usingDirective != null)
        {
            newImports.Add(usingDirective);
        }
 
        if (newImports.Count == 0)
        {
            return root;
        }
 
        var addImportService = document.GetLanguageService<IAddImportsService>();
        var generator = SyntaxGenerator.GetGenerator(document);
        var newRoot = addImportService.AddImports(
            semanticModel.Compilation, root, contextNode, newImports, generator, options, cancellationToken);
        return (CompilationUnitSyntax)newRoot;
    }
 
    protected override async Task<Document> AddImportAsync(
        SyntaxNode contextNode, IReadOnlyList<string> namespaceParts,
        Document document, AddImportPlacementOptions options, CancellationToken cancellationToken)
    {
        var root = GetCompilationUnitSyntaxNode(contextNode, cancellationToken);
 
        var usingDirective = UsingDirective(
            CreateNameSyntax(namespaceParts, namespaceParts.Count - 1));
 
        var compilation = await document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
        var service = document.GetLanguageService<IAddImportsService>();
        var generator = SyntaxGenerator.GetGenerator(document);
        var newRoot = service.AddImport(
            compilation, root, contextNode, usingDirective, generator, options, cancellationToken);
 
        return document.WithSyntaxRoot(newRoot);
    }
 
    private static NameSyntax CreateNameSyntax(IReadOnlyList<string> namespaceParts, int index)
    {
        var part = namespaceParts[index];
        if (SyntaxFacts.GetKeywordKind(part) != SyntaxKind.None)
        {
            part = "@" + part;
        }
 
        var namePiece = IdentifierName(part);
        return index == 0
            ? namePiece
            : QualifiedName(CreateNameSyntax(namespaceParts, index - 1), namePiece);
    }
 
    private static (ExternAliasDirectiveSyntax, bool hasExistingImport) GetExternAliasDirective(
        INamespaceOrTypeSymbol namespaceSymbol,
        SemanticModel semanticModel,
        SyntaxNode contextNode)
    {
        var (val, hasExistingExtern) = GetExternAliasString(namespaceSymbol, semanticModel, contextNode);
        if (val == null)
        {
            return (null, false);
        }
 
        return (ExternAliasDirective(Identifier(val))
                             .WithAdditionalAnnotations(Formatter.Annotation),
                hasExistingExtern);
    }
 
    private static (UsingDirectiveSyntax, bool hasExistingImport) GetUsingDirective(
        Document document,
        AddImportPlacementOptions options,
        INamespaceOrTypeSymbol namespaceOrTypeSymbol,
        SemanticModel semanticModel,
        CompilationUnitSyntax root,
        SyntaxNode contextNode)
    {
        var addImportService = document.GetLanguageService<IAddImportsService>();
        var generator = SyntaxGenerator.GetGenerator(document);
 
        var nameSyntax = namespaceOrTypeSymbol.GenerateNameSyntax();
 
        // We need to create our using in two passes.  This is because we need a using
        // directive so we can figure out where to put it.  Then, once we figure out
        // where to put it, we might need to change it a bit (e.g. removing 'global' 
        // from it if necessary).  So we first create a dummy using directive just to
        // determine which container we're going in.  Then we'll use the container to
        // help create the final using.
        var dummyUsing = UsingDirective(nameSyntax);
 
        var container = addImportService.GetImportContainer(root, contextNode, dummyUsing, options);
        var namespaceToAddTo = container as BaseNamespaceDeclarationSyntax;
 
        // Replace the alias that GenerateTypeSyntax added if we want this to be looked
        // up off of an extern alias.
        var (externAliasDirective, _) = GetExternAliasDirective(
            namespaceOrTypeSymbol, semanticModel, contextNode);
 
        var externAlias = externAliasDirective?.Identifier.ValueText;
        if (externAlias != null)
        {
            nameSyntax = AddOrReplaceAlias(nameSyntax, IdentifierName(externAlias));
        }
        else
        {
            // The name we generated will have the global:: alias on it.  We only need
            // that if the name of our symbol is actually ambiguous in this context.
            // If so, keep global:: on it, otherwise remove it.
            //
            // Note: doing this has a couple of benefits.  First, it's easy for us to see
            // if we have an existing using for this with the same syntax.  Second,
            // it's easy to sort usings properly.  If "global::" was attached to the 
            // using directive, then it would make both of those operations more difficult
            // to achieve.
            nameSyntax = RemoveGlobalAliasIfUnnecessary(semanticModel, nameSyntax, namespaceToAddTo);
        }
 
        var usingDirective = UsingDirective(nameSyntax)
                                          .WithAdditionalAnnotations(Formatter.Annotation);
 
        usingDirective = namespaceOrTypeSymbol.IsKind(SymbolKind.Namespace)
            ? usingDirective
            : usingDirective.WithStaticKeyword(StaticKeyword);
 
        return (usingDirective, addImportService.HasExistingImport(semanticModel.Compilation, root, contextNode, usingDirective, generator));
    }
 
    private static NameSyntax RemoveGlobalAliasIfUnnecessary(
        SemanticModel semanticModel,
        NameSyntax nameSyntax,
        BaseNamespaceDeclarationSyntax namespaceToAddTo)
    {
        var aliasQualifiedName = nameSyntax.DescendantNodesAndSelf()
                                           .OfType<AliasQualifiedNameSyntax>()
                                           .FirstOrDefault();
        if (aliasQualifiedName != null)
        {
            var rightOfAliasName = aliasQualifiedName.Name.Identifier.ValueText;
            if (!ConflictsWithExistingMember(semanticModel, namespaceToAddTo, rightOfAliasName))
            {
                // Strip off the alias.
                return nameSyntax.ReplaceNode(aliasQualifiedName, aliasQualifiedName.Name);
            }
        }
 
        return nameSyntax;
    }
 
    private static bool ConflictsWithExistingMember(
        SemanticModel semanticModel,
        BaseNamespaceDeclarationSyntax namespaceToAddTo,
        string rightOfAliasName)
    {
        if (namespaceToAddTo != null)
        {
            var containingNamespaceSymbol = semanticModel.GetDeclaredSymbol(namespaceToAddTo);
 
            while (containingNamespaceSymbol != null && !containingNamespaceSymbol.IsGlobalNamespace)
            {
                if (containingNamespaceSymbol.GetMembers(rightOfAliasName).Any())
                {
                    // A containing namespace had this name in it.  We need to stay globally qualified.
                    return true;
                }
 
                containingNamespaceSymbol = containingNamespaceSymbol.ContainingNamespace;
            }
        }
 
        // Didn't conflict with anything.  We should remove the global:: alias.
        return false;
    }
 
    private static NameSyntax AddOrReplaceAlias(
        NameSyntax nameSyntax, IdentifierNameSyntax alias)
    {
        if (nameSyntax is SimpleNameSyntax simpleName)
        {
            return AliasQualifiedName(alias, simpleName);
        }
 
        if (nameSyntax is QualifiedNameSyntax qualifiedName)
        {
            return qualifiedName.WithLeft(AddOrReplaceAlias(qualifiedName.Left, alias));
        }
 
        var aliasName = nameSyntax as AliasQualifiedNameSyntax;
        return aliasName.WithAlias(alias);
    }
 
    private static (string, bool hasExistingImport) GetExternAliasString(
        INamespaceOrTypeSymbol namespaceSymbol,
        SemanticModel semanticModel,
        SyntaxNode contextNode)
    {
        string externAliasString = null;
        var metadataReference = semanticModel.Compilation.GetMetadataReference(namespaceSymbol.ContainingAssembly);
        if (metadataReference == null)
        {
            return (null, false);
        }
 
        var aliases = metadataReference.Properties.Aliases;
        if (aliases.IsEmpty)
        {
            return (null, false);
        }
 
        aliases = [.. metadataReference.Properties.Aliases.Where(a => a != MetadataReferenceProperties.GlobalAlias)];
        if (!aliases.Any())
        {
            return (null, false);
        }
 
        // Just default to using the first alias we see for this symbol.
        externAliasString = aliases.First();
        return (externAliasString, HasExistingExternAlias(externAliasString, contextNode));
    }
 
    private static bool HasExistingExternAlias(string alias, SyntaxNode contextNode)
    {
        foreach (var externAlias in contextNode.GetEnclosingExternAliasDirectives())
        {
            // We already have this extern alias in scope.  No need to add it.
            if (externAlias.Identifier.ValueText == alias)
            {
                return true;
            }
        }
 
        return false;
    }
 
    private static CompilationUnitSyntax GetCompilationUnitSyntaxNode(
        SyntaxNode contextNode, CancellationToken cancellationToken)
    {
        return (CompilationUnitSyntax)contextNode.SyntaxTree.GetRoot(cancellationToken);
    }
 
    protected override bool IsViableExtensionMethod(IMethodSymbol method, SyntaxNode expression, SemanticModel semanticModel, ISyntaxFacts syntaxFacts, CancellationToken cancellationToken)
    {
        var leftExpression = syntaxFacts.IsMemberAccessExpression(expression)
            ? syntaxFacts.GetExpressionOfMemberAccessExpression(expression)
            : syntaxFacts.GetTargetOfMemberBinding(expression);
        if (leftExpression == null)
        {
            if (expression.IsKind(SyntaxKind.CollectionInitializerExpression))
            {
                leftExpression = expression.GetAncestor<ObjectCreationExpressionSyntax>();
            }
            else
            {
                return false;
            }
        }
 
        var semanticInfo = semanticModel.GetTypeInfo(leftExpression, cancellationToken);
        var leftExpressionType = semanticInfo.Type;
 
        return IsViableExtensionMethod(method, leftExpressionType);
    }
 
    protected override bool IsAddMethodContext(SyntaxNode node, SemanticModel semanticModel)
    {
        if (node.Parent.IsKind(SyntaxKind.CollectionInitializerExpression))
        {
            var objectCreationExpressionSyntax = node.GetAncestor<ObjectCreationExpressionSyntax>();
            if (objectCreationExpressionSyntax == null)
            {
                return false;
            }
 
            return true;
        }
 
        return false;
    }
}