|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
namespace System.IO.Compression
{
internal sealed class InflaterManaged
{
// const tables used in decoding:
// Extra bits for length code 257 - 285.
private static ReadOnlySpan<byte> ExtraLengthBits =>
[
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3,
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 16
];
// The base length for length code 257 - 285.
// The formula to get the real length for a length code is lengthBase[code - 257] + (value stored in extraBits)
private static ReadOnlySpan<byte> LengthBase =>
[
3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51,
59, 67, 83, 99, 115, 131, 163, 195, 227, 3
];
// The base distance for distance code 0 - 31
// The real distance for a distance code is distanceBasePosition[code] + (value stored in extraBits)
private static ReadOnlySpan<ushort> DistanceBasePosition =>
[
1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513,
769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 32769, 49153
];
// code lengths for code length alphabet is stored in following order
private static ReadOnlySpan<byte> CodeOrder => [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15];
private static ReadOnlySpan<byte> StaticDistanceTreeTable =>
[
0x00, 0x10, 0x08, 0x18, 0x04, 0x14, 0x0c, 0x1c, 0x02, 0x12, 0x0a, 0x1a,
0x06, 0x16, 0x0e, 0x1e, 0x01, 0x11, 0x09, 0x19, 0x05, 0x15, 0x0d, 0x1d,
0x03, 0x13, 0x0b, 0x1b, 0x07, 0x17, 0x0f, 0x1f
];
private readonly OutputWindow _output;
private readonly InputBuffer _input;
private HuffmanTree? _literalLengthTree;
private HuffmanTree? _distanceTree;
private InflaterState _state;
private int _bfinal;
private BlockType _blockType;
// uncompressed block
private readonly byte[] _blockLengthBuffer = new byte[4];
private int _blockLength;
// compressed block
private int _length;
private int _distanceCode;
private int _extraBits;
private int _loopCounter;
private int _literalLengthCodeCount;
private int _distanceCodeCount;
private int _codeLengthCodeCount;
private int _codeArraySize;
private int _lengthCode;
private readonly byte[] _codeList; // temporary array to store the code length for literal/Length and distance
private readonly byte[] _codeLengthTreeCodeLength;
private readonly bool _deflate64;
private HuffmanTree? _codeLengthTree;
private readonly long _uncompressedSize;
private long _currentInflatedCount;
internal InflaterManaged(bool deflate64, long uncompressedSize)
{
_output = new OutputWindow();
_input = new InputBuffer();
_codeList = new byte[HuffmanTree.MaxLiteralTreeElements + HuffmanTree.MaxDistTreeElements];
_codeLengthTreeCodeLength = new byte[HuffmanTree.NumberOfCodeLengthTreeElements];
_deflate64 = deflate64;
_uncompressedSize = uncompressedSize;
_state = InflaterState.ReadingBFinal; // start by reading BFinal bit
}
public void SetInput(Memory<byte> inputBytes) => _input.SetInput(inputBytes);
public void SetInput(byte[] inputBytes, int offset, int length) =>
_input.SetInput(inputBytes, offset, length); // append the bytes
public bool Finished() => _state == InflaterState.Done || _state == InflaterState.VerifyingFooter;
public int AvailableOutput => _output.AvailableBytes;
public int Inflate(Span<byte> bytes)
{
// copy bytes from output to outputbytes if we have available bytes
// if buffer is not filled up. keep decoding until no input are available
// if decodeBlock returns false. Throw an exception.
int count = 0;
do
{
int copied = 0;
if (_uncompressedSize == -1)
{
copied = _output.CopyTo(bytes);
}
else
{
if (_uncompressedSize > _currentInflatedCount)
{
bytes = bytes.Slice(0, (int)Math.Min(bytes.Length, _uncompressedSize - _currentInflatedCount));
copied = _output.CopyTo(bytes);
_currentInflatedCount += copied;
}
else
{
_state = InflaterState.Done;
_output.ClearBytesUsed();
}
}
if (copied > 0)
{
bytes = bytes.Slice(copied);
count += copied;
}
if (bytes.IsEmpty)
{
// filled in the bytes buffer
break;
}
// Decode will return false when more input is needed
} while (!Finished() && Decode());
return count;
}
public int Inflate(byte[] bytes, int offset, int length) => Inflate(bytes.AsSpan(offset, length));
//Each block of compressed data begins with 3 header bits
// containing the following data:
// first bit BFINAL
// next 2 bits BTYPE
// Note that the header bits do not necessarily begin on a byte
// boundary, since a block does not necessarily occupy an integral
// number of bytes.
// BFINAL is set if and only if this is the last block of the data
// set.
// BTYPE specifies how the data are compressed, as follows:
// 00 - no compression
// 01 - compressed with fixed Huffman codes
// 10 - compressed with dynamic Huffman codes
// 11 - reserved (error)
// The only difference between the two compressed cases is how the
// Huffman codes for the literal/length and distance alphabets are
// defined.
//
// This function returns true for success (end of block or output window is full,)
// false if we are short of input
//
private bool Decode()
{
bool eob = false;
bool result;
if (Finished())
{
return true;
}
if (_state == InflaterState.ReadingBFinal)
{
// reading bfinal bit
// Need 1 bit
if (!_input.EnsureBitsAvailable(1))
return false;
_bfinal = _input.GetBits(1);
_state = InflaterState.ReadingBType;
}
if (_state == InflaterState.ReadingBType)
{
// Need 2 bits
if (!_input.EnsureBitsAvailable(2))
{
_state = InflaterState.ReadingBType;
return false;
}
_blockType = (BlockType)_input.GetBits(2);
if (_blockType == BlockType.Dynamic)
{
_state = InflaterState.ReadingNumLitCodes;
}
else if (_blockType == BlockType.Static)
{
_literalLengthTree = HuffmanTree.StaticLiteralLengthTree;
_distanceTree = HuffmanTree.StaticDistanceTree;
_state = InflaterState.DecodeTop;
}
else if (_blockType == BlockType.Uncompressed)
{
_state = InflaterState.UncompressedAligning;
}
else
{
throw new InvalidDataException(SR.UnknownBlockType);
}
}
if (_blockType == BlockType.Dynamic)
{
if (_state < InflaterState.DecodeTop)
{
// we are reading the header
result = DecodeDynamicBlockHeader();
}
else
{
result = DecodeBlock(out eob); // this can returns true when output is full
}
}
else if (_blockType == BlockType.Static)
{
result = DecodeBlock(out eob);
}
else if (_blockType == BlockType.Uncompressed)
{
result = DecodeUncompressedBlock(out eob);
}
else
{
throw new InvalidDataException(SR.UnknownBlockType);
}
//
// If we reached the end of the block and the block we were decoding had
// bfinal=1 (final block)
//
if (eob && (_bfinal != 0))
{
_state = InflaterState.Done;
}
return result;
}
// Format of Non-compressed blocks (BTYPE=00):
//
// Any bits of input up to the next byte boundary are ignored.
// The rest of the block consists of the following information:
//
// 0 1 2 3 4...
// +---+---+---+---+================================+
// | LEN | NLEN |... LEN bytes of literal data...|
// +---+---+---+---+================================+
//
// LEN is the number of data bytes in the block. NLEN is the
// one's complement of LEN.
private bool DecodeUncompressedBlock(out bool end_of_block)
{
end_of_block = false;
while (true)
{
switch (_state)
{
case InflaterState.UncompressedAligning: // initial state when calling this function
// we must skip to a byte boundary
_input.SkipToByteBoundary();
_state = InflaterState.UncompressedByte1;
goto case InflaterState.UncompressedByte1;
case InflaterState.UncompressedByte1: // decoding block length
case InflaterState.UncompressedByte2:
case InflaterState.UncompressedByte3:
case InflaterState.UncompressedByte4:
int bits = _input.GetBits(8);
if (bits < 0)
{
return false;
}
_blockLengthBuffer[_state - InflaterState.UncompressedByte1] = (byte)bits;
if (_state == InflaterState.UncompressedByte4)
{
_blockLength = _blockLengthBuffer[0] + ((int)_blockLengthBuffer[1]) * 256;
int blockLengthComplement = _blockLengthBuffer[2] + ((int)_blockLengthBuffer[3]) * 256;
// make sure complement matches
if ((ushort)_blockLength != (ushort)(~blockLengthComplement))
{
throw new InvalidDataException(SR.InvalidBlockLength);
}
}
_state += 1;
break;
case InflaterState.DecodingUncompressed: // copying block data
// Directly copy bytes from input to output.
int bytesCopied = _output.CopyFrom(_input, _blockLength);
_blockLength -= bytesCopied;
if (_blockLength == 0)
{
// Done with this block, need to re-init bit buffer for next block
_state = InflaterState.ReadingBFinal;
end_of_block = true;
return true;
}
// We can fail to copy all bytes for two reasons:
// Running out of Input
// running out of free space in output window
if (_output.FreeBytes == 0)
{
return true;
}
return false;
default:
Debug.Fail("check why we are here!");
throw new InvalidDataException(SR.UnknownState);
}
}
}
private bool DecodeBlock(out bool end_of_block_code_seen)
{
end_of_block_code_seen = false;
int freeBytes = _output.FreeBytes; // it is a little bit faster than frequently accessing the property
while (freeBytes > 65536)
{
// With Deflate64 we can have up to a 64kb length, so we ensure at least that much space is available
// in the OutputWindow to avoid overwriting previous unflushed output data.
int symbol;
switch (_state)
{
case InflaterState.DecodeTop:
// decode an element from the literal tree
Debug.Assert(_literalLengthTree != null);
// TODO: optimize this!!!
symbol = _literalLengthTree.GetNextSymbol(_input);
if (symbol < 0)
{
// running out of input
return false;
}
if (symbol < 256)
{
// literal
_output.Write((byte)symbol);
--freeBytes;
}
else if (symbol == 256)
{
// end of block
end_of_block_code_seen = true;
// Reset state
_state = InflaterState.ReadingBFinal;
return true;
}
else
{
// length/distance pair
symbol -= 257; // length code started at 257
if (symbol < 8)
{
symbol += 3; // match length = 3,4,5,6,7,8,9,10
_extraBits = 0;
}
else if (!_deflate64 && symbol == 28)
{
// extra bits for code 285 is 0
symbol = 258; // code 285 means length 258
_extraBits = 0;
}
else
{
if ((uint)symbol >= ExtraLengthBits.Length)
{
throw new InvalidDataException(SR.GenericInvalidData);
}
_extraBits = ExtraLengthBits[symbol];
Debug.Assert(_extraBits != 0, "We handle other cases separately!");
}
_length = symbol;
goto case InflaterState.HaveInitialLength;
}
break;
case InflaterState.HaveInitialLength:
if (_extraBits > 0)
{
_state = InflaterState.HaveInitialLength;
int bits = _input.GetBits(_extraBits);
if (bits < 0)
{
return false;
}
if (_length < 0 || _length >= LengthBase.Length)
{
throw new InvalidDataException(SR.GenericInvalidData);
}
_length = LengthBase[_length] + bits;
}
_state = InflaterState.HaveFullLength;
goto case InflaterState.HaveFullLength;
case InflaterState.HaveFullLength:
if (_blockType == BlockType.Dynamic)
{
Debug.Assert(_distanceTree != null);
_distanceCode = _distanceTree.GetNextSymbol(_input);
}
else
{
// get distance code directly for static block
_distanceCode = _input.GetBits(5);
if (_distanceCode >= 0)
{
_distanceCode = StaticDistanceTreeTable[_distanceCode];
}
}
if (_distanceCode < 0)
{
// running out input
return false;
}
_state = InflaterState.HaveDistCode;
goto case InflaterState.HaveDistCode;
case InflaterState.HaveDistCode:
// To avoid a table lookup we note that for distanceCode > 3,
// extra_bits = (distanceCode-2) >> 1
int offset;
if (_distanceCode > 3)
{
_extraBits = (_distanceCode - 2) >> 1;
int bits = _input.GetBits(_extraBits);
if (bits < 0)
{
return false;
}
offset = DistanceBasePosition[_distanceCode] + bits;
}
else
{
offset = _distanceCode + 1;
}
_output.WriteLengthDistance(_length, offset);
freeBytes -= _length;
_state = InflaterState.DecodeTop;
break;
default:
Debug.Fail("check why we are here!");
throw new InvalidDataException(SR.UnknownState);
}
}
return true;
}
// Format of the dynamic block header:
// 5 Bits: HLIT, # of Literal/Length codes - 257 (257 - 286)
// 5 Bits: HDIST, # of Distance codes - 1 (1 - 32)
// 4 Bits: HCLEN, # of Code Length codes - 4 (4 - 19)
//
// (HCLEN + 4) x 3 bits: code lengths for the code length
// alphabet given just above, in the order: 16, 17, 18,
// 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
//
// These code lengths are interpreted as 3-bit integers
// (0-7); as above, a code length of 0 means the
// corresponding symbol (literal/length or distance code
// length) is not used.
//
// HLIT + 257 code lengths for the literal/length alphabet,
// encoded using the code length Huffman code
//
// HDIST + 1 code lengths for the distance alphabet,
// encoded using the code length Huffman code
//
// The code length repeat codes can cross from HLIT + 257 to the
// HDIST + 1 code lengths. In other words, all code lengths form
// a single sequence of HLIT + HDIST + 258 values.
private bool DecodeDynamicBlockHeader()
{
switch (_state)
{
case InflaterState.ReadingNumLitCodes:
_literalLengthCodeCount = _input.GetBits(5);
if (_literalLengthCodeCount < 0)
{
return false;
}
_literalLengthCodeCount += 257;
_state = InflaterState.ReadingNumDistCodes;
goto case InflaterState.ReadingNumDistCodes;
case InflaterState.ReadingNumDistCodes:
_distanceCodeCount = _input.GetBits(5);
if (_distanceCodeCount < 0)
{
return false;
}
_distanceCodeCount += 1;
_state = InflaterState.ReadingNumCodeLengthCodes;
goto case InflaterState.ReadingNumCodeLengthCodes;
case InflaterState.ReadingNumCodeLengthCodes:
_codeLengthCodeCount = _input.GetBits(4);
if (_codeLengthCodeCount < 0)
{
return false;
}
_codeLengthCodeCount += 4;
_loopCounter = 0;
_state = InflaterState.ReadingCodeLengthCodes;
goto case InflaterState.ReadingCodeLengthCodes;
case InflaterState.ReadingCodeLengthCodes:
while (_loopCounter < _codeLengthCodeCount)
{
int bits = _input.GetBits(3);
if (bits < 0)
{
return false;
}
_codeLengthTreeCodeLength[CodeOrder[_loopCounter]] = (byte)bits;
++_loopCounter;
}
for (int i = _codeLengthCodeCount; i < CodeOrder.Length; i++)
{
_codeLengthTreeCodeLength[CodeOrder[i]] = 0;
}
// create huffman tree for code length
_codeLengthTree = new HuffmanTree(_codeLengthTreeCodeLength);
_codeArraySize = _literalLengthCodeCount + _distanceCodeCount;
_loopCounter = 0; // reset loop count
_state = InflaterState.ReadingTreeCodesBefore;
goto case InflaterState.ReadingTreeCodesBefore;
case InflaterState.ReadingTreeCodesBefore:
case InflaterState.ReadingTreeCodesAfter:
while (_loopCounter < _codeArraySize)
{
if (_state == InflaterState.ReadingTreeCodesBefore)
{
Debug.Assert(_codeLengthTree != null);
if ((_lengthCode = _codeLengthTree.GetNextSymbol(_input)) < 0)
{
return false;
}
}
// The alphabet for code lengths is as follows:
// 0 - 15: Represent code lengths of 0 - 15
// 16: Copy the previous code length 3 - 6 times.
// The next 2 bits indicate repeat length
// (0 = 3, ... , 3 = 6)
// Example: Codes 8, 16 (+2 bits 11),
// 16 (+2 bits 10) will expand to
// 12 code lengths of 8 (1 + 6 + 5)
// 17: Repeat a code length of 0 for 3 - 10 times.
// (3 bits of length)
// 18: Repeat a code length of 0 for 11 - 138 times
// (7 bits of length)
if (_lengthCode <= 15)
{
_codeList[_loopCounter++] = (byte)_lengthCode;
}
else
{
int repeatCount;
if (_lengthCode == 16)
{
if (!_input.EnsureBitsAvailable(2))
{
_state = InflaterState.ReadingTreeCodesAfter;
return false;
}
if (_loopCounter == 0)
{
// can't have "prev code" on first code
throw new InvalidDataException();
}
byte previousCode = _codeList[_loopCounter - 1];
repeatCount = _input.GetBits(2) + 3;
if (_loopCounter + repeatCount > _codeArraySize)
{
throw new InvalidDataException();
}
for (int j = 0; j < repeatCount; j++)
{
_codeList[_loopCounter++] = previousCode;
}
}
else if (_lengthCode == 17)
{
if (!_input.EnsureBitsAvailable(3))
{
_state = InflaterState.ReadingTreeCodesAfter;
return false;
}
repeatCount = _input.GetBits(3) + 3;
if (_loopCounter + repeatCount > _codeArraySize)
{
throw new InvalidDataException();
}
for (int j = 0; j < repeatCount; j++)
{
_codeList[_loopCounter++] = 0;
}
}
else
{
// code == 18
if (!_input.EnsureBitsAvailable(7))
{
_state = InflaterState.ReadingTreeCodesAfter;
return false;
}
repeatCount = _input.GetBits(7) + 11;
if (_loopCounter + repeatCount > _codeArraySize)
{
throw new InvalidDataException();
}
for (int j = 0; j < repeatCount; j++)
{
_codeList[_loopCounter++] = 0;
}
}
}
_state = InflaterState.ReadingTreeCodesBefore; // we want to read the next code.
}
break;
default:
Debug.Fail("check why we are here!");
throw new InvalidDataException(SR.UnknownState);
}
byte[] literalTreeCodeLength = new byte[HuffmanTree.MaxLiteralTreeElements];
byte[] distanceTreeCodeLength = new byte[HuffmanTree.MaxDistTreeElements];
// Create literal and distance tables
Array.Copy(_codeList, literalTreeCodeLength, _literalLengthCodeCount);
Array.Copy(_codeList, _literalLengthCodeCount, distanceTreeCodeLength, 0, _distanceCodeCount);
// Make sure there is an end-of-block code, otherwise how could we ever end?
if (literalTreeCodeLength[HuffmanTree.EndOfBlockCode] == 0)
{
throw new InvalidDataException();
}
_literalLengthTree = new HuffmanTree(literalTreeCodeLength);
_distanceTree = new HuffmanTree(distanceTreeCodeLength);
_state = InflaterState.DecodeTop;
return true;
}
}
}
|