|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Runtime.Serialization; // For SR
using System.Text;
namespace System.Xml
{
// This wrapper does not support seek.
// Constructors consume/emit byte order mark.
// Supports: UTF-8, Unicode, BigEndianUnicode
// ASSUMPTION (Microsoft): This class will only be used for EITHER reading OR writing. It can be done, it would just mean more buffers.
// ASSUMPTION (Microsoft): The byte buffer is large enough to hold the declaration
// ASSUMPTION (Microsoft): The buffer manipulation methods (FillBuffer/Compare/etc.) will only be used to parse the declaration
// during construction.
internal sealed class EncodingStreamWrapper : Stream
{
private enum SupportedEncoding
{
UTF8,
UTF16LE,
UTF16BE,
None
}
private const int BufferLength = 128;
// UTF-8 is fastpath, so that's how these are stored
// Compare methods adapt to Unicode.
private static readonly byte[] s_encodingUTF8 = "utf-8"u8.ToArray();
private static readonly byte[] s_encodingUnicode = "utf-16"u8.ToArray();
private static readonly byte[] s_encodingUnicodeLE = "utf-16le"u8.ToArray();
private static readonly byte[] s_encodingUnicodeBE = "utf-16be"u8.ToArray();
private SupportedEncoding _encodingCode;
private Encoding? _encoding;
private readonly Encoder? _enc;
private readonly Decoder? _dec;
private readonly bool _isReading;
private readonly Stream _stream;
private char[]? _chars;
private byte[]? _bytes;
private int _byteOffset;
private int _byteCount;
private readonly byte[] _byteBuffer = new byte[1];
// Reading constructor
public EncodingStreamWrapper(Stream stream, Encoding? encoding)
{
try
{
_isReading = true;
_stream = stream;
// Decode the expected encoding
SupportedEncoding expectedEnc = GetSupportedEncoding(encoding);
// Get the byte order mark so we can determine the encoding
// May want to try to delay allocating everything until we know the BOM
SupportedEncoding declEnc = ReadBOMEncoding(encoding == null);
// Check that the expected encoding matches the decl encoding.
if (expectedEnc != SupportedEncoding.None && expectedEnc != declEnc)
ThrowExpectedEncodingMismatch(expectedEnc, declEnc);
// Fastpath: UTF-8 BOM
if (declEnc == SupportedEncoding.UTF8)
{
// Fastpath: UTF-8 BOM, No declaration
FillBuffer(2);
if (_bytes[_byteOffset + 1] != '?' || _bytes[_byteOffset] != '<')
{
return;
}
FillBuffer(BufferLength);
CheckUTF8DeclarationEncoding(_bytes, _byteOffset, _byteCount, declEnc, expectedEnc);
}
else
{
// Convert to UTF-8
EnsureBuffers();
FillBuffer((BufferLength - 1) * 2);
SetReadDocumentEncoding(declEnc);
CleanupCharBreak();
int count = _encoding.GetChars(_bytes, _byteOffset, _byteCount, _chars, 0);
_byteOffset = 0;
_byteCount = DataContractSerializer.ValidatingUTF8.GetBytes(_chars, 0, count, _bytes, 0);
// Check for declaration
if (_bytes[1] == '?' && _bytes[0] == '<')
{
CheckUTF8DeclarationEncoding(_bytes, 0, _byteCount, declEnc, expectedEnc);
}
else
{
// Declaration required if no out-of-band encoding
if (expectedEnc == SupportedEncoding.None)
throw new XmlException(SR.XmlDeclarationRequired);
}
}
}
catch (DecoderFallbackException ex)
{
throw new XmlException(SR.XmlInvalidBytes, ex);
}
}
[MemberNotNull(nameof(_encoding))]
private void SetReadDocumentEncoding(SupportedEncoding e)
{
EnsureBuffers();
_encodingCode = e;
_encoding = GetEncoding(e);
}
private static Encoding GetEncoding(SupportedEncoding e) =>
e switch
{
SupportedEncoding.UTF8 => DataContractSerializer.ValidatingUTF8,
SupportedEncoding.UTF16LE => DataContractSerializer.ValidatingUTF16,
SupportedEncoding.UTF16BE => DataContractSerializer.ValidatingBEUTF16,
_ => throw new XmlException(SR.XmlEncodingNotSupported),
};
private static Encoding GetSafeEncoding(SupportedEncoding e) =>
e switch
{
SupportedEncoding.UTF8 => DataContractSerializer.UTF8NoBom,
SupportedEncoding.UTF16LE => DataContractSerializer.UTF16NoBom,
SupportedEncoding.UTF16BE => DataContractSerializer.BEUTF16NoBom,
_ => throw new XmlException(SR.XmlEncodingNotSupported),
};
private static string GetEncodingName(SupportedEncoding enc) =>
enc switch
{
SupportedEncoding.UTF8 => "utf-8",
SupportedEncoding.UTF16LE => "utf-16LE",
SupportedEncoding.UTF16BE => "utf-16BE",
_ => throw new XmlException(SR.XmlEncodingNotSupported),
};
private static SupportedEncoding GetSupportedEncoding(Encoding? encoding)
{
if (encoding == null)
return SupportedEncoding.None;
else if (encoding.WebName == DataContractSerializer.ValidatingUTF8.WebName)
return SupportedEncoding.UTF8;
else if (encoding.WebName == DataContractSerializer.ValidatingUTF16.WebName)
return SupportedEncoding.UTF16LE;
else if (encoding.WebName == DataContractSerializer.ValidatingBEUTF16.WebName)
return SupportedEncoding.UTF16BE;
else
throw new XmlException(SR.XmlEncodingNotSupported);
}
// Writing constructor
public EncodingStreamWrapper(Stream stream, Encoding encoding, bool emitBOM)
{
_isReading = false;
_encoding = encoding;
_stream = stream;
// Set the encoding code
_encodingCode = GetSupportedEncoding(encoding);
if (_encodingCode != SupportedEncoding.UTF8)
{
EnsureBuffers();
_dec = DataContractSerializer.ValidatingUTF8.GetDecoder();
_enc = _encoding.GetEncoder();
// Emit BOM
if (emitBOM)
{
ReadOnlySpan<byte> bom = _encoding.Preamble;
if (bom.Length > 0)
_stream.Write(bom);
}
}
}
[MemberNotNull(nameof(_bytes))]
private SupportedEncoding ReadBOMEncoding(bool notOutOfBand)
{
int b1 = _stream.ReadByte();
int b2 = _stream.ReadByte();
int b3 = _stream.ReadByte();
int b4 = _stream.ReadByte();
// Premature end of stream
if (b4 == -1)
throw new XmlException(SR.UnexpectedEndOfFile);
int preserve;
SupportedEncoding e = ReadBOMEncoding((byte)b1, (byte)b2, (byte)b3, (byte)b4, notOutOfBand, out preserve);
EnsureByteBuffer();
switch (preserve)
{
case 1:
_bytes[0] = (byte)b4;
break;
case 2:
_bytes[0] = (byte)b3;
_bytes[1] = (byte)b4;
break;
case 4:
_bytes[0] = (byte)b1;
_bytes[1] = (byte)b2;
_bytes[2] = (byte)b3;
_bytes[3] = (byte)b4;
break;
}
_byteCount = preserve;
return e;
}
private static SupportedEncoding ReadBOMEncoding(byte b1, byte b2, byte b3, byte b4, bool notOutOfBand, out int preserve)
{
SupportedEncoding e = SupportedEncoding.UTF8; // Default
preserve = 0;
if (b1 == '<' && b2 != 0x00) // UTF-8, no BOM
{
e = SupportedEncoding.UTF8;
preserve = 4;
}
else if (b1 == 0xFF && b2 == 0xFE) // UTF-16 little endian
{
e = SupportedEncoding.UTF16LE;
preserve = 2;
}
else if (b1 == 0xFE && b2 == 0xFF) // UTF-16 big endian
{
e = SupportedEncoding.UTF16BE;
preserve = 2;
}
else if (b1 == 0x00 && b2 == '<') // UTF-16 big endian, no BOM
{
e = SupportedEncoding.UTF16BE;
if (notOutOfBand && (b3 != 0x00 || b4 != '?'))
throw new XmlException(SR.XmlDeclMissing);
preserve = 4;
}
else if (b1 == '<' && b2 == 0x00) // UTF-16 little endian, no BOM
{
e = SupportedEncoding.UTF16LE;
if (notOutOfBand && (b3 != '?' || b4 != 0x00))
throw new XmlException(SR.XmlDeclMissing);
preserve = 4;
}
else if (b1 == 0xEF && b2 == 0xBB) // UTF8 with BOM
{
// Encoding error
if (notOutOfBand && b3 != 0xBF)
throw new XmlException(SR.XmlBadBOM);
preserve = 1;
}
else // Assume UTF8
{
preserve = 4;
}
return e;
}
private void FillBuffer(int count)
{
count -= _byteCount;
if (count > 0)
{
_byteCount += _stream.ReadAtLeast(_bytes.AsSpan(_byteOffset + _byteCount, count), count, throwOnEndOfStream: false);
}
}
[MemberNotNull(nameof(_bytes))]
[MemberNotNull(nameof(_chars))]
private void EnsureBuffers()
{
EnsureByteBuffer();
_chars ??= new char[BufferLength];
}
[MemberNotNull(nameof(_bytes))]
private void EnsureByteBuffer()
{
if (_bytes != null)
return;
_bytes = new byte[BufferLength * 4];
_byteOffset = 0;
_byteCount = 0;
}
private static void CheckUTF8DeclarationEncoding(byte[] buffer, int offset, int count, SupportedEncoding e, SupportedEncoding expectedEnc)
{
byte quot = 0;
int encEq = -1;
int max = offset + Math.Min(count, BufferLength);
// Encoding should be second "=", abort at first "?"
int i;
int eq = 0;
for (i = offset + 2; i < max; i++) // Skip the "<?" so we don't get caught by the first "?"
{
if (quot != 0)
{
if (buffer[i] == quot)
{
quot = 0;
}
continue;
}
if (buffer[i] == (byte)'\'' || buffer[i] == (byte)'"')
{
quot = buffer[i];
}
else if (buffer[i] == (byte)'=')
{
if (eq == 1)
{
encEq = i;
break;
}
eq++;
}
else if (buffer[i] == (byte)'?') // Not legal character in a decl before second "="
{
break;
}
}
// No encoding found
if (encEq == -1)
{
if (e != SupportedEncoding.UTF8 && expectedEnc == SupportedEncoding.None)
throw new XmlException(SR.XmlDeclarationRequired);
return;
}
if (encEq < 28) // Earliest second "=" can appear
throw new XmlException(SR.XmlMalformedDecl);
// Back off whitespace
for (i = encEq - 1; IsWhitespace(buffer[i]); i--) ;
// Check for encoding attribute
if (!buffer.AsSpan(0, i + 1).EndsWith("encoding"u8))
{
if (e != SupportedEncoding.UTF8 && expectedEnc == SupportedEncoding.None)
throw new XmlException(SR.XmlDeclarationRequired);
return;
}
// Move ahead of whitespace
for (i = encEq + 1; i < max && IsWhitespace(buffer[i]); i++) ;
// Find the quotes
if (buffer[i] != '\'' && buffer[i] != '"')
throw new XmlException(SR.XmlMalformedDecl);
quot = buffer[i];
int q = i;
for (i = q + 1; buffer[i] != quot && i < max; ++i) ;
if (buffer[i] != quot)
throw new XmlException(SR.XmlMalformedDecl);
int encStart = q + 1;
int encCount = i - encStart;
// lookup the encoding
SupportedEncoding declEnc = e;
if (encCount == s_encodingUTF8.Length && CompareCaseInsensitive(s_encodingUTF8, buffer, encStart))
{
declEnc = SupportedEncoding.UTF8;
}
else if (encCount == s_encodingUnicodeLE.Length && CompareCaseInsensitive(s_encodingUnicodeLE, buffer, encStart))
{
declEnc = SupportedEncoding.UTF16LE;
}
else if (encCount == s_encodingUnicodeBE.Length && CompareCaseInsensitive(s_encodingUnicodeBE, buffer, encStart))
{
declEnc = SupportedEncoding.UTF16BE;
}
else if (encCount == s_encodingUnicode.Length && CompareCaseInsensitive(s_encodingUnicode, buffer, encStart))
{
if (e == SupportedEncoding.UTF8)
ThrowEncodingMismatch(DataContractSerializer.UTF8NoBom.GetString(buffer, encStart, encCount), DataContractSerializer.UTF8NoBom.GetString(s_encodingUTF8, 0, s_encodingUTF8.Length));
}
else
{
ThrowEncodingMismatch(DataContractSerializer.UTF8NoBom.GetString(buffer, encStart, encCount), e);
}
if (e != declEnc)
ThrowEncodingMismatch(DataContractSerializer.UTF8NoBom.GetString(buffer, encStart, encCount), e);
}
private static bool CompareCaseInsensitive(byte[] key, byte[] buffer, int offset)
{
for (int i = 0; i < key.Length; i++)
{
if (key[i] == buffer[offset + i])
continue;
if (key[i] != char.ToLowerInvariant((char)buffer[offset + i]))
return false;
}
return true;
}
private static bool IsWhitespace(byte ch)
{
return ch == (byte)' ' || ch == (byte)'\n' || ch == (byte)'\t' || ch == (byte)'\r';
}
internal static ArraySegment<byte> ProcessBuffer(byte[] buffer, int offset, int count, Encoding? encoding)
{
if (count < 4)
throw new XmlException(SR.UnexpectedEndOfFile);
try
{
int preserve;
ArraySegment<byte> seg;
SupportedEncoding expectedEnc = GetSupportedEncoding(encoding);
SupportedEncoding declEnc = ReadBOMEncoding(buffer[offset], buffer[offset + 1], buffer[offset + 2], buffer[offset + 3], encoding == null, out preserve);
if (expectedEnc != SupportedEncoding.None && expectedEnc != declEnc)
ThrowExpectedEncodingMismatch(expectedEnc, declEnc);
offset += 4 - preserve;
count -= 4 - preserve;
// Fastpath: UTF-8
char[] chars;
byte[] bytes;
Encoding localEnc;
if (declEnc == SupportedEncoding.UTF8)
{
// Fastpath: No declaration
if (buffer[offset + 1] != '?' || buffer[offset] != '<')
{
seg = new ArraySegment<byte>(buffer, offset, count);
return seg;
}
CheckUTF8DeclarationEncoding(buffer, offset, count, declEnc, expectedEnc);
seg = new ArraySegment<byte>(buffer, offset, count);
return seg;
}
// Convert to UTF-8
localEnc = GetSafeEncoding(declEnc);
int inputCount = Math.Min(count, BufferLength * 2);
chars = new char[localEnc.GetMaxCharCount(inputCount)];
int ccount = localEnc.GetChars(buffer, offset, inputCount, chars, 0);
bytes = new byte[DataContractSerializer.ValidatingUTF8.GetMaxByteCount(ccount)];
int bcount = DataContractSerializer.ValidatingUTF8.GetBytes(chars, 0, ccount, bytes, 0);
// Check for declaration
if (bytes[1] == '?' && bytes[0] == '<')
{
CheckUTF8DeclarationEncoding(bytes, 0, bcount, declEnc, expectedEnc);
}
else
{
// Declaration required if no out-of-band encoding
if (expectedEnc == SupportedEncoding.None)
throw new XmlException(SR.XmlDeclarationRequired);
}
seg = new ArraySegment<byte>(DataContractSerializer.ValidatingUTF8.GetBytes(GetEncoding(declEnc).GetChars(buffer, offset, count)));
return seg;
}
catch (DecoderFallbackException e)
{
throw new XmlException(SR.XmlInvalidBytes, e);
}
}
private static void ThrowExpectedEncodingMismatch(SupportedEncoding expEnc, SupportedEncoding actualEnc)
{
throw new XmlException(SR.Format(SR.XmlExpectedEncoding, GetEncodingName(expEnc), GetEncodingName(actualEnc)));
}
private static void ThrowEncodingMismatch(string declEnc, SupportedEncoding enc)
{
ThrowEncodingMismatch(declEnc, GetEncodingName(enc));
}
private static void ThrowEncodingMismatch(string declEnc, string docEnc)
{
throw new XmlException(SR.Format(SR.XmlEncodingMismatch, declEnc, docEnc));
}
// This stream wrapper does not support duplex
public override bool CanRead
{
get
{
if (!_isReading)
return false;
return _stream.CanRead;
}
}
// The encoding conversion and buffering breaks seeking.
public override bool CanSeek
{
get
{
return false;
}
}
// This stream wrapper does not support duplex
public override bool CanWrite
{
get
{
if (_isReading)
return false;
return _stream.CanWrite;
}
}
// The encoding conversion and buffering breaks seeking.
public override long Position
{
get
{
throw new NotSupportedException();
}
set
{
throw new NotSupportedException();
}
}
public override void Close()
{
if (_stream.CanWrite)
{
Flush();
}
base.Close();
_stream.Dispose();
}
public override void Flush()
{
_stream.Flush();
}
public override int ReadByte()
{
if (_byteCount == 0 && _encodingCode == SupportedEncoding.UTF8)
return _stream.ReadByte();
if (Read(_byteBuffer, 0, 1) == 0)
return -1;
return _byteBuffer[0];
}
public override int Read(byte[] buffer, int offset, int count) =>
Read(new Span<byte>(buffer, offset, count));
public override int Read(Span<byte> buffer)
{
try
{
if (_byteCount == 0)
{
if (_encodingCode == SupportedEncoding.UTF8)
return _stream.Read(buffer);
Debug.Assert(_bytes != null);
Debug.Assert(_chars != null);
// No more bytes than can be turned into characters
_byteOffset = 0;
_byteCount = _stream.Read(_bytes, _byteCount, (_chars.Length - 1) * 2);
// Check for end of stream
if (_byteCount == 0)
return 0;
// Fix up incomplete chars
CleanupCharBreak();
// Change encoding
int charCount = _encoding!.GetChars(_bytes, 0, _byteCount, _chars, 0);
_byteCount = Encoding.UTF8.GetBytes(_chars, 0, charCount, _bytes, 0);
}
// Give them bytes
int count = buffer.Length;
if (_byteCount < count)
count = _byteCount;
_bytes.AsSpan(_byteOffset, count).CopyTo(buffer);
_byteOffset += count;
_byteCount -= count;
return count;
}
catch (DecoderFallbackException ex)
{
throw new XmlException(SR.XmlInvalidBytes, ex);
}
}
private void CleanupCharBreak()
{
Debug.Assert(_bytes != null);
int max = _byteOffset + _byteCount;
// Read on 2 byte boundaries
if ((_byteCount % 2) != 0)
{
int b = _stream.ReadByte();
if (b < 0)
throw new XmlException(SR.UnexpectedEndOfFile);
_bytes[max++] = (byte)b;
_byteCount++;
}
// Don't cut off a surrogate character
int w;
if (_encodingCode == SupportedEncoding.UTF16LE)
{
w = _bytes[max - 2] + (_bytes[max - 1] << 8);
}
else
{
w = _bytes[max - 1] + (_bytes[max - 2] << 8);
}
if ((w & 0xDC00) != 0xDC00 && w >= 0xD800 && w <= 0xDBFF) // First 16-bit number of surrogate pair
{
int b1 = _stream.ReadByte();
int b2 = _stream.ReadByte();
if (b2 < 0)
throw new XmlException(SR.UnexpectedEndOfFile);
_bytes[max++] = (byte)b1;
_bytes[max++] = (byte)b2;
_byteCount += 2;
}
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void WriteByte(byte b)
{
if (_encodingCode == SupportedEncoding.UTF8)
{
_stream.WriteByte(b);
return;
}
_byteBuffer[0] = b;
Write(_byteBuffer, 0, 1);
}
public override void Write(byte[] buffer, int offset, int count)
{
// Optimize UTF-8 case
if (_encodingCode == SupportedEncoding.UTF8)
{
_stream.Write(buffer, offset, count);
return;
}
Debug.Assert(_bytes != null);
Debug.Assert(_chars != null);
while (count > 0)
{
int size = _chars.Length < count ? _chars.Length : count;
int charCount = _dec!.GetChars(buffer, offset, size, _chars, 0, false);
_byteCount = _enc!.GetBytes(_chars, 0, charCount, _bytes, 0, false);
_stream.Write(_bytes, 0, _byteCount);
offset += size;
count -= size;
}
}
// Delegate properties
public override bool CanTimeout { get { return _stream.CanTimeout; } }
public override long Length { get { return _stream.Length; } }
public override int ReadTimeout
{
get { return _stream.ReadTimeout; }
set { _stream.ReadTimeout = value; }
}
public override int WriteTimeout
{
get { return _stream.WriteTimeout; }
set { _stream.WriteTimeout = value; }
}
// Delegate methods
public override void SetLength(long value)
{
throw new NotSupportedException();
}
}
}
|