File: FrameworkFork\System.ServiceModel\System\ServiceModel\Channels\SocketConnection.cs
Web Access
Project: src\src\dotnet-svcutil\lib\src\dotnet-svcutil-lib.csproj (dotnet-svcutil-lib)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Diagnostics.Contracts;
using System.Net;
using System.Net.Sockets;
using System.Runtime;
using System.Text;
using System.Threading.Tasks;
 
namespace System.ServiceModel.Channels
{
    internal abstract class SocketConnection : IConnection
    {
        // common state
        protected TimeSpan _sendTimeout;
        protected TimeSpan _receiveTimeout;
        protected CloseState _closeState;
        protected bool _aborted;
 
        // close state
        protected TimeoutHelper _closeTimeoutHelper;
        private bool _isShutdown;
 
        // read state
        protected int _asyncReadSize;
        protected byte[] _readBuffer;
        protected int _asyncReadBufferSize;
        protected object _asyncReadState;
        protected Action<object> _asyncReadCallback;
        protected Exception _asyncReadException;
        protected bool _asyncReadPending;
 
        // write state
        protected object _asyncWriteState;
        protected Action<object> _asyncWriteCallback;
        protected Exception _asyncWriteException;
        protected bool _asyncWritePending;
 
        protected string _timeoutErrorString;
        protected TransferOperation _timeoutErrorTransferOperation;
        private ConnectionBufferPool _connectionBufferPool;
 
        public SocketConnection(ConnectionBufferPool connectionBufferPool)
        {
            Contract.Assert(connectionBufferPool != null, "Argument connectionBufferPool cannot be null");
 
            _closeState = CloseState.Open;
            _connectionBufferPool = connectionBufferPool;
            _readBuffer = _connectionBufferPool.Take();
            _asyncReadBufferSize = _readBuffer.Length;
            _sendTimeout = _receiveTimeout = TimeSpan.MaxValue;
        }
 
        public int AsyncReadBufferSize
        {
            get { return _asyncReadBufferSize; }
        }
 
        public byte[] AsyncReadBuffer
        {
            get
            {
                return _readBuffer;
            }
        }
 
        protected object ThisLock
        {
            get { return this; }
        }
 
        protected abstract IPEndPoint RemoteEndPoint { get; }
 
        protected static void OnReceiveTimeout(object state)
        {
            SocketConnection thisPtr = (SocketConnection)state;
            thisPtr.Abort(string.Format(SRServiceModel.SocketAbortedReceiveTimedOut, thisPtr._receiveTimeout), TransferOperation.Read);
        }
 
        protected static void OnSendTimeout(object state)
        {
            SocketConnection thisPtr = (SocketConnection)state;
            thisPtr.Abort(4,	// TraceEventType.Warning
                string.Format(SRServiceModel.SocketAbortedSendTimedOut, thisPtr._sendTimeout), TransferOperation.Write);
        }
 
        public void Abort()
        {
            Abort(null, TransferOperation.Undefined);
        }
 
        protected void Abort(string timeoutErrorString, TransferOperation transferOperation)
        {
            int traceEventType = 4;	// TraceEventType.Warning;
 
            // we could be timing out a cached connection
 
            Abort(traceEventType, timeoutErrorString, transferOperation);
        }
 
        protected void Abort(int traceEventType)
        {
            Abort(traceEventType, null, TransferOperation.Undefined);
        }
 
        protected abstract void Abort(int traceEventType, string timeoutErrorString, TransferOperation transferOperation);
 
        protected abstract void AbortRead();
 
        public void Close(TimeSpan timeout, bool asyncAndLinger)
        {
            lock (ThisLock)
            {
                if (_closeState == CloseState.Closing || _closeState == CloseState.Closed)
                {
                    // already closing or closed, so just return
                    return;
                }
                _closeState = CloseState.Closing;
            }
 
            _closeTimeoutHelper = new TimeoutHelper(timeout);
 
            // first we shutdown our send-side
            Shutdown(timeout);
            CloseCore(asyncAndLinger);
        }
 
        protected abstract void CloseCore(bool asyncAndLinger);
 
