File: System\ServiceModel\Channels\StreamedFramingRequestChannel.cs
Web Access
Project: src\src\System.ServiceModel.NetFramingBase\src\System.ServiceModel.NetFramingBase.csproj (System.ServiceModel.NetFramingBase)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Diagnostics.Contracts;
using System.Runtime;
using System.Security.Authentication.ExtendedProtection;
using System.ServiceModel.Security;
using System.Threading.Tasks;
 
namespace System.ServiceModel.Channels
{
    internal class StreamedFramingRequestChannel : RequestChannel
    {
        internal IConnectionInitiator _connectionInitiator;
        internal ConnectionPool _connectionPool;
        private MessageEncoder _messageEncoder;
        private IConnectionOrientedTransportFactorySettings _settings;
        private StreamUpgradeProvider _upgrade;
        private ChannelBinding _channelBindingToken;
 
        public StreamedFramingRequestChannel(ChannelManagerBase factory, IConnectionOrientedTransportChannelFactorySettings settings,
            EndpointAddress remoteAddress, Uri via, IConnectionInitiator connectionInitiator, ConnectionPool connectionPool)
            : base(factory, remoteAddress, via, settings.ManualAddressing)
        {
            _settings = settings;
            _connectionInitiator = connectionInitiator;
            _connectionPool = connectionPool;
 
            _messageEncoder = settings.MessageEncoderFactory.Encoder;
            _upgrade = settings.Upgrade;
        }
 
        private Memory<byte> Preamble { get; set; }
 
        protected override IAsyncResult OnBeginOpen(TimeSpan timeout, AsyncCallback callback, object state) => Task.CompletedTask.ToApm(callback, state);
 
        protected override void OnEndOpen(IAsyncResult result) => result.ToApmEnd();
 
        protected override void OnOpen(TimeSpan timeout) { }
 
        protected internal override Task OnOpenAsync(TimeSpan timeout) => Task.CompletedTask;
 
        protected override void OnOpened()
        {
            // setup our preamble which we'll use for all connections we establish
            EncodedVia encodedVia = new EncodedVia(Via.AbsoluteUri);
            EncodedContentType encodedContentType = EncodedContentType.Create(_settings.MessageEncoderFactory.Encoder.ContentType);
            int startSize = ClientSingletonEncoder.ModeBytes.Length + ClientSingletonEncoder.CalcStartSize(encodedVia, encodedContentType);
            int preambleEndOffset = 0;
            if (_upgrade == null)
            {
                preambleEndOffset = startSize;
                startSize += SessionEncoder.PreambleEndBytes.Length;
            }
            Preamble = Fx.AllocateByteArray(startSize);
            ClientSingletonEncoder.ModeBytes.CopyTo(Preamble);
            ClientSingletonEncoder.EncodeStart(Preamble.Slice(ClientSingletonEncoder.ModeBytes.Length), encodedVia, encodedContentType);
            if (preambleEndOffset > 0)
            {
                ClientSingletonEncoder.PreambleEndBytes.CopyTo(Preamble.Slice(preambleEndOffset));
            }
 
            // and then transition to the Opened state
            base.OnOpened();
        }
 
        protected override IAsyncRequest CreateAsyncRequest(Message message) => new StreamedConnectionPoolHelper.StreamedFramingAsyncRequest(this);
 
