|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers
{
/// <summary>
/// Represent the WordPiece tokenizer.
/// </summary>
/// <remarks>
/// The WordPiece tokenizer is a sub-word tokenizer that is used in BERT and other transformer models.
/// The implementation is based on the Hugging Face WordPiece tokenizer https://huggingface.co/docs/tokenizers/api/models#tokenizers.models.WordPiece.
/// </remarks>
public partial class WordPieceTokenizer : Tokenizer
{
private readonly PreTokenizer? _preTokenizer;
private readonly Normalizer? _normalizer;
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
private readonly Dictionary<int, string> _vocabReverse;
internal WordPieceTokenizer(
Dictionary<StringSpanOrdinalKey, int> vocab,
Dictionary<int, string> vocabReverse,
WordPieceOptions? options)
{
Debug.Assert(vocab is not null);
Debug.Assert(vocabReverse is not null);
_vocab = vocab!;
_vocabReverse = vocabReverse!;
options ??= new();
SpecialTokens = options.SpecialTokens;
SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.GroupBy(kvp => kvp.Value).ToDictionary(g => g.Key, g => g.First().Key) : null;
if (options.UnknownToken is null)
{
throw new ArgumentNullException(nameof(options.UnknownToken));
}
if (options.ContinuingSubwordPrefix is null)
{
throw new ArgumentNullException(nameof(options.ContinuingSubwordPrefix));
}
if (options.MaxInputCharsPerWord <= 0)
{
throw new ArgumentOutOfRangeException(nameof(options.MaxInputCharsPerWord), "The maximum number of characters per word must be greater than zero.");
}
if (!vocab!.TryGetValue(options.UnknownToken, out int id))
{
throw new ArgumentException($"The unknown token '{options.UnknownToken}' is not in the vocabulary.");
}
UnknownToken = options.UnknownToken;
UnknownTokenId = id;
ContinuingSubwordPrefix = options.ContinuingSubwordPrefix;
MaxInputCharsPerWord = options.MaxInputCharsPerWord;
_preTokenizer = options.PreTokenizer ?? PreTokenizer.CreateWhiteSpace(options.SpecialTokens);
_normalizer = options.Normalizer;
}
/// <summary>
/// Gets the unknown token ID.
/// A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
/// </summary>
public int UnknownTokenId { get; }
/// <summary>
/// Gets the prefix to use for sub-words that are not the first part of a word.
/// </summary>
public string ContinuingSubwordPrefix { get; }
/// <summary>
/// Gets the maximum number of characters to authorize in a single word.
/// </summary>
public int MaxInputCharsPerWord { get; }
internal static async ValueTask<(Dictionary<StringSpanOrdinalKey, int>, Dictionary<int, string>)> LoadVocabAsync(Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
Dictionary<StringSpanOrdinalKey, int> vocab = new Dictionary<StringSpanOrdinalKey, int>();
Dictionary<int, string> vocabReverse = new Dictionary<int, string>();
StreamReader reader = new StreamReader(vocabStream);
string? line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
int lineNumber = 0;
while (line is not null)
{
if (line.Length != 0)
{
vocab.Add(new StringSpanOrdinalKey(line), lineNumber);
vocabReverse.Add(lineNumber, line);
}
lineNumber++;
line = useAsync ? await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine();
}
return (vocab, vocabReverse);
}
/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class.
/// </summary>
/// <param name="vocabFilePath">The path to the WordPiece vocab file.</param>
/// <param name="options">The options to use for the WordPiece tokenizer.</param>
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
/// <remarks>
/// If the <see cref="WordPieceOptions.PreTokenizer"/> is null, the whitespace pre-tokenizer will be used.
/// When creating the tokenizer, ensure that the vocabulary file is sourced from a trusted provider.
/// </remarks>
public static WordPieceTokenizer Create(
string vocabFilePath,
WordPieceOptions? options = null) =>
Create(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), options, disposeStream: true);
/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class.
/// </summary>
/// <param name="vocabStream">The path to the WordPiece vocab file.</param>
/// <param name="options">The options to use for the WordPiece tokenizer.</param>
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
/// <remarks>
/// If the <see cref="WordPieceOptions.PreTokenizer"/> is null, the whitespace pre-tokenizer will be used.
/// When creating the tokenizer, ensure that the vocabulary stream is sourced from a trusted provider.
/// </remarks>
public static WordPieceTokenizer Create(
Stream vocabStream,
WordPieceOptions? options = null) =>
Create(vocabStream, options, disposeStream: false);
private static WordPieceTokenizer Create(
Stream vocabStream,
WordPieceOptions? options,
bool disposeStream)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
try
{
(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, string> vocabReverse) = LoadVocabAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
return new WordPieceTokenizer(vocab, vocabReverse, options);
}
finally
{
if (disposeStream)
{
vocabStream.Dispose();
}
}
}
/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
/// </summary>
/// <param name="vocabFilePath">The path to the WordPiece vocab file.</param>
/// <param name="options">The options to use for the WordPiece tokenizer.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
/// <remarks>
/// If the <see cref="WordPieceOptions.PreTokenizer"/> is null, the whitespace pre-tokenizer will be used.
/// When creating the tokenizer, ensure that the vocabulary file is sourced from a trusted provider.
/// </remarks>
public static async Task<WordPieceTokenizer> CreateAsync(
string vocabFilePath,
WordPieceOptions? options = null,
CancellationToken cancellationToken = default) =>
await CreateAsync(
string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath),
options,
cancellationToken,
disposeStream: true).ConfigureAwait(false);
/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
/// </summary>
/// <param name="vocabStream">The path to the WordPiece vocab file.</param>
/// <param name="options">The options to use for the WordPiece tokenizer.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A new instance of the <see cref="WordPieceTokenizer"/> class.</returns>
/// <remarks>
/// If the <see cref="WordPieceOptions.PreTokenizer"/> is null, the whitespace pre-tokenizer will be used.
/// When creating the tokenizer, ensure that the vocabulary stream is sourced from a trusted provider.
/// </remarks>
public static async Task<WordPieceTokenizer> CreateAsync(
Stream vocabStream,
WordPieceOptions? options = null,
CancellationToken cancellationToken = default) =>
await CreateAsync(vocabStream, options, cancellationToken, disposeStream: false).ConfigureAwait(false);
private static async Task<WordPieceTokenizer> CreateAsync(
Stream vocabStream,
WordPieceOptions? options,
CancellationToken cancellationToken,
bool disposeStream)
{
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}
try
{
(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, string> vocabReverse) = await LoadVocabAsync(vocabStream, useAsync: true, cancellationToken);
return new WordPieceTokenizer(vocab, vocabReverse, options);
}
finally
{
if (disposeStream)
{
vocabStream.Dispose();
}
}
}
/// <summary>
/// Gets the PreTokenizer used by the Tokenizer.
/// </summary>
public override PreTokenizer? PreTokenizer => _preTokenizer;
/// <summary>
/// Gets the Normalizer in use by the Tokenizer.
/// </summary>
public override Normalizer? Normalizer => _normalizer;
/// <summary>
/// Gets the unknown token.
/// A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
/// </summary>
public string UnknownToken { get; }
/// <summary>
/// Gets the special tokens and their corresponding ids.
/// </summary>
public IReadOnlyDictionary<string, int>? SpecialTokens { get; }
/// <summary>
/// Gets the Ids to tokens mapping for special tokens.
/// </summary>
internal IReadOnlyDictionary<int, string>? SpecialTokensReverse { get; }
/// <summary>
/// Encodes input text to a list of <see cref="EncodedToken" />s.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
protected override EncodeResults<EncodedToken> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
return new EncodeResults<EncodedToken> { NormalizedText = null, Tokens = [], CharsConsumed = 0 };
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderPreTokenization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out string? normalizedText,
out ReadOnlySpan<char> textSpanToEncode,
out int charsConsumed);
List<EncodedToken> tokens = new();
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
EncodeToTokens(textSpanToEncode.Slice(split.Offset, split.Length), tokens, split.Offset);
}
}
else
{
EncodeToTokens(textSpanToEncode, tokens, 0);
}
return new EncodeResults<EncodedToken> { NormalizedText = normalizedText, Tokens = tokens, CharsConsumed = charsConsumed };
}
/// <summary>
/// Encode text to a list of tokens.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="tokens">The list of tokens to populate.</param>
/// <param name="offset">The offset to start encoding from.</param>
private void EncodeToTokens(ReadOnlySpan<char> text, List<EncodedToken> tokens, int offset)
{
Debug.Assert(!text.IsEmpty);
if (text.Length > MaxInputCharsPerWord)
{
tokens.Add(new EncodedToken(UnknownTokenId, UnknownToken, new Range(offset, offset + text.Length)));
return;
}
int maxLength = MaxInputCharsPerWord + ContinuingSubwordPrefix.Length;
char[]? arrayPool = maxLength <= 250 ? null : ArrayPool<char>.Shared.Rent(maxLength);
Span<char> buffer = arrayPool is null ? stackalloc char[maxLength] : arrayPool;
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
int initialTokensCount = tokens.Count;
int textLength = text.Length;
bool isBad = false;
int start = 0;
while (start < textLength)
{
int end = textLength;
EncodedToken curToken = default;
while (start < end)
{
scoped ReadOnlySpan<char> subStr = text.Slice(start, end - start);
if (start > 0)
{
subStr.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
subStr = buffer.Slice(0, ContinuingSubwordPrefix.Length + subStr.Length);
}
if (_vocab.TryGetValue(subStr, out int id))
{
Debug.Assert(_vocabReverse.ContainsKey(id));
curToken = new EncodedToken(id, _vocabReverse[id], new Range(offset + start, offset + end));
break;
}
end -= 1;
}
if (curToken.Value is null)
{
isBad = true;
break;
}
tokens.Add(curToken);
start = end;
}
if (isBad)
{
// remove previously added tokens and add the unknown token
tokens.RemoveRange(initialTokensCount, tokens.Count - initialTokensCount);
tokens.Add(new EncodedToken(UnknownTokenId, UnknownToken, new Range(offset, offset + textLength)));
}
if (arrayPool is not null)
{
ArrayPool<char>.Shared.Return(arrayPool);
}
}
/// <summary>
/// Encodes input text to token Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The encoded results containing the list of encoded Ids.</returns>
protected override EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
int maxTokenCount = settings.MaxTokenCount;
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
return new EncodeResults<int> { NormalizedText = null, Tokens = [], CharsConsumed = 0 };
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderPreTokenization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out string? normalizedText,
out ReadOnlySpan<char> textSpanToEncode,
out int charsConsumed);
List<int> ids = new();
if (splits is not null)
{
charsConsumed = 0;
foreach ((int Offset, int Length) split in splits)
{
EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), ids, out int length, maxTokenCount - ids.Count);
if (length < split.Length || ids.Count >= maxTokenCount)
{
break;
}
charsConsumed = split.Offset + length;
}
}
else
{
EncodeToIds(textSpanToEncode, ids, out charsConsumed);
}
return new EncodeResults<int> { NormalizedText = normalizedText, Tokens = ids, CharsConsumed = charsConsumed };
}
/// <summary>
/// Encode text to a list of Ids.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
/// <param name="charsConsumed">The length of the text that encompasses the maximum encoded tokens.</param>
/// <param name="maxTokenCount">The maximum number of tokens to encode.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
private int EncodeToIds(ReadOnlySpan<char> text, List<int>? accumulatedIds, out int charsConsumed, int maxTokenCount = int.MaxValue)
{
Debug.Assert(maxTokenCount > 0);
if (text.IsEmpty)
{
charsConsumed = 0;
return 0;
}
if (text.Length > MaxInputCharsPerWord)
{
accumulatedIds?.Add(UnknownTokenId);
charsConsumed = text.Length;
return 1;
}
int maxLength = MaxInputCharsPerWord + ContinuingSubwordPrefix.Length;
char[]? arrayPool = maxLength <= 250 ? null : ArrayPool<char>.Shared.Rent(maxLength);
Span<char> buffer = arrayPool is null ? stackalloc char[maxLength] : arrayPool;
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
int addedIds = 0;
int textLength = text.Length;
bool isBad = false;
int start = 0;
while (start < textLength)
{
int end = textLength;
int curId = 0;
bool found = false;
while (start < end)
{
scoped ReadOnlySpan<char> subStr = text.Slice(start, end - start);
if (start > 0)
{
subStr.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
subStr = buffer.Slice(0, ContinuingSubwordPrefix.Length + subStr.Length);
}
if (_vocab.TryGetValue(subStr, out curId))
{
found = true;
break;
}
end -= 1;
}
if (!found)
{
isBad = true;
break;
}
accumulatedIds?.Add(curId);
addedIds++;
start = end;
}
charsConsumed = textLength;
if (addedIds > maxTokenCount)
{
// not enough space to hold added ids. Remove previously added ids
accumulatedIds?.RemoveRange(accumulatedIds.Count - addedIds, addedIds);
addedIds = 0;
charsConsumed = 0;
}
else if (isBad)
{
// remove previously added ids and add the unknown token id
accumulatedIds?.RemoveRange(accumulatedIds.Count - addedIds, addedIds);
accumulatedIds?.Add(UnknownTokenId);
addedIds = 1;
}
if (arrayPool is not null)
{
ArrayPool<char>.Shared.Return(arrayPool);
}
return addedIds;
}
/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <returns>The number of token Ids that the input text will be encoded to.</returns>
protected override int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings)
{
int maxTokenCount = settings.MaxTokenCount;
if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The maximum number of tokens must be greater than zero.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderPreTokenization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out string? normalizedText,
out ReadOnlySpan<char> textSpanToEncode,
out int charsConsumed);
int count = 0;
if (splits is not null)
{
foreach ((int Offset, int Length) split in splits)
{
count += EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), accumulatedIds: null, out int length, maxTokenCount - count);
if (length < split.Length || count >= maxTokenCount)
{
break;
}
}
}
else
{
count = EncodeToIds(textSpanToEncode, accumulatedIds: null, out charsConsumed, maxTokenCount);
}
return count;
}
/// <summary>
/// Find the index of the maximum encoding capacity without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="textSpan">The span of the text to encode which will be used if the <paramref name="text"/> is <see langword="null"/>.</param>
/// <param name="settings">The settings used to encode the text.</param>
/// <param name="fromEnd">Indicate whether to find the index from the end of the text.</param>
/// <param name="normalizedText">If the tokenizer's normalization is enabled or <paramRef name="settings" /> has <see cref="EncodeSettings.ConsiderNormalization"/> is <see langword="false"/>, this will be set to <paramRef name="text" /> in its normalized form; otherwise, this value will be set to <see langword="null"/>.</param>
/// <param name="tokenCount">The token count can be generated which should be smaller than the maximum token count.</param>
/// <returns>
/// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
/// If <paramRef name="fromEnd" /> is <see langword="false"/>, it represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
/// if all tokens fit, the result will be length of the input text or the <paramref name="normalizedText"/> if the normalization is enabled.
/// If <paramRef name="fromEnd" /> is <see langword="true"/>, it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
/// if all tokens fit, the result will be zero.
/// </returns>
protected override int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedText, out int tokenCount)
{
if (settings.MaxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(settings.MaxTokenCount), "The max token count must be greater than 0.");
}
if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
{
normalizedText = null;
tokenCount = 0;
return 0;
}
IEnumerable<(int Offset, int Length)>? splits = InitializeForEncoding(
text,
textSpan,
settings.ConsiderNormalization,
settings.ConsiderNormalization,
_normalizer,
_preTokenizer,
out normalizedText,
out ReadOnlySpan<char> textSpanToEncode,
out _);
int charsConsumed;
if (splits is null)
{
tokenCount = EncodeToIds(textSpanToEncode, accumulatedIds: null, out charsConsumed, settings.MaxTokenCount);
if (charsConsumed != textSpanToEncode.Length)
{
tokenCount = 0;
return fromEnd ? textSpanToEncode.Length : 0;
}
return fromEnd ? 0 : textSpanToEncode.Length;
}
if (fromEnd)
{
splits = splits.Reverse();
}
tokenCount = 0;
foreach ((int Offset, int Length) split in splits)
{
int count = EncodeToIds(textSpanToEncode.Slice(split.Offset, split.Length), accumulatedIds: null, out charsConsumed, settings.MaxTokenCount - tokenCount);
if (charsConsumed != split.Length)
{
return fromEnd ? split.Offset + split.Length : split.Offset;
}
tokenCount += count;
if (count >= settings.MaxTokenCount)
{
return fromEnd ? split.Offset : split.Offset + split.Length;
}
}
return fromEnd ? 0 : textSpanToEncode.Length;
}
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns>
public override string Decode(IEnumerable<int> ids) => Decode(ids, skipSpecialTokens: false);
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="skipSpecialTokens">Indicate whether to skip the special tokens during the decoding.</param>
/// <returns>The decoded string.</returns>
public string Decode(IEnumerable<int> ids, bool skipSpecialTokens)
{
ValueStringBuilder sb = new ValueStringBuilder();
bool first = true;
bool ignoreSpecialTokens = skipSpecialTokens && SpecialTokensReverse is not null;
foreach (int id in ids)
{
if (ignoreSpecialTokens && SpecialTokensReverse!.TryGetValue(id, out _))
{
continue;
}
if (_vocabReverse.TryGetValue(id, out string? token))
{
if (token.StartsWith(ContinuingSubwordPrefix))
{
sb.Append(token.AsSpan().Slice(ContinuingSubwordPrefix.Length));
}
else
{
if (!first && token[0] is not ('.' or ',' or '!' or '?' or '\''))
{
sb.Append(' ');
}
sb.Append(token);
}
}
first = false;
}
return sb.ToString();
}
/// <summary>
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="destination">The span to store the decoded text.</param>
/// <param name="idsConsumed">The number of ids consumed during the decoding.</param>
/// <param name="charsWritten">The number of characters written to the destination span.</param>
/// <returns>The operation status indicates whether all IDs were successfully decoded or if the <paramref name="destination"/> is too small to contain the entire decoded result.</returns>
public override OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten) =>
Decode(ids, destination, skipSpecialTokens: false, out idsConsumed, out charsWritten);
/// <summary>
/// Decode the given ids back to text and store the result in the <paramref name="destination"/> span.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="destination">The span to store the decoded text.</param>
/// <param name="skipSpecialTokens">Indicate whether to skip the special tokens during the decoding.</param>
/// <param name="idsConsumed">The number of ids consumed during the decoding.</param>
/// <param name="charsWritten">The number of characters written to the destination span.</param>
/// <returns>The operation status indicates whether all IDs were successfully decoded or if the <paramref name="destination"/> is too small to contain the entire decoded result.</returns>
public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool skipSpecialTokens, out int idsConsumed, out int charsWritten)
{
charsWritten = 0;
idsConsumed = 0;
Span<char> buffer = destination;
bool first = true;
bool ignoreSpecialTokens = SpecialTokensReverse is not null && skipSpecialTokens;
foreach (int id in ids)
{
if (ignoreSpecialTokens && SpecialTokensReverse!.TryGetValue(id, out _))
{
continue;
}
if (_vocabReverse.TryGetValue(id, out string? token))
{
if (token.StartsWith(ContinuingSubwordPrefix, StringComparison.Ordinal))
{
if (token.Length - ContinuingSubwordPrefix.Length > buffer.Length)
{
return OperationStatus.DestinationTooSmall;
}
token.AsSpan().Slice(ContinuingSubwordPrefix.Length).CopyTo(buffer);
buffer = buffer.Slice(token.Length - ContinuingSubwordPrefix.Length);
charsWritten += token.Length - ContinuingSubwordPrefix.Length;
}
else
{
if (!first)
{
if (token.Length + 1 > buffer.Length)
{
return OperationStatus.DestinationTooSmall;
}
buffer[0] = ' ';
token.AsSpan().CopyTo(buffer.Slice(1));
buffer = buffer.Slice(token.Length + 1);
charsWritten += token.Length + 1;
}
else
{
if (token.Length > buffer.Length)
{
return OperationStatus.DestinationTooSmall;
}
token.AsSpan().CopyTo(buffer);
buffer = buffer.Slice(token.Length);
charsWritten += token.Length;
}
}
first = false;
idsConsumed++;
}
else
{
return OperationStatus.InvalidData;
}
}
return OperationStatus.Done;
}
}
}
|