File: System\ServiceModel\Channels\SingletonConnectionReader.cs
Web Access
Project: src\src\System.ServiceModel.NetFramingBase\src\System.ServiceModel.NetFramingBase.csproj (System.ServiceModel.NetFramingBase)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics.Contracts;
using System.IO;
using System.Runtime;
using System.Threading.Tasks;
using System.Threading;
using System.Xml;
using System.ServiceModel.Security;
using System.ServiceModel.Diagnostics;
using System.Buffers;
using System.Runtime.InteropServices;
 
namespace System.ServiceModel.Channels
{
    internal abstract class SingletonConnectionReader
    {
        private bool _doneReceiving;
        private bool _doneSending;
        private bool _isAtEof;
        private bool _isClosed;
        private SecurityMessageProperty _security;
        private IConnectionOrientedTransportFactorySettings _transportSettings;
        private Uri _via;
        private Stream _inputStream;
 
        protected SingletonConnectionReader(IConnection connection, SecurityMessageProperty security,
            IConnectionOrientedTransportFactorySettings transportSettings, Uri via)
        {
            Contract.Assert(connection != null);
 
            Connection = connection;
            _security = security;
            _transportSettings = transportSettings;
            _via = via;
        }
 
        protected IConnection Connection { get; }
 
        protected object ThisLock { get; } = new object();
 
        protected virtual string ContentType
        {
            get { return null; }
        }
 
        protected abstract long StreamPosition { get; }
 
        public void Abort()
        {
            Connection.Abort();
        }
 
        public void DoneReceiving(bool atEof)
        {
            DoneReceiving(atEof, _transportSettings.CloseTimeout);
        }
 
        private void DoneReceiving(bool atEof, TimeSpan timeout)
        {
            if (!_doneReceiving)
            {
                _isAtEof = atEof;
                _doneReceiving = true;
 
                if (_doneSending)
                {
                    Close(timeout);
                }
            }
        }
 
        public void Close(TimeSpan timeout)
        {
            lock (ThisLock)
            {
                if (_isClosed)
                {
                    return;
                }
 
                _isClosed = true;
            }
 
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            bool success = false;
            try
            {
                // first drain our stream if necessary
                if (_inputStream != null)
                {
                    byte[] dummy = Fx.AllocateByteArray(_transportSettings.ConnectionBufferSize);
                    while (!_isAtEof)
                    {
                        _inputStream.ReadTimeout = TimeoutHelper.ToMilliseconds(timeoutHelper.RemainingTime());
                        int bytesRead = _inputStream.Read(dummy, 0, dummy.Length);
                        if (bytesRead == 0)
                        {
                            _isAtEof = true;
                        }
                    }
                }
                OnClose(timeoutHelper.RemainingTime());
                success = true;
            }
            finally
            {
                if (!success)
                {
                    Abort();
                }
            }
        }
 
        protected abstract void OnClose(TimeSpan timeout);
 
        public void DoneSending(TimeSpan timeout)
        {
            _doneSending = true;
            if (_doneReceiving)
            {
                Close(timeout);
            }
        }
 
        protected abstract bool DecodeBytes(byte[] buffer, ref int offset, ref int size, ref bool isAtEof);
 
        protected virtual void PrepareMessage(Message message)
        {
            message.Properties.Via = _via;
            message.Properties.Security = (_security != null) ? (SecurityMessageProperty)_security.CreateCopy() : null;
        }
 
