File: FrameworkFork\System.ServiceModel\System\ServiceModel\Channels\FramingChannels.cs
Web Access
Project: src\src\dotnet-svcutil\lib\src\dotnet-svcutil-lib.csproj (dotnet-svcutil-lib)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Diagnostics.Contracts;
using System.IO;
using System.Runtime;
using System.ServiceModel.Security;
using System.ServiceModel.Channels.ConnectionHelpers;
using System.Threading.Tasks;
 
namespace System.ServiceModel.Channels
{
    internal abstract class FramingDuplexSessionChannel : TransportDuplexSessionChannel
    {
        private static EndpointAddress s_anonymousEndpointAddress = new EndpointAddress(EndpointAddress.AnonymousUri, new AddressHeader[0]);
        private IConnection _connection;
        private bool _exposeConnectionProperty;
 
        private FramingDuplexSessionChannel(ChannelManagerBase manager, IConnectionOrientedTransportFactorySettings settings,
            EndpointAddress localAddress, Uri localVia, EndpointAddress remoteAddresss, Uri via, bool exposeConnectionProperty)
            : base(manager, settings, localAddress, localVia, remoteAddresss, via)
        {
            _exposeConnectionProperty = exposeConnectionProperty;
        }
 
        protected FramingDuplexSessionChannel(ChannelManagerBase factory, IConnectionOrientedTransportFactorySettings settings,
            EndpointAddress remoteAddresss, Uri via, bool exposeConnectionProperty)
            : this(factory, settings, s_anonymousEndpointAddress, settings.MessageVersion.Addressing == AddressingVersion.None ? null : new Uri("http://www.w3.org/2005/08/addressing/anonymous"),
            remoteAddresss, via, exposeConnectionProperty)
        {
            this.Session = FramingConnectionDuplexSession.CreateSession(this, settings.Upgrade);
        }
 
        protected IConnection Connection
        {
            get
            {
                return _connection;
            }
            set
            {
                _connection = value;
            }
        }
 
        protected override bool IsStreamedOutput
        {
            get { return false; }
        }
 
        protected override void CloseOutputSessionCore(TimeSpan timeout)
        {
            Connection.Write(SessionEncoder.EndBytes, 0, SessionEncoder.EndBytes.Length, true, timeout);
        }
 
        protected override Task CloseOutputSessionCoreAsync(TimeSpan timeout)
        {
            return Connection.WriteAsync(SessionEncoder.EndBytes, 0, SessionEncoder.EndBytes.Length, true, timeout);
        }
 
        protected override void CompleteClose(TimeSpan timeout)
        {
            this.ReturnConnectionIfNecessary(false, timeout);
        }
 
        protected override void PrepareMessage(Message message)
        {
            if (_exposeConnectionProperty)
            {
                message.Properties[ConnectionMessageProperty.Name] = _connection;
            }
            base.PrepareMessage(message);
        }
 
        protected override void OnSendCore(Message message, TimeSpan timeout)
        {
            bool allowOutputBatching;
            ArraySegment<byte> messageData;
 
            allowOutputBatching = message.Properties.AllowOutputBatching;
            messageData = this.EncodeMessage(message);
 
            this.Connection.Write(messageData.Array, messageData.Offset, messageData.Count, !allowOutputBatching,
                timeout, this.BufferManager);
        }
 
        protected override AsyncCompletionResult BeginCloseOutput(TimeSpan timeout, Action<object> callback, object state)
        {
            return this.Connection.BeginWrite(SessionEncoder.EndBytes, 0, SessionEncoder.EndBytes.Length,
                    true, timeout, callback, state);
        }
 
        protected override void FinishWritingMessage()
        {
            this.Connection.EndWrite();
        }
 
        protected override AsyncCompletionResult StartWritingBufferedMessage(Message message, ArraySegment<byte> messageData, bool allowOutputBatching, TimeSpan timeout, Action<object> callback, object state)
        {
            return this.Connection.BeginWrite(messageData.Array, messageData.Offset, messageData.Count,
                    !allowOutputBatching, timeout, callback, state);
        }
 
