File: System\ServiceModel\Channels\WebSocketTransportDuplexSessionChannel.cs
Web Access
Project: src\src\System.ServiceModel.Http\src\System.ServiceModel.Http.csproj (System.ServiceModel.Http)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
 
using System.Diagnostics.Contracts;
using System.IO;
using System.Net.WebSockets;
using System.Runtime;
using System.Threading;
using System.Threading.Tasks;
 
namespace System.ServiceModel.Channels
{
    internal abstract class WebSocketTransportDuplexSessionChannel : TransportDuplexSessionChannel
    {
        private static readonly AsyncCallback s_streamedWriteCallback = Fx.ThunkCallback(StreamWriteCallback);
        private readonly WebSocketCloseDetails _webSocketCloseDetails = new WebSocketCloseDetails();
        private Action<object> _waitCallback;
        private WebSocket _webSocket;
        private WebSocketStream _webSocketStream;
        private object _state;
        private int _cleanupStatus = WebSocketHelper.OperationNotStarted;
        private bool _shouldDisposeWebSocketAfterClosed = true;
        private Exception _pendingWritingMessageException;
 
        public WebSocketTransportDuplexSessionChannel(HttpChannelFactory<IDuplexSessionChannel> channelFactory, EndpointAddress remoteAddress, Uri via)
            : base(channelFactory, channelFactory, EndpointAddress.AnonymousAddress, channelFactory.MessageVersion.Addressing.AnonymousUri, remoteAddress, via)
        {
            Fx.Assert(channelFactory.WebSocketSettings != null, "channelFactory.WebSocketTransportSettings should not be null.");
            WebSocketSettings = channelFactory.WebSocketSettings;
            TransferMode = channelFactory.TransferMode;
            MaxBufferSize = channelFactory.MaxBufferSize;
            TransportFactorySettings = channelFactory;
        }
 
        protected WebSocket WebSocket
        {
            get
            {
                return _webSocket;
            }
 
            set
            {
                Fx.Assert(value != null, "value should not be null.");
                Fx.Assert(_webSocket == null, "webSocket should not be set before this set call.");
                _webSocket = value;
            }
        }
 
        protected WebSocketTransportSettings WebSocketSettings { get; }
 
        protected TransferMode TransferMode { get; }
 
        protected int MaxBufferSize { get; }
 
        protected ITransportFactorySettings TransportFactorySettings { get; }
 
        protected override void OnAbort()
        {
            if (WcfEventSource.Instance.WebSocketConnectionAbortedIsEnabled())
            {
                WcfEventSource.Instance.WebSocketConnectionAborted(
                    EventTraceActivity,
                    WebSocket != null ? WebSocket.GetHashCode() : -1);
            }
 
            Cleanup();
        }
 
        protected override void CompleteClose(TimeSpan timeout)
        {
            if (WcfEventSource.Instance.WebSocketCloseSentIsEnabled())
            {
                WcfEventSource.Instance.WebSocketCloseSent(
                    WebSocket.GetHashCode(),
                    _webSocketCloseDetails.OutputCloseStatus.ToString(),
                    RemoteAddress != null ? RemoteAddress.ToString() : string.Empty);
            }
 
            Task closeTask = CloseAsync();
            closeTask.Wait(timeout, WebSocketHelper.ThrowCorrectException, WebSocketHelper.CloseOperation);
 
            if (WcfEventSource.Instance.WebSocketConnectionClosedIsEnabled())
            {
                WcfEventSource.Instance.WebSocketConnectionClosed(WebSocket.GetHashCode());
            }
        }
 
        protected override void CloseOutputSessionCore(TimeSpan timeout)
        {
            if (WcfEventSource.Instance.WebSocketCloseOutputSentIsEnabled())
            {
                WcfEventSource.Instance.WebSocketCloseOutputSent(
                    WebSocket.GetHashCode(),
                    _webSocketCloseDetails.OutputCloseStatus.ToString(),
                    RemoteAddress != null ? RemoteAddress.ToString() : string.Empty);
            }
 
            Task task = CloseOutputAsync(CancellationToken.None);
            task.Wait(timeout, WebSocketHelper.ThrowCorrectException, WebSocketHelper.CloseOperation);
        }
 
        protected override async Task CloseOutputSessionCoreAsync(TimeSpan timeout)
        {
            if (WcfEventSource.Instance.WebSocketCloseOutputSentIsEnabled())
            {
                WcfEventSource.Instance.WebSocketCloseOutputSent(
                    WebSocket.GetHashCode(),
                    _webSocketCloseDetails.OutputCloseStatus.ToString(),
                    RemoteAddress != null ? RemoteAddress.ToString() : string.Empty);
            }
 
            var helper = new TimeoutHelper(timeout);
            var cancelToken = await helper.GetCancellationTokenAsync();
            try
            {
                await CloseOutputAsync(cancelToken);
            }
            catch (Exception ex)
            {
                if (Fx.IsFatal(ex))
                {
                    throw;
                }
 
                if (cancelToken.IsCancellationRequested)
                {
                    throw Fx.Exception.AsError(new TimeoutException(InternalSR.TaskTimedOutError(timeout)));
                }
 
                throw WebSocketHelper.ConvertAndTraceException(ex, timeout, WebSocketHelper.ReceiveOperation);
            }
        }
 
        protected override void OnClose(TimeSpan timeout)
        {
            try
            {
                base.OnClose(timeout);
            }
            finally
            {
                Cleanup();
            }
        }
 
        protected internal override async Task OnCloseAsync(TimeSpan timeout)
        {
            try
            {
                await base.OnCloseAsync(timeout);
            }
            finally
            {
                Cleanup();
            }
        }
 