        public async Task<Message> ReceiveAsync(TimeoutHelper timeoutHelper)
        {
            byte[] buffer = ArrayPool<byte>.Shared.Rent(Connection.ConnectionBufferSize);
            int offset = 0, size = 0;
            for (; ; )
            {
                if (DecodeBytes(buffer, ref offset, ref size, ref _isAtEof))
                {
                    break;
                }
 
                if (_isAtEof)
                {
                    DoneReceiving(true, timeoutHelper.RemainingTime());
                    return null;
                }
 
                if (size == 0)
                {
                    offset = 0;
                    size = await Connection.ReadAsync(new Memory<byte>(buffer, 0, Connection.ConnectionBufferSize), timeoutHelper.RemainingTime());
                    if (size == 0)
                    {
                        DoneReceiving(true, timeoutHelper.RemainingTime());
                        ArrayPool<byte>.Shared.Return(buffer);
                        return null;
                    }
                }
            }
 
            // we're ready to read a message
            IConnection singletonConnection = Connection;
            if (size > 0)
            {
                byte[] initialData = Fx.AllocateByteArray(size);
                Buffer.BlockCopy(buffer, offset, initialData, 0, size);
                singletonConnection = new PreReadConnection(singletonConnection, initialData);
            }
 
            ArrayPool<byte>.Shared.Return(buffer);
 
            Stream connectionStream = new SingletonInputConnectionStream(this, singletonConnection, _transportSettings);
            _inputStream = new MaxMessageSizeStream(connectionStream, _transportSettings.MaxReceivedMessageSize);
            using (ServiceModelActivity activity = DiagnosticUtility.ShouldUseActivity ? ServiceModelActivity.CreateBoundedActivity(true) : null)
            {
                if (DiagnosticUtility.ShouldUseActivity)
                {
                    ServiceModelActivity.Start(activity, SR.Format(SR.ActivityProcessingMessage, TraceUtility.RetrieveMessageNumber()), ActivityType.ProcessMessage);
                }
 
                Message message = null;
                try
                {
                    message = _transportSettings.MessageEncoderFactory.Encoder.ReadMessage(
                        _inputStream, _transportSettings.MaxBufferSize, ContentType);
                }
                catch (XmlException xmlException)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                        new ProtocolException(SR.Format(SR.MessageXmlProtocolError), xmlException));
                }
 
                if (DiagnosticUtility.ShouldUseActivity)
                {
                    TraceUtility.TransferFromTransport(message);
                }
 
                PrepareMessage(message);
 
