File: ReversedServer\ReversedDiagnosticsServer.cs
Web Access
Project: src\src\diagnostics\src\Microsoft.Diagnostics.NETCore.Client\Microsoft.Diagnostics.NETCore.Client.csproj (Microsoft.Diagnostics.NETCore.Client)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.IO;
using System.IO.Pipes;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Diagnostics.NETCore.Client
{
    /// <summary>
    /// Establishes server endpoint for runtime instances to connect when
    /// configured to provide diagnostic endpoints in reverse mode.
    /// </summary>
    internal sealed class ReversedDiagnosticsServer : IAsyncDisposable
    {
        // The amount of time to allow parsing of the advertise data before cancelling. This allows the server to
        // remain responsive in case the advertise data is incomplete and the stream is not closed.
        private readonly TimeSpan ParseAdvertiseTimeout;

        private readonly CancellationTokenSource _disposalSource = new();
        private readonly HandleableCollection<IpcEndpointInfo> _endpointInfos = new();
        private readonly ConcurrentDictionary<Guid, HandleableCollection<Stream>> _streamCollections = new();
        private readonly string _address;

        private bool _disposed;
        private Task _acceptTransportTask;
        private IpcServerTransport _transport;
        private Kind _kind = Kind.Ipc;

        public enum Kind
        {
            Tcp,
            Ipc,
            WebSocket,
        }

        /// <summary>
        /// Constructs the <see cref="ReversedDiagnosticsServer"/> instance with an endpoint bound
        /// to the location specified by <paramref name="address"/>.
        /// </summary>
        /// <param name="address">
        /// The server endpoint.
        /// On Windows, this can be a full pipe path or the name without the "\\.\pipe\" prefix.
        /// On all other systems, this must be the full file path of the socket.
        /// </param>
        public ReversedDiagnosticsServer(string address)
        {
            _address = address;
            ParseAdvertiseTimeout = TimeSpan.FromMilliseconds(250);
        }

        /// <summary>
        /// Constructs the <see cref="ReversedDiagnosticsServer"/> instance with an endpoint bound
        /// to the location specified by <paramref name="address"/>.
        /// </summary>
        /// <param name="address">
        /// The server endpoint.
        /// On Windows, this can be a full pipe path or the name without the "\\.\pipe\" prefix.
        /// On all other systems, this must be the full file path of the socket.
        /// When TcpIp is enabled, this can also be host:port of the listening socket.
        /// </param>
        /// <param name="kind">
        /// If kind is WebSocket, start a Kestrel web server.
        /// Otherwise if kind is TcpIp as a supported protocol for ReversedDiagnosticServer. When Kind is Tcp, address will
        /// be analyzed and if on format host:port, ReversedDiagnosticServer will try to bind
        /// a TcpIp listener to host and port, otherwise it will use a Unix domain socket or a Windows named pipe.
        /// </param>
        public ReversedDiagnosticsServer(string address, Kind kind)
        {
            _address = address;
            _kind = kind;
            ParseAdvertiseTimeout = TimeSpan.FromMilliseconds(250);
        }

        /// <summary>
        /// Constructs the <see cref="ReversedDiagnosticsServer"/> instance with an endpoint bound
        /// to the location specified by <paramref name="address"/>.
        /// </summary>
        /// <param name="address">
        /// The server endpoint.
        /// On Windows, this can be a full pipe path or the name without the "\\.\pipe\" prefix.
        /// On all other systems, this must be the full file path of the socket.
        /// </param>
        /// <param name="kind">
        /// If kind is WebSocket, start a Kestrel web server.
        /// Otherwise if kind is TcpIp as a supported protocol for ReversedDiagnosticServer. When Kind is Tcp, address will
        /// be analyzed and if on format host:port, ReversedDiagnosticServer will try to bind
        /// a TcpIp listener to host and port, otherwise it will use a Unix domain socket or a Windows named pipe.
        /// </param>
        /// <param name="timeout">
        /// The amount of time to allow parsing of the advertise data before cancelling. This allows the server to
        /// remain responsive in case the advertise data is incomplete and the stream is not closed.
        /// </param>
        public ReversedDiagnosticsServer(string address, Kind kind, TimeSpan timeout)
        {
            _address = address;
            _kind = kind;
            ParseAdvertiseTimeout = timeout;
        }

        public async ValueTask DisposeAsync()
        {
            if (!_disposed)
            {
                // Dispose the server transport before signaling cancellation in order to prevent the
                // AcceptAsync call on the server transport from recreating the server stream.
                try
                {
                    _transport?.Dispose();
                }
                catch (Exception ex)
                {
                    Debug.Fail(ex.Message);
                }

                _disposalSource.Cancel();

                if (null != _acceptTransportTask)
                {
                    try
                    {
                        await _acceptTransportTask.ConfigureAwait(false);
                    }
                    catch (Exception ex)
                    {
                        Debug.Fail(ex.Message);
                    }
                }

                _endpointInfos.Dispose();

                foreach (HandleableCollection<Stream> streamCollection in _streamCollections.Values)
                {
                    streamCollection.Dispose();
                }

                _streamCollections.Clear();

                _disposalSource.Dispose();

                _disposed = true;
            }
        }

        /// <summary>
        /// Starts listening at the address for new connections.
        /// </summary>
        public void Start()
        {
            Start(MaxAllowedConnections);
        }

        /// <summary>
        /// Starts listening at the address for new connections.
        /// </summary>
        /// <param name="maxConnections">The maximum number of connections the server will support.</param>
        public void Start(int maxConnections)
        {
            VerifyNotDisposed();

            if (IsStarted)
            {
                throw new InvalidOperationException(nameof(ReversedDiagnosticsServer.Start) + " method can only be called once.");
            }

            _transport = IpcServerTransport.Create(_address, maxConnections, _kind, TransportCallback);

            _acceptTransportTask = AcceptTransportAsync(_transport, _disposalSource.Token);

            if (_acceptTransportTask.IsFaulted)
            {
                _acceptTransportTask.Wait(); // Rethrow aggregated exception.
            }
        }

        /// <summary>
        /// Gets endpoint information when a new runtime instance connects to the server.
        /// </summary>
        /// <param name="timeout">The amount of time to wait before cancelling the accept operation.</param>
        /// <returns>An <see cref="IpcEndpointInfo"/> value that contains information about the new runtime instance connection.</returns>
        public IpcEndpointInfo Accept(TimeSpan timeout)
        {
            VerifyNotDisposed();
            VerifyIsStarted();

            return _endpointInfos.Handle(timeout);
        }

        /// <summary>
        /// Gets endpoint information when a new runtime instance connects to the server.
        /// </summary>
        /// <param name="token">The token to monitor for cancellation requests.</param>
        /// <returns>A task that completes with a <see cref="IpcEndpointInfo"/> value that contains information about the new runtime instance connection.</returns>
        public Task<IpcEndpointInfo> AcceptAsync(CancellationToken token)
        {
            VerifyNotDisposed();
            VerifyIsStarted();

            return _endpointInfos.HandleAsync(token);
        }

        /// <summary>
        /// Removes endpoint information from the server so that it is no longer tracked.
        /// </summary>
        /// <param name="runtimeCookie">The runtime instance cookie that corresponds to the endpoint to be removed.</param>
        /// <returns>True if the endpoint existed and was removed; otherwise false.</returns>
        public bool RemoveConnection(Guid runtimeCookie)
        {
            VerifyNotDisposed();
            VerifyIsStarted();

            if (_streamCollections.TryRemove(runtimeCookie, out HandleableCollection<Stream> streamCollection))
            {
                streamCollection.Dispose();
                return true;
            }
            return false;
        }

        private void VerifyNotDisposed()
        {
#pragma warning disable CA1513 // Use ObjectDisposedException throw helper
            if (_disposed)
            {
                throw new ObjectDisposedException(nameof(ReversedDiagnosticsServer));
            }
#pragma warning restore CA1513 // Use ObjectDisposedException throw helper
        }

        private void VerifyIsStarted()
        {
            if (!IsStarted)
            {
                throw new InvalidOperationException(nameof(ReversedDiagnosticsServer.Start) + " method must be called before invoking this operation.");
            }
        }

        /// <summary>
        /// Accept connections from the transport.
        /// </summary>
        /// <param name="transport">The server transport from which connections are accepted.</param>
        /// <param name="token">The token to monitor for cancellation requests.</param>
        /// <returns>A task that completes when the server is no longer listening at the address.</returns>
        private async Task AcceptTransportAsync(IpcServerTransport transport, CancellationToken token)
        {
            while (!token.IsCancellationRequested)
            {
                Stream stream = null;
                IpcAdvertise advertise = null;
                try
                {
                    stream = await transport.AcceptAsync(token).ConfigureAwait(false);
                }
                catch (OperationCanceledException)
                {
                }
                catch (Exception)
                {
                    // The advertise data could be incomplete if the runtime shuts down before completely writing
                    // the information. Catch the exception and continue waiting for a new connection.
                }

                if (null != stream)
                {
                    // Cancel parsing of advertise data after timeout period to
                    // mitigate runtimes that write partial data and do not close the stream (avoid waiting forever).
                    using CancellationTokenSource parseCancellationSource = new();
                    using CancellationTokenSource linkedSource = CancellationTokenSource.CreateLinkedTokenSource(token, parseCancellationSource.Token);
                    try
                    {
                        parseCancellationSource.CancelAfter(ParseAdvertiseTimeout);

                        advertise = await IpcAdvertise.ParseAsync(stream, linkedSource.Token).ConfigureAwait(false);
                    }
                    catch (Exception)
                    {
                        stream.Dispose();
                    }
                }

                if (null != advertise)
                {
                    Guid runtimeCookie = advertise.RuntimeInstanceCookie;
                    int pid = unchecked((int)advertise.ProcessId);

                    // The valueFactory parameter of the GetOrAdd overload that uses Func<TKey, TValue> valueFactory
                    // does not execute the factory under a lock thus it is not thread-safe. Create the collection and
                    // use a thread-safe version of GetOrAdd; use equality comparison on the result to determine if
                    // the new collection was added to the dictionary or if an existing one was returned.
                    HandleableCollection<Stream> newStreamCollection = new();
                    HandleableCollection<Stream> streamCollection = _streamCollections.GetOrAdd(runtimeCookie, newStreamCollection);

                    try
                    {
                        streamCollection.ClearItems();
                        streamCollection.Add(stream);

                        if (newStreamCollection == streamCollection)
                        {
                            ServerIpcEndpoint endpoint = new(this, runtimeCookie);
                            _endpointInfos.Add(new IpcEndpointInfo(endpoint, pid, runtimeCookie));
                        }
                        else
                        {
                            newStreamCollection.Dispose();
                        }
                    }
                    catch (ObjectDisposedException)
                    {
                        // The stream collection could be disposed by RemoveConnection which would cause an
                        // ObjectDisposedException to be thrown if trying to clear/add the stream.
                        stream.Dispose();
                    }
                }
            }
        }

        private HandleableCollection<Stream> GetStreams(Guid runtimeCookie)
        {
            return _streamCollections.GetOrAdd(runtimeCookie, _ => new HandleableCollection<Stream>());
        }

        internal Stream Connect(Guid runtimeInstanceCookie, TimeSpan timeout)
        {
            VerifyNotDisposed();
            VerifyIsStarted();

            return GetStreams(runtimeInstanceCookie).Handle(timeout);
        }

        internal Task<Stream> ConnectAsync(Guid runtimeInstanceCookie, CancellationToken token)
        {
            VerifyNotDisposed();
            VerifyIsStarted();

            return GetStreams(runtimeInstanceCookie).HandleAsync(token);
        }

        internal void WaitForConnection(Guid runtimeInstanceCookie, TimeSpan timeout)
        {
            VerifyNotDisposed();
            VerifyIsStarted();

            GetStreams(runtimeInstanceCookie).Handle(WaitForConnectionHandler, timeout);
        }

        internal async Task WaitForConnectionAsync(Guid runtimeInstanceCookie, CancellationToken token)
        {
            VerifyNotDisposed();
            VerifyIsStarted();

            await GetStreams(runtimeInstanceCookie).HandleAsync(WaitForConnectionHandler, token).ConfigureAwait(false);
        }

        private static bool WaitForConnectionHandler(Stream item, out bool removeItem)
        {
            if (!TestStream(item))
            {
                item?.Dispose();
                removeItem = true;
                return false;
            }

            removeItem = false;
            return true;
        }

        private static bool TestStream(Stream stream)
        {
            if (stream is null)
            {
                throw new ArgumentNullException(nameof(stream));
            };

            if (stream is ExposedSocketNetworkStream networkStream)
            {
                // Update Connected state of socket by sending non-blocking zero-byte data.
                Socket socket = networkStream.Socket;
                bool blocking = socket.Blocking;
                try
                {
                    socket.Blocking = false;
                    socket.Send(Array.Empty<byte>(), 0, SocketFlags.None);
                }
                catch (Exception)
                {
                }
                finally
                {
                    socket.Blocking = blocking;
                }
                return socket.Connected;
            }
            else if (stream is PipeStream pipeStream)
            {
                Debug.Assert(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "Pipe stream should only be used on Windows.");

                // PeekNamedPipe will return false if the pipe is disconnected/broken.
                return NativeMethods.PeekNamedPipe(
                    pipeStream.SafePipeHandle,
                    null,
                    0,
                    IntPtr.Zero,
                    IntPtr.Zero,
                    IntPtr.Zero);
            }
            else if (stream is WebSocketServer.IWebSocketStreamAdapter adapter)
            {
                return adapter.IsConnected;
            }

            return false;
        }

        private bool IsStarted => null != _transport;

        public static int MaxAllowedConnections = IpcServerTransport.MaxAllowedConnections;

        internal IIpcServerTransportCallbackInternal TransportCallback { get; set; }
    }
}