        protected override void ReturnConnectionIfNecessary(bool abort, TimeSpan timeout)
        {
        }
 
        protected override AsyncCompletionResult StartWritingBufferedMessage(Message message, ArraySegment<byte> messageData, bool allowOutputBatching, TimeSpan timeout, Action<object> callback, object state)
        {
            Contract.Assert(callback != null, "callback should not be null.");
            ConnectionUtilities.ValidateBufferBounds(messageData);
 
            TimeoutHelper helper = new TimeoutHelper(timeout);
            WebSocketMessageType outgoingMessageType = GetWebSocketMessageType(message);
 
            if (WcfEventSource.Instance.WebSocketAsyncWriteStartIsEnabled())
            {
                WcfEventSource.Instance.WebSocketAsyncWriteStart(
                    WebSocket.GetHashCode(),
                    messageData.Count,
                    RemoteAddress != null ? RemoteAddress.ToString() : string.Empty);
            }
 
            Task task = WebSocket.SendAsync(messageData, outgoingMessageType, true, helper.GetCancellationToken());
            Contract.Assert(_pendingWritingMessageException == null, "'pendingWritingMessageException' MUST be NULL at this point.");
 
            if (task.IsCompleted)
            {
                if (WcfEventSource.Instance.WebSocketAsyncWriteStopIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketAsyncWriteStop(WebSocket.GetHashCode());
                }
 
                _pendingWritingMessageException = WebSocketHelper.CreateExceptionOnTaskFailure(task, timeout, WebSocketHelper.SendOperation);
                return AsyncCompletionResult.Completed;
            }
 
            HandleSendAsyncCompletion(task, timeout, callback, state);
            return AsyncCompletionResult.Queued;
        }
 
        protected override void FinishWritingMessage()
        {
            ThrowOnPendingException(ref _pendingWritingMessageException);
            base.FinishWritingMessage();
        }
 
        protected override AsyncCompletionResult StartWritingStreamedMessage(Message message, TimeSpan timeout, Action<object> callback, object state)
        {
            WebSocketMessageType outgoingMessageType = GetWebSocketMessageType(message);
            WebSocketStream webSocketStream = new WebSocketStream(WebSocket, outgoingMessageType, ((IDefaultCommunicationTimeouts)this).CloseTimeout);
 
            _waitCallback = callback;
            _state = state;
            _webSocketStream = webSocketStream;
            IAsyncResult result = MessageEncoder.BeginWriteMessage(message, new TimeoutStream(webSocketStream, timeout), s_streamedWriteCallback, this);
 
            if (!result.CompletedSynchronously)
            {
                return AsyncCompletionResult.Queued;
            }
 
            MessageEncoder.EndWriteMessage(result);
 
            webSocketStream.WriteEndOfMessageAsync(callback, state);
            return AsyncCompletionResult.Queued;
        }
 
        // TODO: Make TimeoutHelper disposeable which disposes it's cancellation token source
        protected override AsyncCompletionResult BeginCloseOutput(TimeSpan timeout, Action<object> callback, object state)
        {
            Fx.Assert(callback != null, "callback should not be null.");
 
            var helper = new TimeoutHelper(timeout);
            Task task = CloseOutputAsync(helper.GetCancellationToken());
            Fx.Assert(_pendingWritingMessageException == null, "'pendingWritingMessageException' MUST be NULL at this point.");
 
            if (task.IsCompleted)
            {
                _pendingWritingMessageException = WebSocketHelper.CreateExceptionOnTaskFailure(task, timeout, WebSocketHelper.CloseOperation);
                return AsyncCompletionResult.Completed;
            }
 
            HandleCloseOutputAsyncCompletion(task, timeout, callback, state);
            return AsyncCompletionResult.Queued;
        }
 
        protected override void OnSendCore(Message message, TimeSpan timeout)
        {
            Fx.Assert(message != null, "message should not be null.");
 
            TimeoutHelper helper = new TimeoutHelper(timeout);
            WebSocketMessageType outgoingMessageType = GetWebSocketMessageType(message);
 
            if (IsStreamedOutput)
            {
                WebSocketStream webSocketStream = new WebSocketStream(WebSocket, outgoingMessageType, ((IDefaultCommunicationTimeouts)this).CloseTimeout);
                TimeoutStream timeoutStream = new TimeoutStream(webSocketStream, timeout);
                MessageEncoder.WriteMessage(message, timeoutStream);
                webSocketStream.WriteEndOfMessage();
            }
            else
            {
                ArraySegment<byte> messageData = EncodeMessage(message);
                bool success = false;
                try
                {
                    if (WcfEventSource.Instance.WebSocketAsyncWriteStartIsEnabled())
                    {
                        WcfEventSource.Instance.WebSocketAsyncWriteStart(
                            WebSocket.GetHashCode(),
                            messageData.Count,
                            RemoteAddress != null ? RemoteAddress.ToString() : string.Empty);
                    }
 
                    Task task = WebSocket.SendAsync(messageData, outgoingMessageType, true, helper.GetCancellationToken());
                    task.Wait(helper.RemainingTime(), WebSocketHelper.ThrowCorrectException, WebSocketHelper.SendOperation);
 
                    if (WcfEventSource.Instance.WebSocketAsyncWriteStopIsEnabled())
                    {
                        WcfEventSource.Instance.WebSocketAsyncWriteStop(_webSocket.GetHashCode());
                    }
 
                    success = true;
                }
                finally
                {
                    try
                    {
                        BufferManager.ReturnBuffer(messageData.Array);
                    }
                    catch (Exception ex)
                    {
                        if (Fx.IsFatal(ex) || success)
                        {
                            throw;
                        }
 
                        FxTrace.Exception.TraceUnhandledException(ex);
                    }
                }
            }
        }
 
