File: Text\WordBagTransform.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;
using static Microsoft.ML.Transforms.Text.WordBagBuildingTransformer;
 
[assembly: LoadableClass(WordBagBuildingTransformer.Summary, typeof(IDataTransform), typeof(WordBagBuildingTransformer), typeof(WordBagBuildingTransformer.Options), typeof(SignatureDataTransform),
    "Word Bag Transform", "WordBagTransform", "WordBag")]
 
[assembly: LoadableClass(NgramExtractorTransform.Summary, typeof(INgramExtractorFactory), typeof(NgramExtractorTransform), typeof(NgramExtractorTransform.NgramExtractorArguments),
    typeof(SignatureNgramExtractorFactory), "Ngram Extractor Transform", "NgramExtractorTransform", "Ngram", NgramExtractorTransform.LoaderSignature)]
 
[assembly: EntryPointModule(typeof(NgramExtractorTransform.NgramExtractorArguments))]
 
// These are for the internal only TextExpandingTransformer. Not exposed publically
[assembly: LoadableClass(TextExpandingTransformer.Summary, typeof(IDataTransform), typeof(TextExpandingTransformer), null, typeof(SignatureLoadDataTransform),
    TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(TextExpandingTransformer), null, typeof(SignatureLoadModel),
    TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IRowMapper), typeof(TextExpandingTransformer), null, typeof(SignatureLoadRowMapper),
    TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]
 
namespace Microsoft.ML.Transforms.Text
{
    /// <summary>
    /// Signature for creating an INgramExtractorFactory.
    /// </summary>
    internal delegate void SignatureNgramExtractorFactory(TermLoaderArguments termLoaderArgs);
 
    /// <summary>
    /// A many-to-one column common to both <see cref="NgramExtractorTransform"/>
    /// and <see cref="NgramHashExtractingTransformer"/>.
    /// </summary>
    internal sealed class ExtractorColumn : ManyToOneColumn
    {
        // For all source columns, use these friendly names for the source
        // column names instead of the real column names.
        public string[] FriendlyNames;
    }
 
