File: System\Net\WebSockets\Compression\WebSocketDeflater.cs
Web Access
Project: src\src\libraries\System.Net.WebSockets\src\System.Net.WebSockets.csproj (System.Net.WebSockets)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Diagnostics;
using static System.IO.Compression.ZLibNative;
 
namespace System.Net.WebSockets.Compression
{
    /// <summary>
    /// Provides a wrapper around the ZLib compression API.
    /// </summary>
    internal sealed class WebSocketDeflater : IDisposable
    {
        private readonly int _windowBits;
        private ZLibStreamHandle? _stream;
        private readonly bool _persisted;
 
        private byte[]? _buffer;
 
        internal WebSocketDeflater(int windowBits, bool persisted)
        {
            _windowBits = -windowBits; // Negative for raw deflate
            _persisted = persisted;
        }
 
        public void Dispose()
        {
            _stream?.Dispose();
        }
 
        public void ReleaseBuffer()
        {
            if (_buffer is byte[] toReturn)
            {
                _buffer = null;
                ArrayPool<byte>.Shared.Return(toReturn);
            }
        }
 
        public ReadOnlySpan<byte> Deflate(ReadOnlySpan<byte> payload, bool endOfMessage)
        {
            Debug.Assert(_buffer is null, "Invalid state, ReleaseBuffer not called.");
 
            // For small payloads there might actually be overhead in the compression and the resulting
            // output might be larger than the payload. This is why we rent at least 4KB initially.
            const int MinInitialBufferLength = 4 * 1024;
 
            _buffer = ArrayPool<byte>.Shared.Rent(Math.Max(payload.Length, MinInitialBufferLength));
            int position = 0;
 
            while (true)
            {
                DeflatePrivate(payload, _buffer.AsSpan(position), endOfMessage,
                    out int consumed, out int written, out bool needsMoreOutput);
                position += written;
 
                if (!needsMoreOutput)
                {
                    Debug.Assert(consumed == payload.Length);
                    break;
                }
 
                payload = payload.Slice(consumed);
 
                // Rent a 30% bigger buffer
                byte[] newBuffer = ArrayPool<byte>.Shared.Rent((int)(_buffer.Length * 1.3));
                _buffer.AsSpan(0, position).CopyTo(newBuffer);
 
                byte[] toReturn = _buffer;
                _buffer = newBuffer;
 
                ArrayPool<byte>.Shared.Return(toReturn);
            }
 
            return new ReadOnlySpan<byte>(_buffer, 0, position);
        }
 