        protected override ArraySegment<byte> EncodeMessage(Message message)
        {
            return MessageEncoder.WriteMessage(message, int.MaxValue, BufferManager, 0);
        }
 
        protected void Cleanup()
        {
            if (Interlocked.CompareExchange(ref _cleanupStatus, WebSocketHelper.OperationFinished, WebSocketHelper.OperationNotStarted) == WebSocketHelper.OperationNotStarted)
            {
                OnCleanup();
            }
        }
 
        protected virtual void OnCleanup()
        {
            Fx.Assert(_cleanupStatus == WebSocketHelper.OperationFinished,
                "This method should only be called by this.Cleanup(). Make sure that you never call overriden OnCleanup()-methods directly in subclasses");
            if (_shouldDisposeWebSocketAfterClosed && _webSocket != null)
            {
                _webSocket.Dispose();
            }
        }
 
        private static void ThrowOnPendingException(ref Exception pendingException)
        {
            Exception exceptionToThrow = pendingException;
 
            if (exceptionToThrow != null)
            {
                pendingException = null;
                throw FxTrace.Exception.AsError(exceptionToThrow);
            }
        }
 
        private Task CloseAsync()
        {
            try
            {
                if (WebSocket.State == WebSocketState.Closed)
                {
                    return Task.CompletedTask; // Nothing to do here.
                }
 
                return WebSocket.CloseAsync(_webSocketCloseDetails.OutputCloseStatus, _webSocketCloseDetails.OutputCloseStatusDescription, CancellationToken.None);
            }
            catch (Exception e)
            {
                if (Fx.IsFatal(e))
                {
                    throw;
                }
 
                throw WebSocketHelper.ConvertAndTraceException(e);
            }
        }
 
        private Task CloseOutputAsync(CancellationToken cancellationToken)
        {
            try
            {
                return WebSocket.CloseOutputAsync(_webSocketCloseDetails.OutputCloseStatus, _webSocketCloseDetails.OutputCloseStatusDescription, cancellationToken);
            }
            catch (Exception e)
            {
                if (Fx.IsFatal(e))
                {
                    throw;
                }
 
                throw WebSocketHelper.ConvertAndTraceException(e);
            }
        }
 
        private static WebSocketMessageType GetWebSocketMessageType(Message message)
        {
            return WebSocketDefaults.DefaultWebSocketMessageType;
        }
 
        private async void HandleCloseOutputAsyncCompletion(Task task, TimeSpan timeout, Action<object> callback, object state)
        {
            try
            {
                await task;
            }
            catch (Exception)
            {
                _pendingWritingMessageException = WebSocketHelper.CreateExceptionOnTaskFailure(task, timeout, WebSocketHelper.CloseOperation);
            }
            finally
            {
                callback.Invoke(state);
            }
        }
 
        private async void HandleSendAsyncCompletion(Task task, TimeSpan timeout, Action<object> callback, object state)
        {
            try
            {
                await task;
            }
            catch (Exception)
            {
                _pendingWritingMessageException = WebSocketHelper.CreateExceptionOnTaskFailure(task, timeout,
                    WebSocketHelper.SendOperation);
            }
            finally
            {
                if (WcfEventSource.Instance.WebSocketAsyncWriteStopIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketAsyncWriteStop(WebSocket.GetHashCode());
                }
 
                callback.Invoke(state);
            }
        }
 
        private static void StreamWriteCallback(IAsyncResult ar)
        {
            if (ar.CompletedSynchronously)
            {
                return;
            }
 
            WebSocketTransportDuplexSessionChannel thisPtr = (WebSocketTransportDuplexSessionChannel)ar.AsyncState;
 
            try
            {
                thisPtr.MessageEncoder.EndWriteMessage(ar);
                thisPtr._webSocketStream.WriteEndOfMessage();
                thisPtr._waitCallback.Invoke(thisPtr._state);
            }
            catch (Exception ex)
            {
                if (Fx.IsFatal(ex))
                {
                    throw;
                }
            }
        }
 
        protected class WebSocketMessageSource : IMessageSource
        {
            private MessageEncoder _encoder;
            private BufferManager _bufferManager;
            private EndpointAddress _localAddress;
            private Message _pendingMessage;
            private Exception _pendingException;
            private WebSocket _webSocket;
            private bool _closureReceived;
            private bool _useStreaming;
            private int _receiveBufferSize;
            private int _maxBufferSize;
            private long _maxReceivedMessageSize;
            private TaskCompletionSource<object> _streamWaitTask;
            private IDefaultCommunicationTimeouts _defaultTimeouts;
            private WebSocketCloseDetails _closeDetails;
            private TimeSpan _asyncReceiveTimeout;
            private TaskCompletionSource<object> _receiveTask;
            private int _asyncReceiveState;
 
            public WebSocketMessageSource(WebSocketTransportDuplexSessionChannel webSocketTransportDuplexSessionChannel, WebSocket webSocket,
                    bool useStreaming, IDefaultCommunicationTimeouts defaultTimeouts)
            {
                Initialize(webSocketTransportDuplexSessionChannel, webSocket, useStreaming, defaultTimeouts);
 
                StartNextReceiveAsync();
            }
 
