File: FrameworkFork\System.ServiceModel\System\ServiceModel\Channels\SessionConnectionReader.cs
Web Access
Project: src\src\dotnet-svcutil\lib\src\dotnet-svcutil-lib.csproj (dotnet-svcutil-lib)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Runtime;
using System.ServiceModel.Security;
using System.Threading.Tasks;
using Microsoft.Xml;
 
namespace System.ServiceModel.Channels
{
    internal abstract class SessionConnectionReader : IMessageSource
    {
        private bool _isAtEOF;
        private bool _usingAsyncReadBuffer;
        private IConnection _connection;
        private byte[] _buffer;
        private int _offset;
        private int _size;
        private byte[] _envelopeBuffer;
        private int _envelopeOffset;
        private int _envelopeSize;
        private bool _readIntoEnvelopeBuffer;
        private Message _pendingMessage;
        private Exception _pendingException;
        private SecurityMessageProperty _security;
        // Raw connection that we will revert to after end handshake
        private IConnection _rawConnection;
 
        protected SessionConnectionReader(IConnection connection, IConnection rawConnection,
            int offset, int size, SecurityMessageProperty security)
        {
            _offset = offset;
            _size = size;
            if (size > 0)
            {
                _buffer = connection.AsyncReadBuffer;
            }
            _connection = connection;
            _rawConnection = rawConnection;
            _security = security;
        }
 
        private Message DecodeMessage(TimeSpan timeout)
        {
            if (!_readIntoEnvelopeBuffer)
            {
                return DecodeMessage(_buffer, ref _offset, ref _size, ref _isAtEOF, timeout);
            }
            else
            {
                // decode from the envelope buffer
                int dummyOffset = _envelopeOffset;
                return DecodeMessage(_envelopeBuffer, ref dummyOffset, ref _size, ref _isAtEOF, timeout);
            }
        }
 
        protected abstract Message DecodeMessage(byte[] buffer, ref int offset, ref int size, ref bool isAtEof, TimeSpan timeout);
 
        protected byte[] EnvelopeBuffer
        {
            get { return _envelopeBuffer; }
            set { _envelopeBuffer = value; }
        }
 
        protected int EnvelopeOffset
        {
            get { return _envelopeOffset; }
            set { _envelopeOffset = value; }
        }
 
        protected int EnvelopeSize
        {
            get { return _envelopeSize; }
            set { _envelopeSize = value; }
        }
 
        public IConnection GetRawConnection()
        {
            IConnection result = null;
            if (_rawConnection != null)
            {
                result = _rawConnection;
                _rawConnection = null;
                if (_size > 0)
                {
                    PreReadConnection preReadConnection = result as PreReadConnection;
                    if (preReadConnection != null) // make sure we don't keep wrapping
                    {
                        preReadConnection.AddPreReadData(_buffer, _offset, _size);
                    }
                    else
                    {
                        result = new PreReadConnection(result, _buffer, _offset, _size);
                    }
                }
            }
 
            return result;
        }
 
        public async Task<Message> ReceiveAsync(TimeSpan timeout)
        {
            Message message = GetPendingMessage();
 
            if (message != null)
            {
                return message;
            }
 
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            for (; ; )
            {
                if (_isAtEOF)
                {
                    return null;
                }
 
                if (_size > 0)
                {
                    message = DecodeMessage(timeoutHelper.RemainingTime());
 
                    if (message != null)
                    {
                        PrepareMessage(message);
                        return message;
                    }
                    else if (_isAtEOF) // could have read the END record under DecodeMessage
                    {
                        return null;
                    }
                }
 
                if (_size != 0)
                {
                    throw new Exception("Receive: DecodeMessage() should consume the outstanding buffer or return a message.");
                }
 
                if (!_usingAsyncReadBuffer)
                {
                    _buffer = _connection.AsyncReadBuffer;
                    _usingAsyncReadBuffer = true;
                }
 
                int bytesRead;
 
                var tcs = new TaskCompletionSource<bool>();
                var result = _connection.BeginRead(0, _buffer.Length, timeoutHelper.RemainingTime(), TaskHelpers.OnAsyncCompletionCallback, tcs);
                if (result == AsyncCompletionResult.Completed)
                {
                    tcs.TrySetResult(true);
                }
                await tcs.Task;
 
                bytesRead = _connection.EndRead();
                HandleReadComplete(bytesRead, false);
            }
        }
 
        public Message Receive(TimeSpan timeout)
        {
            Message message = GetPendingMessage();
 
            if (message != null)
            {
                return message;
            }
 
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            for (; ; )
            {
                if (_isAtEOF)
                {
                    return null;
                }
 
                if (_size > 0)
                {
                    message = DecodeMessage(timeoutHelper.RemainingTime());
 
                    if (message != null)
                    {
                        PrepareMessage(message);
                        return message;
                    }
                    else if (_isAtEOF) // could have read the END record under DecodeMessage
                    {
                        return null;
                    }
                }
 
                if (_size != 0)
                {
                    throw new Exception("Receive: DecodeMessage() should consume the outstanding buffer or return a message.");
                }
 
                if (_buffer == null)
                {
                    _buffer = Fx.AllocateByteArray(_connection.AsyncReadBufferSize);
                }
 
                int bytesRead;
 
                if (EnvelopeBuffer != null &&
                    (EnvelopeSize - EnvelopeOffset) >= _buffer.Length)
                {
                    bytesRead = _connection.Read(EnvelopeBuffer, EnvelopeOffset, _buffer.Length, timeoutHelper.RemainingTime());
                    HandleReadComplete(bytesRead, true);
                }
                else
                {
                    bytesRead = _connection.Read(_buffer, 0, _buffer.Length, timeoutHelper.RemainingTime());
                    HandleReadComplete(bytesRead, false);
                }
            }
        }
 
