File: SentencepieceModel.cs
Web Access
Project: src\src\Microsoft.ML.Tokenizers\Microsoft.ML.Tokenizers.csproj (Microsoft.ML.Tokenizers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
// Minimal protobuf parser for SentencePiece model files (sentencepiece_model.proto).
// Replaces a full Google.Protobuf dependency with just enough wire-format reading
// to parse the fields the tokenizer implementation actually consumes.
// SentencePiece is under the Apache License 2.0 https://github.com/google/sentencepiece/blob/master/LICENSE
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
 
namespace Sentencepiece
{
    /// <summary>Low-level protobuf wire-format primitives (read-only, forward-only).</summary>
    internal static class SentencePieceProtobufReader
    {
        internal static int ReadRawVarint32(byte[] data, int end, ref int pos)
        {
            if (pos >= end)
            {
                throw new InvalidDataException("Unexpected end of data while reading varint.");
            }
 
            byte b = data[pos++];
            int result = b & 0x7F;
            if ((b & 0x80) == 0)
            {
                return result;
            }
 
            for (int shift = 7; shift < 32; shift += 7)
            {
                if (pos >= end)
                {
                    throw new InvalidDataException("Unexpected end of data while reading varint.");
                }
 
                b = data[pos++];
                result |= (b & 0x7F) << shift;
                if ((b & 0x80) == 0)
                {
                    return result;
                }
            }
 
            // Negative int32 values are sign-extended to 10-byte varints; consume remaining bytes.
            for (int i = 0; i < 5; i++)
            {
                if (pos >= end)
                {
                    throw new InvalidDataException("Unexpected end of data while reading varint.");
                }
 
                if ((data[pos++] & 0x80) == 0)
                {
                    return result;
                }
            }
 
            throw new InvalidDataException("Malformed varint.");
        }
 
        internal static int ReadLengthPrefix(byte[] data, int end, ref int pos)
        {
            int length = ReadRawVarint32(data, end, ref pos);
            if ((uint)length > (uint)(end - pos))
            {
                throw new InvalidDataException("Invalid length-delimited field size.");
            }
 
            return length;
        }
 
        internal static string ReadString(byte[] data, int end, ref int pos)
        {
            int length = ReadLengthPrefix(data, end, ref pos);
            string result = Encoding.UTF8.GetString(data, pos, length);
            pos += length;
            return result;
        }
 
        internal static float ReadFloat(byte[] data, int end, ref int pos)
        {
            if (pos > end - 4)
            {
                throw new InvalidDataException("Unexpected end of data while reading float.");
            }
 
            float value;
            if (BitConverter.IsLittleEndian)
            {
                value = BitConverter.ToSingle(data, pos);
            }
            else
            {
                // Protobuf fixed32 is always little-endian; reverse bytes on big-endian platforms.
                byte[] buffer = new byte[4];
                buffer[0] = data[pos + 3];
                buffer[1] = data[pos + 2];
                buffer[2] = data[pos + 1];
                buffer[3] = data[pos];
                value = BitConverter.ToSingle(buffer, 0);
            }
 
            pos += 4;
            return value;
        }
 
        internal static void SkipField(byte[] data, int end, int wireType, ref int pos)
        {
            switch (wireType)
            {
                case 0: // varint (max 10 bytes per protobuf spec)
                    for (int i = 0; i < 10; i++)
                    {
                        if (pos >= end)
                        {
                            throw new InvalidDataException("Unexpected end of data while skipping varint.");
                        }
 
                        if ((data[pos++] & 0x80) == 0)
                        {
                            break;
                        }
 
                        if (i == 9)
                        {
                            throw new InvalidDataException("Malformed varint.");
                        }
                    }
                    break;
 
                case 1: // 64-bit fixed
                    if (pos > end - 8)
                    {
                        throw new InvalidDataException("Unexpected end of data while skipping fixed64.");
                    }
                    pos += 8;
                    break;
 
                case 2: // length-delimited
                    int skipLength = ReadLengthPrefix(data, end, ref pos);
                    pos += skipLength;
                    break;
 
                case 5: // 32-bit fixed
                    if (pos > end - 4)
                    {
                        throw new InvalidDataException("Unexpected end of data while skipping fixed32.");
                    }
                    pos += 4;
                    break;
 
                default:
                    throw new InvalidDataException($"Unknown or unsupported protobuf wire type {wireType}.");
            }
        }
    }
 
    /// <summary>Lightweight replacement for Google.Protobuf.ByteString with a Span property.</summary>
    internal readonly struct SentencePieceByteString(byte[] data, int offset, int length)
    {
        internal ReadOnlySpan<byte> Span => data is null ? ReadOnlySpan<byte>.Empty : data.AsSpan(offset, length);
    }
 
    /// <summary>ModelProto  (top-level message; field numbers match sentencepiece_model.proto)</summary>
    internal sealed class ModelProto
    {
        internal static readonly ModelProtoParser Parser = new();
 
        internal List<Types.SentencePiece> Pieces { get; } = new();
        internal TrainerSpec TrainerSpec { get; private set; } = new();
        internal NormalizerSpec NormalizerSpec { get; private set; } = new();
 
        internal static ModelProto Parse(byte[] data, int start, int end)
        {
            ModelProto result = new();
            int pos = start;
            int length;
            while (pos < end)
            {
                int tag = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos);
                int fieldNumber = tag >> 3;
                int wireType = tag & 7;
 
                // The 'when wireType == 2' guards serve double duty: they match the expected wire
                // type for these message fields AND provide forward-compatibility — if a future proto
                // version changes a field's type, or if extension fields reuse these numbers with a
                // different wire type, the mismatch falls through to the default skip case.
                //
                // For non-repeated message fields (TrainerSpec, NormalizerSpec), seeing the same field
                // number twice replaces the prior value (last wins). This differs from the standard
                // protobuf library which merges repeated occurrences of non-repeated message fields.
                // SentencePiece model files contain each field at most once, so the difference is moot
                // in practice.
                switch (fieldNumber)
                {
                    case 1 when wireType == 2: // repeated SentencePiece pieces = 1
                        length = SentencePieceProtobufReader.ReadLengthPrefix(data, end, ref pos);
                        result.Pieces.Add(Types.SentencePiece.Parse(data, pos, pos + length));
                        pos += length;
                        break;
 
                    case 2 when wireType == 2: // TrainerSpec trainer_spec = 2
                        length = SentencePieceProtobufReader.ReadLengthPrefix(data, end, ref pos);
                        result.TrainerSpec = TrainerSpec.Parse(data, pos, pos + length);
                        pos += length;
                        break;
 
                    case 3 when wireType == 2: // NormalizerSpec normalizer_spec = 3
                        length = SentencePieceProtobufReader.ReadLengthPrefix(data, end, ref pos);
                        result.NormalizerSpec = NormalizerSpec.Parse(data, pos, pos + length);
                        pos += length;
                        break;
 
                    default:
                        SentencePieceProtobufReader.SkipField(data, end, wireType, ref pos);
                        break;
                }
            }
 
            return result;
        }
 
        internal static class Types
        {
            internal sealed class SentencePiece
            {
                internal string Piece { get; set; } = "";
                internal float Score { get; set; }
                internal Types.Type Type { get; set; } = Types.Type.Normal;
 
                internal static SentencePiece Parse(byte[] data, int start, int end)
                {
                    SentencePiece result = new();
                    int pos = start;
                    while (pos < end)
                    {
                        int tag = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos);
                        int fieldNumber = tag >> 3;
                        int wireType = tag & 7;
 
                        switch (fieldNumber)
                        {
                            case 1 when wireType == 2: // string piece = 1
                                result.Piece = SentencePieceProtobufReader.ReadString(data, end, ref pos);
                                break;
 
                            case 2 when wireType == 5: // float score = 2
                                result.Score = SentencePieceProtobufReader.ReadFloat(data, end, ref pos);
                                break;
 
                            case 3 when wireType == 0: // Type type = 3
                                result.Type = (Types.Type)SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos);
                                break;
 
                            default:
                                SentencePieceProtobufReader.SkipField(data, end, wireType, ref pos);
                                break;
                        }
                    }
 
                    return result;
                }
 
                internal static class Types
                {
                    internal enum Type
                    {
                        Normal = 1,
                        Unknown = 2,
                        Control = 3,
                        UserDefined = 4,
                        Unused = 5,
                        Byte = 6,
                    }
                }
            }
        }
    }
 
    /// <summary>ModelProtoParser  (entry point: ModelProto.Parser.ParseFrom(Stream))</summary>
    internal sealed class ModelProtoParser
    {
        internal ModelProto ParseFrom(Stream stream)
        {
            if (stream is null)
            {
                throw new ArgumentNullException(nameof(stream));
            }
 
            // Fast-path: if the input is already a MemoryStream with an accessible buffer,
            // parse directly from its underlying array without copying.
            if (stream is MemoryStream memoryStream &&
                memoryStream.TryGetBuffer(out ArraySegment<byte> segment))
            {
                int start = segment.Offset + (int)memoryStream.Position;
                int end = segment.Offset + (int)memoryStream.Length;
                return ModelProto.Parse(segment.Array!, start, end);
            }
 
            // Fallback: copy remaining data into a new MemoryStream, pre-sizing when possible.
            MemoryStream ms;
            if (stream.CanSeek)
            {
                long remaining = stream.Length - stream.Position;
                ms = remaining > 0 && remaining <= int.MaxValue ? new MemoryStream((int)remaining) : new MemoryStream();
            }
            else
            {
                ms = new MemoryStream();
            }
 
            stream.CopyTo(ms);
            return ModelProto.Parse(ms.GetBuffer(), 0, (int)ms.Length);
        }
    }
 
    /// <summary>TrainerSpec  (defaults match sentencepiece_model.proto)</summary>
    internal sealed class TrainerSpec
    {
        internal Types.ModelType ModelType { get; private set; } = Types.ModelType.Unigram;
        internal bool TreatWhitespaceAsSuffix { get; private set; }
        internal bool ByteFallback { get; private set; }
        internal int UnkId { get; private set; }
        internal int BosId { get; private set; } = 1;
        internal int EosId { get; private set; } = 2;
        internal int PadId { get; private set; } = -1;
        internal string UnkPiece { get; private set; } = "<unk>";
        internal string BosPiece { get; private set; } = "<s>";
        internal string EosPiece { get; private set; } = "</s>";
        internal string PadPiece { get; private set; } = "<pad>";
 
        internal static TrainerSpec Parse(byte[] data, int start, int end)
        {
            TrainerSpec result = new();
            int pos = start;
            while (pos < end)
            {
                int tag = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos);
                int fieldNumber = tag >> 3;
                int wireType = tag & 7;
 
                switch (fieldNumber)
                {
                    case 3 when wireType == 0:  // ModelType model_type = 3
                        result.ModelType = (Types.ModelType)SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos);
                        break;
 
                    case 24 when wireType == 0: // bool treat_whitespace_as_suffix = 24
                        result.TreatWhitespaceAsSuffix = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos) != 0;
                        break;
 
                    case 35 when wireType == 0: // bool byte_fallback = 35
                        result.ByteFallback = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos) != 0;
                        break;
 
                    case 40 when wireType == 0: // int32 unk_id = 40
                        result.UnkId = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos);
                        break;
 
                    case 41 when wireType == 0: // int32 bos_id = 41
                        result.BosId = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos);
                        break;
 
                    case 42 when wireType == 0: // int32 eos_id = 42
                        result.EosId = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos);
                        break;
 
                    case 43 when wireType == 0: // int32 pad_id = 43
                        result.PadId = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos);
                        break;
 
                    case 45 when wireType == 2: // string unk_piece = 45
                        result.UnkPiece = SentencePieceProtobufReader.ReadString(data, end, ref pos);
                        break;
 
                    case 46 when wireType == 2: // string bos_piece = 46
                        result.BosPiece = SentencePieceProtobufReader.ReadString(data, end, ref pos);
                        break;
 
                    case 47 when wireType == 2: // string eos_piece = 47
                        result.EosPiece = SentencePieceProtobufReader.ReadString(data, end, ref pos);
                        break;
 
                    case 48 when wireType == 2: // string pad_piece = 48
                        result.PadPiece = SentencePieceProtobufReader.ReadString(data, end, ref pos);
                        break;
 
                    default:
                        SentencePieceProtobufReader.SkipField(data, end, wireType, ref pos);
                        break;
                }
            }
 
            return result;
        }
 
        internal static class Types
        {
            internal enum ModelType
            {
                Unigram = 1,
                Bpe = 2,
                Word = 3,
                Char = 4,
            }
        }
    }
 
    /// <summary>NormalizerSpec  (defaults match sentencepiece_model.proto)</summary>
    internal sealed class NormalizerSpec
    {
        internal string Name { get; private set; } = "";
        internal SentencePieceByteString PrecompiledCharsmap { get; private set; }
        internal bool AddDummyPrefix { get; private set; } = true;
        internal bool RemoveExtraWhitespaces { get; private set; } = true;
        internal bool EscapeWhitespaces { get; private set; } = true;
 
        internal static NormalizerSpec Parse(byte[] data, int start, int end)
        {
            NormalizerSpec result = new();
            int pos = start;
            while (pos < end)
            {
                int tag = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos);
                int fieldNumber = tag >> 3;
                int wireType = tag & 7;
 
                switch (fieldNumber)
                {
                    case 1 when wireType == 2: // string name = 1
                        result.Name = SentencePieceProtobufReader.ReadString(data, end, ref pos);
                        break;
 
                    case 2 when wireType == 2: // bytes precompiled_charsmap = 2
                        int length = SentencePieceProtobufReader.ReadLengthPrefix(data, end, ref pos);
                        result.PrecompiledCharsmap = new SentencePieceByteString(data, pos, length);
                        pos += length;
                        break;
 
                    case 3 when wireType == 0: // bool add_dummy_prefix = 3
                        result.AddDummyPrefix = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos) != 0;
                        break;
 
                    case 4 when wireType == 0: // bool remove_extra_whitespaces = 4
                        result.RemoveExtraWhitespaces = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos) != 0;
                        break;
 
                    case 5 when wireType == 0: // bool escape_whitespaces = 5
                        result.EscapeWhitespaces = SentencePieceProtobufReader.ReadRawVarint32(data, end, ref pos) != 0;
                        break;
 
                    default:
                        SentencePieceProtobufReader.SkipField(data, end, wireType, ref pos);
                        break;
                }
            }
 
            return result;
        }
    }
}