                return message;
            }
        }
 
        // ensures that the reader is notified at end-of-stream, and takes care of the framing chunk headers
        private class SingletonInputConnectionStream : ConnectionStream
        {
            private SingletonMessageDecoder _decoder;
            private SingletonConnectionReader _reader;
            private bool _atEof;
            private byte[] _chunkBuffer; // used for when we have overflow
            private int _chunkBufferOffset;
            private int _chunkBufferSize;
            private int _chunkBytesRemaining;
 
            public SingletonInputConnectionStream(SingletonConnectionReader reader, IConnection connection,
                IDefaultCommunicationTimeouts defaultTimeouts)
                : base(connection, defaultTimeouts)
            {
                _reader = reader;
                _decoder = new SingletonMessageDecoder(reader.StreamPosition);
                _chunkBytesRemaining = 0;
                _chunkBuffer = new byte[IntEncoder.MaxEncodedSize];
            }
 
            private void AbortReader()
            {
                _reader.Abort();
            }
 
            protected override void Dispose(bool disposing)
            {
                if (disposing)
                {
                    _reader.DoneReceiving(_atEof);
                }
            }
 
            // run chunk data through the decoder
            private void DecodeData(byte[] buffer, int offset, int size)
            {
                while (size > 0)
                {
                    int bytesRead = _decoder.Decode(buffer, offset, size);
                    offset += bytesRead;
                    size -= bytesRead;
                    Fx.Assert(_decoder.CurrentState == SingletonMessageDecoder.State.ReadingEnvelopeBytes || _decoder.CurrentState == SingletonMessageDecoder.State.ChunkEnd, "");
                }
            }
 
            // run the current data through the decoder to get valid message bytes
            private void DecodeSize(byte[] buffer, ref int offset, ref int size)
            {
                while (size > 0)
                {
                    int bytesRead = _decoder.Decode(buffer, offset, size);
 
                    if (bytesRead > 0)
                    {
                        offset += bytesRead;
                        size -= bytesRead;
                    }
 
                    switch (_decoder.CurrentState)
                    {
                        case SingletonMessageDecoder.State.ChunkStart:
                            _chunkBytesRemaining = _decoder.ChunkSize;
 
                            // if we have overflow and we're not decoding out of our buffer, copy over
                            if (size > 0 && !object.ReferenceEquals(buffer, _chunkBuffer))
                            {
                                Fx.Assert(size <= _chunkBuffer.Length, "");
                                Buffer.BlockCopy(buffer, offset, _chunkBuffer, 0, size);
                                _chunkBufferOffset = 0;
                                _chunkBufferSize = size;
                            }
                            return;
 
                        case SingletonMessageDecoder.State.End:
                            ProcessEof();
                            return;
                    }
                }
            }
 
            private int ReadCore(byte[] buffer, int offset, int count)
            {
                int bytesRead = -1;
                try
                {
                    bytesRead = base.Read(buffer, offset, count);
                    if (count != 0 && bytesRead == 0)
                    {
                        ProcessEof();
                    }
                }
                finally
                {
                    if (bytesRead == -1)  // there was an exception
                    {
                        AbortReader();
                    }
                }
 
                return bytesRead;
            }
 
            private async ValueTask<int> ReadCoreAsync(Memory<byte> buffer)
            {
                int bytesRead = -1;
                try
                {
                    bytesRead = await base.ReadAsync(buffer);
                    if (!buffer.IsEmpty && bytesRead == 0)
                    {
                        ProcessEof();
                    }
                }
                finally
                {
                    if (bytesRead == -1)  // there was an exception
                    {
                        AbortReader();
                    }
                }
 
                return bytesRead;
            }
 
            public override int Read(byte[] buffer, int offset, int count)
            {
                int result = 0;
                while (true)
                {
                    if (count == 0)
                    {
                        return result;
                    }
 
                    if (_atEof)
                    {
                        return result;
                    }
 
                    // first deal with any residual carryover
                    if (_chunkBufferSize > 0)
                    {
                        int bytesToCopy = Math.Min(_chunkBytesRemaining,
                            Math.Min(_chunkBufferSize, count));
 
                        Buffer.BlockCopy(_chunkBuffer, _chunkBufferOffset, buffer, offset, bytesToCopy);
                        // keep decoder up to date
                        DecodeData(_chunkBuffer, _chunkBufferOffset, bytesToCopy);
 
                        _chunkBufferOffset += bytesToCopy;
                        _chunkBufferSize -= bytesToCopy;
                        _chunkBytesRemaining -= bytesToCopy;
                        if (_chunkBytesRemaining == 0 && _chunkBufferSize > 0)
                        {
                            DecodeSize(_chunkBuffer, ref _chunkBufferOffset, ref _chunkBufferSize);
                        }
 
                        result += bytesToCopy;
                        offset += bytesToCopy;
                        count -= bytesToCopy;
                    }
                    else if (_chunkBytesRemaining > 0)
                    {
                        // We're in the middle of a chunk. Try and include the next chunk size as well
 
                        int bytesToRead = count;
                        if (int.MaxValue - _chunkBytesRemaining >= IntEncoder.MaxEncodedSize)
                        {
                            bytesToRead = Math.Min(count, _chunkBytesRemaining + IntEncoder.MaxEncodedSize);
                        }
 
                        int bytesRead = ReadCore(buffer, offset, bytesToRead);
 
                        // keep decoder up to date
                        DecodeData(buffer, offset, Math.Min(bytesRead, _chunkBytesRemaining));
 
                        if (bytesRead > _chunkBytesRemaining)
                        {
                            result += _chunkBytesRemaining;
                            int overflowCount = bytesRead - _chunkBytesRemaining;
                            int overflowOffset = offset + _chunkBytesRemaining;
                            _chunkBytesRemaining = 0;
                            // read at least part of the next chunk, and put any overflow in this.chunkBuffer
                            DecodeSize(buffer, ref overflowOffset, ref overflowCount);
                        }
                        else
                        {
                            result += bytesRead;
                            _chunkBytesRemaining -= bytesRead;
                        }
 
                        return result;
                    }
                    else
                    {
                        // Final case: we have a new chunk. Read the size, and loop around again
                        if (count < IntEncoder.MaxEncodedSize)
                        {
                            // we don't have space for MaxEncodedSize, so it's worth the copy cost to read into a temp buffer
                            _chunkBufferOffset = 0;
                            _chunkBufferSize = ReadCore(_chunkBuffer, 0, _chunkBuffer.Length);
                            DecodeSize(_chunkBuffer, ref _chunkBufferOffset, ref _chunkBufferSize);
                        }
                        else
                        {
                            int bytesRead = ReadCore(buffer, offset, IntEncoder.MaxEncodedSize);
                            int sizeOffset = offset;
                            DecodeSize(buffer, ref sizeOffset, ref bytesRead);
                        }
                    }
                }
            }
 
            private void ProcessEof()
            {
                if (!_atEof)
                {
                    _atEof = true;
                    if (_chunkBufferSize > 0 || _chunkBytesRemaining > 0
                        || _decoder.CurrentState != SingletonMessageDecoder.State.End)
                    {
                        throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(_decoder.CreatePrematureEOFException());
                    }
 
                    _reader.DoneReceiving(true);
                }
            }
 
            public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
            {
                int result = 0;
                while (true)
                {
                    if (count == 0)
                    {
                        return result;
                    }
 
                    if (_atEof)
                    {
                        return result;
                    }
 
                    // first deal with any residual carryover
                    if (_chunkBufferSize > 0)
                    {
                        int bytesToCopy = Math.Min(_chunkBytesRemaining,
                            Math.Min(_chunkBufferSize, count));
 
                        Buffer.BlockCopy(_chunkBuffer, _chunkBufferOffset, buffer, offset, bytesToCopy);
                        // keep decoder up to date
                        DecodeData(_chunkBuffer, _chunkBufferOffset, bytesToCopy);
 
                        _chunkBufferOffset += bytesToCopy;
                        _chunkBufferSize -= bytesToCopy;
                        _chunkBytesRemaining -= bytesToCopy;
                        if (_chunkBytesRemaining == 0 && _chunkBufferSize > 0)
                        {
                            DecodeSize(_chunkBuffer, ref _chunkBufferOffset, ref _chunkBufferSize);
                        }
 
                        result += bytesToCopy;
                        offset += bytesToCopy;
                        count -= bytesToCopy;
                    }
                    else if (_chunkBytesRemaining > 0)
                    {
                        // We're in the middle of a chunk. Try and include the next chunk size as well
 
                        int bytesToRead = count;
                        if (int.MaxValue - _chunkBytesRemaining >= IntEncoder.MaxEncodedSize)
                        {
                            bytesToRead = Math.Min(count, _chunkBytesRemaining + IntEncoder.MaxEncodedSize);
                        }
 
                        int bytesRead = await ReadCoreAsync(new Memory<byte>(buffer, offset, bytesToRead));
 
                        // keep decoder up to date
                        DecodeData(buffer, offset, Math.Min(bytesRead, _chunkBytesRemaining));
 
                        if (bytesRead > _chunkBytesRemaining)
                        {
                            result += _chunkBytesRemaining;
                            int overflowCount = bytesRead - _chunkBytesRemaining;
                            int overflowOffset = offset + _chunkBytesRemaining;
                            _chunkBytesRemaining = 0;
                            // read at least part of the next chunk, and put any overflow in this.chunkBuffer
                            DecodeSize(buffer, ref overflowOffset, ref overflowCount);
                        }
                        else
                        {
                            result += bytesRead;
                            _chunkBytesRemaining -= bytesRead;
                        }
 
                        return result;
                    }
                    else
                    {
                        // Final case: we have a new chunk. Read the size, and loop around again
                        if (count < IntEncoder.MaxEncodedSize)
                        {
                            // we don't have space for MaxEncodedSize, so it's worth the copy cost to read into a temp buffer
                            _chunkBufferOffset = 0;
                            _chunkBufferSize = await ReadCoreAsync(_chunkBuffer);
                            DecodeSize(_chunkBuffer, ref _chunkBufferOffset, ref _chunkBufferSize);
                        }
                        else
                        {
                            int bytesRead = await ReadCoreAsync(new Memory<byte>(buffer, offset, IntEncoder.MaxEncodedSize));
                            int sizeOffset = offset;
                            DecodeSize(buffer, ref sizeOffset, ref bytesRead);
                        }
                    }
                }
            }
 
            public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
            {
                if (buffer.IsEmpty)
                {
                    return new ValueTask<int>(0);
                }
 
                if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> array))
                {
                    return new ValueTask<int>(ReadAsync(array.Array!, array.Offset, array.Count, cancellationToken));
                }
 
                byte[] sharedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
                return FinishReadAsync(ReadAsync(sharedBuffer, 0, buffer.Length, cancellationToken), sharedBuffer, buffer);
 
                static async ValueTask<int> FinishReadAsync(Task<int> readTask, byte[] localBuffer, Memory<byte> localDestination)
                {
                    try
                    {
                        int result = await readTask.ConfigureAwait(false);
                        new ReadOnlySpan<byte>(localBuffer, 0, result).CopyTo(localDestination.Span);
                        return result;
                    }
                    finally
                    {
                        ArrayPool<byte>.Shared.Return(localBuffer);
                    }
                }
            }
        }
    }
 
    internal static class StreamingConnectionHelper
    {
        public async static Task WriteMessageAsync(Message message, IConnection connection, bool isRequest,
            IConnectionOrientedTransportFactorySettings settings, TimeoutHelper timeoutHelper)
        {
            Memory<byte> endBytes = default;
            if (message != null)
            {
                MessageEncoder messageEncoder = settings.MessageEncoderFactory.Encoder;
                Memory<byte> envelopeStartBytes = SingletonEncoder.EnvelopeStartBytes;
 
                bool writeStreamed;
                if (isRequest)
                {
                    endBytes = SingletonEncoder.EnvelopeEndFramingEndBytes;
                    writeStreamed = TransferModeHelper.IsRequestStreamed(settings.TransferMode);
                }
                else
                {
                    endBytes = SingletonEncoder.EnvelopeEndBytes;
                    writeStreamed = TransferModeHelper.IsResponseStreamed(settings.TransferMode);
                }
 
                if (writeStreamed)
                {
                    await connection.WriteAsync(envelopeStartBytes, false, timeoutHelper.RemainingTime());
                    Stream connectionStream = new StreamingOutputConnectionStream(connection, settings);
                    Stream writeTimeoutStream = new TimeoutStream(connectionStream, timeoutHelper.RemainingTime());
                    await messageEncoder.WriteMessageAsync(message, writeTimeoutStream);
                }
                else
                {
                    ArraySegment<byte> messageData = await messageEncoder.WriteMessageAsync(message,
                        int.MaxValue, settings.BufferManager, envelopeStartBytes.Length + IntEncoder.MaxEncodedSize);
                    messageData = SingletonEncoder.EncodeMessageFrame(messageData);
                    envelopeStartBytes.CopyTo(new Memory<byte>(messageData.Array, messageData.Offset - envelopeStartBytes.Length, envelopeStartBytes.Length));
                    await connection.WriteAsync(new Memory<byte>(messageData.Array, messageData.Offset - envelopeStartBytes.Length,
                        messageData.Count + envelopeStartBytes.Length), true, timeoutHelper.RemainingTime());
                }
            }
            else if (isRequest) // context handles response end bytes
            {
                endBytes = SingletonEncoder.EndBytes;
            }
 
            if (!endBytes.IsEmpty)
            {
                await connection.WriteAsync(endBytes,
                    true, timeoutHelper.RemainingTime());
            }
        }
 
        // overrides ConnectionStream to add a Framing int at the beginning of each record
        private class StreamingOutputConnectionStream : ConnectionStream
        {
            private Memory<byte> _encodedSize;
 
            public StreamingOutputConnectionStream(IConnection connection, IDefaultCommunicationTimeouts timeouts)
                : base(connection, timeouts)
            {
                _encodedSize = new byte[IntEncoder.MaxEncodedSize];
            }
 
            private ValueTask WriteChunkSizeAsync(int size)
            {
                if (size > 0)
                {
                    int bytesEncoded = IntEncoder.Encode(size, _encodedSize);
                    return Connection.WriteAsync(_encodedSize.Slice(0, bytesEncoded), false, TimeoutHelper.FromMilliseconds(WriteTimeout));
                }
                else
                {
                    return default;
                }
            }
 
            public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
            {
                await WriteChunkSizeAsync(count);
                await base.WriteAsync(buffer, offset, count, cancellationToken);
            }
 
            public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
            {
                await WriteChunkSizeAsync(buffer.Length);
                await base.WriteAsync(buffer, cancellationToken);
            }
 
            public override void WriteByte(byte value)
            {
                // We're write the chunk size as a variable length integer before we write the data. For a single byte
                // the length will always be 1 and that's encoded as 0x01 so no need to call WriteChunkSizeAsync. We
                // only need to send 2 bytes so can use _encodedSize to save an extra allocation.
                _encodedSize.Span[0] = 0x01;
                _encodedSize.Span[1] = value;
                base.WriteAsync(_encodedSize.Slice(0, 2)).GetAwaiter().GetResult();
            }
 
            public override void Write(byte[] buffer, int offset, int count)
            {
                // Do NOT call base.WriteAsync as the chunk size must be written first.
                WriteAsync(new Memory<byte>(buffer, offset, count)).GetAwaiter().GetResult();
            }
        }
    }
 
}