File: System\ServiceModel\Channels\FramingChannels.cs
Web Access
Project: src\src\System.ServiceModel.NetFramingBase\src\System.ServiceModel.NetFramingBase.csproj (System.ServiceModel.NetFramingBase)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
 
using System.Diagnostics.Contracts;
using System.IO;
using System.Runtime;
using System.ServiceModel.Security;
using System.Threading.Tasks;
using System.IdentityModel.Claims;
using System.Buffers;
using System.Security.Authentication.ExtendedProtection;
 
namespace System.ServiceModel.Channels
{
    internal abstract class FramingDuplexSessionChannel : TransportDuplexSessionChannel
    {
        private static EndpointAddress s_anonymousEndpointAddress = new EndpointAddress(EndpointAddress.AnonymousUri, new AddressHeader[0]);
 
        private FramingDuplexSessionChannel(ChannelManagerBase manager, IConnectionOrientedTransportFactorySettings settings,
            EndpointAddress localAddress, Uri localVia, EndpointAddress remoteAddress, Uri via)
            : base(manager, settings, localAddress, localVia, remoteAddress, via)
        {
        }
 
        protected FramingDuplexSessionChannel(ChannelManagerBase factory, IConnectionOrientedTransportFactorySettings settings,
            EndpointAddress remoteAddress, Uri via)
            : this(factory, settings, s_anonymousEndpointAddress, settings.MessageVersion.Addressing == AddressingVersion.None ? null : new Uri("http://www.w3.org/2005/08/addressing/anonymous"),
            remoteAddress, via)
        {
            Session = FramingConnectionDuplexSession.CreateSession(this, settings.Upgrade);
        }
 
        protected IConnection Connection { get; set; }
 
        protected override bool IsStreamedOutput
        {
            get { return false; }
        }
 
        protected override ValueTask CloseOutputSessionCoreAsync(TimeSpan timeout)
        {
            return Connection.WriteAsync(SessionEncoder.EndBytes, true, timeout);
        }
 
        protected override void CompleteClose(TimeSpan timeout)
        {
            ReturnConnectionIfNecessary(false, timeout);
        }
 
        protected override async ValueTask OnSendCoreAsync(Message message, TimeSpan timeout)
        {
            bool allowOutputBatching;
            ArraySegment<byte> messageData;
 
            allowOutputBatching = message.Properties.AllowOutputBatching;
            messageData = EncodeMessage(message);
 
            await TaskHelpers.EnsureDefaultTaskScheduler();
            await Connection.WriteAsync(messageData, !allowOutputBatching, timeout);
            BufferManager.ReturnBuffer(messageData.Array);
        }
 
        protected override ValueTask CloseOutputAsync(TimeSpan timeout)
        {
            return Connection.WriteAsync(SessionEncoder.EndBytes, true, timeout);
        }
 
        protected override ValueTask StartWritingBufferedMessage(Message message, ArraySegment<byte> messageData, bool allowOutputBatching, TimeSpan timeout)
        {
            return Connection.WriteAsync(messageData, !allowOutputBatching, timeout);
        }
 
        protected override ValueTask StartWritingStreamedMessage(Message message, TimeSpan timeout)
        {
            Contract.Assert(false, "Streamed output should never be called in this channel.");
            return ValueTask.FromException(new InvalidOperationException());
        }
 
        protected override ArraySegment<byte> EncodeMessage(Message message)
        {
            ArraySegment<byte> messageData = MessageEncoder.WriteMessage(message,
                int.MaxValue, BufferManager, SessionEncoder.MaxMessageFrameSize);
 
            messageData = SessionEncoder.EncodeMessageFrame(messageData);
 
            return messageData;
        }
 
        internal class FramingConnectionDuplexSession : ConnectionDuplexSession
        {
            private FramingConnectionDuplexSession(FramingDuplexSessionChannel channel)
                : base(channel)
            {
            }
 
            public static FramingConnectionDuplexSession CreateSession(FramingDuplexSessionChannel channel,
                StreamUpgradeProvider upgrade)
            {
                StreamSecurityUpgradeProvider security = upgrade as StreamSecurityUpgradeProvider;
                if (security == null)
                {
                    return new FramingConnectionDuplexSession(channel);
                }
                else
                {
                    return new SecureConnectionDuplexSession(channel);
                }
            }
            private class SecureConnectionDuplexSession : FramingConnectionDuplexSession, ISecuritySession
            {
                private EndpointIdentity _remoteIdentity;
 
                public SecureConnectionDuplexSession(FramingDuplexSessionChannel channel)
                    : base(channel)
                {
                    // empty
                }
 
