File: Text\TokenizingByCharacters.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;
 
[assembly: LoadableClass(TokenizingByCharactersTransformer.Summary, typeof(IDataTransform), typeof(TokenizingByCharactersTransformer), typeof(TokenizingByCharactersTransformer.Options), typeof(SignatureDataTransform),
    TokenizingByCharactersTransformer.UserName, "CharTokenize", TokenizingByCharactersTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IDataTransform), typeof(TokenizingByCharactersTransformer), null, typeof(SignatureLoadDataTransform),
    TokenizingByCharactersTransformer.UserName, TokenizingByCharactersTransformer.LoaderSignature)]
 
[assembly: LoadableClass(TokenizingByCharactersTransformer.Summary, typeof(TokenizingByCharactersTransformer), null, typeof(SignatureLoadModel),
    TokenizingByCharactersTransformer.UserName, TokenizingByCharactersTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IRowMapper), typeof(TokenizingByCharactersTransformer), null, typeof(SignatureLoadRowMapper),
    TokenizingByCharactersTransformer.UserName, TokenizingByCharactersTransformer.LoaderSignature)]
 
namespace Microsoft.ML.Transforms.Text
{
    /// <summary>
    /// <see cref="ITransformer"/> resulting from fitting a <see cref="TokenizingByCharactersEstimator"/>.
    /// </summary>
    public sealed class TokenizingByCharactersTransformer : OneToOneTransformerBase
    {
        internal sealed class Column : OneToOneColumn
        {
            internal static Column Parse(string str)
            {
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                return TryUnparseCore(sb);
            }
        }
 
        internal sealed class Options : TransformInputBase
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public Column[] Columns;
 
            [Argument(ArgumentType.Multiple, HelpText = "Whether to mark the beginning/end of each row/slot with start of text character (0x02)/end of text character (0x03)",
                ShortName = "mark", SortOrder = 2)]
            public bool UseMarkerChars = TokenizingByCharactersEstimator.Defaults.UseMarkerCharacters;
 
            // REVIEW: support UTF-32 encoding through an argument option?
 
            // REVIEW: support encoding surrogate pairs in UTF-16?
        }
 
        internal const string Summary = "Character-oriented tokenizer where text is considered a sequence of characters.";
 
        internal const string LoaderSignature = "CharToken";
        internal const string UserName = "Character Tokenizer Transform";
 
