File: FindSymbols\FindLiterals\FindLiteralsSearchEngine.cs
Web Access
Project: src\src\Workspaces\Core\Portable\Microsoft.CodeAnalysis.Workspaces.csproj (Microsoft.CodeAnalysis.Workspaces)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.FindSymbols;
 
internal sealed class FindLiteralsSearchEngine
{
    private enum SearchKind
    {
        None,
        StringLiterals,
        CharacterLiterals,
        NumericLiterals,
    }
 
    private readonly Solution _solution;
    private readonly IStreamingFindLiteralReferencesProgress _progress;
    private readonly IStreamingProgressTracker _progressTracker;
 
    private readonly object _value;
    private readonly string? _stringValue;
    private readonly long _longValue;
    private readonly SearchKind _searchKind;
 
    public FindLiteralsSearchEngine(
        Solution solution,
        IStreamingFindLiteralReferencesProgress progress, object value)
    {
        _solution = solution;
        _progress = progress;
        _progressTracker = progress.ProgressTracker;
        _value = value;
 
        switch (value)
        {
            case string s:
                _stringValue = s;
                _searchKind = SearchKind.StringLiterals;
                break;
            case double d:
                _longValue = BitConverter.DoubleToInt64Bits(d);
                _searchKind = SearchKind.NumericLiterals;
                break;
            case float f:
                _longValue = BitConverter.DoubleToInt64Bits(f);
                _searchKind = SearchKind.NumericLiterals;
                break;
            case decimal: // unsupported
                _searchKind = SearchKind.None;
                break;
            case char:
                _longValue = IntegerUtilities.ToInt64(value);
                _searchKind = SearchKind.CharacterLiterals;
                break;
            default:
                _longValue = IntegerUtilities.ToInt64(value);
                _searchKind = SearchKind.NumericLiterals;
                break;
        }
    }
 
    public async Task FindReferencesAsync(CancellationToken cancellationToken)
    {
        var disposable = await _progressTracker.AddSingleItemAsync(cancellationToken).ConfigureAwait(false);
        await using var _ = disposable.ConfigureAwait(false);
 
        if (_searchKind != SearchKind.None)
        {
            await FindReferencesWorkerAsync(cancellationToken).ConfigureAwait(false);
        }
    }
 
    private async Task FindReferencesWorkerAsync(CancellationToken cancellationToken)
    {
        var count = _solution.Projects.SelectMany(p => p.DocumentIds).Count();
        await _progressTracker.AddItemsAsync(count, cancellationToken).ConfigureAwait(false);
 
        await RoslynParallel.ForEachAsync(
            source: SelectManyAsync(_solution.Projects, p => p.GetAllRegularAndSourceGeneratedDocumentsAsync(cancellationToken)),
            cancellationToken,
            ProcessDocumentAsync).ConfigureAwait(false);
 
        static async IAsyncEnumerable<TResult> SelectManyAsync<TSource, TResult>(IEnumerable<TSource> source, Func<TSource, IAsyncEnumerable<TResult>> selector)
        {
            foreach (var item in source)
            {
                await foreach (var result in selector(item))
                    yield return result;
            }
        }
    }
 
    private async ValueTask ProcessDocumentAsync(Document document, CancellationToken cancellationToken)
    {
        try
        {
            await ProcessDocumentWorkerAsync(document, cancellationToken).ConfigureAwait(false);
        }
        finally
        {
            await _progressTracker.ItemCompletedAsync(cancellationToken).ConfigureAwait(false);
        }
    }
 
    private async Task ProcessDocumentWorkerAsync(Document document, CancellationToken cancellationToken)
    {
        var index = await SyntaxTreeIndex.GetIndexAsync(
            document, cancellationToken).ConfigureAwait(false);
 
        Contract.ThrowIfNull(index);
        if (_searchKind == SearchKind.StringLiterals)
        {
            Contract.ThrowIfNull(_stringValue);
            if (index.ProbablyContainsStringValue(_stringValue))
            {
                await SearchDocumentAsync(document, cancellationToken).ConfigureAwait(false);
            }
        }
        else if (index.ProbablyContainsInt64Value(_longValue))
        {
            await SearchDocumentAsync(document, cancellationToken).ConfigureAwait(false);
        }
    }
 
    private async Task SearchDocumentAsync(Document document, CancellationToken cancellationToken)
    {
        cancellationToken.ThrowIfCancellationRequested();
        var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>();
        var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
 
        using var _ = ArrayBuilder<SyntaxToken>.GetInstance(out var matches);
        ProcessNode(syntaxFacts, root, matches, cancellationToken);
 
        foreach (var token in matches)
            await _progress.OnReferenceFoundAsync(document, token.Span, cancellationToken).ConfigureAwait(false);
    }
 
    private void ProcessNode(
        ISyntaxFactsService syntaxFacts, SyntaxNode node,
        ArrayBuilder<SyntaxToken> matches, CancellationToken cancellationToken)
    {
        cancellationToken.ThrowIfCancellationRequested();
        foreach (var child in node.ChildNodesAndTokens())
        {
            if (child.AsNode(out var childNode))
            {
                ProcessNode(syntaxFacts, childNode, matches, cancellationToken);
            }
            else
            {
                ProcessToken(syntaxFacts, child.AsToken(), matches);
            }
        }
    }
 
    private void ProcessToken(
        ISyntaxFactsService syntaxFacts, SyntaxToken token,
        ArrayBuilder<SyntaxToken> matches)
    {
        if (_searchKind == SearchKind.StringLiterals &&
            syntaxFacts.IsStringLiteral(token))
        {
            CheckToken(token, matches);
        }
        else if (_searchKind == SearchKind.CharacterLiterals &&
                 syntaxFacts.IsCharacterLiteral(token))
        {
            CheckToken(token, matches);
        }
        else if (_searchKind == SearchKind.NumericLiterals &&
                 syntaxFacts.IsNumericLiteral(token))
        {
            CheckToken(token, matches);
        }
    }
 
    private void CheckToken(SyntaxToken token, ArrayBuilder<SyntaxToken> matches)
    {
        if (_value.Equals(token.Value))
        {
            matches.Add(token);
        }
    }
}