                EndpointIdentity ISecuritySession.RemoteIdentity
                {
                    get
                    {
                        if (_remoteIdentity == null)
                        {
                            SecurityMessageProperty security = Channel.RemoteSecurity;
                            if (security != null && security.ServiceSecurityContext != null)
                            {
                                Claim identityClaim = SecurityUtils.GetPrimaryIdentityClaim(security.ServiceSecurityContext.AuthorizationContext);
                                if (identityClaim != null && security.ServiceSecurityContext.PrimaryIdentity != null)
                                {
                                    _remoteIdentity = EndpointIdentity.CreateIdentity(identityClaim);
                                }
                            }
                        }
 
                        return _remoteIdentity;
                    }
                }
            }
        }
    }
 
    internal class ClientFramingDuplexSessionChannel : FramingDuplexSessionChannel
    {
        private IConnectionOrientedTransportChannelFactorySettings _settings;
        private ClientDuplexDecoder _decoder;
        private StreamUpgradeProvider _upgrade;
        private ConnectionPoolHelper _connectionPoolHelper;
 
        public ClientFramingDuplexSessionChannel(ChannelManagerBase factory, IConnectionOrientedTransportChannelFactorySettings settings,
            EndpointAddress remoteAddress, Uri via, IConnectionInitiator connectionInitiator, ConnectionPool connectionPool)
            : base(factory, settings, remoteAddress, via)
        {
            _settings = settings;
            MessageEncoder = settings.MessageEncoderFactory.CreateSessionEncoder();
            _upgrade = settings.Upgrade;
            _connectionPoolHelper = new DuplexConnectionPoolHelper(this, connectionPool, connectionInitiator);
        }
 
        private Memory<byte> CreatePreamble()
        {
            EncodedVia encodedVia = new EncodedVia(Via.AbsoluteUri);
            EncodedContentType encodedContentType = EncodedContentType.Create(MessageEncoder.ContentType);
 
            // calculate preamble length
            int startSize = ClientDuplexEncoder.ModeBytes.Length + SessionEncoder.CalcStartSize(encodedVia, encodedContentType);
            int preambleEndOffset = 0;
            if (_upgrade == null)
            {
                preambleEndOffset = startSize;
                startSize += SessionEncoder.PreambleEndBytes.Length;
            }
 
            Memory<byte> startBytes = Fx.AllocateByteArray(startSize);
            ClientDuplexEncoder.ModeBytes.CopyTo(startBytes);
            SessionEncoder.EncodeStart(startBytes.Slice(ClientDuplexEncoder.ModeBytes.Length), encodedVia, encodedContentType);
            if (preambleEndOffset > 0)
            {
                SessionEncoder.PreambleEndBytes.CopyTo(startBytes.Slice(preambleEndOffset));
            }
 
            return startBytes;
        }
 
        protected override IAsyncResult OnBeginOpen(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return OnOpenAsync(timeout).ToApm(callback, state);
        }
 
        protected override void OnEndOpen(IAsyncResult result)
        {
            result.ToApmEnd();
        }
 
        public override T GetProperty<T>()
        {
            T result = base.GetProperty<T>();
 
            if (result == null && _upgrade != null)
            {
                result = _upgrade.GetProperty<T>();
            }
 
            return result;
        }
 
        private async Task<IConnection> SendPreambleAsync(IConnection connection, Memory<byte> preamble, TimeSpan timeout)
        {
            var timeoutHelper = new TimeoutHelper(timeout);
 
            // initialize a new decoder
            _decoder = new ClientDuplexDecoder(0);
            byte[] ackBuffer = new byte[1];
 
            if (!await SendLock.WaitAsync(TimeoutHelper.ToMilliseconds(timeout)))
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException(
                                                SR.Format(SR.CloseTimedOut, timeout),
                                                TimeoutHelper.CreateEnterTimedOutException(timeout)));
            }
 
