File: CodeLens\CodeLensReferencesService.cs
Web Access
Project: src\src\Features\Core\Portable\Microsoft.CodeAnalysis.Features.csproj (Microsoft.CodeAnalysis.Features)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#nullable disable
 
using System;
using System.Collections.Immutable;
using System.Globalization;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.FindSymbols;
using Microsoft.CodeAnalysis.Host;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Text;
 
namespace Microsoft.CodeAnalysis.CodeLens;
 
internal sealed class CodeLensReferencesService : ICodeLensReferencesService
{
    private static readonly SymbolDisplayFormat MethodDisplayFormat =
        new(typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces,
            memberOptions: SymbolDisplayMemberOptions.IncludeContainingType);
 
    /// <summary>
    /// Set ourselves as an implicit invocation of FindReferences.  This will cause the finding operation to operate
    /// in serial, not parallel.  We're running ephemerally in the BG and do not want to saturate the system with
    /// work that then slows the user down.  Also, only process the inheritance hierarchy unidirectionally.  We want
    /// to find references that could actually call into a particular, not references to other members that could
    /// never actually call into this member.
    /// </summary>
    private static readonly FindReferencesSearchOptions s_nonParallelSearch =
        FindReferencesSearchOptions.Default with
        {
            Explicit = false,
            UnidirectionalHierarchyCascade = true
        };
 
    private static async Task<T?> FindAsync<T>(Solution solution, DocumentId documentId, SyntaxNode syntaxNode,
        Func<CodeLensFindReferencesProgress, Task<T>> onResults, Func<CodeLensFindReferencesProgress, Task<T>> onCapped,
        int searchCap, CancellationToken cancellationToken) where T : struct
    {
        var document = await solution.GetDocumentAsync(documentId, includeSourceGenerated: true, cancellationToken).ConfigureAwait(false);
        if (document == null)
        {
            return null;
        }
 
        var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        cancellationToken.ThrowIfCancellationRequested();
 
        var symbol = semanticModel.GetDeclaredSymbol(syntaxNode, cancellationToken);
        if (symbol == null)
        {
            return null;
        }
 
        using var progress = new CodeLensFindReferencesProgress(symbol, syntaxNode, searchCap, cancellationToken);
        try
        {
            await SymbolFinder.FindReferencesAsync(
                symbol, solution, progress, documents: null, s_nonParallelSearch, progress.CancellationToken).ConfigureAwait(false);
 
            return await onResults(progress).ConfigureAwait(false);
        }
        catch (OperationCanceledException)
        {
            if (onCapped != null && progress.SearchCapReached)
            {
                // search was cancelled, and it was cancelled by us because a cap was reached.
                return await onCapped(progress).ConfigureAwait(false);
            }
 
            // search was cancelled, but not because of cap.
            // this always throws.
            throw;
        }
    }
 
    public async ValueTask<VersionStamp> GetProjectCodeLensVersionAsync(Solution solution, ProjectId projectId, CancellationToken cancellationToken)
    {
        return await solution.GetRequiredProject(projectId).GetDependentVersionAsync(cancellationToken).ConfigureAwait(false);
    }
 
    public async Task<ReferenceCount?> GetReferenceCountAsync(Solution solution, DocumentId documentId, SyntaxNode syntaxNode, int maxSearchResults, CancellationToken cancellationToken)
    {
        var projectVersion = await GetProjectCodeLensVersionAsync(solution, documentId.ProjectId, cancellationToken).ConfigureAwait(false);
        return await FindAsync(solution, documentId, syntaxNode,
            progress => Task.FromResult(new ReferenceCount(
                progress.SearchCap > 0
                    ? Math.Min(progress.ReferencesCount, progress.SearchCap)
                    : progress.ReferencesCount, progress.SearchCapReached, projectVersion.ToString())),
            progress => Task.FromResult(new ReferenceCount(progress.SearchCap, IsCapped: true, projectVersion.ToString())),
            maxSearchResults, cancellationToken).ConfigureAwait(false);
    }
 