        internal async Task<(IConnection connection, SecurityMessageProperty remoteSecurity)> SendPreambleAsync(IConnection connection, TimeoutHelper timeoutHelper, ClientFramingDecoder decoder)
        {
            SecurityMessageProperty remoteSecurity = null;
            await connection.WriteAsync(Preamble, true, timeoutHelper.RemainingTime());
 
            if (_upgrade != null)
            {
                IStreamUpgradeChannelBindingProvider channelBindingProvider = _upgrade.GetProperty<IStreamUpgradeChannelBindingProvider>();
                StreamUpgradeInitiator upgradeInitiator = _upgrade.CreateUpgradeInitiator(RemoteAddress, Via);
 
                bool upgradeInitiated;
                (upgradeInitiated, connection)= await ConnectionUpgradeHelper.InitiateUpgradeAsync(upgradeInitiator, connection, decoder, this, timeoutHelper.RemainingTime());
                if (!upgradeInitiated)
                {
                    await ConnectionUpgradeHelper.DecodeFramingFaultAsync(decoder, connection, Via, _messageEncoder.ContentType, timeoutHelper.RemainingTime());
                }
 
                if (channelBindingProvider != null && channelBindingProvider.IsChannelBindingSupportEnabled)
                {
                    _channelBindingToken = channelBindingProvider.GetChannelBinding(upgradeInitiator, ChannelBindingKind.Endpoint);
                }
 
                remoteSecurity = StreamSecurityUpgradeInitiator.GetRemoteSecurity(upgradeInitiator);
                await connection.WriteAsync(ClientSingletonEncoder.PreambleEndBytes, true, timeoutHelper.RemainingTime());
            }
 
            byte[] ackBuffer = new byte[1];
            int ackBytesRead = await connection.ReadAsync(ackBuffer, timeoutHelper.RemainingTime());
 
            if (!ConnectionUpgradeHelper.ValidatePreambleResponse(ackBuffer, ackBytesRead, decoder, Via))
            {
                await ConnectionUpgradeHelper.DecodeFramingFaultAsync(decoder, connection, Via,
                    _messageEncoder.ContentType, timeoutHelper.RemainingTime());
            }
 
            return (connection, remoteSecurity);
        }
 
        protected override void OnClose(TimeSpan timeout) => base.WaitForPendingRequests(timeout);
 
        protected internal override Task OnCloseAsync(TimeSpan timeout) => Task.CompletedTask;
 
        protected override void OnClosed()
        {
            base.OnClosed();
 
            // clean up the CBT after transitioning to the closed state
            ChannelBindingUtility.Dispose(ref _channelBindingToken);
        }
 
        protected override IAsyncResult OnBeginClose(TimeSpan timeout, AsyncCallback callback, object state) => WaitForPendingRequestsAsync(timeout).ToApm(callback, state);
 
        protected override void OnEndClose(IAsyncResult result) => result.ToApmEnd();
 
        internal class StreamedConnectionPoolHelper : ConnectionPoolHelper
        {
            private StreamedFramingRequestChannel _channel;
            private SecurityMessageProperty _remoteSecurity;
 
            public StreamedConnectionPoolHelper(StreamedFramingRequestChannel channel)
                : base(channel._connectionPool, channel._connectionInitiator, channel.Via)
            {
                _channel = channel;
            }
 
            public ClientSingletonDecoder Decoder { get; private set; }
 
            public SecurityMessageProperty RemoteSecurity
            {
                get { return _remoteSecurity; }
            }
 
            protected override TimeoutException CreateNewConnectionTimeoutException(TimeSpan timeout, TimeoutException innerException) => new TimeoutException(SR.Format(SR.RequestTimedOutEstablishingTransportSession,
                        timeout, _channel.Via.AbsoluteUri), innerException);
 
            protected override async Task<IConnection> AcceptPooledConnectionAsync(IConnection connection, TimeoutHelper timeoutHelper)
            {
                Decoder = new ClientSingletonDecoder(0);
                (connection, _remoteSecurity) = await _channel.SendPreambleAsync(connection, timeoutHelper, Decoder);
                return connection;
            }
 
            private class ClientSingletonConnectionReader : SingletonConnectionReader
            {
                private StreamedConnectionPoolHelper _connectionPoolHelper;
 
                public ClientSingletonConnectionReader(IConnection connection, StreamedConnectionPoolHelper connectionPoolHelper,
                    IConnectionOrientedTransportFactorySettings settings)
                    : base(connection, connectionPoolHelper.RemoteSecurity, settings, null)
                {
                    Contract.Assert(connectionPoolHelper != null);
                    _connectionPoolHelper = connectionPoolHelper;
                }
 