            private void Initialize(WebSocketTransportDuplexSessionChannel webSocketTransportDuplexSessionChannel, WebSocket webSocket, bool useStreaming, IDefaultCommunicationTimeouts defaultTimeouts)
            {
                _webSocket = webSocket;
                _encoder = webSocketTransportDuplexSessionChannel.MessageEncoder;
                _bufferManager = webSocketTransportDuplexSessionChannel.BufferManager;
                _localAddress = webSocketTransportDuplexSessionChannel.LocalAddress;
                _maxBufferSize = webSocketTransportDuplexSessionChannel.MaxBufferSize;
                _maxReceivedMessageSize = webSocketTransportDuplexSessionChannel.TransportFactorySettings.MaxReceivedMessageSize;
                _receiveBufferSize = Math.Min(WebSocketHelper.GetReceiveBufferSize(_maxReceivedMessageSize), _maxBufferSize);
                _useStreaming = useStreaming;
                _defaultTimeouts = defaultTimeouts;
                _closeDetails = webSocketTransportDuplexSessionChannel._webSocketCloseDetails;
                _asyncReceiveTimeout = _defaultTimeouts.ReceiveTimeout;
                _asyncReceiveState = AsyncReceiveState.Finished;
            }
 
            private static void OnAsyncReceiveCancelled(object target)
            {
                WebSocketMessageSource messageSource = (WebSocketMessageSource)target;
                messageSource.AsyncReceiveCancelled();
            }
 
            private void AsyncReceiveCancelled()
            {
                if (Interlocked.CompareExchange(ref _asyncReceiveState, AsyncReceiveState.Cancelled, AsyncReceiveState.Started) == AsyncReceiveState.Started)
                {
                    _receiveTask.SetResult(null);
                }
            }
 
            public async Task<Message> ReceiveAsync(TimeSpan timeout)
            {
                bool waitingResult = await _receiveTask.Task.AwaitWithTimeout(timeout);
                ThrowOnPendingException(ref _pendingException);
 
                if (!waitingResult)
                {
                    throw FxTrace.Exception.AsError(new TimeoutException(
                               SR.Format(SR.WaitForMessageTimedOut, timeout),
                               TimeoutHelper.CreateEnterTimedOutException(timeout)));
                }
 
                Message message = GetPendingMessage();
 
                if (message != null)
                {
                    StartNextReceiveAsync();
                }
 
                return message;
            }
 
            public Message Receive(TimeSpan timeout)
            {
                return ReceiveAsyncInternal(timeout).WaitForCompletionNoSpin();
            }
 
            private async Task<Message> ReceiveAsyncInternal(TimeSpan timeout)
            {
                await TaskHelpers.EnsureDefaultTaskScheduler();
                return await ReceiveAsync(timeout);
            }
 
            private async Task ReadBufferedMessageAsync()
            {
                byte[] internalBuffer = null;
                try
                {
                    internalBuffer = _bufferManager.TakeBuffer(_receiveBufferSize);
 
                    int receivedByteCount = 0;
                    bool endOfMessage = false;
                    WebSocketReceiveResult result = null;
                    do
                    {
                        try
                        {
                            if (WcfEventSource.Instance.WebSocketAsyncReadStartIsEnabled())
                            {
                                WcfEventSource.Instance.WebSocketAsyncReadStart(_webSocket.GetHashCode());
                            }
 
                            result = await _webSocket.ReceiveAsync(
                                                new ArraySegment<byte>(internalBuffer, receivedByteCount, internalBuffer.Length - receivedByteCount),
                                                CancellationToken.None);
 
                            CheckCloseStatus(result);
                            endOfMessage = result.EndOfMessage;
 
                            receivedByteCount += result.Count;
                            if (receivedByteCount >= internalBuffer.Length && !result.EndOfMessage)
                            {
                                if (internalBuffer.Length >= _maxBufferSize)
                                {
                                    _pendingException = FxTrace.Exception.AsError(new QuotaExceededException(SR.Format(SR.MaxReceivedMessageSizeExceeded, _maxBufferSize)));
                                    return;
                                }
 
                                int newSize = (int)Math.Min(((double)internalBuffer.Length) * 2, _maxBufferSize);
                                Fx.Assert(newSize > 0, "buffer size should be larger than zero.");
                                byte[] newBuffer = _bufferManager.TakeBuffer(newSize);
                                Buffer.BlockCopy(internalBuffer, 0, newBuffer, 0, receivedByteCount);
                                _bufferManager.ReturnBuffer(internalBuffer);
                                internalBuffer = newBuffer;
                            }
 
                            if (WcfEventSource.Instance.WebSocketAsyncReadStopIsEnabled())
                            {
                                WcfEventSource.Instance.WebSocketAsyncReadStop(
                                    _webSocket.GetHashCode(),
                                    receivedByteCount,
                                    string.Empty);
                            }
                        }
                        catch (AggregateException ex)
                        {
                            WebSocketHelper.ThrowCorrectException(ex, TimeSpan.MaxValue, WebSocketHelper.ReceiveOperation);
                        }
                    }
                    while (!endOfMessage && !_closureReceived);
 
                    byte[] buffer = null;
                    bool success = false;
                    try
                    {
                        buffer = _bufferManager.TakeBuffer(receivedByteCount);
                        Buffer.BlockCopy(internalBuffer, 0, buffer, 0, receivedByteCount);
                        Fx.Assert(result != null, "Result should not be null");
                        _pendingMessage = PrepareMessage(result, buffer, receivedByteCount);
                        success = true;
                    }
                    finally
                    {
                        if (buffer != null && (!success || _pendingMessage == null))
                        {
                            _bufferManager.ReturnBuffer(buffer);
                        }
                    }
                }
                catch (Exception ex)
                {
                    if (Fx.IsFatal(ex))
                    {
                        throw;
                    }
 
                    _pendingException = WebSocketHelper.ConvertAndTraceException(ex, TimeSpan.MaxValue, WebSocketHelper.ReceiveOperation);
                }
                finally
                {
                    if (internalBuffer != null)
                    {
                        _bufferManager.ReturnBuffer(internalBuffer);
                    }
                }
            }
 
