File: Transforms\NormalizeColumn.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
using Newtonsoft.Json.Linq;
 
[assembly: LoadableClass(NormalizeTransform.MinMaxNormalizerSummary, typeof(IDataTransform), typeof(NormalizeTransform), typeof(NormalizeTransform.MinMaxArguments), typeof(SignatureDataTransform),
    NormalizeTransform.MinMaxNormalizerUserName, "MinMaxNormalizer", NormalizeTransform.MinMaxNormalizerShortName)]
 
[assembly: LoadableClass(NormalizeTransform.MeanVarNormalizerSummary, typeof(IDataTransform), typeof(NormalizeTransform), typeof(NormalizeTransform.MeanVarArguments), typeof(SignatureDataTransform),
    NormalizeTransform.MeanVarNormalizerUserName, "MeanVarNormalizer", NormalizeTransform.MeanVarNormalizerShortName, "ZScoreNormalizer", "ZScore", "GaussianNormalizer", "Gaussian")]
 
[assembly: LoadableClass(NormalizeTransform.LogMeanVarNormalizerSummary, typeof(IDataTransform), typeof(NormalizeTransform), typeof(NormalizeTransform.LogMeanVarArguments), typeof(SignatureDataTransform),
    NormalizeTransform.LogMeanVarNormalizerUserName, "LogMeanVarNormalizer", NormalizeTransform.LogMeanVarNormalizerShortName, "LogNormalNormalizer", "LogNormal")]
 
[assembly: LoadableClass(NormalizeTransform.BinNormalizerSummary, typeof(IDataTransform), typeof(NormalizeTransform), typeof(NormalizeTransform.BinArguments), typeof(SignatureDataTransform),
    NormalizeTransform.BinNormalizerUserName, "BinNormalizer", NormalizeTransform.BinNormalizerShortName)]
 
[assembly: LoadableClass(NormalizeTransform.RobustScalingNormalizerSummary, typeof(IDataTransform), typeof(NormalizeTransform), typeof(NormalizeTransform.RobustScalingArguments), typeof(SignatureDataTransform),
    NormalizeTransform.RobustScalingNormalizerUserName, "RobustScalingNormalizer", NormalizeTransform.RobustScalingNormalizerShortName)]
 
[assembly: LoadableClass(typeof(NormalizeTransform.AffineColumnFunction), null, typeof(SignatureLoadColumnFunction),
    "Affine Normalizer", AffineNormSerializationUtils.LoaderSignature)]
 
[assembly: LoadableClass(typeof(NormalizeTransform.CdfColumnFunction), null, typeof(SignatureLoadColumnFunction),
    "CDF Normalizer", NormalizeTransform.CdfColumnFunction.LoaderSignature)]
 