        private void DeflatePrivate(ReadOnlySpan<byte> payload, Span<byte> output, bool endOfMessage,
            out int consumed, out int written, out bool needsMoreOutput)
        {
            _stream ??= CreateDeflater();
 
            if (payload.Length == 0)
            {
                consumed = 0;
                written = 0;
            }
            else
            {
                UnsafeDeflate(payload, output, out consumed, out written, out needsMoreOutput);
 
                if (needsMoreOutput)
                {
                    Debug.Assert(written == output.Length);
                    return;
                }
            }
 
            written += UnsafeFlush(output.Slice(written), out needsMoreOutput);
 
            if (needsMoreOutput)
            {
                return;
            }
            Debug.Assert(output.Slice(written - WebSocketInflater.FlushMarkerLength, WebSocketInflater.FlushMarkerLength)
                               .EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker.");
 
            if (endOfMessage)
            {
                // As per RFC we need to remove the flush markers
                written -= WebSocketInflater.FlushMarkerLength;
            }
 
            if (endOfMessage && !_persisted)
            {
                _stream.Dispose();
                _stream = null;
            }
        }
 
        private unsafe void UnsafeDeflate(ReadOnlySpan<byte> input, Span<byte> output, out int consumed, out int written, out bool needsMoreBuffer)
        {
            Debug.Assert(_stream is not null);
 
            fixed (byte* fixedInput = input)
            fixed (byte* fixedOutput = output)
            {
                _stream.NextIn = (IntPtr)fixedInput;
                _stream.AvailIn = (uint)input.Length;
 
                _stream.NextOut = (IntPtr)fixedOutput;
                _stream.AvailOut = (uint)output.Length;
 
                // The flush is set to Z_NO_FLUSH, which allows deflate to decide
                // how much data to accumulate before producing output,
                // in order to maximize compression.
                var errorCode = Deflate(_stream, FlushCode.NoFlush);
 
                consumed = input.Length - (int)_stream.AvailIn;
                written = output.Length - (int)_stream.AvailOut;
 
                // It is important here to also check that we haven't
                // exhausted the output buffer because after deflating we're
                // always going to issue a flush and a flush with empty output
                // is going to throw.
                needsMoreBuffer = errorCode == ErrorCode.BufError
                    || _stream.AvailIn > 0
                    || written == output.Length;
            }
        }
 
        private unsafe int UnsafeFlush(Span<byte> output, out bool needsMoreBuffer)
        {
            Debug.Assert(_stream is not null);
            Debug.Assert(_stream.AvailIn == 0);
            Debug.Assert(output.Length > 0);
 
            fixed (byte* fixedOutput = output)
            {
                _stream.NextIn = IntPtr.Zero;
                _stream.AvailIn = 0;
 
                _stream.NextOut = (IntPtr)fixedOutput;
                _stream.AvailOut = (uint)output.Length;
 
                // We need to use Z_BLOCK_FLUSH to instruct the zlib to flush all outstanding
                // data but also not to emit a deflate block boundary. After we know that there is no
                // more data, we can safely proceed to instruct the library to emit the boundary markers.
                ErrorCode errorCode = Deflate(_stream, FlushCode.Block);
                Debug.Assert(errorCode is ErrorCode.Ok or ErrorCode.BufError);
 
                // We need at least 6 bytes to guarantee that we can emit a deflate block boundary.
                needsMoreBuffer = _stream.AvailOut < 6;
 
                if (!needsMoreBuffer)
                {
                    // The flush is set to Z_SYNC_FLUSH, all pending output is flushed
                    // to the output buffer and the output is aligned on a byte boundary,
                    // so that the decompressor can get all input data available so far.
                    // This completes the current deflate block and follows it with an empty
                    // stored block that is three bits plus filler bits to the next byte,
                    // followed by four bytes (00 00 ff ff).
                    errorCode = Deflate(_stream, FlushCode.SyncFlush);
                    Debug.Assert(errorCode == ErrorCode.Ok);
                }
 
                return output.Length - (int)_stream.AvailOut;
            }
        }
 
        private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode)
        {
            ErrorCode errorCode = stream.Deflate(flushCode);
 
            if (errorCode is ErrorCode.Ok or ErrorCode.StreamEnd or ErrorCode.BufError)
            {
                return errorCode;
            }
 
            string message = errorCode == ErrorCode.StreamError
                ? SR.ZLibErrorInconsistentStream
                : SR.Format(SR.ZLibErrorUnexpected, (int)errorCode);
            throw new WebSocketException(message);
        }
 
        private ZLibStreamHandle CreateDeflater()
        {
            ZLibStreamHandle? stream = null;
            ErrorCode errorCode;
            try
            {
                errorCode = CreateZLibStreamForDeflate(out stream,
                    level: CompressionLevel.DefaultCompression,
                    windowBits: _windowBits,
                    memLevel: Deflate_DefaultMemLevel,
                    strategy: CompressionStrategy.DefaultStrategy);
            }
            catch (Exception cause)
            {
                stream?.Dispose();
                throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause);
            }
 
            if (errorCode == ErrorCode.Ok)
            {
                return stream;
            }
 
            stream.Dispose();
 
            string message = errorCode == ErrorCode.MemError
                ? SR.ZLibErrorNotEnoughMemory
                : SR.Format(SR.ZLibErrorUnexpected, (int)errorCode);
            throw new WebSocketException(message);
        }
    }
}