            public bool WaitForMessage(TimeSpan timeout)
            {
                try
                {
                    Message message = Receive(timeout);
                    _pendingMessage = message;
                    return true;
                }
                catch (TimeoutException exception)
                {
                    if (WcfEventSource.Instance.ReceiveTimeoutIsEnabled())
                    {
                        WcfEventSource.Instance.ReceiveTimeout(exception.Message);
                    }
 
                    return false;
                }
            }
 
            public async Task<bool> WaitForMessageAsync(TimeSpan timeout)
            {
                bool waitingResult = await _receiveTask.Task.AwaitWithTimeout(timeout);
                if (waitingResult)
                {
                    Message message = await ReceiveAsync(timeout);
                    _pendingMessage = message;
                    return true;
                }
                if (WcfEventSource.Instance.ReceiveTimeoutIsEnabled())
                {
                    WcfEventSource.Instance.ReceiveTimeout(SR.Format(SR.WaitForMessageTimedOut, timeout));
                }
                return false;
            }
 
            internal void FinishUsingMessageStream(Exception ex)
            {
                //// The pattern of the task here is:
                //// 1) Only one thread can get the stream and consume the stream. A new task will be created at the moment it takes the stream
                //// 2) Only one another thread can enter the lock and wait on the task
                //// 3) The cleanup on the stream will return the stream to message source. And the cleanup call is limited to be called only once.
                if (ex != null && _pendingException == null)
                {
                    _pendingException = ex;
                }
 
                _streamWaitTask.SetResult(null);
            }
 
            internal void CheckCloseStatus(WebSocketReceiveResult result)
            {
                if (result.MessageType == WebSocketMessageType.Close)
                {
                    if (WcfEventSource.Instance.WebSocketCloseStatusReceivedIsEnabled())
                    {
                        WcfEventSource.Instance.WebSocketCloseStatusReceived(
                            _webSocket.GetHashCode(),
                            result.CloseStatus.ToString());
                    }
 
                    _closureReceived = true;
                    _closeDetails.InputCloseStatus = result.CloseStatus;
                    _closeDetails.InputCloseStatusDescription = result.CloseStatusDescription;
                }
            }
 
            private async void StartNextReceiveAsync()
            {
                Fx.Assert(_receiveTask == null || _receiveTask.Task.IsCompleted, "this.receiveTask is not completed.");
                _receiveTask = new TaskCompletionSource<object>();
                int currentState = Interlocked.CompareExchange(ref _asyncReceiveState, AsyncReceiveState.Started, AsyncReceiveState.Finished);
                Fx.Assert(currentState == AsyncReceiveState.Finished, "currentState is not AsyncReceiveState.Finished: " + currentState);
                if (currentState != AsyncReceiveState.Finished)
                {
                    throw FxTrace.Exception.AsError(new InvalidOperationException());
                }
 
                try
                {
                    if (_useStreaming)
                    {
                        if (_streamWaitTask != null)
                        {
                            //// Wait until the previous stream message finished.
                            await _streamWaitTask.Task;
                        }
 
                        _streamWaitTask = new TaskCompletionSource<object>();
                    }
 
                    if (_pendingException == null)
                    {
                        if (!_useStreaming)
                        {
                            await ReadBufferedMessageAsync();
                        }
                        else
                        {
                            byte[] buffer = _bufferManager.TakeBuffer(_receiveBufferSize);
                            bool success = false;
                            try
                            {
                                if (WcfEventSource.Instance.WebSocketAsyncReadStartIsEnabled())
                                {
                                    WcfEventSource.Instance.WebSocketAsyncReadStart(_webSocket.GetHashCode());
                                }
 
                                try
                                {
                                    WebSocketReceiveResult result;
                                    TimeoutHelper helper = new TimeoutHelper(_asyncReceiveTimeout);
                                    var cancelToken = await helper.GetCancellationTokenAsync();
                                    result = await _webSocket.ReceiveAsync(new ArraySegment<byte>(buffer, 0, _receiveBufferSize), cancelToken);
                                    CheckCloseStatus(result);
                                    _pendingMessage = PrepareMessage(result, buffer, result.Count);
 
                                    if (WcfEventSource.Instance.WebSocketAsyncReadStopIsEnabled())
                                    {
                                        WcfEventSource.Instance.WebSocketAsyncReadStop(
                                            _webSocket.GetHashCode(),
                                            result.Count,
                                            String.Empty);
                                    }
                                }
                                catch (AggregateException ex)
                                {
                                    WebSocketHelper.ThrowCorrectException(ex, _asyncReceiveTimeout, WebSocketHelper.ReceiveOperation);
                                }
                                success = true;
                            }
                            catch (Exception ex)
                            {
                                if (Fx.IsFatal(ex))
                                {
                                    throw;
                                }
 
                                _pendingException = WebSocketHelper.ConvertAndTraceException(ex, _asyncReceiveTimeout, WebSocketHelper.ReceiveOperation);
                            }
                            finally
                            {
                                if (!success)
                                {
                                    _bufferManager.ReturnBuffer(buffer);
                                }
                            }
                        }
                    }
                }
                finally
                {
                    if (Interlocked.CompareExchange(ref _asyncReceiveState, AsyncReceiveState.Finished, AsyncReceiveState.Started) == AsyncReceiveState.Started)
                    {
                        _receiveTask.SetResult(null);
                    }
                }
            }
 
            private Message GetPendingMessage()
            {
                ThrowOnPendingException(ref _pendingException);
 
                if (_pendingMessage != null)
                {
                    Message pendingMessage = _pendingMessage;
                    _pendingMessage = null;
                    return pendingMessage;
                }
 
                return null;
            }
 
