File: FrameworkFork\System.ServiceModel\System\ServiceModel\Channels\TransportDuplexSessionChannel.cs
Web Access
Project: src\src\dotnet-svcutil\lib\src\dotnet-svcutil-lib.csproj (dotnet-svcutil-lib)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Runtime;
using System.Runtime.Diagnostics;
using System.Security.Authentication.ExtendedProtection;
using System.ServiceModel.Diagnostics;
using System.ServiceModel.Security;
using System.Threading;
using System.Threading.Tasks;
 
namespace System.ServiceModel.Channels
{
    public abstract class TransportDuplexSessionChannel : TransportOutputChannel, IDuplexSessionChannel
    {
        private BufferManager _bufferManager;
        private IDuplexSession _duplexSession;
        private bool _isInputSessionClosed;
        private bool _isOutputSessionClosed;
        private MessageEncoder _messageEncoder;
        private SynchronizedMessageSource _messageSource;
        private SecurityMessageProperty _remoteSecurity;
        private EndpointAddress _localAddress;
        private SemaphoreSlim _sendLock;
        private Uri _localVia;
        private static Action<object> s_onWriteComplete = new Action<object>(OnWriteComplete);
 
        protected TransportDuplexSessionChannel(
                  ChannelManagerBase manager,
                  ITransportFactorySettings settings,
                  EndpointAddress localAddress,
                  Uri localVia,
                  EndpointAddress remoteAddresss,
                  Uri via)
                : base(manager, remoteAddresss, via, settings.ManualAddressing, settings.MessageVersion)
        {
            _localAddress = localAddress;
            _localVia = localVia;
            _bufferManager = settings.BufferManager;
            _sendLock = new SemaphoreSlim(1);
            _messageEncoder = settings.MessageEncoderFactory.CreateSessionEncoder();
            this.Session = new ConnectionDuplexSession(this);
        }
 
        public EndpointAddress LocalAddress
        {
            get { return _localAddress; }
        }
 
        public SecurityMessageProperty RemoteSecurity
        {
            get { return _remoteSecurity; }
            protected set { _remoteSecurity = value; }
        }
 
        public IDuplexSession Session
        {
            get { return _duplexSession; }
            protected set { _duplexSession = value; }
        }
 
        protected BufferManager BufferManager
        {
            get
            {
                return _bufferManager;
            }
        }
 
        protected MessageEncoder MessageEncoder
        {
            get { return _messageEncoder; }
            set { _messageEncoder = value; }
        }
 
        internal SynchronizedMessageSource MessageSource
        {
            get { return _messageSource; }
        }
 
        protected abstract bool IsStreamedOutput { get; }
 
        public Message Receive()
        {
            return this.Receive(this.DefaultReceiveTimeout);
        }
 
        public Message Receive(TimeSpan timeout)
        {
            Message message = null;
            if (DoneReceivingInCurrentState())
            {
                return null;
            }
 
            bool shouldFault = true;
            try
            {
                message = _messageSource.Receive(timeout);
                this.OnReceiveMessage(message);
                shouldFault = false;
                return message;
            }
            finally
            {
                if (shouldFault)
                {
                    if (message != null)
                    {
                        message.Close();
                        message = null;
                    }
 
                    this.Fault();
                }
            }
        }
 
        public async Task<Message> ReceiveAsync(TimeSpan timeout)
        {
            Message message = null;
            if (DoneReceivingInCurrentState())
            {
                return null;
            }
 
            bool shouldFault = true;
            try
            {
                message = await _messageSource.ReceiveAsync(timeout);
                this.OnReceiveMessage(message);
                shouldFault = false;
                return message;
            }
            finally
            {
                if (shouldFault)
                {
                    if (message != null)
                    {
                        message.Close();
                        message = null;
                    }
 
                    this.Fault();
                }
            }
        }
 
        public IAsyncResult BeginReceive(AsyncCallback callback, object state)
        {
            return this.BeginReceive(this.DefaultReceiveTimeout, callback, state);
        }
 
        public IAsyncResult BeginReceive(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return this.ReceiveAsync(timeout).ToApm(callback, state);
        }
 
        public Message EndReceive(IAsyncResult result)
        {
            return result.ToApmEnd<Message>();
        }
 
        public IAsyncResult BeginTryReceive(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return this.ReceiveAsync(timeout).ToApm(callback, state);
        }
 
