File: Text\NgramTransform.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;
 
[assembly: LoadableClass(NgramExtractingTransformer.Summary, typeof(IDataTransform), typeof(NgramExtractingTransformer), typeof(NgramExtractingTransformer.Options), typeof(SignatureDataTransform),
    "Ngram Transform", "NgramTransform", "Ngram")]
 
[assembly: LoadableClass(NgramExtractingTransformer.Summary, typeof(IDataTransform), typeof(NgramExtractingTransformer), null, typeof(SignatureLoadDataTransform),
    "Ngram Transform", NgramExtractingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(NgramExtractingTransformer.Summary, typeof(NgramExtractingTransformer), null, typeof(SignatureLoadModel),
    "Ngram Transform", NgramExtractingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IRowMapper), typeof(NgramExtractingTransformer), null, typeof(SignatureLoadRowMapper),
    "Ngram Transform", NgramExtractingTransformer.LoaderSignature)]
 
namespace Microsoft.ML.Transforms.Text
{
    /// <summary>
    /// <see cref="ITransformer"/> resulting from fitting an <see cref="NgramExtractingEstimator"/>.
    /// </summary>
    public sealed class NgramExtractingTransformer : OneToOneTransformerBase
    {
        internal sealed class Column : OneToOneColumn
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum n-gram length", ShortName = "ngram")]
            public int? NgramLength;
 
