|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;
[assembly: LoadableClass(WordHashBagProducingTransformer.Summary, typeof(IDataTransform), typeof(WordHashBagProducingTransformer), typeof(WordHashBagProducingTransformer.Options), typeof(SignatureDataTransform),
"Word Hash Bag Transform", "WordHashBagTransform", "WordHashBag")]
[assembly: LoadableClass(NgramHashExtractingTransformer.Summary, typeof(INgramExtractorFactory), typeof(NgramHashExtractingTransformer), typeof(NgramHashExtractingTransformer.NgramHashExtractorArguments),
typeof(SignatureNgramExtractorFactory), "Ngram Hash Extractor Transform", "NgramHashExtractorTransform", "NgramHash", NgramHashExtractingTransformer.LoaderSignature)]
[assembly: EntryPointModule(typeof(NgramHashExtractingTransformer.NgramHashExtractorArguments))]
namespace Microsoft.ML.Transforms.Text
{
internal static class WordHashBagProducingTransformer
{
internal sealed class Column : NgramHashExtractingTransformer.ColumnBase
{
internal static Column Parse(string str)
{
Contracts.AssertNonEmpty(str);
var res = new Column();
if (res.TryParse(str))
return res;
return null;
}
private protected override bool TryParse(string str)
{
Contracts.AssertNonEmpty(str);
// We accept N:B:S where N is the new column name, B is the number of bits,
// and S is source column names.
string extra;
if (!base.TryParse(str, out extra))
return false;
if (extra == null)
return true;
int bits;
if (!int.TryParse(extra, out bits))
return false;
NumberOfBits = bits;
return true;
}
internal bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
if (NgramLength != null || SkipLength != null || Seed != null ||
Ordered != null || MaximumNumberOfInverts != null)
{
return false;
}
if (NumberOfBits == null)
return TryUnparseCore(sb);
string extra = NumberOfBits.Value.ToString();
return TryUnparseCore(sb, extra);
}
}
internal sealed class Options : NgramHashExtractingTransformer.ArgumentsBase
{
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:numberOfBits:srcs)",
Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
}
private const string RegistrationName = "WordHashBagTransform";
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. "
+ "It does so by hashing each n-gram and using the hash value as the index in the bag.";
internal static ITransformer CreateTransformer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
h.CheckValue(options, nameof(options));
h.CheckValue(input, nameof(input));
h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
// To each input column to the WordHashBagTransform, a tokenize transform is applied,
// followed by applying WordHashVectorizeTransform.
// Since WordHashBagTransform is a many-to-one column transform, for each
// WordHashBagTransform.Column we may need to define multiple tokenize transform columns.
// NgramHashExtractorTransform may need to define an identical number of HashTransform.Columns.
// The intermediate columns are dropped at the end of using a DropColumnsTransform.
IDataView view = input;
var uniqueSourceNames = NgramExtractionUtils.GenerateUniqueSourceNames(h, options.Columns, view.Schema);
Contracts.Assert(uniqueSourceNames.Length == options.Columns.Length);
var tokenizeColumns = new List<WordTokenizingEstimator.ColumnOptions>();
var extractorCols = new NgramHashExtractingTransformer.Column[options.Columns.Length];
var colCount = options.Columns.Length;
List<string> tmpColNames = new List<string>();
for (int iinfo = 0; iinfo < colCount; iinfo++)
{
var column = options.Columns[iinfo];
int srcCount = column.Source.Length;
var curTmpNames = new string[srcCount];
Contracts.Assert(uniqueSourceNames[iinfo].Length == options.Columns[iinfo].Source.Length);
for (int isrc = 0; isrc < srcCount; isrc++)
tokenizeColumns.Add(new WordTokenizingEstimator.ColumnOptions(curTmpNames[isrc] = uniqueSourceNames[iinfo][isrc], options.Columns[iinfo].Source[isrc]));
tmpColNames.AddRange(curTmpNames);
extractorCols[iinfo] =
new NgramHashExtractingTransformer.Column
{
Name = column.Name,
Source = curTmpNames,
NumberOfBits = column.NumberOfBits,
NgramLength = column.NgramLength,
Seed = column.Seed,
SkipLength = column.SkipLength,
Ordered = column.Ordered,
MaximumNumberOfInverts = column.MaximumNumberOfInverts,
FriendlyNames = options.Columns[iinfo].Source,
UseAllLengths = column.UseAllLengths
};
}
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view);
var featurizeArgs =
new NgramHashExtractingTransformer.Options
{
UseAllLengths = options.UseAllLengths,
NumberOfBits = options.NumberOfBits,
NgramLength = options.NgramLength,
SkipLength = options.SkipLength,
Ordered = options.Ordered,
Seed = options.Seed,
Columns = extractorCols.ToArray(),
MaximumNumberOfInverts = options.MaximumNumberOfInverts
};
view = t1.Transform(view);
ITransformer t2 = NgramHashExtractingTransformer.Create(h, featurizeArgs, view);
// Since we added columns with new names, we need to explicitly drop them before we return the IDataTransform.
ITransformer t3 = new ColumnSelectingTransformer(env, null, tmpColNames.ToArray());
return new TransformerChain<ITransformer>(new[] { t1, t2, t3 });
}
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransformer(env, options, input).Transform(input);
}
/// <summary>
/// A transform that turns a collection of tokenized text (vector of ReadOnlyMemory) into numerical feature vectors
/// using the hashing trick.
/// </summary>
internal static class NgramHashExtractingTransformer
{
internal abstract class ColumnBase : ManyToOneColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length (stores all lengths up to the specified Ngram length)", ShortName = "ngram")]
public int? NgramLength;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Maximum number of tokens to skip when constructing an n-gram",
ShortName = "skips")]
public int? SkipLength;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Number of bits to hash into. Must be between 1 and 30, inclusive.",
ShortName = "bits")]
public int? NumberOfBits;
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
public uint? Seed;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each source column should be included in the hash (when there are multiple source columns).", ShortName = "ord")]
public bool? Ordered;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.",
ShortName = "ih")]
public int? MaximumNumberOfInverts;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Whether to include all n-gram lengths up to " + nameof(NgramLength) + " or only " + nameof(NgramLength),
Name = "AllLengths", ShortName = "all", SortOrder = 4)]
public bool? UseAllLengths;
}
internal sealed class Column : ColumnBase
{
// For all source columns, use these friendly names for the source
// column names instead of the real column names.
public string[] FriendlyNames;
internal static Column Parse(string str)
{
Contracts.AssertNonEmpty(str);
var res = new Column();
if (res.TryParse(str))
return res;
return null;
}
private protected override bool TryParse(string str)
{
Contracts.AssertNonEmpty(str);
// We accept N:B:S where N is the new column name, B is the number of bits,
// and S is source column names.
string extra;
if (!base.TryParse(str, out extra))
return false;
if (extra == null)
return true;
int bits;
if (!int.TryParse(extra, out bits))
return false;
NumberOfBits = bits;
return true;
}
internal bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
if (NgramLength != null || SkipLength != null || Seed != null ||
Ordered != null || MaximumNumberOfInverts != null)
{
return false;
}
if (NumberOfBits == null)
return TryUnparseCore(sb);
string extra = NumberOfBits.Value.ToString();
return TryUnparseCore(sb, extra);
}
}
/// <summary>
/// This class is a merger of <see cref="HashingTransformer.Options"/> and
/// <see cref="NgramHashingTransformer.Options"/>, with the ordered option,
/// the rehashUnigrams option and the allLength option removed.
/// </summary>
internal abstract class ArgumentsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length", ShortName = "ngram", SortOrder = 3)]
public int NgramLength = 1;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Maximum number of tokens to skip when constructing an n-gram",
ShortName = "skips", SortOrder = 4)]
public int SkipLength = 0;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Number of bits to hash into. Must be between 1 and 30, inclusive.",
ShortName = "bits", SortOrder = 2)]
public int NumberOfBits = 16;
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
public uint Seed = 314489979;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Whether the position of each source column should be included in the hash (when there are multiple source columns).",
ShortName = "ord")]
public bool Ordered = true;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.",
ShortName = "ih")]
public int MaximumNumberOfInverts;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Whether to include all n-gram lengths up to ngramLength or only ngramLength",
Name = "AllLengths", ShortName = "all", SortOrder = 4)]
public bool UseAllLengths = true;
}
internal static class DefaultArguments
{
public const int NgramLength = 1;
public const int SkipLength = 0;
public const int NumberOfBits = 16;
public const uint Seed = 314489979;
public const bool Ordered = true;
public const int MaximumNumberOfInverts = 0;
public const bool UseAllLengths = true;
}
[TlcModule.Component(Name = "NGramHash", FriendlyName = "NGram Hash Extractor Transform", Alias = "NGramHashExtractorTransform,NGramHashExtractor",
Desc = "Extracts NGrams from text and convert them to vector using hashing trick.")]
internal sealed class NgramHashExtractorArguments : ArgumentsBase, INgramExtractorFactoryFactory
{
public INgramExtractorFactory CreateComponent(IHostEnvironment env, TermLoaderArguments loaderArgs)
{
return Create(env, this, loaderArgs);
}
}
internal sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
}
internal const string Summary = "A transform that turns a collection of tokenized text (vector of ReadOnlyMemory) into numerical feature vectors using the hashing trick.";
internal const string LoaderSignature = "NgramHashExtractor";
internal static ITransformer Create(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(LoaderSignature);
h.CheckValue(options, nameof(options));
h.CheckValue(input, nameof(input));
h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
var chain = new TransformerChain<ITransformer>();
// To each input column to the NgramHashExtractorArguments, a HashTransform using 31
// bits (to minimize collisions) is applied first, followed by an NgramHashTransform.
var hashColumns = new List<HashingEstimator.ColumnOptions>();
var ngramHashColumns = new NgramHashingEstimator.ColumnOptions[options.Columns.Length];
var colCount = options.Columns.Length;
// The NGramHashExtractor has a ManyToOne column type. To avoid stepping over the source
// column name when a 'name' destination column name was specified, we use temporary column names.
string[][] tmpColNames = new string[colCount][];
for (int iinfo = 0; iinfo < colCount; iinfo++)
{
var column = options.Columns[iinfo];
h.CheckUserArg(!string.IsNullOrWhiteSpace(column.Name), nameof(column.Name));
h.CheckUserArg(Utils.Size(column.Source) > 0 &&
column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source));
int srcCount = column.Source.Length;
tmpColNames[iinfo] = new string[srcCount];
for (int isrc = 0; isrc < srcCount; isrc++)
{
var tmpName = input.Schema.GetTempColumnName(column.Source[isrc]);
tmpColNames[iinfo][isrc] = tmpName;
hashColumns.Add(new HashingEstimator.ColumnOptions(tmpName, column.Source[isrc],
30, column.Seed ?? options.Seed, false, column.MaximumNumberOfInverts ?? options.MaximumNumberOfInverts));
}
ngramHashColumns[iinfo] =
new NgramHashingEstimator.ColumnOptions(column.Name, tmpColNames[iinfo],
column.NgramLength ?? options.NgramLength,
column.SkipLength ?? options.SkipLength,
column.UseAllLengths ?? options.UseAllLengths,
column.NumberOfBits ?? options.NumberOfBits,
column.Seed ?? options.Seed,
column.Ordered ?? options.Ordered,
column.MaximumNumberOfInverts ?? options.MaximumNumberOfInverts);
ngramHashColumns[iinfo].FriendlyNames = column.FriendlyNames;
}
var hashing = new HashingEstimator(h, hashColumns.ToArray()).Fit(input);
return chain.Append(hashing)
.Append(new NgramHashingEstimator(h, ngramHashColumns).Fit(hashing.Transform(input)))
.Append(new ColumnSelectingTransformer(h, null, tmpColNames.SelectMany(cols => cols).ToArray()));
}
internal static ITransformer Create(NgramHashExtractorArguments extractorArgs, IHostEnvironment env, IDataView input,
ExtractorColumn[] cols)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(LoaderSignature);
h.CheckValue(extractorArgs, nameof(extractorArgs));
h.CheckValue(input, nameof(input));
h.CheckUserArg(extractorArgs.SkipLength < extractorArgs.NgramLength, nameof(extractorArgs.SkipLength), "Should be less than " + nameof(extractorArgs.NgramLength));
h.CheckUserArg(Utils.Size(cols) > 0, nameof(Options.Columns), "Must be specified");
var extractorCols = new Column[cols.Length];
for (int i = 0; i < cols.Length; i++)
{
extractorCols[i] =
new Column
{
Name = cols[i].Name,
Source = cols[i].Source,
FriendlyNames = cols[i].FriendlyNames
};
}
var options = new Options
{
Columns = extractorCols,
NgramLength = extractorArgs.NgramLength,
SkipLength = extractorArgs.SkipLength,
NumberOfBits = extractorArgs.NumberOfBits,
MaximumNumberOfInverts = extractorArgs.MaximumNumberOfInverts,
Ordered = extractorArgs.Ordered,
Seed = extractorArgs.Seed,
UseAllLengths = extractorArgs.UseAllLengths
};
return Create(h, options, input);
}
internal static INgramExtractorFactory Create(IHostEnvironment env, NgramHashExtractorArguments extractorArgs,
TermLoaderArguments termLoaderArgs)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(LoaderSignature);
h.CheckValue(extractorArgs, nameof(extractorArgs));
h.CheckParam(termLoaderArgs == null, nameof(termLoaderArgs), "Argument cannot be used with NgramHashExtractor, use NgramExtractor instead");
return new NgramHashExtractorFactory(extractorArgs);
}
}
}
|