    private static async Task<ReferenceLocationDescriptor> GetDescriptorOfEnclosingSymbolAsync(Solution solution, Location location, CancellationToken cancellationToken)
    {
        var document = solution.GetDocument(location.SourceTree);
 
        if (document == null)
        {
            return null;
        }
 
        var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
 
        var langServices = document.GetLanguageService<ICodeLensDisplayInfoService>();
        if (langServices == null)
        {
            throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, "Unsupported language '{0}'", semanticModel.Language), nameof(semanticModel));
        }
 
        var position = location.SourceSpan.Start;
        var token = (await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false)).FindToken(position, true);
        var node = GetEnclosingCodeElementNode(document, token, langServices, cancellationToken);
        var longName = langServices.GetDisplayName(semanticModel, node);
 
        // get the full line of source text on the line that contains this position
        var text = await document.GetValueTextAsync(cancellationToken).ConfigureAwait(false);
 
        // get the actual span of text for the line containing reference
        var textLine = text.Lines.GetLineFromPosition(position);
 
        // turn the span from document relative to line relative
        var spanStart = token.Span.Start - textLine.Span.Start;
        var line = textLine.ToString();
 
        var beforeLine1 = textLine.LineNumber > 0 ? text.Lines[textLine.LineNumber - 1].ToString() : string.Empty;
        var beforeLine2 = textLine.LineNumber - 1 > 0
            ? text.Lines[textLine.LineNumber - 2].ToString()
            : string.Empty;
        var afterLine1 = textLine.LineNumber < text.Lines.Count - 1
            ? text.Lines[textLine.LineNumber + 1].ToString()
            : string.Empty;
        var afterLine2 = textLine.LineNumber + 1 < text.Lines.Count - 1
            ? text.Lines[textLine.LineNumber + 2].ToString()
            : string.Empty;
        var referenceSpan = new TextSpan(spanStart, token.Span.Length);
 
        var symbol = semanticModel.GetDeclaredSymbol(node, cancellationToken);
        var glyph = symbol?.GetGlyph();
        var startLinePosition = location.GetLineSpan().StartLinePosition;
        var documentId = solution.GetDocument(location.SourceTree)?.Id;
 
        return new ReferenceLocationDescriptor(
            longName,
            semanticModel.Language,
            glyph,
            token.Span.Start,
            token.Span.Length,
            startLinePosition.Line,
            startLinePosition.Character,
            documentId.ProjectId.Id,
            documentId.Id,
            document.FilePath,
            line.TrimEnd(),
            referenceSpan.Start,
            referenceSpan.Length,
            beforeLine1.TrimEnd(),
            beforeLine2.TrimEnd(),
            afterLine1.TrimEnd(),
            afterLine2.TrimEnd());
    }
 
    private static SyntaxNode GetEnclosingCodeElementNode(Document document, SyntaxToken token, ICodeLensDisplayInfoService langServices, CancellationToken cancellationToken)
    {
        var syntaxFactsService = document.GetLanguageService<ISyntaxFactsService>();
 
        var node = token.Parent;
        while (node != null)
        {
            cancellationToken.ThrowIfCancellationRequested();
 
            if (syntaxFactsService.IsDocumentationComment(node))
            {
                var structuredTriviaSyntax = (IStructuredTriviaSyntax)node;
                var parentTrivia = structuredTriviaSyntax.ParentTrivia;
                node = parentTrivia.Token.Parent;
            }
            else if (syntaxFactsService.IsDeclaration(node) ||
                     syntaxFactsService.IsUsingOrExternOrImport(node) ||
                     syntaxFactsService.IsGlobalAssemblyAttribute(node))
            {
                break;
            }
            else
            {
                node = node.Parent;
            }
        }
 
        node ??= token.Parent;
 
        return langServices.GetDisplayNode(node);
    }
 
    public async Task<ImmutableArray<ReferenceLocationDescriptor>?> FindReferenceLocationsAsync(Solution solution, DocumentId documentId, SyntaxNode syntaxNode, CancellationToken cancellationToken)
    {
        return await FindAsync(solution, documentId, syntaxNode,
            async progress =>
            {
                var referenceTasks = progress.Locations
                    .Select(location => GetDescriptorOfEnclosingSymbolAsync(solution, location, cancellationToken))
                    .ToArray();
 
                var result = await Task.WhenAll(referenceTasks).ConfigureAwait(false);
 
                return result.ToImmutableArray();
            }, onCapped: null, searchCap: 0, cancellationToken: cancellationToken).ConfigureAwait(false);
    }
 
    private static ISymbol GetEnclosingMethod(SemanticModel semanticModel, Location location, CancellationToken cancellationToken)
    {
        var enclosingSymbol = semanticModel.GetEnclosingSymbol(location.SourceSpan.Start, cancellationToken);
 
        for (var current = enclosingSymbol; current != null; current = current.ContainingSymbol)
        {
            cancellationToken.ThrowIfCancellationRequested();
 
            if (current.Kind != SymbolKind.Method)
            {
                continue;
            }
 
            var method = (IMethodSymbol)current;
            if (method.IsAccessor())
            {
                return method.AssociatedSymbol;
            }
 
            if (method.MethodKind != MethodKind.AnonymousFunction)
            {
                return method;
            }
        }
 
        return null;
    }
 
    private static async Task<ReferenceMethodDescriptor> TryGetMethodDescriptorAsync(Location commonLocation, Solution solution, CancellationToken cancellationToken)
    {
        var document = solution.GetDocument(commonLocation.SourceTree);
        if (document == null)
        {
            return null;
        }
 
        var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var fullName = GetEnclosingMethod(semanticModel, commonLocation, cancellationToken)?.ToDisplayString(MethodDisplayFormat);
 
        return !string.IsNullOrEmpty(fullName) ? new ReferenceMethodDescriptor(fullName, document.FilePath, document.Project.OutputFilePath) : null;
    }
 
    public Task<ImmutableArray<ReferenceMethodDescriptor>?> FindReferenceMethodsAsync(Solution solution, DocumentId documentId, SyntaxNode syntaxNode, CancellationToken cancellationToken)
    {
        return FindAsync(solution, documentId, syntaxNode,
            async progress =>
            {
                var descriptorTasks =
                    progress.Locations
                    .Select(location => TryGetMethodDescriptorAsync(location, solution, cancellationToken))
                    .ToArray();
 
                var result = await Task.WhenAll(descriptorTasks).ConfigureAwait(false);
 
                return result.OfType<ReferenceMethodDescriptor>().ToImmutableArray();
            }, onCapped: null, searchCap: 0, cancellationToken: cancellationToken);
    }
 
    public async Task<string> GetFullyQualifiedNameAsync(Solution solution, DocumentId documentId, SyntaxNode syntaxNode,
        CancellationToken cancellationToken)
    {
        var document = solution.GetDocument(syntaxNode.GetLocation().SourceTree);
 
        var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
        var declaredSymbol = semanticModel.GetDeclaredSymbol(syntaxNode, cancellationToken);
 
        if (declaredSymbol == null)
        {
            return string.Empty;
        }
 
        var parts = declaredSymbol.ToDisplayParts(MethodDisplayFormat);
        var pool = PooledStringBuilder.GetInstance();
 
        try
        {
            var actualBuilder = pool.Builder;
            var previousWasClass = false;
 
            for (var index = 0; index < parts.Length; index++)
            {
                var part = parts[index];
                if (previousWasClass &&
                    part.Kind == SymbolDisplayPartKind.Punctuation &&
                    index < parts.Length - 1)
                {
                    switch (parts[index + 1].Kind)
                    {
                        case SymbolDisplayPartKind.ClassName:
                        case SymbolDisplayPartKind.RecordClassName:
                        case SymbolDisplayPartKind.DelegateName:
                        case SymbolDisplayPartKind.EnumName:
                        case SymbolDisplayPartKind.ErrorTypeName:
                        case SymbolDisplayPartKind.InterfaceName:
                        case SymbolDisplayPartKind.StructName:
                        case SymbolDisplayPartKind.RecordStructName:
                            actualBuilder.Append('+');
                            break;
 
                        default:
                            actualBuilder.Append(part);
                            break;
                    }
                }
                else
                {
                    actualBuilder.Append(part);
                }
 
                previousWasClass = part.Kind is SymbolDisplayPartKind.ClassName or
                                   SymbolDisplayPartKind.RecordClassName or
                                   SymbolDisplayPartKind.InterfaceName or
                                   SymbolDisplayPartKind.StructName or
                                   SymbolDisplayPartKind.RecordStructName;
            }
 
            return actualBuilder.ToString();
        }
        finally
        {
            pool.Free();
        }
    }
}