File: Remote\InProcRemostHostClient.cs
Web Access
Project: src\src\Workspaces\CoreTestUtilities\Microsoft.CodeAnalysis.Workspaces.Test.Utilities.csproj (Microsoft.CodeAnalysis.Workspaces.Test.Utilities)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.EditAndContinue;
using Microsoft.CodeAnalysis.ErrorReporting;
using Microsoft.CodeAnalysis.Host;
using Microsoft.ServiceHub.Framework;
using Nerdbank.Streams;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.Remote.Testing
{
    internal sealed partial class InProcRemoteHostClient : RemoteHostClient
    {
        private readonly SolutionServices _workspaceServices;
        private readonly InProcRemoteServices _inprocServices;
        private readonly RemoteServiceCallbackDispatcherRegistry _callbackDispatchers;
 
        public static RemoteHostClient Create(SolutionServices services, RemoteServiceCallbackDispatcherRegistry callbackDispatchers, TraceListener? traceListener, RemoteHostTestData testData)
        {
            var inprocServices = new InProcRemoteServices(services, traceListener, testData);
            var instance = new InProcRemoteHostClient(services, inprocServices, callbackDispatchers);
 
            // return instance
            return instance;
        }
 
        private InProcRemoteHostClient(
            SolutionServices services,
            InProcRemoteServices inprocServices,
            RemoteServiceCallbackDispatcherRegistry callbackDispatchers)
        {
            _workspaceServices = services;
            _callbackDispatchers = callbackDispatchers;
            _inprocServices = inprocServices;
        }
 
        public static async Task<InProcRemoteHostClient> GetTestClientAsync(Workspace workspace)
        {
            var client = (InProcRemoteHostClient?)await TryGetClientAsync(workspace, CancellationToken.None).ConfigureAwait(false);
            Contract.ThrowIfNull(client);
            return client;
        }
 
        public RemoteWorkspace GetRemoteWorkspace()
            => TestData.WorkspaceManager.GetWorkspace();
 
        public RemoteHostTestData TestData => _inprocServices.TestData;
 
        public override RemoteServiceConnection<T> CreateConnection<T>(object? callbackTarget) where T : class
        {
            var descriptor = ServiceDescriptors.Instance.GetServiceDescriptor(typeof(T), RemoteProcessConfiguration.ServerGC);
            var callbackDispatcher = (descriptor.ClientInterface != null) ? _callbackDispatchers.GetDispatcher(typeof(T)) : null;
 
            return new BrokeredServiceConnection<T>(
                descriptor,
                callbackTarget,
                callbackDispatcher,
                _inprocServices.ServiceBrokerClient,
                _workspaceServices.GetRequiredService<ISolutionAssetStorageProvider>().AssetStorage,
                _workspaceServices.GetRequiredService<IErrorReportingService>(),
                shutdownCancellationService: null,
                remoteProcess: null);
        }
 
        public override void Dispose()
        {
            _inprocServices.Dispose();
        }
 
        public sealed class ServiceProvider : IServiceProvider
        {
            public readonly TraceSource TraceSource;
            public readonly RemoteHostTestData TestData;
 
            public ServiceProvider(TraceSource traceSource, RemoteHostTestData testData)
            {
                TraceSource = traceSource;
                TestData = testData;
            }
 
            public object GetService(Type serviceType)
            {
                if (typeof(TraceSource) == serviceType)
                {
                    return TraceSource;
                }
 
                if (typeof(RemoteHostTestData) == serviceType)
                {
                    return TestData;
                }
 
                throw ExceptionUtilities.UnexpectedValue(serviceType);
            }
        }
 
        private sealed class InProcServiceBroker : IServiceBroker
        {
            private readonly InProcRemoteServices _services;
 
            public InProcServiceBroker(InProcRemoteServices services)
            {
                _services = services;
            }
 
            public event EventHandler<BrokeredServicesChangedEventArgs>? AvailabilityChanged { add { } remove { } }
 
            // This method is currently not needed for our IServiceBroker usage patterns.
            public ValueTask<IDuplexPipe?> GetPipeAsync(ServiceMoniker serviceMoniker, ServiceActivationOptions options, CancellationToken cancellationToken)
                => throw ExceptionUtilities.Unreachable();
 
            public ValueTask<T?> GetProxyAsync<T>(ServiceRpcDescriptor descriptor, ServiceActivationOptions options, CancellationToken cancellationToken) where T : class
            {
                var pipePair = FullDuplexStream.CreatePipePair();
 
                var clientConnection = descriptor
                    .WithTraceSource(_services.ServiceProvider.TraceSource)
                    .ConstructRpcConnection(pipePair.Item2);
 
                Contract.ThrowIfFalse(options.ClientRpcTarget is null == descriptor.ClientInterface is null);
 
                if (descriptor.ClientInterface != null)
                {
                    Contract.ThrowIfNull(options.ClientRpcTarget);
                    clientConnection.AddLocalRpcTarget(options.ClientRpcTarget);
                }
 
                // Clear RPC target so that the server connection is forced to create a new proxy for the callback
                // instead of just invoking the callback object directly (this emulates the product that does
                // not serialize the callback object over).
                options.ClientRpcTarget = null;
 
                // Creates service instance and connects it to the pipe. 
                // We don't need to store the instance anywhere.
                _ = _services.CreateBrokeredService(descriptor, pipePair.Item1, options);
 
                clientConnection.StartListening();
 
                return ValueTaskFactory.FromResult((T?)clientConnection.ConstructRpcClient<T>());
            }
        }
 
        private sealed class InProcRemoteServices : IDisposable
        {
            public readonly ServiceProvider ServiceProvider;
            private readonly Dictionary<ServiceMoniker, Func<object>> _inProcBrokeredServicesMap = [];
            private readonly Dictionary<ServiceMoniker, BrokeredServiceBase.IFactory> _remoteBrokeredServicesMap = [];
 
            public readonly IServiceBroker ServiceBroker;
            public readonly ServiceBrokerClient ServiceBrokerClient;
 
            public InProcRemoteServices(SolutionServices workspaceServices, TraceListener? traceListener, RemoteHostTestData testData)
            {
                var remoteLogger = new TraceSource("InProcRemoteClient")
                {
                    Switch = { Level = SourceLevels.Warning },
                };
 
                if (traceListener != null)
                {
                    remoteLogger.Listeners.Add(traceListener);
                }
 
                ServiceProvider = new ServiceProvider(remoteLogger, testData);
 
                ServiceBroker = new InProcServiceBroker(this);
#pragma warning disable VSTHRD012 // Provide JoinableTaskFactory where allowed
                ServiceBrokerClient = new ServiceBrokerClient(ServiceBroker);
#pragma warning restore
 
                RegisterInProcBrokeredService(SolutionAssetProvider.ServiceDescriptor, () => new SolutionAssetProvider(workspaceServices));
                RegisterRemoteBrokeredService(new RemoteAssetSynchronizationService.Factory());
                RegisterRemoteBrokeredService(new RemoteAsynchronousOperationListenerService.Factory());
                RegisterRemoteBrokeredService(new RemoteCodeLensReferencesService.Factory());
                RegisterRemoteBrokeredService(new RemoteConvertTupleToStructCodeRefactoringService.Factory());
                RegisterRemoteBrokeredService(new RemoteDependentTypeFinderService.Factory());
                RegisterRemoteBrokeredService(new RemoteDesignerAttributeDiscoveryService.Factory());
                RegisterRemoteBrokeredService(new RemoteDiagnosticAnalyzerService.Factory());
                RegisterRemoteBrokeredService(new RemoteDocumentHighlightsService.Factory());
                RegisterRemoteBrokeredService(new RemoteEditAndContinueService.Factory());
                RegisterRemoteBrokeredService(new RemoteEncapsulateFieldService.Factory());
                RegisterRemoteBrokeredService(new RemoteExtensionMethodImportCompletionService.Factory());
                RegisterRemoteBrokeredService(new RemoteFindUsagesService.Factory());
                RegisterRemoteBrokeredService(new RemoteFullyQualifyService.Factory());
                RegisterRemoteBrokeredService(new RemoteInheritanceMarginService.Factory());
                RegisterRemoteBrokeredService(new RemoteKeepAliveService.Factory());
                RegisterRemoteBrokeredService(new RemoteLegacySolutionEventsAggregationService.Factory());
                RegisterRemoteBrokeredService(new RemoteMissingImportDiscoveryService.Factory());
                RegisterRemoteBrokeredService(new RemoteNavigateToSearchService.Factory());
                RegisterRemoteBrokeredService(new RemoteNavigationBarItemService.Factory());
                RegisterRemoteBrokeredService(new RemoteProcessTelemetryService.Factory());
                RegisterRemoteBrokeredService(new RemoteRelatedDocumentsService.Factory());
                RegisterRemoteBrokeredService(new RemoteRenamerService.Factory());
                RegisterRemoteBrokeredService(new RemoteSemanticClassificationService.Factory());
                RegisterRemoteBrokeredService(new RemoteSemanticSearchService.Factory());
                RegisterRemoteBrokeredService(new RemoteSourceGenerationService.Factory());
                RegisterRemoteBrokeredService(new RemoteStackTraceExplorerService.Factory());
                RegisterRemoteBrokeredService(new RemoteSymbolFinderService.Factory());
                RegisterRemoteBrokeredService(new RemoteSymbolSearchUpdateService.Factory());
                RegisterRemoteBrokeredService(new RemoteTaskListService.Factory());
                RegisterRemoteBrokeredService(new RemoteUnitTestingSearchService.Factory());
                RegisterRemoteBrokeredService(new RemoteUnusedReferenceAnalysisService.Factory());
                RegisterRemoteBrokeredService(new RemoteValueTrackingService.Factory());
            }
 
            public void Dispose()
                => ServiceBrokerClient.Dispose();
 
            public RemoteHostTestData TestData => ServiceProvider.TestData;
 
            public void RegisterInProcBrokeredService(ServiceDescriptor serviceDescriptor, Func<object> serviceFactory)
            {
                _inProcBrokeredServicesMap.Add(serviceDescriptor.Moniker, serviceFactory);
            }
 
            public void RegisterRemoteBrokeredService(BrokeredServiceBase.IFactory serviceFactory)
            {
                var moniker = ServiceDescriptors.Instance.GetServiceDescriptorForServiceFactory(serviceFactory.ServiceType).Moniker;
                _remoteBrokeredServicesMap.Add(moniker, serviceFactory);
            }
 
            public object CreateBrokeredService(ServiceRpcDescriptor descriptor, IDuplexPipe pipe, ServiceActivationOptions options)
            {
                if (_inProcBrokeredServicesMap.TryGetValue(descriptor.Moniker, out var inProcFactory))
                {
                    // This code is similar to service creation implemented in BrokeredServiceBase.FactoryBase.
                    // Currently don't support callback creation as we don't have in-proc service with callbacks yet.
                    Contract.ThrowIfFalse(descriptor.ClientInterface == null);
 
                    var serviceConnection = descriptor.WithTraceSource(ServiceProvider.TraceSource).ConstructRpcConnection(pipe);
                    var service = inProcFactory();
 
                    serviceConnection.AddLocalRpcTarget(service);
                    serviceConnection.StartListening();
 
                    return service;
                }
 
                if (_remoteBrokeredServicesMap.TryGetValue(descriptor.Moniker, out var remoteFactory))
                {
                    return remoteFactory.Create(pipe, ServiceProvider, options, ServiceBroker);
                }
 
                throw ExceptionUtilities.UnexpectedValue(descriptor.Moniker);
            }
 
            private sealed class WrappedStream : Stream
            {
                private readonly IDisposable _service;
                private readonly Stream _stream;
 
                public WrappedStream(IDisposable service, Stream stream)
                {
                    // tie service's lifetime with that of stream
                    _service = service;
                    _stream = stream;
                }
 
                public override long Position
                {
                    get { return _stream.Position; }
                    set { _stream.Position = value; }
                }
 
                public override int ReadTimeout
                {
                    get { return _stream.ReadTimeout; }
                    set { _stream.ReadTimeout = value; }
                }
 
                public override int WriteTimeout
                {
                    get { return _stream.WriteTimeout; }
                    set { _stream.WriteTimeout = value; }
                }
 
                public override bool CanRead => _stream.CanRead;
                public override bool CanSeek => _stream.CanSeek;
                public override bool CanWrite => _stream.CanWrite;
                public override long Length => _stream.Length;
                public override bool CanTimeout => _stream.CanTimeout;
 
                public override void Flush() => _stream.Flush();
                public override Task FlushAsync(CancellationToken cancellationToken) => _stream.FlushAsync(cancellationToken);
 
                public override long Seek(long offset, SeekOrigin origin) => _stream.Seek(offset, origin);
                public override void SetLength(long value) => _stream.SetLength(value);
 
                public override int ReadByte() => _stream.ReadByte();
                public override void WriteByte(byte value) => _stream.WriteByte(value);
 
                public override int Read(byte[] buffer, int offset, int count) => _stream.Read(buffer, offset, count);
                public override void Write(byte[] buffer, int offset, int count) => _stream.Write(buffer, offset, count);
 
                public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => _stream.ReadAsync(buffer, offset, count, cancellationToken);
                public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => _stream.WriteAsync(buffer, offset, count, cancellationToken);
 
#if NET // nullability annotations differ
                public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) => _stream.BeginRead(buffer, offset, count, callback, state);
                public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) => _stream.BeginWrite(buffer, offset, count, callback, state);
#else
                public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object? state) => _stream.BeginRead(buffer, offset, count, callback, state);
                public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object? state) => _stream.BeginWrite(buffer, offset, count, callback, state);
#endif
                public override int EndRead(IAsyncResult asyncResult) => _stream.EndRead(asyncResult);
                public override void EndWrite(IAsyncResult asyncResult) => _stream.EndWrite(asyncResult);
 
                public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) => _stream.CopyToAsync(destination, bufferSize, cancellationToken);
 
                public override void Close()
                {
                    _service.Dispose();
                    _stream.Close();
                }
 
                protected override void Dispose(bool disposing)
                {
                    base.Dispose(disposing);
 
                    _service.Dispose();
                    _stream.Dispose();
                }
            }
        }
    }
}