            private Message PrepareMessage(WebSocketReceiveResult result, byte[] buffer, int count)
            {
                if (result.MessageType != WebSocketMessageType.Close)
                {
                    Message message;
                    if (_useStreaming)
                    {
                        var wrappedStream = new MaxMessageSizeStream(
                            new TimeoutStream(
                                new WebSocketStream(
                                    this,
                                    new ArraySegment<byte>(buffer, 0, count),
                                    _webSocket,
                                    result.EndOfMessage,
                                    _bufferManager,
                                    _defaultTimeouts.CloseTimeout),
                                _defaultTimeouts.ReceiveTimeout),
                            _maxReceivedMessageSize);
                        message = _encoder.ReadMessage(wrappedStream, _maxBufferSize);
                    }
                    else
                    {
                        ArraySegment<byte> bytes = new ArraySegment<byte>(buffer, 0, count);
                        message = _encoder.ReadMessage(bytes, _bufferManager);
                    }
 
                    if (message.Version.Addressing != AddressingVersion.None || !_localAddress.IsAnonymous)
                    {
                        _localAddress.ApplyTo(message);
                    }
 
                    if (message.Version.Addressing == AddressingVersion.None && message.Headers.Action == null)
                    {
                        if (result.MessageType == WebSocketMessageType.Binary)
                        {
                            message.Headers.Action = WebSocketTransportSettings.BinaryMessageReceivedAction;
                        }
                        else
                        {
                            // WebSocketMesssageType should always be binary or text at this moment. The layer below us will help protect this.
                            Fx.Assert(result.MessageType == WebSocketMessageType.Text, "result.MessageType must be WebSocketMessageType.Text.");
                            message.Headers.Action = WebSocketTransportSettings.TextMessageReceivedAction;
                        }
                    }
 
                    return message;
                }
 
                return null;
            }
 
            private static class AsyncReceiveState
            {
                internal const int Started = 0;
                internal const int Finished = 1;
                internal const int Cancelled = 2;
            }
        }
 
        private class WebSocketStream : Stream
        {
            private readonly WebSocket _webSocket;
            private readonly WebSocketMessageSource _messageSource;
            private readonly TimeSpan _closeTimeout;
            private ArraySegment<byte> _initialReadBuffer;
            private bool _endOfMessageReached;
            private readonly bool _isForRead;
            private bool _endofMessageReceived;
            private readonly WebSocketMessageType _outgoingMessageType;
            private readonly BufferManager _bufferManager;
            private int _messageSourceCleanState;
            private int _endOfMessageWritten;
            private int _readTimeout;
            private int _writeTimeout;
            private TimeoutHelper _readTimeoutHelper;
            private TimeoutHelper _writeTimeoutHelper;
 
            public WebSocketStream(
                        WebSocketMessageSource messageSource,
                        ArraySegment<byte> initialBuffer,
                        WebSocket webSocket,
                        bool endofMessageReceived,
                        BufferManager bufferManager,
                        TimeSpan closeTimeout)
                : this(webSocket, WebSocketDefaults.DefaultWebSocketMessageType, closeTimeout)
            {
                Fx.Assert(messageSource != null, "messageSource should not be null.");
                _messageSource = messageSource;
                _initialReadBuffer = initialBuffer;
                _isForRead = true;
                _endofMessageReceived = endofMessageReceived;
                _bufferManager = bufferManager;
                _messageSourceCleanState = WebSocketHelper.OperationNotStarted;
                _endOfMessageWritten = WebSocketHelper.OperationNotStarted;
            }
 
            public WebSocketStream(
                    WebSocket webSocket,
                    WebSocketMessageType outgoingMessageType,
                    TimeSpan closeTimeout)
            {
                Fx.Assert(webSocket != null, "webSocket should not be null.");
                _webSocket = webSocket;
                _isForRead = false;
                _outgoingMessageType = outgoingMessageType;
                _messageSourceCleanState = WebSocketHelper.OperationFinished;
                _closeTimeout = closeTimeout;
            }
 
            public override bool CanRead
            {
                get { return _isForRead; }
            }
 
            public override bool CanSeek
            {
                get { return false; }
            }
 
            public override bool CanTimeout
            {
                get
                {
                    return true;
                }
            }
 
            public override bool CanWrite
            {
                get { return !_isForRead; }
            }
 
            public override long Length
            {
                get { throw FxTrace.Exception.AsError(new NotSupportedException(InternalSR.SeekNotSupported)); }
            }
 
            public override long Position
            {
                get
                {
                    throw FxTrace.Exception.AsError(new NotSupportedException(InternalSR.SeekNotSupported));
                }
 
                set
                {
                    throw FxTrace.Exception.AsError(new NotSupportedException(InternalSR.SeekNotSupported));
                }
            }
 
            public override int ReadTimeout
            {
                get
                {
                    return _readTimeout;
                }
 
                set
                {
                    Contract.Assert(value >= -1, "ReadTimeout should not be negative.");
                    _readTimeout = value;
                    _readTimeoutHelper = new TimeoutHelper(TimeoutHelper.FromMilliseconds(_readTimeout));
                }
            }
 
            public override int WriteTimeout
            {
                get
                {
                    return _writeTimeout;
                }
 
                set
                {
                    Contract.Assert(value >= -1, "WriteTimeout should not be negative.");
                    _writeTimeout = value;
                    _writeTimeoutHelper = new TimeoutHelper(TimeoutHelper.FromMilliseconds(_readTimeout));
                }
            }
 
            protected override void Dispose(bool disposing)
            {
                base.Dispose(disposing);
                Cleanup();
            }
 
            public override void Flush()
            {
            }
 
