File: Plugins\MessageDispatcher.cs
Web Access
Project: src\src\nuget-client\src\NuGet.Core\NuGet.Protocol\NuGet.Protocol.csproj (NuGet.Protocol)
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

#nullable disable

using System;
using System.Collections.Concurrent;
using System.Globalization;
using System.Threading;
using System.Threading.Tasks;

namespace NuGet.Protocol.Plugins
{
    /// <summary>
    /// A message dispatcher that maintains state for outstanding requests
    /// and routes messages to configured request handlers.
    /// </summary>
    public sealed class MessageDispatcher : IMessageDispatcher, IResponseHandler
    {
        private IConnection _connection;
        private readonly IIdGenerator _idGenerator;
        private bool _isClosed;
        private bool _isDisposed;
        private readonly ConcurrentDictionary<string, InboundRequestContext> _inboundRequestContexts;
        private readonly IPluginLogger _logger;
        private readonly ConcurrentDictionary<string, OutboundRequestContext> _outboundRequestContexts;
        private readonly InboundRequestProcessingHandler _inboundRequestProcessingContext;

        /// <summary>
        /// Gets the request handlers for use by the dispatcher.
        /// </summary>
        public IRequestHandlers RequestHandlers { get; }

        /// <summary>
        /// Instantiates a new <see cref="MessageDispatcher" /> class.
        /// </summary>
        /// <param name="requestHandlers">Request handlers.</param>
        /// <param name="idGenerator">A unique identifier generator.</param>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="requestHandlers" />
        /// is <see langword="null" />.</exception>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="idGenerator" />
        /// is <see langword="null" />.</exception>
        public MessageDispatcher(IRequestHandlers requestHandlers, IIdGenerator idGenerator)
            : this(requestHandlers, idGenerator, new InboundRequestProcessingHandler(), PluginLogger.DefaultInstance)
        {
        }

        /// <summary>
        /// Instantiates a new <see cref="MessageDispatcher" /> class.
        /// </summary>
        /// <param name="requestHandlers">Request handlers.</param>
        /// <param name="idGenerator">A unique identifier generator.</param>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="requestHandlers" />
        /// is <see langword="null" />.</exception>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="idGenerator" />
        /// is <see langword="null" />.</exception>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="inboundRequestProcessingHandler" />
        /// is <see langword="null" />.</exception>
        /// /// <exception cref="ArgumentNullException">Thrown if <paramref name="logger" />
        /// is <see langword="null" />.</exception>
        internal MessageDispatcher(IRequestHandlers requestHandlers, IIdGenerator idGenerator, InboundRequestProcessingHandler inboundRequestProcessingHandler, IPluginLogger logger)
        {
            if (requestHandlers == null)
            {
                throw new ArgumentNullException(nameof(requestHandlers));
            }

            if (idGenerator == null)
            {
                throw new ArgumentNullException(nameof(idGenerator));
            }

            if (inboundRequestProcessingHandler == null)
            {
                throw new ArgumentNullException(nameof(inboundRequestProcessingHandler));
            }
            if (logger == null)
            {
                throw new ArgumentNullException(nameof(logger));
            }

            RequestHandlers = requestHandlers;
            _idGenerator = idGenerator;
            _logger = logger;

            _inboundRequestContexts = new ConcurrentDictionary<string, InboundRequestContext>();
            _outboundRequestContexts = new ConcurrentDictionary<string, OutboundRequestContext>();
            _inboundRequestProcessingContext = inboundRequestProcessingHandler;
        }

        /// <summary>
        /// Disposes of this instance.
        /// </summary>
        public void Dispose()
        {
            if (_isDisposed)
            {
                return;
            }

            Close();
            _inboundRequestProcessingContext.Dispose();
            SetConnection(connection: null);

            // Do not dispose of _logger.  This message dispatcher does not own it.

            GC.SuppressFinalize(this);

            _isDisposed = true;
        }

        /// <summary>
        /// Closes the dispatcher.
        /// </summary>
        /// <remarks>This does not call <see cref="IDisposable.Dispose" />.</remarks>
        public void Close()
        {
            if (!_isClosed)
            {
                SetConnection(connection: null);

                foreach (var entry in _inboundRequestContexts)
                {
                    entry.Value.Dispose();
                }

                foreach (var entry in _outboundRequestContexts)
                {
                    entry.Value.Dispose();
                }

                _isClosed = true;
            }
        }

