File: System\Net\WebSockets\Compression\WebSocketInflater.cs
Web Access
Project: src\src\libraries\System.Net.WebSockets\src\System.Net.WebSockets.csproj (System.Net.WebSockets)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Diagnostics;
using static System.IO.Compression.ZLibNative;
 
namespace System.Net.WebSockets.Compression
{
    /// <summary>
    /// Provides a wrapper around the ZLib decompression API.
    /// </summary>
    internal sealed class WebSocketInflater : IDisposable
    {
        internal const int FlushMarkerLength = 4;
        internal static ReadOnlySpan<byte> FlushMarker => [0x00, 0x00, 0xFF, 0xFF];
 
        private readonly int _windowBits;
        private ZLibStreamHandle? _stream;
        private readonly bool _persisted;
 
        /// <summary>
        /// There is no way of knowing, when decoding data, if the underlying inflater
        /// has flushed all outstanding data to consumer other than to provide a buffer
        /// and see whether any bytes are written. There are cases when the consumers
        /// provide a buffer exactly the size of the uncompressed data and in this case
        /// to avoid requiring another read we will use this field.
        /// </summary>
        private byte? _remainingByte;
 
        /// <summary>
        /// The last added bytes to the inflater were part of the final
        /// payload for the message being sent.
        /// </summary>
        private bool _endOfMessage;
 
        private byte[]? _buffer;
 
        /// <summary>
        /// The position for the next unconsumed byte in the inflate buffer.
        /// </summary>
        private int _position;
 
        /// <summary>
        /// How many unconsumed bytes are left in the inflate buffer.
        /// </summary>
        private int _available;
 
        internal WebSocketInflater(int windowBits, bool persisted)
        {
            _windowBits = -windowBits; // Negative for raw deflate
            _persisted = persisted;
        }
 
        public Memory<byte> Memory => _buffer.AsMemory(_position + _available);
 
        public Span<byte> Span => _buffer.AsSpan(_position + _available);
 
        public void Dispose()
        {
            _stream?.Dispose();
            ReleaseBuffer();
        }
 
        /// <summary>
        /// Initializes the inflater by allocating a buffer so the websocket can receive directly onto it.
        /// </summary>
        /// <param name="payloadLength">the length of the message payload</param>
        /// <param name="userBufferLength">the length of the buffer where the payload will be inflated</param>
        public void Prepare(long payloadLength, int userBufferLength)
        {
            if (_buffer is not null)
            {
                Debug.Assert(_available > 0);
 
                _buffer.AsSpan(_position, _available).CopyTo(_buffer);
                _position = 0;
            }
            else
            {
                // Rent a buffer as close to the size of the user buffer as possible.
                // If the payload is smaller than the user buffer, rent only as much as we need.
                _buffer = ArrayPool<byte>.Shared.Rent((int)Math.Min(userBufferLength, payloadLength));
            }
        }
 
        public void AddBytes(int totalBytesReceived, bool endOfMessage)
        {
            Debug.Assert(totalBytesReceived == 0 || _buffer is not null, "Prepare must be called.");
 
            _available += totalBytesReceived;
            _endOfMessage = endOfMessage;
 
            if (endOfMessage)
            {
                if (_buffer is null)
                {
                    Debug.Assert(_available == 0);
 
                    _buffer = ArrayPool<byte>.Shared.Rent(FlushMarkerLength);
                    _available = FlushMarkerLength;
                    FlushMarker.CopyTo(_buffer);
                }
                else
                {
                    if (_buffer.Length < _available + FlushMarkerLength)
                    {
                        byte[] newBuffer = ArrayPool<byte>.Shared.Rent(_available + FlushMarkerLength);
                        _buffer.AsSpan(0, _available).CopyTo(newBuffer);
 
                        byte[] toReturn = _buffer;
                        _buffer = newBuffer;
 
                        ArrayPool<byte>.Shared.Return(toReturn);
                    }
 
                    FlushMarker.CopyTo(_buffer.AsSpan(_available));
                    _available += FlushMarkerLength;
                }
            }
        }
 
