// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.ML.Tokenizers
internal struct Word
private static Random? _random;
private Vec<Symbol> _symbols;
public Word() => _symbols = new Vec<Symbol>();
public Word(int capacity)
if (capacity > int.MaxValue)
throw new ArgumentOutOfRangeException(nameof(capacity));
_symbols = new Vec<Symbol>(capacity);
public static Word WithCapacity(int capacity) => new Word(capacity);
public int SymbolsCount => _symbols.Count;
public void Add(int c, int charLength)
int prev = -1;
int next = -1;
int len = _symbols.Count;
if (len > 0)
// Update `next` on the previous one
_symbols[len - 1].Next = len;
prev = len - 1;
_symbols.Push(new Symbol(c, prev, next, charLength));
public Vec<(Pair<int>, int)> Merge(int c1, int c2, int replacement)
Vec<(Pair<int>, int)> changes = new();
int i = 0;
while (true)
if (i >= _symbols.Count)
// Found a pair
if (_symbols[i].C == c1 && i + 1 < _symbols.Count && _symbols[i + 1].C == c2)
Symbol first = _symbols[i];
Symbol second = _symbols[i + 1];
// If there are other characters before the pair
if (i > 0)
changes.Push((Pair<int>.Create(_symbols[i - 1].C, first.C), -1));
changes.Push((Pair<int>.Create(_symbols[i - 1].C, replacement), 1));
// Remove in place
// Insert replacement before first char of pair
// Remove first char of pair
// And then the second
_symbols[i].C = replacement;
_symbols[i].Prev = first.Prev;
_symbols[i].Next = second.Next;
_symbols[i].Len = first.Len + second.Len;
_symbols.Remove(i + 1);
// If there are other characters after the pair
if (i < _symbols.Count - 1)
changes.Push((Pair<int>.Create(second.C, _symbols[i + 1].C), -1));
changes.Push((Pair<int>.Create(replacement, _symbols[i + 1].C), 1));
i += 1;
return changes;
public void MergeAll(Dictionary<Pair<int>, (int, int)> merges, float? dropout, ref PriorityQueue<Merge>? priorityQueue)
priorityQueue ??= new PriorityQueue<Merge>(_symbols.Count);
Vec<Merge> skip = new Vec<Merge>(priorityQueue.Count);
for (int i = 0; i < _symbols.Count - 1; i++)
if (merges.TryGetValue(Pair<int>.Create(_symbols[i].C, _symbols[i + 1].C), out (int m1, int m2) value))
priorityQueue.Enqueue(new Merge(i, value.m1, value.m2));
while (priorityQueue.Count > 0)
Merge top = priorityQueue.Dequeue();
if (dropout.HasValue && (_random ??= new()).NextDouble() < dropout)
// Re-insert the skipped elements
for (int i = 0; i < skip.Count; i++)
// Do nothing if we are the last symbol
if (_symbols.Count == 0 || _symbols[top.Pos].Len == 0 || _symbols[top.Pos].Next == -1)
int nextPos = _symbols[top.Pos].Next;
Symbol right = _symbols[nextPos];
// Make sure we are not processing an expired queue entry
Pair<int> targetNewPair = Pair<int>.Create(_symbols[top.Pos].C, right.C);
if (!merges.TryGetValue(targetNewPair, out (int m1, int m2) value) || value.m2 != top.NewId)
// Otherwise, let's merge
_symbols[top.Pos].MergeWith(ref right, top.NewId);
// Tag the right part as removed
_symbols[nextPos].Len = 0;
// Update `prev` on the new `next` to the current pos
if (right.Next > -1 && right.Next < _symbols.Count)
_symbols[right.Next].Prev = top.Pos;
// Insert the new pair formed with the previous symbol
Symbol current = _symbols[top.Pos];
if (current.Prev >= 0)
int prev = current.Prev;
Symbol prevSymbol = _symbols[prev];
Pair<int> newPair = Pair<int>.Create(prevSymbol.C, current.C);
if (merges.TryGetValue(newPair, out value))
priorityQueue.Enqueue(new Merge(current.Prev, value.m1, value.m2));
// Insert the new pair formed with the next symbol
int next = current.Next;
if ((uint)next < (uint)_symbols.Count)
Symbol nextSymbol = _symbols[next];
Pair<int> newPair = Pair<int>.Create(current.C, nextSymbol.C);
if (merges.TryGetValue(newPair, out value))
priorityQueue.Enqueue(new Merge(top.Pos, value.m1, value.m2));
// Filter out the removed symbols
for (int i = _symbols.Count - 1; i >= 0; i--)
if (_symbols[i].Len == 0)
public void PopulateIds(IList<int> accumulatedIds)
for (int i = 0; i < SymbolsCount; i++)
public int PopulateIdsUpToMax(IList<int> accumulatedIds, int maxTokens, out int charsConsumed)
charsConsumed = 0;
int count = Math.Min(SymbolsCount, maxTokens);
for (int i = 0; i < count; i++)
charsConsumed += _symbols[i].Len;
return count;
public int PopulateIdsUpToMaxFromEnd(IList<int> accumulatedIds, int maxTokens, int fullTextLength, out int textIndex)
textIndex = fullTextLength;
int count = Math.Min(SymbolsCount, maxTokens);
for (int i = SymbolsCount - count; i < SymbolsCount; i++)
textIndex -= _symbols[i].Len;
return count;
public int CountIdsUpToMax(int maxTokens, out int charsConsumed)
charsConsumed = 0;
int count = Math.Min(SymbolsCount, maxTokens);
for (int i = 0; i < count; i++)
charsConsumed += _symbols[i].Len;
return count;
public int CountIdsUpToMaxFromEnd(int maxTokens, int fullTextLength, out int textIndex)
textIndex = fullTextLength;
int count = Math.Min(SymbolsCount, maxTokens);
for (int i = SymbolsCount - count; i < SymbolsCount; i++)
textIndex -= _symbols[i].Len;
return count;
public Vec<int> GetChars()
Vec<int> chars = new Vec<int>();
for (int i = 0; i < _symbols.Count; i++)
return chars;
public override string ToString()
if (_symbols.Count == 0)
return "[]";
StringBuilder sb = new StringBuilder();
for (int i = 1; i < _symbols.Count; i++)
sb.Append($", {_symbols[i].C}");
return sb.ToString();
public void ToTokens(SortedDictionary<int, string> vocabReverse, List<EncodedToken> tokens, int offset)
int index = 0;
for (int i = 0; i < SymbolsCount; i++)
int endIndex = index + _symbols[i].Len;
tokens.Add(new EncodedToken(_symbols[i].C, vocabReverse[_symbols[i].C], new Range(index + offset, index + offset + _symbols[i].Len)));
index += _symbols[i].Len;