File: Text\TextFeaturizingEstimator.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;
 
[assembly: LoadableClass(TextFeaturizingEstimator.Summary, typeof(IDataTransform), typeof(TextFeaturizingEstimator), typeof(TextFeaturizingEstimator.Options), typeof(SignatureDataTransform),
    TextFeaturizingEstimator.UserName, "TextTransform", TextFeaturizingEstimator.LoaderSignature)]
 
[assembly: LoadableClass(TextFeaturizingEstimator.Summary, typeof(ITransformer), typeof(TextFeaturizingEstimator), null, typeof(SignatureLoadModel),
    TextFeaturizingEstimator.UserName, "TextTransform", TextFeaturizingEstimator.LoaderSignature)]
 
namespace Microsoft.ML.Transforms.Text
{
    using CaseMode = TextNormalizingEstimator.CaseMode;
    using StopWordsCol = StopWordsRemovingTransformer.Column;
 
    /// <summary>
    /// Defines the different type of stop words remover supported.
    /// </summary>
    public interface IStopWordsRemoverOptions { }
 
    /// <summary>
    ///  An estimator that turns a collection of text documents into numerical feature vectors.
    ///  The feature vectors are normalized counts of word and/or character n-grams (based on the options supplied).
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    ///
    /// ###  Estimator Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Does this estimator need to look at the data to train its parameters? | Yes. |
    /// | Input column data type | [text](xref:Microsoft.ML.Data.TextDataViewType) |
    /// | Output column data type | Vector of <xref:System.Single> |
    /// | Exportable to ONNX | No |
    ///
    /// This estimator gives the user one-stop solution for doing:
    /// * Language Detection
    /// * [Tokenization](https://en.wikipedia.org/wiki/Lexical_analysis#Tokenization)
    /// * [Text normalization](https://en.wikipedia.org/wiki/Text_normalization)
    /// * [Predefined and custom stopwords removal](https://en.wikipedia.org/wiki/Stop_words)
    /// * [Word-based or character-based Ngram extraction and SkipGram extraction (through the advanced [options](xref:Microsoft.ML.Transforms.TextFeaturizingEstimator.Options.WordFeatureExtractor))](https://en.wikipedia.org/wiki/N-gram)
    /// * [TF, IDF or TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf)
    /// * [L-p vector normalization](xref: Microsoft.ML.Transforms.LpNormNormalizingTransformer)
    ///
    /// By default the features are made of (word/character) n-grams/skip-grams​ and the number of features are equal to the vocabulary size found by analyzing the data.
    /// To output an additional column with the tokens generated, use [OutputTokensColumnName](xref:Microsoft.ML.Transforms.Text.TextFeaturizingEstimator.Options.OutputTokensColumnName).
    /// The number of features can also be specified by selecting the maximum number of n-gram to keep in the <xref:Microsoft.ML.Transforms.Text.TextFeaturizingEstimator.Options>, where the estimator can be further tuned.
    ///
    /// Check the See Also section for links to usage examples.
    /// ]]></format>
    /// </remarks>
    /// <seealso cref="TextCatalog.FeaturizeText(TransformsCatalog.TextTransforms, string, Options, string[])"/>
    /// <seealso cref="TextCatalog.FeaturizeText(TransformsCatalog.TextTransforms, string, string)"/>
    public sealed class TextFeaturizingEstimator : IEstimator<ITransformer>
    {
        /// <summary>
        /// Text language. This enumeration is serialized.
        /// </summary>
        public enum Language
        {
            English = 1,
            French = 2,
            German = 3,
            Dutch = 4,
            Italian = 5,
            Spanish = 6,
            Japanese = 7
        }
 
        /// <summary>
        /// Text vector normalizer kind.
        /// </summary>
        public enum NormFunction
        {
            /// <summary>
            /// Use this to disable normalization.
            /// </summary>
            None = 0,
            /// <summary>
            /// L1-norm.
            /// </summary>
            L1 = 1,
            /// <summary>
            /// L2-norm.
            /// </summary>
            L2 = 2,
            /// <summary>
            /// Infinity-norm.
            /// </summary>
            Infinity = 3
        }
 
