File: System\Xml\XmlStreamNodeWriter.cs
Web Access
Project: src\src\libraries\System.Private.DataContractSerialization\src\System.Private.DataContractSerialization.csproj (System.Private.DataContractSerialization)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers.Binary;
using System.Diagnostics;
using System.IO;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Serialization;
using System.Text;
using System.Threading.Tasks;
 
namespace System.Xml
{
    internal abstract class XmlStreamNodeWriter : XmlNodeWriter
    {
        private readonly byte[] _buffer;
        private int _offset;
        private bool _ownsStream;
        private const int bufferLength = 512;
        private const int maxBytesPerChar = 3;
        private Encoding? _encoding;
 
        protected XmlStreamNodeWriter()
        {
            _buffer = new byte[bufferLength];
            OutputStream = null!; // Always initialized by SetOutput()
        }
 
        protected void SetOutput(Stream stream, bool ownsStream, Encoding? encoding)
        {
            OutputStream = stream;
            _ownsStream = ownsStream;
            _offset = 0;
 
            if (encoding != null)
                _encoding = encoding;
        }
 
        // Getting/Setting the Stream exists for fragmenting
        public Stream OutputStream { get; set; }
 
        // StreamBuffer/BufferOffset exists only for the BinaryWriter to fix up nodes
        public byte[] StreamBuffer
        {
            get
            {
                return _buffer;
            }
        }
        public int BufferOffset
        {
            get
            {
                return _offset;
            }
        }
 
        public int Position
        {
            get
            {
                return (int)OutputStream.Position + _offset;
            }
        }
 
        protected byte[] GetBuffer(int count, out int offset)
        {
            Debug.Assert(count >= 0 && count <= bufferLength);
            int bufferOffset = _offset;
            if (bufferOffset + count <= bufferLength)
            {
                offset = bufferOffset;
            }
            else
            {
                FlushBuffer();
                offset = 0;
            }
#if DEBUG
            Debug.Assert(offset + count <= bufferLength);
            for (int i = 0; i < count; i++)
            {
                _buffer[offset + i] = (byte)'<';
            }
#endif
            return _buffer;
        }
 
        protected async Task<BytesWithOffset> GetBufferAsync(int count)
        {
            int offset;
            Debug.Assert(count >= 0 && count <= bufferLength);
            int bufferOffset = _offset;
            if (bufferOffset + count <= bufferLength)
            {
                offset = bufferOffset;
            }
            else
            {
                await FlushBufferAsync().ConfigureAwait(false);
                offset = 0;
            }
#if DEBUG
            Debug.Assert(offset + count <= bufferLength);
            for (int i = 0; i < count; i++)
            {
                _buffer[offset + i] = (byte)'<';
            }
#endif
            return new BytesWithOffset(_buffer, offset);
        }
 
        protected void Advance(int count)
        {
            Debug.Assert(_offset + count <= bufferLength);
            _offset += count;
        }
 
        private void EnsureByte()
        {
            if (_offset >= bufferLength)
            {
                FlushBuffer();
            }
        }
 
        protected void WriteByte(byte b)
        {
            EnsureByte();
            _buffer[_offset++] = b;
        }
 
        protected Task WriteByteAsync(byte b)
        {
            if (_offset >= bufferLength)
            {
                return FlushBufferAndWriteByteAsync(b);
            }
            else
            {
                _buffer[_offset++] = b;
                return Task.CompletedTask;
            }
        }
 
        private async Task FlushBufferAndWriteByteAsync(byte b)
        {
            await FlushBufferAsync().ConfigureAwait(false);
            _buffer[_offset++] = b;
        }
 
        protected void WriteByte(char ch)
        {
            Debug.Assert(ch < 0x80);
            WriteByte((byte)ch);
        }
 
        protected Task WriteByteAsync(char ch)
        {
            Debug.Assert(ch < 0x80);
            return WriteByteAsync((byte)ch);
        }
 
        protected void WriteBytes(byte b1, byte b2)
        {
            byte[] buffer = _buffer;
            int offset = _offset;
            if (offset + 1 >= bufferLength)
            {
                FlushBuffer();
                offset = 0;
            }
            buffer[offset + 0] = b1;
            buffer[offset + 1] = b2;
            _offset += 2;
        }
 