                protected override long StreamPosition
                {
                    get { return _connectionPoolHelper.Decoder.StreamPosition; }
                }
 
                protected override bool DecodeBytes(byte[] buffer, ref int offset, ref int size, ref bool isAtEof)
                {
                    while (size > 0)
                    {
                        int bytesRead = _connectionPoolHelper.Decoder.Decode(buffer, offset, size);
                        if (bytesRead > 0)
                        {
                            offset += bytesRead;
                            size -= bytesRead;
                        }
 
                        switch (_connectionPoolHelper.Decoder.CurrentState)
                        {
                            case ClientFramingDecoderState.EnvelopeStart:
                                // we're at the envelope
                                return true;
 
                            case ClientFramingDecoderState.End:
                                isAtEof = true;
                                return false;
                        }
                    }
 
                    return false;
                }
 
                protected override void OnClose(TimeSpan timeout) => _connectionPoolHelper.Close(timeout);
            }
 
            internal class StreamedFramingAsyncRequest : IAsyncRequest
            {
                private StreamedFramingRequestChannel _channel;
                private IConnection _connection;
                private StreamedConnectionPoolHelper _connectionPoolHelper;
                private Message _message;
                private TimeoutHelper _timeoutHelper;
                private ClientSingletonConnectionReader _connectionReader;
 
                public StreamedFramingAsyncRequest(StreamedFramingRequestChannel channel)
                {
                    _channel = channel;
                    _connectionPoolHelper = new StreamedConnectionPoolHelper(channel);
                }
 
                public async Task SendRequestAsync(Message message, TimeoutHelper timeoutHelper)
                {
                    _timeoutHelper = timeoutHelper;
                    _message = message;
 
                    bool success = false;
                    try
                    {
                        try
                        {
                            _connection = await _connectionPoolHelper.EstablishConnectionAsync(timeoutHelper.RemainingTime());
 
                            ChannelBindingUtility.TryAddToMessage(_channel._channelBindingToken, _message, false);
                            await StreamingConnectionHelper.WriteMessageAsync(_message, _connection, true, _channel._settings, timeoutHelper);
                        }
                        catch (TimeoutException exception)
                        {
                            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                                new TimeoutException(SR.Format(SR.TimeoutOnRequest, timeoutHelper.RemainingTime()), exception));
                        }
 
                        success = true;
                    }
                    finally
                    {
                        if (!success)
                        {
                            Cleanup();
                        }
                    }
                }
 
                public void Abort(RequestChannel requestChannel) => Cleanup();
 
                public void Fault(RequestChannel requestChannel) => Cleanup();
 
                private void Cleanup() => _connectionPoolHelper.Abort();
 
 
                public void OnReleaseRequest()
                {
                }
 
                public async Task<Message> ReceiveReplyAsync(TimeoutHelper timeoutHelper)
                {
                    try
                    {
                        _connectionReader = new ClientSingletonConnectionReader(_connection, _connectionPoolHelper, _channel._settings);
                        _connectionReader.DoneSending(TimeSpan.Zero);
                        Message message = await _connectionReader.ReceiveAsync(timeoutHelper);
                        if (message != null)
                        {
                            ChannelBindingUtility.TryAddToMessage(_channel._channelBindingToken, message, false);
                        }
 
                        return message;
                    }
                    catch (OperationCanceledException)
                    {
                        var cancelToken = _timeoutHelper.GetCancellationToken();
                        if (cancelToken.IsCancellationRequested)
                        {
                            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException(SR.Format(
                                SR.RequestChannelWaitForReplyTimedOut, timeoutHelper.OriginalTimeout)));
                        }
                        else
                        {
                            // Cancellation came from somewhere other than timeoutCts and needs to be handled differently.
                            throw;
                        }
                    }
                }
            }
        }
    }
}