File: FindSymbols\SymbolFinder_Callers.cs
Web Access
Project: src\src\Workspaces\Core\Portable\Microsoft.CodeAnalysis.Workspaces.csproj (Microsoft.CodeAnalysis.Workspaces)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.Shared.Utilities;
 
#if !NETCOREAPP
using Roslyn.Utilities;
#endif
 
namespace Microsoft.CodeAnalysis.FindSymbols;
 
public static partial class SymbolFinder
{
    /// <summary>
    /// Finds all the callers of a specified symbol.
    /// </summary>
    public static Task<IEnumerable<SymbolCallerInfo>> FindCallersAsync(
        ISymbol symbol, Solution solution, CancellationToken cancellationToken = default)
    {
        return FindCallersAsync(symbol, solution, documents: null, cancellationToken: cancellationToken);
    }
 
    /// <summary>
    /// Finds all the callers of a specified symbol.
    /// </summary>
    public static async Task<IEnumerable<SymbolCallerInfo>> FindCallersAsync(
        ISymbol symbol, Solution solution, IImmutableSet<Document>? documents, CancellationToken cancellationToken = default)
    {
        if (symbol is null)
            throw new System.ArgumentNullException(nameof(symbol));
        if (solution is null)
            throw new System.ArgumentNullException(nameof(solution));
 
        symbol = symbol.OriginalDefinition;
        var foundSymbol = FindSourceDefinition(symbol, solution, cancellationToken);
        symbol = foundSymbol ?? symbol;
 
        var references = await FindCallReferencesAsync(solution, symbol, documents, cancellationToken).ConfigureAwait(false);
 
        var directReference = references.Where(
            r => SymbolEquivalenceComparer.Instance.Equals(symbol, r.Definition)).FirstOrDefault();
 
        var indirectReferences = references.WhereAsArray(r => r != directReference);
 
        var results = new List<SymbolCallerInfo>();
 
        if (directReference != null)
        {
            await AddReferencingSymbolsAsync(directReference, isDirect: true).ConfigureAwait(false);
        }
 
        foreach (var indirectReference in indirectReferences)
        {
            await AddReferencingSymbolsAsync(indirectReference, isDirect: false).ConfigureAwait(false);
        }
 
        return results;
 
        async Task AddReferencingSymbolsAsync(ReferencedSymbol reference, bool isDirect)
        {
            var result = await reference.Locations.FindReferencingSymbolsAsync(cancellationToken).ConfigureAwait(false);
            foreach (var (callingSymbol, locations) in result)
            {
                results.Add(new SymbolCallerInfo(callingSymbol, reference.Definition, locations, isDirect));
            }
        }
    }
 
    private static async Task<ImmutableArray<ReferencedSymbol>> FindCallReferencesAsync(
        Solution solution,
        ISymbol symbol,
        IImmutableSet<Document>? documents,
        CancellationToken cancellationToken = default)
    {
        if (symbol.Kind is SymbolKind.Event or
            SymbolKind.Method or
            SymbolKind.Property or
            SymbolKind.Field)
        {
            var collector = new StreamingProgressCollector();
            var options = FindReferencesSearchOptions.GetFeatureOptionsForStartingSymbol(symbol);
            await FindReferencesAsync(
                symbol, solution, collector, documents,
                options, cancellationToken).ConfigureAwait(false);
            return collector.GetReferencedSymbols();
        }
 
        return [];
    }
}