|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// Represent the Byte Pair Encoding model.
/// </summary>
public sealed class BpeTokenizer : Tokenizer
{
/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
private const int MaxWordLengthToCache = 15;
private string? _unknownToken;
private int? _unknownTokenId;
private readonly PreTokenizer? _preTokenizer;
private readonly Normalizer? _normalizer;
private readonly Dictionary<StringSpanOrdinalKey, (int, string)>? _specialTokens;
private readonly Dictionary<int, string>? _specialTokensReverse;
/// <summary>
/// Gets the special tokens.
/// </summary>
public IReadOnlyDictionary<string, int>? SpecialTokens { get; }
/// <summary>
/// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char
/// </summary>
public string? UnknownToken
{
get
{
return _unknownToken;
}
private set
{
if (value is null)
{
_unknownToken = value;
_unknownTokenId = null;
return;
}
if (!_vocab.TryGetValue(value, out int id))
{
throw new InvalidOperationException($"Unknown Token '{value}' was not present in '{nameof(Vocabulary)}'.");
}
_unknownTokenId = id;
_unknownToken = value;
}
}
/// <summary>
/// A prefix to be used for every subword that is not a beginning-of-word
/// </summary>
public string? ContinuingSubwordPrefix { get; }
/// <summary>
/// An optional suffix to characterize and end-of-word sub-word
/// </summary>
public string? EndOfWordSuffix { get; }
/// <summary>
/// Gets or sets whether allowing multiple unknown tokens get fused
/// </summary>
public bool FuseUnknownTokens { get; }
/// <summary>
/// Create a new Bpe tokenizer object to use for text encoding.
/// </summary>
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
/// <remarks>
/// When creating the tokenizer, ensure that the vocabulary file is sourced from a trusted provider.
/// </remarks>
public static BpeTokenizer Create(string vocabFile, string? mergesFile)
=> Create(vocabFile, mergesFile, preTokenizer: PreTokenizer.CreateWordOrNonWord(), normalizer: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
/// <summary>
/// Create a new Bpe tokenizer object to use for text encoding.
/// </summary>
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that don’t represent a beginning of word.</param>
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
/// <remarks>
/// When creating the tokenizer, ensure that the vocabulary file is sourced from a trusted provider.
/// </remarks>
public static BpeTokenizer Create(
string vocabFile,
string? mergesFile,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
IReadOnlyDictionary<string, int>? specialTokens = null,
string? unknownToken = null,
string? continuingSubwordPrefix = null,
string? endOfWordSuffix = null,
bool fuseUnknownTokens = false)
{
if (vocabFile is null)
{
throw new ArgumentNullException(nameof(vocabFile));
}
using Stream vocabStream = File.OpenRead(vocabFile);
using Stream? mergesStream = mergesFile is null ? null : File.OpenRead(mergesFile);
(Dictionary<StringSpanOrdinalKey, int>? vocab, Vec<(string, string)> merges) result = ReadModelDataAsync(vocabStream, mergesStream, useAsync: false).GetAwaiter().GetResult();
return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
}
/// <summary>
/// Create a new Bpe tokenizer object to use for text encoding.
/// </summary>
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
/// <remarks>
/// When creating the tokenizer, ensure that the vocabulary stream is sourced from a trusted provider.
/// </remarks>
public static BpeTokenizer Create(Stream vocabStream, Stream? mergesStream)
=> Create(vocabStream, mergesStream, preTokenizer: PreTokenizer.CreateWordOrNonWord(), normalizer: null, specialTokens: null, unknownToken: null, continuingSubwordPrefix: null, endOfWordSuffix: null, fuseUnknownTokens: false);
/// <summary>
/// Create a new Bpe tokenizer object to use for text encoding.
/// </summary>
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that don’t represent a beginning of word.</param>
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
/// <remarks>
/// When creating the tokenizer, ensure that the vocabulary stream is sourced from a trusted provider.
/// </remarks>
public static BpeTokenizer Create(
Stream vocabStream,
Stream? mergesStream,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
IReadOnlyDictionary<string, int>? specialTokens = null,
string? unknownToken = null,
string? continuingSubwordPrefix = null,
string? endOfWordSuffix = null,
bool fuseUnknownTokens = false)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
(Dictionary<StringSpanOrdinalKey, int>? vocab, Vec<(string, string)> merges) result = ReadModelDataAsync(vocabStream, mergesStream, useAsync: false).GetAwaiter().GetResult();
return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
}
/// <summary>
/// Create a new Bpe tokenizer object asynchronously to use for text encoding.
/// </summary>
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that don’t represent a beginning of word.</param>
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
/// <remarks>
/// When creating the tokenizer, ensure that the vocabulary stream is sourced from a trusted provider.
/// </remarks>
public static async Task<BpeTokenizer> CreateAsync(
Stream vocabStream,
Stream? mergesStream,
PreTokenizer? preTokenizer = null,
Normalizer? normalizer = null,
IReadOnlyDictionary<string, int>? specialTokens = null,
string? unknownToken = null,
string? continuingSubwordPrefix = null,
string? endOfWordSuffix = null,
bool fuseUnknownTokens = false)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
(Dictionary<StringSpanOrdinalKey, int>? vocab, Vec<(string, string)> merges) result = await ReadModelDataAsync(vocabStream, mergesStream, useAsync: true).ConfigureAwait(false);
return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
}
/// <summary>
/// Construct a new Bpe model object to use for text encoding.
/// </summary>
/// <param name="vocab">The dictionary vocabulary mapping token string to ids.</param>
/// <param name="merges">The pairs list help in merging tokens during the encoding process.</param>
/// <param name="preTokenizer">The pre-tokenizer to use.</param>
/// <param name="normalizer">The normalizer to use.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that don’t represent a beginning of word.</param>
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
private BpeTokenizer(
Dictionary<StringSpanOrdinalKey, int>? vocab,
Vec<(string, string)> merges,
PreTokenizer? preTokenizer,
Normalizer? normalizer,
IReadOnlyDictionary<string, int>? specialTokens,
string? unknownToken,
string? continuingSubwordPrefix,
string? endOfWordSuffix,
bool fuseUnknownTokens)
{
FuseUnknownTokens = fuseUnknownTokens;
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;
_preTokenizer = preTokenizer ?? PreTokenizer.CreateWordOrNonWord(); // Default to WordOrNonWord pre-tokenizer
_normalizer = normalizer;
_vocab = vocab ?? new Dictionary<StringSpanOrdinalKey, int>();
Cache = new StringSpanOrdinalKeyCache<Word>();
VocabReverse = new();
foreach (KeyValuePair<StringSpanOrdinalKey, int> kvp in _vocab)
{
VocabReverse.Add(kvp.Value, kvp.Key.Data!);
}
if (specialTokens is not null)
{
SpecialTokens = specialTokens;
_specialTokens = specialTokens.ToDictionary(kvp => new StringSpanOrdinalKey(kvp.Key), kvp => (kvp.Value, kvp.Key));
_specialTokensReverse = specialTokens.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
}
UnknownToken = unknownToken;
int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;
Merges = new();
for (int i = 0; i < merges.Count; i++)
{
(string a, string b) mergeValues = merges[i];
if (!_vocab.TryGetValue(mergeValues.a, out int aId))
{
throw new InvalidOperationException($"Trying to merge a token '{mergeValues.a}' which not exist in the vocabulary.");
}
if (!_vocab.TryGetValue(mergeValues.b, out int bId))
{
throw new InvalidOperationException($"Trying to merge a token '{mergeValues.b}' which not exist in the vocabulary.");
}
if (mergeValues.b.Length <= prefixLen)
{
throw new InvalidOperationException($"The merge value '{mergeValues.b}' is too short to be merged with a prefix of length {prefixLen}. This implies that the merge file is either damaged or missing the prefix in its entries.");
}
string newToken = $"{mergeValues.a}{mergeValues.b.Substring(prefixLen)}";
if (!_vocab.TryGetValue(newToken, out int newId))
{
throw new InvalidOperationException($"Trying to merge a token '{newToken}' which not exist in the vocabulary.");
}
Merges.Add(new Pair<int>(aId, bId), (i, newId));
}
}
/// <summary>
/// Gets the PreTokenizer used by the Tokenizer.
/// </summary>
public override PreTokenizer? PreTokenizer => _preTokenizer;
/// <summary>
/// Gets the Normalizer in use by the Tokenizer.
/// </summary>
public override Normalizer? Normalizer => _normalizer;
/// <summary>
/// Encodes input text to a list of <see cref="EncodedToken" />s.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
protected override EncodeResults<EncodedToken> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
return new EncodeResults<EncodedToken> { Tokens = [], NormalizedText = null, CharsConsumed = 0 };
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderPreTokenization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out string? normalizedText,
out ReadOnlySpan<char> textSpanToEncode,
out int charsConsumed);
List<EncodedToken> tokens = new();
PriorityQueue<Merge>? priorityQueue = null;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
EncodeWithCache(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset, ref priorityQueue);
}
}
else
{
EncodeWithCache(textSpanToEncode, tokens, 0, ref priorityQueue);
}
return new EncodeResults<EncodedToken> { Tokens = tokens, NormalizedText = normalizedText, CharsConsumed = charsConsumed };
}
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The encoded results containing the list of encoded Ids.</returns>
protected override EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
int maxTokenCount = settings.MaxTokenCount;
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
return new EncodeResults<int> { Tokens = [], NormalizedText = null, CharsConsumed = 0 };
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderPreTokenization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out string? normalizedText,
out ReadOnlySpan<char> textSpanToEncode,
out _);
List<int> ids = new();
PriorityQueue<Merge>? priorityQueue = null;
int charsConsumed = 0;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), ids, maxTokenCount - ids.Count, out int length, ref priorityQueue);
charsConsumed = split.Offset + length;
if (length < split.Length || ids.Count >= maxTokenCount)
{
break;
}
}
}
else
{
EncodeToIdsWithCache(textSpanToEncode, ids, maxTokenCount, out charsConsumed, ref priorityQueue);
}
return new EncodeResults<int> { Tokens = ids, NormalizedText = normalizedText, CharsConsumed = charsConsumed };
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
protected override int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
int maxTokenCount = settings.MaxTokenCount;
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderPreTokenization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out string? normalizedText,
out ReadOnlySpan<char> textSpanToEncode,
out _);
PriorityQueue<Merge>? priorityQueue = null;
int count = 0;
int textLength = 0;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
count += EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - count, out int length, ref priorityQueue);
textLength = split.Offset + length;
if (length < split.Length || count >= maxTokenCount)
{
break;
}
}
}
else
{
count = EncodeToIdsWithCache(textSpanToEncode, null, maxTokenCount, out textLength, ref priorityQueue);
}
return count;
}
/// <summary>
/// Find the index of the maximum encoding capacity without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <param name="fromEnd">Indicate whether to find the index from the end of the text.</param>
/// <param name="normalizedText">If the tokenizer's normalization is enabled or <paramRef name="settings" /> has <see cref="EncodeSettings.ConsiderNormalization"/> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// If <paramRef name="fromEnd" /> is <see langword="false"/>, it represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the input text or the <paramref name="normalizedText"/> if the normalization is enabled.
/// If <paramRef name="fromEnd" /> is <see langword="true"/>, it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
/// if all tokens fit, the result will be zero.
/// </returns>
protected override int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedText, out int tokenCount)
{
if (fromEnd)
{
return LastIndexOf(text, textSpan, settings.MaxTokenCount, settings.ConsiderPreTokenization, settings.ConsiderNormalization, out normalizedText, out tokenCount);
}
tokenCount = CountTokens(text, textSpan, settings.ConsiderPreTokenization, settings.ConsiderNormalization, out normalizedText, out int charsConsumed, settings.MaxTokenCount);
return charsConsumed;
}
private int CountTokens(string? text, ReadOnlySpan<char> textSpan, bool considerPreTokenization, bool considerNormalization, out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
}
charsConsumed = 0;
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedText = null;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
considerPreTokenization,
considerNormalization,
_normalizer,
_preTokenizer,
out normalizedText,
out ReadOnlySpan<char> textSpanToEncode,
out _);
PriorityQueue<Merge>? priorityQueue = null;
int count = 0;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
count += EncodeToIdsWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - count, out int length, ref priorityQueue);
charsConsumed = split.Offset + length;
if (length < split.Length || count >= maxTokenCount)
{
break;
}
}
}
else
{
count = EncodeToIdsWithCache(textSpanToEncode, null, maxTokenCount, out charsConsumed, ref priorityQueue);
}
return count;
}
private int LastIndexOf(string? text, ReadOnlySpan<char> textSpan, int maxTokenCount, bool considerPreTokenization, bool considerNormalization, out string? normalizedText, out int tokenCount)
{
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedText = null;
tokenCount = 0;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
considerPreTokenization,
considerNormalization,
_normalizer,
_preTokenizer,
out normalizedText,
out ReadOnlySpan<char> textSpanToEncode,
out _);
PriorityQueue<Merge>? priorityQueue = null;
if (splits is not null)
{
tokenCount = 0;
foreach ((int Offset, int Length) split in splits.Reverse())
{
tokenCount += EncodeToIdsFromEndWithCache(textSpanToEncode.Slice(split.Offset, split.Length), null, maxTokenCount - tokenCount, out int textIndex, ref priorityQueue);
if (textIndex > 0 || tokenCount >= maxTokenCount)
{
return split.Offset + textIndex;
}
}
}
else
{
tokenCount = EncodeToIdsFromEndWithCache(textSpanToEncode, null, maxTokenCount, out int charsConsumed, ref priorityQueue);
return charsConsumed;
}
return 0;
}
/// <summary>
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <returns>The mapped Id of the token.</returns>
private int? MapTokenToId(ReadOnlySpan<char> token) => _vocab.TryGetValue(token, out int value) ? value : null;
/// <summary>
/// Map the encoded Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <returns>The mapped token of the Id.</returns>
private string? MapIdToToken(int id)
{
if (VocabReverse.TryGetValue(id, out string? value))
{
return value;
}
return null;
}
/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int> Vocabulary => _vocabOriginal ??= new ReadOnlyDictionary<string, int>(_vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value));
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns>
public override string Decode(IEnumerable<int> ids) => Decode(ids, considerSpecialTokens: true);
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="considerSpecialTokens">Indicate whether to consider special tokens or not.</param>
/// <returns>The decoded string.</returns>
public string Decode(IEnumerable<int> ids, bool considerSpecialTokens)
{
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}
ValueStringBuilder sb = new ValueStringBuilder();
bool decodeUnknownToken = _unknownTokenId.HasValue && considerSpecialTokens;
if (decodeUnknownToken)
{
foreach (int id in ids)
{
if (MapIdToToken(id) is string s)
{
sb.Append(s);
}
}
}
else
{
foreach (int id in ids)
{
if (id == _unknownTokenId)
{
continue;
}
if (MapIdToToken(id) is string s)
{
sb.Append(s);
}
}
}
if (EndOfWordSuffix is not null)
{
sb.RemoveSuffix(EndOfWordSuffix);
sb.Replace(EndOfWordSuffix, " ");
}
if (ContinuingSubwordPrefix is not null)
{
sb.Replace(ContinuingSubwordPrefix, string.Empty);
}
return sb.ToString();
}
/// <summary>
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="destination">The span to store the decoded text.</param>
/// <param name="idsConsumed">The number of ids consumed during the decoding.</param>
/// <param name="charsWritten">The number of characters written to the destination span.</param>
/// <returns>The operation status indicates whether all IDs were successfully decoded or if the <paramref name="destination"/> is too small to contain the entire decoded result.</returns>
public override OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten)
=> Decode(ids, destination, considerSpecialTokens: true, out idsConsumed, out charsWritten);
/// <summary>
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="destination">The span to store the decoded text.</param>
/// <param name="considerSpecialTokens">Indicate whether to consider special tokens or not.</param>
/// <param name="idsConsumed">The number of ids consumed during the decoding.</param>
/// <param name="charsWritten">The number of characters written to the destination span.</param>
/// <returns>The operation status indicates whether all IDs were successfully decoded or if the <paramref name="destination"/> is too small to contain the entire decoded result.</returns>
public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool considerSpecialTokens, out int idsConsumed, out int charsWritten)
{
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}
idsConsumed = 0;
charsWritten = 0;
Span<char> buffer = destination;
bool skipUnknownToken = !_unknownTokenId.HasValue || !considerSpecialTokens;
bool addSpace = false;
bool continuingSubwordPrefix = ContinuingSubwordPrefix is not null && ContinuingSubwordPrefix.Length > 0;
bool endOfWordSuffix = EndOfWordSuffix is not null && EndOfWordSuffix.Length > 0;
int previousCharsWritten = 0;
int previousIdsConsumed = 0;
foreach (int id in ids)
{
if (skipUnknownToken && id == _unknownTokenId)
{
idsConsumed++;
continue;
}
if (addSpace)
{
if (buffer.Length == 0)
{
charsWritten = previousCharsWritten;
return OperationStatus.DestinationTooSmall;
}
buffer[0] = ' ';
buffer = buffer.Slice(1);
charsWritten++;
}
if (MapIdToToken(id) is string s)
{
ReadOnlySpan<char> sSpan = s.AsSpan();
if (continuingSubwordPrefix && sSpan.StartsWith(ContinuingSubwordPrefix.AsSpan(), StringComparison.Ordinal))
{
sSpan = sSpan.Slice(ContinuingSubwordPrefix!.Length);
}
addSpace = false;
if (endOfWordSuffix && sSpan.EndsWith(EndOfWordSuffix!.AsSpan(), StringComparison.Ordinal))
{
sSpan = sSpan.Slice(0, sSpan.Length - EndOfWordSuffix!.Length);
addSpace = true;
previousIdsConsumed = idsConsumed;
previousCharsWritten = charsWritten;
}
if (sSpan.Length > buffer.Length)
{
return OperationStatus.DestinationTooSmall;
}
sSpan.CopyTo(buffer);
charsWritten += sSpan.Length;
buffer = buffer.Slice(sSpan.Length);
}
idsConsumed++;
}
return OperationStatus.Done;
}
/// Read the given files to extract the vocab and merges
internal static async ValueTask<(Dictionary<StringSpanOrdinalKey, int>?, Vec<(string, string)>)> ReadModelDataAsync(Stream vocab, Stream? merges, bool useAsync, CancellationToken cancellationToken = default)
{
Dictionary<StringSpanOrdinalKey, int>? dic = useAsync
? await JsonSerializer.DeserializeAsync(vocab, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32, cancellationToken).ConfigureAwait(false)
: JsonSerializer.Deserialize(vocab, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32);
var m = useAsync ?
await ConvertMergesToHashmapAsync(merges, useAsync, cancellationToken).ConfigureAwait(false) :
ConvertMergesToHashmapAsync(merges, useAsync).GetAwaiter().GetResult();
return (dic, m);
}
/// The vocabulary assigns a number to each token.
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
private IReadOnlyDictionary<string, int>? _vocabOriginal;
/// Contains the mapping between Pairs and their (rank, newId).
internal Dictionary<Pair<int>, (int, int)> Merges { get; }
/// Contains the cache for optimizing the encoding step.
internal StringSpanOrdinalKeyCache<Word>? Cache { get; }
internal static readonly int DefaultCacheCapacity = 10_000;
/// Reversed vocabulary, to rebuild the text.
internal SortedDictionary<int, string> VocabReverse { get; }
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
/// perform no merges, so the result will just be characters.
internal float? Dropout { get; }
/// Converts the merges strings (for example from `merges.txt` file) with the format
/// "{pair_a} {pair_b}" into the format expected by the BPE struct
internal static async ValueTask<Vec<(string, string)>> ConvertMergesToHashmapAsync(Stream? mergesStream, bool useAsync = false, CancellationToken cancellationToken = default)
{
if (mergesStream is null)
{
return new Vec<(string, string)>();
}
// Don't dispose the reader as it will dispose the underlying stream mergesStream. The caller is responsible for disposing the stream.
StreamReader reader = new StreamReader(mergesStream);
Vec<(string, string)> merges = new(1000);
int lineNumber = 0;
while (true)
{
string? line = useAsync ?
await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
reader.ReadLine();
if (line is null)
{
break;
}
lineNumber++;
if (line.StartsWith("#version", StringComparison.Ordinal) || line.Length == 0)
{
continue;
}
int index = line.IndexOf(' ');
if (index < 0 || index == line.Length - 1 || line.IndexOf(' ', index + 1) >= 0)
{
throw new InvalidOperationException($"Invalid merger file format at line: {lineNumber}");
}
merges.Push((line.Substring(0, index), line.Substring(index + 1)));
}
return merges;
}
private readonly Dictionary<char, string> _charToString = new Dictionary<char, string>();
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal string CharToString(char c)
{
if (_charToString.TryGetValue(c, out string? v))
{
return v;
}
string s = c.ToString();
_charToString[c] = s;
return s;
}
internal Word MergeWord(ReadOnlySpan<char> w, ref PriorityQueue<Merge>? priorityQueue)
{
Word word = Word.WithCapacity(w.Length);
(int Id, int Len)? unk = null;
int i = 0;
Span<char> buffer = stackalloc char[256];
scoped ReadOnlySpan<char> s;
while (i < w.Length)
{
int length;
if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1]))
{
length = 2;
s = w.Slice(i, 2);
}
else
{
length = 1;
s = w.Slice(i, 1);
}
// Add the `continuing_subword_prefix` if relevant
if (i > 0 && ContinuingSubwordPrefix is not null)
{
if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length)
{
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
s.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
s = buffer.Slice(0, ContinuingSubwordPrefix.Length + s.Length);
}
else
{
#if NETCOREAPP
s = $"{ContinuingSubwordPrefix}{s}".AsSpan();
#else
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
s = $"{ContinuingSubwordPrefix}{s1}".AsSpan();
#endif
}
}
// Add the `end_of_word_suffix` if relevant
if (i + length >= w.Length && EndOfWordSuffix is not null)
{
if (s.Length + EndOfWordSuffix.Length <= buffer.Length)
{
s.CopyTo(buffer);
EndOfWordSuffix.AsSpan().CopyTo(buffer.Slice(s.Length));
s = buffer.Slice(0, s.Length + EndOfWordSuffix.Length);
}
else
{
#if NETCOREAPP
s = $"{s}{EndOfWordSuffix}".AsSpan();
#else
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
s = $"{s1}{EndOfWordSuffix}".AsSpan();
#endif
}
}
if (_vocab.TryGetValue(s, out int id))
{
if (unk.HasValue)
{
word.Add(unk.Value.Id, unk.Value.Len);
unk = null;
}
word.Add(id, length);
}
else if (UnknownToken is not null)
{
if (unk.HasValue)
{
if (FuseUnknownTokens)
{
// Fuse unk
unk = (unk.Value.Id, unk.Value.Len + length);
}
else
{
// Do not fuse unk, add the previous one
word.Add(unk.Value.Id, unk.Value.Len);
if (!_vocab.TryGetValue(UnknownToken, out int value))
{
throw new InvalidOperationException($"Unknown Token Out Of Vocabulary.");
}
unk = (value, length);
}
}
else
{
if (!_vocab.TryGetValue(UnknownToken, out int value))
{
throw new InvalidOperationException($"Unknown Token Out Of Vocabulary.");
}
unk = (value, length);
}
}
i += length;
}
if (unk.HasValue)
{
word.Add(unk.Value.Id, unk.Value.Len);
}
word.MergeAll(Merges, Dropout, ref priorityQueue);
return word;
}
internal void WordToTokens(ref Word word, List<EncodedToken> tokens, int offset) => word.ToTokens(VocabReverse, tokens, offset);
internal void EncodeWithCache(ReadOnlySpan<char> text, List<EncodedToken> tokens, int offset, ref PriorityQueue<Merge>? priorityQueue)
{
if (_specialTokens?.TryGetValue(text, out (int specialTokenId, string specialToken) value) is true)
{
tokens.Add(new EncodedToken(value.specialTokenId, value.specialToken, new Range(offset, offset + text.Length)));
return;
}
Word word;
if (Cache is not null)
{
if (Cache.TryGetValue(text, out word))
{
WordToTokens(ref word, tokens, offset);
return;
}
word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache)
{
Cache.Set(text.ToString(), word);
}
}
else
{
word = MergeWord(text, ref priorityQueue);
}
WordToTokens(ref word, tokens, offset);
}
internal int WordToIds(ref Word word, IList<int>? accumulatedIds, out int charsConsumed, int fullTextLength, int maxTokens)
{
if (word.SymbolsCount <= maxTokens)
{
charsConsumed = fullTextLength;
if (accumulatedIds is not null)
{
word.PopulateIds(accumulatedIds);
}
return word.SymbolsCount;
}
if (accumulatedIds is not null)
{
return word.PopulateIdsUpToMax(accumulatedIds, maxTokens, out charsConsumed);
}
return word.CountIdsUpToMax(maxTokens, out charsConsumed);
}
internal int WordToIdsFromEnd(ref Word word, IList<int>? accumulatedIds, out int textIndex, int fullTextLength, int maxTokens)
{
if (word.SymbolsCount <= maxTokens)
{
textIndex = 0;
if (accumulatedIds is not null)
{
word.PopulateIds(accumulatedIds);
}
return word.SymbolsCount;
}
if (accumulatedIds is not null)
{
return word.PopulateIdsUpToMaxFromEnd(accumulatedIds, maxTokens, fullTextLength, out textIndex);
}
return word.CountIdsUpToMaxFromEnd(maxTokens, fullTextLength, out textIndex);
}
private int EncodeToIdsWithCache(ReadOnlySpan<char> text, List<int>? accumulatedIds, int maxTokens, out int charsConsumed, ref PriorityQueue<Merge>? priorityQueue)
{
if (_specialTokens?.TryGetValue(text, out (int specialTokenId, string specialToken) value) is true && maxTokens > 0)
{
accumulatedIds?.Add(value.specialTokenId);
charsConsumed = text.Length;
return 1;
}
Word word;
if (Cache is not null)
{
if (Cache.TryGetValue(text, out Word hit))
{
return WordToIds(ref hit, accumulatedIds, out charsConsumed, text.Length, maxTokens);
}
word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache)
{
Cache.Set(text.ToString(), word);
}
}
else
{
word = MergeWord(text, ref priorityQueue);
}
return WordToIds(ref word, accumulatedIds, out charsConsumed, text.Length, maxTokens);
}
internal int EncodeToIdsFromEndWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds, int maxTokens, out int textIndex, ref PriorityQueue<Merge>? priorityQueue)
{
Word word;
if (_specialTokens?.TryGetValue(text, out (int specialTokenId, string specialToken) value) is true && maxTokens > 0)
{
accumulatedIds?.Add(value.specialTokenId);
textIndex = 0;
return 1;
}
if (Cache is not null)
{
if (Cache.TryGetValue(text, out Word hit))
{
return WordToIdsFromEnd(ref hit, accumulatedIds, out textIndex, text.Length, maxTokens);
}
word = MergeWord(text, ref priorityQueue);
if (text.Length <= MaxWordLengthToCache)
{
Cache.Set(text.ToString(), word);
}
}
else
{
word = MergeWord(text, ref priorityQueue);
}
return WordToIdsFromEnd(ref word, accumulatedIds, out textIndex, text.Length, maxTokens);
}
}
}
|