File: Utils\StringSpanOrdinalKey.cs
Web Access
Project: src\src\Microsoft.ML.Tokenizers\Microsoft.ML.Tokenizers.csproj (Microsoft.ML.Tokenizers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
 
namespace Microsoft.ML.Tokenizers
{
    /// <summary>Used as a key in a dictionary to enable querying with either a string or a span.</summary>
    /// <remarks>
    /// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should
    /// always be used with a string.
    /// </remarks>
    [JsonConverter(typeof(StringSpanOrdinalKeyConverter))]
    internal readonly unsafe struct StringSpanOrdinalKey : IEquatable<StringSpanOrdinalKey>
    {
        public readonly char* Ptr;
        public readonly int Length;
        public readonly string? Data;
 
        public StringSpanOrdinalKey(char* ptr, int length)
        {
            Ptr = ptr;
            Length = length;
        }
 
        public StringSpanOrdinalKey(string data) =>
            Data = data;
 
        private ReadOnlySpan<char> Span => Ptr is not null ?
            new ReadOnlySpan<char>(Ptr, Length) :
            Data.AsSpan();
 
        public override string ToString() => Data ?? Span.ToString();
 
        public override bool Equals(object? obj) =>
            obj is StringSpanOrdinalKey wrapper && Equals(wrapper);
 
        public bool Equals(StringSpanOrdinalKey other) =>
            Span.SequenceEqual(other.Span);
 
        public override int GetHashCode() => Helpers.GetHashCode(Span);
    }
 
    internal readonly unsafe struct StringSpanOrdinalKeyPair : IEquatable<StringSpanOrdinalKeyPair>
    {
        private readonly StringSpanOrdinalKey _left;
        private readonly StringSpanOrdinalKey _right;
 
        public StringSpanOrdinalKeyPair(char* ptr1, int length1, char* ptr2, int length2)
        {
            _left = new StringSpanOrdinalKey(ptr1, length1);
            _right = new StringSpanOrdinalKey(ptr2, length2);
        }
 
        public StringSpanOrdinalKeyPair(string data1, string data2)
        {
            _left = new StringSpanOrdinalKey(data1);
            _right = new StringSpanOrdinalKey(data2);
        }
        public override bool Equals(object? obj) =>
            obj is StringSpanOrdinalKeyPair wrapper && wrapper._left.Equals(_left) && wrapper._right.Equals(_right);
 
        public bool Equals(StringSpanOrdinalKeyPair other) => other._left.Equals(_left) && other._right.Equals(_right);
 
        public override int GetHashCode() => HashCode.Combine(_left.GetHashCode(), _right.GetHashCode());
    }
 
 
    internal sealed class StringSpanOrdinalKeyCache<TValue>
    {
        private readonly int _capacity;
        private readonly Dictionary<StringSpanOrdinalKey, TValue> _map;
 
        private object SyncObj => _map;
 
        internal StringSpanOrdinalKeyCache() : this(BpeTokenizer.DefaultCacheCapacity) { }
 
        internal StringSpanOrdinalKeyCache(int capacity)
        {
            _capacity = capacity;
            _map = new Dictionary<StringSpanOrdinalKey, TValue>(capacity);
        }
 
        internal bool TryGetValue(string key, out TValue value)
        {
            lock (SyncObj)
            {
                return _map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
            }
        }
 
        internal unsafe bool TryGetValue(ReadOnlySpan<char> key, out TValue value)
        {
            lock (SyncObj)
            {
                fixed (char* ptr = key)
                {
                    return _map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
                }
            }
        }
 
        internal void Remove(string key)
        {
            lock (SyncObj)
            {
                _map.Remove(new StringSpanOrdinalKey(key));
            }
        }
 
        internal void Set(string k, TValue v)
        {
            lock (SyncObj)
            {
                if (_map.Count < _capacity)
                {
                    _map[new StringSpanOrdinalKey(k)] = v;
                }
            }
        }
    }
 
    [JsonConverter(typeof(VocabularyConverter))]
    internal sealed class Vocabulary : Dictionary<StringSpanOrdinalKey, (int, string)>;
 
    /// <summary>
    /// Custom JSON converter for <see cref="StringSpanOrdinalKey"/>.
    /// </summary>
    internal sealed class StringSpanOrdinalKeyConverter : JsonConverter<StringSpanOrdinalKey>
    {
        public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) =>
            new StringSpanOrdinalKey(reader.GetString()!);
 
        public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) =>
            writer.WriteStringValue(value.Data!);
 
        public override StringSpanOrdinalKey Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => new StringSpanOrdinalKey(reader.GetString()!);
        public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!);
    }
 
    internal class VocabularyConverter : JsonConverter<Vocabulary>
    {
        public override Vocabulary Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
        {
            var dictionary = new Vocabulary();
            while (reader.Read())
            {
                if (reader.TokenType == JsonTokenType.EndObject)
                {
                    return dictionary;
                }
 
                if (reader.TokenType == JsonTokenType.PropertyName)
                {
                    var key = reader.GetString();
                    reader.Read();
                    var value = reader.GetInt32();
                    dictionary.Add(new StringSpanOrdinalKey(key!), (value, key!));
                }
            }
            throw new JsonException("Invalid JSON.");
        }
 
        public override void Write(Utf8JsonWriter writer, Vocabulary value, JsonSerializerOptions options) => throw new NotImplementedException();
    }
 
    /// <summary>
    /// Extension methods for <see cref="StringSpanOrdinalKey"/>.
    /// </summary>
    internal static class StringSpanOrdinalKeyExtensions
    {
        public static unsafe bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, ReadOnlySpan<char> key, out TValue value)
        {
            fixed (char* ptr = key)
            {
                return map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
            }
        }
 
        public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKey, TValue> map, string key, out TValue value) =>
            map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
 
        public static unsafe bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, ReadOnlySpan<char> key1, ReadOnlySpan<char> key2, out TValue value)
        {
            fixed (char* ptr1 = key1)
            fixed (char* ptr2 = key2)
            {
                return map.TryGetValue(new StringSpanOrdinalKeyPair(ptr1, key1.Length, ptr2, key2.Length), out value!);
            }
        }
 
        public static bool TryGetValue<TValue>(this Dictionary<StringSpanOrdinalKeyPair, TValue> map, string key1, string key2, out TValue value) =>
            map.TryGetValue(new StringSpanOrdinalKeyPair(key1, key2), out value!);
    }
}