File: SocketClient.cs
Web Access
Project: src\src\vstest\src\Microsoft.TestPlatform.CommunicationUtilities\Microsoft.TestPlatform.CommunicationUtilities.csproj (Microsoft.TestPlatform.CommunicationUtilities)
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.VisualStudio.TestPlatform.CommunicationUtilities.Interfaces;
using Microsoft.VisualStudio.TestPlatform.ObjectModel;
using Microsoft.VisualStudio.TestPlatform.Utilities;

namespace Microsoft.VisualStudio.TestPlatform.CommunicationUtilities;

/// <summary>
/// Communication client implementation over sockets.
/// </summary>
[SuppressMessage("Design", "CA1001:Types that own disposable fields should be disposable", Justification = "Would cause a breaking change if users are inheriting this class and implement IDisposable")]
public class SocketClient : ICommunicationEndPoint
{
    private readonly CancellationTokenSource _cancellation;
    private readonly TcpClient _tcpClient;
    private readonly Func<Stream, ICommunicationChannel> _channelFactory;

    private ICommunicationChannel? _channel;
    private bool _stopped;
    private string? _endPoint;

    public SocketClient()
        : this(stream => new LengthPrefixCommunicationChannel(stream))
    {
    }

    protected SocketClient(Func<Stream, ICommunicationChannel> channelFactory)
    {
        // Used to cancel the message loop
        _cancellation = new CancellationTokenSource();
        _stopped = false;

        _tcpClient = new TcpClient { NoDelay = true };
        _channelFactory = channelFactory;
    }

    /// <inheritdoc />
    public event EventHandler<ConnectedEventArgs>? Connected;

    /// <inheritdoc />
    public event EventHandler<DisconnectedEventArgs>? Disconnected;

    /// <inheritdoc />
    public string Start(string endPoint)
    {
        _endPoint = endPoint;
        var ipEndPoint = endPoint.GetIpEndPoint();

        EqtTrace.Info("SocketClient.Start: connecting to server endpoint: {0}", endPoint);

        // Don't start if the endPoint port is zero
        _tcpClient.ConnectAsync(ipEndPoint.Address, ipEndPoint.Port).ContinueWith(OnServerConnected);
        return ipEndPoint.ToString();
    }

    /// <inheritdoc />
    public void Stop()
    {
        EqtTrace.Info("SocketClient.Stop: Stop communication from server endpoint: {0}", _endPoint);

        if (!_stopped)
        {
            EqtTrace.Info("SocketClient: Stop: Cancellation requested. Stopping message loop.");

            try
            {
                _cancellation.Cancel();
            }
            catch (ObjectDisposedException)
            {
                // This is race condition with stop on error.
            }
        }
    }

    private void OnServerConnected(Task connectAsyncTask)
    {
        EqtTrace.Info("SocketClient.OnServerConnected: connected to server endpoint: {0}", _endPoint);

        if (Connected == null)
        {
            return;
        }

        if (connectAsyncTask.IsFaulted)
        {
            Connected.SafeInvoke(this, new ConnectedEventArgs(connectAsyncTask.Exception), "SocketClient: Server Failed to Connect");
            EqtTrace.Verbose("Unable to connect to server, Exception occurred: {0}", connectAsyncTask.Exception);
            return;
        }

        _channel = _channelFactory(_tcpClient.GetStream());
        Connected.SafeInvoke(this, new ConnectedEventArgs(_channel), "SocketClient: ServerConnected");

        EqtTrace.Verbose("Connected to server, and starting MessageLoopAsync");

        // Start the message loop
        Task.Run(() => _tcpClient.MessageLoopAsync(
                _channel,
                StopOnError,
                _cancellation.Token))
            .ConfigureAwait(false);
    }

    private void StopOnError(Exception? error)
    {
        EqtTrace.Info("SocketClient.PrivateStop: Stop communication from server endpoint: {0}, error:{1}", _endPoint, error);
        // This is here to prevent stack overflow.
        if (!_stopped)
        {
            // Do not allow stop to be called multiple times.
            _stopped = true;

            // Close the client and dispose the underlying stream
            // tcpClient.Close() calls tcpClient.Dispose().
            _tcpClient?.Close();
            _channel?.Dispose();
            _cancellation.Dispose();

            Disconnected?.SafeInvoke(this, new DisconnectedEventArgs(), "SocketClient: ServerDisconnected");
        }
    }
}