File: PreTokenizer\RegexPreTokenizer.cs
Web Access
Project: src\src\Microsoft.ML.Tokenizers\Microsoft.ML.Tokenizers.csproj (Microsoft.ML.Tokenizers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
 
namespace Microsoft.ML.Tokenizers
{
    /// <summary>
    /// The pre-tokenizer for Tiktoken tokenizer.
    /// </summary>
    public sealed partial class RegexPreTokenizer : PreTokenizer
    {
        private readonly Regex? _specialTokensRegex;
        private readonly Regex _regex;
 
        /// <summary>
        /// Initializes a new instance of the <see cref="RegexPreTokenizer"/> class.
        /// </summary>
        /// <param name="regex">The regex to use for splitting the text into smaller tokens in the pre-tokenization process.</param>
        /// <param name="specialTokens">The dictionary containing the special tokens and their corresponding ids.</param>
        /// <exception cref="ArgumentNullException">When regex is null</exception>
        public RegexPreTokenizer(Regex regex, IReadOnlyDictionary<string, int>? specialTokens)
        {
            if (regex is null)
            {
                throw new ArgumentNullException(nameof(regex));
            }
 
            _regex = regex;
 
            if (specialTokens is { Count: > 0 })
            {
                // We create this Regex object without a timeout, as we expect the match operation to complete in \(O(N)\) time complexity. Note that `specialTokens` is treated as constants after the pre-tokenizer is created.
                _specialTokensRegex = new Regex(string.Join("|", specialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
            }
        }
 
        /// <summary>
        /// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
        /// </summary>
        /// <param name="text">The string to split into tokens.</param>
        /// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
        public override IEnumerable<(int Offset, int Length)> PreTokenize(string text)
        {
            if (string.IsNullOrEmpty(text))
            {
                return [];
            }
 
            return SplitText(text, _regex, _specialTokensRegex);
 
            static IEnumerable<(int Offset, int Length)> SplitText(string text, Regex regex, Regex? specialTokensRegex)
            {
                (int Offset, int Length) match;
                int beginning = 0;
 
                if (specialTokensRegex is not null)
                {
                    while (true)
                    {
                        (int Offset, int Length) specialMatch;
                        if (!TryGetMatch(specialTokensRegex, text, beginning, text.Length - beginning, out specialMatch))
                        {
                            break;
                        }
 
                        while (TryGetMatch(regex, text, beginning, specialMatch.Offset - beginning, out match))
                        {
                            yield return (match.Offset, match.Length);
                            beginning = match.Offset + match.Length;
                        }
 
                        yield return (specialMatch.Offset, specialMatch.Length);
                        beginning = specialMatch.Offset + specialMatch.Length;
                    }
                }
 
                while (TryGetMatch(regex, text, beginning, text.Length - beginning, out match))
                {
                    yield return (match.Offset, match.Length);
                    beginning = match.Length + match.Offset;
                }
            }
        }
 
        /// <summary>
        /// Get the offsets and lengths of the tokens relative to the <paramref name="text"/>.
        /// </summary>
        /// <param name="text">The string to split into tokens.</param>
        /// <returns>The offsets and lengths of the tokens, expressed as pairs, are relative to the original string.</returns>
        public override IEnumerable<(int Offset, int Length)> PreTokenize(ReadOnlySpan<char> text)
        {
            if (text.IsEmpty)
            {
                return [];
            }
 
#if NET7_0_OR_GREATER
            char[] buffer = ArrayPool<char>.Shared.Rent(text.Length);
            text.CopyTo(buffer);
            return SplitText(buffer, _regex, _specialTokensRegex, text.Length);
 
            static IEnumerable<(int Offset, int Length)> SplitText(char[] text, Regex regex, Regex? specialTokensRegex, int textLength)
            {
                (int Offset, int Length) match;
                int beginning = 0;
 
                if (specialTokensRegex is not null)
                {
                    while (true)
                    {
                        (int Offset, int Length) specialMatch;
                        if (!TryGetMatch(specialTokensRegex, text.AsSpan(), beginning, textLength - beginning, out specialMatch))
                        {
                            break;
                        }
 
                        while (TryGetMatch(regex, text.AsSpan(), beginning, specialMatch.Offset - beginning, out match))
                        {
                            yield return (match.Offset, match.Length);
                            beginning = match.Offset + match.Length;
                        }
 
                        yield return (specialMatch.Offset, specialMatch.Length);
                        beginning = specialMatch.Offset + specialMatch.Length;
                    }
                }
 
                while (TryGetMatch(regex, text.AsSpan(), beginning, textLength - beginning, out match))
                {
                    yield return (match.Offset, match.Length);
                    beginning = match.Length + match.Offset;
                }
 
                ArrayPool<char>.Shared.Return(text);
            }
#else
            return PreTokenize(text.ToString());
#endif // NET7_0_OR_GREATER
        }
    }
}