File: Dracula\CountTargetEncodingTransformer.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(CountTargetEncodingTransformer.Summary, typeof(IDataTransform), typeof(CountTargetEncodingEstimator), typeof(CountTargetEncodingEstimator.Options), typeof(SignatureDataTransform),
    CountTargetEncodingTransformer.UserName, CountTargetEncodingTransformer.LoaderSignature, "Dracula")]
 
[assembly: LoadableClass(CountTargetEncodingTransformer.Summary, typeof(CountTargetEncodingTransformer), typeof(CountTargetEncodingEstimator), null, typeof(SignatureLoadModel),
    CountTargetEncodingTransformer.UserName, CountTargetEncodingTransformer.LoaderSignature)]
 
[assembly: EntryPointModule(typeof(CountTargetEncoder))]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// Transforms a categorical column into a set of features that includes the count of each label class,
    /// the log-odds for each label class and the back-off indicator.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    /// ###  Estimator Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Does this estimator need to look at the data to train its parameters? | Yes |
    /// | Input column data type | Any |
    /// | Output column data type | Vector of <xref:System.Single>. |
    ///
    /// The resulting <xref:Microsoft.ML.Transforms.CountTargetEncodingTransformer> creates a new column, named as specified in the output column name parameters,
    /// containing three parts: the count of each label class, the log-odds for each label class and the back-off indicator.
    ///
    /// Check the See Also section for links to usage examples.
    /// ]]>
    /// </format>
    /// </remarks>
    /// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, InputOutputColumnPair[], CountTargetEncodingTransformer, string)" />
    /// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, InputOutputColumnPair[], string, CountTableBuilderBase, float, float, bool, int, bool, uint)" />
    /// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, string, CountTargetEncodingTransformer, string, string)"/>
    /// <seealso cref="CountTargetEncodingCatalog.CountTargetEncode(TransformsCatalog, string, string, string, CountTableBuilderBase, float, float, int, bool, uint)"/>
    internal class CountTargetEncodingEstimator : IEstimator<CountTargetEncodingTransformer>
    {
        /// <summary>
        /// This is a merger of arguments for <see cref="CountTableTransformer"/> and <see cref="HashingTransformer"/>
        /// </summary>
        internal sealed class Options : TransformInputBase
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)",
                ShortName = "col", SortOrder = 1)]
            public Column[] Columns;
 
            [Argument(ArgumentType.Multiple, HelpText = "Count table settings", ShortName = "table", SignatureType = typeof(SignatureCountTableBuilder))]
            public ICountTableBuilderFactory CountTable = new CMCountTableBuilder.Options();
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The coefficient with which to apply the prior smoothing to the features", ShortName = "prior")]
            public float PriorCoefficient = 1;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Laplacian noise diversity/scale-parameter. Suggest keeping it less than 1.", ShortName = "laplace")]
            public float LaplaceScale = 0;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Seed for the random generator for the laplacian noise.", ShortName = "seed")]
            public int Seed = 314489979;
 
            [Argument(ArgumentType.Required, HelpText = "Label column", ShortName = "label,lab", Purpose = SpecialPurpose.ColumnName)]
            public string LabelColumn;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Optional model file to load counts from. If this is specified all other options are ignored.", ShortName = "inmodel, extfile")]
            public string InitialCountsModel;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Keep counts for all columns in one shared count table", ShortName = "shared")]
            public bool SharedTable = false;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether the values need to be combined for a single hash", SortOrder = 3)]
            public bool Combine = true;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive.",
                ShortName = "bits")]
            public int NumberOfBits = HashingEstimator.NumBitsLim - 1;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
            public uint HashingSeed = 314489979;
        }
 
        internal sealed class Column : OneToOneColumn
        {
            [Argument(ArgumentType.Multiple, HelpText = "Count table settings", ShortName = "table", SignatureType = typeof(SignatureCountTableBuilder))]
            public ICountTableBuilderFactory CountTable;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The coefficient with which to apply the prior smoothing to the features", ShortName = "prior")]
            public float? PriorCoefficient;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Laplacian noise diversity/scale-parameter. Suggest keeping it less than 1.", ShortName = "laplace")]
            public float? LaplaceScale;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Seed for the random generator for the laplacian noise.", ShortName = "seed")]
            public int? Seed;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether the values need to be combined for a single hash")]
            public bool? Combine;
 
            public static Column Parse(string str)
            {
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
        }
 
        private readonly IHost _host;
        private readonly HashingEstimator.ColumnOptions[] _hashingColumns;
        private readonly HashingEstimator _hashingEstimator;
        private readonly CountTableEstimator _countTableEstimator;
 
        internal CountTargetEncodingEstimator(IHostEnvironment env, string labelColumnName, CountTableEstimator.ColumnOptions[] columnOptions,
            int numberOfBits = HashingEstimator.Defaults.NumberOfBits,
            bool combine = HashingEstimator.Defaults.Combine, uint hashingSeed = HashingEstimator.Defaults.Seed)
            : this(env, new CountTableEstimator(env, labelColumnName,
                columnOptions.Select(col => new CountTableEstimator.ColumnOptions(col.Name, col.Name, col.CountTableBuilder, col.PriorCoefficient, col.LaplaceScale, col.Seed)).ToArray()),
                  columnOptions, numberOfBits, combine, hashingSeed)
        {
        }
 
        internal CountTargetEncodingEstimator(IHostEnvironment env, string labelColumnName, CountTableEstimator.SharedColumnOptions[] columnOptions,
            CountTableBuilderBase countTableBuilder, int numberOfBits = HashingEstimator.Defaults.NumberOfBits,
            bool combine = HashingEstimator.Defaults.Combine, uint hashingSeed = HashingEstimator.Defaults.Seed)
            : this(env, new CountTableEstimator(env, labelColumnName, countTableBuilder,
                columnOptions.Select(col => new CountTableEstimator.SharedColumnOptions(col.Name, col.Name, col.PriorCoefficient, col.LaplaceScale, col.Seed)).ToArray()),
                  columnOptions, numberOfBits, combine, hashingSeed)
        {
        }
 
        internal CountTargetEncodingEstimator(IHostEnvironment env, string labelColumnName, CountTargetEncodingTransformer initialCounts, params InputOutputColumnPair[] columns)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));
            _host.CheckValue(initialCounts, nameof(initialCounts));
            _host.CheckNonEmpty(columns, nameof(columns));
            _host.Check(initialCounts.VerifyColumns(columns), nameof(columns));
 
            _hashingEstimator = new HashingEstimator(_host, initialCounts.HashingTransformer.Columns.ToArray());
            _countTableEstimator = new CountTableEstimator(_host, labelColumnName, initialCounts.CountTable,
                columns.Select(c => new InputOutputColumnPair(c.OutputColumnName, c.OutputColumnName)).ToArray());
        }
 
        private CountTargetEncodingEstimator(IHostEnvironment env, CountTableEstimator estimator, CountTableEstimator.ColumnOptionsBase[] columns,
            int numberOfBits, bool combine, uint hashingSeed)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));
 
            _hashingColumns = InitializeHashingColumnOptions(columns, numberOfBits, combine, hashingSeed);
            _hashingEstimator = new HashingEstimator(_host, _hashingColumns);
            _countTableEstimator = estimator;
        }
 
        internal CountTargetEncodingEstimator(IHostEnvironment env, Options options)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));
            _host.CheckValue(options, nameof(options));
            _host.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(options.LabelColumn), nameof(options.LabelColumn), "Must specify the label column name");
 
            if (!string.IsNullOrEmpty(options.InitialCountsModel))
            {
                _countTableEstimator = LoadFromFile(env, options.InitialCountsModel, options.LabelColumn,
                    options.Columns.Select(col => new InputOutputColumnPair(col.Name)).ToArray());
                if (_countTableEstimator == null)
                    throw env.Except($"The file {options.InitialCountsModel} does not contain a CountTableTransformer");
            }
            else if (options.SharedTable)
            {
                var columns = new CountTableEstimator.SharedColumnOptions[options.Columns.Length];
                for (int i = 0; i < options.Columns.Length; i++)
                {
                    var column = options.Columns[i];
                    columns[i] = new CountTableEstimator.SharedColumnOptions(
                        column.Name,
                        column.Name,
                        column.PriorCoefficient ?? options.PriorCoefficient,
                        column.LaplaceScale ?? options.LaplaceScale,
                        column.Seed ?? options.Seed);
                }
                var builder = options.CountTable;
                _host.CheckValue(builder, nameof(options.CountTable));
                _countTableEstimator = new CountTableEstimator(_host, options.LabelColumn, builder.CreateComponent(_host), columns);
            }
            else
            {
                var columns = new CountTableEstimator.ColumnOptions[options.Columns.Length];
                for (int i = 0; i < options.Columns.Length; i++)
                {
                    var column = options.Columns[i];
                    var builder = column.CountTable ?? options.CountTable;
                    _host.CheckValue(builder, nameof(options.CountTable));
                    columns[i] = new CountTableEstimator.ColumnOptions(
                        column.Name,
                        column.Name,
                        builder.CreateComponent(_host),
                        column.PriorCoefficient ?? options.PriorCoefficient,
                        column.LaplaceScale ?? options.LaplaceScale,
                        column.Seed ?? options.Seed);
                }
                _countTableEstimator = new CountTableEstimator(_host, options.LabelColumn, columns);
            }
 
            _hashingColumns = InitializeHashingColumnOptions(options);
            _hashingEstimator = new HashingEstimator(_host, _hashingColumns);
        }
 
        internal static CountTableEstimator LoadFromFile(IHostEnvironment env, string initialCountsModel, string labelColumn, InputOutputColumnPair[] columns)
        {
            var catalog = new ModelOperationsCatalog(env);
            var transformer = catalog.Load(initialCountsModel, out _);
            CountTableTransformer countTableTransformer;
            if ((countTableTransformer = transformer as CountTableTransformer) == null)
            {
                if (transformer is CountTargetEncodingTransformer cte)
                    countTableTransformer = cte.CountTable;
                else if (transformer is TransformerChain<ITransformer> chain)
                {
                    var transformerChain = ((ITransformerChainAccessor)chain).Transformers;
                    for (int i = transformerChain.Length - 1; i >= 0; i--)
                    {
                        countTableTransformer = transformerChain[i] as CountTableTransformer;
                        if (countTableTransformer == null)
                        {
                            if ((cte = transformerChain[i] as CountTargetEncodingTransformer) != null)
                                countTableTransformer = cte.CountTable;
                        }
                        if (countTableTransformer != null)
                            break;
                    }
                }
            }
            if (countTableTransformer != null)
                return new CountTableEstimator(env, labelColumn, countTableTransformer, columns);
            return null;
        }
 
        private HashingEstimator.ColumnOptions[] InitializeHashingColumnOptions(CountTableEstimator.ColumnOptionsBase[] columns, int numberOfBits, bool combine, uint hashingSeed)
        {
            var cols = new HashingEstimator.ColumnOptions[columns.Length];
            for (int i = 0; i < cols.Length; i++)
            {
                var column = columns[i];
                cols[i] = new HashingEstimator.ColumnOptions(column.Name, column.InputColumnName,
                    numberOfBits, hashingSeed, combine: combine);
            }
            return cols;
        }
 
        private HashingEstimator.ColumnOptions[] InitializeHashingColumnOptions(Options options)
        {
            var columns = options.Columns;
            var cols = new HashingEstimator.ColumnOptions[columns.Length];
            for (int i = 0; i < cols.Length; i++)
            {
                var column = columns[i];
                cols[i] = new HashingEstimator.ColumnOptions(column.Name, column.Source,
                    options.NumberOfBits, options.HashingSeed, combine: column.Combine ?? options.Combine);
            }
            return cols;
        }
 
        // Factory method for SignatureDataTransform.
        internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            env.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns));
 
            var estimator = new CountTargetEncodingEstimator(env, options);
            return (estimator.Fit(input) as ITransformerWithDifferentMappingAtTrainingTime).TransformForTrainingPipeline(input) as IDataTransform;
        }
 
        private static CountTargetEncodingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
            => CountTargetEncodingTransformer.Create(env, ctx);
 
        public CountTargetEncodingTransformer Fit(IDataView input)
        {
            var hashingTransformer = _hashingEstimator.Fit(input);
            return new CountTargetEncodingTransformer(_host, hashingTransformer, _countTableEstimator.Fit(hashingTransformer.Transform(input)));
        }
 
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            var hashingOutputSchema = _hashingEstimator.GetOutputSchema(inputSchema);
            return _countTableEstimator.GetOutputSchema(hashingOutputSchema);
        }
    }
 
    /// <summary>
    /// <see cref="ITransformer"/> resulting from fitting a <see cref="LpNormNormalizingEstimator"/> or <see cref="CountTargetEncodingEstimator"/>.
    /// </summary>
    internal sealed class CountTargetEncodingTransformer : ITransformerWithDifferentMappingAtTrainingTime
    {
        private readonly IHost _host;
        //internal readonly HashingEstimator.ColumnOptions[] HashingColumns;
        public readonly CountTableTransformer CountTable;
        public readonly HashingTransformer HashingTransformer;
 
        internal const string Summary = "Transforms the categorical column into the set of features: count of each label class, "
            + "log-odds for each label class, back-off indicator. The columns can be of arbitrary type.";
        internal const string LoaderSignature = "CountTargetEncode";
        internal const string UserName = "Count Target Encoding Transform";
 
        bool ITransformer.IsRowToRowMapper => true;
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "DRACULA ",
                 verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(CountTargetEncodingTransformer).Assembly.FullName);
        }
 
        internal CountTargetEncodingTransformer(IHostEnvironment env/*, HashingEstimator.ColumnOptions[] hashingColumns*/, HashingTransformer hashingTransformer, CountTableTransformer countTable)
        {
            Contracts.AssertValue(env);
            env.AssertValue(hashingTransformer);
            env.AssertValue(countTable);
            _host = env.Register(nameof(CountTargetEncodingTransformer));
            HashingTransformer = hashingTransformer;
            CountTable = countTable;
        }
 
        private CountTargetEncodingTransformer(IHost host, ModelLoadContext ctx)
        {
            _host = host;
 
            // *** Binary format ***
            // _hashingTransformer
            // _countTable
 
            ctx.LoadModel<HashingTransformer, SignatureLoadModel>(_host, out HashingTransformer, "HashingTransformer");
            ctx.LoadModel<CountTableTransformer, SignatureLoadModel>(_host, out CountTable, "CountTable");
        }
 
        public void Save(ModelSaveContext ctx)
        {
            _host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // _hashingTransformer
            // _countTable
 
            ctx.SaveModel(HashingTransformer, "HashingTransformer");
            ctx.SaveModel(CountTable, "CountTable");
        }
 
        internal static CountTargetEncodingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(LoaderSignature);
 
            host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            return new CountTargetEncodingTransformer(host, ctx);
        }
 
        public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var chain = new TransformerChain<ITransformer>(HashingTransformer, CountTable);
            return chain.GetOutputSchema(inputSchema);
        }
 
        IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            ITransformer chain = new TransformerChain<ITransformer>(HashingTransformer, CountTable);
            return chain.GetRowToRowMapper(inputSchema);
        }
 
        public IDataView Transform(IDataView input)
        {
            _host.CheckValue(input, nameof(input));
            var chain = new TransformerChain<ITransformer>(HashingTransformer, CountTable);
            return chain.Transform(input);
        }
 
        IDataView ITransformerWithDifferentMappingAtTrainingTime.TransformForTrainingPipeline(IDataView input)
        {
            var hashedData = HashingTransformer.Transform(input);
            return (CountTable as ITransformerWithDifferentMappingAtTrainingTime).TransformForTrainingPipeline(hashedData);
        }
 
        internal bool VerifyColumns(InputOutputColumnPair[] columns)
        {
            return columns.All(c => HashingTransformer.Columns.Select(hc => hc.InputColumnName).Contains(c.InputColumnName))
                && columns.All(c => HashingTransformer.Columns.Select(hc => hc.Name).Contains(c.OutputColumnName));
        }
    }
 
    internal static class CountTargetEncodingCatalog
    {
        /// <summary>
        /// Transforms a categorical column into a set of features that includes the count of each label class,
        /// the log-odds for each label class and the back-off indicator.
        /// </summary>
        /// <param name="catalog">The transforms catalog.</param>
        /// <param name="columns">The input and output columns.</param>
        /// <param name="labelColumn">The name of the label column.</param>
        /// <param name="builder">The builder that creates the count tables from the training data.</param>
        /// <param name="priorCoefficient">The coefficient with which to apply the prior smoothing to the features.</param>
        /// <param name="laplaceScale">The Laplacian noise diversity/scale-parameter. Recommended values are between 0 and 1. Note that the noise
        /// will only be applied if the estimator is part of an <see cref="EstimatorChain{TLastTransformer}"/>, when fitting the next estimator in the chain.</param>
        /// <param name="sharedTable">Indicates whether to keep counts for all columns and slots in one shared count table. If true, the keys in the count table
        /// will include a hash of the column and slot indices.</param>
        /// <param name="numberOfBits">The number of bits to hash the input into. Must be between 1 and 31, inclusive.</param>
        /// <param name="combine">In case the input is a vector column, indicates whether the values should be combined into a single hash to create a single
        /// count table, or be left as a vector of hashes with multiple count tables.</param>
        /// <param name="hashingSeed">The seed used for hashing the input columns.</param>
        /// <returns></returns>
        public static CountTargetEncodingEstimator CountTargetEncode(this TransformsCatalog catalog,
            InputOutputColumnPair[] columns, string labelColumn = DefaultColumnNames.Label,
            CountTableBuilderBase builder = null,
            float priorCoefficient = CountTableTransformer.Defaults.PriorCoefficient,
            float laplaceScale = CountTableTransformer.Defaults.LaplaceScale,
            bool sharedTable = CountTableTransformer.Defaults.SharedTable,
            int numberOfBits = HashingEstimator.Defaults.NumberOfBits,
            bool combine = HashingEstimator.Defaults.Combine,
            uint hashingSeed = HashingEstimator.Defaults.Seed)
        {
            var env = CatalogUtils.GetEnvironment(catalog);
            env.CheckValue(columns, nameof(columns));
 
            builder = builder ?? new CMCountTableBuilder();
 
            CountTargetEncodingEstimator estimator;
            if (sharedTable)
            {
                var columnOptions = new CountTableEstimator.SharedColumnOptions[columns.Length];
                for (int i = 0; i < columns.Length; i++)
                {
                    columnOptions[i] = new CountTableEstimator.SharedColumnOptions(
                        columns[i].OutputColumnName, columns[i].InputColumnName, priorCoefficient, laplaceScale);
                }
                estimator = new CountTargetEncodingEstimator(env, labelColumn, columnOptions, builder, numberOfBits, combine, hashingSeed);
            }
            else
            {
                var columnOptions = new CountTableEstimator.ColumnOptions[columns.Length];
                for (int i = 0; i < columns.Length; i++)
                {
                    columnOptions[i] = new CountTableEstimator.ColumnOptions(
                        columns[i].OutputColumnName, columns[i].InputColumnName, builder, priorCoefficient, laplaceScale);
                }
                estimator = new CountTargetEncodingEstimator(env, labelColumn, columnOptions, numberOfBits: numberOfBits, combine: combine, hashingSeed: hashingSeed);
            }
            return estimator;
        }
 
        /// <summary>
        /// Transforms a categorical column into a set of features that includes the count of each label class,
        /// the log-odds for each label class and the back-off indicator.
        /// </summary>
        /// <param name="catalog">The transforms catalog.</param>
        /// <param name="columns">The input and output columns.</param>
        /// <param name="initialCounts">A previously trained count table containing initial counts.</param>
        /// <param name="labelColumn">The name of the label column.</param>
        /// <returns></returns>
        public static CountTargetEncodingEstimator CountTargetEncode(this TransformsCatalog catalog,
            InputOutputColumnPair[] columns, CountTargetEncodingTransformer initialCounts, string labelColumn = "Label")
        {
            return new CountTargetEncodingEstimator(CatalogUtils.GetEnvironment(catalog), labelColumn, initialCounts, columns);
        }
 
        /// <summary>
        /// Transforms a categorical column into a set of features that includes the count of each label class,
        /// the log-odds for each label class and the back-off indicator.
        /// </summary>
        /// <param name="catalog">The transforms catalog.</param>
        /// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
        /// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
        /// <param name="labelColumn">The name of the label column.</param>
        /// <param name="builder">The builder that creates the count tables from the training data.</param>
        /// <param name="priorCoefficient">The coefficient with which to apply the prior smoothing to the features.</param>
        /// <param name="laplaceScale">The Laplacian noise diversity/scale-parameter. Recommended values are between 0 and 1. Note that the noise
        /// will only be applied if the estimator is part of an <see cref="EstimatorChain{TLastTransformer}"/>, when fitting the next estimator in the chain.</param>
        /// <param name="numberOfBits">The number of bits to hash the input into. Must be between 1 and 31, inclusive.</param>
        /// <param name="combine">In case the input is a vector column, indicates whether the values should be combined into a single hash to create a single
        /// count table, or be left as a vector of hashes with multiple count tables.</param>
        /// <param name="hashingSeed">The seed used for hashing the input columns.</param>
        public static CountTargetEncodingEstimator CountTargetEncode(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null,
            string labelColumn = DefaultColumnNames.Label,
            CountTableBuilderBase builder = null,
            float priorCoefficient = CountTableTransformer.Defaults.PriorCoefficient,
            float laplaceScale = CountTableTransformer.Defaults.LaplaceScale,
            int numberOfBits = HashingEstimator.Defaults.NumberOfBits,
            bool combine = HashingEstimator.Defaults.Combine,
            uint hashingSeed = HashingEstimator.Defaults.Seed)
        {
            var env = CatalogUtils.GetEnvironment(catalog);
            env.CheckNonEmpty(outputColumnName, nameof(outputColumnName));
 
            inputColumnName = string.IsNullOrEmpty(inputColumnName) ? outputColumnName : inputColumnName;
            builder = builder ?? new CMCountTableBuilder();
 
            return new CountTargetEncodingEstimator(env, labelColumn,
                new[] { new CountTableEstimator.ColumnOptions(outputColumnName, inputColumnName, builder, priorCoefficient, laplaceScale) },
                numberOfBits, combine, hashingSeed);
        }
 
        /// <summary>
        /// Transforms a categorical column into a set of features that includes the count of each label class,
        /// the log-odds for each label class and the back-off indicator.
        /// </summary>
        /// <param name="catalog">The transforms catalog.</param>
        /// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
        /// <param name="initialCounts">A previously trained count table containing initial counts.</param>
        /// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
        /// <param name="labelColumn">The name of the label column.</param>
        /// <returns></returns>
        public static CountTargetEncodingEstimator CountTargetEncode(this TransformsCatalog catalog, string outputColumnName,
            CountTargetEncodingTransformer initialCounts,
            string inputColumnName = null, string labelColumn = "Label")
        {
            return new CountTargetEncodingEstimator(CatalogUtils.GetEnvironment(catalog), labelColumn, initialCounts, new[] { new InputOutputColumnPair(outputColumnName, inputColumnName) });
        }
    }
 
    internal static class CountTargetEncoder
    {
        [TlcModule.EntryPoint(Name = "Transforms.CountTargetEncoder", Desc = CountTargetEncodingTransformer.Summary, UserName = CountTableTransformer.UserName, ShortName = "Count")]
        internal static CommonOutputs.TransformOutput Create(IHostEnvironment env, CountTargetEncodingEstimator.Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
 
            var h = EntryPointUtils.CheckArgsAndCreateHost(env, nameof(CountTargetEncoder), input);
            var view = CountTargetEncodingEstimator.Create(h, input, input.Data);
            return new CommonOutputs.TransformOutput()
            {
                Model = new TransformModelImpl(h, view, input.Data),
                OutputData = view
            };
        }
    }
}