        /// <summary>
        /// Creates a message.
        /// </summary>
        /// <param name="type">The message type.</param>
        /// <param name="method">The message method.</param>
        /// <returns>A message.</returns>
        public Message CreateMessage(MessageType type, MessageMethod method)
        {
            var requestId = _idGenerator.GenerateUniqueId();

            return MessageUtilities.Create(requestId, type, method);
        }

        /// <summary>
        /// Creates a message.
        /// </summary>
        /// <typeparam name="TPayload">The message payload.</typeparam>
        /// <param name="type">The message type.</param>
        /// <param name="method">The message method.</param>
        /// <param name="payload">The message payload.</param>
        /// <returns>A message.</returns>
        /// <exception cref="ArgumentNullException">Throws if <paramref name="payload" /> is <see langword="null" />.</exception>
        public Message CreateMessage<TPayload>(MessageType type, MessageMethod method, TPayload payload)
            where TPayload : class
        {
            if (payload == null)
            {
                throw new ArgumentNullException(nameof(payload));
            }

            var requestId = _idGenerator.GenerateUniqueId();

            return MessageUtilities.Create(requestId, type, method, payload);
        }

        /// <summary>
        /// Asynchronously dispatches a cancellation request for the specified request.
        /// </summary>
        /// <param name="request">The request.</param>
        /// <param name="cancellationToken">A cancellation token.</param>
        /// <returns>A task that represents the asynchronous operation.</returns>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="request" /> is <see langword="null" />.</exception>
        /// <exception cref="OperationCanceledException">Thrown if <paramref name="cancellationToken" />
        /// is cancelled.</exception>
        public Task DispatchCancelAsync(Message request, CancellationToken cancellationToken)
        {
            if (request == null)
            {
                throw new ArgumentNullException(nameof(request));
            }

            cancellationToken.ThrowIfCancellationRequested();

            // Capture _connection as SetConnection(...) could null it out later.
            var connection = _connection;

            if (connection == null)
            {
                return Task.CompletedTask;
            }

            return DispatchCancelAsync(connection, request, cancellationToken);
        }

        /// <summary>
        /// Asynchronously dispatches a fault notification for the specified request.
        /// </summary>
        /// <param name="request">The cancel request.</param>
        /// <param name="fault">The fault payload.</param>
        /// <param name="cancellationToken">A cancellation token.</param>
        /// <returns>A task that represents the asynchronous operation.</returns>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="fault" /> is <see langword="null" />.</exception>
        /// <exception cref="OperationCanceledException">Thrown if <paramref name="cancellationToken" />
        /// is cancelled.</exception>
        public Task DispatchFaultAsync(Message request, Fault fault, CancellationToken cancellationToken)
        {
            if (fault == null)
            {
                throw new ArgumentNullException(nameof(fault));
            }

            cancellationToken.ThrowIfCancellationRequested();

            // Capture _connection as SetConnection(...) could null it out later.
            var connection = _connection;

            if (connection == null)
            {
                return Task.CompletedTask;
            }

            return DispatchFaultAsync(connection, request, fault, cancellationToken);
        }

        /// <summary>
        /// Asynchronously dispatches a progress notification for the specified request.
        /// </summary>
        /// <param name="request">The request.</param>
        /// <param name="progress">The progress payload.</param>
        /// <param name="cancellationToken">A cancellation token.</param>
        /// <returns>A task that represents the asynchronous operation.</returns>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="request" /> is <see langword="null" />.</exception>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="progress" /> is <see langword="null" />.</exception>
        /// <exception cref="OperationCanceledException">Thrown if <paramref name="cancellationToken" />
        /// is cancelled.</exception>
        public Task DispatchProgressAsync(Message request, Progress progress, CancellationToken cancellationToken)
        {
            if (request == null)
            {
                throw new ArgumentNullException(nameof(request));
            }

            if (progress == null)
            {
                throw new ArgumentNullException(nameof(progress));
            }

            cancellationToken.ThrowIfCancellationRequested();

            // Capture _connection as SetConnection(...) could null it out later.
            var connection = _connection;

            if (connection == null)
            {
                return Task.CompletedTask;
            }

            return DispatchProgressAsync(connection, request, progress, cancellationToken);
        }