        internal sealed class Column : ManyToOneColumn
        {
            internal static Column Parse(string str)
            {
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                return TryUnparseCore(sb);
            }
        }
 
        /// <summary>
        /// Advanced options for the <see cref="TextFeaturizingEstimator"/>.
        /// </summary>
        public sealed class Options : TransformInputBase
        {
            [Argument(ArgumentType.Required, HelpText = "New column definition (optional form: name:srcs).", Name = "Column", ShortName = "col", SortOrder = 1)]
            internal Column Columns;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Dataset language or 'AutoDetect' to detect language per row.", ShortName = "lang", SortOrder = 3)]
            internal Language Language = DefaultLanguage;
 
            [Argument(ArgumentType.Multiple, Name = "StopWordsRemover", HelpText = "Stopwords remover.", ShortName = "remover", NullName = "<None>", SortOrder = 4)]
            internal IStopWordsRemoverFactory StopWordsRemover;
 
            /// <summary>
            /// The underlying state of <see cref="StopWordsRemover"/> and <see cref="StopWordsRemoverOptions"/>.
            /// </summary>
            private IStopWordsRemoverOptions _stopWordsRemoverOptions;
 
            /// <summary>
            /// Option to set type of stop word remover to use.
            /// The following options are available
            /// <list type="bullet">
            ///     <item>
            ///         <description>The <see cref="StopWordsRemovingEstimator.Options"/> removes the language specific list of stop words from the input.</description>
            ///     </item>
            ///     <item>
            ///        <description>The <see cref="CustomStopWordsRemovingEstimator.Options"/> uses user provided list of stop words.</description>
            ///     </item>
            /// </list>
            /// Setting this to 'null' does not remove stop words from the input.
            /// </summary>
            public IStopWordsRemoverOptions StopWordsRemoverOptions
            {
                get { return _stopWordsRemoverOptions; }
                set
                {
                    _stopWordsRemoverOptions = value;
                    IStopWordsRemoverFactory options = null;
                    if (_stopWordsRemoverOptions != null)
                    {
                        if (_stopWordsRemoverOptions is StopWordsRemovingEstimator.Options)
                        {
                            options = new PredefinedStopWordsRemoverFactory();
                            Language = (_stopWordsRemoverOptions as StopWordsRemovingEstimator.Options).Language;
                        }
                        else if (_stopWordsRemoverOptions is CustomStopWordsRemovingEstimator.Options)
                        {
                            var stopwords = (_stopWordsRemoverOptions as CustomStopWordsRemovingEstimator.Options).StopWords;
                            options = new CustomStopWordsRemovingTransformer.LoaderArguments()
                            {
                                Stopwords = stopwords,
                                Stopword = string.Join(",", stopwords)
                            };
                        }
                    }
                    StopWordsRemover = options;
                }
            }
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Casing text using the rules of the invariant culture.", Name = "TextCase", ShortName = "case", SortOrder = 5)]
            public CaseMode CaseMode = TextNormalizingEstimator.Defaults.Mode;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep diacritical marks or remove them.", ShortName = "diac", SortOrder = 6)]
            public bool KeepDiacritics = TextNormalizingEstimator.Defaults.KeepDiacritics;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep punctuation marks or remove them.", ShortName = "punc", SortOrder = 7)]
            public bool KeepPunctuations = TextNormalizingEstimator.Defaults.KeepPunctuations;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep numbers or remove them.", ShortName = "num", SortOrder = 8)]
            public bool KeepNumbers = TextNormalizingEstimator.Defaults.KeepNumbers;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Column containing the transformed text tokens.", ShortName = "tokens,showtext,showTransformedText", SortOrder = 9)]
            public string OutputTokensColumnName;
 
            [Argument(ArgumentType.Multiple, HelpText = "A dictionary of allowed terms.", ShortName = "dict", NullName = "<None>", SortOrder = 10, Hide = true)]
            internal TermLoaderArguments Dictionary;
 
