File: System\ServiceModel\Channels\MaxMessageSizeStream.cs
Web Access
Project: src\src\System.ServiceModel.Primitives\src\System.ServiceModel.Primitives.csproj (System.ServiceModel.Primitives)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.IO;
using System.Runtime;
using System.Threading.Tasks;
using System.Threading;
 
namespace System.ServiceModel.Channels
{
    internal class MaxMessageSizeStream : DelegatingStream
    {
        private long _maxMessageSize;
        private long _totalBytesRead;
        private long _bytesWritten;
 
        public MaxMessageSizeStream(Stream stream, long maxMessageSize)
            : base(stream)
        {
            _maxMessageSize = maxMessageSize;
        }
 
        public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
        {
            count = PrepareRead(count);
            int bytesRead = await base.ReadAsync(buffer, offset, count, cancellationToken);
            return FinishRead(bytesRead);
        }
 
        public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
        {
            int readCount = PrepareRead(buffer.Length);
            if (readCount < buffer.Length)
            {
                buffer = buffer.Slice(0, readCount);
            }
            int bytesRead = await base.ReadAsync(buffer, cancellationToken);
            return FinishRead(bytesRead);
        }
 
        public override int Read(byte[] buffer, int offset, int count)
        {
            count = PrepareRead(count);
            return FinishRead(base.Read(buffer, offset, count));
        }
 
        public override int ReadByte()
        {
            PrepareRead(1);
            int i = base.ReadByte();
            if (i != -1)
            {
                FinishRead(1);
            }
 
            return i;
        }
 
        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
        {
            PrepareWrite(count);
            return base.WriteAsync(buffer, offset, count, cancellationToken);
        }
 
        public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
        {
            PrepareWrite(buffer.Length);
            return base.WriteAsync(buffer, cancellationToken);
        }
 
        public override void Write(byte[] buffer, int offset, int count)
        {
            PrepareWrite(count);
            base.Write(buffer, offset, count);
        }
 
        public override void WriteByte(byte value)
        {
            PrepareWrite(1);
            base.WriteByte(value);
        }
 
        public static Exception CreateMaxReceivedMessageSizeExceededException(long maxMessageSize)
        {
            string message = SRP.Format(SRP.MaxReceivedMessageSizeExceeded, maxMessageSize);
            Exception inner = new QuotaExceededException(message);
 
            if (WcfEventSource.Instance.MaxReceivedMessageSizeExceededIsEnabled())
            {
                WcfEventSource.Instance.MaxReceivedMessageSizeExceeded(message);
            }
 
            return new CommunicationException(message, inner);
        }
 
        internal static Exception CreateMaxSentMessageSizeExceededException(long maxMessageSize)
        {
            string message = SRP.Format(SRP.MaxSentMessageSizeExceeded, maxMessageSize);
            Exception inner = new QuotaExceededException(message);
 
            if (WcfEventSource.Instance.MaxSentMessageSizeExceededIsEnabled())
            {
                WcfEventSource.Instance.MaxSentMessageSizeExceeded(message);
            }
 
            return new CommunicationException(message, inner);
        }
 
        private int PrepareRead(int bytesToRead)
        {
            if (_totalBytesRead >= _maxMessageSize)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(CreateMaxReceivedMessageSizeExceededException(_maxMessageSize));
            }
 
            long bytesRemaining = _maxMessageSize - _totalBytesRead;
 
            if (bytesRemaining > int.MaxValue)
            {
                return bytesToRead;
            }
            else
            {
                return Math.Min(bytesToRead, (int)bytesRemaining);
            }
        }
 
        private int FinishRead(int bytesRead)
        {
            _totalBytesRead += bytesRead;
            return bytesRead;
        }
 
        private void PrepareWrite(int bytesToWrite)
        {
            if (_bytesWritten + bytesToWrite > _maxMessageSize)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(CreateMaxSentMessageSizeExceededException(_maxMessageSize));
            }
 
            _bytesWritten += bytesToWrite;
        }
    }
}