        public bool EndTryReceive(IAsyncResult result, out Message message)
        {
            try
            {
                message = result.ToApmEnd<Message>();
                return true;
            }
            catch (TimeoutException e)
            {
                if (WcfEventSource.Instance.ReceiveTimeoutIsEnabled())
                {
                    WcfEventSource.Instance.ReceiveTimeout(e.Message);
                }
 
 
                message = null;
                return false;
            }
        }
 
        public bool TryReceive(TimeSpan timeout, out Message message)
        {
            try
            {
                message = this.Receive(timeout);
                return true;
            }
            catch (TimeoutException e)
            {
                if (WcfEventSource.Instance.ReceiveTimeoutIsEnabled())
                {
                    WcfEventSource.Instance.ReceiveTimeout(e.Message);
                }
                message = null;
                return false;
            }
        }
 
        public async Task<bool> WaitForMessageAsync(TimeSpan timeout)
        {
            if (DoneReceivingInCurrentState())
            {
                return true;
            }
 
            bool shouldFault = true;
            try
            {
                bool success = await _messageSource.WaitForMessageAsync(timeout);
                shouldFault = !success; // need to fault if we've timed out because we're now toast
                return success;
            }
            finally
            {
                if (shouldFault)
                {
                    this.Fault();
                }
            }
        }
 
        public bool WaitForMessage(TimeSpan timeout)
        {
            if (DoneReceivingInCurrentState())
            {
                return true;
            }
 
            bool shouldFault = true;
            try
            {
                bool success = _messageSource.WaitForMessage(timeout);
                shouldFault = !success; // need to fault if we've timed out because we're now toast
                return success;
            }
            finally
            {
                if (shouldFault)
                {
                    this.Fault();
                }
            }
        }
 
        public IAsyncResult BeginWaitForMessage(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return this.WaitForMessageAsync(timeout).ToApm(callback, state);
        }
 
        public bool EndWaitForMessage(IAsyncResult result)
        {
            return result.ToApmEnd<bool>();
        }
 
        protected void SetMessageSource(IMessageSource messageSource)
        {
            _messageSource = new SynchronizedMessageSource(messageSource);
        }
 
        protected abstract Task CloseOutputSessionCoreAsync(TimeSpan timeout);
 
        protected abstract void CloseOutputSessionCore(TimeSpan timeout);
 
        protected async Task CloseOutputSessionAsync(TimeSpan timeout)
        {
            ThrowIfNotOpened();
            ThrowIfFaulted();
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
 
            // If timeout == TimeSpan.MaxValue, then we want to pass Timeout.Infinite as 
            // SemaphoreSlim doesn't accept timeouts > Int32.MaxValue.
            // Using TimeoutHelper.RemainingTime() would yield a value less than TimeSpan.MaxValue
            // and would result in the value Int32.MaxValue so we must use the original timeout specified.
            if (!await _sendLock.WaitAsync(TimeoutHelper.ToMilliseconds(timeout)))
            {
                if (WcfEventSource.Instance.CloseTimeoutIsEnabled())
                {
                    WcfEventSource.Instance.CloseTimeout(string.Format(SRServiceModel.CloseTimedOut, timeout));
                }
 
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException(
                                                string.Format(SRServiceModel.CloseTimedOut, timeout),
                                                TimeoutHelper.CreateEnterTimedOutException(timeout)));
            }
 
            try
            {
                // check again in case the previous send faulted while we were waiting for the lock
                ThrowIfFaulted();
 
                // we're synchronized by sendLock here
                if (_isOutputSessionClosed)
                {
                    return;
                }
 
                _isOutputSessionClosed = true;
                bool shouldFault = true;
                try
                {
                    await this.CloseOutputSessionCoreAsync(timeout);
                    this.OnOutputSessionClosed(ref timeoutHelper);
                    shouldFault = false;
                }
                finally
                {
                    if (shouldFault)
                    {
                        this.Fault();
                    }
                }
            }
            finally
            {
                _sendLock.Release();
            }
        }
 