            [TGUI(Label = "Word Gram Extractor")]
            [Argument(ArgumentType.Multiple, Name = "WordFeatureExtractor", HelpText = "Ngram feature extractor to use for words (WordBag/WordHashBag).", ShortName = "wordExtractor", NullName = "<None>", SortOrder = 11)]
            internal INgramExtractorFactoryFactory WordFeatureExtractorFactory;
 
            /// <summary>
            /// The underlying state of <see cref="WordFeatureExtractorFactory"/> and <see cref="WordFeatureExtractor"/>.
            /// </summary>
            private WordBagEstimator.Options _wordFeatureExtractor;
 
            /// <summary>
            /// Norm of the output vector. It will be normalized to one.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Normalize vectors (rows) individually by rescaling them to unit norm.", Name = "VectorNormalizer", ShortName = "norm", SortOrder = 13)]
            public NormFunction Norm = NormFunction.L2;
 
            /// <summary>
            /// Ngram feature extractor to use for words (WordBag/WordHashBag).
            /// Set to <see langword="null" /> to turn off n-gram generation for words.
            /// </summary>
            public WordBagEstimator.Options WordFeatureExtractor
            {
                get { return _wordFeatureExtractor; }
                set
                {
                    _wordFeatureExtractor = value;
                    NgramExtractorTransform.NgramExtractorArguments extractor = null;
                    if (_wordFeatureExtractor != null)
                    {
                        extractor = new NgramExtractorTransform.NgramExtractorArguments();
                        extractor.NgramLength = _wordFeatureExtractor.NgramLength;
                        extractor.SkipLength = _wordFeatureExtractor.SkipLength;
                        extractor.UseAllLengths = _wordFeatureExtractor.UseAllLengths;
                        extractor.MaxNumTerms = _wordFeatureExtractor.MaximumNgramsCount;
                        extractor.Weighting = _wordFeatureExtractor.Weighting;
                    }
                    WordFeatureExtractorFactory = extractor;
                }
            }
 
            [TGUI(Label = "Char Gram Extractor")]
            [Argument(ArgumentType.Multiple, Name = "CharFeatureExtractor", HelpText = "Ngram feature extractor to use for characters (WordBag/WordHashBag).", ShortName = "charExtractor", NullName = "<None>", SortOrder = 12)]
            internal INgramExtractorFactoryFactory CharFeatureExtractorFactory;
 
            /// <summary>
            /// The underlying state of <see cref="CharFeatureExtractorFactory"/> and <see cref="CharFeatureExtractor"/>
            /// </summary>
            private WordBagEstimator.Options _charFeatureExtractor;
 
            /// <summary>
            /// Ngram feature extractor to use for characters (WordBag/WordHashBag).
            /// Set to <see langword="null" /> to turn off n-gram generation for characters.
            /// </summary>
            public WordBagEstimator.Options CharFeatureExtractor
            {
                get { return _charFeatureExtractor; }
                set
                {
                    _charFeatureExtractor = value;
                    NgramExtractorTransform.NgramExtractorArguments extractor = null;
                    if (_charFeatureExtractor != null)
                    {
                        extractor = new NgramExtractorTransform.NgramExtractorArguments();
                        extractor.NgramLength = _charFeatureExtractor.NgramLength;
                        extractor.SkipLength = _charFeatureExtractor.SkipLength;
                        extractor.UseAllLengths = _charFeatureExtractor.UseAllLengths;
                        extractor.MaxNumTerms = _charFeatureExtractor.MaximumNgramsCount;
                        extractor.Weighting = _charFeatureExtractor.Weighting;
                    }
                    CharFeatureExtractorFactory = extractor;
                }
            }
 
            public Options()
            {
                WordFeatureExtractor = new WordBagEstimator.Options();
                CharFeatureExtractor = new WordBagEstimator.Options() { NgramLength = 3, UseAllLengths = false };
            }
        }
 