        /// <summary>
        /// Asynchronously dispatches a request.
        /// </summary>
        /// <typeparam name="TOutbound">The request payload type.</typeparam>
        /// <typeparam name="TInbound">The expected response payload type.</typeparam>
        /// <param name="method">The request method.</param>
        /// <param name="payload">The request payload.</param>
        /// <param name="cancellationToken">A cancellation token.</param>
        /// <returns>A task that represents the asynchronous operation.
        /// The task result (<see cref="Task{TResult}.Result" />) returns a <typeparamref name="TInbound" />
        /// from the target.</returns>
        /// <exception cref="OperationCanceledException">Thrown if <paramref name="cancellationToken" />
        /// is cancelled.</exception>
        public Task<TInbound> DispatchRequestAsync<TOutbound, TInbound>(
            MessageMethod method,
            TOutbound payload,
            CancellationToken cancellationToken)
            where TOutbound : class
            where TInbound : class
        {
            cancellationToken.ThrowIfCancellationRequested();

            // Capture _connection as SetConnection(...) could null it out later.
            var connection = _connection;

            if (connection == null)
            {
                return TaskResult.Null<TInbound>();
            }

            return DispatchWithNewContextAsync<TOutbound, TInbound>(
                connection,
                MessageType.Request,
                method,
                payload,
                cancellationToken);
        }

        /// <summary>
        /// Asynchronously dispatches a response.
        /// </summary>
        /// <typeparam name="TOutbound">The request payload type.</typeparam>
        /// <param name="request">The associated request.</param>
        /// <param name="responsePayload">The response payload.</param>
        /// <param name="cancellationToken">A cancellation token.</param>
        /// <returns>A task that represents the asynchronous operation.</returns>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="request" /> is <see langword="null" />.</exception>
        /// <exception cref="ArgumentNullException">Thrown if <paramref name="responsePayload" /> is <see langword="null" />.</exception>
        /// <exception cref="OperationCanceledException">Thrown if <paramref name="cancellationToken" />
        /// is cancelled.</exception>
        public Task DispatchResponseAsync<TOutbound>(
            Message request,
            TOutbound responsePayload,
            CancellationToken cancellationToken)
            where TOutbound : class
        {
            if (request == null)
            {
                throw new ArgumentNullException(nameof(request));
            }

            if (responsePayload == null)
            {
                throw new ArgumentNullException(nameof(responsePayload));
            }

            cancellationToken.ThrowIfCancellationRequested();

            // Capture _connection as SetConnection(...) could null it out later.
            var connection = _connection;

            if (connection == null)
            {
                return Task.CompletedTask;
            }

            return DispatchAsync(connection, MessageType.Response, request, responsePayload, cancellationToken);
        }

        /// <summary>
        /// Sets the connection to be used for dispatching messages.
        /// </summary>
        /// <param name="connection">A connection instance.  Can be <see langword="null" />.</param>
        public void SetConnection(IConnection connection)
        {
            if (_connection == connection)
            {
                return;
            }

            if (_connection != null)
            {
                _connection.MessageReceived -= OnMessageReceived;
            }

            _connection = connection;

            if (_connection != null)
            {
                _connection.MessageReceived += OnMessageReceived;
            }
        }

        Task IResponseHandler.SendResponseAsync<TPayload>(
            Message request,
            TPayload payload,
            CancellationToken cancellationToken)
        {
            return DispatchResponseAsync(request, payload, cancellationToken);
        }

        private async Task DispatchAsync<TOutgoing>(
            IConnection connection,
            MessageType type,
            Message request,
            TOutgoing payload,
            CancellationToken cancellationToken)
            where TOutgoing : class
        {
            InboundRequestContext requestContext;

            if (!_inboundRequestContexts.TryGetValue(request.RequestId, out requestContext))
            {
                return;
            }

            var message = MessageUtilities.Create(request.RequestId, type, request.Method, payload);

            try
            {
                await connection.SendAsync(message, cancellationToken);
            }
            finally
            {
                RemoveInboundRequestContext(request.RequestId);
            }
        }

