File: SocketServer.cs
Web Access
Project: src\src\vstest\src\Microsoft.TestPlatform.CommunicationUtilities\Microsoft.TestPlatform.CommunicationUtilities.csproj (Microsoft.TestPlatform.CommunicationUtilities)
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities.Interfaces;
using Microsoft.VisualStudio.TestPlatform.ObjectModel;
using Microsoft.VisualStudio.TestPlatform.Utilities;

namespace Microsoft.VisualStudio.TestPlatform.CommunicationUtilities;

/// <summary>
/// Communication server implementation over sockets.
/// </summary>
[SuppressMessage("Design", "CA1001:Types that own disposable fields should be disposable", Justification = "Would cause a breaking change if users are inheriting this class and implement IDisposable")]
public class SocketServer : ICommunicationEndPoint
{
    private readonly CancellationTokenSource _cancellation;
    private readonly Func<Stream, ICommunicationChannel> _channelFactory;

    private ICommunicationChannel? _channel;
    private TcpListener? _tcpListener;
    private TcpClient? _tcpClient;
    private bool _stopped;
    private string? _endPoint;

    /// <summary>
    /// Initializes a new instance of the <see cref="SocketServer"/> class.
    /// </summary>
    public SocketServer()
        : this(stream => new LengthPrefixCommunicationChannel(stream))
    {
    }

    /// <summary>
    /// Initializes a new instance of the <see cref="SocketServer"/> class with given channel
    /// factory implementation.
    /// </summary>
    /// <param name="channelFactory">Factory to create communication channel.</param>
    protected SocketServer(Func<Stream, ICommunicationChannel> channelFactory)
    {
        // Used to cancel the message loop
        _cancellation = new CancellationTokenSource();

        _channelFactory = channelFactory;
    }

    /// <inheritdoc />
    public event EventHandler<ConnectedEventArgs>? Connected;

    /// <inheritdoc />
    public event EventHandler<DisconnectedEventArgs>? Disconnected;

    public string? Start(string endPoint)
    {
        try
        {
            _tcpListener = new TcpListener(endPoint.GetIpEndPoint());

            _tcpListener.Start();

            _endPoint = _tcpListener.LocalEndpoint.ToString();
            EqtTrace.Info("SocketServer.Start: Listening on endpoint : {0}", _endPoint);

            // Serves a single client at the moment. An error in connection, or message loop just
            // terminates the entire server.
            _tcpListener.AcceptTcpClientAsync().ContinueWith(t => OnClientConnected(t.Result));
            return _endPoint;
        }
        catch (SocketException ex)
        {
            EqtTrace.Error("Failed for address {0}, with: {1}", endPoint, ex);
            throw;
        }
    }

    /// <inheritdoc />
    public void Stop()
    {
        EqtTrace.Info("SocketServer.Stop: Stop server endPoint: {0}", _endPoint);
        if (!_stopped)
        {
            EqtTrace.Info("SocketServer.Stop: Cancellation requested. Stopping message loop.");
            _cancellation.Cancel();
        }
    }

    private void OnClientConnected(TcpClient client)
    {
        _tcpClient = client;
        _tcpClient.Client.NoDelay = true;

        if (Connected == null)
        {
            return;
        }

        _channel = _channelFactory(_tcpClient.GetStream());
        Connected.SafeInvoke(this, new ConnectedEventArgs(_channel), "SocketServer: ClientConnected");

        EqtTrace.Verbose("SocketServer.OnClientConnected: Client connected for endPoint: {0}, starting MessageLoopAsync:", _endPoint);

        // Start the message loop
        Task.Run(() => _tcpClient.MessageLoopAsync(_channel, error => StopOnError(error), _cancellation.Token)).ConfigureAwait(false);
    }

    /// <summary>
    /// Stop the connection when error was encountered. Dispose all communication, and notify subscribers of Disconnected event
    /// that we aborted.
    /// </summary>
    /// <param name="error"></param>
    private void StopOnError(Exception? error)
    {
        EqtTrace.Info("SocketServer.PrivateStop: Stopping server endPoint: {0} error: {1}", _endPoint, error);

        if (_stopped)
        {
            return;
        }

        TPDebug.Assert(_tcpListener is not null, $"{nameof(_tcpListener)} is null");
        TPDebug.Assert(_channel is not null, $"{nameof(_channel)} is null");

        // Do not allow stop to be called multiple times.
        _stopped = true;

        // Stop accepting any other connections
        _tcpListener.Stop();

        // Close the client and dispose the underlying stream
        // tcpClient.Close() calls tcpClient.Dispose().
        _tcpClient?.Close();
        _channel.Dispose();
        _cancellation.Dispose();

        EqtTrace.Info("SocketServer.Stop: Raise disconnected event endPoint: {0} error: {1}", _endPoint, error);
        Disconnected?.SafeInvoke(this, new DisconnectedEventArgs { Error = error }, "SocketServer: ClientDisconnected");
    }
}