|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
{
/// <summary>
/// This delegate represents a function that gets an n-gram as input, and outputs the id of
/// the n-gram and whether or not to continue processing n-grams.
/// </summary>
/// <param name="ngram">The array containing the n-gram</param>
/// <param name="lim">The n-gram is stored in ngram[0],...ngram[lim-1].</param>
/// <param name="icol">The index of the column the transform is applied to.</param>
/// <param name="more">True if processing should continue, false if it should stop.
/// It is true on input, so only needs to be set to false.</param>
/// <returns>The n-gram slot if it was found, -1 otherwise.</returns>
internal delegate int NgramIdFinder(uint[] ngram, int lim, int icol, ref bool more);
// A class that given a VBuffer of keys, finds all the n-grams in it, and maintains a vector of n-gram-counts.
// The id of each n-gram is found by calling an NgramIdFinder delegate. This class can also be used to build
// an n-gram dictionary, by defining an NgramIdFinder that adds the n-grams to a dictionary and always return false.
internal sealed class NgramBufferBuilder
{
// This buffer builder maintains the vector of n-gram-counts.
private readonly BufferBuilder<float> _bldr;
// A queue that holds _ngramLength+_skipLength keys, so that it contains all the n-grams starting with the
// first key in the n-gram.
private readonly FixedSizeQueue<uint> _queue;
// The maximum n-gram length.
private readonly int _ngramLength;
// The maximum number of skips contained in an n-gram.
private readonly int _skipLength;
// An array of length _ngramLength, containing the current n-gram.
private readonly uint[] _ngram;
// The maximum n-gram id.
private readonly int _slotLim;
private readonly NgramIdFinder _finder;
public const int MaxSkipNgramLength = 10;
public bool IsEmpty { get { return _slotLim == 0; } }
public NgramBufferBuilder(int ngramLength, int skipLength, int slotLim, NgramIdFinder finder)
{
Contracts.Assert(ngramLength > 0);
Contracts.Assert(skipLength >= 0);
Contracts.Assert(ngramLength <= MaxSkipNgramLength - skipLength);
Contracts.Assert(slotLim >= 0);
_ngramLength = ngramLength;
_skipLength = skipLength;
_slotLim = slotLim;
_ngram = new uint[_ngramLength];
_queue = new FixedSizeQueue<uint>(_ngramLength + _skipLength);
_bldr = BufferBuilder<float>.CreateDefault();
_finder = finder;
}
public void Reset()
{
_bldr.Reset(_slotLim, false);
_queue.Clear();
}
public bool AddNgrams(in VBuffer<uint> src, int icol, uint keyMax)
{
Contracts.Assert(icol >= 0);
Contracts.Assert(keyMax > 0);
var srcValues = src.GetValues();
uint curKey = 0;
if (src.IsDense)
{
for (int i = 0; i < src.Length; i++)
{
curKey = srcValues[i];
if (curKey > keyMax)
curKey = 0;
_queue.AddLast(curKey);
// Add the n-gram counts
if (_queue.IsFull && !ProcessNgrams(icol))
return false;
}
}
else
{
var queueSize = _queue.Capacity;
var srcIndices = src.GetIndices();
int iindex = 0;
for (int i = 0; i < src.Length; i++)
{
if (iindex < srcIndices.Length && i == srcIndices[iindex])
{
curKey = srcValues[iindex];
if (curKey > keyMax)
curKey = 0;
iindex++;
}
else
curKey = 0;
_queue.AddLast(curKey);
if (!_queue.IsFull)
continue;
// Add the n-gram counts
if (!ProcessNgrams(icol))
return false;
}
}
if (_queue.IsFull)
_queue.RemoveFirst();
// Process the grams of the remaining terms
while (_queue.Count > 0)
{
if (!ProcessNgrams(icol))
return false;
_queue.RemoveFirst();
}
return true;
}
public void GetResult(ref VBuffer<float> dst)
{
_bldr.GetResult(ref dst);
}
// Returns false if there is no need to process more n-grams.
private bool ProcessNgrams(int icol)
{
Contracts.Assert(_queue.Count > 0);
_ngram[0] = _queue[0];
int slot;
bool more = true;
if ((slot = _finder(_ngram, 1, icol, ref more)) >= 0)
{
Contracts.Assert(0 <= slot && slot < _slotLim);
_bldr.AddFeature(slot, 1);
}
if (_queue.Count == 1 || !more)
return more;
if (_skipLength > 0)
return ProcessSkipNgrams(icol, 1, 0);
for (int i = 1; i < _queue.Count; i++)
{
_ngram[i] = _queue[i];
Contracts.Assert(more);
if ((slot = _finder(_ngram, i + 1, icol, ref more)) >= 0)
{
Contracts.Assert(0 <= slot && slot < _slotLim);
_bldr.AddFeature(slot, 1);
}
if (!more)
return false;
}
return true;
}
// Uses DFS. When called with i and skips, it assumes that the
// first i terms in the _ngram array are already populated using "skips" skips,
// and it adds the (i+1)st term. It then recursively calls ProcessSkipNgrams
// to add the next term.
private bool ProcessSkipNgrams(int icol, int i, int skips)
{
Contracts.Assert(0 < i && i < _ngramLength);
Contracts.Assert(0 <= skips && skips <= _skipLength);
Contracts.Assert(i + skips < _queue.Count);
Contracts.Assert(i > 1 || skips == 0);
Contracts.Assert(_ngram.Length == _ngramLength);
bool more = true;
for (int k = skips; k <= _skipLength && k + i < _queue.Count; k++)
{
_ngram[i] = _queue[k + i];
int slot;
Contracts.Assert(more);
if ((slot = _finder(_ngram, i + 1, icol, ref more)) >= 0)
{
Contracts.Assert(0 <= slot && slot < _slotLim);
_bldr.AddFeature(slot, 1);
}
if (!more || (i + 1 < _ngramLength && i + k + 1 < _queue.Count && !ProcessSkipNgrams(icol, i + 1, k)))
return false;
}
return more;
}
}
internal static class NgramUtils
{
public static bool IsValidNgramRawType(Type rawType)
{
// Can only accept key types that can be converted to U8 (ulong).
return rawType != typeof(ulong);
}
}
}
|