|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;
using static Microsoft.ML.Transforms.Text.WordBagBuildingTransformer;
[assembly: LoadableClass(WordBagBuildingTransformer.Summary, typeof(IDataTransform), typeof(WordBagBuildingTransformer), typeof(WordBagBuildingTransformer.Options), typeof(SignatureDataTransform),
"Word Bag Transform", "WordBagTransform", "WordBag")]
[assembly: LoadableClass(NgramExtractorTransform.Summary, typeof(INgramExtractorFactory), typeof(NgramExtractorTransform), typeof(NgramExtractorTransform.NgramExtractorArguments),
typeof(SignatureNgramExtractorFactory), "Ngram Extractor Transform", "NgramExtractorTransform", "Ngram", NgramExtractorTransform.LoaderSignature)]
[assembly: EntryPointModule(typeof(NgramExtractorTransform.NgramExtractorArguments))]
// These are for the internal only TextExpandingTransformer. Not exposed publically
[assembly: LoadableClass(TextExpandingTransformer.Summary, typeof(IDataTransform), typeof(TextExpandingTransformer), null, typeof(SignatureLoadDataTransform),
TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]
[assembly: LoadableClass(typeof(TextExpandingTransformer), null, typeof(SignatureLoadModel),
TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]
[assembly: LoadableClass(typeof(IRowMapper), typeof(TextExpandingTransformer), null, typeof(SignatureLoadRowMapper),
TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]
namespace Microsoft.ML.Transforms.Text
{
/// <summary>
/// Signature for creating an INgramExtractorFactory.
/// </summary>
internal delegate void SignatureNgramExtractorFactory(TermLoaderArguments termLoaderArgs);
/// <summary>
/// A many-to-one column common to both <see cref="NgramExtractorTransform"/>
/// and <see cref="NgramHashExtractingTransformer"/>.
/// </summary>
internal sealed class ExtractorColumn : ManyToOneColumn
{
// For all source columns, use these friendly names for the source
// column names instead of the real column names.
public string[] FriendlyNames;
}
internal static class WordBagBuildingTransformer
{
internal sealed class Column : ManyToOneColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length", ShortName = "ngram")]
public int? NgramLength;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Maximum number of tokens to skip when constructing an n-gram",
ShortName = "skips")]
public int? SkipLength;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Whether to include all n-gram lengths up to " + nameof(NgramLength) + " or only " + nameof(NgramLength),
Name = "AllLengths", ShortName = "all")]
public bool? UseAllLengths;
[Argument(ArgumentType.Multiple, HelpText = "Maximum number of n-grams to store in the dictionary", ShortName = "max")]
public int[] MaxNumTerms = null;
[Argument(ArgumentType.AtMostOnce, HelpText = "Statistical measure used to evaluate how important a word is to a document in a corpus")]
public NgramExtractingEstimator.WeightingCriteria? Weighting;
internal static Column Parse(string str)
{
Contracts.AssertNonEmpty(str);
var res = new Column();
if (res.TryParse(str))
return res;
return null;
}
internal bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
if (NgramLength != null || SkipLength != null || UseAllLengths != null || Utils.Size(MaxNumTerms) > 0 ||
Weighting != null)
{
return false;
}
return TryUnparseCore(sb);
}
}
internal sealed class Options : NgramExtractorTransform.ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
}
private const string RegistrationName = "WordBagTransform";
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. It does so by building "
+ "a dictionary of n-grams and using the id in the dictionary as the index in the bag.";
internal static IEstimator<ITransformer> CreateEstimator(IHostEnvironment env, Options options, SchemaShape inputSchema)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
h.CheckValue(options, nameof(options));
h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
// Compose the WordBagTransform from a tokenize transform,
// followed by a NgramExtractionTransform.
// Since WordBagTransform is a many-to-one column transform, for each
// WordBagTransform.Column with multiple sources, we first apply a ConcatTransform.
// REVIEW: In order to not get n-grams that cross between vector slots, we need to
// enable tokenize transforms to insert a special token between slots.
// REVIEW: In order to make it possible to output separate bags for different columns
// using the same dictionary, we need to find a way to make ConcatTransform remember the boundaries.
var tokenizeColumns = new WordTokenizingEstimator.ColumnOptions[options.Columns.Length];
var extractorArgs =
new NgramExtractorTransform.Options()
{
MaxNumTerms = options.MaxNumTerms,
NgramLength = options.NgramLength,
SkipLength = options.SkipLength,
UseAllLengths = options.UseAllLengths,
Weighting = options.Weighting,
Columns = new NgramExtractorTransform.Column[options.Columns.Length]
};
for (int iinfo = 0; iinfo < options.Columns.Length; iinfo++)
{
var column = options.Columns[iinfo];
h.CheckUserArg(!string.IsNullOrWhiteSpace(column.Name), nameof(column.Name));
h.CheckUserArg(Utils.Size(column.Source) > 0, nameof(column.Source));
h.CheckUserArg(column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source));
tokenizeColumns[iinfo] = new WordTokenizingEstimator.ColumnOptions(column.Name, column.Source.Length > 1 ? column.Name : column.Source[0]);
extractorArgs.Columns[iinfo] =
new NgramExtractorTransform.Column()
{
Name = column.Name,
Source = column.Name,
MaxNumTerms = column.MaxNumTerms,
NgramLength = column.NgramLength,
SkipLength = column.SkipLength,
Weighting = column.Weighting,
UseAllLengths = column.UseAllLengths,
};
}
IEstimator<ITransformer> estimator = NgramExtractionUtils.GetConcatEstimator(h, options.Columns);
if (options.FreqSeparator != default)
{
estimator = estimator.Append(new TextExpandingEstimator(h, tokenizeColumns[0].InputColumnName, options.FreqSeparator, options.TermSeparator));
}
estimator = estimator.Append(new WordTokenizingEstimator(h, tokenizeColumns));
estimator = estimator.Append(NgramExtractorTransform.CreateEstimator(h, extractorArgs, estimator.GetOutputSchema(inputSchema)));
return estimator;
}
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateEstimator(env, options, SchemaShape.Create(input.Schema)).Fit(input).Transform(input);
#region TextExpander
// Internal only estimator used to facilitate the expansion of ngrams with pre-defined weights
internal sealed class TextExpandingEstimator : TrivialEstimator<TextExpandingTransformer>
{
private readonly string _columnName;
public TextExpandingEstimator(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingEstimator)), new TextExpandingTransformer(env, columnName, freqSeparator, termSeparator))
{
_columnName = columnName;
}
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
if (!inputSchema.TryFindColumn(_columnName, out SchemaShape.Column outCol) && outCol.ItemType != TextDataViewType.Instance)
{
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columnName);
}
return inputSchema;
}
}
// Internal only transformer used to facilitate the expansion of ngrams with pre-defined weights
internal sealed class TextExpandingTransformer : RowToRowTransformerBase
{
internal const string Summary = "Expands text in the format of term:freq; to have the correct number of terms";
internal const string UserName = "Text Expanding Transform";
internal const string LoadName = "TextExpand";
internal const string LoaderSignature = "TextExpandTransform";
private readonly string _columnName;
private readonly char _freqSeparator;
private readonly char _termSeparator;
public TextExpandingTransformer(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingTransformer)))
{
_columnName = columnName;
_freqSeparator = freqSeparator;
_termSeparator = termSeparator;
}
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "TEXT EXP",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(TextExpandingTransformer).Assembly.FullName);
}
/// <summary>
/// Factory method for SignatureLoadModel.
/// </summary>
private TextExpandingTransformer(IHostEnvironment env, ModelLoadContext ctx) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer)))
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
// *** Binary format ***
// string: column n ame
// char: frequency separator
// char: term separator
_columnName = ctx.Reader.ReadString();
_freqSeparator = ctx.Reader.ReadChar();
_termSeparator = ctx.Reader.ReadChar();
}
/// <summary>
/// Factory method for SignatureLoadRowMapper.
/// </summary>
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
=> new TextExpandingTransformer(env, ctx).MakeRowMapper(inputSchema);
/// <summary>
/// Factory method for SignatureLoadDataTransform.
/// </summary>
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
=> new TextExpandingTransformer(env, ctx).MakeDataTransform(input);
private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
{
return new Mapper(Host, schema, this);
}
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
// string: column n ame
// char: frequency separator
// char: term separator
ctx.Writer.Write(_columnName);
ctx.Writer.Write(_freqSeparator);
ctx.Writer.Write(_termSeparator);
}
private sealed class Mapper : MapperBase
{
private readonly TextExpandingTransformer _parent;
public Mapper(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent)
: base(host, inputSchema, parent)
{
_parent = (TextExpandingTransformer)parent;
}
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
return new DataViewSchema.DetachedColumn[]
{
new DataViewSchema.DetachedColumn(_parent._columnName, TextDataViewType.Instance)
};
}
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
{
disposer = null;
ValueGetter<ReadOnlyMemory<char>> srcGetter = input.GetGetter<ReadOnlyMemory<char>>(input.Schema.GetColumnOrNull(_parent._columnName).Value);
ReadOnlyMemory<char> inputMem = default;
var sb = new StringBuilder();
ValueGetter<ReadOnlyMemory<char>> result = (ref ReadOnlyMemory<char> dst) =>
{
sb.Clear();
srcGetter(ref inputMem);
var inputText = inputMem.ToString();
foreach (var termFreq in inputText.Split(_parent._termSeparator))
{
var tf = termFreq.Split(_parent._freqSeparator);
if (tf.Length != 2)
sb.Append(tf[0] + " ");
else
{
for (int i = 0; i < int.Parse(tf[1]); i++)
sb.Append(tf[0] + " ");
}
}
dst = sb.ToString().AsMemory();
};
return result;
}
private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
var active = new bool[InputSchema.Count];
if (activeOutput(0))
{
active[InputSchema.GetColumnOrNull(_parent._columnName).Value.Index] = true;
}
return col => active[col];
}
private protected override void SaveModel(ModelSaveContext ctx)
{
_parent.SaveModel(ctx);
}
}
}
#endregion TextExpander
}
/// <summary>
/// A transform that turns a collection of tokenized text (vector of ReadOnlyMemory), or vectors of keys into numerical
/// feature vectors. The feature vectors are counts of n-grams (sequences of consecutive *tokens* -words or keys-
/// of length 1-n).
/// </summary>
internal static class NgramExtractorTransform
{
internal sealed class Column : OneToOneColumn
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length (stores all lengths up to the specified Ngram length)", ShortName = "ngram")]
public int? NgramLength;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Maximum number of tokens to skip when constructing an n-gram",
ShortName = "skips")]
public int? SkipLength;
[Argument(ArgumentType.AtMostOnce, HelpText =
"Whether to include all n-gram lengths up to " + nameof(NgramLength) + " or only " + nameof(NgramLength),
Name = "AllLengths", ShortName = "all")]
public bool? UseAllLengths;
// REVIEW: This argument is actually confusing. If you set only one value we will use this value for all n-grams respectfully for example,
// if we specify 3 n-grams we will have maxNumTerms * 3. And it also pick first value from this array to run term transform, so if you specify
// something like 1,1,10000, term transform would be run with limitation of only one term.
[Argument(ArgumentType.Multiple, HelpText = "Maximum number of n-grams to store in the dictionary", ShortName = "max")]
public int[] MaxNumTerms = null;
[Argument(ArgumentType.AtMostOnce, HelpText = "The weighting criteria")]
public NgramExtractingEstimator.WeightingCriteria? Weighting;
internal static Column Parse(string str)
{
Contracts.AssertNonEmpty(str);
var res = new Column();
if (res.TryParse(str))
return res;
return null;
}
internal bool TryUnparse(StringBuilder sb)
{
Contracts.AssertValue(sb);
if (NgramLength != null || SkipLength != null || UseAllLengths != null || Utils.Size(MaxNumTerms) > 0 ||
Weighting != null)
{
return false;
}
return TryUnparseCore(sb);
}
}
/// <summary>
/// This class is a merger of <see cref="ValueToKeyMappingTransformer.Options"/> and
/// <see cref="NgramExtractingTransformer.Options"/>, with the allLength option removed.
/// </summary>
internal abstract class ArgumentsBase
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length", ShortName = "ngram")]
public int NgramLength = 1;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Maximum number of tokens to skip when constructing an n-gram",
ShortName = "skips")]
public int SkipLength = NgramExtractingEstimator.Defaults.SkipLength;
[Argument(ArgumentType.AtMostOnce,
HelpText = "Whether to include all n-gram lengths up to " + nameof(NgramLength) + " or only " + nameof(NgramLength),
Name = "AllLengths", ShortName = "all")]
public bool UseAllLengths = NgramExtractingEstimator.Defaults.UseAllLengths;
[Argument(ArgumentType.Multiple, HelpText = "Maximum number of n-grams to store in the dictionary", ShortName = "max")]
public int[] MaxNumTerms = new int[] { NgramExtractingEstimator.Defaults.MaximumNgramsCount };
[Argument(ArgumentType.AtMostOnce, HelpText = "The weighting criteria")]
public NgramExtractingEstimator.WeightingCriteria Weighting = NgramExtractingEstimator.Defaults.Weighting;
[Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms/frequency pairs.")]
public char TermSeparator = default;
[Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms from their frequency.")]
public char FreqSeparator = default;
}
[TlcModule.Component(Name = "NGram", FriendlyName = "NGram Extractor Transform", Alias = "NGramExtractorTransform,NGramExtractor",
Desc = "Extracts NGrams from text and convert them to vector using dictionary.")]
public sealed class NgramExtractorArguments : ArgumentsBase, INgramExtractorFactoryFactory
{
public INgramExtractorFactory CreateComponent(IHostEnvironment env, TermLoaderArguments loaderArgs)
{
return Create(env, this, loaderArgs);
}
}
internal sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
public Column[] Columns;
}
internal const string Summary = "A transform that turns a collection of tokenized text ReadOnlyMemory, or vectors of keys into numerical " +
"feature vectors. The feature vectors are counts of n-grams (sequences of consecutive *tokens* -words or keys- of length 1-n).";
internal const string LoaderSignature = "NgramExtractor";
internal static IEstimator<ITransformer> CreateEstimator(IHostEnvironment env, Options options, SchemaShape inputSchema, TermLoaderArguments termLoaderArgs = null)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(LoaderSignature);
h.CheckValue(options, nameof(options));
h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
var chain = new EstimatorChain<ITransformer>();
var termCols = new List<Column>();
var isTermCol = new bool[options.Columns.Length];
for (int i = 0; i < options.Columns.Length; i++)
{
var col = options.Columns[i];
h.CheckNonWhiteSpace(col.Name, nameof(col.Name));
h.CheckNonWhiteSpace(col.Source, nameof(col.Source));
if (inputSchema.TryFindColumn(col.Source, out var colShape) &&
colShape.ItemType is TextDataViewType)
{
termCols.Add(col);
isTermCol[i] = true;
}
}
// If the column types of args.column are text, apply term transform to convert them to keys.
// Otherwise, skip term transform and apply n-gram transform directly.
// This logic allows NgramExtractorTransform to handle both text and key input columns.
// Note: n-gram transform handles the validation of the types natively (in case the types
// of args.column are not text nor keys).
if (termCols.Count > 0)
{
var columnOptions = new List<ValueToKeyMappingEstimator.ColumnOptionsBase>();
string[] missingDropColumns = termLoaderArgs != null && termLoaderArgs.DropUnknowns ? new string[termCols.Count] : null;
for (int iinfo = 0; iinfo < termCols.Count; iinfo++)
{
var column = termCols[iinfo];
var colOptions = new ValueToKeyMappingEstimator.ColumnOptions(
column.Name,
column.Source,
maximumNumberOfKeys: Utils.Size(column.MaxNumTerms) > 0 ? column.MaxNumTerms[0] :
Utils.Size(options.MaxNumTerms) > 0 ? options.MaxNumTerms[0] :
termLoaderArgs == null ? NgramExtractingEstimator.Defaults.MaximumNgramsCount : int.MaxValue,
keyOrdinality: termLoaderArgs?.Sort ?? ValueToKeyMappingEstimator.KeyOrdinality.ByOccurrence);
if (termLoaderArgs != null)
{
colOptions.Key = termLoaderArgs.Term;
colOptions.Keys = termLoaderArgs.Terms;
}
columnOptions.Add(colOptions);
if (missingDropColumns != null)
missingDropColumns[iinfo] = column.Name;
}
IDataView keyData = null;
if (termLoaderArgs?.DataFile != null)
{
using (var ch = env.Start("Create key data view"))
keyData = ValueToKeyMappingTransformer.GetKeyDataViewOrNull(env, ch, termLoaderArgs.DataFile, termLoaderArgs.TermsColumn, termLoaderArgs.Loader, out var autoConvert);
}
chain = chain.Append<ITransformer>(new ValueToKeyMappingEstimator(h, columnOptions.ToArray(), keyData));
if (missingDropColumns != null)
chain = chain.Append<ITransformer>(new MissingValueDroppingEstimator(h, missingDropColumns.Select(x => (x, x)).ToArray()));
}
var ngramColumns = new NgramExtractingEstimator.ColumnOptions[options.Columns.Length];
for (int iinfo = 0; iinfo < options.Columns.Length; iinfo++)
{
var column = options.Columns[iinfo];
ngramColumns[iinfo] = new NgramExtractingEstimator.ColumnOptions(column.Name,
column.NgramLength ?? options.NgramLength,
column.SkipLength ?? options.SkipLength,
column.UseAllLengths ?? options.UseAllLengths,
column.Weighting ?? options.Weighting,
column.MaxNumTerms ?? options.MaxNumTerms,
isTermCol[iinfo] ? column.Name : column.Source
);
}
return chain.Append<ITransformer>(new NgramExtractingEstimator(env, ngramColumns));
}
internal static IDataTransform CreateDataTransform(IHostEnvironment env, Options options, IDataView input,
TermLoaderArguments termLoaderArgs = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(input, nameof(input));
return CreateEstimator(env, options, SchemaShape.Create(input.Schema), termLoaderArgs).Fit(input).Transform(input)/* Create(env, options, input, termLoaderArgs).Transform(input) */as IDataTransform;
}
internal static Options CreateNgramExtractorOptions(NgramExtractorArguments extractorArgs, ExtractorColumn[] cols)
{
var extractorCols = new Column[cols.Length];
for (int i = 0; i < cols.Length; i++)
{
Contracts.Check(Utils.Size(cols[i].Source) == 1, "too many source columns");
extractorCols[i] = new Column { Name = cols[i].Name, Source = cols[i].Source[0] };
}
var options = new Options
{
Columns = extractorCols,
NgramLength = extractorArgs.NgramLength,
SkipLength = extractorArgs.SkipLength,
UseAllLengths = extractorArgs.UseAllLengths,
MaxNumTerms = extractorArgs.MaxNumTerms,
Weighting = extractorArgs.Weighting
};
return options;
}
internal static INgramExtractorFactory Create(IHostEnvironment env, NgramExtractorArguments extractorArgs,
TermLoaderArguments termLoaderArgs)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(LoaderSignature);
h.CheckValue(extractorArgs, nameof(extractorArgs));
h.CheckValueOrNull(termLoaderArgs);
return new NgramExtractorFactory(extractorArgs, termLoaderArgs);
}
}
/// <summary>
/// Arguments for defining custom list of terms or data file containing the terms.
/// The class includes a subset of <see cref="ValueToKeyMappingTransformer"/>'s arguments.
/// </summary>
internal sealed class TermLoaderArguments
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Comma separated list of terms", Name = "Terms", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
public string Term;
[Argument(ArgumentType.AtMostOnce, HelpText = "List of terms", Name = "Term", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string[] Terms;
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Data file containing the terms", ShortName = "data", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
public string DataFile;
[Argument(ArgumentType.Multiple, HelpText = "Data loader", NullName = "<Auto>", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureDataLoader))]
internal IComponentFactory<IMultiStreamSource, ILegacyDataLoader> Loader;
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the text column containing the terms", ShortName = "termCol", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
public string TermsColumn;
[Argument(ArgumentType.AtMostOnce, HelpText = "How items should be ordered when vectorized. By default, they will be in the order encountered. " +
"If by value, items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a').", SortOrder = 5)]
public ValueToKeyMappingEstimator.KeyOrdinality Sort = ValueToKeyMappingEstimator.KeyOrdinality.ByOccurrence;
[Argument(ArgumentType.AtMostOnce, HelpText = "Drop unknown terms instead of mapping them to NA term.", ShortName = "dropna", SortOrder = 6)]
public bool DropUnknowns = false;
}
/// <summary>
/// An n-gram extractor factory interface to create an n-gram extractor transform.
/// </summary>
internal interface INgramExtractorFactory
{
/// <summary>
/// Whether the extractor transform created by this factory uses the hashing trick
/// (by using <see cref="HashingTransformer"/> or <see cref="NgramHashingTransformer"/>, for example).
/// </summary>
bool UseHashingTrick { get; }
ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColumn[] cols);
}
[TlcModule.ComponentKind("NgramExtractor")]
internal interface INgramExtractorFactoryFactory : IComponentFactory<TermLoaderArguments, INgramExtractorFactory> { }
/// <summary>
/// An implementation of <see cref="INgramExtractorFactory"/> to create <see cref="NgramExtractorTransform"/>.
/// </summary>
internal class NgramExtractorFactory : INgramExtractorFactory
{
private readonly NgramExtractorTransform.NgramExtractorArguments _extractorArgs;
private readonly TermLoaderArguments _termLoaderArgs;
public bool UseHashingTrick { get { return false; } }
public NgramExtractorFactory(NgramExtractorTransform.NgramExtractorArguments extractorArgs,
TermLoaderArguments termLoaderArgs)
{
Contracts.CheckValue(extractorArgs, nameof(extractorArgs));
Contracts.CheckValueOrNull(termLoaderArgs);
_extractorArgs = extractorArgs;
_termLoaderArgs = termLoaderArgs;
}
public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColumn[] cols)
{
var options = NgramExtractorTransform.CreateNgramExtractorOptions(_extractorArgs, cols);
return NgramExtractorTransform.CreateEstimator(env, options, SchemaShape.Create(input.Schema), _termLoaderArgs).Fit(input);
}
}
/// <summary>
/// An implementation of <see cref="INgramExtractorFactory"/> to create <see cref="NgramHashExtractingTransformer"/>.
/// </summary>
internal class NgramHashExtractorFactory : INgramExtractorFactory
{
private readonly NgramHashExtractingTransformer.NgramHashExtractorArguments _extractorArgs;
public bool UseHashingTrick { get { return true; } }
public NgramHashExtractorFactory(NgramHashExtractingTransformer.NgramHashExtractorArguments extractorArgs)
{
Contracts.CheckValue(extractorArgs, nameof(extractorArgs));
_extractorArgs = extractorArgs;
}
public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColumn[] cols)
{
return NgramHashExtractingTransformer.Create(_extractorArgs, env, input, cols);
}
}
internal static class NgramExtractionUtils
{
public static IEstimator<ITransformer> GetConcatEstimator(IHostEnvironment env, ManyToOneColumn[] columns)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(columns, nameof(columns));
var estimator = new EstimatorChain<ITransformer>();
foreach (var col in columns)
{
env.CheckUserArg(col != null, nameof(WordBagBuildingTransformer.Options.Columns));
env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name));
env.CheckUserArg(Utils.Size(col.Source) > 0, nameof(col.Source));
env.CheckUserArg(col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source));
if (col.Source.Length > 1)
estimator = estimator.Append<ITransformer>(new ColumnConcatenatingEstimator(env, col.Name, col.Source));
}
return estimator;
}
/// <summary>
/// Generates and returns unique names for columns source. Each element of the returned array is
/// an array of unique source names per specific column.
/// </summary>
public static string[][] GenerateUniqueSourceNames(IHostEnvironment env, ManyToOneColumn[] columns, DataViewSchema schema)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(columns, nameof(columns));
env.CheckValue(schema, nameof(schema));
string[][] uniqueNames = new string[columns.Length][];
int tmp = 0;
for (int iinfo = 0; iinfo < columns.Length; iinfo++)
{
var col = columns[iinfo];
env.CheckUserArg(col != null, nameof(WordHashBagProducingTransformer.Options.Columns));
env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name));
env.CheckUserArg(Utils.Size(col.Source) > 0 &&
col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source));
int srcCount = col.Source.Length;
uniqueNames[iinfo] = new string[srcCount];
for (int isrc = 0; isrc < srcCount; isrc++)
{
string tmpColName;
for (; ; )
{
tmpColName = string.Format("_tmp{0:000}", tmp++);
int index;
if (!schema.TryGetColumnIndex(tmpColName, out index))
break;
}
uniqueNames[iinfo][isrc] = tmpColName;
}
}
return uniqueNames;
}
}
}
|