File: FrameworkFork\System.ServiceModel\System\ServiceModel\Channels\ProducerConsumerStream.cs
Web Access
Project: src\src\dotnet-svcutil\lib\src\dotnet-svcutil-lib.csproj (dotnet-svcutil-lib)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Diagnostics.Contracts;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
 
namespace System.ServiceModel.Channels
{
    internal class ProducerConsumerStream : Stream
    {
        private TaskCompletionSource<WriteBufferWrapper> _buffer;
        private TaskCompletionSource<int> _dataAvail;
        private WriteBufferWrapper _currentBuffer;
        private bool _disposed;
 
        public ProducerConsumerStream()
        {
            _buffer = new TaskCompletionSource<WriteBufferWrapper>();
            _dataAvail = new TaskCompletionSource<int>();
            _currentBuffer = WriteBufferWrapper.EmptyContainer;
        }
 
        public override bool CanRead { get { return !_disposed; } }
 
        public override bool CanSeek { get { return false; } }
 
        public override bool CanWrite { get { return !_disposed; } }
 
        public override void Flush()
        {
            _dataAvail.TrySetResult(0);
        }
 
        public override int Read(byte[] buffer, int offset, int count)
        {
            try
            {
                return ReadAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult();
            }
            catch (OperationCanceledException oce)
            {
                // If WriteAsync is canceled, an OperationCanceledException might get thrown on the Read path.
                // The stream is no longer usable so converting to an ObjectDisposedException.
                throw new ObjectDisposedException("ProducerConsumerStream", oce);
            }
        }
 
        public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
        {
            if (buffer == null)
            {
                throw new ArgumentNullException("buffer");
            }
            if (offset < 0 || offset > buffer.Length)
            {
                throw new ArgumentOutOfRangeException("offset");
            }
            if (count <= 0 || count > buffer.Length - offset)
            {
                throw new ArgumentOutOfRangeException("count");
            }
 
            cancellationToken.ThrowIfCancellationRequested();
 
            if (_disposed)
            {
                return 0;
            }
 
            using (cancellationToken.Register(CancelAndDispose, this))
            {
                _buffer.TrySetResult(new WriteBufferWrapper(buffer, offset, count));
                int totalBytesRead = 0;
                while (count > 0)
                {
                    int bytesRead = await _dataAvail.Task;
                    _dataAvail = new TaskCompletionSource<int>();
                    if (bytesRead == 0)
                    {
                        break;
                    }
 
                    totalBytesRead += bytesRead;
                    count -= bytesRead;
                }
 
                return totalBytesRead;
            }
        }
 
        public override void Write(byte[] buffer, int offset, int count)
        {
            try
            {
                WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult();
            }
            catch (OperationCanceledException oce)
            {
                // If ReadAsync is canceled, an OperationCanceledException might get thrown on the write path.
                // The stream is no longer usable so converting to an ObjectDisposedException.
                throw new ObjectDisposedException("ProducerConsumerStream", oce);
            }
        }
 
        public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
        {
            if (buffer == null)
            {
                throw new ArgumentNullException("buffer");
            }
            if (offset < 0 || offset > buffer.Length)
            {
                throw new ArgumentOutOfRangeException("offset");
            }
            if (count < 0 || count > buffer.Length - offset)
            {
                throw new ArgumentOutOfRangeException("count");
            }
 
            cancellationToken.ThrowIfCancellationRequested();
 
            if (_disposed)
            {
                throw new ObjectDisposedException("ProducerConsumerStream");
            }
 
            using (cancellationToken.Register(CancelAndDispose, this))
            {
                while (count > 0)
                {
                    if (_currentBuffer == WriteBufferWrapper.EmptyContainer)
                    {
                        _currentBuffer = await _buffer.Task;
                        _buffer = new TaskCompletionSource<WriteBufferWrapper>();
                    }
 
                    int bytesWritten = _currentBuffer.Write(buffer, offset, count);
                    count -= bytesWritten;
                    offset += bytesWritten;
                    if (_currentBuffer.Count == 0)
                    {
                        _currentBuffer = WriteBufferWrapper.EmptyContainer;
                    }
 
                    _dataAvail.TrySetResult(bytesWritten);
                }
            }
        }
 
        public override long Seek(long offset, SeekOrigin origin)
        {
            throw new NotSupportedException();
        }
 
        public override void SetLength(long value)
        {
            throw new NotSupportedException();
        }
 
        public override long Length
        {
            get { throw new NotSupportedException(); }
        }
 
        public override long Position
        {
            get { throw new NotSupportedException(); }
            set { throw new NotSupportedException(); }
        }
 
        protected override void Dispose(bool disposing)
        {
            if (disposing && !_disposed)
            {
                _disposed = true;
                _dataAvail.TrySetResult(0);
            }
 
            base.Dispose(disposing);
        }
 
        private static void CancelAndDispose(object state)
        {
            var thisPtr = state as ProducerConsumerStream;
            if (thisPtr != null)
            {
                thisPtr._dataAvail.TrySetCanceled();
                thisPtr._buffer.TrySetCanceled();
                thisPtr.Dispose();
            }
        }
 
#pragma warning disable 659,661  // Hashcode not needed, Equals overriden for fast path comparison of EmptyContainer
        private struct WriteBufferWrapper : IEquatable<WriteBufferWrapper>
#pragma warning restore 659,661
        {
            public static readonly WriteBufferWrapper EmptyContainer = new WriteBufferWrapper(null, -1, -1, true);
 
            private byte[] _buffer;
            private int _offset;
            private int _count;
            private readonly bool _readOnly;
 
            public WriteBufferWrapper(byte[] buffer, int offset, int count) : this(buffer, offset, count, false) { }
            private WriteBufferWrapper(byte[] buffer, int offset, int count, bool isReadOnly) : this()
            {
                _buffer = buffer;
                _offset = offset;
                _count = count;
                _readOnly = isReadOnly;
            }
 
            public byte[] ByteBuffer
            {
                get { return _buffer; }
            }
 
            public int Offset
            {
                get { return _offset; }
            }
 
            public int Count
            {
                get { return _count; }
            }
 
            public override bool Equals(object obj)
            {
                if (obj is WriteBufferWrapper)
                    return this.Equals((WriteBufferWrapper)obj);
                return false;
            }
 
            public bool Equals(WriteBufferWrapper other)
            {
                // Doing _readOnly comparison first as main use case is comparing against
                // BufferContainer.EmptyContainer which is the only readonly instance of BufferContainer.
                // If either instance is the read only, then this field is sufficient for comparison.
                if (_readOnly || other._readOnly)
                {
                    return _readOnly == other._readOnly;
                }
 
                if (other._buffer == _buffer && other._offset == _offset)
                {
                    return other._count == _count;
                }
 
                return false;
            }
 
            public static bool operator ==(WriteBufferWrapper a, WriteBufferWrapper b)
            {
                return a.Equals(b);
            }
 
            public static bool operator !=(WriteBufferWrapper a, WriteBufferWrapper b)
            {
                return !(a == b);
            }
 
            internal int Write(byte[] srcBuffer, int srcOffset, int srcCount)
            {
                if (_readOnly)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SRServiceModel.ObjectIsReadOnly));
                }
 
                int bytesToCopy = Math.Min(_count, srcCount);
                Buffer.BlockCopy(srcBuffer, srcOffset, _buffer, _offset, bytesToCopy);
                _offset += bytesToCopy;
                _count -= bytesToCopy;
                return bytesToCopy;
            }
        }
    }
}