[assembly: LoadableClass(NormalizeTransform.BinNormalizerSummary, typeof(NormalizeTransform.BinColumnFunction), null, typeof(SignatureLoadColumnFunction),
    "Bin Normalizer", NormalizeTransform.BinColumnFunction.LoaderSignature)]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// The normalize transform for support of normalization via the <see cref="IDataTransform"/> mechanism.
    /// More contemporaneous API usage of normalization ought to use <see cref="NormalizingEstimator"/>
    /// and <see cref="NormalizingTransformer"/> rather than this structure.
    /// </summary>
    internal sealed partial class NormalizeTransform
    {
        public abstract class ColumnBase : OneToOneColumn
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Max number of examples used to train the normalizer",
                Name = "MaxTrainingExamples", ShortName = "maxtrain")]
            public long? MaximumExampleCount;
 
            private protected ColumnBase()
            {
            }
 
            private protected override bool TryUnparseCore(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                if (MaximumExampleCount != null)
                    return false;
                return base.TryUnparseCore(sb);
            }
        }
 
        // REVIEW: Support different aggregators on different columns, eg, MinMax vs Variance/ZScore.
        public abstract class ControlZeroColumnBase : ColumnBase
        {
            // REVIEW: This only allows mapping either zero or min to zero. It might make sense to allow also max, midpoint and mean to be mapped to zero.
            [Argument(ArgumentType.AtMostOnce, Name = "FixZero", HelpText = "Whether to map zero to zero, preserving sparsity", ShortName = "zero")]
            public bool? EnsureZeroUntouched;
 
            private protected override bool TryUnparseCore(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                if (EnsureZeroUntouched != null)
                    return false;
                return base.TryUnparseCore(sb);
            }
        }
 
        public sealed class AffineColumn : ControlZeroColumnBase
        {
            internal static AffineColumn Parse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                var res = new AffineColumn();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                return TryUnparseCore(sb);
            }
        }
 
        public sealed class BinColumn : ControlZeroColumnBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Max number of bins, power of 2 recommended", ShortName = "bins")]
            [TGUI(Label = "Max number of bins")]
            public int? NumBins;
 
            internal static BinColumn Parse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                var res = new BinColumn();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                if (NumBins != null)
                    return false;
                return TryUnparseCore(sb);
            }
        }
 
        public sealed class LogNormalColumn : ColumnBase
        {
            internal static LogNormalColumn Parse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                var res = new LogNormalColumn();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                return TryUnparseCore(sb);
            }
        }
 
        private static class Defaults
        {
            public const bool EnsureZeroUntouched = true;
            public const bool MeanVarCdf = false;
            public const bool LogMeanVarCdf = true;
            public const int NumBins = 1024;
            public const int MinBinSize = 10;
            public const bool CenterData = true;
            public const int QuantileMin = 25;
            public const int QuantileMax = 75;
        }
 
        public abstract class ControlZeroArgumentsBase : ArgumentsBase
        {
            // REVIEW: This only allows mapping either zero or min to zero. It might make sense to allow also max, midpoint and mean to be mapped to zero.
            // REVIEW: Convert this to bool? or even an enum{Auto, No, Yes}, and automatically map zero to zero when it is null/Auto.
            [Argument(ArgumentType.AtMostOnce, Name = "FixZero", HelpText = "Whether to map zero to zero, preserving sparsity", ShortName = "zero")]
            public bool EnsureZeroUntouched = Defaults.EnsureZeroUntouched;
        }
 
        public abstract class AffineArgumentsBase : ControlZeroArgumentsBase
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public AffineColumn[] Columns;
 
            public override OneToOneColumn[] GetColumns() => Columns;
        }
 
        public sealed class MinMaxArguments : AffineArgumentsBase
        {
        }
 
        public sealed class MeanVarArguments : AffineArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use CDF as the output", ShortName = "cdf")]
            public bool UseCdf = Defaults.MeanVarCdf;
        }
 
        public abstract class ArgumentsBase : TransformInputBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Max number of examples used to train the normalizer",
                Name = "MaxTrainingExamples", ShortName = "maxtrain")]
            public long MaximumExampleCount = 1000000000;
 
            public abstract OneToOneColumn[] GetColumns();
 
            public string TestType(DataViewType type)
            {
                DataViewType itemType = type;
                if (type is VectorDataViewType vectorType)
                {
                    // We require vectors to be of known size.
                    if (!vectorType.IsKnownSize)
                        return "Expected known size vector";
 
                    itemType = vectorType.ItemType;
                }
 
                if (itemType != NumberDataViewType.Single && itemType != NumberDataViewType.Double)
                    return "Expected Single or Double item type";
 
                return null;
            }
        }
 
        public sealed class LogMeanVarArguments : ArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use CDF as the output", ShortName = "cdf")]
            public bool UseCdf = Defaults.LogMeanVarCdf;
 
            [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public LogNormalColumn[] Columns;
 
            public override OneToOneColumn[] GetColumns() => Columns;
        }
 
        public abstract class BinArgumentsBase : ControlZeroArgumentsBase
        {
            [Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public BinColumn[] Columns;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Max number of bins, power of 2 recommended", ShortName = "bins")]
            [TGUI(Label = "Max number of bins")]
            public int NumBins = Defaults.NumBins;
 
            public override OneToOneColumn[] GetColumns() => Columns;
        }
 
        public sealed class BinArguments : BinArgumentsBase
        {
        }
 
        public sealed class SupervisedBinArguments : BinArgumentsBase
        {
            // REVIEW: factor in a loss function / optimization algorithm to make it work better in regression case
            [Argument(ArgumentType.Required, HelpText = "Label column for supervised binning", ShortName = "label,lab",
                Purpose = SpecialPurpose.ColumnName)]
            public string LabelColumn;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Minimum number of examples per bin")]
            public int MinBinSize = Defaults.MinBinSize;
        }
 
        public sealed class RobustScalingArguments : AffineArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Should the data be centered around 0", Name = "CenterData", ShortName = "center", SortOrder = 1)]
            public bool CenterData = Defaults.CenterData;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Minimum quantile value. Defaults to 25", Name = "QuantileMin", ShortName = "qmin", SortOrder = 2)]
            public uint QuantileMin = Defaults.QuantileMin;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum quantile value. Defaults to 75", Name = "QuantileMax", ShortName = "qmax", SortOrder = 3)]
            public uint QuantileMax = Defaults.QuantileMax;
        }
 
        internal const string MinMaxNormalizerSummary = "Normalizes the data based on the observed minimum and maximum values of the data.";
        internal const string MeanVarNormalizerSummary = "Normalizes the data based on the computed mean and variance of the data.";
        internal const string LogMeanVarNormalizerSummary = "Normalizes the data based on the computed mean and variance of the logarithm of the data.";
        internal const string BinNormalizerSummary = "The values are assigned into equidensity bins and a value is mapped to its bin_number/number_of_bins.";
        internal const string SupervisedBinNormalizerSummary = "Similar to BinNormalizer, but calculates bins based on correlation with the label column, not equi-density. "
            + "The new value is bin_number / number_of_bins.";
        internal const string RobustScalingNormalizerSummary = "Optionally centers the data and scales based on the range of data and the quantile min and max values provided. "
            + "This method is more robust to outliers.";
 
        internal const string MinMaxNormalizerUserName = "Min-Max Normalizer";
        internal const string MeanVarNormalizerUserName = "MeanVar Normalizer";
        internal const string LogMeanVarNormalizerUserName = "LogMeanVar Normalizer";
        internal const string BinNormalizerUserName = "Binning Normalizer";
        internal const string SupervisedBinNormalizerUserName = "Supervised Binning Normalizer";
        internal const string RobustScalingNormalizerUserName = "Robust Scaling Normalizer";
 
        internal const string MinMaxNormalizerShortName = "MinMax";
        internal const string MeanVarNormalizerShortName = "MeanVar";
        internal const string LogMeanVarNormalizerShortName = "LogMeanVar";
        internal const string BinNormalizerShortName = "Bin";
        internal const string SupervisedBinNormalizerShortName = "SupBin";
        internal const string RobustScalingNormalizerShortName = "RobScal";
 
        /// <summary>
        /// A helper method to create a MinMax normalizer.
        /// </summary>
        /// <param name="env">Host Environment.</param>
        /// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
        /// <param name="outputColumnName">Name of the output column.</param>
        /// <param name="inputColumnName">Name of the column to be transformed. If this is null '<paramref name="outputColumnName"/>' will be used.</param>
        public static IDataView CreateMinMaxNormalizer(IHostEnvironment env, IDataView input, string outputColumnName, string inputColumnName = null)
        {
            Contracts.CheckValue(env, nameof(env));
 
            var normalizer = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(outputColumnName, inputColumnName ?? outputColumnName));
            return normalizer.Fit(input).MakeDataTransform(input);
        }
 
        /// <summary>
        /// Factory method corresponding to SignatureDataTransform.
        /// </summary>
        internal static IDataTransform Create(IHostEnvironment env, MinMaxArguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(args, nameof(args));
            env.CheckValue(args.Columns, nameof(args.Columns));
 
            var columns = args.Columns
                .Select(col => new NormalizingEstimator.MinMaxColumnOptions(
                    col.Name,
                    col.Source ?? col.Name,
                    col.MaximumExampleCount ?? args.MaximumExampleCount,
                    col.EnsureZeroUntouched ?? args.EnsureZeroUntouched))
                .ToArray();
            var normalizer = new NormalizingEstimator(env, columns);
            return normalizer.Fit(input).MakeDataTransform(input);
        }
 
        // Factory method corresponding to SignatureDataTransform.
        internal static IDataTransform Create(IHostEnvironment env, MeanVarArguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(args, nameof(args));
            env.CheckValue(args.Columns, nameof(args.Columns));
 
            var columns = args.Columns
                .Select(col => new NormalizingEstimator.MeanVarianceColumnOptions(
                    col.Name,
                    col.Source ?? col.Name,
                    col.MaximumExampleCount ?? args.MaximumExampleCount,
                    col.EnsureZeroUntouched ?? args.EnsureZeroUntouched))
                .ToArray();
            var normalizer = new NormalizingEstimator(env, columns);
            return normalizer.Fit(input).MakeDataTransform(input);
        }
 
        /// <summary>
        /// Factory method corresponding to SignatureDataTransform.
        /// </summary>
        internal static IDataTransform Create(IHostEnvironment env, LogMeanVarArguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(args, nameof(args));
            env.CheckValue(args.Columns, nameof(args.Columns));
 
            var columns = args.Columns
                .Select(col => new NormalizingEstimator.LogMeanVarianceColumnOptions(
                    col.Name,
                    col.Source ?? col.Name,
                    col.MaximumExampleCount ?? args.MaximumExampleCount,
                    args.UseCdf))
                .ToArray();
            var normalizer = new NormalizingEstimator(env, columns);
            return normalizer.Fit(input).MakeDataTransform(input);
        }
 
        /// <summary>
        /// Factory method corresponding to SignatureDataTransform.
        /// </summary>
        internal static IDataTransform Create(IHostEnvironment env, BinArguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(args, nameof(args));
            env.CheckValue(args.Columns, nameof(args.Columns));
 
            var columns = args.Columns
                .Select(col => new NormalizingEstimator.BinningColumnOptions(
                    col.Name,
                    col.Source ?? col.Name,
                    col.MaximumExampleCount ?? args.MaximumExampleCount,
                    col.EnsureZeroUntouched ?? args.EnsureZeroUntouched,
                    col.NumBins ?? args.NumBins))
                .ToArray();
            var normalizer = new NormalizingEstimator(env, columns);
            return normalizer.Fit(input).MakeDataTransform(input);
        }
 
        /// <summary>
        /// Factory method corresponding to SignatureDataTransform.
        /// </summary>
        internal static IDataTransform Create(IHostEnvironment env, RobustScalingArguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(args, nameof(args));
            env.CheckValue(args.Columns, nameof(args.Columns));
 
            var columns = args.Columns
                .Select(col => new NormalizingEstimator.RobustScalingColumnOptions(
                    col.Name,
                    col.Source ?? col.Name,
                    col.MaximumExampleCount ?? args.MaximumExampleCount,
                    args.CenterData,
                    args.QuantileMin,
                    args.QuantileMax))
                .ToArray();
            var normalizer = new NormalizingEstimator(env, columns);
            return normalizer.Fit(input).MakeDataTransform(input);
        }
 
        internal abstract partial class AffineColumnFunction : IColumnFunction
        {
            protected readonly IHost Host;
 
            // The only derived classes are private inner classes
            private AffineColumnFunction(IHost host)
            {
                Contracts.CheckValue(host, nameof(host));
                Host = host;
            }
 
            void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
 
            private protected abstract void SaveModel(ModelSaveContext ctx);
 
            public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken);
            public bool CanSaveOnnx(OnnxContext ctx) => true;
            public abstract bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount);
 
            public abstract Delegate GetGetter(DataViewRow input, int icol);
 
            public abstract void AttachMetadata(MetadataDispatcher.Builder bldr, DataViewType typeSrc);
 
            public abstract NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams();
 
            public static AffineColumnFunction Create(ModelLoadContext ctx, IHost host, DataViewType typeSrc)
            {
                Contracts.CheckValue(host, nameof(host));
                if (typeSrc is NumberDataViewType)
                {
                    if (typeSrc == NumberDataViewType.Single)
                        return Sng.ImplOne.Create(ctx, host, typeSrc);
                    if (typeSrc == NumberDataViewType.Double)
                        return Dbl.ImplOne.Create(ctx, host, typeSrc);
                }
                else if (typeSrc is VectorDataViewType vectorType && vectorType.ItemType is NumberDataViewType)
                {
                    if (vectorType.ItemType == NumberDataViewType.Single)
                        return Sng.ImplVec.Create(ctx, host, vectorType);
                    if (vectorType.ItemType == NumberDataViewType.Double)
                        return Dbl.ImplVec.Create(ctx, host, vectorType);
                }
                throw host.ExceptUserArg(nameof(AffineArgumentsBase.Columns), "Wrong column type. Expected: Single, Double, or Vector of Single or Vector of Double. Got: {0}.", typeSrc.ToString());
            }
 
            private abstract class ImplOne<TFloat> : AffineColumnFunction
            {
                protected readonly TFloat Scale;
                protected readonly TFloat Offset;
                protected ImplOne(IHost host, TFloat scale, TFloat offset)
                    : base(host)
                {
                    Scale = scale;
                    Offset = offset;
                }
 
                public override void AttachMetadata(MetadataDispatcher.Builder bldr, DataViewType typeSrc)
                {
                    Host.CheckValue(bldr, nameof(bldr));
                    Host.CheckValue(typeSrc, nameof(typeSrc));
                    Host.Check(typeSrc.RawType == typeof(TFloat));
                    bldr.AddPrimitive("AffineScale", typeSrc, Scale);
                    bldr.AddPrimitive("AffineOffset", typeSrc, Offset);
                }
 
                public override NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams()
                    => new NormalizingTransformer.AffineNormalizerModelParameters<TFloat>(Scale, Offset);
 
            }
 
            private abstract class ImplVec<TFloat> : AffineColumnFunction
            {
                protected readonly TFloat[] Scale;
                protected readonly TFloat[] Offset;
                protected readonly int[] IndicesNonZeroOffset;
 
                protected ImplVec(IHost host, TFloat[] scale, TFloat[] offset, int[] indicesNonZeroOffset)
                    : base(host)
                {
                    Host.AssertValue(scale);
                    Host.AssertValueOrNull(offset);
                    Host.Assert(indicesNonZeroOffset == null || offset != null);
                    Host.Assert(Offset == null || Offset.Length == Scale.Length);
 
                    Scale = scale;
                    Offset = offset;
                    IndicesNonZeroOffset = indicesNonZeroOffset;
                }
 
                public override void AttachMetadata(MetadataDispatcher.Builder bldr, DataViewType typeSrc)
                {
                    Host.CheckValue(bldr, nameof(bldr));
                    Host.CheckValue(typeSrc, nameof(typeSrc));
                    Host.Check(typeSrc.GetVectorSize() == Scale.Length);
                    Host.Check(typeSrc.GetItemType().RawType == typeof(TFloat));
                    bldr.AddGetter<VBuffer<TFloat>>("AffineScale", typeSrc, ScaleMetadataGetter);
                    if (Offset != null)
                        bldr.AddGetter<VBuffer<TFloat>>("AffineOffset", typeSrc, OffsetMetadataGetter);
                }
 
                private void ScaleMetadataGetter(int col, ref VBuffer<TFloat> dst)
                {
                    var src = new VBuffer<TFloat>(Scale.Length, Scale);
                    src.CopyTo(ref dst);
                }
 
                private void OffsetMetadataGetter(int col, ref VBuffer<TFloat> dst)
                {
                    Host.AssertValue(Offset);
                    var src = new VBuffer<TFloat>(Offset.Length, Offset);
                    src.CopyTo(ref dst);
                }
 
                public override NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams()
                    => new NormalizingTransformer.AffineNormalizerModelParameters<ImmutableArray<TFloat>>(ImmutableArray.Create(Scale), ImmutableArray.Create(Offset));
            }
        }
 
        internal abstract partial class CdfColumnFunction : IColumnFunction
        {
            protected readonly IHost Host;
 
            // The only derived classes are private inner classes
            private CdfColumnFunction(IHost host)
            {
                Contracts.CheckValue(host, nameof(host));
                Host = host;
            }
 
            void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
 
            private protected abstract void SaveModel(ModelSaveContext ctx);
 
            public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => null;
 
            public bool CanSaveOnnx(OnnxContext ctx) => false;
 
            public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
                => throw Host.ExceptNotSupp();
 
            public abstract Delegate GetGetter(DataViewRow input, int icol);
            public abstract void AttachMetadata(MetadataDispatcher.Builder bldr, DataViewType typeSrc);
            public abstract NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams();
 
            public static CdfColumnFunction Create(ModelLoadContext ctx, IHost host, DataViewType typeSrc)
            {
                Contracts.CheckValue(host, nameof(host));
                if (typeSrc is NumberDataViewType)
                {
                    if (typeSrc == NumberDataViewType.Single)
                        return Sng.ImplOne.Create(ctx, host, typeSrc);
                    if (typeSrc == NumberDataViewType.Double)
                        return Dbl.ImplOne.Create(ctx, host, typeSrc);
                }
                else if (typeSrc is VectorDataViewType vectorType && vectorType.ItemType is NumberDataViewType)
                {
                    if (vectorType.ItemType == NumberDataViewType.Single)
                        return Sng.ImplVec.Create(ctx, host, vectorType);
                    if (vectorType.ItemType == NumberDataViewType.Double)
                        return Dbl.ImplVec.Create(ctx, host, vectorType);
                }
                throw host.ExceptUserArg(nameof(AffineArgumentsBase.Columns), "Wrong column type. Expected: Single, Double, Vector of Single or Vector of Double. Got: {0}.", typeSrc);
            }
 
            private abstract class ImplOne<TFloat> : CdfColumnFunction
            {
                protected readonly TFloat Mean;
                protected readonly TFloat Stddev;
                protected readonly bool UseLog;
 
                protected ImplOne(IHost host, TFloat mean, TFloat stddev, bool useLog)
                    : base(host)
                {
                    Mean = mean;
                    Stddev = stddev;
                    UseLog = useLog;
                }
 
                public override void AttachMetadata(MetadataDispatcher.Builder bldr, DataViewType typeSrc)
                {
                    Host.CheckValue(bldr, nameof(bldr));
                    Host.CheckValue(typeSrc, nameof(typeSrc));
                    Host.Check(typeSrc.RawType == typeof(TFloat));
                    bldr.AddPrimitive("CdfMean", typeSrc, Mean);
                    bldr.AddPrimitive("CdfStdDev", typeSrc, Stddev);
                    bldr.AddPrimitive("CdfUseLog", BooleanDataViewType.Instance, UseLog);
                }
 
                public override NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams()
                    => new NormalizingTransformer.CdfNormalizerModelParameters<TFloat>(Mean, Stddev, UseLog);
            }
 
            private abstract class ImplVec<TFloat> : CdfColumnFunction
            {
                protected readonly TFloat[] Mean;
                protected readonly TFloat[] Stddev;
                protected readonly bool UseLog;
 
                protected ImplVec(IHost host, TFloat[] mean, TFloat[] stddev, bool useLog)
                    : base(host)
                {
                    Host.AssertValue(mean);
                    Host.AssertValue(stddev);
                    Host.Assert(mean.Length == stddev.Length);
                    Mean = mean;
                    Stddev = stddev;
                    UseLog = useLog;
                }
 
                public override void AttachMetadata(MetadataDispatcher.Builder bldr, DataViewType typeSrc)
                {
                    Host.CheckValue(bldr, nameof(bldr));
                    Host.CheckValue(typeSrc, nameof(typeSrc));
                    Host.Check(typeSrc.GetVectorSize() == Mean.Length);
                    Host.Check(typeSrc.GetItemType().RawType == typeof(TFloat));
                    bldr.AddGetter<VBuffer<TFloat>>("CdfMean", typeSrc, MeanMetadataGetter);
                    bldr.AddGetter<VBuffer<TFloat>>("CdfStdDev", typeSrc, StddevMetadataGetter);
                    bldr.AddPrimitive("CdfUseLog", BooleanDataViewType.Instance, UseLog);
                }
 
                private void MeanMetadataGetter(int col, ref VBuffer<TFloat> dst)
                {
                    var src = new VBuffer<TFloat>(Mean.Length, Mean);
                    src.CopyTo(ref dst);
                }
 
                private void StddevMetadataGetter(int col, ref VBuffer<TFloat> dst)
                {
                    var src = new VBuffer<TFloat>(Stddev.Length, Stddev);
                    src.CopyTo(ref dst);
                }
 
                public override NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams()
                    => new NormalizingTransformer.CdfNormalizerModelParameters<ImmutableArray<TFloat>>(ImmutableArray.Create(Mean), ImmutableArray.Create(Stddev), UseLog);
            }
 
            public const string LoaderSignature = "CdfNormalizeFunction";
 
            public static VersionInfo GetVersionInfo()
            {
                return new VersionInfo(
                    modelSignature: "CDFNORMF",
                    verWrittenCur: 0x00010001, // Initial
                    verReadableCur: 0x00010001,
                    verWeCanReadBack: 0x00010001,
                    loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(CdfColumnFunction).Assembly.FullName);
            }
        }
 
        internal abstract partial class BinColumnFunction : IColumnFunction
        {
            protected readonly IHost Host;
 
            protected BinColumnFunction(IHost host)
            {
                Contracts.CheckValue(host, nameof(host));
                Host = host;
            }
 
            void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
 
            private protected abstract void SaveModel(ModelSaveContext ctx);
 
            public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) => null;
 
            public bool CanSaveOnnx(OnnxContext ctx) => false;
 
            public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
                => throw Host.ExceptNotSupp();
 
            public abstract Delegate GetGetter(DataViewRow input, int icol);
 
            public void AttachMetadata(MetadataDispatcher.Builder bldr, DataViewType typeSrc)
            {
                // REVIEW: How to attach information on the bins, to metadata?
            }
 
            public abstract NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams();
 
            public static BinColumnFunction Create(ModelLoadContext ctx, IHost host, DataViewType typeSrc)
            {
                Contracts.CheckValue(host, nameof(host));
                if (typeSrc is NumberDataViewType)
                {
                    if (typeSrc == NumberDataViewType.Single)
                        return Sng.ImplOne.Create(ctx, host, typeSrc);
                    if (typeSrc == NumberDataViewType.Double)
                        return Dbl.ImplOne.Create(ctx, host, typeSrc);
                }
                if (typeSrc is VectorDataViewType vectorType && vectorType.ItemType is NumberDataViewType)
                {
                    if (vectorType.ItemType == NumberDataViewType.Single)
                        return Sng.ImplVec.Create(ctx, host, vectorType);
                    if (vectorType.ItemType == NumberDataViewType.Double)
                        return Dbl.ImplVec.Create(ctx, host, vectorType);
                }
                throw host.ExceptUserArg(nameof(BinArguments.Columns), "Wrong column type. Expected: Single, Double, Vector of Single or Vector of Double. Got: {0}.", typeSrc);
            }
 
            public const string LoaderSignature = "BinNormalizeFunction";
 
            public static VersionInfo GetVersionInfo()
            {
                return new VersionInfo(
                    modelSignature: "BINNORMF",
                    verWrittenCur: 0x00010001, // Initial
                    verReadableCur: 0x00010001,
                    verWeCanReadBack: 0x00010001,
                    loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(BinColumnFunction).Assembly.FullName);
            }
        }
 
        private abstract class OneColumnFunctionBuilderBase<TFloat> : IColumnFunctionBuilder
        {
            protected IHost Host;
            protected readonly long Lim;
            protected long Rem;
            private readonly ValueGetter<TFloat> _getSrc;
 
            protected OneColumnFunctionBuilderBase(IHost host, long lim, ValueGetter<TFloat> getSrc)
            {
                Contracts.CheckValue(host, nameof(host));
                Host = host;
                Rem = lim;
                Lim = lim;
                _getSrc = getSrc;
            }
 
            public bool ProcessValue()
            {
                TFloat tmp = default(TFloat);
                _getSrc(ref tmp);
                return ProcessValue(in tmp);
            }
 
            protected virtual bool ProcessValue(in TFloat val)
            {
                Host.Assert(Rem >= 0);
                if (Rem == 0)
                    return false;
                Rem--;
                return true;
            }
 
            public abstract IColumnFunction CreateColumnFunction();
        }
 
        private abstract class VecColumnFunctionBuilderBase<TFloat> : IColumnFunctionBuilder
        {
            protected IHost Host;
            protected readonly long Lim;
            protected long Rem;
            private readonly ValueGetter<VBuffer<TFloat>> _getSrc;
            private VBuffer<TFloat> _buffer;
 
            protected VecColumnFunctionBuilderBase(IHost host, long lim, ValueGetter<VBuffer<TFloat>> getSrc)
            {
                Contracts.CheckValue(host, nameof(host));
                Host = host;
                Rem = lim;
                Lim = lim;
                _getSrc = getSrc;
            }
 
            public bool ProcessValue()
            {
                _getSrc(ref _buffer);
                return ProcessValue(in _buffer);
            }
 
            protected virtual bool ProcessValue(in VBuffer<TFloat> buffer)
            {
                Host.Assert(Rem >= 0);
                if (Rem == 0)
                    return false;
                Rem--;
                return true;
            }
 
            public abstract IColumnFunction CreateColumnFunction();
        }
 
        private abstract class SupervisedBinFunctionBuilderBase : IColumnFunctionBuilder
        {
            protected readonly IHost Host;
            protected readonly long Lim;
            protected long Rem;
 
            protected readonly List<int> Labels;
            protected readonly int LabelCardinality;
            private readonly ValueGetter<int> _labelGetterSrc;
 
            protected SupervisedBinFunctionBuilderBase(IHost host, long lim, int labelColId, DataViewRow dataRow)
            {
                Contracts.CheckValue(host, nameof(host));
                Host = host;
                Rem = lim;
                Lim = lim;
                Labels = new List<int>();
                _labelGetterSrc = GetLabelGetter(dataRow, labelColId, out LabelCardinality);
            }
 
            private ValueGetter<int> GetLabelGetter(DataViewRow row, int col, out int labelCardinality)
            {
                // The label column type is checked as part of args validation.
                var type = row.Schema[col].Type;
                Host.Assert(type is KeyDataViewType || type is NumberDataViewType);
 
                if (type is KeyDataViewType keyType)
                {
                    Host.Assert(type.GetKeyCountAsInt32(Host) > 0);
                    labelCardinality = type.GetKeyCountAsInt32(Host);
 
                    int size = type.GetKeyCountAsInt32(Host);
                    ulong src = 0;
                    var getSrc = RowCursorUtils.GetGetterAs<ulong>(NumberDataViewType.UInt64, row, col);
                    return
                        (ref int dst) =>
                        {
                            getSrc(ref src);
                            // The value should fall between 0 and _labelCardinality inclusive, where 0 is considered
                            // missing/invalid (this is the contract of the KeyType). However, we still handle the
                            // cases of too large values correctly (by treating them as invalid).
                            if (src <= (ulong)size)
                                dst = (int)src - 1;
                            else
                                dst = -1;
                        };
                }
                else
                {
                    // REVIEW: replace with trainable binning for numeric value
                    labelCardinality = 2; // any numeric column is split into 0 and 1
 
                    Double src = 0;
                    var getSrc = RowCursorUtils.GetGetterAs<Double>(NumberDataViewType.Double, row, col);
                    return
                        (ref int dst) =>
                        {
                            getSrc(ref src);
                            // NaN maps to -1.
                            if (src > 0)
                                dst = 1;
                            else if (src <= 0)
                                dst = 0;
                            else
                                dst = -1;
                        };
                }
            }
 
            public virtual bool ProcessValue()
            {
                Host.Assert(Rem >= 0);
                if (Rem == 0)
                    return false;
                Rem--;
                int label = 0;
                _labelGetterSrc(ref label);
                var accept = label >= 0 && AcceptColumnValue(); // skip examples with negative label
                if (accept)
                    Labels.Add(label);
                return true;
            }
 
            public abstract IColumnFunction CreateColumnFunction();
 
            protected abstract bool AcceptColumnValue();
        }
 
        private abstract class OneColumnSupervisedBinFunctionBuilderBase<TFloat> : SupervisedBinFunctionBuilderBase
        {
            private readonly ValueGetter<TFloat> _colGetterSrc;
            protected readonly List<TFloat> ColValues;
 
            protected OneColumnSupervisedBinFunctionBuilderBase(IHost host, long lim, int valueColId, int labelColId,
                DataViewRow dataRow)
                : base(host, lim, labelColId, dataRow)
            {
                _colGetterSrc = dataRow.GetGetter<TFloat>(dataRow.Schema[valueColId]);
                ColValues = new List<TFloat>();
            }
 
            protected override bool AcceptColumnValue()
            {
                TFloat colValue = default(TFloat);
                _colGetterSrc(ref colValue);
                var result = AcceptColumnValue(in colValue);
                if (result)
                    ColValues.Add(colValue);
                return result;
            }
 
            protected abstract bool AcceptColumnValue(in TFloat colValue);
        }
 
        private abstract class VecColumnSupervisedBinFunctionBuilderBase<TFloat> : SupervisedBinFunctionBuilderBase
        {
            private readonly ValueGetter<VBuffer<TFloat>> _colValueGetter;
            private VBuffer<TFloat> _buffer;
 
            protected readonly List<TFloat>[] ColValues;
            protected readonly int ColumnSlotCount;
 
            protected VecColumnSupervisedBinFunctionBuilderBase(IHost host, long lim, int valueColId, int labelColId, DataViewRow dataRow)
                : base(host, lim, labelColId, dataRow)
            {
                var valueCol = dataRow.Schema[valueColId];
                _colValueGetter = dataRow.GetGetter<VBuffer<TFloat>>(valueCol);
 
                Host.Assert(valueCol.Type.IsKnownSizeVector());
                ColumnSlotCount = valueCol.Type.GetValueCount();
 
                ColValues = new List<TFloat>[ColumnSlotCount];
                for (int i = 0; i < ColumnSlotCount; i++)
                    ColValues[i] = new List<TFloat>();
 
                _buffer = default(VBuffer<TFloat>);
            }
 
            protected override bool AcceptColumnValue()
            {
                _colValueGetter(ref _buffer);
                bool result = AcceptColumnValue(in _buffer);
                if (result)
                {
                    if (_buffer.IsDense)
                    {
                        var values = _buffer.GetValues();
                        for (int i = 0; i < ColumnSlotCount; i++)
                            ColValues[i].Add(values[i]);
                    }
                    else
                    {
                        var indices = _buffer.GetIndices();
                        var values = _buffer.GetValues();
                        int k = 0;
                        for (int i = 0; i < values.Length; i++)
                        {
                            var val = values[i];
                            var index = indices[i];
                            while (k < index)
                                ColValues[k++].Add(default(TFloat));
 
                            ColValues[k++].Add(val);
                        }
 
                        while (k < ColumnSlotCount)
                            ColValues[k++].Add(default(TFloat));
                    }
                }
                return result;
            }
 
            protected abstract bool AcceptColumnValue(in VBuffer<TFloat> buffer);
        }
 
        internal static partial class MinMaxUtils
        {
            public static IColumnFunctionBuilder CreateBuilder(MinMaxArguments args, IHost host,
                int icol, int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                Contracts.AssertValue(host);
                host.AssertValue(args);
 
                return CreateBuilder(new NormalizingEstimator.MinMaxColumnOptions(
                    args.Columns[icol].Name,
                    args.Columns[icol].Source ?? args.Columns[icol].Name,
                    args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
                    args.Columns[icol].EnsureZeroUntouched ?? args.EnsureZeroUntouched), host, srcIndex, srcType, cursor);
            }
 
            public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.MinMaxColumnOptions column, IHost host,
                int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                var srcColumn = cursor.Schema[srcIndex];
                if (srcType is NumberDataViewType)
                {
                    if (srcType == NumberDataViewType.Single)
                        return Sng.MinMaxOneColumnFunctionBuilder.Create(column, host, srcType, cursor.GetGetter<Single>(srcColumn));
                    if (srcType == NumberDataViewType.Double)
                        return Dbl.MinMaxOneColumnFunctionBuilder.Create(column, host, srcType, cursor.GetGetter<Double>(srcColumn));
                }
                if (srcType is VectorDataViewType vectorType && vectorType.IsKnownSize && vectorType.ItemType is NumberDataViewType)
                {
                    if (vectorType.ItemType == NumberDataViewType.Single)
                        return Sng.MinMaxVecColumnFunctionBuilder.Create(column, host, vectorType, cursor.GetGetter<VBuffer<Single>>(srcColumn));
                    if (vectorType.ItemType == NumberDataViewType.Double)
                        return Dbl.MinMaxVecColumnFunctionBuilder.Create(column, host, vectorType, cursor.GetGetter<VBuffer<Double>>(srcColumn));
                }
                throw host.ExceptParam(nameof(srcType), "Wrong column type for input column. Expected: Single, Double, Vector of Single or Vector of Double. Got: {0}.", srcType.ToString());
            }
        }
 
        internal static partial class MeanVarUtils
        {
            public static IColumnFunctionBuilder CreateBuilder(MeanVarArguments args, IHost host,
                int icol, int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                Contracts.AssertValue(host);
                host.AssertValue(args);
 
                return CreateBuilder(new NormalizingEstimator.MeanVarianceColumnOptions(
                    args.Columns[icol].Name,
                    args.Columns[icol].Source ?? args.Columns[icol].Name,
                    args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
                    args.Columns[icol].EnsureZeroUntouched ?? args.EnsureZeroUntouched,
                    args.UseCdf), host, srcIndex, srcType, cursor);
            }
 
            public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.MeanVarianceColumnOptions column, IHost host,
                int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                Contracts.AssertValue(host);
                var srcColumn = cursor.Schema[srcIndex];
                if (srcType is NumberDataViewType)
                {
                    if (srcType == NumberDataViewType.Single)
                        return Sng.MeanVarOneColumnFunctionBuilder.Create(column, host, srcType, cursor.GetGetter<Single>(srcColumn));
                    if (srcType == NumberDataViewType.Double)
                        return Dbl.MeanVarOneColumnFunctionBuilder.Create(column, host, srcType, cursor.GetGetter<Double>(srcColumn));
                }
                if (srcType is VectorDataViewType vectorType && vectorType.IsKnownSize && vectorType.ItemType is NumberDataViewType)
                {
                    if (vectorType.ItemType == NumberDataViewType.Single)
                        return Sng.MeanVarVecColumnFunctionBuilder.Create(column, host, vectorType, cursor.GetGetter<VBuffer<Single>>(srcColumn));
                    if (vectorType.ItemType == NumberDataViewType.Double)
                        return Dbl.MeanVarVecColumnFunctionBuilder.Create(column, host, vectorType, cursor.GetGetter<VBuffer<Double>>(srcColumn));
                }
                throw host.ExceptParam(nameof(srcType), "Wrong column type for input column. Expected: Single, Double, Vector of Single or Vector of Double. Got: {0}.", srcType.ToString());
            }
 
        }
 
        internal static partial class LogMeanVarUtils
        {
            public static IColumnFunctionBuilder CreateBuilder(LogMeanVarArguments args, IHost host,
                int icol, int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                Contracts.AssertValue(host);
                host.AssertValue(args);
 
                return CreateBuilder(new NormalizingEstimator.LogMeanVarianceColumnOptions(
                    args.Columns[icol].Name,
                    args.Columns[icol].Source ?? args.Columns[icol].Name,
                    args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
                    args.UseCdf), host, srcIndex, srcType, cursor);
            }
 
            public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.LogMeanVarianceColumnOptions column, IHost host,
                int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                Contracts.AssertValue(host);
                host.AssertValue(column);
 
                var srcColumn = cursor.Schema[srcIndex];
                if (srcType is NumberDataViewType)
                {
                    if (srcType == NumberDataViewType.Single)
                        return Sng.MeanVarOneColumnFunctionBuilder.Create(column, host, srcType, cursor.GetGetter<Single>(srcColumn));
                    if (srcType == NumberDataViewType.Double)
                        return Dbl.MeanVarOneColumnFunctionBuilder.Create(column, host, srcType, cursor.GetGetter<Double>(srcColumn));
                }
                if (srcType is VectorDataViewType vectorType && vectorType.IsKnownSize && vectorType.ItemType is NumberDataViewType)
                {
                    if (vectorType.ItemType == NumberDataViewType.Single)
                        return Sng.MeanVarVecColumnFunctionBuilder.Create(column, host, vectorType, cursor.GetGetter<VBuffer<Single>>(srcColumn));
                    if (vectorType.ItemType == NumberDataViewType.Double)
                        return Dbl.MeanVarVecColumnFunctionBuilder.Create(column, host, vectorType, cursor.GetGetter<VBuffer<Double>>(srcColumn));
                }
                throw host.ExceptUserArg(nameof(column), "Wrong column type for column {0}. Expected: Single, Double, Vector of Single or Vector of Double. Got: {1}.", column.InputColumnName, srcType.ToString());
            }
        }
 
        internal static partial class BinUtils
        {
            public static IColumnFunctionBuilder CreateBuilder(BinArguments args, IHost host,
                int icol, int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                Contracts.AssertValue(host);
                host.AssertValue(args);
 
                return CreateBuilder(new NormalizingEstimator.BinningColumnOptions(
                    args.Columns[icol].Name,
                    args.Columns[icol].Source ?? args.Columns[icol].Name,
                    args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
                    args.Columns[icol].EnsureZeroUntouched ?? args.EnsureZeroUntouched,
                    args.Columns[icol].NumBins ?? args.NumBins), host, srcIndex, srcType, cursor);
            }
 
            public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.BinningColumnOptions column, IHost host,
                int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                Contracts.AssertValue(host);
 
                var srcColumn = cursor.Schema[srcIndex];
                if (srcType is NumberDataViewType)
                {
                    if (srcType == NumberDataViewType.Single)
                        return Sng.BinOneColumnFunctionBuilder.Create(column, host, srcType, cursor.GetGetter<Single>(srcColumn));
                    if (srcType == NumberDataViewType.Double)
                        return Dbl.BinOneColumnFunctionBuilder.Create(column, host, srcType, cursor.GetGetter<Double>(srcColumn));
                }
                if (srcType is VectorDataViewType vectorType && vectorType.IsKnownSize && vectorType.ItemType is NumberDataViewType)
                {
                    if (vectorType.ItemType == NumberDataViewType.Single)
                        return Sng.BinVecColumnFunctionBuilder.Create(column, host, vectorType, cursor.GetGetter<VBuffer<Single>>(srcColumn));
                    if (vectorType.ItemType == NumberDataViewType.Double)
                        return Dbl.BinVecColumnFunctionBuilder.Create(column, host, vectorType, cursor.GetGetter<VBuffer<Double>>(srcColumn));
                }
                throw host.ExceptParam(nameof(column), "Wrong column type for column {0}. Expected: Single, Double, Vector of Single or Vector of Double. Got: {1}.", column.InputColumnName, srcType.ToString());
            }
        }
 
        internal static class SupervisedBinUtils
        {
            public static IColumnFunctionBuilder CreateBuilder(SupervisedBinArguments args, IHost host,
                int icol, int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                Contracts.AssertValue(host);
                host.AssertValue(args);
 
                // checking for label column
                host.CheckUserArg(!string.IsNullOrWhiteSpace(args.LabelColumn), nameof(args.LabelColumn), "Must specify the label column name");
                int labelColumnId = GetLabelColumnId(host, cursor.Schema, args.LabelColumn);
                var labelColumnType = cursor.Schema[labelColumnId].Type;
                if (labelColumnType is KeyDataViewType labelKeyType)
                    host.CheckUserArg(labelKeyType.Count > 0, nameof(args.LabelColumn), "Label column must have a known cardinality");
                else
                    host.CheckUserArg(labelColumnType is NumberDataViewType, nameof(args.LabelColumn), "Label column must be a number or a key type");
 
                return CreateBuilder(
                    new NormalizingEstimator.SupervisedBinningColumOptions(
                        args.Columns[icol].Name,
                        args.Columns[icol].Source ?? args.Columns[icol].Name,
                        args.LabelColumn ?? DefaultColumnNames.Label,
                        args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
                        args.Columns[icol].EnsureZeroUntouched ?? args.EnsureZeroUntouched,
                        args.Columns[icol].NumBins ?? args.NumBins,
                        args.MinBinSize),
                    host, labelColumnId, srcIndex, srcType, cursor);
            }
 
            public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.SupervisedBinningColumOptions column, IHost host,
                 string labelColumn, int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                int labelColumnId = GetLabelColumnId(host, cursor.Schema, labelColumn);
                return CreateBuilder(column, host, labelColumnId, srcIndex, srcType, cursor);
            }
 
            private static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.SupervisedBinningColumOptions column, IHost host,
                int labelColumnId, int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                Contracts.AssertValue(host);
 
                if (srcType is NumberDataViewType)
                {
                    if (srcType == NumberDataViewType.Single)
                        return Sng.SupervisedBinOneColumnFunctionBuilder.Create(column, host, srcIndex, labelColumnId, cursor);
                    if (srcType == NumberDataViewType.Double)
                        return Dbl.SupervisedBinOneColumnFunctionBuilder.Create(column, host, srcIndex, labelColumnId, cursor);
                }
                if (srcType is VectorDataViewType vectorType && vectorType.ItemType is NumberDataViewType)
                {
                    if (vectorType.ItemType == NumberDataViewType.Single)
                        return Sng.SupervisedBinVecColumnFunctionBuilder.Create(column, host, srcIndex, labelColumnId, cursor);
                    if (vectorType.ItemType == NumberDataViewType.Double)
                        return Dbl.SupervisedBinVecColumnFunctionBuilder.Create(column, host, srcIndex, labelColumnId, cursor);
                }
 
                throw host.ExceptParam(nameof(column), "Wrong column type for column {0}. Expected: Single, Double, Vec<Single, n> or Vec<Double, n>. Got: {1}.",
                    column.InputColumnName,
                    srcType.ToString());
            }
 
            public static int GetLabelColumnId(IExceptionContext host, DataViewSchema schema, string labelColumnName)
            {
                Contracts.AssertValue(host);
                host.AssertValue(schema);
                int labelColumnId;
                if (!schema.TryGetColumnIndex(labelColumnName, out labelColumnId))
                    throw host.ExceptUserArg(nameof(SupervisedBinArguments.LabelColumn), "Label column '{0}' not found", labelColumnName);
                return labelColumnId;
            }
        }
 
        internal static partial class RobustScaleUtils
        {
            public static IColumnFunctionBuilder CreateBuilder(RobustScalingArguments args, IHost host,
                int icol, int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                Contracts.AssertValue(host);
                host.AssertValue(args);
 
                return CreateBuilder(new NormalizingEstimator.RobustScalingColumnOptions(
                    args.Columns[icol].Name,
                    args.Columns[icol].Source ?? args.Columns[icol].Name,
                    args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
                    args.CenterData,
                    args.QuantileMin,
                    args.QuantileMax), host, srcIndex, srcType, cursor);
            }
 
            public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.RobustScalingColumnOptions column, IHost host,
                int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
            {
                var srcColumn = cursor.Schema[srcIndex];
                if (srcType is NumberDataViewType)
                {
                    if (srcType == NumberDataViewType.Single)
                        return Sng.RobustScalerOneColumnFunctionBuilder.Create(column, host, srcType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<Single>(srcColumn));
                    if (srcType == NumberDataViewType.Double)
                        return Dbl.RobustScalerOneColumnFunctionBuilder.Create(column, host, srcType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<double>(srcColumn));
                }
                if (srcType is VectorDataViewType vectorType && vectorType.IsKnownSize && vectorType.ItemType is NumberDataViewType)
                {
                    if (vectorType.ItemType == NumberDataViewType.Single)
                        return Sng.RobustScalerVecFunctionBuilder.Create(column, host, srcType as VectorDataViewType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<VBuffer<float>>(srcColumn));
                    if (vectorType.ItemType == NumberDataViewType.Double)
                        return Dbl.RobustScalerVecFunctionBuilder.Create(column, host, srcType as VectorDataViewType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<VBuffer<double>>(srcColumn));
                }
 
                throw host.ExceptParam(nameof(srcType), "Wrong column type for input column. Expected: Single, Double, Vector of Single or Vector of Double. Got: {0}.", srcType.ToString());
            }
        }
    }
 
    internal static partial class AffineNormSerializationUtils
    {
        public const string LoaderSignature = "AffineNormExec";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "AFF NORM",
                // verWrittenCur: 0x00010001, // Initial
                //verWrittenCur: 0x00010002, // Sparse representation
                verWrittenCur: 0x00010003, // Scales multiply instead of divide
                verReadableCur: 0x00010003,
                verWeCanReadBack: 0x00010003,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(AffineNormSerializationUtils).Assembly.FullName);
        }
    }
 
    internal static partial class BinNormSerializationUtils
    {
        public const string LoaderSignature = "BinNormExec";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "BIN NORM",
                verWrittenCur: 0x00010001,
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(BinNormSerializationUtils).Assembly.FullName);
        }
    }
 
    internal static class MeanVarUtils
    {
        internal static void AdjustForZeros(ref Double mean, ref Double m2, ref long count, long numZeros)
        {
            Contracts.Assert(m2 >= 0);
            Contracts.Assert(count >= 0);
            Contracts.Assert(numZeros >= 0);
 
            if (numZeros == 0)
                return;
            count += numZeros;
            var delta = 0 - mean;
            mean += delta * numZeros / count;
            var d2 = delta * (0 - mean);
            Contracts.Assert(d2 >= 0);
            m2 += d2 * numZeros;
            Contracts.Assert(m2 >= 0);
        }
    }
}