        internal readonly string OutputColumn;
        private readonly string[] _inputColumns;
        private IReadOnlyCollection<string> InputColumns => _inputColumns.AsReadOnly();
        internal Options OptionalSettings { get; }
 
        // These parameters are hardcoded for now.
        // REVIEW: expose them once sub-transforms are estimators.
        private IStopWordsRemoverFactory _stopWordsRemover;
        private TermLoaderArguments _dictionary;
        private INgramExtractorFactoryFactory _wordFeatureExtractor;
        private INgramExtractorFactoryFactory _charFeatureExtractor;
 
        private readonly IHost _host;
 
        /// <summary>
        /// A distilled version of the TextFeaturizingEstimator  Arguments, with all fields marked readonly and
        /// only the exact set of information needed to construct the transforms preserved.
        /// </summary>
        private sealed class TransformApplierParams
        {
            public readonly INgramExtractorFactory WordExtractorFactory;
            public readonly INgramExtractorFactory CharExtractorFactory;
 
            public readonly NormFunction Norm;
            public readonly Language Language;
            public readonly IStopWordsRemoverFactory StopWordsRemover;
            public readonly CaseMode TextCase;
            public readonly bool KeepDiacritics;
            public readonly bool KeepPunctuations;
            public readonly bool KeepNumbers;
            public readonly string OutputTextTokensColumnName;
            public readonly TermLoaderArguments Dictionary;
 
            public StopWordsRemovingEstimator.Language StopwordsLanguage
                => (StopWordsRemovingEstimator.Language)Enum.Parse(typeof(StopWordsRemovingEstimator.Language), Language.ToString());
 
            internal LpNormNormalizingEstimatorBase.NormFunction LpNorm
            {
                get
                {
                    switch (Norm)
                    {
                        case NormFunction.L1:
                            return LpNormNormalizingEstimatorBase.NormFunction.L1;
                        case NormFunction.L2:
                            return LpNormNormalizingEstimatorBase.NormFunction.L2;
                        case NormFunction.Infinity:
                            return LpNormNormalizingEstimatorBase.NormFunction.Infinity;
                        default:
                            Contracts.Assert(false, "Unexpected normalizer type");
                            return LpNormNormalizingEstimatorBase.NormFunction.L2;
                    }
                }
            }
 
            // These properties encode the logic needed to determine which transforms to apply.
            #region NeededTransforms
            public bool NeedsWordTokenizationTransform { get { return WordExtractorFactory != null || NeedsRemoveStopwordsTransform || !string.IsNullOrEmpty(OutputTextTokensColumnName); } }
 
            public bool NeedsRemoveStopwordsTransform { get { return StopWordsRemover != null; } }
 
            public bool NeedsNormalizeTransform
            {
                get
                {
                    return
                        TextCase != CaseMode.None ||
                        !KeepDiacritics ||
                        !KeepPunctuations ||
                        !KeepNumbers;
                }
            }
 
            private bool UsesHashExtractors
            {
                get
                {
                    return
                        (WordExtractorFactory == null ? true : WordExtractorFactory.UseHashingTrick) &&
                        (CharExtractorFactory == null ? true : CharExtractorFactory.UseHashingTrick);
                }
            }
 
            // If we're performing language auto detection, or either of our extractors aren't hashing then
            // we need all the input text concatenated into a single ReadOnlyMemory, for the LanguageDetectionTransform
            // to operate on the entire text vector, and for the Dictionary feature extractor to build its bound dictionary
            // correctly.
            public bool NeedInitialSourceColumnConcatTransform
            {
                get
                {
                    return !UsesHashExtractors;
                }
            }
            #endregion
 
