File: Utilities\NormStr.cs
Web Access
Project: src\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj (Microsoft.ML.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Internal.Utilities
{
    using Conditional = System.Diagnostics.ConditionalAttribute;
 
    /// <summary>
    /// Normalized string type. For string pooling.
    /// </summary>
    [BestFriend]
    internal sealed class NormStr
    {
        public readonly ReadOnlyMemory<char> Value;
        public readonly int Id;
        private readonly uint _hash;
 
        /// <summary>
        /// NormStr's can only be created by the Pool.
        /// </summary>
        private NormStr(ReadOnlyMemory<char> str, int id, uint hash)
        {
            Contracts.Assert(id >= 0 || id == -1 && str.IsEmpty);
            Value = str;
            Id = id;
            _hash = hash;
        }
 
        public override int GetHashCode()
        {
            return (int)_hash;
        }
 
        public sealed class Pool : IEnumerable<NormStr>
        {
            private int _mask; // Number of buckets minus 1. The number of buckets must be a power of two.
            private int[] _rgins;  // Buckets of size _mask + 1.
 
            // The number of strings.
            private int _cns;
            // The strings.
            private NormStr[] _rgns;
            // Parallel to _rgns. Each ulong contains the length of the string (for speedy checks) and the
            // index of the next string in the same bucket. The length is the low int and the next index
            // is the high int. Doing this gives better perf than more structured alternatives.
            private ulong[] _rgmeta;
 
            public int Count { get { return _cns; } }
 
            public Pool()
            {
                _mask = 31;
                _rgins = new int[_mask + 1];
                for (int i = 0; i < _rgins.Length; i++)
                    _rgins[i] = -1;
 
                AssertValid();
            }
 
            [Conditional("DEBUG")]
            private void AssertValid()
            {
                // Number of buckets must be a power of two.
                Contracts.AssertValue(_rgins);
                Contracts.Assert(_rgins.Length == _mask + 1);
                Contracts.Assert(Utils.IsPowerOfTwo(_mask + 1));
 
                Contracts.Assert(0 <= _cns && _cns <= Utils.Size(_rgns));
                Contracts.Assert(Utils.Size(_rgns) == Utils.Size(_rgmeta));
            }
 
            private int GetIns(uint hash)
            {
                return _rgins[(int)hash & _mask];
            }
 
            private int GetIins(uint hash)
            {
                return (int)hash & _mask;
            }
 
            /// <summary>
            /// Find the given string in the pool. If not found, returns null.
            /// </summary>
            public NormStr Get(string str, bool add = false)
            {
                AssertValid();
 
                if (str == null)
                    str = "";
 
                var strSpan = str.AsSpan();
                uint hash = Hashing.HashString(strSpan);
                int ins = GetIns(hash);
                while (ins >= 0)
                {
                    ulong meta = _rgmeta[ins];
                    if ((int)Utils.GetLo(meta) == str.Length)
                    {
                        var ns = GetNs(ins);
                        if (strSpan.SequenceEqual(ns.Value.Span))
                            return ns;
                    }
                    ins = (int)Utils.GetHi(meta);
                }
                Contracts.Assert(ins == -1);
 
                return add ? AddCore(str.AsMemory(), hash) : null;
            }
 
            public NormStr Get(ReadOnlyMemory<char> str, bool add = false, bool duplicateStr = true)
            {
                AssertValid();
 
                var span = str.Span;
                uint hash = Hashing.HashString(span);
                int ins = GetIns(hash);
                while (ins >= 0)
                {
                    ulong meta = _rgmeta[ins];
                    if ((int)Utils.GetLo(meta) == str.Length)
                    {
                        var ns = GetNs(ins);
                        if (ns.Value.Span.SequenceEqual(span))
                            return ns;
                    }
                    ins = (int)Utils.GetHi(meta);
                }
                Contracts.Assert(ins == -1);
 
                if (duplicateStr)
                {
                    // To avoid the case where 'str' actually stores a string with the
                    // content of a whole row in the dataset, a new 'str' is created
                    // See issue https://github.com/dotnet/machinelearning/issues/4571
                    // and PR https://github.com/dotnet/machinelearning/pull/4576
                    return add ? AddCore(str.ToString().AsMemory(), hash) : null;
                }
 
                return add ? AddCore(str, hash) : null;
            }
 
            /// <summary>
            /// Make sure the given string has an equivalent NormStr in the pool and return it.
            /// </summary>
            public NormStr Add(string str)
            {
                return Get(str, true);
            }
 
            public NormStr Add(ReadOnlyMemory<char> str, bool duplicateStr = true)
            {
                return Get(str, true, duplicateStr);
            }
 
            /// <summary>
            /// Make sure the given string has an equivalent NormStr in the pool and return it.
            /// </summary>
            public NormStr Get(StringBuilder sb, bool add = false)
            {
                AssertValid();
 
                if (sb == null)
                    return Get("", add);
 
                int cch = sb.Length;
 
                NormStr ns;
                uint hash = Hashing.HashString(sb);
                int ins = GetIns(hash);
                while (ins >= 0)
                {
                    ulong meta = _rgmeta[ins];
                    if ((int)Utils.GetLo(meta) == cch)
                    {
                        ns = GetNs(ins);
                        var value = ns.Value;
                        for (int ich = 0; ; ich++)
                        {
                            if (ich == cch)
                                return ns;
                            if (value.Span[ich] != sb[ich])
                                break;
                        }
                    }
                    ins = (int)Utils.GetHi(meta);
                }
                Contracts.Assert(ins == -1);
 
                return add ? AddCore(sb.ToString().AsMemory(), hash) : null;
            }
 
            /// <summary>
            /// Make sure the given string builder has an equivalent NormStr in the pool and return it.
            /// </summary>
            public NormStr Add(StringBuilder sb)
            {
                return Get(sb, true);
            }
 
            /// <summary>
            /// Adds the item. Does NOT check for whether the item is already present.
            /// </summary>
            private NormStr AddCore(ReadOnlyMemory<char> str, uint hash)
            {
                Contracts.Assert(str.Length >= 0);
                Contracts.Assert(Hashing.HashString(str.Span) == hash);
 
                if (_rgns == null)
                {
                    Contracts.Assert(_cns == 0);
                    _rgmeta = new ulong[10];
                    _rgns = new NormStr[10];
                }
                else if (_cns >= _rgns.Length)
                {
                    Contracts.Assert(_cns == _rgns.Length);
                    int size = checked(_rgns.Length / 2 + _rgns.Length);
                    Array.Resize(ref _rgmeta, size);
                    Array.Resize(ref _rgns, size);
                }
                Contracts.Assert(_cns < _rgns.Length);
 
                NormStr ns = new NormStr(str, _cns, hash);
                int iins = GetIins(hash);
                _rgns[_cns] = ns;
                _rgmeta[_cns] = Utils.MakeUlong((uint)_rgins[iins], (uint)ns.Value.Length);
                _rgins[iins] = _cns;
 
                if (++_cns >= _rgins.Length)
                    GrowTable();
 
                AssertValid();
                return ns;
            }
 
            public NormStr GetNormStrById(int id)
            {
                Contracts.CheckParam(0 <= id && id < _cns, nameof(id));
                return GetNs(id);
            }
 
            private NormStr GetNs(int ins)
            {
                Contracts.Assert(0 <= ins && ins < _cns);
                Contracts.Assert(_rgns[ins].Id == ins);
                return _rgns[ins];
            }
 
            private void GrowTable()
            {
                AssertValid();
 
                int size = checked(2 * _rgins.Length);
                _rgins = new int[size];
                _mask = size - 1;
                for (int i = 0; i < _rgins.Length; i++)
                    _rgins[i] = -1;
 
                for (int ins = 0; ins < _cns; ins++)
                {
                    var ns = GetNs(ins);
                    int iins = GetIins(ns._hash);
                    _rgmeta[ins] = Utils.MakeUlong((uint)_rgins[iins], (uint)ns.Value.Length);
                    _rgins[iins] = ins;
                }
 
                AssertValid();
            }
 
            public IEnumerator<NormStr> GetEnumerator()
            {
                for (int ins = 0; ins < _cns; ins++)
                    yield return GetNs(ins);
            }
 
            IEnumerator IEnumerable.GetEnumerator()
            {
                return GetEnumerator();
            }
        }
    }
}