        public Message EndReceive()
        {
            return GetPendingMessage();
        }
 
        private Message GetPendingMessage()
        {
            if (_pendingException != null)
            {
                Exception exception = _pendingException;
                _pendingException = null;
                throw exception;
            }
 
            if (_pendingMessage != null)
            {
                Message message = _pendingMessage;
                _pendingMessage = null;
                return message;
            }
 
            return null;
        }
 
        public async Task<bool> WaitForMessageAsync(TimeSpan timeout)
        {
            try
            {
                Message message = await ReceiveAsync(timeout);
                _pendingMessage = message;
                return true;
            }
            catch (TimeoutException e)
            {
                if (WcfEventSource.Instance.ReceiveTimeoutIsEnabled())
                {
                    WcfEventSource.Instance.ReceiveTimeout(e.Message);
                }
 
                return false;
            }
        }
 
        public bool WaitForMessage(TimeSpan timeout)
        {
            try
            {
                Message message = Receive(timeout);
                _pendingMessage = message;
                return true;
            }
            catch (TimeoutException e)
            {
                if (WcfEventSource.Instance.ReceiveTimeoutIsEnabled())
                {
                    WcfEventSource.Instance.ReceiveTimeout(e.Message);
                }
                return false;
            }
        }
 
        protected abstract void EnsureDecoderAtEof();
 
        private void HandleReadComplete(int bytesRead, bool readIntoEnvelopeBuffer)
        {
            _readIntoEnvelopeBuffer = readIntoEnvelopeBuffer;
 
            if (bytesRead == 0)
            {
                EnsureDecoderAtEof();
                _isAtEOF = true;
            }
            else
            {
                _offset = 0;
                _size = bytesRead;
            }
        }
 
        protected virtual void PrepareMessage(Message message)
        {
            if (_security != null)
            {
                message.Properties.Security = (SecurityMessageProperty)_security.CreateCopy();
            }
        }
    }
 
 
    internal class ClientDuplexConnectionReader : SessionConnectionReader
    {
        private ClientDuplexDecoder _decoder;
        private int _maxBufferSize;
        private BufferManager _bufferManager;
        private MessageEncoder _messageEncoder;
        private ClientFramingDuplexSessionChannel _channel;
 
        public ClientDuplexConnectionReader(ClientFramingDuplexSessionChannel channel, IConnection connection, ClientDuplexDecoder decoder,
            IConnectionOrientedTransportFactorySettings settings, MessageEncoder messageEncoder)
            : base(connection, null, 0, 0, null)
        {
            _decoder = decoder;
            _maxBufferSize = settings.MaxBufferSize;
            _bufferManager = settings.BufferManager;
            _messageEncoder = messageEncoder;
            _channel = channel;
        }
 
        protected override void EnsureDecoderAtEof()
        {
            if (!(_decoder.CurrentState == ClientFramingDecoderState.End
                || _decoder.CurrentState == ClientFramingDecoderState.EnvelopeEnd
                || _decoder.CurrentState == ClientFramingDecoderState.ReadingUpgradeRecord
                || _decoder.CurrentState == ClientFramingDecoderState.UpgradeResponse))
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(_decoder.CreatePrematureEOFException());
            }
        }
 
        private static IDisposable CreateProcessActionActivity()
        {
            return null;
        }
 
        protected override Message DecodeMessage(byte[] buffer, ref int offset, ref int size, ref bool isAtEOF, TimeSpan timeout)
        {
            while (size > 0)
            {
                int bytesRead = _decoder.Decode(buffer, offset, size);
                if (bytesRead > 0)
                {
                    if (EnvelopeBuffer != null)
                    {
                        if (!object.ReferenceEquals(buffer, EnvelopeBuffer))
                            System.Buffer.BlockCopy(buffer, offset, EnvelopeBuffer, EnvelopeOffset, bytesRead);
                        EnvelopeOffset += bytesRead;
                    }
 
                    offset += bytesRead;
                    size -= bytesRead;
                }
 
                switch (_decoder.CurrentState)
                {
                    case ClientFramingDecoderState.Fault:
                        _channel.Session.CloseOutputSession(_channel.GetInternalCloseTimeout());
                        throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(FaultStringDecoder.GetFaultException(_decoder.Fault, _channel.RemoteAddress.Uri.ToString(), _messageEncoder.ContentType));
 
                    case ClientFramingDecoderState.End:
                        isAtEOF = true;
                        return null; // we're done
 
                    case ClientFramingDecoderState.EnvelopeStart:
                        int envelopeSize = _decoder.EnvelopeSize;
                        if (envelopeSize > _maxBufferSize)
                        {
                            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                                ExceptionHelper.CreateMaxReceivedMessageSizeExceededException(_maxBufferSize));
                        }
                        EnvelopeBuffer = _bufferManager.TakeBuffer(envelopeSize);
                        EnvelopeOffset = 0;
                        EnvelopeSize = envelopeSize;
                        break;
 
                    case ClientFramingDecoderState.EnvelopeEnd:
                        if (EnvelopeBuffer != null)
                        {
                            Message message = null;
                            try
                            {
                                IDisposable activity = ClientDuplexConnectionReader.CreateProcessActionActivity();
                                using (activity)
                                {
                                    message = _messageEncoder.ReadMessage(new ArraySegment<byte>(EnvelopeBuffer, 0, EnvelopeSize), _bufferManager);
                                }
                            }
                            catch (XmlException xmlException)
                            {
                                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                                    new ProtocolException(SRServiceModel.MessageXmlProtocolError, xmlException));
                            }
                            EnvelopeBuffer = null;
                            return message;
                        }
                        break;
                }
            }
            return null;
        }
    }
}