    internal static class WordBagBuildingTransformer
    {
        internal sealed class Column : ManyToOneColumn
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length", ShortName = "ngram")]
            public int? NgramLength;
 
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Maximum number of tokens to skip when constructing an n-gram",
                ShortName = "skips")]
            public int? SkipLength;
 
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Whether to include all n-gram lengths up to " + nameof(NgramLength) + " or only " + nameof(NgramLength),
                Name = "AllLengths", ShortName = "all")]
            public bool? UseAllLengths;
 
            [Argument(ArgumentType.Multiple, HelpText = "Maximum number of n-grams to store in the dictionary", ShortName = "max")]
            public int[] MaxNumTerms = null;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Statistical measure used to evaluate how important a word is to a document in a corpus")]
            public NgramExtractingEstimator.WeightingCriteria? Weighting;
 
            internal static Column Parse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                if (NgramLength != null || SkipLength != null || UseAllLengths != null || Utils.Size(MaxNumTerms) > 0 ||
                    Weighting != null)
                {
                    return false;
                }
                return TryUnparseCore(sb);
            }
        }
 
        internal sealed class Options : NgramExtractorTransform.ArgumentsBase
        {
            [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:srcs)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public Column[] Columns;
        }
 
        private const string RegistrationName = "WordBagTransform";
 
        internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. It does so by building "
            + "a dictionary of n-grams and using the id in the dictionary as the index in the bag.";
 
        internal static IEstimator<ITransformer> CreateEstimator(IHostEnvironment env, Options options, SchemaShape inputSchema)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
            h.CheckValue(options, nameof(options));
            h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
 
            // Compose the WordBagTransform from a tokenize transform,
            // followed by a NgramExtractionTransform.
            // Since WordBagTransform is a many-to-one column transform, for each
            // WordBagTransform.Column with multiple sources, we first apply a ConcatTransform.
 
            // REVIEW: In order to not get n-grams that cross between vector slots, we need to
            // enable tokenize transforms to insert a special token between slots.
 
            // REVIEW: In order to make it possible to output separate bags for different columns
            // using the same dictionary, we need to find a way to make ConcatTransform remember the boundaries.
 
            var tokenizeColumns = new WordTokenizingEstimator.ColumnOptions[options.Columns.Length];
 
            var extractorArgs =
                new NgramExtractorTransform.Options()
                {
                    MaxNumTerms = options.MaxNumTerms,
                    NgramLength = options.NgramLength,
                    SkipLength = options.SkipLength,
                    UseAllLengths = options.UseAllLengths,
                    Weighting = options.Weighting,
                    Columns = new NgramExtractorTransform.Column[options.Columns.Length]
                };
 
            for (int iinfo = 0; iinfo < options.Columns.Length; iinfo++)
            {
                var column = options.Columns[iinfo];
                h.CheckUserArg(!string.IsNullOrWhiteSpace(column.Name), nameof(column.Name));
                h.CheckUserArg(Utils.Size(column.Source) > 0, nameof(column.Source));
                h.CheckUserArg(column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source));
 
                tokenizeColumns[iinfo] = new WordTokenizingEstimator.ColumnOptions(column.Name, column.Source.Length > 1 ? column.Name : column.Source[0]);
 
                extractorArgs.Columns[iinfo] =
                    new NgramExtractorTransform.Column()
                    {
                        Name = column.Name,
                        Source = column.Name,
                        MaxNumTerms = column.MaxNumTerms,
                        NgramLength = column.NgramLength,
                        SkipLength = column.SkipLength,
                        Weighting = column.Weighting,
                        UseAllLengths = column.UseAllLengths,
                    };
            }
 
            IEstimator<ITransformer> estimator = NgramExtractionUtils.GetConcatEstimator(h, options.Columns);
            if (options.FreqSeparator != default)
            {
                estimator = estimator.Append(new TextExpandingEstimator(h, tokenizeColumns[0].InputColumnName, options.FreqSeparator, options.TermSeparator));
            }
            estimator = estimator.Append(new WordTokenizingEstimator(h, tokenizeColumns));
            estimator = estimator.Append(NgramExtractorTransform.CreateEstimator(h, extractorArgs, estimator.GetOutputSchema(inputSchema)));
            return estimator;
        }
 
        internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
            (IDataTransform)CreateEstimator(env, options, SchemaShape.Create(input.Schema)).Fit(input).Transform(input);
 
        #region TextExpander
 
        // Internal only estimator used to facilitate the expansion of ngrams with pre-defined weights
        internal sealed class TextExpandingEstimator : TrivialEstimator<TextExpandingTransformer>
        {
            private readonly string _columnName;
            public TextExpandingEstimator(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator)
                : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingEstimator)), new TextExpandingTransformer(env, columnName, freqSeparator, termSeparator))
            {
                _columnName = columnName;
            }
 
            public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
            {
                Host.CheckValue(inputSchema, nameof(inputSchema));
                if (!inputSchema.TryFindColumn(_columnName, out SchemaShape.Column outCol) && outCol.ItemType != TextDataViewType.Instance)
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columnName);
                }
 
                return inputSchema;
            }
        }
 
        // Internal only transformer used to facilitate the expansion of ngrams with pre-defined weights
        internal sealed class TextExpandingTransformer : RowToRowTransformerBase
        {
            internal const string Summary = "Expands text in the format of term:freq; to have the correct number of terms";
            internal const string UserName = "Text Expanding Transform";
            internal const string LoadName = "TextExpand";
 
            internal const string LoaderSignature = "TextExpandTransform";
 
            private readonly string _columnName;
            private readonly char _freqSeparator;
            private readonly char _termSeparator;
 
            public TextExpandingTransformer(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator)
                : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingTransformer)))
            {
                _columnName = columnName;
                _freqSeparator = freqSeparator;
                _termSeparator = termSeparator;
            }
 
            private static VersionInfo GetVersionInfo()
            {
                return new VersionInfo(
                    modelSignature: "TEXT EXP",
                    verWrittenCur: 0x00010001, // Initial
                    verReadableCur: 0x00010001,
                    verWeCanReadBack: 0x00010001,
                    loaderSignature: LoaderSignature,
                    loaderAssemblyName: typeof(TextExpandingTransformer).Assembly.FullName);
            }
 
            /// <summary>
            /// Factory method for SignatureLoadModel.
            /// </summary>
            private TextExpandingTransformer(IHostEnvironment env, ModelLoadContext ctx) :
                base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer)))
            {
                Host.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel(GetVersionInfo());
                // *** Binary format ***
                // string: column n ame
                // char: frequency separator
                // char: term separator
 
                _columnName = ctx.Reader.ReadString();
                _freqSeparator = ctx.Reader.ReadChar();
                _termSeparator = ctx.Reader.ReadChar();
            }
 
            /// <summary>
            /// Factory method for SignatureLoadRowMapper.
            /// </summary>
            private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
                => new TextExpandingTransformer(env, ctx).MakeRowMapper(inputSchema);
 
            /// <summary>
            /// Factory method for SignatureLoadDataTransform.
            /// </summary>
            private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
                => new TextExpandingTransformer(env, ctx).MakeDataTransform(input);
 
            private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
            {
                return new Mapper(Host, schema, this);
            }
 
            private protected override void SaveModel(ModelSaveContext ctx)
            {
                Host.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel();
                ctx.SetVersionInfo(GetVersionInfo());
 
                // *** Binary format ***
                // string: column n ame
                // char: frequency separator
                // char: term separator
 
                ctx.Writer.Write(_columnName);
                ctx.Writer.Write(_freqSeparator);
                ctx.Writer.Write(_termSeparator);
            }
 
            private sealed class Mapper : MapperBase
            {
                private readonly TextExpandingTransformer _parent;
                public Mapper(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent)
                    : base(host, inputSchema, parent)
                {
                    _parent = (TextExpandingTransformer)parent;
                }
 
                protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
                {
                    return new DataViewSchema.DetachedColumn[]
                    {
                        new DataViewSchema.DetachedColumn(_parent._columnName, TextDataViewType.Instance)
                    };
                }
 
                protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
                {
                    disposer = null;
                    ValueGetter<ReadOnlyMemory<char>> srcGetter = input.GetGetter<ReadOnlyMemory<char>>(input.Schema.GetColumnOrNull(_parent._columnName).Value);
                    ReadOnlyMemory<char> inputMem = default;
                    var sb = new StringBuilder();
 
                    ValueGetter<ReadOnlyMemory<char>> result = (ref ReadOnlyMemory<char> dst) =>
                    {
                        sb.Clear();
                        srcGetter(ref inputMem);
                        var inputText = inputMem.ToString();
                        foreach (var termFreq in inputText.Split(_parent._termSeparator))
                        {
                            var tf = termFreq.Split(_parent._freqSeparator);
                            if (tf.Length != 2)
                                sb.Append(tf[0] + " ");
                            else
                            {
                                for (int i = 0; i < int.Parse(tf[1]); i++)
                                    sb.Append(tf[0] + " ");
                            }
                        }
 
                        dst = sb.ToString().AsMemory();
                    };
 
                    return result;
                }
 
                private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
                {
                    var active = new bool[InputSchema.Count];
                    if (activeOutput(0))
                    {
                        active[InputSchema.GetColumnOrNull(_parent._columnName).Value.Index] = true;
                    }
                    return col => active[col];
                }
 
                private protected override void SaveModel(ModelSaveContext ctx)
                {
                    _parent.SaveModel(ctx);
                }
            }
        }
 
        #endregion TextExpander
    }
 
    /// <summary>
    /// A transform that turns a collection of tokenized text (vector of ReadOnlyMemory), or vectors of keys into numerical
    /// feature vectors. The feature vectors are counts of n-grams (sequences of consecutive *tokens* -words or keys-
    /// of length 1-n).
    /// </summary>
    internal static class NgramExtractorTransform
    {
        internal sealed class Column : OneToOneColumn
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length (stores all lengths up to the specified Ngram length)", ShortName = "ngram")]
            public int? NgramLength;
 
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Maximum number of tokens to skip when constructing an n-gram",
                ShortName = "skips")]
            public int? SkipLength;
 
            [Argument(ArgumentType.AtMostOnce, HelpText =
                "Whether to include all n-gram lengths up to " + nameof(NgramLength) + " or only " + nameof(NgramLength),
                Name = "AllLengths", ShortName = "all")]
            public bool? UseAllLengths;
 
            // REVIEW: This argument is actually confusing. If you set only one value we will use this value for all n-grams respectfully for example,
            // if we specify 3 n-grams we will have maxNumTerms * 3. And it also pick first value from this array to run term transform, so if you specify
            // something like 1,1,10000, term transform would be run with limitation of only one term.
            [Argument(ArgumentType.Multiple, HelpText = "Maximum number of n-grams to store in the dictionary", ShortName = "max")]
            public int[] MaxNumTerms = null;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The weighting criteria")]
            public NgramExtractingEstimator.WeightingCriteria? Weighting;
 
            internal static Column Parse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                if (NgramLength != null || SkipLength != null || UseAllLengths != null || Utils.Size(MaxNumTerms) > 0 ||
                    Weighting != null)
                {
                    return false;
                }
                return TryUnparseCore(sb);
            }
        }
 
        /// <summary>
        /// This class is a merger of <see cref="ValueToKeyMappingTransformer.Options"/> and
        /// <see cref="NgramExtractingTransformer.Options"/>, with the allLength option removed.
        /// </summary>
        internal abstract class ArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Ngram length", ShortName = "ngram")]
            public int NgramLength = 1;
 
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Maximum number of tokens to skip when constructing an n-gram",
                ShortName = "skips")]
            public int SkipLength = NgramExtractingEstimator.Defaults.SkipLength;
 
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Whether to include all n-gram lengths up to " + nameof(NgramLength) + " or only " + nameof(NgramLength),
                Name = "AllLengths", ShortName = "all")]
            public bool UseAllLengths = NgramExtractingEstimator.Defaults.UseAllLengths;
 
            [Argument(ArgumentType.Multiple, HelpText = "Maximum number of n-grams to store in the dictionary", ShortName = "max")]
            public int[] MaxNumTerms = new int[] { NgramExtractingEstimator.Defaults.MaximumNgramsCount };
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The weighting criteria")]
            public NgramExtractingEstimator.WeightingCriteria Weighting = NgramExtractingEstimator.Defaults.Weighting;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms/frequency pairs.")]
            public char TermSeparator = default;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms from their frequency.")]
            public char FreqSeparator = default;
 
        }
 
        [TlcModule.Component(Name = "NGram", FriendlyName = "NGram Extractor Transform", Alias = "NGramExtractorTransform,NGramExtractor",
            Desc = "Extracts NGrams from text and convert them to vector using dictionary.")]
        public sealed class NgramExtractorArguments : ArgumentsBase, INgramExtractorFactoryFactory
        {
            public INgramExtractorFactory CreateComponent(IHostEnvironment env, TermLoaderArguments loaderArgs)
            {
                return Create(env, this, loaderArgs);
            }
        }
 
        internal sealed class Options : ArgumentsBase
        {
            [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public Column[] Columns;
        }
 
        internal const string Summary = "A transform that turns a collection of tokenized text ReadOnlyMemory, or vectors of keys into numerical " +
            "feature vectors. The feature vectors are counts of n-grams (sequences of consecutive *tokens* -words or keys- of length 1-n).";
 
        internal const string LoaderSignature = "NgramExtractor";
 
        internal static IEstimator<ITransformer> CreateEstimator(IHostEnvironment env, Options options, SchemaShape inputSchema, TermLoaderArguments termLoaderArgs = null)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(LoaderSignature);
            h.CheckValue(options, nameof(options));
            h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
 
            var chain = new EstimatorChain<ITransformer>();
 
            var termCols = new List<Column>();
            var isTermCol = new bool[options.Columns.Length];
 
            for (int i = 0; i < options.Columns.Length; i++)
            {
                var col = options.Columns[i];
 
                h.CheckNonWhiteSpace(col.Name, nameof(col.Name));
                h.CheckNonWhiteSpace(col.Source, nameof(col.Source));
                if (inputSchema.TryFindColumn(col.Source, out var colShape) &&
                    colShape.ItemType is TextDataViewType)
                {
                    termCols.Add(col);
                    isTermCol[i] = true;
                }
            }
 
            // If the column types of args.column are text, apply term transform to convert them to keys.
            // Otherwise, skip term transform and apply n-gram transform directly.
            // This logic allows NgramExtractorTransform to handle both text and key input columns.
            // Note: n-gram transform handles the validation of the types natively (in case the types
            // of args.column are not text nor keys).
            if (termCols.Count > 0)
            {
                var columnOptions = new List<ValueToKeyMappingEstimator.ColumnOptionsBase>();
                string[] missingDropColumns = termLoaderArgs != null && termLoaderArgs.DropUnknowns ? new string[termCols.Count] : null;
 
                for (int iinfo = 0; iinfo < termCols.Count; iinfo++)
                {
                    var column = termCols[iinfo];
                    var colOptions = new ValueToKeyMappingEstimator.ColumnOptions(
                        column.Name,
                        column.Source,
                        maximumNumberOfKeys: Utils.Size(column.MaxNumTerms) > 0 ? column.MaxNumTerms[0] :
                        Utils.Size(options.MaxNumTerms) > 0 ? options.MaxNumTerms[0] :
                        termLoaderArgs == null ? NgramExtractingEstimator.Defaults.MaximumNgramsCount : int.MaxValue,
                        keyOrdinality: termLoaderArgs?.Sort ?? ValueToKeyMappingEstimator.KeyOrdinality.ByOccurrence);
                    if (termLoaderArgs != null)
                    {
                        colOptions.Key = termLoaderArgs.Term;
                        colOptions.Keys = termLoaderArgs.Terms;
                    }
                    columnOptions.Add(colOptions);
 
                    if (missingDropColumns != null)
                        missingDropColumns[iinfo] = column.Name;
                }
 
                IDataView keyData = null;
                if (termLoaderArgs?.DataFile != null)
                {
                    using (var ch = env.Start("Create key data view"))
                        keyData = ValueToKeyMappingTransformer.GetKeyDataViewOrNull(env, ch, termLoaderArgs.DataFile, termLoaderArgs.TermsColumn, termLoaderArgs.Loader, out var autoConvert);
                }
                chain = chain.Append<ITransformer>(new ValueToKeyMappingEstimator(h, columnOptions.ToArray(), keyData));
                if (missingDropColumns != null)
                    chain = chain.Append<ITransformer>(new MissingValueDroppingEstimator(h, missingDropColumns.Select(x => (x, x)).ToArray()));
            }
 
            var ngramColumns = new NgramExtractingEstimator.ColumnOptions[options.Columns.Length];
            for (int iinfo = 0; iinfo < options.Columns.Length; iinfo++)
            {
                var column = options.Columns[iinfo];
                ngramColumns[iinfo] = new NgramExtractingEstimator.ColumnOptions(column.Name,
                    column.NgramLength ?? options.NgramLength,
                    column.SkipLength ?? options.SkipLength,
                    column.UseAllLengths ?? options.UseAllLengths,
                    column.Weighting ?? options.Weighting,
                    column.MaxNumTerms ?? options.MaxNumTerms,
                    isTermCol[iinfo] ? column.Name : column.Source
                    );
            }
            return chain.Append<ITransformer>(new NgramExtractingEstimator(env, ngramColumns));
        }
 
        internal static IDataTransform CreateDataTransform(IHostEnvironment env, Options options, IDataView input,
            TermLoaderArguments termLoaderArgs = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
            return CreateEstimator(env, options, SchemaShape.Create(input.Schema), termLoaderArgs).Fit(input).Transform(input)/* Create(env, options, input, termLoaderArgs).Transform(input) */as IDataTransform;
        }
 
        internal static Options CreateNgramExtractorOptions(NgramExtractorArguments extractorArgs, ExtractorColumn[] cols)
        {
            var extractorCols = new Column[cols.Length];
            for (int i = 0; i < cols.Length; i++)
            {
                Contracts.Check(Utils.Size(cols[i].Source) == 1, "too many source columns");
                extractorCols[i] = new Column { Name = cols[i].Name, Source = cols[i].Source[0] };
            }
 
            var options = new Options
            {
                Columns = extractorCols,
                NgramLength = extractorArgs.NgramLength,
                SkipLength = extractorArgs.SkipLength,
                UseAllLengths = extractorArgs.UseAllLengths,
                MaxNumTerms = extractorArgs.MaxNumTerms,
                Weighting = extractorArgs.Weighting
            };
            return options;
        }
 
        internal static INgramExtractorFactory Create(IHostEnvironment env, NgramExtractorArguments extractorArgs,
            TermLoaderArguments termLoaderArgs)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(LoaderSignature);
            h.CheckValue(extractorArgs, nameof(extractorArgs));
            h.CheckValueOrNull(termLoaderArgs);
 
            return new NgramExtractorFactory(extractorArgs, termLoaderArgs);
        }
    }
 
    /// <summary>
    /// Arguments for defining custom list of terms or data file containing the terms.
    /// The class includes a subset of <see cref="ValueToKeyMappingTransformer"/>'s arguments.
    /// </summary>
    internal sealed class TermLoaderArguments
    {
        [Argument(ArgumentType.AtMostOnce, HelpText = "Comma separated list of terms", Name = "Terms", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
        public string Term;
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "List of terms", Name = "Term", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
        public string[] Terms;
 
        [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Data file containing the terms", ShortName = "data", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
        public string DataFile;
 
        [Argument(ArgumentType.Multiple, HelpText = "Data loader", NullName = "<Auto>", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureDataLoader))]
        internal IComponentFactory<IMultiStreamSource, ILegacyDataLoader> Loader;
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the text column containing the terms", ShortName = "termCol", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
        public string TermsColumn;
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "How items should be ordered when vectorized. By default, they will be in the order encountered. " +
            "If by value, items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a').", SortOrder = 5)]
        public ValueToKeyMappingEstimator.KeyOrdinality Sort = ValueToKeyMappingEstimator.KeyOrdinality.ByOccurrence;
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "Drop unknown terms instead of mapping them to NA term.", ShortName = "dropna", SortOrder = 6)]
        public bool DropUnknowns = false;
    }
 
    /// <summary>
    /// An n-gram extractor factory interface to create an n-gram extractor transform.
    /// </summary>
    internal interface INgramExtractorFactory
    {
        /// <summary>
        /// Whether the extractor transform created by this factory uses the hashing trick
        /// (by using <see cref="HashingTransformer"/> or <see cref="NgramHashingTransformer"/>, for example).
        /// </summary>
        bool UseHashingTrick { get; }
 
        ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColumn[] cols);
    }
 
    [TlcModule.ComponentKind("NgramExtractor")]
    internal interface INgramExtractorFactoryFactory : IComponentFactory<TermLoaderArguments, INgramExtractorFactory> { }
 
    /// <summary>
    /// An implementation of <see cref="INgramExtractorFactory"/> to create <see cref="NgramExtractorTransform"/>.
    /// </summary>
    internal class NgramExtractorFactory : INgramExtractorFactory
    {
        private readonly NgramExtractorTransform.NgramExtractorArguments _extractorArgs;
        private readonly TermLoaderArguments _termLoaderArgs;
 
        public bool UseHashingTrick { get { return false; } }
 
        public NgramExtractorFactory(NgramExtractorTransform.NgramExtractorArguments extractorArgs,
            TermLoaderArguments termLoaderArgs)
        {
            Contracts.CheckValue(extractorArgs, nameof(extractorArgs));
            Contracts.CheckValueOrNull(termLoaderArgs);
            _extractorArgs = extractorArgs;
            _termLoaderArgs = termLoaderArgs;
        }
 
        public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColumn[] cols)
        {
            var options = NgramExtractorTransform.CreateNgramExtractorOptions(_extractorArgs, cols);
            return NgramExtractorTransform.CreateEstimator(env, options, SchemaShape.Create(input.Schema), _termLoaderArgs).Fit(input);
        }
    }
 
    /// <summary>
    /// An implementation of <see cref="INgramExtractorFactory"/> to create <see cref="NgramHashExtractingTransformer"/>.
    /// </summary>
    internal class NgramHashExtractorFactory : INgramExtractorFactory
    {
        private readonly NgramHashExtractingTransformer.NgramHashExtractorArguments _extractorArgs;
 
        public bool UseHashingTrick { get { return true; } }
 
        public NgramHashExtractorFactory(NgramHashExtractingTransformer.NgramHashExtractorArguments extractorArgs)
        {
            Contracts.CheckValue(extractorArgs, nameof(extractorArgs));
            _extractorArgs = extractorArgs;
        }
 
        public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColumn[] cols)
        {
            return NgramHashExtractingTransformer.Create(_extractorArgs, env, input, cols);
        }
    }
 
    internal static class NgramExtractionUtils
    {
        public static IEstimator<ITransformer> GetConcatEstimator(IHostEnvironment env, ManyToOneColumn[] columns)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(columns, nameof(columns));
 
            var estimator = new EstimatorChain<ITransformer>();
            foreach (var col in columns)
            {
                env.CheckUserArg(col != null, nameof(WordBagBuildingTransformer.Options.Columns));
                env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name));
                env.CheckUserArg(Utils.Size(col.Source) > 0, nameof(col.Source));
                env.CheckUserArg(col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source));
                if (col.Source.Length > 1)
                    estimator = estimator.Append<ITransformer>(new ColumnConcatenatingEstimator(env, col.Name, col.Source));
            }
            return estimator;
        }
 
        /// <summary>
        /// Generates and returns unique names for columns source. Each element of the returned array is
        /// an array of unique source names per specific column.
        /// </summary>
        public static string[][] GenerateUniqueSourceNames(IHostEnvironment env, ManyToOneColumn[] columns, DataViewSchema schema)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(columns, nameof(columns));
            env.CheckValue(schema, nameof(schema));
 
            string[][] uniqueNames = new string[columns.Length][];
            int tmp = 0;
            for (int iinfo = 0; iinfo < columns.Length; iinfo++)
            {
                var col = columns[iinfo];
                env.CheckUserArg(col != null, nameof(WordHashBagProducingTransformer.Options.Columns));
                env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name));
                env.CheckUserArg(Utils.Size(col.Source) > 0 &&
                              col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source));
 
                int srcCount = col.Source.Length;
                uniqueNames[iinfo] = new string[srcCount];
                for (int isrc = 0; isrc < srcCount; isrc++)
                {
                    string tmpColName;
                    for (; ; )
                    {
                        tmpColName = string.Format("_tmp{0:000}", tmp++);
                        int index;
                        if (!schema.TryGetColumnIndex(tmpColName, out index))
                            break;
                    }
 
                    uniqueNames[iinfo][isrc] = tmpColName;
                }
            }
 
            return uniqueNames;
        }
    }
}