File: Rpc\RpcServer.cs
Web Access
Project: src\roslyn\src\Workspaces\MSBuild\BuildHost\Microsoft.CodeAnalysis.Workspaces.MSBuild.BuildHost.csproj (Microsoft.CodeAnalysis.Workspaces.MSBuild.BuildHost)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipes;
using System.Reflection;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.MSBuild;

/// <summary>
/// Implements the server side of the RPC channel used to communicate with the build host.
/// </summary>
/// <remarks>
/// The RPC system implemented here is pretty close to something like JSON-RPC; however since we need the Build Host to be usable in Source Build
/// scenarios, we are limited to using only what is either in .NET or can be easily made buildable in Source Build. Thus existing solutions like StreamJsonRpc 
/// are out. If at some point there is a standard RPC mechanism exposed in .NET or Source Build, we should delete this and use that instead.
/// </remarks>
internal sealed class RpcServer
#if NETFRAMEWORK
    : MarshalByRefObject
#endif
{
    private readonly TextWriter _streamWriter;
    private readonly SemaphoreSlim _sendingStreamSemaphore = new(initialCount: 1);
    private readonly TextReader _streamReader;
    private readonly RpcMethodInvoker _rpcMethodInvoker;

    private readonly ConcurrentDictionary<int, object> _rpcTargets = [];
    private volatile int _nextRpcTargetIndex = -1; // We'll start at -1 so the first value becomes zero

    private readonly CancellationTokenSource _shutdownTokenSource = new();

    public RpcServer(PipeStream stream) : this(stream, new RpcMethodInvoker())
    {
    }

    public RpcServer(PipeStream stream, RpcMethodInvoker methodInvoker)
    {
        _streamWriter = new StreamWriter(stream, JsonSettings.StreamEncoding);
        _streamReader = new StreamReader(stream, JsonSettings.StreamEncoding);
        _rpcMethodInvoker = methodInvoker;
    }

    public int AddTarget(object rpcTarget)
    {
        // Loop until we successfully have a new index for this; practically we don't expect this to ever collide, since that'd mean we'd have
        // billions of long lived projects, but...
        while (true)
        {
            var nextIndex = Interlocked.Increment(ref _nextRpcTargetIndex);
            if (_rpcTargets.TryAdd(nextIndex, rpcTarget))
                return nextIndex;
        }
    }

    /// <summary>
    /// Runs the server, waiting for responses. The task is completed when the receiving stream closes (and thus no more requests can come in), or
    /// <see cref="Shutdown"/> is called.
    /// </summary>
    public async Task RunAsync()
    {
        var runningTasks = new ConcurrentSet<Task>();

        string? line;
        while ((line = await _streamReader.TryReadLineOrReturnNullIfCancelledAsync(_shutdownTokenSource.Token).ConfigureAwait(false)) != null)
        {
            Request? request;

            try
            {
                request = JsonSerializer.Deserialize<Request>(line, JsonSettings.SingleLineSerializerOptions);
                Contract.ThrowIfNull(request);
            }
            catch (Exception e)
            {
                throw new Exception($"Failure while deserializing '{line}'", innerException: e);
            }

            var runningTask = Task.Run(() => ProcessRequestAsync(request));

            // We'll add this task to the list of running tasks, and then create a continuation to remove it from the list again; this ensures
            // that we won't try to remove it before it was added in case the task completed by the time we got here.
            runningTasks.Add(runningTask);
            _ = runningTask.ContinueWith(
                _ => Contract.ThrowIfFalse(runningTasks.Remove(runningTask)),
                CancellationToken.None,
                TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
        }

        // Wait until all outstanding requests are processed; we however first must copy this into a list safely
        // since the collection might get modified while we're calling Task.WhenAll. The problem is (as of this writing)
        // ConcurrentSet implements ICollection, and all the common helpers like EnumerableExtension.ToArray(), ToList(),
        // etc. all have an optimization where if the IEnumerable implements ICollection, the helpers ask for the count
        // and pre-allocate an array. But if the collection then gets smaller, the array is never resized and so you'll end
        // up with nulls. See the comment at https://github.com/dotnet/runtime/blob/46c8a668eb4bbc66d9eb988d2988ecc84074be10/src/libraries/Common/src/System/Collections/Generic/EnumerableHelpers.cs#L29-L34
        // for example of this concern.
        var remainingTasks = new List<Task>(capacity: runningTasks.Count);
        foreach (var task in runningTasks)
            remainingTasks.Add(task);

        await Task.WhenAll(remainingTasks).ConfigureAwait(false);
    }

    private async Task ProcessRequestAsync(Request request)
    {
        Response response;

        try
        {
            Contract.ThrowIfFalse(
                _rpcTargets.TryGetValue(request.TargetObject, out var rpcTarget),
                $"Received a request for target object {request.TargetObject} but we don't have a registered object for that.");

            var method = rpcTarget.GetType().GetMethod(request.Method, BindingFlags.Public | BindingFlags.Instance);

            Contract.ThrowIfNull(method, $"The invoked method '{request.Method}' could not be found.");

            var methodParameters = method.GetParameters();

            var lastParameterIsCancellationToken = methodParameters.Length > 0 && methodParameters[^1].ParameterType == typeof(CancellationToken);

            if (lastParameterIsCancellationToken)
                Contract.ThrowIfFalse(request.Parameters.Length == methodParameters.Length - 1, $"The arguments list should contain every parameter for {request.Method} except the final CancellationToken.");
            else
                Contract.ThrowIfFalse(request.Parameters.Length == methodParameters.Length, $"The arguments list should contain every parameter for {request.Method}.");

            var arguments = new object?[methodParameters.Length];

            for (var i = 0; i < methodParameters.Length; i++)
            {
                // If the method we're calling accepts a cancellation token, we want to fill in a CancellationToken. That filling in happens in the
                // RpcMethodInvoker, so we just keep the array null here.
                if (!(i == methodParameters.Length - 1 && lastParameterIsCancellationToken))
                    arguments[i] = request.Parameters[i].Deserialize(methodParameters[i].ParameterType, JsonSettings.SingleLineSerializerOptions);
            }

            var result = _rpcMethodInvoker.InvokeMethod(rpcTarget, method, arguments, lastParameterIsCancellationToken);

            if (result is Task resultTask)
            {
                result = await RpcMethodInvoker.GetTaskResultAsync(resultTask, calledMethod: method).ConfigureAwait(false);
            }

            response = new Response { Id = request.Id, Value = result is not null ? JsonSerializer.SerializeToElement(result, JsonSettings.SingleLineSerializerOptions) : null };
        }
        catch (Exception e)
        {
            if (e is TargetInvocationException)
                e = e.InnerException ?? e;

            response = new Response { Id = request.Id, ExceptionMessage = $"An exception of type {e.GetType()} was thrown: {e.Message}", ExceptionStackTrace = e.StackTrace };
        }

        var responseJson = JsonSerializer.Serialize(response, JsonSettings.SingleLineSerializerOptions);

#if DEBUG
        // Assert we didn't put a newline in this, since if we did the receiving side won't know how to parse it
        Contract.ThrowIfTrue(responseJson.Contains("\r") || responseJson.Contains("\n"));
#endif
        using (await _sendingStreamSemaphore.DisposableWaitAsync().ConfigureAwait(false))
        {
            await _streamWriter.WriteLineAsync(responseJson).ConfigureAwait(false);
            await _streamWriter.FlushAsync().ConfigureAwait(false);
        }
    }

    public void Shutdown()
    {
        _shutdownTokenSource.Cancel();
    }
}