            public TransformApplierParams(TextFeaturizingEstimator parent)
            {
                var host = parent._host;
                host.Check(Enum.IsDefined(typeof(Language), parent.OptionalSettings.Language));
                host.Check(Enum.IsDefined(typeof(CaseMode), parent.OptionalSettings.CaseMode));
                WordExtractorFactory = parent._wordFeatureExtractor?.CreateComponent(host, parent._dictionary);
                CharExtractorFactory = parent._charFeatureExtractor?.CreateComponent(host, parent._dictionary);
                Norm = parent.OptionalSettings.Norm;
                Language = parent.OptionalSettings.Language;
                StopWordsRemover = parent._stopWordsRemover;
                TextCase = parent.OptionalSettings.CaseMode;
                KeepDiacritics = parent.OptionalSettings.KeepDiacritics;
                KeepPunctuations = parent.OptionalSettings.KeepPunctuations;
                KeepNumbers = parent.OptionalSettings.KeepNumbers;
                OutputTextTokensColumnName = parent.OptionalSettings.OutputTokensColumnName;
                Dictionary = parent._dictionary;
            }
        }
 
        internal const string Summary = "A transform that turns a collection of text documents into numerical feature vectors. " +
            "The feature vectors are normalized counts of (word and/or character) n-grams in a given tokenized text.";
 
        internal const string UserName = "Text Transform";
        internal const string LoaderSignature = "Text";
 
        internal const Language DefaultLanguage = Language.English;
 
