File: WordPieceTests.cs
Web Access
Project: src\test\Microsoft.ML.Tokenizers.Tests\Microsoft.ML.Tokenizers.Tests.csproj (Microsoft.ML.Tokenizers.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Xunit;
 
namespace Microsoft.ML.Tokenizers.Tests
{
    public class WordPieceTests
    {
        static string[] _vocabTokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"];
 
        internal static string CreateVocabFile(string[] vocabTokens)
        {
            string vocabFile = Path.GetTempFileName();
            File.WriteAllLines(vocabFile, vocabTokens);
            return vocabFile;
        }
 
        [Fact]
        public void TestCreation()
        {
            string vocabFile = CreateVocabFile(_vocabTokens);
 
            try
            {
                using Stream vocabStream = File.OpenRead(vocabFile);
                WordPieceTokenizer[] wordPieceTokenizers = [WordPieceTokenizer.Create(vocabFile), WordPieceTokenizer.Create(vocabStream)];
 
                foreach (var tokenizer in wordPieceTokenizers)
                {
                    Assert.NotNull(tokenizer.PreTokenizer);
                    Assert.Equal("[UNK]", tokenizer.UnknownToken);
                    Assert.Equal(0, tokenizer.UnknownTokenId);
                    Assert.Null(tokenizer.Normalizer);
                    Assert.Equal(100, tokenizer.MaxInputCharsPerWord);
                    Assert.Equal("##", tokenizer.ContinuingSubwordPrefix);
                }
            }
            finally
            {
                File.Delete(vocabFile);
            }
        }
 
        [Fact]
        public void TestTokenization()
        {
            string vocabFile = CreateVocabFile(_vocabTokens);
 
            try
            {
                WordPieceTokenizer tokenizer = WordPieceTokenizer.Create(vocabFile);
 
                Assert.Null(tokenizer.SpecialTokens);
 
                IReadOnlyList<EncodedToken> tokens = tokenizer.EncodeToTokens("", out _);
                Assert.Empty(tokens);
                Assert.Equal(0, tokenizer.CountTokens(""));
                IReadOnlyList<int> ids = tokenizer.EncodeToIds("");
                Assert.Empty(ids);
                int index = tokenizer.GetIndexByTokenCount("", maxTokenCount: 10, normalizedString: out _, tokenCount: out int tokenCount);
                Assert.Equal(0, index);
                Assert.Equal(0, tokenCount);
                index = tokenizer.GetIndexByTokenCountFromEnd("", maxTokenCount: 10, normalizedString: out _, tokenCount: out tokenCount);
                Assert.Equal(0, index);
                Assert.Equal(0, tokenCount);
 
                string text = "unwanted running";
                tokens = tokenizer.EncodeToTokens(text, out _);
                Assert.Equal(
                    [
                        new EncodedToken(7, "un", new Range(0, 2)),
                        new EncodedToken(4, "##want", new Range(2, 6)),
                        new EncodedToken(5, "##ed", new Range(6, 8)),
                        new EncodedToken(8, "runn", new Range(9, 13)),
                        new EncodedToken(9, "##ing", new Range(13, 16))
                    ],
                    tokens
                );
 
                ids = tokenizer.EncodeToIds(text);
                Assert.Equal([7, 4, 5, 8, 9], ids);
 
                int[] expectedTokenCount = [0, 0, 3, 3, 5];
                for (int i = 1; i <= 5; i++)
                {
                    Assert.Equal(ids.Take(expectedTokenCount[i - 1]).ToArray(), tokenizer.EncodeToIds(text, maxTokenCount: i, normalizedText: out _, out tokenCount));
                }
 
                Assert.Equal(text, tokenizer.Decode(ids));
 
                Span<char> buffer = stackalloc char[text.Length];
                for (int i = 0; i < text.Length - 1; i++)
                {
                    Span<char> bufferSlice = buffer.Slice(0, i);
                    OperationStatus result = tokenizer.Decode(ids, bufferSlice, out int idsConsumed, out int charsWritten);
                    Assert.Equal(OperationStatus.DestinationTooSmall, result);
 
                    int j = 0;
 
                    while (i >= tokens[j].Offset.End.Value)
                    {
                        j++;
                    }
 
                    Assert.Equal(j, idsConsumed);
                    Assert.Equal(j == 0 ? 0 : tokens[j - 1].Offset.End.Value, charsWritten);
                    Assert.Equal(j == 0 ? "" : text.Substring(0, tokens[j - 1].Offset.End.Value), bufferSlice.Slice(0, charsWritten).ToString());
                }
 
                Assert.Equal(5, tokenizer.CountTokens(text));
 
                int[] expectedIndexes = [0, 0, 8, 9, 16];
                expectedTokenCount = [0, 0, 3, 3, 5];
 
                for (int i = 1; i <= 5; i++)
                {
                    index = tokenizer.GetIndexByTokenCount(text, maxTokenCount: i, normalizedString: out _, out tokenCount);
                    Assert.Equal(expectedTokenCount[i - 1], tokenCount);
                    Assert.Equal(expectedIndexes[i - 1], index);
                }
 
                expectedIndexes = [16, 9, 8, 8, 0];
                expectedTokenCount = [0, 2, 2, 2, 5];
 
                for (int i = 1; i <= 5; i++)
                {
                    index = tokenizer.GetIndexByTokenCountFromEnd(text, maxTokenCount: i, normalizedString: out _, out tokenCount);
                    Assert.Equal(expectedTokenCount[i - 1], tokenCount);
                    Assert.Equal(expectedIndexes[i - 1], index);
                }
            }
            finally
            {
                File.Delete(vocabFile);
            }
        }
 
        [Fact]
        public void TestTokenizationWithUnknownTokens()
        {
            string vocabFile = CreateVocabFile(_vocabTokens);
 
            try
            {
                WordPieceTokenizer tokenizer = WordPieceTokenizer.Create(vocabFile);
 
                string text = "unwantedX running";
 
                IReadOnlyList<EncodedToken> tokens = tokenizer.EncodeToTokens(text, out _);
                Assert.Equal(
                    [
                        new EncodedToken(0, "[UNK]", new Range(0, 9)),
                        new EncodedToken(8, "runn",  new Range(10, 14)),
                        new EncodedToken(9, "##ing", new Range(14, 17))
                    ],
                    tokens
                );
 
                IReadOnlyList<int> ids = tokenizer.EncodeToIds(text);
                Assert.Equal([0, 8, 9], ids);
 
                Assert.Equal("[UNK] running", tokenizer.Decode(ids));
            }
            finally
            {
                File.Delete(vocabFile);
            }
        }
 
        [Fact]
        public void TestTokenizationWithSpecialTokens()
        {
            string vocabFile = CreateVocabFile(_vocabTokens);
 
            try
            {
                Dictionary<string, int> specialTokens = new Dictionary<string, int>
                {
                    { "[UNK]", 0 }, { "[CLS]", 1 }, { "[SEP]", 2 }
                };
                WordPieceTokenizer tokenizer = WordPieceTokenizer.Create(vocabFile, specialTokens: specialTokens);
 
                Assert.Equal(specialTokens, tokenizer.SpecialTokens);
 
                string text = "[UNK] unwanted [SEP][CLS] running [CLS]";
 
                IReadOnlyList<EncodedToken> tokens = tokenizer.EncodeToTokens(text, out _);
                Assert.Equal(
                    [
                        new EncodedToken(0, "[UNK]", new Range(0, 5)),
                        new EncodedToken(7, "un", new Range(6, 8)),
                        new EncodedToken(4, "##want", new Range(8, 12)),
                        new EncodedToken(5, "##ed", new Range(12, 14)),
                        new EncodedToken(2, "[SEP]", new Range(15, 20)),
                        new EncodedToken(1, "[CLS]", new Range(20, 25)),
                        new EncodedToken(8, "runn", new Range(26, 30)),
                        new EncodedToken(9, "##ing", new Range(30, 33)),
                        new EncodedToken(1, "[CLS]", new Range(34, 39)),
                    ],
                    tokens
                );
 
                IReadOnlyList<int> ids = tokenizer.EncodeToIds(text);
                Assert.Equal([0, 7, 4, 5, 2, 1, 8, 9, 1], ids);
 
                Assert.Equal("[UNK] unwanted [SEP] [CLS] running [CLS]", tokenizer.Decode(ids));
            }
            finally
            {
                File.Delete(vocabFile);
            }
        }
    }
}