File: src\Workspaces\SharedUtilitiesAndExtensions\Workspace\CSharp\Editing\CSharpImportAdder.cs
Web Access
Project: src\src\CodeStyle\CSharp\CodeFixes\Microsoft.CodeAnalysis.CSharp.CodeStyle.Fixes.csproj (Microsoft.CodeAnalysis.CSharp.CodeStyle.Fixes)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Composition;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp.Editing;
 
[ExportLanguageService(typeof(ImportAdderService), LanguageNames.CSharp), Shared]
[method: ImportingConstructor]
[method: Obsolete(MefConstruction.ImportingConstructorMessage, error: true)]
internal sealed class CSharpImportAdder() : ImportAdderService
{
    protected override INamespaceSymbol? GetExplicitNamespaceSymbol(SyntaxNode node, SemanticModel model)
    {
        switch (node)
        {
            case QualifiedNameSyntax name:
                return GetExplicitNamespaceSymbol(name, name.Left, model);
            case MemberAccessExpressionSyntax memberAccess:
                return GetExplicitNamespaceSymbol(memberAccess, memberAccess.Expression, model);
        }
 
        return null;
    }
 
    protected override Task AddPotentiallyConflictingImportsAsync(
        SemanticModel model,
        SyntaxNode container,
        ImmutableArray<INamespaceSymbol> namespaceSymbols,
        HashSet<INamespaceSymbol> conflicts,
        CancellationToken cancellationToken)
    {
        var conflictFinder = new ConflictFinder(model, namespaceSymbols);
        return conflictFinder.AddPotentiallyConflictingImportsAsync(container, conflicts, cancellationToken);
    }
 
    private static INamespaceSymbol? GetExplicitNamespaceSymbol(ExpressionSyntax fullName, ExpressionSyntax namespacePart, SemanticModel model)
    {
        // name must refer to something that is not a namespace, but be qualified with a namespace.
        var symbol = model.GetSymbolInfo(fullName).Symbol;
        if (symbol != null && symbol.Kind != SymbolKind.Namespace && model.GetSymbolInfo(namespacePart).Symbol is INamespaceSymbol)
        {
            // use the symbols containing namespace, and not the potentially less than fully qualified namespace in the full name expression.
            var ns = symbol.ContainingNamespace;
            if (ns != null)
            {
                return model.Compilation.GetCompilationNamespace(ns);
            }
        }
 
        return null;
    }
 
    /// <summary>
    /// Walks the portion of the tree we're adding imports to looking to see if those imports could likely cause
    /// conflicts with existing code.  Note: this is a best-effort basis, and the goal is to catch reasonable
    /// conflicts effectively.  There may be cases that do slip through that we can adjust for in the future.  Those
    /// cases should be assessed to see how reasonable/likely they are.  I.e. if it's just a hypothetical case with
    /// no users being hit, then that's far less important than if we have a reasonable coding pattern that would be
    /// impacted by adding an import to a normal namespace.
    /// </summary>
    private class ConflictFinder
    {
        private readonly SemanticModel _model;
 
        /// <summary>
        /// A mapping containing the simple names and arity of all imported types, mapped to the import that they're
        /// brought in by.
        /// </summary>
        private readonly MultiDictionary<(string name, int arity), INamespaceSymbol> _importedTypes = [];
 
        /// <summary>
        /// A mapping containing the simple names of all imported extension methods, mapped to the import that
        /// they're brought in by.  This doesn't keep track of arity because methods can be called with type
        /// arguments.
        /// </summary>
        /// <remarks>
        /// We could consider adding more information here (for example the min/max number of args that this can be
        /// called with).  That could then be used to check if there could be a conflict. However, that's likely
        /// more complexity than we need currently.  But it is always something we can do in the future.
        /// </remarks>
        private readonly MultiDictionary<string, INamespaceSymbol> _importedExtensionMethods = [];
 
        public ConflictFinder(
            SemanticModel model,
            ImmutableArray<INamespaceSymbol> namespaceSymbols)
        {
            _model = model;
 
            AddImportedMembers(namespaceSymbols);
        }
 
        private void AddImportedMembers(ImmutableArray<INamespaceSymbol> namespaceSymbols)
        {
            foreach (var ns in namespaceSymbols)
            {
                foreach (var type in ns.GetTypeMembers())
                {
                    _importedTypes.Add((type.Name, type.Arity), ns);
 
                    if (type.MightContainExtensionMethods)
                    {
                        foreach (var member in type.GetMembers())
                        {
                            if (member is IMethodSymbol method && method.IsExtensionMethod)
                                _importedExtensionMethods.Add(method.Name, ns);
                        }
                    }
                }
            }
        }
 