        protected void CloseOutputSession(TimeSpan timeout)
        {
            ThrowIfNotOpened();
            ThrowIfFaulted();
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
 
            // If timeout == TimeSpan.MaxValue, then we want to pass Timeout.Infinite as 
            // SemaphoreSlim doesn't accept timeouts > Int32.MaxValue.
            // Using TimeoutHelper.RemainingTime() would yield a value less than TimeSpan.MaxValue
            // and would result in the value Int32.MaxValue so we must use the original timeout specified.
            if (!_sendLock.Wait(TimeoutHelper.ToMilliseconds(timeout)))
            {
                if (WcfEventSource.Instance.CloseTimeoutIsEnabled())
                {
                    WcfEventSource.Instance.CloseTimeout(string.Format(SRServiceModel.CloseTimedOut, timeout));
                }
 
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException(
                                                string.Format(SRServiceModel.CloseTimedOut, timeout),
                                                TimeoutHelper.CreateEnterTimedOutException(timeout)));
            }
 
            try
            {
                // check again in case the previous send faulted while we were waiting for the lock
                ThrowIfFaulted();
 
                // we're synchronized by sendLock here
                if (_isOutputSessionClosed)
                {
                    return;
                }
 
                _isOutputSessionClosed = true;
                bool shouldFault = true;
                try
                {
                    this.CloseOutputSessionCore(timeout);
                    this.OnOutputSessionClosed(ref timeoutHelper);
                    shouldFault = false;
                }
                finally
                {
                    if (shouldFault)
                    {
                        this.Fault();
                    }
                }
            }
            finally
            {
                _sendLock.Release();
            }
        }
 
        // used to return cached connection to the pool/reader pool
        protected abstract void ReturnConnectionIfNecessary(bool abort, TimeSpan timeout);
 
        protected override void OnAbort()
        {
            this.ReturnConnectionIfNecessary(true, TimeSpan.Zero);
        }
 
        protected override void OnFaulted()
        {
            base.OnFaulted();
            this.ReturnConnectionIfNecessary(true, TimeSpan.Zero);
        }
 
        protected internal override async Task OnCloseAsync(TimeSpan timeout)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            await this.CloseOutputSessionAsync(timeoutHelper.RemainingTime());
 
            // close input session if necessary
            if (!_isInputSessionClosed)
            {
                await this.EnsureInputClosedAsync(timeoutHelper.RemainingTime());
                this.OnInputSessionClosed();
            }
 
            this.CompleteClose(timeoutHelper.RemainingTime());
        }
 
        protected override void OnClose(TimeSpan timeout)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            this.CloseOutputSession(timeoutHelper.RemainingTime());
 
            // close input session if necessary
            if (!_isInputSessionClosed)
            {
                this.EnsureInputClosed(timeoutHelper.RemainingTime());
                this.OnInputSessionClosed();
            }
 
            this.CompleteClose(timeoutHelper.RemainingTime());
        }
 
        protected override IAsyncResult OnBeginClose(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return OnCloseAsync(timeout).ToApm(callback, state);
        }
 
        protected override void OnEndClose(IAsyncResult result)
        {
            result.ToApmEnd();
        }
 
        protected override void OnClosed()
        {
            base.OnClosed();
            // clean up the CBT after transitioning to the closed state
            //ChannelBindingUtility.Dispose(ref this.channelBindingToken);
        }
 
        protected virtual void OnReceiveMessage(Message message)
        {
            if (message == null)
            {
                this.OnInputSessionClosed();
            }
            else
            {
                this.PrepareMessage(message);
            }
        }
 
        protected void ApplyChannelBinding(Message message)
        {
            //ChannelBindingUtility.TryAddToMessage(this.channelBindingToken, message, false);
        }
 
        protected virtual void PrepareMessage(Message message)
        {
            message.Properties.Via = _localVia;
 
            this.ApplyChannelBinding(message);
 
            if (FxTrace.Trace.IsEnd2EndActivityTracingEnabled)
            {
                EventTraceActivity eventTraceActivity = EventTraceActivityHelper.TryExtractActivity(message);
                Guid relatedActivityId = EventTraceActivity.GetActivityIdFromThread();
                if (eventTraceActivity == null)
                {
                    eventTraceActivity = EventTraceActivity.GetFromThreadOrCreate();
                    EventTraceActivityHelper.TryAttachActivity(message, eventTraceActivity);
                }
 
                if (WcfEventSource.Instance.MessageReceivedByTransportIsEnabled())
                {
                    WcfEventSource.Instance.MessageReceivedByTransport(
                        eventTraceActivity,
                        this.LocalAddress != null && this.LocalAddress.Uri != null ? this.LocalAddress.Uri.AbsoluteUri : string.Empty,
                        relatedActivityId);
                }
            }
        }
 
        protected abstract AsyncCompletionResult StartWritingBufferedMessage(Message message, ArraySegment<byte> messageData, bool allowOutputBatching, TimeSpan timeout, Action<object> callback, object state);
 
        protected abstract AsyncCompletionResult BeginCloseOutput(TimeSpan timeout, Action<object> callback, object state);
 
        protected virtual void FinishWritingMessage()
        {
        }
 
        protected abstract ArraySegment<byte> EncodeMessage(Message message);
 
        protected abstract void OnSendCore(Message message, TimeSpan timeout);
 
        protected abstract AsyncCompletionResult StartWritingStreamedMessage(Message message, TimeSpan timeout, Action<object> callback, object state);
 
        protected override async Task OnSendAsync(Message message, TimeSpan timeout)
        {
            this.ThrowIfDisposedOrNotOpen();
 
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
 
            // If timeout == TimeSpan.MaxValue, then we want to pass Timeout.Infinite as 
            // SemaphoreSlim doesn't accept timeouts > Int32.MaxValue.
            // Using TimeoutHelper.RemainingTime() would yield a value less than TimeSpan.MaxValue
            // and would result in the value Int32.MaxValue so we must use the original timeout specified.
            if (!await _sendLock.WaitAsync(TimeoutHelper.ToMilliseconds(timeout)))
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException(
                                            string.Format(SRServiceModel.SendToViaTimedOut, Via, timeout),
                                            TimeoutHelper.CreateEnterTimedOutException(timeout)));
            }
 
            byte[] buffer = null;
 
            try
            {
                // check again in case the previous send faulted while we were waiting for the lock
                this.ThrowIfDisposedOrNotOpen();
                this.ThrowIfOutputSessionClosed();
 
                bool success = false;
                try
                {
                    this.ApplyChannelBinding(message);
 
                    var tcs = new TaskCompletionSource<bool>(this);
 
                    AsyncCompletionResult completionResult;
                    if (this.IsStreamedOutput)
                    {
                        completionResult = this.StartWritingStreamedMessage(message, timeoutHelper.RemainingTime(), s_onWriteComplete, this);
                    }
                    else
                    {
                        bool allowOutputBatching;
                        ArraySegment<byte> messageData;
                        allowOutputBatching = message.Properties.AllowOutputBatching;
                        messageData = this.EncodeMessage(message);
 
                        buffer = messageData.Array;
                        completionResult = this.StartWritingBufferedMessage(
                                                                          message,
                                                                          messageData,
                                                                          allowOutputBatching,
                                                                          timeoutHelper.RemainingTime(),
                                                                          s_onWriteComplete,
                                                                          tcs);
                    }
 
                    if (completionResult == AsyncCompletionResult.Completed)
                    {
                        tcs.TrySetResult(true);
                    }
 
                    await tcs.Task;
 
                    this.FinishWritingMessage();
 
                    success = true;
                    if (WcfEventSource.Instance.MessageSentByTransportIsEnabled())
                    {
                        EventTraceActivity eventTraceActivity = EventTraceActivityHelper.TryExtractActivity(message);
                        WcfEventSource.Instance.MessageSentByTransport(eventTraceActivity, this.RemoteAddress.Uri.AbsoluteUri);
                    }
                }
                finally
                {
                    if (!success)
                    {
                        this.Fault();
                    }
                }
            }
            finally
            {
                _sendLock.Release();
            }
            if (buffer != null)
            {
                _bufferManager.ReturnBuffer(buffer);
            }
        }
 
        private static void OnWriteComplete(object state)
        {
            if (state == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("state");
            }
 
            var tcs = state as TaskCompletionSource<bool>;
            if (tcs == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgument("state", SRServiceModel.SPS_InvalidAsyncResult);
            }
 
            tcs.TrySetResult(true);
        }
 
 
        protected override void OnSend(Message message, TimeSpan timeout)
        {
            this.ThrowIfDisposedOrNotOpen();
 
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
 
            // If timeout == TimeSpan.MaxValue, then we want to pass Timeout.Infinite as 
            // SemaphoreSlim doesn't accept timeouts > Int32.MaxValue.
            // Using TimeoutHelper.RemainingTime() would yield a value less than TimeSpan.MaxValue
            // and would result in the value Int32.MaxValue so we must use the original timeout specified.
            if (!_sendLock.Wait(TimeoutHelper.ToMilliseconds(timeout)))
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException(
                                            string.Format(SRServiceModel.SendToViaTimedOut, Via, timeout),
                                            TimeoutHelper.CreateEnterTimedOutException(timeout)));
            }
 
            try
            {
                // check again in case the previous send faulted while we were waiting for the lock
                this.ThrowIfDisposedOrNotOpen();
                this.ThrowIfOutputSessionClosed();
 
                bool success = false;
                try
                {
                    this.ApplyChannelBinding(message);
 
                    this.OnSendCore(message, timeoutHelper.RemainingTime());
                    success = true;
                    if (WcfEventSource.Instance.MessageSentByTransportIsEnabled())
                    {
                        EventTraceActivity eventTraceActivity = EventTraceActivityHelper.TryExtractActivity(message);
                        WcfEventSource.Instance.MessageSentByTransport(eventTraceActivity, this.RemoteAddress.Uri.AbsoluteUri);
                    }
                }
                finally
                {
                    if (!success)
                    {
                        this.Fault();
                    }
                }
            }
            finally
            {
                _sendLock.Release();
            }
        }
 
        // cleanup after the framing handshake has completed
        protected abstract void CompleteClose(TimeSpan timeout);
 
        // must be called under sendLock 
        private void ThrowIfOutputSessionClosed()
        {
            if (_isOutputSessionClosed)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SRServiceModel.SendCannotBeCalledAfterCloseOutputSession));
            }
        }
 
        private async Task EnsureInputClosedAsync(TimeSpan timeout)
        {
            Message message = await this.MessageSource.ReceiveAsync(timeout);
            if (message != null)
            {
                using (message)
                {
                    ProtocolException error = ProtocolException.ReceiveShutdownReturnedNonNull(message);
                    throw TraceUtility.ThrowHelperError(error, message);
                }
            }
        }
 
        private void EnsureInputClosed(TimeSpan timeout)
        {
            Message message = this.MessageSource.Receive(timeout);
            if (message != null)
            {
                using (message)
                {
                    ProtocolException error = ProtocolException.ReceiveShutdownReturnedNonNull(message);
                    throw TraceUtility.ThrowHelperError(error, message);
                }
            }
        }
 
        private void OnInputSessionClosed()
        {
            lock (ThisLock)
            {
                if (_isInputSessionClosed)
                {
                    return;
                }
 
                _isInputSessionClosed = true;
            }
        }
 
        private void OnOutputSessionClosed(ref TimeoutHelper timeoutHelper)
        {
            bool releaseConnection = false;
            lock (ThisLock)
            {
                if (_isInputSessionClosed)
                {
                    // we're all done, release the connection
                    releaseConnection = true;
                }
            }
 
            if (releaseConnection)
            {
                this.ReturnConnectionIfNecessary(false, timeoutHelper.RemainingTime());
            }
        }
 
        public class ConnectionDuplexSession : IDuplexSession
        {
            private static UriGenerator s_uriGenerator;
            private TransportDuplexSessionChannel _channel;
            private string _id;
 
            public ConnectionDuplexSession(TransportDuplexSessionChannel channel)
                : base()
            {
                _channel = channel;
            }
 
            public string Id
            {
                get
                {
                    if (_id == null)
                    {
                        lock (_channel)
                        {
                            if (_id == null)
                            {
                                _id = UriGenerator.Next();
                            }
                        }
                    }
 
                    return _id;
                }
            }
 
            public TransportDuplexSessionChannel Channel
            {
                get { return _channel; }
            }
 
            private static UriGenerator UriGenerator
            {
                get
                {
                    if (s_uriGenerator == null)
                    {
                        s_uriGenerator = new UriGenerator();
                    }
 
                    return s_uriGenerator;
                }
            }
 
            public IAsyncResult BeginCloseOutputSession(AsyncCallback callback, object state)
            {
                return this.BeginCloseOutputSession(_channel.DefaultCloseTimeout, callback, state);
            }
 
            public IAsyncResult BeginCloseOutputSession(TimeSpan timeout, AsyncCallback callback, object state)
            {
                return _channel.CloseOutputSessionAsync(timeout).ToApm(callback, state);
            }
 
            public void EndCloseOutputSession(IAsyncResult result)
            {
                result.ToApmEnd();
            }
 
            public void CloseOutputSession()
            {
                this.CloseOutputSession(_channel.DefaultCloseTimeout);
            }
 
            public void CloseOutputSession(TimeSpan timeout)
            {
                _channel.CloseOutputSession(timeout);
            }
        }
    }
}