            [Argument(ArgumentType.AtMostOnce, HelpText =
                "Whether to include all n-gram lengths up to " + nameof(NgramLength) + " or only " + nameof(NgramLength), Name = "AllLengths", ShortName = "all")]
            public bool? UseAllLengths;
 
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Maximum number of tokens to skip when constructing an n-gram",
                ShortName = "skips")]
            public int? SkipLength;
 
            [Argument(ArgumentType.Multiple, HelpText = "Maximum number of n-grams to store in the dictionary", ShortName = "max")]
            public int[] MaxNumTerms = null;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Statistical measure used to evaluate how important a word is to a document in a corpus")]
            public NgramExtractingEstimator.WeightingCriteria? Weighting;
 
            internal static Column Parse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                if (NgramLength != null || UseAllLengths != null || SkipLength != null || Utils.Size(MaxNumTerms) != 0)
                    return false;
                return TryUnparseCore(sb);
            }
        }
 
        internal sealed class Options : TransformInputBase
        {
            [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public Column[] Columns;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum n-gram length", ShortName = "ngram")]
            public int NgramLength = NgramExtractingEstimator.Defaults.NgramLength;
 
            [Argument(ArgumentType.AtMostOnce, HelpText =
                "Whether to store all n-gram lengths up to ngramLength, or only ngramLength", Name = "AllLengths", ShortName = "all")]
            public bool UseAllLengths = NgramExtractingEstimator.Defaults.UseAllLengths;
 
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Maximum number of tokens to skip when constructing an n-gram",
                ShortName = "skips")]
            public int SkipLength = NgramExtractingEstimator.Defaults.SkipLength;
 
            [Argument(ArgumentType.Multiple, HelpText = "Maximum number of n-grams to store in the dictionary", ShortName = "max")]
            public int[] MaxNumTerms = new int[] { NgramExtractingEstimator.Defaults.MaximumNgramsCount };
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The weighting criteria")]
            public NgramExtractingEstimator.WeightingCriteria Weighting = NgramExtractingEstimator.Defaults.Weighting;
        }
 
        private const uint VerTfIdfSupported = 0x00010002;
 
        internal const string LoaderSignature = "NgramTransform";
        internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive values of length 1-n) in a given vector of keys. "
            + "It does so by building a dictionary of n-grams and using the id in the dictionary as the index in the bag.";
 
        internal const string UserName = "NGram Transform";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "NGRAMTRN",
                // verWrittenCur: 0x00010001, // Initial
                // verWrittenCur: 0x00010002, // Add support for TF-IDF
                verWrittenCur: 0x00010003, // Get rid of writing float size in model context
                verReadableCur: 0x00010003,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(NgramExtractingTransformer).Assembly.FullName);
        }
 
        private sealed class TransformInfo
        {
            // Position i, indicates whether the pool contains any (i+1)-grams
            public readonly bool[] NonEmptyLevels;
            public readonly int NgramLength;
            public readonly int SkipLength;
            public readonly bool UseAllLengths;
            public readonly NgramExtractingEstimator.WeightingCriteria Weighting;
 
            public bool RequireIdf => Weighting == NgramExtractingEstimator.WeightingCriteria.Idf || Weighting == NgramExtractingEstimator.WeightingCriteria.TfIdf;
 
            public TransformInfo(NgramExtractingEstimator.ColumnOptions info)
            {
                NgramLength = info.NgramLength;
                SkipLength = info.SkipLength;
                Weighting = info.Weighting;
                UseAllLengths = info.UseAllLengths;
                NonEmptyLevels = new bool[NgramLength];
            }
 
            public TransformInfo(ModelLoadContext ctx, bool readWeighting)
            {
                Contracts.AssertValue(ctx);
 
                // *** Binary format ***
                // int: NgramLength
                // int: SkipLength
                // int: Weighting Criteria (if readWeighting == true)
                // bool[NgramLength]: NonEmptyLevels
 
                NgramLength = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength);
                SkipLength = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
                Contracts.CheckDecode(NgramLength <= NgramBufferBuilder.MaxSkipNgramLength - SkipLength);
 
                if (readWeighting)
                    Weighting = (NgramExtractingEstimator.WeightingCriteria)ctx.Reader.ReadInt32();
                Contracts.CheckDecode(Enum.IsDefined(typeof(NgramExtractingEstimator.WeightingCriteria), Weighting));
                NonEmptyLevels = ctx.Reader.ReadBoolArray(NgramLength);
            }
 
            internal void Save(ModelSaveContext ctx)
            {
                Contracts.AssertValue(ctx);
 
                // *** Binary format ***
                // int: NgramLength
                // int: SkipLength
                // int: Weighting Criteria
                // bool[NgramLength]: NonEmptyLevels
 
                Contracts.Assert(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength);
                ctx.Writer.Write(NgramLength);
                Contracts.Assert(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
                Contracts.Assert(NgramLength + SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
                ctx.Writer.Write(SkipLength);
                Contracts.Assert(Enum.IsDefined(typeof(NgramExtractingEstimator.WeightingCriteria), Weighting));
                ctx.Writer.Write((int)Weighting);
                Contracts.Assert(Utils.Size(NonEmptyLevels) == NgramLength);
                ctx.Writer.WriteBoolBytesNoCount(NonEmptyLevels);
            }
        }
 
        private readonly ImmutableArray<TransformInfo> _transformInfos;
 
        // These contain the n-gram maps
        private readonly SequencePool[] _ngramMaps;
 
        // Ngram inverse document frequencies
        private readonly double[][] _invDocFreqs;
 
        private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(NgramExtractingEstimator.ColumnOptions[] columns)
        {
            Contracts.CheckValue(columns, nameof(columns));
            return columns.Select(x => (x.Name, x.InputColumnName)).ToArray();
        }
 
        private protected override void CheckInputColumn(DataViewSchema inputSchema, int col, int srcCol)
        {
            var type = inputSchema[srcCol].Type;
            if (!NgramExtractingEstimator.IsColumnTypeValid(type))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, NgramExtractingEstimator.ExpectedColumnType, type.ToString());
        }
 
        internal NgramExtractingTransformer(IHostEnvironment env, IDataView input, NgramExtractingEstimator.ColumnOptions[] columns)
           : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NgramExtractingTransformer)), GetColumnPairs(columns))
        {
            var transformInfos = new TransformInfo[columns.Length];
            for (int i = 0; i < columns.Length; i++)
            {
                input.Schema.TryGetColumnIndex(columns[i].InputColumnName, out int srcCol);
                var typeSrc = input.Schema[srcCol].Type;
                transformInfos[i] = new TransformInfo(columns[i]);
            }
            _transformInfos = transformInfos.ToImmutableArray();
            _ngramMaps = Train(Host, columns, _transformInfos, input, out _invDocFreqs);
        }
 
        private static SequencePool[] Train(IHostEnvironment env, NgramExtractingEstimator.ColumnOptions[] columns, ImmutableArray<TransformInfo> transformInfos, IDataView trainingData, out double[][] invDocFreqs)
        {
            var helpers = new NgramBufferBuilder[columns.Length];
            var getters = new ValueGetter<VBuffer<uint>>[columns.Length];
            var src = new VBuffer<uint>[columns.Length];
 
            // Keep track of how many grams are in the pool for each value of n. Position
            // i in _counts counts how many (i+1)-grams are in the pool for column iinfo.
            var counts = new int[columns.Length][];
            var ngramMaps = new SequencePool[columns.Length];
            var activeCols = new List<DataViewSchema.Column>();
            var srcTypes = new DataViewType[columns.Length];
            var srcCols = new int[columns.Length];
            for (int iinfo = 0; iinfo < columns.Length; iinfo++)
            {
                trainingData.Schema.TryGetColumnIndex(columns[iinfo].InputColumnName, out srcCols[iinfo]);
                srcTypes[iinfo] = trainingData.Schema[srcCols[iinfo]].Type;
                activeCols.Add(trainingData.Schema[srcCols[iinfo]]);
            }
            using (var cursor = trainingData.GetRowCursor(activeCols))
            using (var pch = env.StartProgressChannel("Building n-gram dictionary"))
            {
                for (int iinfo = 0; iinfo < columns.Length; iinfo++)
                {
                    env.Assert(srcTypes[iinfo] is VectorDataViewType vectorType && vectorType.ItemType is KeyDataViewType);
                    var ngramLength = columns[iinfo].NgramLength;
                    var skipLength = columns[iinfo].SkipLength;
 
                    getters[iinfo] = RowCursorUtils.GetVecGetterAs<uint>(NumberDataViewType.UInt32, cursor, srcCols[iinfo]);
                    src[iinfo] = default;
                    counts[iinfo] = new int[ngramLength];
                    ngramMaps[iinfo] = new SequencePool();
 
                    // Note: GetNgramIdFinderAdd will control how many n-grams of a specific length will
                    // be added (using lims[iinfo]), therefore we set slotLim to the maximum
                    helpers[iinfo] = new NgramBufferBuilder(ngramLength, skipLength, Utils.ArrayMaxSize,
                        GetNgramIdFinderAdd(env, counts[iinfo], columns[iinfo].MaximumNgramsCounts, ngramMaps[iinfo], transformInfos[iinfo].RequireIdf));
                }
 
                int cInfoFull = 0;
                bool[] infoFull = new bool[columns.Length];
 
                invDocFreqs = new double[columns.Length][];
 
                long totalDocs = 0;
                var rowCount = trainingData.GetRowCount() ?? double.NaN;
                var buffers = new VBuffer<float>[columns.Length];
                pch.SetHeader(new ProgressHeader(new[] { "Total n-grams" }, new[] { "documents" }),
                    e => e.SetProgress(0, totalDocs, rowCount));
                while (cInfoFull < columns.Length && cursor.MoveNext())
                {
                    totalDocs++;
                    for (int iinfo = 0; iinfo < columns.Length; iinfo++)
                    {
                        getters[iinfo](ref src[iinfo]);
                        var keyCount = (uint)srcTypes[iinfo].GetItemType().GetKeyCount();
                        if (keyCount == 0)
                            keyCount = uint.MaxValue;
                        if (!infoFull[iinfo])
                        {
                            if (transformInfos[iinfo].RequireIdf)
                                helpers[iinfo].Reset();
 
                            helpers[iinfo].AddNgrams(in src[iinfo], 0, keyCount);
                            if (transformInfos[iinfo].RequireIdf)
                            {
                                int totalNgrams = counts[iinfo].Sum();
                                Utils.EnsureSize(ref invDocFreqs[iinfo], totalNgrams);
                                helpers[iinfo].GetResult(ref buffers[iinfo]);
                                foreach (var pair in buffers[iinfo].Items())
                                {
                                    if (pair.Value >= 1)
                                        invDocFreqs[iinfo][pair.Key] += 1;
                                }
                            }
                        }
                        AssertValid(env, counts[iinfo], columns[iinfo].MaximumNgramsCounts, ngramMaps[iinfo]);
                    }
                }
 
                pch.Checkpoint(counts.Sum(c => c.Sum()), totalDocs);
                for (int iinfo = 0; iinfo < columns.Length; iinfo++)
                {
                    for (int i = 0; i < Utils.Size(invDocFreqs[iinfo]); i++)
                        if (invDocFreqs[iinfo][i] != 0)
                            invDocFreqs[iinfo][i] = Math.Log(totalDocs / invDocFreqs[iinfo][i]);
                }
 
                for (int iinfo = 0; iinfo < columns.Length; iinfo++)
                {
                    AssertValid(env, counts[iinfo], columns[iinfo].MaximumNgramsCounts, ngramMaps[iinfo]);
 
                    int ngramLength = transformInfos[iinfo].NgramLength;
                    for (int i = 0; i < ngramLength; i++)
                        transformInfos[iinfo].NonEmptyLevels[i] = counts[iinfo][i] > 0;
                }
 
                return ngramMaps;
            }
        }
 
        [Conditional("DEBUG")]
        private static void AssertValid(IHostEnvironment env, int[] counts, IReadOnlyList<int> lims, SequencePool pool)
        {
            int count = 0;
            int countFull = 0;
            for (int i = 0; i < lims.Count; i++)
            {
                env.Assert(counts[i] >= 0);
                env.Assert(counts[i] <= lims[i]);
                if (counts[i] == lims[i])
                    countFull++;
                count += counts[i];
            }
            env.Assert(count == pool.Count);
        }
 
        private static NgramIdFinder GetNgramIdFinderAdd(IHostEnvironment env, int[] counts, IReadOnlyList<int> lims, SequencePool pool, bool requireIdf)
        {
            Contracts.AssertValue(env);
            env.Assert(lims.Count > 0);
            env.Assert(lims.Count == Utils.Size(counts));
 
            int numFull = lims.Count(l => l <= 0);
            int ngramLength = lims.Count;
            return
                (uint[] ngram, int lim, int icol, ref bool more) =>
                {
                    env.Assert(0 < lim && lim <= Utils.Size(ngram));
                    env.Assert(lim <= Utils.Size(counts));
                    env.Assert(lim <= lims.Count);
                    env.Assert(icol == 0);
 
                    var max = lim - 1;
                    int slot = -1;
                    if (counts[max] < lims[max] && pool.TryAdd(ngram, 0, lim, out slot) && ++counts[max] >= lims[max])
                        numFull++;
 
                    // Note: 'slot' is either the id of the added n-gram or -1. In case it is -1, find its id.
                    // Note: 'more' controls whether more n-grams/skip-grams should be processed in the current
                    //       row. For IDF, as we are interested in counting the occurrence of n-grams/skip-
                    //       grams, more should not be updated.
                    if (requireIdf)
                        return slot != -1 ? slot : pool.Get(ngram, 0, lim);
 
                    more = numFull < ngramLength;
                    return -1;
                };
        }
 
        // Factory method for SignatureLoadDataTransform.
        private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
            => Create(env, ctx).MakeDataTransform(input);
 
        // Factory method for SignatureLoadRowMapper.
        private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
            => Create(env, ctx).MakeRowMapper(inputSchema);
 
        private NgramExtractingTransformer(IHost host, ModelLoadContext ctx) :
            base(host, ctx)
        {
            var columnsLength = ColumnPairs.Length;
            // *** Binary format ***
            // <prefix handled in static Create method>
            // <base>
            // for each column
            //   _transformInfo
            //   the n-gram SequencePool
            //   the n-gram inverse document frequencies
            var transformInfos = new TransformInfo[columnsLength];
            _ngramMaps = new SequencePool[columnsLength];
            _invDocFreqs = new double[columnsLength][];
            for (int i = 0; i < columnsLength; i++)
            {
                transformInfos[i] = new TransformInfo(ctx, ctx.Header.ModelVerWritten >= VerTfIdfSupported);
                _ngramMaps[i] = new SequencePool(ctx.Reader);
 
                if (ctx.Header.ModelVerWritten >= VerTfIdfSupported)
                {
                    _invDocFreqs[i] = ctx.Reader.ReadDoubleArray();
                    for (int j = 0; j < Utils.Size(_invDocFreqs[i]); j++)
                        Host.CheckDecode(_invDocFreqs[i][j] >= 0);
                }
            }
            _transformInfos = transformInfos.ToImmutableArray();
        }
 
        // Factory method for SignatureDataTransform.
        internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            env.CheckValue(input, nameof(input));
 
            env.CheckValue(options.Columns, nameof(options.Columns));
            var cols = new NgramExtractingEstimator.ColumnOptions[options.Columns.Length];
            using (var ch = env.Start("ValidateArgs"))
            {
 
                for (int i = 0; i < cols.Length; i++)
                {
                    var item = options.Columns[i];
                    var maxNumTerms = Utils.Size(item.MaxNumTerms) > 0 ? item.MaxNumTerms : options.MaxNumTerms;
                    cols[i] = new NgramExtractingEstimator.ColumnOptions(
                        item.Name,
                        item.NgramLength ?? options.NgramLength,
                        item.SkipLength ?? options.SkipLength,
                        item.UseAllLengths ?? options.UseAllLengths,
                        item.Weighting ?? options.Weighting,
                        maxNumTerms,
                        item.Source ?? item.Name);
                }
            }
            return new NgramExtractingTransformer(env, input, cols).MakeDataTransform(input);
        }
 
        // Factory method for SignatureLoadModel.
        private static NgramExtractingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(nameof(NgramExtractingTransformer));
 
            host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            if (ctx.Header.ModelVerWritten < 0x00010003)
            {
                int cbFloat = ctx.Reader.ReadInt32();
                env.CheckDecode(cbFloat == sizeof(float));
            }
            return new NgramExtractingTransformer(host, ctx);
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
            // *** Binary format ***
            // <base>
            // for each added column
            //   _transformInfo
            //   the n-gram SequencePool
            //   the n-gram inverse document frequencies
            SaveColumns(ctx);
            for (int i = 0; i < _transformInfos.Length; i++)
            {
                _transformInfos[i].Save(ctx);
                _ngramMaps[i].Save(ctx.Writer);
                ctx.Writer.WriteDoubleArray(_invDocFreqs[i]);
            }
        }
 
        private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
 
        private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
        {
            private readonly DataViewType[] _srcTypes;
            private readonly int[] _srcCols;
            private readonly DataViewType[] _types;
            private readonly NgramExtractingTransformer _parent;
 
            public Mapper(NgramExtractingTransformer parent, DataViewSchema inputSchema)
               : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
            {
                _parent = parent;
                _types = new DataViewType[_parent.ColumnPairs.Length];
                _srcTypes = new DataViewType[_parent.ColumnPairs.Length];
                _srcCols = new int[_parent.ColumnPairs.Length];
                for (int i = 0; i < _parent.ColumnPairs.Length; i++)
                {
                    _types[i] = new VectorDataViewType(NumberDataViewType.Single, _parent._ngramMaps[i].Count);
                    inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]);
                    _srcTypes[i] = inputSchema[_srcCols[i]].Type;
                }
            }
 
            protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            {
                var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length];
                for (int i = 0; i < _parent.ColumnPairs.Length; i++)
                {
                    var builder = new DataViewSchema.Annotations.Builder();
                    AddMetadata(i, builder);
 
                    result[i] = new DataViewSchema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], builder.ToAnnotations());
                }
                return result;
            }
 
            private void AddMetadata(int iinfo, DataViewSchema.Annotations.Builder builder)
            {
                if (InputSchema[_srcCols[iinfo]].HasKeyValues())
                {
                    ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                    {
                        GetSlotNames(iinfo, _parent._ngramMaps[iinfo].Count, ref dst);
                    };
 
                    var slotNamesType = new VectorDataViewType(TextDataViewType.Instance, _parent._ngramMaps[iinfo].Count);
                    builder.AddSlotNames(_parent._ngramMaps[iinfo].Count, getter);
                }
            }
 
            private void GetSlotNames(int iinfo, int size, ref VBuffer<ReadOnlyMemory<char>> dst)
            {
                var itemType = _srcTypes[iinfo].GetItemType();
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                Host.Assert(InputSchema[_srcCols[iinfo]].HasKeyValues());
 
                var unigramNames = new VBuffer<ReadOnlyMemory<char>>();
 
                // Get the key values of the unigrams.
                var keyCount = itemType.GetKeyCountAsInt32(Host);
                InputSchema[_srcCols[iinfo]].GetKeyValues(ref unigramNames);
                Host.Check(unigramNames.Length == keyCount);
 
                var pool = _parent._ngramMaps[iinfo];
 
                var ngramCount = pool.Count;
                var dstEditor = VBufferEditor.Create(ref dst, ngramCount);
 
                StringBuilder sb = new StringBuilder();
                uint[] ngram = new uint[_parent._transformInfos[iinfo].NgramLength];
                for (int slot = 0; slot < pool.Count; slot++)
                {
                    var n = pool.GetById(slot, ref ngram);
                    Host.Assert(n >= 0);
 
                    // Get the unigrams composing the current n-gram.
                    ComposeNgramString(ngram, n, sb, keyCount, in unigramNames);
                    dstEditor.Values[slot] = sb.ToString().AsMemory();
                }
 
                dst = dstEditor.Commit();
            }
 
            private IEnumerable<long> GetNgramData(int iinfo, out long[] ngramCounts, out double[] weights, out List<long> indexes)
            {
                var transformInfo = _parent._transformInfos[iinfo];
                var itemType = _srcTypes[iinfo].GetItemType();
 
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                Host.Assert(InputSchema[_srcCols[iinfo]].HasKeyValues());
 
                // Get the key values of the unigrams.
                var keyCount = itemType.GetKeyCountAsInt32(Host);
 
                var maxNGramLength = transformInfo.NgramLength;
 
                var pool = _parent._ngramMaps[iinfo];
 
                // the ngrams in ML.NET are sequentially organized. e.g. {a, a|b, b, b|c...}
                // in onnx, they need to be separated by type. e.g. {a, b, c, a|b, b|c...}
                // since the resulting vectors need to match, we need to create a mapping
                // between the two and store it in the node attributes
 
                // create seprate lists to track the ids of 1-grams, 2-grams etc
                // because they need to be in adjacent regions in the same list
                // when supplied to onnx
                // We later concatenate all these separate n-gram lists
                var ngramIds = new List<long>[maxNGramLength];
                var ngramIndexes = new List<long>[maxNGramLength];
                for (int i = 0; i < ngramIds.Length; i++)
                {
                    ngramIds[i] = new List<long>();
                    ngramIndexes[i] = new List<long>();
                    //ngramWeights[i] = new List<float>();
                }
 
                weights = new double[pool.Count];
 
                uint[] ngram = new uint[maxNGramLength];
                for (int i = 0; i < pool.Count; i++)
                {
                    var n = pool.GetById(i, ref ngram);
                    Host.Assert(n >= 0);
 
                    // add the id of each gram to the corresponding ids list
                    for (int j = 0; j < n; j++)
                        ngramIds[n - 1].Add(ngram[j]);
 
                    // add the indexes to the corresponding list
                    ngramIndexes[n - 1].Add(i);
 
                    if (transformInfo.RequireIdf)
                        weights[i] = _parent._invDocFreqs[iinfo][i];
                    else
                        weights[i] = 1.0f;
                }
 
                // initialize the ngramCounts array with start-index of each n-gram
                int start = 0;
                ngramCounts = new long[maxNGramLength];
                for (int i = 0; i < maxNGramLength; i++)
                {
                    ngramCounts[i] = start;
                    start += ngramIds[i].Count;
                }
 
                // concatenate all the lists and return
                IEnumerable<long> allNGramIds = ngramIds[0];
                indexes = ngramIndexes[0];
                for (int i = 1; i < maxNGramLength; i++)
                {
                    allNGramIds = Enumerable.Concat(allNGramIds, ngramIds[i]);
                    indexes = indexes.Concat(ngramIndexes[i]).ToList();
                }
 
                return allNGramIds;
            }
 
            private void ComposeNgramString(uint[] ngram, int count, StringBuilder sb, int keyCount, in VBuffer<ReadOnlyMemory<char>> terms)
            {
                Host.AssertValue(sb);
                Host.AssertValue(ngram);
                Host.Assert(keyCount > 0);
 
                sb.Clear();
                ReadOnlyMemory<char> term = default;
                string sep = "";
                for (int iterm = 0; iterm < count; iterm++)
                {
                    sb.Append(sep);
                    sep = "|";
                    var unigram = ngram[iterm];
                    if (unigram <= 0 || unigram > keyCount)
                        sb.Append("*");
                    else
                    {
                        term = terms.GetItemOrDefault((int)unigram - 1);
                        sb.AppendMemory(term);
                    }
                }
            }
 
            private NgramIdFinder GetNgramIdFinder(int iinfo)
            {
                return
                    (uint[] ngram, int lim, int icol, ref bool more) =>
                    {
                        Host.Assert(0 < lim && lim <= Utils.Size(ngram));
                        Host.Assert(lim <= Utils.Size(_parent._transformInfos[iinfo].NonEmptyLevels));
 
                        if (!_parent._transformInfos[iinfo].NonEmptyLevels[lim - 1])
                            return -1;
                        return _parent._ngramMaps[iinfo].Get(ngram, 0, lim);
                    };
            }
 
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
            {
                Contracts.AssertValue(input);
                Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                disposer = null;
 
                var getSrc = RowCursorUtils.GetVecGetterAs<uint>(NumberDataViewType.UInt32, input, _srcCols[iinfo]);
                var src = default(VBuffer<uint>);
                var bldr = new NgramBufferBuilder(_parent._transformInfos[iinfo].NgramLength, _parent._transformInfos[iinfo].SkipLength,
                    _parent._ngramMaps[iinfo].Count, GetNgramIdFinder(iinfo));
                var keyCount = (uint)_srcTypes[iinfo].GetItemType().GetKeyCount();
                if (keyCount == 0)
                    keyCount = uint.MaxValue;
 
                ValueGetter<VBuffer<float>> del;
                switch (_parent._transformInfos[iinfo].Weighting)
                {
                    case NgramExtractingEstimator.WeightingCriteria.TfIdf:
                        Host.AssertValue(_parent._invDocFreqs[iinfo]);
                        del =
                            (ref VBuffer<float> dst) =>
                            {
                                getSrc(ref src);
                                if (!bldr.IsEmpty)
                                {
                                    bldr.Reset();
                                    bldr.AddNgrams(in src, 0, keyCount);
                                    bldr.GetResult(ref dst);
                                    VBufferUtils.Apply(ref dst, (int i, ref float v) => v = (float)(v * _parent._invDocFreqs[iinfo][i]));
                                }
                                else
                                    VBufferUtils.Resize(ref dst, 0);
                            };
                        break;
                    case NgramExtractingEstimator.WeightingCriteria.Idf:
                        Host.AssertValue(_parent._invDocFreqs[iinfo]);
                        del =
                            (ref VBuffer<float> dst) =>
                            {
                                getSrc(ref src);
                                if (!bldr.IsEmpty)
                                {
                                    bldr.Reset();
                                    bldr.AddNgrams(in src, 0, keyCount);
                                    bldr.GetResult(ref dst);
                                    VBufferUtils.Apply(ref dst, (int i, ref float v) => v = v >= 1 ? (float)_parent._invDocFreqs[iinfo][i] : 0);
                                }
                                else
                                    VBufferUtils.Resize(ref dst, 0);
                            };
                        break;
                    case NgramExtractingEstimator.WeightingCriteria.Tf:
                        del =
                            (ref VBuffer<float> dst) =>
                            {
                                getSrc(ref src);
                                if (!bldr.IsEmpty)
                                {
                                    bldr.Reset();
                                    bldr.AddNgrams(in src, 0, keyCount);
                                    bldr.GetResult(ref dst);
                                }
                                else
                                    VBufferUtils.Resize(ref dst, 0);
                            };
                        break;
                    default:
                        throw Host.Except("Unsupported weighting criteria");
                }
                return del;
            }
 
            public bool CanSaveOnnx(OnnxContext ctx) => true;
 
            public void SaveAsOnnx(OnnxContext ctx)
            {
                Host.CheckValue(ctx, nameof(ctx));
 
                int numColumns = _parent.ColumnPairs.Length;
                for (int iinfo = 0; iinfo < numColumns; ++iinfo)
                {
                    string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
                    if (!ctx.ContainsColumn(inputColumnName))
                        continue;
 
                    string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
 
                    // NOTE:
                    // There is a subtle side effect in how OnnxContextImpl works
                    // AddIntermediateVariable picks a unique name for the variable being added
                    // and overwrites any existing old name with the new name in its column map
                    // If the pipeline consists of the same column name being transformed by more
                    // than one transformer, then the srcVariableName used must be the old name
                    // and the dstVariableName must be the newly chosen unique name
                    // Therefore, GetVariableName must always be called before AddIntermediateVariable here
                    string srcVariableName = ctx.GetVariableName(inputColumnName);
                    string dstVariableName = ctx.AddIntermediateVariable(_srcTypes[iinfo], outputColumnName, true);
                    SaveAsOnnxCore(ctx, iinfo, srcVariableName, dstVariableName);
                }
            }
 
            private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
            {
                const int minimumOpSetVersion = 9;
                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
 
                var transformInfo = _parent._transformInfos[iinfo];
 
                // TfIdfVectorizer accepts strings, int32 and int64 tensors.
                // But in the ML.NET implementation of the NGramTransform, it only accepts keys as inputs
                // That are the result of ValueToKeyMapping transformer, which outputs UInt32 values,
                // Or TokenizingByCharacters, which outputs UInt16 values
                // So, if it is UInt32, UInt64, or UInt16, cast the output here to Int32/Int64
                string opType;
                var vectorType = _srcTypes[iinfo] as VectorDataViewType;
 
                if ((vectorType != null) &&
                    ((vectorType.RawType == typeof(VBuffer<UInt32>)) || (vectorType.RawType == typeof(VBuffer<UInt64>)) ||
                    (vectorType.RawType == typeof(VBuffer<UInt16>))))
                {
                    var dataKind = _srcTypes[iinfo] == NumberDataViewType.UInt32 ? DataKind.Int32 : DataKind.Int64;
 
                    opType = "Cast";
                    string castOutput = ctx.AddIntermediateVariable(_srcTypes[iinfo], "CastOutput", true);
 
                    var castNode = ctx.CreateNode(opType, srcVariableName, castOutput, ctx.GetNodeName(opType), "");
                    var t = InternalDataKindExtensions.ToInternalDataKind(dataKind).ToType();
                    castNode.AddAttribute("to", t);
 
                    srcVariableName = castOutput;
                }
 
                opType = "TfIdfVectorizer";
                var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
                node.AddAttribute("max_gram_length", transformInfo.NgramLength);
                node.AddAttribute("max_skip_count", transformInfo.SkipLength);
                node.AddAttribute("min_gram_length", transformInfo.UseAllLengths ? 1 : transformInfo.NgramLength);
 
                string mode;
                if (transformInfo.RequireIdf)
                {
                    mode = transformInfo.Weighting == NgramExtractingEstimator.WeightingCriteria.Idf ? "IDF" : "TFIDF";
                }
                else
                {
                    mode = "TF";
                }
                node.AddAttribute("mode", mode);
 
                long[] ngramCounts;
                double[] ngramWeights;
                List<long> ngramIndexes;
 
                var ngramIds = GetNgramData(iinfo, out ngramCounts, out ngramWeights, out ngramIndexes);
 
                node.AddAttribute("ngram_counts", ngramCounts);
                node.AddAttribute("pool_int64s", ngramIds);
                node.AddAttribute("ngram_indexes", ngramIndexes);
                node.AddAttribute("weights", ngramWeights);
            }
 
        }
    }
 
    /// <summary>
    /// Produces a vector of counts of n-grams (sequences of consecutive words) encountered in the input text.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    ///
    /// ###  Estimator Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Does this estimator need to look at the data to train its parameters? | Yes |
    /// | Input column data type | Vector of [key](xref:Microsoft.ML.Data.KeyDataViewType) type. |
    /// | Output column data type | Known-sized vector of <xref:System.Single> |
    /// | Exportable to ONNX | Yes |
    ///
    /// The resulting <xref:Microsoft.ML.Transforms.Text.NgramExtractingTransformer>
    /// creates a new column, named as specified in the output column name parameters, where each
    /// input vector is mapped to a vector of counts of n-grams (sequences of consecutive words) encountered in the input text.
    ///
    /// The estimator builds a dictionary of n-grams and the <xref:Microsoft.ML.Transforms.Text.NgramExtractingTransformer>
    /// uses the id in the dictionary as the index in the count vector that it produces.
    ///
    /// Check the See Also section for links to usage examples.
    /// ]]></format>
    /// </remarks>
    /// <seealso cref="TextCatalog.ProduceNgrams(TransformsCatalog.TextTransforms, string, string, int, int, bool, int, WeightingCriteria)"/>
    public sealed class NgramExtractingEstimator : IEstimator<NgramExtractingTransformer>
    {
        /// <summary>
        /// A statistical measure used to evaluate how important a word is to a document in a corpus.
        /// This enumeration is serialized.
        /// </summary>
        public enum WeightingCriteria
        {
            /// <summary>Term Frequency. Calculated based on the number of occurrences in the document.</summary>
            [EnumValueDisplay("TF (Term Frequency)")]
            Tf = 0,
 
            /// <summary>
            /// Inverse Document Frequency. A ratio (the logarithm of inverse relative frequency)
            /// that measures the information a slot provides by determining how common or rare it is across the entire corpus.
            /// </summary>
            [EnumValueDisplay("IDF (Inverse Document Frequency)")]
            Idf = 1,
 
            /// <summary>The product of the term frequency and the inverse document frequency.</summary>
            [EnumValueDisplay("TF-IDF")]
            TfIdf = 2
        }
 
        internal static class Defaults
        {
            public const int NgramLength = 2;
            public const bool UseAllLengths = true;
            public const int SkipLength = 0;
            public const int MaximumNgramsCount = 10000000;
            public const WeightingCriteria Weighting = WeightingCriteria.Tf;
        }
 
        private readonly IHost _host;
        private readonly ColumnOptions[] _columns;
 
        /// <summary>
        /// Produces a bag of counts of n-grams (sequences of consecutive words) in <paramref name="inputColumnName"/>
        /// and outputs bag of word vector as <paramref name="outputColumnName"/>
        /// </summary>
        /// <param name="env">The environment.</param>
        /// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
        /// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
        /// <param name="ngramLength">Ngram length.</param>
        /// <param name="skipLength">Maximum number of tokens to skip when constructing an n-gram.</param>
        /// <param name="useAllLengths">Whether to include all n-gram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
        /// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
        /// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
        internal NgramExtractingEstimator(IHostEnvironment env,
            string outputColumnName, string inputColumnName = null,
            int ngramLength = Defaults.NgramLength,
            int skipLength = Defaults.SkipLength,
            bool useAllLengths = Defaults.UseAllLengths,
            int maximumNgramsCount = Defaults.MaximumNgramsCount,
            WeightingCriteria weighting = Defaults.Weighting)
            : this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting)
        {
        }
 
        /// <summary>
        /// Produces a bag of counts of n-grams (sequences of consecutive words) in <paramref name="columns.inputs"/>
        /// and outputs bag of word vector for each output in <paramref name="columns.output"/>
        /// </summary>
        /// <param name="env">The environment.</param>
        /// <param name="columns">Pairs of columns to compute bag of word vector.</param>
        /// <param name="ngramLength">Ngram length.</param>
        /// <param name="skipLength">Maximum number of tokens to skip when constructing an n-gram.</param>
        /// <param name="useAllLengths">Whether to include all n-gram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
        /// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
        /// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
        internal NgramExtractingEstimator(IHostEnvironment env,
            (string outputColumnName, string inputColumnName)[] columns,
            int ngramLength = Defaults.NgramLength,
            int skipLength = Defaults.SkipLength,
            bool useAllLengths = Defaults.UseAllLengths,
            int maximumNgramsCount = Defaults.MaximumNgramsCount,
            WeightingCriteria weighting = Defaults.Weighting)
            : this(env, columns.Select(x => new ColumnOptions(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, useAllLengths, weighting, maximumNgramsCount)).ToArray())
        {
        }
 
        /// <summary>
        /// Produces a bag of counts of n-grams (sequences of consecutive words) in <paramref name="columns.inputs"/>
        /// and outputs bag of word vector for each output in <paramref name="columns.output"/>
        /// </summary>
        /// <param name="env">The environment.</param>
        /// <param name="columns">Array of columns with information how to transform data.</param>
        internal NgramExtractingEstimator(IHostEnvironment env, params ColumnOptions[] columns)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(NgramExtractingEstimator));
            _columns = columns;
        }
 
        /// <summary>
        /// Trains and returns a <see cref="NgramExtractingTransformer"/>.
        /// </summary>
        public NgramExtractingTransformer Fit(IDataView input) => new NgramExtractingTransformer(_host, input, _columns);
 
        internal static bool IsColumnTypeValid(DataViewType type)
        {
            if (!(type is VectorDataViewType vectorType))
                return false;
            if (!(vectorType.ItemType is KeyDataViewType itemKeyType))
                return false;
            // Can only accept key types that can be converted to U4.
            if (itemKeyType.Count == 0 && !NgramUtils.IsValidNgramRawType(itemKeyType.RawType))
                return false;
            return true;
        }
 
        internal static bool IsSchemaColumnValid(SchemaShape.Column col)
        {
            if (col.Kind == SchemaShape.Column.VectorKind.Scalar)
                return false;
            if (!col.IsKey)
                return false;
            // Can only accept key types that can be converted to U8.
            if (!NgramUtils.IsValidNgramRawType(col.ItemType.RawType))
                return false;
            return true;
        }
 
        internal const string ExpectedColumnType = "Expected vector of Key type, and Key is convertible to UInt32";
 
        /// <summary>
        /// Describes how the transformer handles one column pair.
        /// </summary>
        [BestFriend]
        internal sealed class ColumnOptions
        {
            /// <summary>Name of the column resulting from the transformation of <see cref="InputColumnName"/>.</summary>
            public readonly string Name;
            /// <summary>Name of column to transform.</summary>
            public readonly string InputColumnName;
            /// <summary>Maximum n-gram length.</summary>
            public readonly int NgramLength;
            /// <summary>Maximum number of tokens to skip when constructing an n-gram.</summary>
            public readonly int SkipLength;
            /// <summary>Whether to store all n-gram lengths up to ngramLength, or only ngramLength.</summary>
            public readonly bool UseAllLengths;
            /// <summary>The weighting criteria.</summary>
            public readonly WeightingCriteria Weighting;
            /// <summary>
            /// Underlying state of <see cref="MaximumNgramsCounts"/>.
            /// </summary>
            private readonly ImmutableArray<int> _maximumNgramsCounts;
            /// <summary>
            /// Contains the maximum number of terms (that is, n-grams) to store in the dictionary, for each level of n-grams,
            /// from n=1 (in position 0) up to n=<see cref="NgramLength"/> (in position <see cref="NgramLength"/>-1)
            /// </summary>
            public IReadOnlyList<int> MaximumNgramsCounts => _maximumNgramsCounts;
 
            /// <summary>
            /// Describes how the transformer handles one Gcn column pair.
            /// </summary>
            /// <param name="name">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
            /// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/>, the value of the <paramref name="name"/> will be used as source.</param>
            /// <param name="ngramLength">Maximum n-gram length.</param>
            /// <param name="skipLength">Maximum number of tokens to skip when constructing an n-gram.</param>
            /// <param name="useAllLengths">Whether to store all n-gram lengths up to ngramLength, or only ngramLength.</param>
            /// <param name="weighting">The weighting criteria.</param>
            /// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
            public ColumnOptions(string name, string inputColumnName = null,
                int ngramLength = Defaults.NgramLength,
                int skipLength = Defaults.SkipLength,
                bool useAllLengths = Defaults.UseAllLengths,
                WeightingCriteria weighting = Defaults.Weighting,
                int maximumNgramsCount = Defaults.MaximumNgramsCount)
                : this(name, ngramLength, skipLength, useAllLengths, weighting, new int[] { maximumNgramsCount }, inputColumnName ?? name)
            {
            }
 
            internal ColumnOptions(string name,
                int ngramLength,
                int skipLength,
                bool useAllLengths,
                WeightingCriteria weighting,
                int[] maximumNgramsCounts,
                string inputColumnName = null)
            {
                if (ngramLength == 1 && skipLength != 0)
                    throw Contracts.ExceptUserArg(nameof(skipLength), string.Format(
                        "{0} (actual value: {1}) can only be zero when {2} set to one.", nameof(skipLength), skipLength, nameof(ngramLength)));
                if (ngramLength + skipLength > NgramBufferBuilder.MaxSkipNgramLength)
                    throw Contracts.ExceptUserArg(nameof(skipLength),
                        $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}");
                Contracts.CheckUserArg(0 < ngramLength && ngramLength <= NgramBufferBuilder.MaxSkipNgramLength, nameof(ngramLength));
 
                var limits = new int[ngramLength];
                if (!useAllLengths)
                {
                    Contracts.CheckUserArg(Utils.Size(maximumNgramsCounts) == 0 ||
                        Utils.Size(maximumNgramsCounts) == 1 && maximumNgramsCounts[0] > 0, nameof(maximumNgramsCounts));
                    limits[ngramLength - 1] = Utils.Size(maximumNgramsCounts) == 0 ? Defaults.MaximumNgramsCount : maximumNgramsCounts[0];
                }
                else
                {
                    Contracts.CheckUserArg(Utils.Size(maximumNgramsCounts) <= ngramLength, nameof(maximumNgramsCounts));
                    Contracts.CheckUserArg(Utils.Size(maximumNgramsCounts) == 0 || maximumNgramsCounts.All(i => i >= 0) && maximumNgramsCounts[maximumNgramsCounts.Length - 1] > 0, nameof(maximumNgramsCounts));
                    var extend = Utils.Size(maximumNgramsCounts) == 0 ? Defaults.MaximumNgramsCount : maximumNgramsCounts[maximumNgramsCounts.Length - 1];
                    limits = Utils.BuildArray(ngramLength, i => i < Utils.Size(maximumNgramsCounts) ? maximumNgramsCounts[i] : extend);
                }
                _maximumNgramsCounts = ImmutableArray.Create(limits);
 
                Name = name;
                InputColumnName = inputColumnName ?? name;
                NgramLength = ngramLength;
                SkipLength = skipLength;
                UseAllLengths = useAllLengths;
                Weighting = weighting;
            }
        }
 
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);
            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                if (!IsSchemaColumnValid(col))
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, ExpectedColumnType, col.GetTypeString());
                var metadata = new List<SchemaShape.Column>();
                if (col.NeedsSlotNames())
                    metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
                result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(metadata));
            }
            return new SchemaShape(result.Values);
        }
    }
}