        private async Task DispatchCancelAsync(
            IConnection connection,
            Message request,
            CancellationToken cancellationToken)
        {
            var message = MessageUtilities.Create(request.RequestId, MessageType.Cancel, request.Method);

            await DispatchWithExistingContextAsync(connection, message, cancellationToken);
        }

        private async Task DispatchFaultAsync(
            IConnection connection,
            Message request,
            Fault fault,
            CancellationToken cancellationToken)
        {
            Message message;

            if (request == null)
            {
                var requestId = _idGenerator.GenerateUniqueId();

                message = new Message(requestId, MessageType.Fault, MessageMethod.None, fault);

                await connection.SendAsync(message, cancellationToken);
            }
            else
            {
                message = new Message(request.RequestId, MessageType.Fault, request.Method, fault);

                await DispatchWithExistingContextAsync(connection, message, cancellationToken);
            }
        }

        private async Task DispatchProgressAsync(
            IConnection connection,
            Message request,
            Progress progress,
            CancellationToken cancellationToken)
        {
            var message = MessageUtilities.Create(request.RequestId, MessageType.Progress, request.Method, progress);

            await DispatchWithExistingContextAsync(connection, message, cancellationToken);
        }

        private async Task DispatchWithExistingContextAsync(
            IConnection connection,
            Message response,
            CancellationToken cancellationToken)
        {
            var requestContext = GetOutboundRequestContext(response.RequestId);

            await connection.SendAsync(response, cancellationToken);
        }

        private async Task<TIncoming> DispatchWithNewContextAsync<TOutgoing, TIncoming>(
            IConnection connection,
            MessageType type,
            MessageMethod method,
            TOutgoing payload,
            CancellationToken cancellationToken)
            where TOutgoing : class
            where TIncoming : class
        {
            var message = CreateMessage(type, method, payload);
            var timeout = GetRequestTimeout(connection, type, method);
            var isKeepAlive = GetIsKeepAlive(type, method);
            var requestContext = CreateOutboundRequestContext<TIncoming>(
                message,
                timeout,
                isKeepAlive,
                cancellationToken);

            _outboundRequestContexts.TryAdd(message.RequestId, requestContext);

            switch (type)
            {
                case MessageType.Request:
                case MessageType.Response:
                case MessageType.Fault:
                    var removeRequestContext = true;

                    try
                    {
                        await connection.SendAsync(message, requestContext.CancellationToken);

                        return await requestContext.CompletionTask;
                    }
                    catch (OperationCanceledException) when (requestContext.CancellationToken.IsCancellationRequested)
                    {
                        if (_logger.IsEnabled)
                        {
                            _logger.Write(new CommunicationLogMessage(_logger.Now, message.RequestId, message.Method, message.Type, MessageState.Cancelled));
                        }

                        // Keep the request context around if cancellation was requested.
                        // A race condition exists where after sending a cancellation request,
                        // we could receive a response (which was in flight) or a cancellation
                        // response.
                        // If a normal response (success/failure) and not a cancellation response
                        // is received after a cancellation request, we need to have an active
                        // request context to avoid a protocol exception.
                        removeRequestContext = false;

                        throw;
                    }
                    finally
                    {
                        if (removeRequestContext)
                        {
                            RemoveOutboundRequestContext(message.RequestId);
                        }
                    }

                default:
                    break;
            }

            return null;
        }

        private void OnMessageReceived(object sender, MessageEventArgs e)
        {
            // Capture _connection as SetConnection(...) could null it out later.
            var connection = _connection;

            if (connection == null)
            {
                return;
            }

            OutboundRequestContext requestContext;

            if (_outboundRequestContexts.TryGetValue(e.Message.RequestId, out requestContext))
            {
                switch (e.Message.Type)
                {
                    case MessageType.Response:
                        requestContext.HandleResponse(e.Message);
                        break;

                    case MessageType.Progress:
                        requestContext.HandleProgress(e.Message);
                        break;

                    case MessageType.Fault:
                        requestContext.HandleFault(e.Message);
                        break;

                    case MessageType.Cancel:
                        requestContext.HandleCancelResponse();
                        break;

                    default:
                        throw new ProtocolException(
                            string.Format(
                                CultureInfo.CurrentCulture,
                                Strings.Plugin_InvalidMessageType,
                                e.Message.Type));
                }

                return;
            }

            switch (e.Message.Type)
            {
                case MessageType.Cancel:
                    HandleInboundCancel(e.Message);
                    break;

                case MessageType.Request:
                    HandleInboundRequest(e.Message);
                    break;

                case MessageType.Fault:
                    HandleInboundFault(e.Message);
                    break;

                default:
                    throw new ProtocolException(
                        string.Format(
                            CultureInfo.CurrentCulture,
                            Strings.Plugin_InvalidMessageType,
                            e.Message.Type));
            }
        }

