File: Services\FindUsages\RemoteFindUsagesService.cs
Web Access
Project: src\src\Workspaces\Remote\ServiceHub\Microsoft.CodeAnalysis.Remote.ServiceHub.csproj (Microsoft.CodeAnalysis.Remote.ServiceHub)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.Classification;
using Microsoft.CodeAnalysis.FindSymbols;
using Microsoft.CodeAnalysis.FindUsages;
using Microsoft.CodeAnalysis.Notification;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
 
namespace Microsoft.CodeAnalysis.Remote;
 
internal sealed class RemoteFindUsagesService(in BrokeredServiceBase.ServiceConstructionArguments arguments, RemoteCallback<IRemoteFindUsagesService.ICallback> callback)
    : BrokeredServiceBase(arguments), IRemoteFindUsagesService
{
    internal sealed class Factory : FactoryBase<IRemoteFindUsagesService, IRemoteFindUsagesService.ICallback>
    {
        protected override IRemoteFindUsagesService CreateService(in ServiceConstructionArguments arguments, RemoteCallback<IRemoteFindUsagesService.ICallback> callback)
            => new RemoteFindUsagesService(arguments, callback);
    }
 
    public ValueTask FindReferencesAsync(
        Checksum solutionChecksum,
        RemoteServiceCallbackId callbackId,
        SerializableSymbolAndProjectId symbolAndProjectId,
        FindReferencesSearchOptions options,
        CancellationToken cancellationToken)
    {
        return RunServiceAsync(solutionChecksum, async solution =>
        {
            var project = solution.GetRequiredProject(symbolAndProjectId.ProjectId);
 
            var symbol = await symbolAndProjectId.TryRehydrateAsync(
                solution, cancellationToken).ConfigureAwait(false);
 
            if (symbol == null)
                return;
 
            var context = new RemoteFindUsageContext(callback, callbackId);
            var classificationOptions = GetClientOptionsProvider<ClassificationOptions, IRemoteFindUsagesService.ICallback>(callback, callbackId);
 
            await AbstractFindUsagesService.FindReferencesAsync(
                context, symbol, project, options, classificationOptions, cancellationToken).ConfigureAwait(false);
        }, cancellationToken);
    }
 
    public ValueTask FindImplementationsAsync(
        Checksum solutionChecksum,
        RemoteServiceCallbackId callbackId,
        SerializableSymbolAndProjectId symbolAndProjectId,
        CancellationToken cancellationToken)
    {
        return RunServiceAsync(solutionChecksum, async solution =>
        {
            var project = solution.GetRequiredProject(symbolAndProjectId.ProjectId);
 
            var symbol = await symbolAndProjectId.TryRehydrateAsync(
                solution, cancellationToken).ConfigureAwait(false);
            if (symbol == null)
                return;
 
            var context = new RemoteFindUsageContext(callback, callbackId);
            var classificationOptions = GetClientOptionsProvider<ClassificationOptions, IRemoteFindUsagesService.ICallback>(callback, callbackId);
 
            await AbstractFindUsagesService.FindImplementationsAsync(
                context, symbol, project, classificationOptions, cancellationToken).ConfigureAwait(false);
        }, cancellationToken);
    }
 
    private sealed class RemoteFindUsageContext : IFindUsagesContext, IStreamingProgressTracker
    {
        private readonly RemoteCallback<IRemoteFindUsagesService.ICallback> _callback;
        private readonly RemoteServiceCallbackId _callbackId;
        private readonly Dictionary<DefinitionItem, int> _definitionItemToId = [];
 
        public RemoteFindUsageContext(RemoteCallback<IRemoteFindUsagesService.ICallback> callback, RemoteServiceCallbackId callbackId)
        {
            _callback = callback;
            _callbackId = callbackId;
        }
 
        #region IStreamingProgressTracker
 
        public ValueTask AddItemsAsync(int count, CancellationToken cancellationToken)
            => _callback.InvokeAsync((callback, cancellationToken) => callback.AddItemsAsync(_callbackId, count, cancellationToken), cancellationToken);
 
        public ValueTask ItemsCompletedAsync(int count, CancellationToken cancellationToken)
            => _callback.InvokeAsync((callback, cancellationToken) => callback.ItemsCompletedAsync(_callbackId, count, cancellationToken), cancellationToken);
 
        #endregion
 
        #region IFindUsagesContext
 
        public IStreamingProgressTracker ProgressTracker => this;
 
        public ValueTask ReportNoResultsAsync(string message, CancellationToken cancellationToken)
            => _callback.InvokeAsync((callback, cancellationToken) => callback.ReportMessageAsync(_callbackId, message, cancellationToken), cancellationToken);
 
        public ValueTask ReportMessageAsync(string message, NotificationSeverity severity, CancellationToken cancellationToken)
            => _callback.InvokeAsync((callback, cancellationToken) => callback.ReportInformationalMessageAsync(_callbackId, message, cancellationToken), cancellationToken);
 
        public ValueTask SetSearchTitleAsync(string title, CancellationToken cancellationToken)
            => _callback.InvokeAsync((callback, cancellationToken) => callback.SetSearchTitleAsync(_callbackId, title, cancellationToken), cancellationToken);
 
        public ValueTask OnDefinitionFoundAsync(DefinitionItem definition, CancellationToken cancellationToken)
        {
            var id = GetOrAddDefinitionItemId(definition);
            var dehydratedDefinition = SerializableDefinitionItem.Dehydrate(id, definition);
            return _callback.InvokeAsync((callback, cancellationToken) => callback.OnDefinitionFoundAsync(_callbackId, dehydratedDefinition, cancellationToken), cancellationToken);
        }
 
        private int GetOrAddDefinitionItemId(DefinitionItem item)
        {
            lock (_definitionItemToId)
            {
                if (!_definitionItemToId.TryGetValue(item, out var id))
                {
                    id = _definitionItemToId.Count;
                    _definitionItemToId.Add(item, id);
                }
 
                return id;
            }
        }
 
        public async ValueTask OnReferencesFoundAsync(IAsyncEnumerable<SourceReferenceItem> references, CancellationToken cancellationToken)
        {
            using var _ = ArrayBuilder<SerializableSourceReferenceItem>.GetInstance(out var dehydrated);
            await foreach (var reference in references)
            {
                var dehydratedReference = SerializableSourceReferenceItem.Dehydrate(
                    GetOrAddDefinitionItemId(reference.Definition), reference);
                dehydrated.Add(dehydratedReference);
            }
 
            var dehydratedReferences = dehydrated.ToImmutableAndClear();
            await _callback.InvokeAsync((callback, cancellationToken) => callback.OnReferencesFoundAsync(
                _callbackId, dehydratedReferences, cancellationToken), cancellationToken).ConfigureAwait(false);
        }
 
        #endregion
    }
}