            public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
            {
                Contract.Assert(_messageSource != null, "messageSource should not be null in read case.");
 
                var cancelToken = _readTimeoutHelper.GetCancellationToken();
                if (cancelToken.IsCancellationRequested)
                {
                    throw FxTrace.Exception.AsError(WebSocketHelper.GetTimeoutException(null,
                        _readTimeoutHelper.OriginalTimeout, WebSocketHelper.ReceiveOperation));
                }
 
                if (_endOfMessageReached)
                {
                    return Task.FromResult(0);
                }
 
                if (_initialReadBuffer.Count != 0)
                {
                    return Task.FromResult(GetBytesFromInitialReadBuffer(buffer, offset, count));
                }
 
                if (_endofMessageReceived)
                {
                    _endOfMessageReached = true;
                    Cleanup();
                    return Task.FromResult(0);
                }
 
                return ReadAsyncCore(buffer, offset, count, cancellationToken);
            }
 
            private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
            {
                int receivedBytes = 0;
 
                if (WcfEventSource.Instance.WebSocketAsyncReadStartIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketAsyncReadStart(_webSocket.GetHashCode());
                }
 
                WebSocketReceiveResult result;
                try
                {
                    result = await _webSocket.ReceiveAsync(new ArraySegment<byte>(buffer, offset, count), cancellationToken);
                }
                catch (Exception ex)
                {
                    if (Fx.IsFatal(ex))
                    {
                        throw;
                    }
 
                    if (cancellationToken.IsCancellationRequested)
                    {
                        throw Fx.Exception.AsError(new TimeoutException(InternalSR.TaskTimedOutError(new TimeSpan(ReadTimeout))));
                    }
 
                    throw WebSocketHelper.ConvertAndTraceException(ex, new TimeSpan(ReadTimeout), WebSocketHelper.ReceiveOperation);
                }
 
                if (result.EndOfMessage)
                {
                    _endofMessageReceived = true;
                    _endOfMessageReached = true;
                }
 
                receivedBytes = result.Count;
                CheckResultAndEnsureNotCloseMessage(_messageSource, result);
 
                if (WcfEventSource.Instance.WebSocketAsyncReadStopIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketAsyncReadStop(_webSocket.GetHashCode(), receivedBytes, string.Empty);
                }
 
                if (_endOfMessageReached)
                {
                    Cleanup();
                }
 
                return receivedBytes;
            }
 
            public override int Read(byte[] buffer, int offset, int count)
            {
                // WebSocketStream is never used directly but is wrapped in a TimeoutStream which calls the Async
                // implementation in the synchronous Read method. 
                throw FxTrace.Exception.AsError(new NotSupportedException("this method should never get called"));
            }
 
            public override long Seek(long offset, SeekOrigin origin)
            {
                throw FxTrace.Exception.AsError(new NotSupportedException());
            }
 
            public override void SetLength(long value)
            {
                throw FxTrace.Exception.AsError(new NotSupportedException());
            }
 
            public override void Write(byte[] buffer, int offset, int count)
            {
                // WebSocketStream is never used directly but is wrapped in a TimeoutStream which calls the Async
                // implementation in the synchronous Write method. 
                throw FxTrace.Exception.AsError(new NotSupportedException("this method should never get called"));
            }
 
            public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
            {
                if (_endOfMessageWritten == WebSocketHelper.OperationFinished)
                {
                    throw FxTrace.Exception.AsError(new InvalidOperationException(SR.WebSocketStreamWriteCalledAfterEOMSent));
                }
 
                if (WcfEventSource.Instance.WebSocketAsyncWriteStartIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketAsyncWriteStart(
                            _webSocket.GetHashCode(),
                            count,
                            string.Empty);
                }
 
                try
                {
                    await _webSocket.SendAsync(new ArraySegment<byte>(buffer, offset, count), _outgoingMessageType, false, cancellationToken);
                }
                catch (Exception ex)
                {
                    if (Fx.IsFatal(ex))
                    {
                        throw;
                    }
 
                    if (cancellationToken.IsCancellationRequested)
                    {
                        throw Fx.Exception.AsError(new TimeoutException(InternalSR.TaskTimedOutError(new TimeSpan(WriteTimeout))));
                    }
 
                    throw WebSocketHelper.ConvertAndTraceException(ex, new TimeSpan(WriteTimeout), WebSocketHelper.SendOperation);
                }
 
                if (WcfEventSource.Instance.WebSocketAsyncWriteStopIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketAsyncWriteStop(_webSocket.GetHashCode());
                }
            }
 