        // Keep track of the model that was saved with ver:0x00010001
        private readonly bool _isSeparatorStartEnd;
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "CHARTOKN",
                //verWrittenCur: 0x00010001, // Initial
                verWrittenCur: 0x00010002,  // Updated to use UnitSeparator <US> character instead of using <ETX><STX> for vector inputs.
                verReadableCur: 0x00010002,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(TokenizingByCharactersTransformer).Assembly.FullName);
        }
 
        // Controls whether to mark the beginning/end of each row/slot with TextStartMarker/TextEndMarker.
        private readonly bool _useMarkerChars;
 
        private const ushort UnitSeparator = 0x1f;
        private const ushort TextStartMarker = 0x02;
        private const ushort TextEndMarker = 0x03;
        private const int TextMarkersCount = 2;
 
        // For now, this transform supports input text formatted as UTF-16 only.
        // Note: Null-char is mapped to NA. Therefore, we have UInt16.MaxValue unique key values.
        internal const int CharsCount = ushort.MaxValue;
        private const string RegistrationName = "CharTokenizer";
 
        /// <summary>
        /// Tokenize incoming text in input columns and output the tokens as output columns.
        /// </summary>
        /// <param name="env">The environment.</param>
        /// <param name="useMarkerCharacters">Whether to prepend a marker character, <see langword="0x02"/>, to the beginning,
        /// and append another marker character, <see langword="0x03"/>, to the end of the output vector of characters.</param>
        /// <param name="columns">Pairs of columns to run the tokenization on.</param>
        internal TokenizingByCharactersTransformer(IHostEnvironment env, bool useMarkerCharacters = TokenizingByCharactersEstimator.Defaults.UseMarkerCharacters,
            params (string outputColumnName, string inputColumnName)[] columns) :
            base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns)
        {
            _useMarkerChars = useMarkerCharacters;
        }
 
        /// <summary>
        /// The names of the output and input column pairs on which the transformation is applied.
        /// </summary>
        internal IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly();
 
        private protected override void CheckInputColumn(DataViewSchema inputSchema, int col, int srcCol)
        {
            var type = inputSchema[srcCol].Type;
            if (!TokenizingByCharactersEstimator.IsColumnTypeValid(type))
                throw Host.ExceptParam(nameof(inputSchema), TokenizingByCharactersEstimator.ExpectedColumnType);
        }
 
        private TokenizingByCharactersTransformer(IHost host, ModelLoadContext ctx) :
          base(host, ctx)
        {
            // *** Binary format ***
            // <base>
            // byte: _useMarkerChars value.
            _useMarkerChars = ctx.Reader.ReadBoolByte();
            _isSeparatorStartEnd = ctx.Header.ModelVerReadable < 0x00010002 || ctx.Reader.ReadBoolByte();
        }
 
        // Factory method for SignatureLoadDataTransform.
        private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
            => Create(env, ctx).MakeDataTransform(input);
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // <base>
            // byte: _useMarkerChars value.
            SaveColumns(ctx);
            ctx.Writer.WriteBoolByte(_useMarkerChars);
            ctx.Writer.WriteBoolByte(_isSeparatorStartEnd);
        }
 
        // Factory method for SignatureLoadModel.
        private static TokenizingByCharactersTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(RegistrationName);
            host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            return new TokenizingByCharactersTransformer(host, ctx);
        }
 
        // Factory method for SignatureDataTransform.
        internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            env.CheckValue(input, nameof(input));
 
            env.CheckValue(options.Columns, nameof(options.Columns));
            var cols = new (string outputColumnName, string inputColumnName)[options.Columns.Length];
            for (int i = 0; i < cols.Length; i++)
            {
                var item = options.Columns[i];
                cols[i] = (item.Name, item.Source ?? item.Name);
            }
            return new TokenizingByCharactersTransformer(env, options.UseMarkerChars, cols).MakeDataTransform(input);
        }
 
        // Factory method for SignatureLoadRowMapper.
        private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
            => Create(env, ctx).MakeRowMapper(inputSchema);
 
        private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
 
        private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
        {
            private readonly DataViewType _type;
            private readonly TokenizingByCharactersTransformer _parent;
            private readonly bool[] _isSourceVector;
            private readonly int[] _sourceVectorLength;
            // Constructed and cached the first time it is needed.
            private volatile string _keyValuesStr;
            private volatile int[] _keyValuesBoundaries;
 
            public Mapper(TokenizingByCharactersTransformer parent, DataViewSchema inputSchema)
             : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
            {
                _parent = parent;
                var keyType = new KeyDataViewType(typeof(ushort), CharsCount);
                _type = new VectorDataViewType(keyType);
                _isSourceVector = new bool[_parent.ColumnPairs.Length];
                _sourceVectorLength = new int[_parent.ColumnPairs.Length];
                for (int i = 0; i < _isSourceVector.Length; i++)
                {
                    var type = inputSchema[_parent.ColumnPairs[i].inputColumnName].Type;
                    _isSourceVector[i] = type is VectorDataViewType;
                    _sourceVectorLength[i] = type.GetValueCount();
                }
            }
 
            public bool CanSaveOnnx(OnnxContext ctx) => true;
 
            public void SaveAsOnnx(OnnxContext ctx)
            {
                Host.CheckValue(ctx, nameof(ctx));
                for (int iinfo = 0; iinfo < _isSourceVector.Length; ++iinfo)
                {
                    string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
                    if (!ctx.ContainsColumn(inputColumnName))
                        continue;
 
                    string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
                    string srcVariableName = ctx.GetVariableName(inputColumnName);
                    string dstVariableName = ctx.AddIntermediateVariable(_type, outputColumnName, true);
                    SaveAsOnnxCore(ctx, iinfo, srcVariableName, dstVariableName);
                }
            }
 
            private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
            {
                const int minimumOpSetVersion = 9;
                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
 
                string opType = "Tokenizer";
                DataViewType dataViewType;
                if (_isSourceVector[iinfo])
                    dataViewType = new VectorDataViewType(TextDataViewType.Instance, _sourceVectorLength[iinfo]);
                else
                    dataViewType = TextDataViewType.Instance;
 
                string tokenizerOutput = ctx.AddIntermediateVariable(dataViewType, "TokenizerOutput", true);
                var node = ctx.CreateNode(opType, srcVariableName, tokenizerOutput, ctx.GetNodeName(opType), "com.microsoft");
                node.AddAttribute("mark", _parent._useMarkerChars);
                node.AddAttribute("mincharnum", 1);
                node.AddAttribute("pad_value", "");
                node.AddAttribute("separators", new string[] { "" });
 
                opType = "Squeeze";
                var squeezeOutput = ctx.AddIntermediateVariable(dataViewType, "SqueezeOutput");
                node = ctx.CreateNode(opType, tokenizerOutput, squeezeOutput, ctx.GetNodeName(opType), "");
                node.AddAttribute("axes", new long[] { 1 });
 
                opType = "LabelEncoder";
                var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "LabelEncoderOutput");
                node = ctx.CreateNode(opType, squeezeOutput, labelEncoderOutput, ctx.GetNodeName(opType));
 
                IEnumerable<string> charStrings = Enumerable.Range(0, 65535).Select(x => ((char)x).ToString());
                IEnumerable<long> charValues = Enumerable.Range(0, 65535).Select(x => Convert.ToInt64(x));
                node.AddAttribute("keys_strings", charStrings);
                node.AddAttribute("values_int64s", charValues);
 
                opType = "Cast";
                var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
                var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt16).ToType();
                castNode.AddAttribute("to", t);
            }
 
            protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            {
                var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length];
                for (int i = 0; i < _parent.ColumnPairs.Length; i++)
                {
                    var builder = new DataViewSchema.Annotations.Builder();
                    AddMetadata(i, builder);
                    result[i] = new DataViewSchema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _type, builder.ToAnnotations());
                }
                return result;
            }
 
            private void AddMetadata(int iinfo, DataViewSchema.Annotations.Builder builder)
            {
                builder.Add(InputSchema[_parent.ColumnPairs[iinfo].inputColumnName].Annotations, name => name == AnnotationUtils.Kinds.SlotNames);
                ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter =
                       (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                       {
                           GetKeyValues(iinfo, ref dst);
                       };
                builder.AddKeyValues(CharsCount, TextDataViewType.Instance, getter);
            }
 
            /// <summary>
            /// Get the key values (chars) corresponding to keys in the output columns.
            /// </summary>
            private void GetKeyValues(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
            {
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
 
                if (_keyValuesStr == null)
                {
                    // Create key values corresponding to the character. This will
                    // often just be the character itself, but sometimes (control characters,
                    // illegal codepoints, spaces, etc.) it is better to use something else
                    // to represent the character.
                    int[] boundaries = new int[CharsCount + 1];
                    var bldr = new StringBuilder();
                    for (int i = 1; i <= CharsCount; i++)
                    {
                        AppendCharRepr((char)i, bldr);
                        boundaries[i] = bldr.Length;
                    }
 
                    Host.Assert(bldr.Length == boundaries[boundaries.Length - 1]);
                    Interlocked.CompareExchange(ref _keyValuesBoundaries, boundaries, null);
                    Interlocked.CompareExchange(ref _keyValuesStr, bldr.ToString(), null);
                    bldr.Length = 0;
                }
 
                var keyValuesStr = _keyValuesStr;
                var keyValuesBoundaries = _keyValuesBoundaries;
                Host.AssertValue(keyValuesBoundaries);
 
                var editor = VBufferEditor.Create(ref dst, CharsCount);
                for (int i = 0; i < CharsCount; i++)
                    editor.Values[i] = keyValuesStr.AsMemory().Slice(keyValuesBoundaries[i], keyValuesBoundaries[i + 1] - keyValuesBoundaries[i]);
                dst = editor.Commit();
            }
 
            private void AppendCharRepr(char c, StringBuilder bldr)
            {
                // Special handling of characters identified in https://en.wikipedia.org/wiki/Unicode_control_characters,
                // as well as space, using the control pictures.
                if (c <= 0x20)
                {
                    // Use the control pictures unicode code block.
                    bldr.Append('<');
                    bldr.Append((char)(c + '\u2400'));
                    bldr.Append('>');
                    return;
                }
                if ('\uD800' <= c && c <= '\uDFFF')
                {
                    // These aren't real characters, and so will cause an exception
                    // when we try to write them to the file.
                    bldr.AppendFormat("\\u{0:4X}", (int)c);
                    return;
                }
 
                switch (c)
                {
                    case '\u007f':
                        bldr.Append("<\u2421>");
                        return; // DEL
                    case '\u0080':
                        bldr.Append("<PAD>");
                        return;
                    case '\u0081':
                        bldr.Append("<HOP>");
                        return;
                    case '\u0082':
                        bldr.Append("<BPH>");
                        return;
                    case '\u0083':
                        bldr.Append("<NBH>");
                        return;
                    case '\u0084':
                        bldr.Append("<IND>");
                        return;
                    case '\u0085':
                        bldr.Append("<NEL>");
                        return;
                    case '\u0086':
                        bldr.Append("<SSA>");
                        return;
                    case '\u0087':
                        bldr.Append("<ESA>");
                        return;
                    case '\u0088':
                        bldr.Append("<HTS>");
                        return;
                    case '\u0089':
                        bldr.Append("<HTJ>");
                        return;
                    case '\u008a':
                        bldr.Append("<VTS>");
                        return;
                    case '\u008b':
                        bldr.Append("<PLD>");
                        return;
                    case '\u008c':
                        bldr.Append("<PLU>");
                        return;
                    case '\u008d':
                        bldr.Append("<RI>");
                        return;
                    case '\u008e':
                        bldr.Append("<SS2>");
                        return;
                    case '\u008f':
                        bldr.Append("<SS3>");
                        return;
                    case '\u0090':
                        bldr.Append("<DCS>");
                        return;
                    case '\u0091':
                        bldr.Append("<PU1>");
                        return;
                    case '\u0092':
                        bldr.Append("<PU2>");
                        return;
                    case '\u0093':
                        bldr.Append("<STS>");
                        return;
                    case '\u0094':
                        bldr.Append("<CCH>");
                        return;
                    case '\u0095':
                        bldr.Append("<MW>");
                        return;
                    case '\u0096':
                        bldr.Append("<SPA>");
                        return;
                    case '\u0097':
                        bldr.Append("<EPA>");
                        return;
                    case '\u0098':
                        bldr.Append("<SOS>");
                        return;
                    case '\u0099':
                        bldr.Append("<SGCI>");
                        return;
                    case '\u009a':
                        bldr.Append("<SCI>");
                        return;
                    case '\u009b':
                        bldr.Append("<CSI>");
                        return;
                    case '\u009c':
                        bldr.Append("<ST>");
                        return;
                    case '\u009d':
                        bldr.Append("<OSC>");
                        return;
                    case '\u009e':
                        bldr.Append("<PM>");
                        return;
                    case '\u009f':
                        bldr.Append("<APC>");
                        return;
                    case '\u2028':
                        bldr.Append("<LSEP>");
                        return;
                    case '\u2029':
                        bldr.Append("<PSEP>");
                        return;
                    default:
                        bldr.Append(c);
                        return;
                }
            }
 
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
            {
                Host.AssertValue(input);
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                disposer = null;
 
                if (!(input.Schema[_parent.ColumnPairs[iinfo].inputColumnName].Type is VectorDataViewType))
                    return MakeGetterOne(input, iinfo);
                return MakeGetterVec(input, iinfo);
            }
 
            private ValueGetter<VBuffer<ushort>> MakeGetterOne(DataViewRow input, int iinfo)
            {
                Host.AssertValue(input);
                var getSrc = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[ColMapNewToOld[iinfo]]);
                var src = default(ReadOnlyMemory<char>);
                return
                    (ref VBuffer<ushort> dst) =>
                    {
                        getSrc(ref src);
 
                        var len = !src.IsEmpty ? (_parent._useMarkerChars ? src.Length + TextMarkersCount : src.Length) : 0;
                        var editor = VBufferEditor.Create(ref dst, len);
                        if (len > 0)
                        {
                            int index = 0;
                            if (_parent._useMarkerChars)
                                editor.Values[index++] = TextStartMarker;
                            var span = src.Span;
                            for (int ich = 0; ich < src.Length; ich++)
                                editor.Values[index++] = span[ich];
                            if (_parent._useMarkerChars)
                                editor.Values[index++] = TextEndMarker;
                            Contracts.Assert(index == len);
                        }
 
                        dst = editor.Commit();
                    };
            }
 
            private ValueGetter<VBuffer<ushort>> MakeGetterVec(DataViewRow input, int iinfo)
            {
                Host.AssertValue(input);
 
                int cv = input.Schema[ColMapNewToOld[iinfo]].Type.GetVectorSize();
                Contracts.Assert(cv >= 0);
 
                var getSrc = input.GetGetter<VBuffer<ReadOnlyMemory<char>>>(input.Schema[ColMapNewToOld[iinfo]]);
                var src = default(VBuffer<ReadOnlyMemory<char>>);
 
                ValueGetter<VBuffer<ushort>> getterWithStartEndSep = (ref VBuffer<ushort> dst) =>
                {
                    getSrc(ref src);
 
                    int len = 0;
                    var srcValues = src.GetValues();
                    for (int i = 0; i < srcValues.Length; i++)
                    {
                        if (!srcValues[i].IsEmpty)
                        {
                            len += srcValues[i].Length;
                            if (_parent._useMarkerChars)
                                len += TextMarkersCount;
                        }
                    }
 
                    var editor = VBufferEditor.Create(ref dst, len);
                    if (len > 0)
                    {
                        int index = 0;
                        for (int i = 0; i < srcValues.Length; i++)
                        {
                            if (srcValues[i].IsEmpty)
                                continue;
                            if (_parent._useMarkerChars)
                                editor.Values[index++] = TextStartMarker;
                            var span = srcValues[i].Span;
                            for (int ich = 0; ich < srcValues[i].Length; ich++)
                                editor.Values[index++] = span[ich];
                            if (_parent._useMarkerChars)
                                editor.Values[index++] = TextEndMarker;
                        }
                        Contracts.Assert(index == len);
                    }
 
                    dst = editor.Commit();
                };
 
                ValueGetter<VBuffer<ushort>> getterWithUnitSep = (ref VBuffer<ushort> dst) =>
                {
                    getSrc(ref src);
 
                    int len = 0;
 
                    var srcValues = src.GetValues();
                    for (int i = 0; i < srcValues.Length; i++)
                    {
                        if (!srcValues[i].IsEmpty)
                        {
                            len += srcValues[i].Length;
 
                            if (i > 0)
                                len += 1;  // add UnitSeparator character to len that will be added
                        }
                    }
 
                    if (_parent._useMarkerChars)
                        len += TextMarkersCount;
 
                    var editor = VBufferEditor.Create(ref dst, len);
                    if (len > 0)
                    {
                        int index = 0;
 
                        // ReadOnlyMemory can be a result of either concatenating text columns together
                        // or application of word tokenizer before char tokenizer in TextFeaturizingEstimator.
                        //
                        // Considering VBuffer<ReadOnlyMemory> as a single text stream.
                        // Therefore, prepend and append start and end markers only once i.e. at the start and at end of vector.
                        // Insert UnitSeparator after every piece of text in the vector.
                        if (_parent._useMarkerChars)
                            editor.Values[index++] = TextStartMarker;
 
                        for (int i = 0; i < srcValues.Length; i++)
                        {
                            if (srcValues[i].IsEmpty)
                                continue;
 
                            if (i > 0)
                                editor.Values[index++] = UnitSeparator;
 
                            var span = srcValues[i].Span;
                            for (int ich = 0; ich < srcValues[i].Length; ich++)
                                editor.Values[index++] = span[ich];
                        }
 
                        if (_parent._useMarkerChars)
                            editor.Values[index++] = TextEndMarker;
 
                        Contracts.Assert(index == len);
                    }
 
                    dst = editor.Commit();
                };
                return _parent._isSeparatorStartEnd ? getterWithStartEndSep : getterWithUnitSep;
            }
        }
    }
 
    /// <summary>
    /// <see cref="IEstimator{TTransformer}"/> for the <see cref="TokenizingByCharactersTransformer"/>.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    ///
    /// ###  Estimator Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Does this estimator need to look at the data to train its parameters? | Yes |
    /// | Input column data type | Scalar or Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType)  |
    /// | Output column data type | Variable-sized vector of [key](xref:Microsoft.ML.Data.KeyDataViewType) type. |
    /// | Exportable to ONNX | Yes |
    ///
    /// The estimator tokenizes characters by splitting text into sequences of characters using a sliding window.
    /// During training, the estimator builds a key-value pair dictionary with the encountered sequences of characters.
    ///
    /// The <xref:Microsoft.ML.Transforms.Text.TokenizingByCharactersTransformer> resulting from fitting the estimator
    /// creates a new column, named as specified in the output column name parameters, which contains the keys of the
    /// sequences of characters that were encountered in the input.
    ///
    /// Check the See Also section for links to usage examples.
    /// ]]>
    /// </format>
    /// </remarks>
    /// <seealso cref="TextCatalog.TokenizeIntoCharactersAsKeys(TransformsCatalog.TextTransforms, string, string, bool)" />
    public sealed class TokenizingByCharactersEstimator : TrivialEstimator<TokenizingByCharactersTransformer>
    {
        internal static class Defaults
        {
            public const bool UseMarkerCharacters = true;
        }
 
        internal static bool IsColumnTypeValid(DataViewType type) => type.GetItemType() is TextDataViewType;
 
        internal const string ExpectedColumnType = "String";
 
        /// <summary>
        /// Tokenize incoming text in <paramref name="inputColumnName"/> and output the tokens as <paramref name="outputColumnName"/>.
        /// </summary>
        /// <param name="env">The environment.</param>
        /// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
        /// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
        /// <param name="useMarkerCharacters">Whether to prepend a marker character, <see langword="0x02"/>, to the beginning,
        /// and append another marker character, <see langword="0x03"/>, to the end of the output vector of characters.</param>
        internal TokenizingByCharactersEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null,
            bool useMarkerCharacters = Defaults.UseMarkerCharacters)
            : this(env, useMarkerCharacters, new[] { (outputColumnName, inputColumnName ?? outputColumnName) })
        {
        }
 
        /// <summary>
        /// Tokenize incoming text in input columns and output the tokens as output columns.
        /// </summary>
        /// <param name="env">The environment.</param>
        /// <param name="useMarkerCharacters">Whether to prepend a marker character, <see langword="0x02"/>, to the beginning,
        /// and append another marker character, <see langword="0x03"/>, to the end of the output vector of characters.</param>
        /// <param name="columns">Pairs of columns to run the tokenization on.</param>
 
        internal TokenizingByCharactersEstimator(IHostEnvironment env, bool useMarkerCharacters = Defaults.UseMarkerCharacters,
            params (string outputColumnName, string inputColumnName)[] columns)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TokenizingByCharactersEstimator)), new TokenizingByCharactersTransformer(env, useMarkerCharacters, columns))
        {
        }
 
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);
            foreach (var colInfo in Transformer.Columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.inputColumnName, out var col))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName);
                if (!IsColumnTypeValid(col.ItemType))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName, ExpectedColumnType, col.ItemType.ToString());
                var metadata = new List<SchemaShape.Column>();
                if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
                    metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, slotMeta.ItemType, false));
                metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false));
                result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName, SchemaShape.Column.VectorKind.VariableVector, NumberDataViewType.UInt16, true, new SchemaShape(metadata.ToArray()));
            }
 
            return new SchemaShape(result.Values);
        }
    }
}