File: System\Net\StreamFramer.cs
Web Access
Project: src\src\libraries\System.Net.Security\src\System.Net.Security.csproj (System.Net.Security)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Globalization;
using System.IO;
using System.Net.Security;
using System.Threading;
using System.Threading.Tasks;
 
namespace System.Net
{
    internal sealed class StreamFramer
    {
        private readonly FrameHeader _writeHeader = new FrameHeader();
        private readonly FrameHeader _curReadHeader = new FrameHeader();
 
        private readonly byte[] _readHeaderBuffer = new byte[FrameHeader.Size];
        private readonly byte[] _writeHeaderBuffer = new byte[FrameHeader.Size];
        private bool _eof;
 
        public FrameHeader ReadHeader => _curReadHeader;
        public FrameHeader WriteHeader => _writeHeader;
 
        public async ValueTask<byte[]?> ReadMessageAsync<TAdapter>(Stream stream, CancellationToken cancellationToken)
            where TAdapter : IReadWriteAdapter
        {
            if (_eof)
            {
                return null;
            }
 
            byte[] buffer = _readHeaderBuffer;
 
            int bytesRead = await TAdapter.ReadAtLeastAsync(
                stream, buffer, buffer.Length, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
            if (bytesRead < buffer.Length)
            {
                if (bytesRead == 0)
                {
                    // m_Eof, return null
                    _eof = true;
                    return null;
                }
                throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
            }
 
            _curReadHeader.CopyFrom(buffer, 0);
            if (_curReadHeader.PayloadSize > FrameHeader.MaxMessageSize)
            {
                throw new InvalidOperationException(SR.Format(SR.net_frame_size,
                                                               FrameHeader.MaxMessageSize,
                                                               _curReadHeader.PayloadSize.ToString(NumberFormatInfo.InvariantInfo)));
            }
 
            buffer = new byte[_curReadHeader.PayloadSize];
 
            if (buffer.Length > 0)
            {
                bytesRead = await TAdapter.ReadAtLeastAsync(
                    stream, buffer, buffer.Length, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
                if (bytesRead < buffer.Length)
                {
                    throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
                }
            }
            return buffer;
        }
 
        public async Task WriteMessageAsync<TAdapter>(Stream stream, byte[] message, CancellationToken cancellationToken)
            where TAdapter : IReadWriteAdapter
        {
            _writeHeader.PayloadSize = message.Length;
            _writeHeader.CopyTo(_writeHeaderBuffer, 0);
 
            await TAdapter.WriteAsync(stream, _writeHeaderBuffer, cancellationToken).ConfigureAwait(false);
            if (message.Length != 0)
            {
                await TAdapter.WriteAsync(stream, message, cancellationToken).ConfigureAwait(false);
            }
        }
    }
 
    // Describes the header used in framing of the stream data.
    internal sealed class FrameHeader
    {
        public const int IgnoreValue = -1;
        public const int HandshakeDoneId = 20;
        public const int HandshakeErrId = 21;
        public const int HandshakeId = 22;
        public const int DefaultMajorV = 1;
        public const int DefaultMinorV = 0;
        public const int Size = 5;
        public const int MaxMessageSize = 0xFFFF;
 
        private int _payloadSize = -1;
 
        public int MessageId { get; set; } = HandshakeId;
        public int MajorV { get; private set; } = DefaultMajorV;
        public int MinorV { get; private set; } = DefaultMinorV;
 
        public int PayloadSize
        {
            get => _payloadSize;
            set
            {
                if (value > MaxMessageSize)
                {
                    throw new ArgumentException(SR.Format(SR.net_frame_max_size, MaxMessageSize, value), nameof(PayloadSize));
                }
 
                _payloadSize = value;
            }
        }
 
        public void CopyTo(byte[] dest, int start)
        {
            dest[start++] = (byte)MessageId;
            dest[start++] = (byte)MajorV;
            dest[start++] = (byte)MinorV;
            dest[start++] = (byte)((_payloadSize >> 8) & 0xFF);
            dest[start] = (byte)(_payloadSize & 0xFF);
        }
 
        public void CopyFrom(byte[] bytes, int start)
        {
            MessageId = bytes[start++];
            MajorV = bytes[start++];
            MinorV = bytes[start++];
            _payloadSize = (bytes[start++] << 8) | bytes[start];
        }
    }
}