        /// <summary>
        /// Inflates the last receive payload into the provided buffer.
        /// </summary>
        public unsafe bool Inflate(Span<byte> output, out int written)
        {
            _stream ??= CreateInflater();
 
            if (_available > 0 && output.Length > 0)
            {
                int consumed;
 
                fixed (byte* bufferPtr = _buffer)
                {
                    _stream.NextIn = (IntPtr)(bufferPtr + _position);
                    _stream.AvailIn = (uint)_available;
 
                    written = Inflate(_stream, output, FlushCode.NoFlush);
                    consumed = _available - (int)_stream.AvailIn;
                }
 
                _position += consumed;
                _available -= consumed;
            }
            else
            {
                written = 0;
            }
 
            if (_available == 0)
            {
                ReleaseBuffer();
                return _endOfMessage ? Finish(output, ref written) : true;
            }
 
            return false;
        }
 
        /// <summary>
        /// Finishes the decoding by flushing any outstanding data to the output.
        /// </summary>
        /// <returns>true if the flush completed, false to indicate that there is more outstanding data.</returns>
        private unsafe bool Finish(Span<byte> output, ref int written)
        {
            Debug.Assert(_stream is not null && _stream.AvailIn == 0);
            Debug.Assert(_available == 0);
 
            if (_remainingByte is not null)
            {
                if (output.Length == written)
                {
                    return false;
                }
                output[written] = _remainingByte.GetValueOrDefault();
                _remainingByte = null;
                written += 1;
            }
 
            // If we have more space in the output, try to inflate
            if (output.Length > written)
            {
                written += Inflate(_stream, output[written..], FlushCode.SyncFlush);
            }
 
            // After inflate, if we have more space in the output then it means that we
            // have finished. Otherwise we need to manually check for more data.
            if (written < output.Length || IsFinished(_stream, out _remainingByte))
            {
                if (!_persisted)
                {
                    _stream.Dispose();
                    _stream = null;
                }
                return true;
            }
 
            return false;
        }
 
        private void ReleaseBuffer()
        {
            if (_buffer is byte[] toReturn)
            {
                _buffer = null;
                _available = 0;
                _position = 0;
 
                ArrayPool<byte>.Shared.Return(toReturn);
            }
        }
 
        private static bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte)
        {
            // There is no other way to make sure that we've consumed all data
            // but to try to inflate again with at least one byte of output buffer.
            byte b = 0;
            if (Inflate(stream, new Span<byte>(ref b), FlushCode.SyncFlush) == 0)
            {
                remainingByte = null;
                return true;
            }
 
            remainingByte = b;
            return false;
        }
 
        private static unsafe int Inflate(ZLibStreamHandle stream, Span<byte> destination, FlushCode flushCode)
        {
            Debug.Assert(destination.Length > 0);
            ErrorCode errorCode;
 
            fixed (byte* bufPtr = destination)
            {
                stream.NextOut = (IntPtr)bufPtr;
                stream.AvailOut = (uint)destination.Length;
 
                errorCode = stream.Inflate(flushCode);
 
                if (errorCode is ErrorCode.Ok or ErrorCode.StreamEnd or ErrorCode.BufError)
                {
                    return destination.Length - (int)stream.AvailOut;
                }
            }
 
            string message = errorCode switch
            {
                ErrorCode.MemError => SR.ZLibErrorNotEnoughMemory,
                ErrorCode.DataError => SR.ZLibUnsupportedCompression,
                ErrorCode.StreamError => SR.ZLibErrorInconsistentStream,
                _ => SR.Format(SR.ZLibErrorUnexpected, (int)errorCode)
            };
            throw new WebSocketException(message);
        }
 
        private ZLibStreamHandle CreateInflater()
        {
            ZLibStreamHandle? stream = null;
            ErrorCode errorCode;
 
            try
            {
                errorCode = CreateZLibStreamForInflate(out stream, _windowBits);
            }
            catch (Exception exception)
            {
                stream?.Dispose();
                throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception);
            }
 
            if (errorCode == ErrorCode.Ok)
            {
                return stream;
            }
 
            stream.Dispose();
 
            string message = errorCode == ErrorCode.MemError
                ? SR.ZLibErrorNotEnoughMemory
                : SR.Format(SR.ZLibErrorUnexpected, (int)errorCode);
            throw new WebSocketException(message);
        }
    }
}