            private async Task WriteAsyncInternal(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
            {
                await TaskHelpers.EnsureDefaultTaskScheduler();
                await WriteAsync(buffer, offset, count, cancellationToken);
            }
 
            public void WriteEndOfMessage()
            {
                if (WcfEventSource.Instance.WebSocketAsyncWriteStartIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketAsyncWriteStart(
                            _webSocket.GetHashCode(),
                            0,
                            string.Empty);
                }
 
                var timeoutHelper = new TimeoutHelper(_closeTimeout);
 
                if (Interlocked.CompareExchange(ref _endOfMessageWritten, WebSocketHelper.OperationFinished, WebSocketHelper.OperationNotStarted) == WebSocketHelper.OperationNotStarted)
                {
                    Task task = _webSocket.SendAsync(new ArraySegment<byte>(Array.Empty<byte>(), 0, 0), _outgoingMessageType, true, timeoutHelper.GetCancellationToken());
                    task.Wait(timeoutHelper.RemainingTime(), WebSocketHelper.ThrowCorrectException, WebSocketHelper.SendOperation);
                }
 
                if (WcfEventSource.Instance.WebSocketAsyncWriteStopIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketAsyncWriteStop(_webSocket.GetHashCode());
                }
            }
 
            public async void WriteEndOfMessageAsync(Action<object> callback, object state)
            {
                if (WcfEventSource.Instance.WebSocketAsyncWriteStartIsEnabled())
                {
                    // TODO: Open bug about not emitting the hostname/port
                    WcfEventSource.Instance.WebSocketAsyncWriteStart(
                        _webSocket.GetHashCode(),
                        0,
                        string.Empty);
                }
 
                var timeoutHelper = new TimeoutHelper(_closeTimeout);
                var cancelTokenTask = timeoutHelper.GetCancellationTokenAsync();
                try
                {
                    var cancelToken = await cancelTokenTask;
                    await _webSocket.SendAsync(new ArraySegment<byte>(Array.Empty<byte>(), 0, 0), _outgoingMessageType, true, cancelToken);
 
                    if (WcfEventSource.Instance.WebSocketAsyncWriteStopIsEnabled())
                    {
                        WcfEventSource.Instance.WebSocketAsyncWriteStop(_webSocket.GetHashCode());
                    }
                }
                catch (Exception ex)
                {
                    if (Fx.IsFatal(ex))
                    {
                        throw;
                    }
 
                    if (cancelTokenTask.Result.IsCancellationRequested)
                    {
                        throw Fx.Exception.AsError(
                            new TimeoutException(InternalSR.TaskTimedOutError(timeoutHelper.OriginalTimeout)));
                    }
 
                    throw WebSocketHelper.ConvertAndTraceException(ex, timeoutHelper.OriginalTimeout,
                        WebSocketHelper.SendOperation);
                }
                finally
                {
                    callback.Invoke(state);
                }
            }
 
            private static void CheckResultAndEnsureNotCloseMessage(WebSocketMessageSource messageSource, WebSocketReceiveResult result)
            {
                messageSource.CheckCloseStatus(result);
                if (result.MessageType == WebSocketMessageType.Close)
                {
                    throw FxTrace.Exception.AsError(new ProtocolException(SR.WebSocketUnexpectedCloseMessageError));
                }
            }
 
            private int GetBytesFromInitialReadBuffer(byte[] buffer, int offset, int count)
            {
                int bytesToCopy = _initialReadBuffer.Count > count ? count : _initialReadBuffer.Count;
                Buffer.BlockCopy(_initialReadBuffer.Array, _initialReadBuffer.Offset, buffer, offset, bytesToCopy);
                _initialReadBuffer = new ArraySegment<byte>(_initialReadBuffer.Array, _initialReadBuffer.Offset + bytesToCopy, _initialReadBuffer.Count - bytesToCopy);
                return bytesToCopy;
            }
 
            private void Cleanup()
            {
                if (_isForRead)
                {
                    if (Interlocked.CompareExchange(ref _messageSourceCleanState, WebSocketHelper.OperationFinished, WebSocketHelper.OperationNotStarted) == WebSocketHelper.OperationNotStarted)
                    {
                        Exception pendingException = null;
                        try
                        {
                            if (!_endofMessageReceived && (_webSocket.State == WebSocketState.Open || _webSocket.State == WebSocketState.CloseSent))
                            {
                                // Drain the reading stream
                                var closeTimeoutHelper = new TimeoutHelper(_closeTimeout);
                                do
                                {
                                    var cancelToken = closeTimeoutHelper.GetCancellationToken();
                                    Task<WebSocketReceiveResult> receiveTask =
                                        _webSocket.ReceiveAsync(new ArraySegment<byte>(_initialReadBuffer.Array), cancelToken);
                                    receiveTask.Wait(closeTimeoutHelper.RemainingTime(),
                                        WebSocketHelper.ThrowCorrectException, WebSocketHelper.ReceiveOperation);
                                    _endofMessageReceived = receiveTask.GetAwaiter().GetResult().EndOfMessage;
                                } while (!_endofMessageReceived &&
                                         (_webSocket.State == WebSocketState.Open ||
                                          _webSocket.State == WebSocketState.CloseSent));
                            }
                        }
                        catch (Exception ex)
                        {
                            if (Fx.IsFatal(ex))
                            {
                                throw;
                            }
 
                            // Not throwing out this exception during stream cleanup. The exception
                            // will be thrown out when we are trying to receive the next message using the same
                            // WebSocket object.
                            pendingException = WebSocketHelper.ConvertAndTraceException(ex, _closeTimeout, WebSocketHelper.CloseOperation);
                        }
 
                        _bufferManager.ReturnBuffer(_initialReadBuffer.Array);
                        Fx.Assert(_messageSource != null, "messageSource should not be null.");
                        _messageSource.FinishUsingMessageStream(pendingException);
                    }
                }
                else
                {
                    if (Interlocked.CompareExchange(ref _endOfMessageWritten, WebSocketHelper.OperationFinished, WebSocketHelper.OperationNotStarted) == WebSocketHelper.OperationNotStarted)
                    {
                        WriteEndOfMessage();
                    }
                }
            }
        }
 
        private class WebSocketCloseDetails
        {
            private string _outputCloseStatusDescription;
 
            public WebSocketCloseStatus? InputCloseStatus { get; internal set; }
 
            public string InputCloseStatusDescription { get; internal set; }
 
            internal WebSocketCloseStatus OutputCloseStatus { get; private set; } = WebSocketCloseStatus.NormalClosure;
 
            internal string OutputCloseStatusDescription
            {
                get
                {
                    return _outputCloseStatusDescription;
                }
            }
 
            public void SetOutputCloseStatus(WebSocketCloseStatus closeStatus, string closeStatusDescription)
            {
                OutputCloseStatus = closeStatus;
                _outputCloseStatusDescription = closeStatusDescription;
            }
        }
    }
}