        protected Task WriteBytesAsync(byte b1, byte b2)
        {
            if (_offset + 1 >= bufferLength)
            {
                return FlushAndWriteBytesAsync(b1, b2);
            }
            else
            {
                _buffer[_offset++] = b1;
                _buffer[_offset++] = b2;
                return Task.CompletedTask;
            }
        }
 
        private async Task FlushAndWriteBytesAsync(byte b1, byte b2)
        {
            await FlushBufferAsync().ConfigureAwait(false);
            _buffer[_offset++] = b1;
            _buffer[_offset++] = b2;
        }
 
        protected void WriteBytes(char ch1, char ch2)
        {
            Debug.Assert(ch1 < 0x80 && ch2 < 0x80);
            WriteBytes((byte)ch1, (byte)ch2);
        }
 
        protected Task WriteBytesAsync(char ch1, char ch2)
        {
            Debug.Assert(ch1 < 0x80 && ch2 < 0x80);
            return WriteBytesAsync((byte)ch1, (byte)ch2);
        }
 
        public void WriteBytes(byte[] byteBuffer, int byteOffset, int byteCount)
        {
            if (byteCount < bufferLength)
            {
                int offset;
                byte[] buffer = GetBuffer(byteCount, out offset);
                Buffer.BlockCopy(byteBuffer, byteOffset, buffer, offset, byteCount);
                Advance(byteCount);
            }
            else
            {
                FlushBuffer();
                OutputStream.Write(byteBuffer, byteOffset, byteCount);
            }
        }
 
        protected void WriteBytes(ReadOnlySpan<byte> bytes)
        {
            if (bytes.Length < bufferLength)
            {
                var buffer = GetBuffer(bytes.Length, out int offset).AsSpan(offset, bytes.Length);
                bytes.CopyTo(buffer);
                Advance(bytes.Length);
            }
            else
            {
                FlushBuffer();
                OutputStream.Write(bytes);
            }
        }
 
        protected unsafe void WriteUTF8Char(int ch)
        {
            if (ch < 0x80)
            {
                WriteByte((byte)ch);
            }
            else if (ch <= char.MaxValue)
            {
                char* chars = stackalloc char[1];
                chars[0] = (char)ch;
                UnsafeWriteUTF8Chars(chars, 1);
            }
            else
            {
                SurrogateChar surrogateChar = new SurrogateChar(ch);
                char* chars = stackalloc char[2];
                chars[0] = surrogateChar.HighChar;
                chars[1] = surrogateChar.LowChar;
                UnsafeWriteUTF8Chars(chars, 2);
            }
        }
 
        protected unsafe void WriteUTF8Chars(string value)
        {
            int count = value.Length;
            if (count > 0)
            {
                fixed (char* chars = value)
                {
                    UnsafeWriteUTF8Chars(chars, count);
                }
            }
        }
 
        protected void WriteUTF8Bytes(ReadOnlySpan<byte> value)
        {
            if (value.Length < bufferLength)
            {
                byte[] buffer = GetBuffer(value.Length, out int offset);
                value.CopyTo(buffer.AsSpan(offset));
                Advance(value.Length);
            }
            else
            {
                FlushBuffer();
                OutputStream.Write(value);
            }
        }
 
        protected unsafe void UnsafeWriteUTF8Chars(char* chars, int charCount)
        {
            const int charChunkSize = bufferLength / maxBytesPerChar;
            while (charCount > charChunkSize)
            {
                int offset;
                int chunkSize = charChunkSize;
                if ((int)(chars[chunkSize - 1] & 0xFC00) == 0xD800) // This is a high surrogate
                    chunkSize--;
                byte[] buffer = GetBuffer(chunkSize * maxBytesPerChar, out offset);
                Advance(UnsafeGetUTF8Chars(chars, chunkSize, buffer, offset));
                charCount -= chunkSize;
                chars += chunkSize;
            }
            if (charCount > 0)
            {
                int offset;
                byte[] buffer = GetBuffer(charCount * maxBytesPerChar, out offset);
                Advance(UnsafeGetUTF8Chars(chars, charCount, buffer, offset));
            }
        }
 