        protected override AsyncCompletionResult StartWritingStreamedMessage(Message message, TimeSpan timeout, Action<object> callback, object state)
        {
            Contract.Assert(false, "Streamed output should never be called in this channel.");
            throw new InvalidOperationException();
        }
 
        protected override ArraySegment<byte> EncodeMessage(Message message)
        {
            ArraySegment<byte> messageData = MessageEncoder.WriteMessage(message,
                int.MaxValue, this.BufferManager, SessionEncoder.MaxMessageFrameSize);
 
            messageData = SessionEncoder.EncodeMessageFrame(messageData);
 
            return messageData;
        }
 
        internal class FramingConnectionDuplexSession : ConnectionDuplexSession
        {
            private FramingConnectionDuplexSession(FramingDuplexSessionChannel channel)
                : base(channel)
            {
            }
 
            public static FramingConnectionDuplexSession CreateSession(FramingDuplexSessionChannel channel,
                StreamUpgradeProvider upgrade)
            {
                StreamSecurityUpgradeProvider security = upgrade as StreamSecurityUpgradeProvider;
                if (security == null)
                {
                    return new FramingConnectionDuplexSession(channel);
                }
                else
                {
                    return new SecureConnectionDuplexSession(channel);
                }
            }
            private class SecureConnectionDuplexSession : FramingConnectionDuplexSession, ISecuritySession
            {
                private EndpointIdentity _remoteIdentity;
 
                public SecureConnectionDuplexSession(FramingDuplexSessionChannel channel)
                    : base(channel)
                {
                    // empty
                }
 
                EndpointIdentity ISecuritySession.RemoteIdentity
                {
                    get
                    {
                        if (_remoteIdentity == null)
                        {
                            SecurityMessageProperty security = this.Channel.RemoteSecurity;
                            if (security != null && security.ServiceSecurityContext != null &&
                                security.ServiceSecurityContext.IdentityClaim != null &&
                                security.ServiceSecurityContext.PrimaryIdentity != null)
                            {
                                _remoteIdentity = EndpointIdentity.CreateIdentity(
                                    security.ServiceSecurityContext.IdentityClaim);
                            }
                        }
 
                        return _remoteIdentity;
                    }
                }
            }
        }
    }
 
    internal class ClientFramingDuplexSessionChannel : FramingDuplexSessionChannel
    {
        private IConnectionOrientedTransportChannelFactorySettings _settings;
        private ClientDuplexDecoder _decoder;
        private StreamUpgradeProvider _upgrade;
        private ConnectionPoolHelper _connectionPoolHelper;
        private bool _flowIdentity;
 
        public ClientFramingDuplexSessionChannel(ChannelManagerBase factory, IConnectionOrientedTransportChannelFactorySettings settings,
            EndpointAddress remoteAddresss, Uri via, IConnectionInitiator connectionInitiator, ConnectionPool connectionPool,
            bool exposeConnectionProperty, bool flowIdentity)
            : base(factory, settings, remoteAddresss, via, exposeConnectionProperty)
        {
            _settings = settings;
            this.MessageEncoder = settings.MessageEncoderFactory.CreateSessionEncoder();
            _upgrade = settings.Upgrade;
            _flowIdentity = flowIdentity;
            _connectionPoolHelper = new DuplexConnectionPoolHelper(this, connectionPool, connectionInitiator);
        }
 
        private ArraySegment<byte> CreatePreamble()
        {
            EncodedVia encodedVia = new EncodedVia(this.Via.AbsoluteUri);
            EncodedContentType encodedContentType = EncodedContentType.Create(this.MessageEncoder.ContentType);
 
            // calculate preamble length
            int startSize = ClientDuplexEncoder.ModeBytes.Length + SessionEncoder.CalcStartSize(encodedVia, encodedContentType);
            int preambleEndOffset = 0;
            if (_upgrade == null)
            {
                preambleEndOffset = startSize;
                startSize += ClientDuplexEncoder.PreambleEndBytes.Length;
            }
 
            byte[] startBytes = Fx.AllocateByteArray(startSize);
            Buffer.BlockCopy(ClientDuplexEncoder.ModeBytes, 0, startBytes, 0, ClientDuplexEncoder.ModeBytes.Length);
            SessionEncoder.EncodeStart(startBytes, ClientDuplexEncoder.ModeBytes.Length, encodedVia, encodedContentType);
            if (preambleEndOffset > 0)
            {
                Buffer.BlockCopy(ClientDuplexEncoder.PreambleEndBytes, 0, startBytes, preambleEndOffset, ClientDuplexEncoder.PreambleEndBytes.Length);
            }
 
            return new ArraySegment<byte>(startBytes, 0, startSize);
        }
 