        private void Shutdown(TimeSpan timeout)
        {
            lock (ThisLock)
            {
                if (_isShutdown)
                {
                    return;
                }
 
                _isShutdown = true;
            }
 
            ShutdownCore(timeout);
        }
 
        protected abstract void ShutdownCore(TimeSpan timeout);
 
        protected void ThrowIfNotOpen()
        {
            if (_closeState == CloseState.Closing || _closeState == CloseState.Closed)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    ConvertObjectDisposedException(new ObjectDisposedException(
                    this.GetType().ToString(), SRServiceModel.SocketConnectionDisposed), TransferOperation.Undefined));
            }
        }
 
        protected void ThrowIfClosed()
        {
            if (_closeState == CloseState.Closed)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    ConvertObjectDisposedException(new ObjectDisposedException(
                    this.GetType().ToString(), SRServiceModel.SocketConnectionDisposed), TransferOperation.Undefined));
            }
        }
 
        protected Exception ConvertSendException(SocketException socketException, TimeSpan remainingTime)
        {
            return ConvertTransferException(socketException, _sendTimeout, socketException,
                TransferOperation.Write, _aborted, _timeoutErrorString, _timeoutErrorTransferOperation, this, remainingTime);
        }
 
        protected Exception ConvertReceiveException(SocketException socketException, TimeSpan remainingTime)
        {
            return ConvertTransferException(socketException, _receiveTimeout, socketException,
                TransferOperation.Read, _aborted, _timeoutErrorString, _timeoutErrorTransferOperation, this, remainingTime);
        }
 
        internal static Exception ConvertTransferException(SocketException socketException, TimeSpan timeout, Exception originalException)
        {
            return ConvertTransferException(socketException, timeout, originalException,
                TransferOperation.Undefined, false, null, TransferOperation.Undefined, null, TimeSpan.MaxValue);
        }
 
        protected Exception ConvertObjectDisposedException(ObjectDisposedException originalException, TransferOperation transferOperation)
        {
            if (_timeoutErrorString != null)
            {
                return ConvertTimeoutErrorException(originalException, transferOperation, _timeoutErrorString, _timeoutErrorTransferOperation);
            }
            else if (_aborted)
            {
                return new CommunicationObjectAbortedException(SRServiceModel.SocketConnectionDisposed, originalException);
            }
            else
            {
                return originalException;
            }
        }
 
        private static Exception ConvertTransferException(SocketException socketException, TimeSpan timeout, Exception originalException,
            TransferOperation transferOperation, bool aborted, string timeoutErrorString, TransferOperation timeoutErrorTransferOperation,
            SocketConnection socketConnection, TimeSpan remainingTime)
        {
            if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.ERROR_INVALID_HANDLE)
            {
                return new CommunicationObjectAbortedException(socketException.Message, socketException);
            }
 
            if (timeoutErrorString != null)
            {
                return ConvertTimeoutErrorException(originalException, transferOperation, timeoutErrorString, timeoutErrorTransferOperation);
            }
 
            // 10053 can occur due to our timeout sockopt firing, so map to TimeoutException in that case
            if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAECONNABORTED &&
                remainingTime <= TimeSpan.Zero)
            {
                TimeoutException timeoutException = new TimeoutException(string.Format(SRServiceModel.TcpConnectionTimedOut, timeout), originalException);
                return timeoutException;
            }
 
            if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAENETRESET ||
                (int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAECONNABORTED ||
                (int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAECONNRESET)
            {
                if (aborted)
                {
                    return new CommunicationObjectAbortedException(SRServiceModel.TcpLocalConnectionAborted, originalException);
                }
                else
                {
                    CommunicationException communicationException = new CommunicationException(string.Format(SRServiceModel.TcpConnectionResetError, timeout), originalException);
                    return communicationException;
                }
            }
            else if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAETIMEDOUT)
            {
                TimeoutException timeoutException = new TimeoutException(string.Format(SRServiceModel.TcpConnectionTimedOut, timeout), originalException);
                return timeoutException;
            }
            else
            {
                if (aborted)
                {
                    return new CommunicationObjectAbortedException(string.Format(SRServiceModel.TcpTransferError, (int)socketException.SocketErrorCode, socketException.Message), originalException);
                }
                else
                {
                    CommunicationException communicationException = new CommunicationException(string.Format(SRServiceModel.TcpTransferError, (int)socketException.SocketErrorCode, socketException.Message), originalException);
                    return communicationException;
                }
            }
        }
 
        private static Exception ConvertTimeoutErrorException(Exception originalException,
            TransferOperation transferOperation, string timeoutErrorString, TransferOperation timeoutErrorTransferOperation)
        {
            Contract.Assert(timeoutErrorString != null, "Argument timeoutErrorString must not be null.");
 
            if (transferOperation == timeoutErrorTransferOperation)
            {
                return new TimeoutException(timeoutErrorString, originalException);
            }
            else
            {
                return new CommunicationException(timeoutErrorString, originalException);
            }
        }
 
        public AsyncCompletionResult BeginWrite(byte[] buffer, int offset, int size, bool immediate, TimeSpan timeout,
            Action<object> callback, object state)
        {
            if (WcfEventSource.Instance.SocketAsyncWriteStartIsEnabled())
            {
                TraceWriteStart(size, true);
            }
 
            return BeginWriteCore(buffer, offset, size, immediate, timeout, callback, state);
        }
 
        protected abstract void TraceWriteStart(int size, bool async);
 
        protected abstract AsyncCompletionResult BeginWriteCore(byte[] buffer, int offset, int size, bool immediate, TimeSpan timeout,
            Action<object> callback, object state);
 
        public void EndWrite()
        {
            EndWriteCore();
        }
 
        protected abstract void EndWriteCore();
 
        protected void FinishWrite()
        {
            Action<object> asyncWriteCallback = _asyncWriteCallback;
            object asyncWriteState = _asyncWriteState;
 
            _asyncWriteState = null;
            _asyncWriteCallback = null;
 
            asyncWriteCallback(asyncWriteState);
        }
 
        public void Write(byte[] buffer, int offset, int size, bool immediate, TimeSpan timeout)
        {
            WriteCore(buffer, offset, size, immediate, timeout);
        }
 
        protected abstract void WriteCore(byte[] buffer, int offset, int size, bool immediate, TimeSpan timeout);
 
        public void Write(byte[] buffer, int offset, int size, bool immediate, TimeSpan timeout, BufferManager bufferManager)
        {
            try
            {
                Write(buffer, offset, size, immediate, timeout);
            }
            finally
            {
                bufferManager.ReturnBuffer(buffer);
            }
        }
 
        public int Read(byte[] buffer, int offset, int size, TimeSpan timeout)
        {
            ConnectionUtilities.ValidateBufferBounds(buffer, offset, size);
            ThrowIfNotOpen();
            int bytesRead = ReadCore(buffer, offset, size, timeout, false);
            if (WcfEventSource.Instance.SocketReadStopIsEnabled())
            {
                TraceSocketReadStop(bytesRead, false);
            }
 
            return bytesRead;
        }
 
        protected abstract int ReadCore(byte[] buffer, int offset, int size, TimeSpan timeout, bool closing);
 
        public virtual AsyncCompletionResult BeginRead(int offset, int size, TimeSpan timeout,
            Action<object> callback, object state)
        {
            ConnectionUtilities.ValidateBufferBounds(AsyncReadBufferSize, offset, size);
            this.ThrowIfNotOpen();
            var completionResult = this.BeginReadCore(offset, size, timeout, callback, state);
            if (completionResult == AsyncCompletionResult.Completed && WcfEventSource.Instance.SocketReadStopIsEnabled())
            {
                TraceSocketReadStop(_asyncReadSize, true);
            }
 
            return completionResult;
        }
 
        protected abstract void TraceSocketReadStop(int bytesRead, bool async);
 
        protected abstract AsyncCompletionResult BeginReadCore(int offset, int size, TimeSpan timeout,
            Action<object> callback, object state);
 
        protected void FinishRead()
        {
            if (_asyncReadException != null && WcfEventSource.Instance.SocketReadStopIsEnabled())
            {
                TraceSocketReadStop(_asyncReadSize, true);
            }
 
            Action<object> asyncReadCallback = _asyncReadCallback;
            object asyncReadState = _asyncReadState;
 
            _asyncReadState = null;
            _asyncReadCallback = null;
 
            asyncReadCallback(asyncReadState);
        }
 
        // Both BeginRead/ReadAsync paths completed themselves. EndRead's only job is to deliver the result.
        public int EndRead()
        {
            return EndReadCore();
        }
 
        protected abstract int EndReadCore();
 
        // This method should be called inside ThisLock
        protected void ReturnReadBuffer()
        {
            // We release the buffer only if there is no outstanding I/O
            this.TryReturnReadBuffer();
        }
 
        // This method should be called inside ThisLock
        protected void TryReturnReadBuffer()
        {
            // The buffer must not be returned and nulled when an abort occurs. Since the buffer
            // is also accessed by higher layers, code that has not yet realized the stack is
            // aborted may be attempting to read from the buffer.
            if (_readBuffer != null && !_aborted)
            {
                _connectionBufferPool.Return(_readBuffer);
                _readBuffer = null;
            }
        }
 
        protected enum CloseState
        {
            Open,
            Closing,
            Closed,
        }
 
        protected enum TransferOperation
        {
            Write,
            Read,
            Undefined,
        }
    }
 
    internal abstract class SocketConnectionInitiator : IConnectionInitiator
    {
        private int _bufferSize;
        protected ConnectionBufferPool _connectionBufferPool;
 
        public SocketConnectionInitiator(int bufferSize)
        {
            _bufferSize = bufferSize;
            _connectionBufferPool = new ConnectionBufferPool(bufferSize);
        }
 
        protected abstract IConnection CreateConnection(IPAddress address, int port);
 
        protected abstract Task<IConnection> CreateConnectionAsync(IPAddress address, int port);
 
        public static Exception ConvertConnectException(SocketException socketException, Uri remoteUri, TimeSpan timeSpent, Exception innerException)
        {
            if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.ERROR_INVALID_HANDLE)
            {
                return new CommunicationObjectAbortedException(socketException.Message, socketException);
            }
 
            if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAEADDRNOTAVAIL ||
                (int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAECONNREFUSED ||
                (int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAENETDOWN ||
                (int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAENETUNREACH ||
                (int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAEHOSTDOWN ||
                (int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAEHOSTUNREACH ||
                (int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAETIMEDOUT)
            {
                if (timeSpent == TimeSpan.MaxValue)
                {
                    return new EndpointNotFoundException(string.Format(SRServiceModel.TcpConnectError, remoteUri.AbsoluteUri, (int)socketException.SocketErrorCode, socketException.Message), innerException);
                }
                else
                {
                    return new EndpointNotFoundException(string.Format(SRServiceModel.TcpConnectErrorWithTimeSpan, remoteUri.AbsoluteUri, (int)socketException.SocketErrorCode, socketException.Message, timeSpent), innerException);
                }
            }
            else if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.WSAENOBUFS)
            {
                return new OutOfMemoryException(SRServiceModel.TcpConnectNoBufs, innerException);
            }
            else if ((int)socketException.SocketErrorCode == UnsafeNativeMethods.ERROR_NOT_ENOUGH_MEMORY ||
                (int)socketException.SocketErrorCode == UnsafeNativeMethods.ERROR_NO_SYSTEM_RESOURCES ||
                (int)socketException.SocketErrorCode == UnsafeNativeMethods.ERROR_OUTOFMEMORY)
            {
                return new OutOfMemoryException(SRServiceModel.InsufficentMemory, socketException);
            }
            else
            {
                if (timeSpent == TimeSpan.MaxValue)
                {
                    return new CommunicationException(string.Format(SRServiceModel.TcpConnectError, remoteUri.AbsoluteUri, (int)socketException.SocketErrorCode, socketException.Message), innerException);
                }
                else
                {
                    return new CommunicationException(string.Format(SRServiceModel.TcpConnectErrorWithTimeSpan, remoteUri.AbsoluteUri, (int)socketException.SocketErrorCode, socketException.Message, timeSpent), innerException);
                }
            }
        }
 
        private static async Task<IPAddress[]> GetIPAddressesAsync(Uri uri)
        {
            if (uri.HostNameType == UriHostNameType.IPv4 ||
                uri.HostNameType == UriHostNameType.IPv6)
            {
                IPAddress ipAddress = IPAddress.Parse(uri.DnsSafeHost);
                return new IPAddress[] { ipAddress };
            }
 
            IPAddress[] addresses = null;
 
            try
            {
                addresses = await DnsCache.ResolveAsync(uri);
            }
            catch (SocketException socketException)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new EndpointNotFoundException(string.Format(SRServiceModel.UnableToResolveHost, uri.Host), socketException));
            }
 
            if (addresses.Length == 0)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new EndpointNotFoundException(string.Format(SRServiceModel.UnableToResolveHost, uri.Host)));
            }
 
            return addresses;
        }
 
        private static TimeoutException CreateTimeoutException(Uri uri, TimeSpan timeout, IPAddress[] addresses, int invalidAddressCount,
            SocketException innerException)
        {
            StringBuilder addressStringBuilder = new StringBuilder();
            for (int i = 0; i < invalidAddressCount; i++)
            {
                if (addresses[i] == null)
                {
                    continue;
                }
 
                if (addressStringBuilder.Length > 0)
                {
                    addressStringBuilder.Append(", ");
                }
                addressStringBuilder.Append(addresses[i].ToString());
            }
 
            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException(
                string.Format(SRServiceModel.TcpConnectingToViaTimedOut, uri.AbsoluteUri, timeout.ToString(),
                invalidAddressCount, addresses.Length, addressStringBuilder.ToString()), innerException));
        }
 
        public IConnection Connect(Uri uri, TimeSpan timeout)
        {
            int port = uri.Port;
            IPAddress[] addresses = SocketConnectionInitiator.GetIPAddressesAsync(uri).GetAwaiter().GetResult();
            IConnection socketConnection = null;
            SocketException lastException = null;
 
            if (port == -1)
            {
                port = TcpUri.DefaultPort;
            }
 
            int invalidAddressCount = 0;
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            for (int i = 0; i < addresses.Length; i++)
            {
                if (timeoutHelper.RemainingTime() == TimeSpan.Zero)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                        CreateTimeoutException(uri, timeoutHelper.OriginalTimeout, addresses, invalidAddressCount, lastException));
                }
 
                DateTime connectStartTime = DateTime.UtcNow;
                try
                {
                    socketConnection = CreateConnection(addresses[i], port);
                    lastException = null;
                    break;
                }
                catch (SocketException socketException)
                {
                    invalidAddressCount++;
                    lastException = socketException;
                }
            }
 
            if (socketConnection == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new EndpointNotFoundException(string.Format(SRServiceModel.NoIPEndpointsFoundForHost, uri.Host)));
            }
 
            if (lastException != null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    SocketConnectionInitiator.ConvertConnectException(lastException, uri,
                    timeoutHelper.ElapsedTime(), lastException));
            }
 
            return socketConnection;
        }
 
        public async Task<IConnection> ConnectAsync(Uri uri, TimeSpan timeout)
        {
            int port = uri.Port;
            IPAddress[] addresses = await SocketConnectionInitiator.GetIPAddressesAsync(uri);
            IConnection socketConnection = null;
            SocketException lastException = null;
 
            if (port == -1)
            {
                port = TcpUri.DefaultPort;
            }
 
            int invalidAddressCount = 0;
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            for (int i = 0; i < addresses.Length; i++)
            {
                if (timeoutHelper.RemainingTime() == TimeSpan.Zero)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                        CreateTimeoutException(uri, timeoutHelper.OriginalTimeout, addresses, invalidAddressCount, lastException));
                }
 
                DateTime connectStartTime = DateTime.UtcNow;
                try
                {
                    socketConnection = await CreateConnectionAsync(addresses[i], port);
                    lastException = null;
                    break;
                }
                catch (SocketException socketException)
                {
                    invalidAddressCount++;
                    lastException = socketException;
                }
            }
 
            if (socketConnection == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new EndpointNotFoundException(string.Format(SRServiceModel.NoIPEndpointsFoundForHost, uri.Host)));
            }
 
            if (lastException != null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    SocketConnectionInitiator.ConvertConnectException(lastException, uri,
                    timeoutHelper.ElapsedTime(), lastException));
            }
 
            return socketConnection;
        }
    }
}