        internal TextFeaturizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null)
            : this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName })
        {
        }
 
        internal TextFeaturizingEstimator(IHostEnvironment env, string name, IEnumerable<string> source, Options options = null)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(TextFeaturizingEstimator));
            _host.CheckValue(source, nameof(source));
            _host.CheckParam(source.Any(), nameof(source));
            _host.CheckParam(!source.Any(string.IsNullOrWhiteSpace), nameof(source));
            _host.CheckNonEmpty(name, nameof(name));
            _host.CheckValueOrNull(options);
 
            _inputColumns = source.ToArray();
            OutputColumn = name;
 
            OptionalSettings = new Options();
            if (options != null)
                OptionalSettings = options;
 
            _stopWordsRemover = OptionalSettings.StopWordsRemover;
            _dictionary = null;
            _wordFeatureExtractor = OptionalSettings.WordFeatureExtractorFactory;
            _charFeatureExtractor = OptionalSettings.CharFeatureExtractorFactory;
 
        }
 
        /// <summary>
        /// Trains and returns a <see cref="ITransformer"/>.
        /// </summary>
        public ITransformer Fit(IDataView input)
        {
            var h = _host;
            h.CheckValue(input, nameof(input));
 
            var tparams = new TransformApplierParams(this);
            string[] textCols = _inputColumns;
            string[] wordTokCols = null;
            string[] charTokCols = null;
            string wordFeatureCol = null;
            string charFeatureCol = null;
            List<string> tempCols = new List<string>();
            IDataView view = input;
            TransformerChain<ITransformer> chain = new TransformerChain<ITransformer>();
 
            if (tparams.NeedInitialSourceColumnConcatTransform && textCols.Length > 1)
            {
                var srcCols = textCols;
                textCols = new[] { GenerateColumnName(input.Schema, OutputColumn, "InitialConcat") };
                tempCols.Add(textCols[0]);
                chain = AddToChainAndTransform(chain, new ColumnConcatenatingTransformer(h, textCols[0], srcCols), ref view);
            }
 
            if (tparams.NeedsNormalizeTransform)
            {
                var xfCols = new (string outputColumnName, string inputColumnName)[textCols.Length];
                string[] dstCols = new string[textCols.Length];
                for (int i = 0; i < textCols.Length; i++)
                {
                    dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer");
                    tempCols.Add(dstCols[i]);
                    xfCols[i] = (dstCols[i], textCols[i]);
                }
                chain = AddToChainAndTransform(chain,
                    new TextNormalizingEstimator(h, tparams.TextCase, tparams.KeepDiacritics, tparams.KeepPunctuations,
                    tparams.KeepNumbers, xfCols).Fit(view), ref view);
 
                textCols = dstCols;
            }
 
            if (tparams.NeedsWordTokenizationTransform)
            {
                var xfCols = new WordTokenizingEstimator.ColumnOptions[textCols.Length];
                wordTokCols = new string[textCols.Length];
                for (int i = 0; i < textCols.Length; i++)
                {
                    var col = new WordTokenizingEstimator.ColumnOptions(GenerateColumnName(view.Schema, textCols[i], "WordTokenizer"), textCols[i]);
                    xfCols[i] = col;
                    wordTokCols[i] = col.Name;
                    tempCols.Add(col.Name);
                }
 
                chain = AddToChainAndTransform(chain, new WordTokenizingEstimator(h, xfCols).Fit(view), ref view);
            }
 
            if (tparams.NeedsRemoveStopwordsTransform)
            {
                Contracts.Assert(wordTokCols != null, "StopWords transform requires that word tokenization has been applied to the input text.");
                var xfCols = new StopWordsCol[wordTokCols.Length];
                var dstCols = new string[wordTokCols.Length];
                for (int i = 0; i < wordTokCols.Length; i++)
                {
                    var col = new StopWordsCol();
                    col.Source = wordTokCols[i];
                    col.Name = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform");
                    dstCols[i] = col.Name;
                    tempCols.Add(col.Name);
                    col.Language = tparams.StopwordsLanguage;
 
                    xfCols[i] = col;
                }
                chain = AddToChainAndTransform(chain, tparams.StopWordsRemover.CreateComponent(h, view, xfCols), ref view);
                wordTokCols = dstCols;
            }
 
            if (tparams.WordExtractorFactory != null)
            {
                var dstCol = GenerateColumnName(view.Schema, OutputColumn, "WordExtractor");
                tempCols.Add(dstCol);
                chain = AddToChainAndTransform(chain, tparams.WordExtractorFactory.Create(h, view, new[] {
                    new ExtractorColumn()
                    {
                        Name = dstCol,
                        Source = wordTokCols,
                        FriendlyNames = _inputColumns
                    }}), ref view);
                wordFeatureCol = dstCol;
            }
 
            if (!string.IsNullOrEmpty(tparams.OutputTextTokensColumnName))
            {
                string[] srcCols = wordTokCols ?? textCols;
                chain = AddToChainAndTransform(chain, new ColumnConcatenatingTransformer(h, tparams.OutputTextTokensColumnName, srcCols), ref view);
            }
 
            if (tparams.CharExtractorFactory != null)
            {
                var srcCols = tparams.NeedsRemoveStopwordsTransform ? wordTokCols : textCols;
                charTokCols = new string[srcCols.Length];
                var xfCols = new (string outputColumnName, string inputColumnName)[srcCols.Length];
                for (int i = 0; i < srcCols.Length; i++)
                {
                    xfCols[i] = (GenerateColumnName(view.Schema, srcCols[i], "CharTokenizer"), srcCols[i]);
                    tempCols.Add(xfCols[i].outputColumnName);
                    charTokCols[i] = xfCols[i].outputColumnName;
                }
                chain = AddToChainAndTransform(chain, new TokenizingByCharactersTransformer(h, columns: xfCols), ref view);
 
                charFeatureCol = GenerateColumnName(view.Schema, OutputColumn, "CharExtractor");
                tempCols.Add(charFeatureCol);
                chain = AddToChainAndTransform(chain, tparams.CharExtractorFactory.Create(h, view, new[] {
                    new ExtractorColumn()
                    {
                        Source = charTokCols,
                        FriendlyNames = _inputColumns,
                        Name = charFeatureCol
                    } }), ref view);
            }
 
            if (tparams.Norm != NormFunction.None)
            {
                var xfCols = new List<LpNormNormalizingEstimator.ColumnOptions>(2);
 
                if (charFeatureCol != null)
                {
                    var dstCol = GenerateColumnName(view.Schema, charFeatureCol, "LpCharNorm");
                    tempCols.Add(dstCol);
                    xfCols.Add(new LpNormNormalizingEstimator.ColumnOptions(dstCol, charFeatureCol, norm: tparams.LpNorm));
                    charFeatureCol = dstCol;
                }
 
                if (wordFeatureCol != null)
                {
                    var dstCol = GenerateColumnName(view.Schema, wordFeatureCol, "LpWordNorm");
                    tempCols.Add(dstCol);
                    xfCols.Add(new LpNormNormalizingEstimator.ColumnOptions(dstCol, wordFeatureCol, norm: tparams.LpNorm));
                    wordFeatureCol = dstCol;
                }
 
                if (xfCols.Count > 0)
                    chain = AddToChainAndTransform(chain, new LpNormNormalizingTransformer(h, xfCols.ToArray()), ref view);
            }
 
            {
                var srcTaggedCols = new List<KeyValuePair<string, string>>(2);
                if (charFeatureCol != null && wordFeatureCol != null)
                {
                    // If we're producing both char and word grams, then we need to disambiguate
                    // between them (for example, the word 'a' vs. the char gram 'a').
                    srcTaggedCols.Add(new KeyValuePair<string, string>("Char", charFeatureCol));
                    srcTaggedCols.Add(new KeyValuePair<string, string>("Word", wordFeatureCol));
                }
                else
                {
                    // Otherwise, simply use the slot names, omitting the original source column names
                    // entirely. For the Concat transform setting the Key == Value of the TaggedColumn
                    // KVP signals this intent.
                    Contracts.Assert(charFeatureCol != null || wordFeatureCol != null || !string.IsNullOrEmpty(tparams.OutputTextTokensColumnName));
                    if (charFeatureCol != null)
                        srcTaggedCols.Add(new KeyValuePair<string, string>(charFeatureCol, charFeatureCol));
                    else if (wordFeatureCol != null)
                        srcTaggedCols.Add(new KeyValuePair<string, string>(wordFeatureCol, wordFeatureCol));
                }
                if (srcTaggedCols.Count > 0)
                {
                    chain = AddToChainAndTransform(chain, new ColumnConcatenatingTransformer(h, new ColumnConcatenatingTransformer.ColumnOptions(OutputColumn,
                        srcTaggedCols.Select(kvp => (kvp.Value, kvp.Key)))), ref view);
                }
            }
 
            chain = AddToChainAndTransform(chain, new ColumnSelectingTransformer(h, null, tempCols.ToArray()), ref view);
            return new Transformer(_host, chain);
        }
 
        private static TransformerChain<ITransformer> AddToChainAndTransform(TransformerChain<ITransformer> chain, ITransformer transformer, ref IDataView view)
        {
            Contracts.AssertValue(chain);
            Contracts.AssertValue(transformer);
            Contracts.AssertValue(view);
 
            view = transformer.Transform(view);
            return chain.Append(transformer);
        }
 
        private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx)
            => new Transformer(env, ctx);
 
        private static string GenerateColumnName(DataViewSchema schema, string srcName, string xfTag)
        {
            return schema.GetTempColumnName(string.Format("{0}_{1}", srcName, xfTag));
        }
 
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);
            foreach (var srcName in _inputColumns)
            {
                if (!inputSchema.TryFindColumn(srcName, out var col))
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName);
                if (!(col.ItemType is TextDataViewType))
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName, "scalar or vector of String", col.GetTypeString());
            }
 
            var metadata = new List<SchemaShape.Column>(2);
            metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
            if (OptionalSettings.Norm != NormFunction.None)
                metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
 
            result[OutputColumn] = new SchemaShape.Column(OutputColumn, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false,
                new SchemaShape(metadata));
 
            if (!string.IsNullOrEmpty(OptionalSettings.OutputTokensColumnName))
            {
                string name = OptionalSettings.OutputTokensColumnName;
                result[name] = new SchemaShape.Column(name, SchemaShape.Column.VectorKind.VariableVector, TextDataViewType.Instance, false);
            }
 
            return new SchemaShape(result.Values);
        }
 
        // Factory method for SignatureDataTransform.
        internal static IDataTransform Create(IHostEnvironment env, Options args, IDataView data)
        {
            var estimator = new TextFeaturizingEstimator(env, args.Columns.Name, args.Columns.Source ?? new[] { args.Columns.Name }, args);
            estimator._stopWordsRemover = args.StopWordsRemover;
            estimator._dictionary = args.Dictionary;
            // Review: I don't think the following two lines are needed.
            estimator._wordFeatureExtractor = args.WordFeatureExtractorFactory;
            estimator._charFeatureExtractor = args.CharFeatureExtractorFactory;
            return estimator.Fit(data).Transform(data) as IDataTransform;
        }
 
        private sealed class Transformer : ITransformer
        {
            private const string TransformDirTemplate = "Step_{0:000}";
            private const uint VerIDataTransform = 0x00010001;
 
            private readonly IHost _host;
            private readonly TransformerChain<ITransformer> _chain;
 
            internal Transformer(IHostEnvironment env, TransformerChain<ITransformer> chain)
            {
                Contracts.AssertValue(env);
                env.AssertValue(chain);
                _host = env.Register(nameof(Transformer));
                _chain = chain;
            }
 
            public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
            {
                _host.CheckValue(inputSchema, nameof(inputSchema));
                return _chain.GetOutputSchema(inputSchema);
            }
 
            public IDataView Transform(IDataView input)
            {
                _host.CheckValue(input, nameof(input));
                return _chain.Transform(input);
            }
 
            bool ITransformer.IsRowToRowMapper => true;
 
            IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
            {
                _host.CheckValue(inputSchema, nameof(inputSchema));
                return (_chain as ITransformer).GetRowToRowMapper(inputSchema);
            }
 
            void ICanSaveModel.Save(ModelSaveContext ctx)
            {
                _host.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel();
                ctx.SetVersionInfo(GetVersionInfo());
 
                ctx.SaveModel(_chain, "Chain");
            }
 
            public Transformer(IHostEnvironment env, ModelLoadContext ctx)
            {
                Contracts.CheckValue(env, nameof(env));
                _host = env.Register(nameof(Transformer));
                _host.CheckValue(ctx, nameof(ctx));
 
                ctx.CheckAtModel(GetVersionInfo());
 
                if (ctx.Header.ModelVerReadable == VerIDataTransform)
                {
                    int n = ctx.Reader.ReadInt32();
                    _chain = new TransformerChain<ITransformer>();
                    ctx.LoadModel<ILegacyDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));
                    IDataView data = loader;
                    for (int i = 0; i < n; i++)
                    {
                        var dirName = string.Format(TransformDirTemplate, i);
                        ITransformer transformer;
                        // Try to load as an ITransformer.
                        try
                        {
                            ctx.LoadModelOrNull<ITransformer, SignatureLoadModel>(env, out transformer, dirName);
                        }
                        catch (FormatException)
                        {
                            transformer = null;
                        }
 
                        // If that didn't work, this should be a RowToRowMapperTransform with a "Mapper" folder in it containing an ITransformer.
                        var mapperDirName = Path.Combine(dirName, "Mapper");
                        if (transformer == null && ctx.ContainsModel(mapperDirName))
                            ctx.LoadModelOrNull<ITransformer, SignatureLoadModel>(env, out transformer, mapperDirName);
 
                        if (transformer != null)
                            data = transformer.Transform(data);
                        else
                        {
                            ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
                            data = xf;
                            transformer = new TransformWrapper(_host, xf);
                        }
                        _chain = _chain.Append(transformer);
                    }
                }
                else
                    ctx.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out _chain, "Chain");
            }
 
            private static VersionInfo GetVersionInfo()
            {
                return new VersionInfo(
                    modelSignature: "TEXT XFR",
                    //verWrittenCur: 0x00010001, // Initial
                    verWrittenCur: 0x00010002, // Save as TransformerChain instead of an array of IDataTransform
                    verReadableCur: 0x00010002,
                    verWeCanReadBack: 0x00010001,
                    loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(Transformer).Assembly.FullName);
            }
        }
    }
}