        protected override IAsyncResult OnBeginOpen(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return OnOpenAsync(timeout).ToApm(callback, state);
        }
 
        protected override void OnEndOpen(IAsyncResult result)
        {
            result.ToApmEnd();
        }
 
        public override T GetProperty<T>()
        {
            T result = base.GetProperty<T>();
 
            if (result == null && _upgrade != null)
            {
                result = _upgrade.GetProperty<T>();
            }
 
            return result;
        }
 
        private async Task<IConnection> SendPreambleAsync(IConnection connection, ArraySegment<byte> preamble, TimeSpan timeout)
        {
            var timeoutHelper = new TimeoutHelper(timeout);
 
            // initialize a new decoder
            _decoder = new ClientDuplexDecoder(0);
            byte[] ackBuffer = new byte[1];
            await connection.WriteAsync(preamble.Array, preamble.Offset, preamble.Count, true, timeoutHelper.RemainingTime());
 
            if (_upgrade != null)
            {
                StreamUpgradeInitiator upgradeInitiator = _upgrade.CreateUpgradeInitiator(this.RemoteAddress, this.Via);
 
                await upgradeInitiator.OpenAsync(timeoutHelper.RemainingTime());
                var connectionWrapper = new OutWrapper<IConnection>();
                connectionWrapper.Value = connection;
                bool upgradeInitiated = await ConnectionUpgradeHelper.InitiateUpgradeAsync(upgradeInitiator, connectionWrapper, _decoder, this, timeoutHelper.RemainingTime());
                connection = connectionWrapper.Value;
                if (!upgradeInitiated)
                {
                    await ConnectionUpgradeHelper.DecodeFramingFaultAsync(_decoder, connection, this.Via, MessageEncoder.ContentType, timeoutHelper.RemainingTime());
                }
 
                SetRemoteSecurity(upgradeInitiator);
                await upgradeInitiator.CloseAsync(timeoutHelper.RemainingTime());
 
                await connection.WriteAsync(ClientDuplexEncoder.PreambleEndBytes, 0, ClientDuplexEncoder.PreambleEndBytes.Length, true, timeoutHelper.RemainingTime());
            }
 
            int ackBytesRead = await connection.ReadAsync(ackBuffer, 0, ackBuffer.Length, timeoutHelper.RemainingTime());
 
            if (!ConnectionUpgradeHelper.ValidatePreambleResponse(ackBuffer, ackBytesRead, _decoder, Via))
            {
                await ConnectionUpgradeHelper.DecodeFramingFaultAsync(_decoder, connection, Via,
                    MessageEncoder.ContentType, timeoutHelper.RemainingTime());
            }
 
            return connection;
        }
 
 
        private IConnection SendPreamble(IConnection connection, ArraySegment<byte> preamble, ref TimeoutHelper timeoutHelper)
        {
            // initialize a new decoder
            _decoder = new ClientDuplexDecoder(0);
            byte[] ackBuffer = new byte[1];
            connection.Write(preamble.Array, preamble.Offset, preamble.Count, true, timeoutHelper.RemainingTime());
 
            if (_upgrade != null)
            {
                StreamUpgradeInitiator upgradeInitiator = _upgrade.CreateUpgradeInitiator(this.RemoteAddress, this.Via);
 
                upgradeInitiator.Open(timeoutHelper.RemainingTime());
                if (!ConnectionUpgradeHelper.InitiateUpgrade(upgradeInitiator, ref connection, _decoder, this, ref timeoutHelper))
                {
                    ConnectionUpgradeHelper.DecodeFramingFault(_decoder, connection, this.Via, MessageEncoder.ContentType, ref timeoutHelper);
                }
 
                SetRemoteSecurity(upgradeInitiator);
                upgradeInitiator.Close(timeoutHelper.RemainingTime());
                connection.Write(ClientDuplexEncoder.PreambleEndBytes, 0, ClientDuplexEncoder.PreambleEndBytes.Length, true, timeoutHelper.RemainingTime());
            }
 
            // read ACK
            int ackBytesRead = connection.Read(ackBuffer, 0, ackBuffer.Length, timeoutHelper.RemainingTime());
            if (!ConnectionUpgradeHelper.ValidatePreambleResponse(ackBuffer, ackBytesRead, _decoder, Via))
            {
                ConnectionUpgradeHelper.DecodeFramingFault(_decoder, connection, Via,
                    MessageEncoder.ContentType, ref timeoutHelper);
            }
 
            return connection;
        }
 
