File: Utils\BytePairEncoder.cs
Web Access
Project: src\src\Microsoft.ML.Tokenizers\Microsoft.ML.Tokenizers.csproj (Microsoft.ML.Tokenizers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Buffers;
using System.Collections.Generic;
 
namespace Microsoft.ML.Tokenizers
{
    /// <summary>
    /// This class implements the byte pair encoding algorithm.
    /// </summary>
    internal static class BytePairEncoder
    {
        public static (int Id, int TokenIndex, int TokenLength)[] BytePairEncode(ReadOnlyMemory<byte> mergingBytes, IReadOnlyDictionary<ReadOnlyMemory<byte>, int> ranks, ReadOnlySpan<int> indexMappingSpan)
        {
            if (mergingBytes.Length == 1)
            {
                return [(ranks[mergingBytes], 0, 1)];
            }
 
            // For large inputs, use heap-based algorithm to avoid O(n²) behavior.
            // Threshold of 128 chosen empirically: linear scan is cache-friendly for small inputs,
            // while heap overhead (O(log n) per operation) becomes worthwhile for larger inputs.
            // Based on upstream tiktoken using 100, adjusted upward for C#'s efficient span operations.
            if (mergingBytes.Length > 128)
            {
                return BytePairEncodeLarge(mergingBytes, ranks, indexMappingSpan);
            }
 
            (int Index, int Rank)[]? arrayPoolArray = null;
            int requiredLength = mergingBytes.Length + 1;
            Span<(int Index, int Rank)> byteIndicesAndRanks = requiredLength <= 64 ?
                stackalloc (int, int)[64] :
                (arrayPoolArray = ArrayPool<(int, int)>.Shared.Rent(requiredLength));
            byteIndicesAndRanks = byteIndicesAndRanks.Slice(0, requiredLength);
 
            for (int i = 0; i < byteIndicesAndRanks.Length; i++)
            {
                byteIndicesAndRanks[i] = (i, int.MaxValue);
            }
 
            int GetRank(Span<(int Index, int Rank)> byteIndicesAndRanks, int startIndex, int skip = 0)
            {
                if (startIndex + skip + 2 < byteIndicesAndRanks.Length)
                {
                    var slice = mergingBytes.SliceStartEnd(byteIndicesAndRanks[startIndex].Index, byteIndicesAndRanks[startIndex + skip + 2].Index);
                    if (ranks.TryGetValue(slice, out var rank))
                    {
                        return rank;
                    }
                }
 
                return int.MaxValue;
            }
 
            for (int i = 0; i < byteIndicesAndRanks.Length - 2; i++)
            {
                int rank = GetRank(byteIndicesAndRanks, i);
                if (rank != int.MaxValue)
                {
                    byteIndicesAndRanks[i].Rank = rank;
                }
            }
 
            while (byteIndicesAndRanks.Length > 1)
            {
                var minRank = (Index: 0, Rank: int.MaxValue);
                for (int i = 0; i < byteIndicesAndRanks.Length - 1; i++)
                {
                    if (byteIndicesAndRanks[i].Rank < minRank.Rank)
                    {
                        minRank = (i, byteIndicesAndRanks[i].Rank);
                    }
                }
 
                if (minRank.Rank != int.MaxValue)
                {
                    int j = minRank.Index;
                    byteIndicesAndRanks[j].Rank = GetRank(byteIndicesAndRanks, j, 1);
                    if (j > 0)
                    {
                        byteIndicesAndRanks[j - 1].Rank = GetRank(byteIndicesAndRanks, j - 1, 1);
                    }
 
                    byteIndicesAndRanks.Slice(j + 2).CopyTo(byteIndicesAndRanks.Slice(j + 1));
                    byteIndicesAndRanks = byteIndicesAndRanks.Slice(0, byteIndicesAndRanks.Length - 1);
                }
                else
                {
                    break;
                }
            }
 
            var result = new (int Id, int TokenIndex, int TokenLength)[byteIndicesAndRanks.Length - 1];
            for (int i = 0; i < result.Length; i++)
            {
                int startIndex = byteIndicesAndRanks[i].Index;
                int endIndex = byteIndicesAndRanks[i + 1].Index;
 
                int mappedStartIndex = indexMappingSpan[startIndex];
                int mappedEndIndex = indexMappingSpan[endIndex];
 
                int finalEndIndex = endIndex;
 
                if (finalEndIndex > 0 && indexMappingSpan[finalEndIndex - 1] == mappedEndIndex)
                {
                    // The partial character/element should be included in the current token.
                    finalEndIndex++;
                    while (finalEndIndex < indexMappingSpan.Length && indexMappingSpan[finalEndIndex] == mappedEndIndex)
                    {
                        finalEndIndex++;
                    }
                }
 
                result[i] = (ranks[mergingBytes.SliceStartEnd(startIndex, endIndex)], mappedStartIndex, indexMappingSpan[finalEndIndex] - mappedStartIndex);
            }
 
            if (arrayPoolArray is not null)
            {
                ArrayPool<(int, int)>.Shared.Return(arrayPoolArray);
            }
 
            return result;
        }
 
        private struct State
        {
            public int Prev;
            public int End;
            public int NextEnd;
            public int NextRank;
            // Note: In the Tiktoken tokenizer, the rank is also the token Id.
            // This field is used to cache the rank/Id after a merge so we don't need to re-look it up.
            // Using this code with a different tokenizer where rank != token Id would produce wrong results.
            public int CurRank;
        }
 
        private struct MergeEntry : IComparable<MergeEntry>
        {
            public int Rank;
            public int Start;
 
            public int CompareTo(MergeEntry other)
            {
                int rankComparison = Rank.CompareTo(other.Rank);
                if (rankComparison != 0)
                {
                    return rankComparison;
                }
                return Start.CompareTo(other.Start);
            }
        }
 
        private static (int Id, int TokenIndex, int TokenLength)[] BytePairEncodeLarge(ReadOnlyMemory<byte> mergingBytes, IReadOnlyDictionary<ReadOnlyMemory<byte>, int> ranks, ReadOnlySpan<int> indexMappingSpan)
        {
            int stateLength = mergingBytes.Length;
            State[] statePoolArray = ArrayPool<State>.Shared.Rent(stateLength);
            Span<State> state = statePoolArray.AsSpan(0, stateLength);
 
            state[0] = new State
            {
                Prev = int.MaxValue,
                End = 1,
                NextEnd = 2,
                NextRank = int.MaxValue,
                CurRank = int.MaxValue
            };
 
            var heap = new PriorityQueue<MergeEntry>();
 
            for (int i = 0; i < mergingBytes.Length - 1; i++)
            {
                var slice = mergingBytes.Slice(i, 2);
                if (ranks.TryGetValue(slice, out int rank))
                {
                    heap.Enqueue(new MergeEntry { Start = i, Rank = rank });
                    state[i].NextRank = rank;
                }
 
                state[i + 1] = new State
                {
                    Prev = i,
                    End = i + 2,
                    NextEnd = i + 3,
                    NextRank = int.MaxValue,
                    CurRank = int.MaxValue
                };
            }
 
            // Local function to add a potential merge to the heap.
            void PotentialMerge(Span<State> stateSpan, PriorityQueue<MergeEntry> heapQueue, int start, int nextEndItem)
            {
                stateSpan[start].NextEnd = nextEndItem;
                stateSpan[start].NextRank = int.MaxValue;
 
                if (nextEndItem <= mergingBytes.Length)
                {
                    var slice = mergingBytes.Slice(start, nextEndItem - start);
                    if (ranks.TryGetValue(slice, out int rank))
                    {
                        heapQueue.Enqueue(new MergeEntry { Start = start, Rank = rank });
                        stateSpan[start].NextRank = rank;
                    }
                }
            }
 
            while (heap.Count > 0)
            {
                MergeEntry left = heap.Dequeue();
 
                if (left.Rank == int.MaxValue)
                {
                    break;
                }
 
                if (left.Rank != state[left.Start].NextRank)
                {
                    continue;
                }
 
                int leftStart = left.Start;
                int rightStart = state[leftStart].End;
                int rightEnd = state[leftStart].NextEnd;
                int rightNextEnd = state[rightStart].NextEnd;
 
                state[leftStart].CurRank = state[leftStart].NextRank;
                state[leftStart].End = rightEnd;
                PotentialMerge(state, heap, leftStart, rightNextEnd);
 
                if (rightEnd < state.Length)
                {
                    state[rightEnd].Prev = leftStart;
                }
 
                if (leftStart > 0)
                {
                    int prevStart = state[leftStart].Prev;
                    PotentialMerge(state, heap, prevStart, rightEnd);
                }
 
                state[rightStart].NextRank = int.MaxValue;
            }
 
            // Use ArrayPool for the result buffer to avoid List<T> overhead.
            // The maximum number of tokens is mergingBytes.Length (no merges).
            var resultPoolArray = ArrayPool<(int Id, int TokenIndex, int TokenLength)>.Shared.Rent(mergingBytes.Length);
            int resultCount = 0;
            int currentIndex = 0;
 
            while (currentIndex < state.Length)
            {
                int startIndex = currentIndex;
                int endIndex = state[currentIndex].End;
 
                int mappedStartIndex = indexMappingSpan[startIndex];
                int mappedEndIndex = indexMappingSpan[endIndex];
 
                int finalEndIndex = endIndex;
 
                // Handle partial characters/elements at token boundaries.
                // If the byte at endIndex-1 maps to the same character as endIndex,
                // extend the token to include the complete character.
                if (finalEndIndex > 0 && indexMappingSpan[finalEndIndex - 1] == mappedEndIndex)
                {
                    finalEndIndex++;
                    while (finalEndIndex < indexMappingSpan.Length && indexMappingSpan[finalEndIndex] == mappedEndIndex)
                    {
                        finalEndIndex++;
                    }
                }
 
                int tokenId = state[currentIndex].CurRank != int.MaxValue
                    ? state[currentIndex].CurRank
                    : ranks[mergingBytes.SliceStartEnd(startIndex, endIndex)];
 
                resultPoolArray[resultCount++] = (tokenId, mappedStartIndex, indexMappingSpan[finalEndIndex] - mappedStartIndex);
 
                currentIndex = state[currentIndex].End;
            }
 
            ArrayPool<State>.Shared.Return(statePoolArray);
 
            var result = resultPoolArray.AsSpan(0, resultCount).ToArray();
            ArrayPool<(int Id, int TokenIndex, int TokenLength)>.Shared.Return(resultPoolArray);
            return result;
        }
 
        private static ReadOnlyMemory<byte> SliceStartEnd(this ReadOnlyMemory<byte> memory, int start, int end) => memory.Slice(start, end - start);
    }
}