        private void HandleInboundCancel(Message message)
        {
            InboundRequestContext requestContext;

            if (_inboundRequestContexts.TryGetValue(message.RequestId, out requestContext))
            {
                requestContext.Cancel();
            }
        }

        private void HandleInboundFault(Message fault)
        {
            if (fault == null)
            {
                throw new ArgumentNullException(nameof(fault));
            }

            var payload = MessageUtilities.DeserializePayload<Fault>(fault);

            throw new ProtocolException(payload.Message);
        }

        private void HandleInboundRequest(Message message)
        {
            var cancellationToken = CancellationToken.None;
            IRequestHandler requestHandler = null;
            ProtocolException exception = null;

            try
            {
                requestHandler = GetInboundRequestHandler(message.Method);
                cancellationToken = requestHandler.CancellationToken;
            }
            catch (ProtocolException ex)
            {
                exception = ex;
            }

            var requestContext = CreateInboundRequestContext(message, cancellationToken);

            if (exception == null && requestHandler != null)
            {
                _inboundRequestContexts.TryAdd(message.RequestId, requestContext);

                requestContext.BeginResponseAsync(message, requestHandler, this);
            }
            else
            {
                requestContext.BeginFaultAsync(message, exception);
            }
        }

        private IRequestHandler GetInboundRequestHandler(MessageMethod method)
        {
            IRequestHandler handler;

            if (!RequestHandlers.TryGet(method, out handler))
            {
                throw new ProtocolException(
                    string.Format(CultureInfo.CurrentCulture, Strings.Plugin_RequestHandlerDoesNotExist, method));
            }

            return handler;
        }

        private OutboundRequestContext GetOutboundRequestContext(string requestId)
        {
            OutboundRequestContext requestContext;

            if (!_outboundRequestContexts.TryGetValue(requestId, out requestContext))
            {
                throw new ProtocolException(
                    string.Format(CultureInfo.CurrentCulture, Strings.Plugin_RequestContextDoesNotExist, requestId));
            }

            return requestContext;
        }

        private void RemoveInboundRequestContext(string requestId)
        {
            InboundRequestContext requestContext;

            if (_inboundRequestContexts.TryRemove(requestId, out requestContext))
            {
                requestContext.Dispose();
            }
        }

        private void RemoveOutboundRequestContext(string requestId)
        {
            OutboundRequestContext requestContext;

            if (_outboundRequestContexts.TryRemove(requestId, out requestContext))
            {
                requestContext.Dispose();
            }
        }

        private InboundRequestContext CreateInboundRequestContext(
            Message message,
            CancellationToken cancellationToken)
        {
            return new InboundRequestContext(
                _connection,
                message.RequestId,
                cancellationToken,
                _inboundRequestProcessingContext,
                _logger);
        }

        private OutboundRequestContext<TIncoming> CreateOutboundRequestContext<TIncoming>(
            Message message,
            TimeSpan? timeout,
            bool isKeepAlive,
            CancellationToken cancellationToken)
        {
            return new OutboundRequestContext<TIncoming>(
                _connection,
                message,
                timeout,
                isKeepAlive,
                cancellationToken,
                _logger);
        }

        private static bool GetIsKeepAlive(MessageType type, MessageMethod method)
        {
            if (type == MessageType.Request && method == MessageMethod.Handshake)
            {
                return false;
            }

            return true;
        }

        private static TimeSpan GetRequestTimeout(IConnection connection, MessageType type, MessageMethod method)
        {
            if (type == MessageType.Request && method == MessageMethod.Handshake)
            {
                return connection.Options.HandshakeTimeout;
            }

            return connection.Options.RequestTimeout;
        }

        private sealed class NullPayload
        {
        }
    }
}