        protected unsafe void UnsafeWriteUnicodeChars(char* chars, int charCount)
        {
            const int charChunkSize = bufferLength / 2;
            while (charCount > charChunkSize)
            {
                int offset;
                int chunkSize = charChunkSize;
                if ((int)(chars[chunkSize - 1] & 0xFC00) == 0xD800) // This is a high surrogate
                    chunkSize--;
                byte[] buffer = GetBuffer(chunkSize * 2, out offset);
                Advance(UnsafeGetUnicodeChars(chars, chunkSize, buffer, offset));
                charCount -= chunkSize;
                chars += chunkSize;
            }
            if (charCount > 0)
            {
                int offset;
                byte[] buffer = GetBuffer(charCount * 2, out offset);
                Advance(UnsafeGetUnicodeChars(chars, charCount, buffer, offset));
            }
        }
 
        protected static unsafe int UnsafeGetUnicodeChars(char* chars, int charCount, byte[] buffer, int offset)
        {
            if (BitConverter.IsLittleEndian)
            {
                new ReadOnlySpan<char>(chars, charCount)
                    .CopyTo(MemoryMarshal.Cast<byte, char>(buffer.AsSpan(offset)));
            }
            else
            {
                BinaryPrimitives.ReverseEndianness(new ReadOnlySpan<short>(chars, charCount),
                    MemoryMarshal.Cast<byte, short>(buffer.AsSpan(offset)));
            }
 
            return charCount * 2;
        }
 
        protected unsafe int UnsafeGetUTF8Length(char* chars, int charCount)
        {
            // Length will always be at least ( 128 / maxBytesPerChar) = 42
            return (_encoding ?? DataContractSerializer.ValidatingUTF8).GetByteCount(chars, charCount);
        }
 
        protected unsafe int UnsafeGetUTF8Chars(char* chars, int charCount, byte[] buffer, int offset)
        {
            if (charCount > 0)
            {
                fixed (byte* _bytes = &buffer[offset])
                {
                    // Fast path for small strings, use Encoding.GetBytes for larger strings since it is faster when vectorization is possible
                    if (!Vector128.IsHardwareAccelerated || (uint)charCount < 32)
                    {
                        byte* bytes = _bytes;
                        char* charsMax = &chars[charCount];
 
                        while (chars < charsMax)
                        {
                            char t = *chars;
                            if (t >= 0x80)
                                goto NonAscii;
 
                            *bytes = (byte)t;
                            bytes++;
                            chars++;
                        }
                        return charCount;
 
                    NonAscii:
                        byte* bytesMax = _bytes + buffer.Length - offset;
                        return (int)(bytes - _bytes) + (_encoding ?? DataContractSerializer.ValidatingUTF8).GetBytes(chars, (int)(charsMax - chars), bytes, (int)(bytesMax - bytes));
                    }
                    else
                    {
                        return (_encoding ?? DataContractSerializer.ValidatingUTF8).GetBytes(chars, charCount, _bytes, buffer.Length - offset);
                    }
                }
            }
            return 0;
        }
 
        protected virtual void FlushBuffer()
        {
            if (_offset != 0)
            {
                OutputStream.Write(_buffer, 0, _offset);
                _offset = 0;
            }
        }
 
        protected virtual Task FlushBufferAsync()
        {
            if (_offset != 0)
            {
                var task = OutputStream.WriteAsync(_buffer, 0, _offset);
                _offset = 0;
                return task;
            }
 
            return Task.CompletedTask;
        }
 
        public override void Flush()
        {
            FlushBuffer();
            OutputStream.Flush();
        }
 
        public override async Task FlushAsync()
        {
            await FlushBufferAsync().ConfigureAwait(false);
            await OutputStream.FlushAsync().ConfigureAwait(false);
        }
 
        public override void Close()
        {
            if (OutputStream != null)
            {
                if (_ownsStream)
                {
                    OutputStream.Dispose();
                }
                OutputStream = null!;
            }
        }
    }
}