File: Utils\Helpers.cs
Web Access
Project: src\src\Microsoft.ML.Tokenizers\Microsoft.ML.Tokenizers.csproj (Microsoft.ML.Tokenizers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Buffers;
using System.Diagnostics;
using System.Text;
 
namespace Microsoft.ML.Tokenizers
{
    internal static partial class Helpers
    {
        internal static void ArrayPoolGrow<T>(ref T[] arrayPoolArray, int requiredCapacity)
        {
            T[] tmp = ArrayPool<T>.Shared.Rent(Math.Max(arrayPoolArray.Length * 2, requiredCapacity));
            arrayPoolArray.CopyTo(tmp.AsSpan());
            ArrayPool<T>.Shared.Return(arrayPoolArray);
            arrayPoolArray = tmp;
        }
 
        internal static int EncodeToUtf8(ReadOnlySpan<char> text, Span<byte> destination, Span<int> indexMapping)
        {
            Debug.Assert(!text.IsEmpty);
            Debug.Assert(Encoding.UTF8.GetMaxByteCount(text.Length) <= destination.Length);
            Debug.Assert(indexMapping.Length >= destination.Length);
 
            int targetIndex = 0;
 
            for (int i = 0; i < text.Length; i++)
            {
                uint c = (uint)text[i];
                if (c <= 0x7Fu)
                {
                    destination[targetIndex] = (byte)c;
                    indexMapping[targetIndex] = i;
                    targetIndex++;
                    continue;
                }
 
                if (c <= 0x7FFu)
                {
                    // Scalar 00000yyy yyxxxxxx -> bytes [ 110yyyyy 10xxxxxx ]
                    destination[targetIndex] = (byte)((c + (0b110u << 11)) >> 6);
                    destination[targetIndex + 1] = (byte)((c & 0x3Fu) + 0x80u);
                    indexMapping[targetIndex] = indexMapping[targetIndex + 1] = i;
                    targetIndex += 2;
                    continue;
                }
 
                if (i < text.Length - 1 && char.IsSurrogatePair((char)c, text[i + 1]))
                {
                    // Scalar 000uuuuu zzzzyyyy yyxxxxxx -> bytes [ 11110uuu 10uuzzzz 10yyyyyy 10xxxxxx ]
                    uint value = (uint)char.ConvertToUtf32((char)c, text[i + 1]);
                    destination[targetIndex] = (byte)((value + (0b11110 << 21)) >> 18);
                    destination[targetIndex + 1] = (byte)(((value & (0x3Fu << 12)) >> 12) + 0x80u);
                    destination[targetIndex + 2] = (byte)(((value & (0x3Fu << 6)) >> 6) + 0x80u);
                    destination[targetIndex + 3] = (byte)((value & 0x3Fu) + 0x80u);
                    indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = indexMapping[targetIndex + 3] = i;
                    i++;
                    targetIndex += 4;
                    continue;
                }
 
                // Scalar zzzzyyyy yyxxxxxx -> bytes [ 1110zzzz 10yyyyyy 10xxxxxx ]
                destination[targetIndex] = (byte)((c + (0b1110 << 16)) >> 12);
                destination[targetIndex + 1] = (byte)(((c & (0x3Fu << 6)) >> 6) + 0x80u);
                destination[targetIndex + 2] = (byte)((c & 0x3Fu) + 0x80u);
                indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = i;
                targetIndex += 3;
            }
 
            return targetIndex;
        }
 
        internal static int EncodeToUtf8AndTransform(ReadOnlySpan<char> text, Span<char> destination, Span<int> indexMapping)
        {
            Debug.Assert(!text.IsEmpty);
            Debug.Assert(Encoding.UTF8.GetMaxByteCount(text.Length) <= destination.Length);
            Debug.Assert(indexMapping.Length >= destination.Length);
 
            ByteToUnicodeEncoding byteToUnicodeEncoder = ByteToUnicodeEncoding.Instance;
            int targetIndex = 0;
 
            for (int i = 0; i < text.Length; i++)
            {
                uint c = (uint)text[i];
                if (c <= 0x7Fu)
                {
                    destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)c];
                    indexMapping[targetIndex] = i;
                    targetIndex++;
                    continue;
                }
 
                if (c <= 0x7FFu)
                {
                    // Scalar 00000yyy yyxxxxxx -> bytes [ 110yyyyy 10xxxxxx ]
                    destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)((c + (0b110u << 11)) >> 6)];
                    destination[targetIndex + 1] = byteToUnicodeEncoder.ByteToUnicode[(char)((c & 0x3Fu) + 0x80u)];
                    indexMapping[targetIndex] = indexMapping[targetIndex + 1] = i;
                    targetIndex += 2;
                    continue;
                }
 
                if (i < text.Length - 1 && char.IsSurrogatePair((char)c, text[i + 1]))
                {
                    // Scalar 000uuuuu zzzzyyyy yyxxxxxx -> bytes [ 11110uuu 10uuzzzz 10yyyyyy 10xxxxxx ]
                    uint value = (uint)char.ConvertToUtf32((char)c, text[i + 1]);
                    destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)((value + (0b11110 << 21)) >> 18)];
                    destination[targetIndex + 1] = byteToUnicodeEncoder.ByteToUnicode[(char)(((value & (0x3Fu << 12)) >> 12) + 0x80u)];
                    destination[targetIndex + 2] = byteToUnicodeEncoder.ByteToUnicode[(char)(((value & (0x3Fu << 6)) >> 6) + 0x80u)];
                    destination[targetIndex + 3] = byteToUnicodeEncoder.ByteToUnicode[(char)((value & 0x3Fu) + 0x80u)];
                    indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = indexMapping[targetIndex + 3] = i;
                    i++;
                    targetIndex += 4;
                    continue;
                }
 
                // Scalar zzzzyyyy yyxxxxxx -> bytes [ 1110zzzz 10yyyyyy 10xxxxxx ]
                destination[targetIndex] = byteToUnicodeEncoder.ByteToUnicode[(char)((c + (0b1110 << 16)) >> 12)];
                destination[targetIndex + 1] = byteToUnicodeEncoder.ByteToUnicode[(char)(((c & (0x3Fu << 6)) >> 6) + 0x80u)];
                destination[targetIndex + 2] = byteToUnicodeEncoder.ByteToUnicode[(char)((c & 0x3Fu) + 0x80u)];
                indexMapping[targetIndex] = indexMapping[targetIndex + 1] = indexMapping[targetIndex + 2] = i;
                targetIndex += 3;
            }
 
            return targetIndex;
        }
 
        public static bool ConvertUtf8ToUtf16(ReadOnlySpan<byte> utf8Bytes, Span<char> utf16Chars, out int bytesConsumed, out int charsWritten)
        {
            Debug.Assert(utf16Chars.Length >= Encoding.UTF8.GetMaxCharCount(utf8Bytes.Length));
 
            int byteIndex = 0;
            int charIndex = 0;
            bytesConsumed = 0;
            charsWritten = 0;
 
            while (byteIndex < utf8Bytes.Length)
            {
                uint codePoint;
                int additionalBytes;
 
                byte firstByte = utf8Bytes[byteIndex];
 
                if ((firstByte & 0x80) == 0)
                {
                    // 1-byte sequence (ASCII)
                    codePoint = firstByte;
                    utf16Chars[charIndex++] = (char)firstByte;
                    charsWritten++;
                    bytesConsumed = ++byteIndex;
                    continue;
                }
                else if ((firstByte & 0xE0) == 0xC0)
                {
                    // 2-byte sequence
                    codePoint = (uint)(firstByte & 0x1F);
                    additionalBytes = 1;
                }
                else if ((firstByte & 0xF0) == 0xE0)
                {
                    // 3-byte sequence
                    codePoint = (uint)(firstByte & 0x0F);
                    additionalBytes = 2;
                }
                else if ((firstByte & 0xF8) == 0xF0)
                {
                    // 4-byte sequence
                    codePoint = (uint)(firstByte & 0x07);
                    additionalBytes = 3;
                }
                else
                {
                    return false;
                }
 
                if (byteIndex + additionalBytes >= utf8Bytes.Length)
                {
                    return true; // incomplete utf-8 sequence
                }
 
                for (int i = 1; i <= additionalBytes; i++)
                {
                    byte nextByte = utf8Bytes[byteIndex + i];
                    if ((nextByte & 0xC0) != 0x80)
                    {
                        return false;
                    }
                    codePoint = (codePoint << 6) | (uint)(nextByte & 0x3F);
                }
 
                byteIndex += additionalBytes + 1;
                bytesConsumed = byteIndex;
 
                if (codePoint <= 0xFFFF)
                {
                    utf16Chars[charIndex++] = (char)codePoint;
                }
                else
                {
                    codePoint -= 0x10000;
                    utf16Chars[charIndex++] = (char)((codePoint >> 10) + 0xD800);
                    utf16Chars[charIndex++] = (char)((codePoint & 0x3FF) + 0xDC00);
                }
 
                charsWritten = charIndex;
            }
 
            return true;
        }
    }
}