        public async Task AddPotentiallyConflictingImportsAsync(SyntaxNode container, HashSet<INamespaceSymbol> conflicts, CancellationToken cancellationToken)
        {
            using var _ = ArrayBuilder<SyntaxNode>.GetInstance(out var nodes);
 
            CollectInfoFromContainer(container, nodes, out var containsAnonymousMethods);
 
            await ProducerConsumer<INamespaceSymbol>.RunParallelAsync(
                source: nodes,
                produceItems: static (node, onItemsFound, args, cancellationToken) =>
                {
                    var (self, containsAnonymousMethods, _) = args;
                    if (node is SimpleNameSyntax nameSyntaxNode)
                        self.ProduceConflicts(nameSyntaxNode, onItemsFound, cancellationToken);
                    else if (node is MemberAccessExpressionSyntax memberAccessExpressionNode)
                        self.ProduceConflicts(memberAccessExpressionNode, containsAnonymousMethods, onItemsFound, cancellationToken);
                    else
                        throw ExceptionUtilities.Unreachable();
 
                    return Task.CompletedTask;
                },
                consumeItems: static async (items, args, cancellationToken) =>
                {
                    var (_, _, conflicts) = args;
                    await foreach (var conflict in items)
                        conflicts.Add(conflict);
                },
                args: (self: this, containsAnonymousMethods, conflicts),
                cancellationToken).ConfigureAwait(false);
        }
 
        private void CollectInfoFromContainer(SyntaxNode container, ArrayBuilder<SyntaxNode> nodes, out bool containsAnonymousMethods)
        {
            containsAnonymousMethods = false;
 
            foreach (var node in container.DescendantNodesAndSelf())
            {
                switch (node.Kind())
                {
                    case SyntaxKind.IdentifierName:
                    case SyntaxKind.GenericName:
                        if (IsPotentialConflictWithImportedType((SimpleNameSyntax)node))
                            nodes.Add(node);
                        break;
                    case SyntaxKind.SimpleMemberAccessExpression:
                    case SyntaxKind.PointerMemberAccessExpression:
                        if (IsPotentialConflictWithImportedExtensionMethod((MemberAccessExpressionSyntax)node))
                            nodes.Add(node);
                        break;
                    case SyntaxKind.SimpleLambdaExpression:
                    case SyntaxKind.ParenthesizedLambdaExpression:
                    case SyntaxKind.AnonymousMethodExpression:
                        // Track if we've seen an anonymous method or not.  If so, because of how the language binds lambdas and
                        // overloads, we'll assume any method access we see inside (instance or otherwise) could end up conflicting
                        // with an extension method we might pull in.
                        containsAnonymousMethods = true;
                        break;
                }
            }
        }
 
        private bool IsPotentialConflictWithImportedType(SimpleNameSyntax node)
        {
            // Check to see if we have an standalone identifier (or identifier on the left of a dot). If so, if that
            // identifier binds to a type, then we don't want to bring in any imports that would bring in the same
            // name and could then potentially conflict here.
 
            if (node.IsRightSideOfDotOrArrowOrColonColon())
                return false;
 
            // Check to see if we have a var. If so, then nothing assigned to a var
            // would bring any imports that could cause a potential conflict.
            if (node.IsVar)
                return false;
 
            // Drastically reduce the number of nodes that need to be inspected by filtering
            // out nodes whose identifier isn't a potential conflict.
            if (!_importedTypes.ContainsKey((node.Identifier.Text, node.Arity)))
                return false;
 
            return true;
        }
 
        private bool IsPotentialConflictWithImportedExtensionMethod(MemberAccessExpressionSyntax node)
            => _importedExtensionMethods.ContainsKey(node.Name.Identifier.Text);
 
        private void ProduceConflicts(SimpleNameSyntax node, Action<INamespaceSymbol> addConflict, CancellationToken cancellationToken)
        {
            var symbol = _model.GetSymbolInfo(node, cancellationToken).GetAnySymbol();
            if (symbol?.Kind == SymbolKind.NamedType)
            {
                foreach (var conflictingSymbol in _importedTypes[(symbol.Name, node.Arity)])
                    addConflict(conflictingSymbol);
            }
        }
 
        private void ProduceConflicts(MemberAccessExpressionSyntax node, bool containsAnonymousMethods, Action<INamespaceSymbol> addConflict, CancellationToken cancellationToken)
        {
            // Check to see if we have a reference to an extension method.  If so, then pulling in an import could
            // bring in an extension that conflicts with that.
 
            var symbol = _model.GetSymbolInfo(node.Name, cancellationToken).GetAnySymbol();
            if (symbol is IMethodSymbol method)
            {
                var isConflicting = method.IsReducedExtension();
 
                if (!isConflicting && containsAnonymousMethods)
                {
                    // lambdas are interesting.  Say you have:
                    //
                    //      Goo(x => x.M());
                    //
                    //      void Goo(Action<C> act) { }
                    //      void Goo(Action<int> act) { }
                    //
                    //      class C { public void M() { } }
                    //
                    // This is legal code where the lambda body is calling the instance method.  However, if we introduce a
                    // using that brings in an extension method 'M' on 'int', then the above will become ambiguous.  This is
                    // because lambda binding will try each interpretation separately and eliminate the ones that fail.
                    // Adding the import will make the int form succeed, causing ambiguity.
                    //
                    // To deal with that, we keep track of if we're in a lambda, and we conservatively assume that a method
                    // access (even to a non-extension method) could conflict with an extension method brought in.
                    isConflicting = node.HasAncestor<AnonymousFunctionExpressionSyntax>();
                }
 
                if (isConflicting)
                {
                    foreach (var conflictingSymbol in _importedExtensionMethods[method.Name])
                        addConflict(conflictingSymbol);
                }
            }
        }
    }
}