File: Text\NgramUtils.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// This delegate represents a function that gets an n-gram as input, and outputs the id of
    /// the n-gram and whether or not to continue processing n-grams.
    /// </summary>
    /// <param name="ngram">The array containing the n-gram</param>
    /// <param name="lim">The n-gram is stored in ngram[0],...ngram[lim-1].</param>
    /// <param name="icol">The index of the column the transform is applied to.</param>
    /// <param name="more">True if processing should continue, false if it should stop.
    /// It is true on input, so only needs to be set to false.</param>
    /// <returns>The n-gram slot if it was found, -1 otherwise.</returns>
    internal delegate int NgramIdFinder(uint[] ngram, int lim, int icol, ref bool more);
 
    // A class that given a VBuffer of keys, finds all the n-grams in it, and maintains a vector of n-gram-counts.
    // The id of each n-gram is found by calling an NgramIdFinder delegate. This class can also be used to build
    // an n-gram dictionary, by defining an NgramIdFinder that adds the n-grams to a dictionary and always return false.
    internal sealed class NgramBufferBuilder
    {
        // This buffer builder maintains the vector of n-gram-counts.
        private readonly BufferBuilder<float> _bldr;
        // A queue that holds _ngramLength+_skipLength keys, so that it contains all the n-grams starting with the
        // first key in the n-gram.
        private readonly FixedSizeQueue<uint> _queue;
        // The maximum n-gram length.
        private readonly int _ngramLength;
        // The maximum number of skips contained in an n-gram.
        private readonly int _skipLength;
        // An array of length _ngramLength, containing the current n-gram.
        private readonly uint[] _ngram;
        // The maximum n-gram id.
        private readonly int _slotLim;
        private readonly NgramIdFinder _finder;
 
        public const int MaxSkipNgramLength = 10;
 
        public bool IsEmpty { get { return _slotLim == 0; } }
 
        public NgramBufferBuilder(int ngramLength, int skipLength, int slotLim, NgramIdFinder finder)
        {
            Contracts.Assert(ngramLength > 0);
            Contracts.Assert(skipLength >= 0);
            Contracts.Assert(ngramLength <= MaxSkipNgramLength - skipLength);
            Contracts.Assert(slotLim >= 0);
 
            _ngramLength = ngramLength;
            _skipLength = skipLength;
            _slotLim = slotLim;
 
            _ngram = new uint[_ngramLength];
            _queue = new FixedSizeQueue<uint>(_ngramLength + _skipLength);
            _bldr = BufferBuilder<float>.CreateDefault();
            _finder = finder;
        }
 
        public void Reset()
        {
            _bldr.Reset(_slotLim, false);
            _queue.Clear();
        }
 
        public bool AddNgrams(in VBuffer<uint> src, int icol, uint keyMax)
        {
            Contracts.Assert(icol >= 0);
            Contracts.Assert(keyMax > 0);
 
            var srcValues = src.GetValues();
            uint curKey = 0;
            if (src.IsDense)
            {
                for (int i = 0; i < src.Length; i++)
                {
                    curKey = srcValues[i];
                    if (curKey > keyMax)
                        curKey = 0;
 
                    _queue.AddLast(curKey);
 
                    // Add the n-gram counts
                    if (_queue.IsFull && !ProcessNgrams(icol))
                        return false;
                }
            }
            else
            {
                var queueSize = _queue.Capacity;
                var srcIndices = src.GetIndices();
 
                int iindex = 0;
                for (int i = 0; i < src.Length; i++)
                {
                    if (iindex < srcIndices.Length && i == srcIndices[iindex])
                    {
                        curKey = srcValues[iindex];
                        if (curKey > keyMax)
                            curKey = 0;
                        iindex++;
                    }
                    else
                        curKey = 0;
 
                    _queue.AddLast(curKey);
                    if (!_queue.IsFull)
                        continue;
 
                    // Add the n-gram counts
                    if (!ProcessNgrams(icol))
                        return false;
                }
            }
 
            if (_queue.IsFull)
                _queue.RemoveFirst();
 
            // Process the grams of the remaining terms
            while (_queue.Count > 0)
            {
                if (!ProcessNgrams(icol))
                    return false;
                _queue.RemoveFirst();
            }
            return true;
        }
 
        public void GetResult(ref VBuffer<float> dst)
        {
            _bldr.GetResult(ref dst);
        }
 
        // Returns false if there is no need to process more n-grams.
        private bool ProcessNgrams(int icol)
        {
            Contracts.Assert(_queue.Count > 0);
 
            _ngram[0] = _queue[0];
 
            int slot;
            bool more = true;
            if ((slot = _finder(_ngram, 1, icol, ref more)) >= 0)
            {
                Contracts.Assert(0 <= slot && slot < _slotLim);
                _bldr.AddFeature(slot, 1);
            }
 
            if (_queue.Count == 1 || !more)
                return more;
 
            if (_skipLength > 0)
                return ProcessSkipNgrams(icol, 1, 0);
 
            for (int i = 1; i < _queue.Count; i++)
            {
                _ngram[i] = _queue[i];
                Contracts.Assert(more);
                if ((slot = _finder(_ngram, i + 1, icol, ref more)) >= 0)
                {
                    Contracts.Assert(0 <= slot && slot < _slotLim);
                    _bldr.AddFeature(slot, 1);
                }
                if (!more)
                    return false;
            }
 
            return true;
        }
 
        // Uses DFS. When called with i and skips, it assumes that the
        // first i terms in the _ngram array are already populated using "skips" skips,
        // and it adds the (i+1)st term. It then recursively calls ProcessSkipNgrams
        // to add the next term.
        private bool ProcessSkipNgrams(int icol, int i, int skips)
        {
            Contracts.Assert(0 < i && i < _ngramLength);
            Contracts.Assert(0 <= skips && skips <= _skipLength);
            Contracts.Assert(i + skips < _queue.Count);
            Contracts.Assert(i > 1 || skips == 0);
            Contracts.Assert(_ngram.Length == _ngramLength);
 
            bool more = true;
            for (int k = skips; k <= _skipLength && k + i < _queue.Count; k++)
            {
                _ngram[i] = _queue[k + i];
                int slot;
                Contracts.Assert(more);
                if ((slot = _finder(_ngram, i + 1, icol, ref more)) >= 0)
                {
                    Contracts.Assert(0 <= slot && slot < _slotLim);
                    _bldr.AddFeature(slot, 1);
                }
                if (!more || (i + 1 < _ngramLength && i + k + 1 < _queue.Count && !ProcessSkipNgrams(icol, i + 1, k)))
                    return false;
            }
            return more;
        }
    }
 
    internal static class NgramUtils
    {
        public static bool IsValidNgramRawType(Type rawType)
        {
            // Can only accept key types that can be converted to U8 (ulong).
            return rawType != typeof(ulong);
        }
    }
}