            try
            {
                await connection.WriteAsync(preamble, true, timeoutHelper.RemainingTime());
 
                if (_upgrade != null)
                {
                    IStreamUpgradeChannelBindingProvider channelBindingProvider = _upgrade.GetProperty<IStreamUpgradeChannelBindingProvider>();
                    StreamUpgradeInitiator upgradeInitiator = _upgrade.CreateUpgradeInitiator(RemoteAddress, Via);
 
                    await upgradeInitiator.OpenAsync(timeoutHelper.RemainingTime());
                    bool upgradeInitiated;
                    (upgradeInitiated, connection) = await ConnectionUpgradeHelper.InitiateUpgradeAsync(upgradeInitiator, connection, _decoder, this, timeoutHelper.RemainingTime());
                    if (!upgradeInitiated)
                    {
                        await ConnectionUpgradeHelper.DecodeFramingFaultAsync(_decoder, connection, Via, MessageEncoder.ContentType, timeoutHelper.RemainingTime());
                    }
 
                    if (channelBindingProvider != null && channelBindingProvider.IsChannelBindingSupportEnabled)
                    {
                        SetChannelBinding(channelBindingProvider.GetChannelBinding(upgradeInitiator, ChannelBindingKind.Endpoint));
                    }
 
                    SetRemoteSecurity(upgradeInitiator);
                    await upgradeInitiator.CloseAsync(timeoutHelper.RemainingTime());
 
                    await connection.WriteAsync(SessionEncoder.PreambleEndBytes, true, timeoutHelper.RemainingTime());
                }
 
                int ackBytesRead = await connection.ReadAsync(ackBuffer, timeoutHelper.RemainingTime());
 
                if (!ConnectionUpgradeHelper.ValidatePreambleResponse(ackBuffer, ackBytesRead, _decoder, Via))
                {
                    await ConnectionUpgradeHelper.DecodeFramingFaultAsync(_decoder, connection, Via,
                        MessageEncoder.ContentType, timeoutHelper.RemainingTime());
                }
 
                return connection;
            }
            finally
            {
                SendLock.Release();
            }
        }
 
        internal protected override async Task OnOpenAsync(TimeSpan timeout)
        {
            IConnection connection;
            try
            {
                using (TaskHelpers.RunTaskContinuationsOnOurThreads())
                {
                    connection = await _connectionPoolHelper.EstablishConnectionAsync(timeout);
                }
            }
            catch (TimeoutException exception)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new TimeoutException(SR.Format(SR.TimeoutOnOpen, timeout), exception));
            }
 
            bool connectionAccepted = false;
            try
            {
                AcceptConnection(connection);
                connectionAccepted = true;
            }
            finally
            {
                if (!connectionAccepted)
                {
                    _connectionPoolHelper.Abort();
                }
            }
        }
 
        protected override void OnOpen(TimeSpan timeout)
        {
            OnOpenAsync(timeout).GetAwaiter().GetResult();
        }
 
        protected override void ReturnConnectionIfNecessary(bool abort, TimeSpan timeout)
        {
            lock (ThisLock)
            {
                if (abort)
                {
                    _connectionPoolHelper.Abort();
                }
                else
                {
                    _connectionPoolHelper.Close(timeout);
                }
            }
        }
 
        private void AcceptConnection(IConnection connection)
        {
            SetMessageSource(new ClientDuplexConnectionReader(this, connection, _decoder, _settings, MessageEncoder));
 
            lock (ThisLock)
            {
                if (State != CommunicationState.Opening)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                        new CommunicationObjectAbortedException(SR.Format(SR.DuplexChannelAbortedDuringOpen, Via)));
                }
 
                Connection = connection;
            }
        }
 
        private void SetRemoteSecurity(StreamUpgradeInitiator upgradeInitiator)
        {
            RemoteSecurity = StreamSecurityUpgradeInitiator.GetRemoteSecurity(upgradeInitiator);
        }
 
        protected override void PrepareMessage(Message message)
        {
            base.PrepareMessage(message);
        }
 
        internal class DuplexConnectionPoolHelper : ConnectionPoolHelper
        {
            private ClientFramingDuplexSessionChannel _channel;
            private Memory<byte> _preamble;
 
            public DuplexConnectionPoolHelper(ClientFramingDuplexSessionChannel channel,
                ConnectionPool connectionPool, IConnectionInitiator connectionInitiator)
                : base(connectionPool, connectionInitiator, channel.Via)
            {
                _channel = channel;
                _preamble = channel.CreatePreamble();
            }
 
            protected override TimeoutException CreateNewConnectionTimeoutException(TimeSpan timeout, TimeoutException innerException)
            {
                return new TimeoutException(SR.Format(SR.OpenTimedOutEstablishingTransportSession,
                        timeout, _channel.Via.AbsoluteUri), innerException);
            }
 
            protected override Task<IConnection> AcceptPooledConnectionAsync(IConnection connection, TimeoutHelper timeoutHelper)
            {
                return _channel.SendPreambleAsync(connection, _preamble, timeoutHelper.RemainingTime());
            }
        }
    }
 
    // used by StreamedFramingRequestChannel and ClientFramingDuplexSessionChannel
    internal class ConnectionUpgradeHelper
    {
        public static async Task DecodeFramingFaultAsync(ClientFramingDecoder decoder, IConnection connection,
            Uri via, string contentType, TimeSpan timeout)
        {
            var timeoutHelper = new TimeoutHelper(timeout);
            ValidateReadingFaultString(decoder);
 
            byte[] faultBuffer = ArrayPool<byte>.Shared.Rent(FaultStringDecoder.FaultSizeQuota);
            int size = await connection.ReadAsync(new Memory<byte>(faultBuffer,0,
                Math.Min(FaultStringDecoder.FaultSizeQuota, connection.ConnectionBufferSize)),
                timeoutHelper.RemainingTime());
 
            int offset = 0;
            while (size > 0)
            {
                int bytesDecoded = decoder.Decode(faultBuffer, offset, size);
                offset += bytesDecoded;
                size -= bytesDecoded;
 
                if (decoder.CurrentState == ClientFramingDecoderState.Fault)
                {
                    await ConnectionUtilities.CloseNoThrowAsync(connection, timeoutHelper.RemainingTime());
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                        FaultStringDecoder.GetFaultException(decoder.Fault, via.ToString(), contentType));
                }
                else
                {
                    if (decoder.CurrentState != ClientFramingDecoderState.ReadingFaultString)
                    {
                        throw new Exception("invalid framing client state machine");
                    }
                    if (size == 0)
                    {
                        offset = 0;
                        size = await connection.ReadAsync(new Memory<byte>(faultBuffer, 0,
                            Math.Min(FaultStringDecoder.FaultSizeQuota, connection.ConnectionBufferSize)),
                            timeoutHelper.RemainingTime());
                    }
                }
            }
 
            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(decoder.CreatePrematureEOFException());
        }
 
        public static async Task<(bool success, IConnection connection)> InitiateUpgradeAsync(
            StreamUpgradeInitiator upgradeInitiator, IConnection connection,
            ClientFramingDecoder decoder, IDefaultCommunicationTimeouts defaultTimeouts, TimeSpan timeout)
        {
            string upgradeContentType = upgradeInitiator.GetNextUpgrade();
 
            while (upgradeContentType != null)
            {
                EncodedUpgrade encodedUpgrade = new EncodedUpgrade(upgradeContentType);
                // write upgrade request framing for synchronization
                await connection.WriteAsync(encodedUpgrade.EncodedBytes, true, timeout);
                byte[] buffer = new byte[1];
 
                // read upgrade response framing 
                int size = await connection.ReadAsync(buffer, timeout);
 
                if (!ValidateUpgradeResponse(buffer, size, decoder)) // we have a problem
                {
                    return (false, connection);
                }
 
                // initiate wire upgrade
                ConnectionStream connectionStream = new ConnectionStream(connection, defaultTimeouts);
                Stream upgradedStream = await upgradeInitiator.InitiateUpgradeAsync(connectionStream);
 
                // and re-wrap connection
                connection = new StreamConnection(upgradedStream, connectionStream);
                upgradeContentType = upgradeInitiator.GetNextUpgrade();
            }
 
            return (true, connection);
        }
 
        private static void ValidateReadingFaultString(ClientFramingDecoder decoder)
        {
            if (decoder.CurrentState != ClientFramingDecoderState.ReadingFaultString)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new MessageSecurityException(
                    SR.ServerRejectedUpgradeRequest));
            }
        }
 
        public static bool ValidatePreambleResponse(byte[] buffer, int count, ClientFramingDecoder decoder, Uri via)
        {
            if (count == 0)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new ProtocolException(SR.Format(SR.ServerRejectedSessionPreamble, via),
                    decoder.CreatePrematureEOFException()));
            }
 
            // decode until the framing byte has been processed (it always will be)
            while(decoder.Decode(buffer, 0, count) == 0)
            {
                // do nothing
            }
 
            if (decoder.CurrentState != ClientFramingDecoderState.Start) // we have a problem
            {
                return false;
            }
 
            return true;
        }
 
        private static bool ValidateUpgradeResponse(byte[] buffer, int count, ClientFramingDecoder decoder)
        {
            if (count == 0)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new MessageSecurityException(SR.ServerRejectedUpgradeRequest, decoder.CreatePrematureEOFException()));
            }
 
            // decode until the framing byte has been processed (it always will be)
            while (decoder.Decode(buffer, 0, count) == 0)
            {
                // do nothing
            }
 
            if (decoder.CurrentState != ClientFramingDecoderState.UpgradeResponse) // we have a problem
            {
                return false;
            }
 
            return true;
        }
    }
}