File: Utils\TokenizerExtensions.cs
Web Access
Project: src\src\Microsoft.ML.Tokenizers\Microsoft.ML.Tokenizers.csproj (Microsoft.ML.Tokenizers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
 
namespace Microsoft.ML.Tokenizers
{
    internal static class TokenizerExtensions
    {
        public static T? ArgMin<T>(this IEnumerable<T> source, Func<T, int> getValue)
        {
            var keys = source.ToList();     // avoid enumerate twice
            var values = keys.Select(getValue);
            var (minSource, minValue) = keys.Zip(values, (first, second) => (first, second)).Aggregate((min, x) => min.second <= x.second ? min : x);
            return minValue < int.MaxValue ? minSource : default;
        }
 
        public static TValue GetOrAdd<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue setValue) where TKey : notnull
        {
            if (dic.TryGetValue(key, out var value))
            {
                return value;
            }
 
            dic[key] = setValue;
            return setValue;
        }
 
        public static IReadOnlyDictionary<TValue, TKey> Reverse<TKey, TValue>(this IReadOnlyDictionary<TKey, TValue> source) where TValue : notnull
        {
            Dictionary<TValue, TKey> dictionary = new Dictionary<TValue, TKey>();
            if (source != null)
            {
                foreach (KeyValuePair<TKey, TValue> pair in source)
                {
                    dictionary[pair.Value] = pair.Key;
                }
                return dictionary;
            }
            return dictionary;
        }
 
        public static SortedDictionary<TValue, TKey> ReverseSorted<TKey, TValue>(this IReadOnlyDictionary<TKey, TValue> source) where TKey : notnull where TValue : notnull
        {
            SortedDictionary<TValue, TKey> dictionary = new SortedDictionary<TValue, TKey>();
            if (source != null)
            {
                foreach (KeyValuePair<TKey, TValue> pair in source)
                {
                    dictionary[pair.Value] = pair.Key;
                }
                return dictionary;
            }
            return dictionary;
        }
    }
 
    internal class DictReversingConverter : JsonConverter<SortedDictionary<int, string>>
    {
        public override SortedDictionary<int, string>? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => null;
 
        public override void Write(Utf8JsonWriter writer, SortedDictionary<int, string> value, JsonSerializerOptions options)
        {
            writer.WriteStartObject();
 
            foreach (KeyValuePair<int, string> pair in value)
            {
                if (pair.Key >= 0)
                {
                    writer.WriteNumber(pair.Value, pair.Key);
                }
            }
 
            writer.WriteEndObject();
        }
    }
}