        protected internal override async Task OnOpenAsync(TimeSpan timeout)
        {
            IConnection connection;
            try
            {
                connection = await _connectionPoolHelper.EstablishConnectionAsync(timeout);
            }
            catch (TimeoutException exception)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new TimeoutException(string.Format(SRServiceModel.TimeoutOnOpen, timeout), exception));
            }
 
            bool connectionAccepted = false;
            try
            {
                AcceptConnection(connection);
                connectionAccepted = true;
            }
            finally
            {
                if (!connectionAccepted)
                {
                    _connectionPoolHelper.Abort();
                }
            }
        }
 
        protected override void OnOpen(TimeSpan timeout)
        {
            IConnection connection;
            try
            {
                connection = _connectionPoolHelper.EstablishConnection(timeout);
            }
            catch (TimeoutException exception)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new TimeoutException(string.Format(SRServiceModel.TimeoutOnOpen, timeout), exception));
            }
 
            bool connectionAccepted = false;
            try
            {
                AcceptConnection(connection);
                connectionAccepted = true;
            }
            finally
            {
                if (!connectionAccepted)
                {
                    _connectionPoolHelper.Abort();
                }
            }
        }
 
        protected override void ReturnConnectionIfNecessary(bool abort, TimeSpan timeout)
        {
            lock (ThisLock)
            {
                if (abort)
                {
                    _connectionPoolHelper.Abort();
                }
                else
                {
                    _connectionPoolHelper.Close(timeout);
                }
            }
        }
 
        private void AcceptConnection(IConnection connection)
        {
            base.SetMessageSource(new ClientDuplexConnectionReader(this, connection, _decoder, _settings, MessageEncoder));
 
            lock (ThisLock)
            {
                if (this.State != CommunicationState.Opening)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                        new CommunicationObjectAbortedException(string.Format(SRServiceModel.DuplexChannelAbortedDuringOpen, this.Via)));
                }
 
                this.Connection = connection;
            }
        }
 
        private void SetRemoteSecurity(StreamUpgradeInitiator upgradeInitiator)
        {
            this.RemoteSecurity = StreamSecurityUpgradeInitiator.GetRemoteSecurity(upgradeInitiator);
        }
 
        protected override void PrepareMessage(Message message)
        {
            base.PrepareMessage(message);
        }
 
        internal class DuplexConnectionPoolHelper : ConnectionPoolHelper
        {
            private ClientFramingDuplexSessionChannel _channel;
            private ArraySegment<byte> _preamble;
 
            public DuplexConnectionPoolHelper(ClientFramingDuplexSessionChannel channel,
                ConnectionPool connectionPool, IConnectionInitiator connectionInitiator)
                : base(connectionPool, connectionInitiator, channel.Via)
            {
                _channel = channel;
                _preamble = channel.CreatePreamble();
            }
 
            protected override TimeoutException CreateNewConnectionTimeoutException(TimeSpan timeout, TimeoutException innerException)
            {
                return new TimeoutException(string.Format(SRServiceModel.OpenTimedOutEstablishingTransportSession,
                        timeout, _channel.Via.AbsoluteUri), innerException);
            }
 
            protected override IConnection AcceptPooledConnection(IConnection connection, ref TimeoutHelper timeoutHelper)
            {
                return _channel.SendPreamble(connection, _preamble, ref timeoutHelper);
            }
 
            protected override Task<IConnection> AcceptPooledConnectionAsync(IConnection connection, ref TimeoutHelper timeoutHelper)
            {
                return _channel.SendPreambleAsync(connection, _preamble, timeoutHelper.RemainingTime());
            }
        }
    }
 
    // used by StreamedFramingRequestChannel and ClientFramingDuplexSessionChannel
    internal class ConnectionUpgradeHelper
    {
        public static async Task DecodeFramingFaultAsync(ClientFramingDecoder decoder, IConnection connection,
            Uri via, string contentType, TimeSpan timeout)
        {
            var timeoutHelper = new TimeoutHelper(timeout);
            ValidateReadingFaultString(decoder);
 
            int size = await connection.ReadAsync(0,
                Math.Min(FaultStringDecoder.FaultSizeQuota, connection.AsyncReadBufferSize),
                timeoutHelper.RemainingTime());
 
            int offset = 0;
            while (size > 0)
            {
                int bytesDecoded = decoder.Decode(connection.AsyncReadBuffer, offset, size);
                offset += bytesDecoded;
                size -= bytesDecoded;
 
                if (decoder.CurrentState == ClientFramingDecoderState.Fault)
                {
                    ConnectionUtilities.CloseNoThrow(connection, timeoutHelper.RemainingTime());
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                        FaultStringDecoder.GetFaultException(decoder.Fault, via.ToString(), contentType));
                }
                else
                {
                    if (decoder.CurrentState != ClientFramingDecoderState.ReadingFaultString)
                    {
                        throw new Exception("invalid framing client state machine");
                    }
                    if (size == 0)
                    {
                        offset = 0;
                        size = await connection.ReadAsync(0,
                            Math.Min(FaultStringDecoder.FaultSizeQuota, connection.AsyncReadBufferSize),
                            timeoutHelper.RemainingTime());
                    }
                }
            }
 
            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(decoder.CreatePrematureEOFException());
        }
 
        public static void DecodeFramingFault(ClientFramingDecoder decoder, IConnection connection,
            Uri via, string contentType, ref TimeoutHelper timeoutHelper)
        {
            ValidateReadingFaultString(decoder);
 
            int offset = 0;
            byte[] faultBuffer = Fx.AllocateByteArray(FaultStringDecoder.FaultSizeQuota);
            int size = connection.Read(faultBuffer, offset, faultBuffer.Length, timeoutHelper.RemainingTime());
 
            while (size > 0)
            {
                int bytesDecoded = decoder.Decode(faultBuffer, offset, size);
                offset += bytesDecoded;
                size -= bytesDecoded;
 
                if (decoder.CurrentState == ClientFramingDecoderState.Fault)
                {
                    ConnectionUtilities.CloseNoThrow(connection, timeoutHelper.RemainingTime());
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                        FaultStringDecoder.GetFaultException(decoder.Fault, via.ToString(), contentType));
                }
                else
                {
                    if (decoder.CurrentState != ClientFramingDecoderState.ReadingFaultString)
                    {
                        throw new Exception("invalid framing client state machine");
                    }
                    if (size == 0)
                    {
                        offset = 0;
                        size = connection.Read(faultBuffer, offset, faultBuffer.Length, timeoutHelper.RemainingTime());
                    }
                }
            }
 
            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(decoder.CreatePrematureEOFException());
        }
 
        public static bool InitiateUpgrade(StreamUpgradeInitiator upgradeInitiator, ref IConnection connection,
            ClientFramingDecoder decoder, IDefaultCommunicationTimeouts defaultTimeouts, ref TimeoutHelper timeoutHelper)
        {
            string upgradeContentType = upgradeInitiator.GetNextUpgrade();
 
            while (upgradeContentType != null)
            {
                EncodedUpgrade encodedUpgrade = new EncodedUpgrade(upgradeContentType);
                // write upgrade request framing for synchronization
                connection.Write(encodedUpgrade.EncodedBytes, 0, encodedUpgrade.EncodedBytes.Length, true, timeoutHelper.RemainingTime());
                byte[] buffer = new byte[1];
 
                // read upgrade response framing 
                int size = connection.Read(buffer, 0, buffer.Length, timeoutHelper.RemainingTime());
 
                if (!ValidateUpgradeResponse(buffer, size, decoder)) // we have a problem
                {
                    return false;
                }
 
                // initiate wire upgrade
                ConnectionStream connectionStream = new ConnectionStream(connection, defaultTimeouts);
                Stream upgradedStream = upgradeInitiator.InitiateUpgrade(connectionStream);
 
                // and re-wrap connection
                connection = new StreamConnection(upgradedStream, connectionStream);
 
                upgradeContentType = upgradeInitiator.GetNextUpgrade();
            }
 
            return true;
        }
 
        public static async Task<bool> InitiateUpgradeAsync(StreamUpgradeInitiator upgradeInitiator, OutWrapper<IConnection> connectionWrapper,
            ClientFramingDecoder decoder, IDefaultCommunicationTimeouts defaultTimeouts, TimeSpan timeout)
        {
            IConnection connection = connectionWrapper.Value;
            string upgradeContentType = upgradeInitiator.GetNextUpgrade();
 
            while (upgradeContentType != null)
            {
                EncodedUpgrade encodedUpgrade = new EncodedUpgrade(upgradeContentType);
                // write upgrade request framing for synchronization
                await connection.WriteAsync(encodedUpgrade.EncodedBytes, 0, encodedUpgrade.EncodedBytes.Length, true, timeout);
                byte[] buffer = new byte[1];
 
                // read upgrade response framing 
                int size = await connection.ReadAsync(buffer, 0, buffer.Length, timeout);
 
                if (!ValidateUpgradeResponse(buffer, size, decoder)) // we have a problem
                {
                    return false;
                }
 
                // initiate wire upgrade
                ConnectionStream connectionStream = new ConnectionStream(connection, defaultTimeouts);
                Stream upgradedStream = await upgradeInitiator.InitiateUpgradeAsync(connectionStream);
 
                // and re-wrap connection
                connection = new StreamConnection(upgradedStream, connectionStream);
                connectionWrapper.Value = connection;
 
                upgradeContentType = upgradeInitiator.GetNextUpgrade();
            }
 
            return true;
        }
 
        private static void ValidateReadingFaultString(ClientFramingDecoder decoder)
        {
            if (decoder.CurrentState != ClientFramingDecoderState.ReadingFaultString)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new System.ServiceModel.Security.MessageSecurityException(
                    SRServiceModel.ServerRejectedUpgradeRequest));
            }
        }
 
        public static bool ValidatePreambleResponse(byte[] buffer, int count, ClientFramingDecoder decoder, Uri via)
        {
            if (count == 0)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new ProtocolException(string.Format(SRServiceModel.ServerRejectedSessionPreamble, via),
                    decoder.CreatePrematureEOFException()));
            }
 
            // decode until the framing byte has been processed (it always will be)
            while (decoder.Decode(buffer, 0, count) == 0)
            {
                // do nothing
            }
 
            if (decoder.CurrentState != ClientFramingDecoderState.Start) // we have a problem
            {
                return false;
            }
 
            return true;
        }
 
        private static bool ValidateUpgradeResponse(byte[] buffer, int count, ClientFramingDecoder decoder)
        {
            if (count == 0)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new MessageSecurityException(SRServiceModel.ServerRejectedUpgradeRequest, decoder.CreatePrematureEOFException()));
            }
 
            // decode until the framing byte has been processed (it always will be)
            while (decoder.Decode(buffer, 0, count) == 0)
            {
                // do nothing
            }
 
            if (decoder.CurrentState != ClientFramingDecoderState.UpgradeResponse) // we have a problem
            {
                return false;
            }
 
            return true;
        }
    }
}