File: Transforms\NormalizeColumnDbl.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Runtime.CompilerServices;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
using Newtonsoft.Json.Linq;
 
namespace Microsoft.ML.Transforms
{
    // !!! WARNING !!!
    // This file contains the Double version for normalizers and is almost identical with NormalizeColumnSng.cs
    // When making changes to one, use BeyondCompare or a similar tool to view diffs and propagate
    // appropriate changes to the other.
    using TFloat = Double;
 
    internal static partial class AffineNormSerializationUtils
    {
        public static void SaveModel(ModelSaveContext ctx,
            int numFeatures, int[] indices, TFloat[] scales, TFloat[] offsets, bool saveText = false)
        {
            Contracts.AssertValue(ctx);
            ctx.CheckAtModel();
            Contracts.Check(numFeatures > 0);
            Contracts.CheckValueOrNull(indices);
            Contracts.CheckValue(scales, nameof(scales));
            Contracts.CheckValueOrNull(offsets);
 
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // int: sizeof(TFloat)
            // int: number of features (size)
            // int: number of indices morphed (morph: -1 means that we assume all are, zero means none are)
            // int[]: morphed indices (max(0, morph) of them)
            // int: number of scales (if morph >= 0, this should be morph, otherwise, should be size)
            // TFloat[]: scale values
            // int: number of offsets (zero if they are all zero, otherwise, should be morph or size - same as scales)
            // TFloat[]: offset values
            ctx.Writer.Write(sizeof(TFloat));
            ctx.Writer.Write(numFeatures);
 
            Contracts.Assert(offsets == null || offsets.Length == scales.Length);
            if (indices == null)
            {
                Contracts.Assert(scales.Length == numFeatures);
                ctx.Writer.Write(-1);
            }
            else
            {
                Contracts.Assert(indices.Length < numFeatures);
                Contracts.Assert(scales.Length == indices.Length);
                ctx.Writer.WriteIntArray(indices);
            }
            ctx.Writer.WriteDoubleArray(scales);
            ctx.Writer.WriteDoubleArray(offsets);
 
            if (saveText)
            {
                ctx.SaveTextStream("AffineNormalizer.txt",
                    writer =>
                    {
                        writer.WriteLine("NumNormalizationFeatures={0}", numFeatures);
                        if (indices == null)
                        {
                            for (int i = 0; i < numFeatures; i++)
                                writer.WriteLine("{0}\t{1}\t{2}", i, offsets != null ? offsets[i] : 0, scales[i]);
                        }
                        else
                        {
                            for (int ii = 0; ii < indices.Length; ii++)
                                writer.WriteLine("{0}\t{1}\t{2}", indices[ii], offsets != null ? offsets[ii] : 0,
                                    scales[ii]);
                        }
                        writer.WriteLine();
                    });
            }
        }
 
        public static void LoadModel(ModelLoadContext ctx, ref List<int> indicesShift,
            out int numFeatures, out TFloat[] scales, out TFloat[] offsets,
            out int[] indicesMorph, out TFloat[] scalesSparse, out TFloat[] offsetsSparse)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            // *** Binary format ***
            // int: sizeof(TFloat)
            // int: number of features (size)
            // int: number of indices morphed (morph: -1 means that we assume all are, zero means none are)
            // int[]: morphed indices (max(0, morph) of them)
            // int: number of scales (if morph >= 0, this should be morph, otherwise, should be size)
            // TFloat[]: scale values
            // int: number of offsets (zero if they are all zero, otherwise, should be morph or size - same as scales)
            // TFloat[]: offset values
 
            int cbFloat = ctx.Reader.ReadInt32();
            Contracts.CheckDecode(cbFloat == sizeof(TFloat));
 
            int size = ctx.Reader.ReadInt32();
            Contracts.CheckDecode(size > 0);
 
            numFeatures = size;
 
            int morphCount = ctx.Reader.ReadInt32();
            Contracts.CheckDecode(-1 <= morphCount && morphCount < size);
 
            if (indicesShift != null)
                indicesShift.Clear();
            if (morphCount == -1)
            {
                // Not using sparsity.
                indicesMorph = null;
                int scaleCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(scaleCount == size);
                scalesSparse = ctx.Reader.ReadDoubleArray(scaleCount);
                int offsetCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(offsetCount == 0 || offsetCount == size);
                offsetsSparse = ctx.Reader.ReadDoubleArray(offsetCount);
 
                scales = scalesSparse;
                offsets = offsetsSparse;
                for (int iv = 0; iv < scales.Length; iv++)
                {
                    TFloat scale = scales[iv];
                    Contracts.CheckDecode(!TFloat.IsNaN(scale));
                    if (offsets == null)
                        continue;
                    if (scale == 0)
                    {
                        offsets[iv] = 0;
                        continue;
                    }
                    TFloat offset = offsets[iv];
                    Contracts.CheckDecode(!TFloat.IsNaN(offset));
                    if (!(offset == 0))
                        Utils.Add(ref indicesShift, iv);
                }
            }
            else
            {
                // Using sparsity.
                indicesMorph = ctx.Reader.ReadIntArray(morphCount) ?? new int[0];
 
                int scaleCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(scaleCount == morphCount);
                scalesSparse = ctx.Reader.ReadDoubleArray(scaleCount) ?? new TFloat[0];
                int offsetCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(offsetCount == 0 || offsetCount == morphCount);
                offsetsSparse = ctx.Reader.ReadDoubleArray(offsetCount);
 
                // Construct the dense representations.
                scales = Utils.CreateArray<TFloat>(numFeatures, 1);
                offsets = offsetsSparse != null ? new TFloat[numFeatures] : null;
                int ivPrev = -1;
                for (int iiv = 0; iiv < indicesMorph.Length; iiv++)
                {
                    int iv = indicesMorph[iiv];
                    Contracts.CheckDecode(ivPrev < iv && iv < numFeatures);
                    ivPrev = iv;
                    TFloat scale = scales[iv] = scalesSparse[iiv];
                    Contracts.CheckDecode(!TFloat.IsNaN(scale));
                    if (offsetsSparse == null)
                        continue;
                    if (scale == 0)
                    {
                        offsetsSparse[iiv] = 0;
                        continue;
                    }
                    TFloat offset = offsets[iv] = offsetsSparse[iiv];
                    Contracts.CheckDecode(!TFloat.IsNaN(offset));
                    if (!(offset == 0))
                        Utils.Add(ref indicesShift, iv);
                }
            }
 
            Contracts.Assert(numFeatures > 0);
            Contracts.Assert(scalesSparse != null);
            Contracts.Assert(indicesMorph == null || indicesMorph.Length == scalesSparse.Length);
            Contracts.Assert(offsetsSparse == null || offsetsSparse.Length == scalesSparse.Length);
            Contracts.Assert((offsets == null) == (offsetsSparse == null));
        }
    }
 
    internal static partial class BinNormSerializationUtils
    {
        public static void SaveModel(ModelSaveContext ctx, TFloat[][] binUpperBounds, bool saveText = false)
        {
            Contracts.AssertValue(ctx);
 
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // int: sizeof(TFloat)
            // int: number of bin upper bounds arrays = number of features
            // for each array:
            //     int: number of elements in bin upper bounds
            //     TFloat[]: bin upper bounds
            ctx.Writer.Write(sizeof(TFloat));
 
            ctx.Writer.Write(binUpperBounds.Length);
            foreach (var featureUpperBounds in binUpperBounds)
                ctx.Writer.WriteDoubleArray(featureUpperBounds);
 
            if (saveText)
            {
                ctx.SaveTextStream("BinNormalizer.txt",
                    writer =>
                    {
                        writer.WriteLine("NumNormalizationFeatures={0}", binUpperBounds.Length);
                        for (int i = 0; i < binUpperBounds.Length; i++)
                        {
                            string pre = "";
                            for (int j = 0; j < binUpperBounds[i].Length - 1; j++)
                            {
                                writer.Write(pre);
                                pre = "\t";
                                writer.Write(binUpperBounds[i][j]);
                            }
                            writer.WriteLine();
                        }
                    });
            }
        }
 
        public static void LoadModel(ModelLoadContext ctx, out TFloat[][] binUpperBounds)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            // *** Binary format ***
            // int: sizeof(TFloat)
            // int: number of bin upper bounds arrays = number of features
            // for each array:
            //     int: number of elements in bin upper bounds
            //     TFloat[]: bin upper bounds
            int cbFloat = ctx.Reader.ReadInt32();
            Contracts.CheckDecode(cbFloat == sizeof(TFloat));
 
            // Core model
            int numFeatures = ctx.Reader.ReadInt32();
            Contracts.CheckDecode(numFeatures > 0);
            binUpperBounds = new TFloat[numFeatures][];
            for (int i = 0; i < numFeatures; i++)
            {
                TFloat[] curUpperBounds = ctx.Reader.ReadDoubleArray();
                Contracts.CheckDecode(Utils.Size(curUpperBounds) > 0);
                binUpperBounds[i] = curUpperBounds;
                for (int j = 1; j < curUpperBounds.Length; j++)
                    Contracts.CheckDecode(curUpperBounds[j - 1] < curUpperBounds[j]);
                Contracts.CheckDecode(curUpperBounds[curUpperBounds.Length - 1] == TFloat.PositiveInfinity);
            }
        }
    }
 
    internal static partial class CdfNormSerializationUtils
    {
        public static void SaveModel(ModelSaveContext ctx, bool useLog, TFloat[] mean, TFloat[] stddev)
        {
            // *** Binary format ***
            // int: sizeof(TFloat)
            // bool: useLog
            // int: number of features (size)
            // TFloat[]: mean values
            // TFloat[]: stddev values
            ctx.Writer.Write(sizeof(TFloat));
            ctx.Writer.WriteBoolByte(useLog);
            ctx.Writer.Write(mean.Length);
            ctx.Writer.WriteDoublesNoCount(mean);
            ctx.Writer.WriteDoublesNoCount(stddev.AsSpan(0, mean.Length));
 
            ctx.SaveTextStream("CdfNormalizer.txt",
                writer =>
                {
                    writer.WriteLine("NumNormalizationFeatures={0}", mean.Length);
                    writer.WriteLine("Log={0}", useLog);
                    for (int i = 0; i < mean.Length; i++)
                        writer.WriteLine("{0}\t{1}", mean[i], stddev[i]);
                });
        }
 
        public static void LoadModel(ModelLoadContext ctx, int cv, out bool useLog, out TFloat[] mean, out TFloat[] stddev)
        {
            // *** Binary format ***
            // int: sizeof(TFloat)
            // bool: useLog
            // int: number of features (size)
            // TFloat[]: mean values
            // TFloat[]: stddev values
 
            int cbFloat = ctx.Reader.ReadInt32();
            Contracts.CheckDecode(cbFloat == sizeof(TFloat));
 
            useLog = ctx.Reader.ReadBoolByte();
 
            int size = ctx.Reader.ReadInt32();
            Contracts.CheckDecode(size > 0);
            if (size != cv)
                throw Contracts.Except("Normalizer expected {0} slots, but the input data column has {1} slots.", size, cv);
            mean = ctx.Reader.ReadDoubleArray(size);
            stddev = ctx.Reader.ReadDoubleArray(size);
        }
    }
 
    /// <summary>
    /// Base class for tracking min and max values for a vector valued column.
    /// It tracks min, max, number of non-sparse values (vCount) and number of ProcessValue() calls (trainCount).
    /// NaNs are ignored when updating min and max.
    /// </summary>
    internal sealed class MinMaxDblAggregator : IColumnAggregator<VBuffer<TFloat>>
    {
        private readonly TFloat[] _min;
        private readonly TFloat[] _max;
        private readonly long[] _vCount;
        private long _trainCount;
 
        public MinMaxDblAggregator(int size)
        {
            Contracts.Check(size > 0);
            _min = new TFloat[size];
            _max = new TFloat[size];
            _vCount = new long[size];
            for (int i = 0; i < size; i++)
            {
                _min[i] = TFloat.PositiveInfinity;
                _max[i] = TFloat.NegativeInfinity;
            }
        }
 
        public TFloat[] Min
        {
            get { return _min; }
        }
 
        public TFloat[] Max
        {
            get { return _max; }
        }
 
        public long[] Count
        {
            get { return _vCount; }
        }
 
        public void ProcessValue(in VBuffer<TFloat> value)
        {
            var size = _min.Length;
            Contracts.Check(value.Length == size);
            _trainCount++;
            var values = value.GetValues();
            Contracts.Assert(0 <= values.Length && values.Length <= size);
            if (values.Length == 0)
                return;
 
            if (values.Length == size)
            {
                for (int j = 0; j < values.Length; j++)
                {
                    var val = values[j];
                    _vCount[j]++;
                    Update(j, val);
                }
            }
            else
            {
                var indices = value.GetIndices();
                for (int k = 0; k < values.Length; k++)
                {
                    var val = values[k];
                    var j = indices[k];
                    _vCount[j]++;
                    Update(j, val);
                }
            }
        }
 
        public void Finish()
        {
            var size = _min.Length;
            for (int i = 0; i < size; i++)
            {
                if (_vCount[i] < _trainCount)
                    Update(i, 0);
            }
        }
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private void Update(int j, TFloat val)
        {
            if (_max[j] < val)
                _max[j] = val;
            if (_min[j] > val)
                _min[j] = val;
        }
    }
 
    /// <summary>
    /// Class for computing the mean and variance for a vector valued column.
    /// It tracks the current mean and the M2 (sum of squared diffs of the values from the mean),
    /// the number of NaNs and the number of non-zero elements.
    /// Uses the algorithm described here: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
    /// </summary>
    internal sealed class MeanVarDblAggregator
    {
        private readonly bool _useLog;
        private readonly Double[] _mean;
        private readonly Double[] _m2;
        private readonly long[] _cnan;
        private readonly long[] _cnz;
        private long _trainCount;
 
        public MeanVarDblAggregator(int size, bool useLog)
        {
            _useLog = useLog;
            _mean = new Double[size];
            _m2 = new Double[size];
            if (!_useLog)
                _cnan = new long[size];
            _cnz = new long[size];
        }
 
        public long[] Counts
        {
            get { return _cnz; }
        }
 
        public Double[] Mean
        {
            get { return _mean; }
        }
 
        public Double[] StdDevPopulation
        {
            get { return _m2.Select((m2, i) => Math.Sqrt(m2 / _cnz[i])).ToArray(); }
        }
 
        public Double[] StdDevSample
        {
            get { return _m2.Select((m2, i) => Math.Sqrt(m2 / Math.Max(0, _cnz[i] - 1))).ToArray(); }
        }
 
        public Double[] MeanSquareError
        {
            get { return _m2.Select((m2, i) => m2 / _cnz[i]).ToArray(); }
        }
 
        public Double[] SampleVariance
        {
            get { return _m2.Select((m2, i) => m2 / Math.Max(0, _cnz[i] - 1)).ToArray(); }
        }
 
        public Double[] M2
        {
            get { return _m2; }
        }
 
        public void ProcessValue(in VBuffer<TFloat> value)
        {
            _trainCount++;
            var size = _mean.Length;
            var values = value.GetValues();
            Contracts.Assert(0 <= values.Length && values.Length <= size);
            if (values.Length == 0)
                return;
 
            if (values.Length == size)
            {
                for (int j = 0; j < values.Length; j++)
                {
                    var origVal = values[j];
                    Update(j, origVal);
                }
            }
            else
            {
                var indices = value.GetIndices();
                for (int k = 0; k < values.Length; k++)
                {
                    var origVal = values[k];
                    var j = indices[k];
                    Update(j, origVal);
                }
            }
        }
 
        public void Finish()
        {
            if (!_useLog)
            {
                for (int i = 0; i < _mean.Length; i++)
                {
                    Contracts.Assert(_trainCount >= _cnan[i] + _cnz[i]);
                    MeanVarUtils.AdjustForZeros(ref _mean[i], ref _m2[i], ref _cnz[i], _trainCount - _cnan[i] - _cnz[i]);
                }
            }
        }
 
        [MethodImpl(MethodImplOptions.AggressiveInlining)]
        private void Update(int j, TFloat origVal)
        {
            if (origVal == 0)
                return;
            var val = _useLog ? (TFloat)Math.Log(origVal) : origVal;
            if (!FloatUtils.IsFinite(val))
            {
                if (!_useLog)
                    _cnan[j]++;
                return;
            }
 
            _cnz[j]++;
            var delta = val - _mean[j];
            _mean[j] += delta / _cnz[j];
            var dm2 = delta * (val - _mean[j]);
            Contracts.Assert(dm2 >= 0);
            _m2[j] += dm2;
            Contracts.Assert(_m2[j] >= 0);
        }
    }
 
    [BestFriend]
    internal static partial class MedianAggregatorUtils
    {
        /// <summary>
        /// Based on the algorithm on GeeksForGeeks https://www.geeksforgeeks.org/median-of-stream-of-integers-running-integers/.
        /// </summary>
        /// <param name="num">The new number to account for in our median calculation.</param>
        /// <param name="median">The current median.</param>
        /// <param name="belowMedianHeap">The MaxHeap that has all the numbers below the median.</param>
        /// <param name="aboveMedianHeap">The MinHeap that has all the numbers above the median.</param>
        [BestFriend]
        internal static void GetMedianSoFar(in double num, ref double median, ref MaxHeap<double> belowMedianHeap, ref MinHeap<double> aboveMedianHeap)
        {
            int comparison = belowMedianHeap.Count().CompareTo(aboveMedianHeap.Count());
 
            if (comparison < 0)
            { // More elements in aboveMedianHeap than belowMedianHeap.
                if (num < median)
                { // Current element belongs in the belowMedianHeap.
                    // Insert new number into belowMedianHeap
                    belowMedianHeap.Add(num);
 
                }
                else
                { // Current element belongs in aboveMedianHeap.
                    // Need to move one to belowMedianHeap to keep heeps balanced.
                    belowMedianHeap.Add(aboveMedianHeap.Pop());
 
                    aboveMedianHeap.Add(num);
                }
 
                // Both heaps are balanced so median is the average of the 2 heaps.
                median = (aboveMedianHeap.Peek() + belowMedianHeap.Peek()) / 2;
 
            }
            else if (comparison == 0)
            { // Both heaps have the same number of elements. Simple put the number where it belongs.
                if (num < median)
                { // Current element belongs in the belowMedianHeap.
                    belowMedianHeap.Add(num);
 
                    // Now we have an odd number of items, median is the new root of the belowMedianHeap
                    median = belowMedianHeap.Peek();
 
                }
                else
                { // Current element belongs in above median heap.
                    aboveMedianHeap.Add(num);
 
                    // Now we have an odd number of items, median is the new root of the aboveMedianHeap
                    median = aboveMedianHeap.Peek();
                }
 
            }
            else
            { // More elements in belowMedianHeap than aboveMedianHeap.
                if (num < median)
                { // Current element belongs in the belowMedianHeap.
                    // Need to move one to aboveMedianHeap to keep heeps balanced.
                    aboveMedianHeap.Add(belowMedianHeap.Pop());
 
                    // Insert new number into belowMedianHeap
                    belowMedianHeap.Add(num);
 
                }
                else
                { // Current element belongs in aboveMedianHeap.
                    aboveMedianHeap.Add(num);
                }
 
                // Both heaps are balanced so median is the average of the 2 heaps.
                median = (aboveMedianHeap.Peek() + belowMedianHeap.Peek()) / 2;
            }
        }
    }
 
    /// <summary>
    /// Base class for tracking median values for a single valued column.
    /// It tracks median values of non-sparse values (vCount).
    /// NaNs are ignored when updating min and max.
    /// </summary>
    [BestFriend]
    internal sealed class MedianDblAggregator : IColumnAggregator<double>
    {
        private MedianAggregatorUtils.MaxHeap<double> _belowMedianHeap;
        private MedianAggregatorUtils.MinHeap<double> _aboveMedianHeap;
        private double _median;
 
        public MedianDblAggregator(int contatinerStartingSize = 1000)
        {
            Contracts.Check(contatinerStartingSize > 0);
            _belowMedianHeap = new MedianAggregatorUtils.MaxHeap<double>(contatinerStartingSize);
            _aboveMedianHeap = new MedianAggregatorUtils.MinHeap<double>(contatinerStartingSize);
            _median = default;
        }
 
        public double Median
        {
            get { return _median; }
        }
 
        public void ProcessValue(in double value)
        {
            MedianAggregatorUtils.GetMedianSoFar(value, ref _median, ref _belowMedianHeap, ref _aboveMedianHeap);
        }
 
        public void Finish()
        {
            // Finish is a no-op because we are updating the median continually as we go
        }
    }
 
    internal sealed partial class NormalizeTransform
    {
        internal abstract partial class AffineColumnFunction
        {
            public static IColumnFunction Create(IHost host, TFloat scale, TFloat offset)
            {
                return new Dbl.ImplOne(host, scale, offset);
            }
 
            public static IColumnFunction Create(IHost host, TFloat[] scale, TFloat[] offset, int[] indicesNonZeroOffset)
            {
                return new Dbl.ImplVec(host, scale, offset, indicesNonZeroOffset);
            }
 
            private static class Dbl
            {
                // REVIEW: Should we have separate classes for offset==0 and/or scale==1?
                public sealed class ImplOne : ImplOne<TFloat>
                {
                    public ImplOne(IHost host, TFloat scale, TFloat offset)
                        : base(host, scale, offset)
                    {
                    }
 
                    public static new ImplOne Create(ModelLoadContext ctx, IHost host, DataViewType typeSrc)
                    {
                        host.Check(typeSrc.RawType == typeof(TFloat), "The column type must be R8.");
                        List<int> nz = null;
                        int cfeat;
                        TFloat[] scales;
                        TFloat[] offsets;
                        int[] indices;
                        TFloat[] scalesSparse;
                        TFloat[] offsetsSparse;
 
                        AffineNormSerializationUtils.LoadModel(ctx, ref nz, out cfeat, out scales, out offsets,
                            out indices, out scalesSparse, out offsetsSparse);
                        host.Assert(scales.Length == cfeat);
                        host.Assert(offsets == null || offsets.Length == cfeat);
                        host.Assert(Utils.Size(nz) == 0 || offsets != null);
                        if (cfeat != 1)
                            throw host.Except("Normalizer expected {0} slots, but the input data column has 1 slot.", cfeat);
 
                        return new ImplOne(host, scales[0], (offsets != null) ? offsets[0] : 0);
                    }
 
                    private void GetResult(ref TFloat input, ref TFloat value)
                    {
                        value = (input - Offset) * Scale;
                    }
 
                    private protected override void SaveModel(ModelSaveContext ctx)
                    {
                        AffineNormSerializationUtils.SaveModel(ctx, 1, null, new[] { Scale }, new[] { Offset }, saveText: true);
                    }
 
                    public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
                        => PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale);
 
                    public override bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
                    {
                        nodeProtoWrapper.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount));
                        nodeProtoWrapper.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount));
                        return true;
                    }
 
                    public override Delegate GetGetter(DataViewRow input, int icol)
                    {
                        var getSrc = input.GetGetter<TFloat>(input.Schema[icol]);
                        ValueGetter<TFloat> del =
                            (ref TFloat dst) =>
                            {
                                getSrc(ref dst);
                                GetResult(ref dst, ref dst);
                            };
                        return del;
                    }
 
                }
 
                // REVIEW: Does it make sense to have 3 separate classes for the 3 cases in GetResult?
                public sealed class ImplVec : ImplVec<TFloat>
                {
                    public ImplVec(IHost host, TFloat[] scale, TFloat[] offset, int[] indicesNonZeroOffset)
                        : base(host, scale, offset, indicesNonZeroOffset)
                    {
                    }
 
                    public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorDataViewType typeSrc)
                    {
                        host.Check(typeSrc.ItemType.RawType == typeof(TFloat), "The column type must be vector of Double.");
                        int cv = Math.Max(1, typeSrc.Size);
                        List<int> nz = null;
                        int cfeat;
                        TFloat[] scales;
                        TFloat[] offsets;
                        int[] indices;
                        TFloat[] scalesSparse;
                        TFloat[] offsetsSparse;
 
                        AffineNormSerializationUtils.LoadModel(ctx, ref nz, out cfeat, out scales, out offsets,
                            out indices, out scalesSparse, out offsetsSparse);
                        host.Assert(scales.Length == cfeat);
                        host.Assert(offsets == null || offsets.Length == cfeat);
                        host.Assert(Utils.Size(nz) == 0 || offsets != null);
                        if (cfeat != cv)
                            throw host.Except("Normalizer expected {0} slots, but the input data column has {1} slots.", cfeat, cv);
 
                        return new ImplVec(host, scales, offsets, (offsets != null && nz.Count < cv / 2) ? nz.ToArray() : null);
                    }
 
                    private protected override void SaveModel(ModelSaveContext ctx)
                    {
                        AffineNormSerializationUtils.SaveModel(ctx, Scale.Length, null, Scale, Offset, saveText: true);
                    }
 
                    public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
                    {
                        var itemType = PfaUtils.Type.Double;
                        var arrType = PfaUtils.Type.Array(itemType);
                        var cellName = ctx.DeclareCell("AffNormScale", arrType, new JArray(Scale));
                        var scaleCell = PfaUtils.Cell(cellName);
                        if (Offset != null)
                        {
                            cellName = ctx.DeclareCell("AffNormOffset", arrType, new JArray(Offset));
                            var offsetCell = PfaUtils.Cell(cellName);
                            srcToken = PfaUtils.Call("a.zipmap", srcToken, offsetCell, PfaUtils.FuncRef(ctx.Pfa.EnsureSub(itemType)));
                        }
                        return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType)));
                    }
 
                    public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount)
                    {
                        if (Offset != null)
                            node.AddAttribute("offset", Offset);
                        else
                            node.AddAttribute("offset", Enumerable.Repeat<TFloat>(0, featureCount));
 
                        node.AddAttribute("scale", Scale);
                        return true;
                    }
 
                    public override Delegate GetGetter(DataViewRow input, int icol)
                    {
                        var getSrc = input.GetGetter<VBuffer<TFloat>>(input.Schema[icol]);
                        var bldr = new BufferBuilder<TFloat>(R8Adder.Instance);
                        ValueGetter<VBuffer<TFloat>> del;
                        if (Offset == null)
                        {
                            del = (ref VBuffer<TFloat> dst) =>
                            {
                                getSrc(ref dst);
                                Contracts.Check(dst.Length == Scale.Length);
                                FillValues(in dst, bldr, Scale);
                                bldr.GetResult(ref dst);
                            };
                        }
                        else if (IndicesNonZeroOffset == null)
                        {
                            del = (ref VBuffer<TFloat> dst) =>
                            {
                                getSrc(ref dst);
                                Contracts.Check(dst.Length == Scale.Length);
                                FillValues(in dst, bldr, Scale, Offset);
                                bldr.GetResult(ref dst);
                            };
                        }
                        else
                        {
                            del = (ref VBuffer<TFloat> dst) =>
                            {
                                getSrc(ref dst);
                                Contracts.Check(dst.Length == Scale.Length);
                                FillValues(in dst, bldr, Scale, Offset, IndicesNonZeroOffset);
                                bldr.GetResult(ref dst);
                            };
                        }
 
                        return del;
                    }
 
                    // REVIEW: Change to normalize in place. when there are no offsets.
                    private static void FillValues(in VBuffer<TFloat> input, BufferBuilder<TFloat> bldr, TFloat[] scale)
                    {
                        Contracts.Assert(input.Length == scale.Length);
                        int size = scale.Length;
                        var values = input.GetValues();
                        Contracts.Assert(0 <= values.Length && values.Length <= size);
 
                        // We always start with sparse, since we may make things sparser than the source.
                        bldr.Reset(size, dense: false);
                        if (values.Length == 0)
                            return;
 
                        if (values.Length >= size)
                        {
                            for (int i = 0; i < size; i++)
                                bldr.AddFeature(i, values[i] * scale[i]);
                            return;
                        }
 
                        // The input is sparse.
                        var indices = input.GetIndices();
                        for (int ii = 0; ii < values.Length; ii++)
                        {
                            int i = indices[ii];
                            Contracts.Assert(0 <= i && i < size);
                            bldr.AddFeature(i, values[ii] * scale[i]);
                        }
                    }
 
                    private static void FillValues(in VBuffer<TFloat> input, BufferBuilder<TFloat> bldr, TFloat[] scale,
                        TFloat[] offset)
                    {
                        Contracts.Assert(input.Length == scale.Length);
                        int size = scale.Length;
                        var values = input.GetValues();
                        Contracts.Assert(0 <= values.Length && values.Length <= size);
 
                        // We always start with sparse, since we may make things sparser than the source.
                        bldr.Reset(size, dense: false);
 
                        if (values.Length == 0)
                        {
                            for (int i = 0; i < size; i++)
                                bldr.AddFeature(i, -offset[i] * scale[i]);
                            return;
                        }
 
                        if (values.Length >= size)
                        {
                            for (int i = 0; i < size; i++)
                                bldr.AddFeature(i, (values[i] - offset[i]) * scale[i]);
                            return;
                        }
 
                        // The input is sparse.
                        var indices = input.GetIndices();
                        int ii = 0;
                        int ivSrc = indices[ii];
                        Contracts.Assert(ivSrc < size);
                        for (int ivDst = 0; ivDst < size; ivDst++)
                        {
                            Contracts.Assert(ivDst <= ivSrc && ivSrc <= size);
                            if (ivDst == ivSrc)
                            {
                                bldr.AddFeature(ivDst, (values[ii] - offset[ivDst]) * scale[ivDst]);
                                ivSrc = ++ii < values.Length ? indices[ii] : size;
                                Contracts.Assert(ii == values.Length || ivSrc < size);
                            }
                            else
                                bldr.AddFeature(ivDst, -offset[ivDst] * scale[ivDst]);
                        }
                    }
 
                    private static void FillValues(in VBuffer<TFloat> input, BufferBuilder<TFloat> bldr, TFloat[] scale,
                        TFloat[] offset, int[] nz)
                    {
                        Contracts.Assert(input.Length == scale.Length);
 
                        int size = scale.Length;
                        var values = input.GetValues();
                        Contracts.Assert(0 <= values.Length && values.Length <= size);
 
                        // We always start with sparse, since we may make things sparser than the source.
                        bldr.Reset(size, dense: false);
 
                        if (values.Length == 0)
                        {
                            foreach (int i in nz)
                                bldr.AddFeature(i, -offset[i] * scale[i]);
                            return;
                        }
 
                        if (values.Length >= size)
                        {
                            for (int i = 0; i < size; i++)
                                bldr.AddFeature(i, (values[i] - offset[i]) * scale[i]);
                            return;
                        }
 
                        // The input is sparse.
                        var indices = input.GetIndices();
                        int ii = 0;
                        int ivSrc = indices[ii];
                        int inz = 0;
                        int ivDst = nz[inz];
                        for (; ; )
                        {
                            Contracts.Assert(0 <= ivDst && ivDst <= size);
                            Contracts.Assert(0 <= ivSrc && ivSrc <= size);
                            Contracts.Assert(ii < values.Length && ivSrc == indices[ii] || ii == values.Length && ivSrc == size);
                            Contracts.Assert(inz < nz.Length && ivDst == nz[inz] || inz == nz.Length && ivDst == size);
 
                            int diff = ivSrc - ivDst;
                            if (diff > 0)
                            {
                                // Offset but no value
                                bldr.AddFeature(ivDst, -offset[ivDst] * scale[ivDst]);
                                ivDst = ++inz < nz.Length ? nz[inz] : size;
                            }
                            else if (diff < 0)
                            {
                                // Value but no offset
                                bldr.AddFeature(ivSrc, values[ii] * scale[ivSrc]);
                                ivSrc = ++ii < values.Length ? indices[ii] : size;
                                Contracts.Assert((ii == values.Length) == (ivSrc >= size));
                            }
                            else
                            {
                                Contracts.Assert(ivSrc == ivDst);
                                if (ivDst >= size)
                                    break;
 
                                bldr.AddFeature(ivDst, (values[ii] - offset[ivDst]) * scale[ivDst]);
                                ivSrc = ++ii < values.Length ? indices[ii] : size;
                                Contracts.Assert((ii == values.Length) == (ivSrc >= size));
                                ivDst = ++inz < nz.Length ? nz[inz] : size;
                                Contracts.Assert((inz == nz.Length) == (ivDst >= size));
                            }
                        }
                        Contracts.Assert(ii == values.Length);
                        Contracts.Assert(inz == nz.Length);
                    }
                }
            }
        }
 
        internal abstract partial class CdfColumnFunction
        {
            public static IColumnFunction Create(IHost host, TFloat mean, TFloat stddev, bool useLog)
            {
                return new Dbl.ImplOne(host, mean, stddev, useLog);
            }
 
            public static IColumnFunction Create(IHost host, TFloat[] mean, TFloat[] stddev, bool useLog)
            {
                return new Dbl.ImplVec(host, mean, stddev, useLog);
            }
 
            private static class Dbl
            {
                public sealed class ImplOne : ImplOne<TFloat>
                {
                    public ImplOne(IHost host, TFloat mean, TFloat stddev, bool useLog)
                        : base(host, mean, stddev, useLog)
                    {
                    }
 
                    public static new ImplOne Create(ModelLoadContext ctx, IHost host, DataViewType typeSrc)
                    {
                        host.Check(typeSrc.RawType == typeof(TFloat), "The column type must be Double.");
                        host.CheckValue(ctx, nameof(ctx));
                        ctx.CheckAtModel(GetVersionInfo());
 
                        bool useLog;
                        TFloat[] mean;
                        TFloat[] stddev;
                        CdfNormSerializationUtils.LoadModel(ctx, 1, out useLog, out mean, out stddev);
 
                        return new ImplOne(host, mean[0], stddev[0], useLog);
                    }
 
                    private void GetResult(ref TFloat input, ref TFloat value)
                    {
                        var val = UseLog ? (TFloat)Math.Log(input) : input;
                        if (!FloatUtils.IsFinite(val))
                        {
                            value = 0;
                            return;
                        }
 
                        value = CdfUtils.Cdf(val, Mean, Stddev);
                    }
 
                    private protected override void SaveModel(ModelSaveContext ctx)
                    {
                        Contracts.AssertValue(ctx);
                        ctx.CheckAtModel();
                        ctx.SetVersionInfo(GetVersionInfo());
 
                        CdfNormSerializationUtils.SaveModel(ctx, UseLog, new[] { Mean }, new[] { Stddev });
                    }
 
                    public override Delegate GetGetter(DataViewRow input, int icol)
                    {
                        if (Stddev <= TFloat.Epsilon)
                        {
                            ValueGetter<TFloat> trivial =
                                (ref TFloat dst) =>
                                {
                                    dst = 0;
                                };
                            return trivial;
                        }
 
                        var getSrc = input.GetGetter<TFloat>(input.Schema[icol]);
                        ValueGetter<TFloat> del =
                            (ref TFloat dst) =>
                            {
                                getSrc(ref dst);
                                GetResult(ref dst, ref dst);
                            };
                        return del;
                    }
                }
 
                public sealed class ImplVec : ImplVec<TFloat>
                {
                    public ImplVec(IHost host, TFloat[] mean, TFloat[] stddev, bool useLog)
                        : base(host, mean, stddev, useLog)
                    {
                    }
 
                    public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorDataViewType typeSrc)
                    {
                        host.Check(typeSrc.ItemType.RawType == typeof(TFloat), "The column type must be vector of Double.");
                        int cv = Math.Max(1, typeSrc.Size);
 
                        host.CheckValue(ctx, nameof(ctx));
                        ctx.CheckAtModel(GetVersionInfo());
 
                        bool useLog;
                        TFloat[] mean;
                        TFloat[] stddev;
                        CdfNormSerializationUtils.LoadModel(ctx, cv, out useLog, out mean, out stddev);
 
                        return new ImplVec(host, mean, stddev, useLog);
                    }
 
                    private protected override void SaveModel(ModelSaveContext ctx)
                    {
                        Contracts.AssertValue(ctx);
                        ctx.CheckAtModel();
                        ctx.SetVersionInfo(GetVersionInfo());
 
                        CdfNormSerializationUtils.SaveModel(ctx, UseLog, Mean, Stddev);
                    }
 
                    public override Delegate GetGetter(DataViewRow input, int icol)
                    {
                        var getSrc = input.GetGetter<VBuffer<TFloat>>(input.Schema[icol]);
                        var bldr = new BufferBuilder<TFloat>(R8Adder.Instance);
                        ValueGetter<VBuffer<TFloat>> del;
                        del = (ref VBuffer<TFloat> dst) =>
                        {
                            getSrc(ref dst);
                            Host.Check(dst.Length == Mean.Length);
                            FillValues(in dst, bldr, Mean, Stddev, UseLog);
                            bldr.GetResult(ref dst);
                        };
 
                        return del;
                    }
 
                    private static void FillValues(in VBuffer<TFloat> input, BufferBuilder<TFloat> bldr, TFloat[] mean,
                        TFloat[] stddev, bool useLog)
                    {
                        Contracts.Assert(input.Length == mean.Length);
                        int size = mean.Length;
                        var values = input.GetValues();
                        Contracts.Assert(0 <= values.Length && values.Length <= size);
 
                        // We always start with sparse, since we may make things sparser than the source.
                        bldr.Reset(size, dense: false);
 
                        if (values.Length == 0)
                            return;
 
                        if (values.Length >= size)
                        {
                            for (int i = 0; i < size; i++)
                            {
                                var sigma = stddev[i];
                                if (sigma > TFloat.Epsilon)
                                {
                                    var val = useLog ? (TFloat)Math.Log(values[i]) : values[i];
                                    if (FloatUtils.IsFinite(val))
                                        bldr.AddFeature(i, CdfUtils.Cdf(val, mean[i], sigma));
                                }
                            }
                            return;
                        }
 
                        // The input is sparse.
                        var indices = input.GetIndices();
                        for (int ii = 0; ii < values.Length; ii++)
                        {
                            var ivDst = indices[ii];
                            var sigma = stddev[ivDst];
                            if (sigma > TFloat.Epsilon)
                            {
                                var val = useLog ? (TFloat)Math.Log(values[ii]) : values[ii];
                                if (FloatUtils.IsFinite(val))
                                    bldr.AddFeature(ivDst, CdfUtils.Cdf(val, mean[ivDst], sigma));
                            }
                        }
                    }
                }
            }
        }
 
        internal abstract partial class BinColumnFunction
        {
            public static IColumnFunction Create(IHost host, TFloat[] binUpperBounds, bool fixZero)
            {
                return new Dbl.ImplOne(host, binUpperBounds, fixZero);
            }
 
            public static IColumnFunction Create(IHost host, TFloat[][] binUpperBounds, bool fixZero)
            {
                return new Dbl.ImplVec(host, binUpperBounds, fixZero);
            }
 
            private static class Dbl
            {
                public sealed class ImplOne : BinColumnFunction
                {
                    private readonly TFloat[] _binUpperBounds;
                    private readonly TFloat _den;
                    private readonly TFloat _offset;
 
                    public ImplOne(IHost host, TFloat[] binUpperBounds, bool fixZero)
                        : base(host)
                    {
                        _binUpperBounds = binUpperBounds;
                        _den = Math.Max(1, _binUpperBounds.Length - 1);
                        if (fixZero)
                            _offset = _binUpperBounds.FindIndexSorted(0) / _den;
                        Host.Assert(0 <= _offset && _offset <= 1);
                    }
 
                    public static new ImplOne Create(ModelLoadContext ctx, IHost host, DataViewType typeSrc)
                    {
                        host.Check(typeSrc.RawType == typeof(TFloat), "The column type must be Double.");
                        host.CheckValue(ctx, nameof(ctx));
                        ctx.CheckAtModel(GetVersionInfo());
 
                        // *** Binary format ***
                        // Byte: fixZero bool
                        bool fixZero = ctx.Reader.ReadBoolByte();
 
                        TFloat[][] binUpperBounds = null;
                        if (!ctx.TryProcessSubModel("BinNormalizer",
                            c => BinNormSerializationUtils.LoadModel(c, out binUpperBounds)))
                        {
                            throw host.ExceptDecode();
                        }
                        if (binUpperBounds.Length != 1)
                            throw host.Except("Normalizer expected {0} slots, but the input data column has 1 slot.", binUpperBounds.Length);
 
                        return new ImplOne(host, binUpperBounds[0], fixZero);
                    }
 
                    private protected override void SaveModel(ModelSaveContext ctx)
                    {
                        Contracts.AssertValue(ctx);
                        ctx.CheckAtModel();
                        ctx.SetVersionInfo(GetVersionInfo());
 
                        // *** Binary format ***
                        // Byte: fixZero bool
                        ctx.Writer.WriteBoolByte(_offset != 0);
 
                        ctx.SaveSubModel("BinNormalizer",
                            c => BinNormSerializationUtils.SaveModel(c, new[] { _binUpperBounds }, saveText: true));
                    }
 
                    public override Delegate GetGetter(DataViewRow input, int icol)
                    {
                        var getSrc = input.GetGetter<TFloat>(input.Schema[icol]);
                        ValueGetter<TFloat> del =
                            (ref TFloat dst) =>
                            {
                                getSrc(ref dst);
                                GetResult(ref dst, ref dst);
                            };
                        return del;
                    }
 
                    private void GetResult(ref TFloat input, ref TFloat value)
                    {
                        value = BinUtils.GetValue(in input, _binUpperBounds, _den, _offset);
                    }
 
                    public override NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams()
                         => new NormalizingTransformer.BinNormalizerModelParameters<TFloat>(ImmutableArray.Create(_binUpperBounds), _den, _offset);
                }
 
                public sealed class ImplVec : BinColumnFunction
                {
                    private readonly TFloat[][] _binUpperBounds;
                    private readonly TFloat[] _den;
                    private readonly TFloat[] _offset;
 
                    public ImplVec(IHost host, TFloat[][] binUpperBounds, bool fixZero)
                        : base(host)
                    {
                        _binUpperBounds = binUpperBounds;
                        _den = new TFloat[_binUpperBounds.Length];
                        for (int i = 0; i < _binUpperBounds.Length; i++)
                            _den[i] = Math.Max(1, _binUpperBounds[i].Length - 1);
                        if (fixZero)
                        {
                            _offset = new TFloat[_binUpperBounds.Length];
                            bool any = false;
                            for (int i = 0; i < _binUpperBounds.Length; i++)
                            {
                                _offset[i] = _binUpperBounds[i].FindIndexSorted(0) / _den[i];
                                Host.Assert(0 <= _offset[i] && _offset[i] <= 1);
                                any |= _offset[i] != 0;
                            }
                            if (!any)
                                _offset = null;
                        }
                    }
 
                    public static ImplVec Create(ModelLoadContext ctx, IHost host, VectorDataViewType typeSrc)
                    {
                        host.Check(typeSrc.ItemType.RawType == typeof(TFloat), "The column type must be vector of Double.");
                        int cv = Math.Max(1, typeSrc.Size);
                        host.CheckValue(ctx, nameof(ctx));
                        ctx.CheckAtModel(GetVersionInfo());
 
                        // *** Binary format ***
                        // Byte: fixZero bool
                        bool fixZero = ctx.Reader.ReadBoolByte();
 
                        TFloat[][] binUpperBounds = null;
                        if (!ctx.TryProcessSubModel("BinNormalizer",
                            c => BinNormSerializationUtils.LoadModel(c, out binUpperBounds)))
                        {
                            throw host.ExceptDecode();
                        }
                        if (binUpperBounds.Length != cv)
                            throw host.Except("Normalizer expected {0} slots, but the input data column has {1} slots.", binUpperBounds.Length, cv);
 
                        return new ImplVec(host, binUpperBounds, fixZero);
                    }
 
                    private protected override void SaveModel(ModelSaveContext ctx)
                    {
                        Contracts.AssertValue(ctx);
                        ctx.CheckAtModel();
                        ctx.SetVersionInfo(GetVersionInfo());
 
                        // *** Binary format ***
                        // Byte: fixZero bool
                        ctx.Writer.WriteBoolByte(_offset != null);
 
                        ctx.SaveSubModel("BinNormalizer", c => BinNormSerializationUtils.SaveModel(c, _binUpperBounds, saveText: true));
                    }
 
                    public override Delegate GetGetter(DataViewRow input, int icol)
                    {
                        var getSrc = input.GetGetter<VBuffer<TFloat>>(input.Schema[icol]);
                        var bldr = new BufferBuilder<TFloat>(R8Adder.Instance);
                        ValueGetter<VBuffer<TFloat>> del =
                            (ref VBuffer<TFloat> dst) =>
                            {
                                getSrc(ref dst);
                                Host.Check(dst.Length == _binUpperBounds.Length);
                                GetResult(in dst, ref dst, bldr);
                            };
                        return del;
                    }
 
                    private void GetResult(in VBuffer<TFloat> input, ref VBuffer<TFloat> value, BufferBuilder<TFloat> bldr)
                    {
                        Contracts.Assert(input.Length == _binUpperBounds.Length);
                        int size = _binUpperBounds.Length;
                        var values = input.GetValues();
                        Contracts.Assert(0 <= values.Length && values.Length <= size);
 
                        // We always start with sparse, since we may make things sparser than the source.
                        bldr.Reset(size, dense: false);
                        if (values.Length == 0)
                        {
                            bldr.GetResult(ref value);
                            return;
                        }
 
                        if (values.Length >= size)
                        {
                            if (_offset != null)
                            {
                                for (int i = 0; i < size; i++)
                                    bldr.AddFeature(i, BinUtils.GetValue(in values[i], _binUpperBounds[i], _den[i], _offset[i]));
                            }
                            else
                            {
                                for (int i = 0; i < size; i++)
                                    bldr.AddFeature(i, BinUtils.GetValue(in values[i], _binUpperBounds[i], _den[i]));
                            }
                            bldr.GetResult(ref value);
                            return;
                        }
 
                        // The input is sparse.
                        if (_offset != null)
                        {
                            var indices = input.GetIndices();
                            int ii = 0;
                            int ivSrc = indices[ii];
                            Contracts.Assert(ivSrc < size);
                            TFloat zero = 0;
                            for (int ivDst = 0; ivDst < size; ivDst++)
                            {
                                Contracts.Assert(ivDst <= ivSrc && ivSrc <= size);
                                if (ivDst == ivSrc)
                                {
                                    bldr.AddFeature(ivDst,
                                        BinUtils.GetValue(in values[ii], _binUpperBounds[ivDst], _den[ivDst], _offset[ivDst]));
                                    ivSrc = ++ii < values.Length ? indices[ii] : size;
                                    Contracts.Assert(ii == values.Length || ivSrc < size);
                                }
                                else
                                    bldr.AddFeature(ivDst,
                                        BinUtils.GetValue(in zero, _binUpperBounds[ivDst], _den[ivDst], _offset[ivDst]));
                            }
                        }
                        else
                        {
                            var indices = input.GetIndices();
                            for (int ii = 0; ii < values.Length; ii++)
                            {
                                int i = indices[ii];
                                Contracts.Assert(0 <= i && i < size);
                                bldr.AddFeature(i, BinUtils.GetValue(in values[ii], _binUpperBounds[i], _den[i]));
                            }
                        }
 
                        bldr.GetResult(ref value);
                    }
 
                    public override NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams()
                          => new NormalizingTransformer.BinNormalizerModelParameters<ImmutableArray<TFloat>>(_binUpperBounds.Select(b => ImmutableArray.Create(b)).ToImmutableArray(),
                              ImmutableArray.Create(_den),
                              ImmutableArray.Create(_offset));
                }
            }
        }
 
        internal static partial class MinMaxUtils
        {
            public static void ComputeScaleAndOffset(bool fixZero, TFloat max, TFloat min, out TFloat scale, out TFloat offset)
            {
                if (fixZero)
                    ComputeScaleAndOffsetFixZero(max, min, out scale, out offset);
                else
                    ComputeScaleAndOffset(max, min, out scale, out offset);
            }
 
            private static void ComputeScaleAndOffset(TFloat max, TFloat min, out TFloat scale, out TFloat offset)
            {
                Contracts.Assert(!TFloat.IsNaN(min));
                Contracts.Assert(!TFloat.IsNaN(max));
 
                // If the column has only NaNs, or has no rows at all, then min==infinity and max==-infinity. In all
                // other cases, min<=max.
                Contracts.Assert(min <= max || (TFloat.IsPositiveInfinity(min) && TFloat.IsNegativeInfinity(max)));
 
                // In the case where max <= min, the slot contains no useful information (since it is either constant, or
                // is all NaNs, or has no rows), so we force it to zero.
                // Note that setting scale to zero effectively maps finite values to zero,
                // but infinities and NaN to NaN.
                // REVIEW: If min <= 0 and max >= 0, then why not fix zero for this slot and simply scale by 1 / max(abs(..))?
                // We could even be more aggressive about it, and fix zero if 0 < min < max <= 2 * min.
                // Then the common case where features are in the range [1, N] (and integer valued) wouldn't subtract 1 every time....
                if (!(max > min))
                    scale = offset = 0;
                else if ((scale = 1 / (max - min)) == 0)
                    offset = 0;
                else
                    offset = min;
                Contracts.Assert(0 <= scale && scale < TFloat.PositiveInfinity);
            }
 
            private static void ComputeScaleAndOffsetFixZero(TFloat max, TFloat min, out TFloat scale, out TFloat offset)
            {
                Contracts.Assert(!TFloat.IsNaN(min));
                Contracts.Assert(!TFloat.IsNaN(max));
 
                // If the column has only NaNs, or has no rows at all, then min==infinity and max==-infinity. In all
                // other cases, min<=max.
                Contracts.Assert(min <= max || (TFloat.IsPositiveInfinity(min) && TFloat.IsNegativeInfinity(max)));
 
                // In the case where max <= min, the slot contains no useful information (since it is either constant, or
                // is all NaNs, or has no rows), so we force it to zero.
                // Note that setting scale to zero effectively maps finite values to zero,
                // but infinities and NaN to NaN.
                offset = 0;
                if (!(max > min))
                    scale = 0;
                else
                    scale = 1 / Math.Max(Math.Abs(max), Math.Abs(min));
                Contracts.Assert(0 <= scale && scale < TFloat.PositiveInfinity);
            }
        }
 
        internal static partial class MeanVarUtils
        {
            public static void ComputeScaleAndOffset(Double mean, Double stddev, out TFloat scale, out TFloat offset)
            {
                Contracts.Assert(!Double.IsNaN(mean));
                Contracts.Assert(stddev >= 0);
 
                // In the case where stdev==0, the slot contains no useful information (since it is constant),
                // so we force it to zero. Note that setting scale to zero effectively maps finite values to zero,
                // but infinities and NaN to NaN.
                if (stddev == 0)
                    scale = offset = 0;
                else if ((scale = 1 / (TFloat)stddev) == 0)
                    offset = 0;
                else
                    offset = (TFloat)mean;
                Contracts.Assert(0 <= scale && scale < TFloat.PositiveInfinity);
            }
 
            public static void ComputeScaleAndOffsetFixZero(Double mean, Double meanSquaredError, out TFloat scale, out TFloat offset)
            {
                Contracts.Assert(!Double.IsNaN(mean));
                Contracts.Assert(meanSquaredError >= 0);
 
                // In the case where stdev==0, the slot contains no useful information (since it is constant),
                // so we force it to zero. Note that setting scale to zero effectively maps finite values to zero,
                // but infinities and NaN to NaN.
                offset = 0;
                if (meanSquaredError == 0)
                    scale = 0;
                else
                    scale = 1 / (TFloat)Math.Sqrt(meanSquaredError + mean * mean);
                Contracts.Assert(0 <= scale && scale < TFloat.PositiveInfinity);
            }
        }
 
        private static partial class CdfUtils
        {
            public static TFloat Cdf(TFloat input, TFloat mean, TFloat stddev)
            {
                // REVIEW: This should be changed to call the AML stats library.
                // Temporarily, it does the following:
                // Using CDF(x) = 0.5 ( 1 + erf( ( x - mu ) / ( sigma * sqrt(2) ) ) )
                // Also using an approximation for erf(x) from https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions
                var x = (input - mean) / stddev;
                var x2 = x * x / 2;
                const TFloat a = (TFloat)0.147;
                var ax2 = a * x2;
                return (TFloat)(0.5 + 0.5 * Math.Sign(x) * Math.Sqrt(1 - Math.Exp(-x2 * (4 / Math.PI + ax2) / (1 + ax2))));
            }
        }
 
        internal static partial class BinUtils
        {
            public static TFloat GetValue(in TFloat input, TFloat[] binUpperBounds, TFloat den, TFloat offset)
            {
                if (TFloat.IsNaN(input))
                    return input;
                int binIdx = binUpperBounds.FindIndexSorted(0, binUpperBounds.Length - 1, input);
                Contracts.Check(binIdx < binUpperBounds.Length);
                var value = binIdx / den - offset;
                Contracts.Assert(-1 <= value && value <= 1);
                return value;
            }
 
            public static TFloat GetValue(in TFloat input, TFloat[] binUpperBounds, TFloat den)
            {
                if (TFloat.IsNaN(input))
                    return input;
                int binIdx = binUpperBounds.FindIndexSorted(0, binUpperBounds.Length - 1, input);
                Contracts.Check(binIdx < binUpperBounds.Length);
                var value = binIdx / den;
                Contracts.Assert(0 <= value & value <= 1);
                return value;
            }
        }
 
        private static class Dbl
        {
            public abstract class MinMaxOneColumnFunctionBuilderBase : OneColumnFunctionBuilderBase<TFloat>
            {
                protected readonly bool Fix;
                protected readonly MinMaxDblAggregator Aggregator;
                private VBuffer<TFloat> _buffer;
 
                protected MinMaxOneColumnFunctionBuilderBase(IHost host, long lim, bool fix, ValueGetter<TFloat> getSrc)
                    : base(host, lim, getSrc)
                {
                    Fix = fix;
                    Aggregator = new MinMaxDblAggregator(1);
                    _buffer = new VBuffer<TFloat>(1, new TFloat[1]);
                }
 
                protected override bool ProcessValue(in TFloat val)
                {
                    if (!base.ProcessValue(in val))
                        return false;
                    VBufferEditor.CreateFromBuffer(ref _buffer).Values[0] = val;
                    Aggregator.ProcessValue(in _buffer);
                    return true;
                }
            }
 
            public sealed class MinMaxOneColumnFunctionBuilder : MinMaxOneColumnFunctionBuilderBase
            {
                private MinMaxOneColumnFunctionBuilder(IHost host, long lim, bool fix, ValueGetter<TFloat> getSrc)
                    : base(host, lim, fix, getSrc)
                {
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.MinMaxColumnOptions column, IHost host, DataViewType srcType,
                    ValueGetter<TFloat> getter)
                {
                    host.CheckUserArg(column.MaximumExampleCount > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    return new MinMaxOneColumnFunctionBuilder(host, column.MaximumExampleCount, column.EnsureZeroUntouched, getter);
                }
 
                public override IColumnFunction CreateColumnFunction()
                {
                    Aggregator.Finish();
                    TFloat scale;
                    TFloat offset;
                    MinMaxUtils.ComputeScaleAndOffset(Fix, Aggregator.Max[0], Aggregator.Min[0], out scale, out offset);
 
                    return AffineColumnFunction.Create(Host, scale, offset);
                }
            }
 
            public abstract class MinMaxVecColumnFunctionBuilderBase : VecColumnFunctionBuilderBase<TFloat>
            {
                protected readonly MinMaxDblAggregator Aggregator;
                protected readonly bool Fix;
 
                protected MinMaxVecColumnFunctionBuilderBase(IHost host, int cv, long lim, bool fix, ValueGetter<VBuffer<TFloat>> getSrc)
                    : base(host, lim, getSrc)
                {
                    Fix = fix;
                    Aggregator = new MinMaxDblAggregator(cv);
                }
 
                protected override bool ProcessValue(in VBuffer<TFloat> buffer)
                {
                    if (!base.ProcessValue(in buffer))
                        return false;
                    var size = Aggregator.Min.Length;
                    if (buffer.Length != size)
                        throw Host.Except("Normalizer expected {0} slots but got {1}", size, buffer.Length);
                    Aggregator.ProcessValue(in buffer);
                    return true;
                }
            }
 
            public sealed class MinMaxVecColumnFunctionBuilder : MinMaxVecColumnFunctionBuilderBase
            {
                private MinMaxVecColumnFunctionBuilder(IHost host, int cv, long lim, bool fix,
                    ValueGetter<VBuffer<TFloat>> getSrc)
                    : base(host, cv, lim, fix, getSrc)
                {
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.MinMaxColumnOptions column, IHost host, VectorDataViewType srcType,
                    ValueGetter<VBuffer<TFloat>> getter)
                {
                    host.CheckUserArg(column.MaximumExampleCount > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    var cv = srcType.Size;
                    return new MinMaxVecColumnFunctionBuilder(host, cv, column.MaximumExampleCount, column.EnsureZeroUntouched, getter);
                }
 
                public override IColumnFunction CreateColumnFunction()
                {
                    Aggregator.Finish();
                    var cv = Aggregator.Min.Length;
                    // These are ignored if fix is true.
                    int lim = cv / 2;
                    var nz = new List<int>();
 
                    for (int i = 0; i < cv; i++)
                    {
                        MinMaxUtils.ComputeScaleAndOffset(Fix, Aggregator.Max[i], Aggregator.Min[i], out Aggregator.Max[i], out Aggregator.Min[i]);
                        if (Aggregator.Min[i] != 0 && nz.Count < lim)
                            nz.Add(i);
                    }
 
                    var min = Aggregator.Min;
                    // Note: There is a special case when cv == 1. In this case lim == 0, so nz will be empty regardless
                    // of whether the offset is non-zero.
                    Host.Assert((lim == 0) == (cv == 1));
                    int[] indicesNonZeroOffset = null;
                    if (Fix)
                        min = null;
                    else if (cv == 1)
                    {
                        if (min[0] == 0)
                            min = null;
                    }
                    else if (nz.Count == 0)
                        min = null;
                    else if (nz.Count < lim)
                        indicesNonZeroOffset = nz.ToArray();
 
                    return AffineColumnFunction.Create(Host, Aggregator.Max, min, indicesNonZeroOffset);
                }
            }
 
            public sealed class MeanVarOneColumnFunctionBuilder : OneColumnFunctionBuilderBase<TFloat>
            {
                private readonly bool _useLog;
                private readonly bool _useCdf;
                private readonly bool _fix;
                private readonly bool _useSampleVariance;
                private readonly MeanVarDblAggregator _aggregator;
                private VBuffer<TFloat> _buffer;
 
                private MeanVarOneColumnFunctionBuilder(IHost host, long lim, bool fix, ValueGetter<TFloat> getSrc, bool useLog, bool useCdf, bool useSampleVariance)
                    : base(host, lim, getSrc)
                {
                    _useLog = useLog;
                    _useCdf = useCdf;
                    _fix = fix;
                    _useSampleVariance = useSampleVariance;
                    _aggregator = new MeanVarDblAggregator(1, useLog);
                    _buffer = new VBuffer<TFloat>(1, new TFloat[1]);
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.MeanVarianceColumnOptions column, IHost host, DataViewType srcType,
                    ValueGetter<TFloat> getter)
                {
                    host.CheckUserArg(column.MaximumExampleCount > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    return new MeanVarOneColumnFunctionBuilder(host, column.MaximumExampleCount, column.EnsureZeroUntouched, getter, false, column.UseCdf, column.UseSampleVariance);
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.LogMeanVarianceColumnOptions column, IHost host, DataViewType srcType,
                    ValueGetter<TFloat> getter)
                {
                    var lim = column.MaximumExampleCount;
                    host.CheckUserArg(lim > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    return new MeanVarOneColumnFunctionBuilder(host, lim, false, getter, true, column.UseCdf, column.UseSampleVariance);
                }
 
                protected override bool ProcessValue(in TFloat origVal)
                {
                    if (!base.ProcessValue(in origVal))
                        return false;
                    VBufferEditor.CreateFromBuffer(ref _buffer).Values[0] = origVal;
                    _aggregator.ProcessValue(in _buffer);
                    return true;
                }
 
                public override IColumnFunction CreateColumnFunction()
                {
                    _aggregator.Finish();
                    if (_useCdf)
                        return CreateCdfColumnFunction();
                    return CreateAffineColumnFunction();
                }
 
                private IColumnFunction CreateAffineColumnFunction()
                {
                    Contracts.Assert(_aggregator.M2[0] >= 0);
                    if (_aggregator.M2[0] == 0)
                        return AffineColumnFunction.Create(Host, (TFloat)0, (TFloat)0);
                    TFloat scale;
                    TFloat offset;
                    var stdDev = _useSampleVariance ? _aggregator.StdDevSample[0] : _aggregator.StdDevPopulation[0];
                    var variance = _useSampleVariance ? _aggregator.SampleVariance[0] : _aggregator.MeanSquareError[0];
 
                    if (_fix)
                        MeanVarUtils.ComputeScaleAndOffsetFixZero(_aggregator.Mean[0], variance, out scale, out offset);
                    else
                        MeanVarUtils.ComputeScaleAndOffset(_aggregator.Mean[0], stdDev, out scale, out offset);
 
                    return AffineColumnFunction.Create(Host, scale, offset);
                }
 
                private IColumnFunction CreateCdfColumnFunction()
                {
                    Contracts.Assert(_aggregator.M2[0] >= 0);
                    if (_aggregator.M2[0] == 0 || _aggregator.Counts[0] == 0)
                        return CdfColumnFunction.Create(Host, (TFloat)0, (TFloat)0, _useLog);
 
                    var stdDev = _useSampleVariance ? _aggregator.StdDevSample[0] : _aggregator.StdDevPopulation[0];
 
                    return CdfColumnFunction.Create(Host, (TFloat)_aggregator.Mean[0], (TFloat)stdDev, _useLog);
                }
            }
 
            public sealed class MeanVarVecColumnFunctionBuilder : VecColumnFunctionBuilderBase<TFloat>
            {
                private readonly bool _fix;
                private readonly bool _useLog;
                private readonly bool _useCdf;
                private readonly bool _useSampleVariance;
                private readonly MeanVarDblAggregator _aggregator;
 
                private MeanVarVecColumnFunctionBuilder(IHost host, int cv, long lim, bool fix,
                    ValueGetter<VBuffer<TFloat>> getSrc, bool useLog, bool useCdf, bool useSampleVariance)
                    : base(host, lim, getSrc)
                {
                    _aggregator = new MeanVarDblAggregator(cv, useLog);
                    _fix = fix;
                    _useLog = useLog;
                    _useCdf = useCdf;
                    _useSampleVariance = useSampleVariance;
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.MeanVarianceColumnOptions column, IHost host, VectorDataViewType srcType,
                    ValueGetter<VBuffer<TFloat>> getter)
                {
                    host.CheckUserArg(column.MaximumExampleCount > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    var cv = srcType.Size;
                    return new MeanVarVecColumnFunctionBuilder(host, cv, column.MaximumExampleCount, column.EnsureZeroUntouched, getter, false, column.UseCdf, column.UseSampleVariance);
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.LogMeanVarianceColumnOptions column, IHost host, VectorDataViewType srcType,
                    ValueGetter<VBuffer<TFloat>> getter)
                {
                    var lim = column.MaximumExampleCount;
                    host.CheckUserArg(lim > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    var cv = srcType.Size;
                    return new MeanVarVecColumnFunctionBuilder(host, cv, lim, false, getter, true, column.UseCdf, column.UseSampleVariance);
                }
 
                protected override bool ProcessValue(in VBuffer<TFloat> buffer)
                {
                    if (!base.ProcessValue(in buffer))
                        return false;
 
                    _aggregator.ProcessValue(in buffer);
                    return true;
                }
 
                public override IColumnFunction CreateColumnFunction()
                {
                    _aggregator.Finish();
                    if (_useCdf)
                        return CreateCdfColumnFunction();
                    return CreateAffineColumnFunction();
                }
 
                private IColumnFunction CreateAffineColumnFunction()
                {
                    int cv = _aggregator.Mean.Length;
                    // These are ignored if fix is true.
                    int lim = cv / 2;
                    var nz = new List<int>();
 
                    var scale = new TFloat[cv];
                    var offset = new TFloat[cv];
 
                    for (int i = 0; i < cv; i++)
                    {
                        Contracts.Assert(_aggregator.M2[i] >= 0);
                        if (_aggregator.M2[i] == 0)
                        {
                            scale[i] = offset[i] = 0;
                            continue;
                        }
 
                        var stdDev = _useSampleVariance ? _aggregator.StdDevSample[i] : _aggregator.StdDevPopulation[i];
                        var variance = _useSampleVariance ? _aggregator.SampleVariance[i] : _aggregator.MeanSquareError[i];
 
                        if (_fix)
                            MeanVarUtils.ComputeScaleAndOffsetFixZero(_aggregator.Mean[i], variance, out scale[i], out offset[i]);
                        else
                            MeanVarUtils.ComputeScaleAndOffset(_aggregator.Mean[i], stdDev, out scale[i], out offset[i]);
                        if (offset[i] != 0 && nz.Count < lim)
                            nz.Add(i);
                    }
 
                    // Note: There is a special case when cv == 1. In this case lim == 0, so nz will be empty regardless
                    // of whether the offset is non-zero.
                    Host.Assert((lim == 0) == (cv == 1));
                    int[] indicesNonZeroOffset = null;
                    if (_fix)
                        offset = null;
                    else if (cv == 1)
                    {
                        if (offset[0] == 0)
                            offset = null;
                    }
                    else if (nz.Count == 0)
                        offset = null;
                    else if (nz.Count < lim)
                        indicesNonZeroOffset = nz.ToArray();
 
                    return AffineColumnFunction.Create(Host, scale, offset, indicesNonZeroOffset);
                }
 
                private IColumnFunction CreateCdfColumnFunction()
                {
                    int cv = _aggregator.Mean.Length;
 
                    var mean = new TFloat[cv];
                    var stddev = new TFloat[cv];
 
                    for (int i = 0; i < cv; i++)
                    {
                        Contracts.Assert(_aggregator.M2[i] >= 0);
                        if (_aggregator.M2[i] == 0 || _aggregator.Counts[i] == 0)
                        {
                            mean[i] = stddev[i] = 0;
                            continue;
                        }
                        mean[i] = (TFloat)_aggregator.Mean[i];
                        stddev[i] = (TFloat)(_useSampleVariance ? _aggregator.StdDevSample[i] : _aggregator.StdDevPopulation[i]);
 
                    }
 
                    return CdfColumnFunction.Create(Host, mean, stddev, _useLog);
                }
            }
 
            public sealed class BinOneColumnFunctionBuilder : OneColumnFunctionBuilderBase<TFloat>
            {
                private readonly bool _fix;
                private readonly int _numBins;
                private readonly List<TFloat> _values;
 
                private BinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, ValueGetter<TFloat> getSrc)
                    : base(host, lim, getSrc)
                {
                    _fix = fix;
                    _numBins = numBins;
                    _values = new List<TFloat>();
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.BinningColumnOptions column, IHost host, DataViewType srcType,
                    ValueGetter<TFloat> getter)
                {
                    var lim = column.MaximumExampleCount;
                    host.CheckUserArg(lim > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    bool fix = column.EnsureZeroUntouched;
                    var numBins = column.MaximumBinCount;
                    host.CheckUserArg(numBins > 1, nameof(column.MaximumBinCount), "Must be greater than 1");
                    return new BinOneColumnFunctionBuilder(host, lim, fix, numBins, getter);
                }
 
                protected override bool ProcessValue(in TFloat val)
                {
                    if (!base.ProcessValue(in val))
                        return false;
                    if (val != 0)
                        _values.Add(val);
                    return true;
                }
 
                public override IColumnFunction CreateColumnFunction()
                {
                    var binFinder = new GreedyBinFinder();
                    var numZeroes = checked((int)(Lim - Rem - _values.Count));
                    _values.RemoveAll(TFloat.IsNaN);
                    var binUpperBounds = binFinder.FindBins(_numBins, _values, numZeroes);
                    return BinColumnFunction.Create(Host, binUpperBounds, _fix);
                }
            }
 
            public sealed class BinVecColumnFunctionBuilder : VecColumnFunctionBuilderBase<TFloat>
            {
                private readonly bool _fix;
                private readonly int _numBins;
                private readonly List<TFloat>[] _values;
 
                private BinVecColumnFunctionBuilder(IHost host, int cv, long lim, bool fix, int numBins,
                    ValueGetter<VBuffer<TFloat>> getSrc)
                    : base(host, lim, getSrc)
                {
                    _fix = fix;
                    _numBins = numBins;
                    _values = new List<TFloat>[cv];
                    for (int i = 0; i < cv; i++)
                    {
                        _values[i] = new List<TFloat>();
                    }
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.BinningColumnOptions column, IHost host, VectorDataViewType srcType,
                    ValueGetter<VBuffer<TFloat>> getter)
                {
                    var lim = column.MaximumExampleCount;
                    host.CheckUserArg(lim > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    bool fix = column.EnsureZeroUntouched;
                    var numBins = column.MaximumBinCount;
                    host.CheckUserArg(numBins > 1, nameof(column.MaximumBinCount), "Must be greater than 1");
                    var cv = srcType.Size;
                    return new BinVecColumnFunctionBuilder(host, cv, lim, fix, numBins, getter);
                }
 
                protected override bool ProcessValue(in VBuffer<TFloat> buffer)
                {
                    if (!base.ProcessValue(in buffer))
                        return false;
 
                    int size = _values.Length;
                    Host.Check(buffer.Length == size);
 
                    var values = buffer.GetValues();
                    Host.Assert(0 <= values.Length && values.Length <= size);
                    if (values.Length == 0)
                        return true;
 
                    if (values.Length == size)
                    {
                        for (int j = 0; j < values.Length; j++)
                            _values[j].Add(values[j]);
                    }
                    else
                    {
                        var indices = buffer.GetIndices();
                        for (int k = 0; k < values.Length; k++)
                        {
                            var val = values[k];
                            var j = indices[k];
                            _values[j].Add(val);
                        }
                    }
                    return true;
                }
 
                public override IColumnFunction CreateColumnFunction()
                {
                    var binFinder = new GreedyBinFinder();
                    var count = _values.Length;
                    var binUpperBounds = new TFloat[count][];
                    for (int i = 0; i < count; i++)
                    {
                        var numZeroes = checked((int)(Lim - Rem - _values[i].Count));
                        _values[i].RemoveAll(TFloat.IsNaN);
                        binUpperBounds[i] = binFinder.FindBins(_numBins, _values[i], numZeroes);
                    }
                    return BinColumnFunction.Create(Host, binUpperBounds, _fix);
                }
            }
 
            public sealed class SupervisedBinOneColumnFunctionBuilder : OneColumnSupervisedBinFunctionBuilderBase<TFloat>
            {
                private readonly bool _fix;
                private readonly int _numBins;
                private readonly int _minBinSize;
 
                private SupervisedBinOneColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, DataViewRow dataRow)
                    : base(host, lim, valueColumnId, labelColumnId, dataRow)
                {
                    _fix = fix;
                    _numBins = numBins;
                    _minBinSize = minBinSize;
                }
 
                protected override bool AcceptColumnValue(in TFloat colValue)
                {
                    return !TFloat.IsNaN(colValue);
                }
 
                public override IColumnFunction CreateColumnFunction()
                {
                    var binFinder = new SupervisedBinFinder();
                    var binUpperBounds = binFinder.FindBins(_numBins, _minBinSize, LabelCardinality, ColValues, Labels);
                    return BinColumnFunction.Create(Host, binUpperBounds, _fix);
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.SupervisedBinningColumOptions column, IHost host, int valueColumnId, int labelColumnId, DataViewRow dataRow)
                {
                    var lim = column.MaximumExampleCount;
                    host.CheckUserArg(lim > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    bool fix = column.EnsureZeroUntouched;
                    var numBins = column.MaximumBinCount;
                    host.CheckUserArg(numBins > 1, nameof(column.MaximumBinCount), "Must be greater than 1");
                    host.CheckUserArg(column.MininimumBinSize > 0, nameof(column.MininimumBinSize), "Must be positive");
                    return new SupervisedBinOneColumnFunctionBuilder(host, lim, fix, numBins, column.MininimumBinSize, valueColumnId, labelColumnId, dataRow);
                }
            }
 
            public sealed class SupervisedBinVecColumnFunctionBuilder : VecColumnSupervisedBinFunctionBuilderBase<TFloat>
            {
                private readonly bool _fix;
                private readonly int _numBins;
                private readonly int _minBinSize;
 
                private SupervisedBinVecColumnFunctionBuilder(IHost host, long lim, bool fix, int numBins, int minBinSize, int valueColumnId, int labelColumnId, DataViewRow dataRow)
                    : base(host, lim, valueColumnId, labelColumnId, dataRow)
                {
                    _fix = fix;
                    _numBins = numBins;
                    _minBinSize = minBinSize;
                }
 
                protected override bool AcceptColumnValue(in VBuffer<TFloat> colValuesBuffer)
                {
                    return !VBufferUtils.HasNaNs(in colValuesBuffer);
                }
 
                public override IColumnFunction CreateColumnFunction()
                {
                    var binFinder = new SupervisedBinFinder();
                    TFloat[][] binUpperBounds = new TFloat[ColumnSlotCount][];
                    for (int i = 0; i < ColumnSlotCount; i++)
                        binUpperBounds[i] = binFinder.FindBins(_numBins, _minBinSize, LabelCardinality, ColValues[i], Labels);
                    return BinColumnFunction.Create(Host, binUpperBounds, _fix);
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.SupervisedBinningColumOptions column, IHost host, int valueColumnId, int labelColumnId, DataViewRow dataRow)
                {
                    var lim = column.MaximumExampleCount;
                    host.CheckUserArg(lim > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    bool fix = column.EnsureZeroUntouched;
                    var numBins = column.MaximumBinCount;
                    host.CheckUserArg(numBins > 1, nameof(column.MaximumBinCount), "Must be greater than 1");
                    host.CheckUserArg(column.MininimumBinSize > 0, nameof(column.MininimumBinSize), "Must be positive");
                    return new SupervisedBinVecColumnFunctionBuilder(host, lim, fix, numBins, column.MininimumBinSize, valueColumnId, labelColumnId, dataRow);
                }
            }
 
            public sealed class RobustScalerOneColumnFunctionBuilder : OneColumnFunctionBuilderBase<double>
            {
                private readonly MinMaxDblAggregator _minMaxAggregator;
                private readonly MedianDblAggregator _medianAggregator;
                private readonly bool _centerData;
                private readonly uint _quantileMin;
                private readonly uint _quantileMax;
                private VBuffer<double> _buffer;
 
                private RobustScalerOneColumnFunctionBuilder(IHost host, long lim, bool centerData, uint quantileMin, uint quantileMax, ValueGetter<double> getSrc)
                    : base(host, lim, getSrc)
                {
                    // Using the MinMax aggregator since that is what needs to be found here as well.
                    // The difference is how the min/max are used.
                    _minMaxAggregator = new MinMaxDblAggregator(1);
                    _medianAggregator = new MedianDblAggregator();
                    _buffer = new VBuffer<double>(1, new double[1]);
                    _centerData = centerData;
                    _quantileMin = quantileMin;
                    _quantileMax = quantileMax;
                }
 
                protected override bool ProcessValue(in double val)
                {
                    if (!base.ProcessValue(in val))
                        return false;
                    VBufferEditor.CreateFromBuffer(ref _buffer).Values[0] = val;
                    _minMaxAggregator.ProcessValue(in _buffer);
                    _medianAggregator.ProcessValue(in val);
                    return true;
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.RobustScalingColumnOptions column, IHost host, DataViewType srcType,
                    bool centerData, uint quantileMin, uint quantileMax, ValueGetter<double> getter)
                {
                    host.CheckUserArg(column.MaximumExampleCount > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    return new RobustScalerOneColumnFunctionBuilder(host, column.MaximumExampleCount, centerData, quantileMin, quantileMax, getter);
                }
 
                public override IColumnFunction CreateColumnFunction()
                {
                    _minMaxAggregator.Finish();
                    _medianAggregator.Finish();
 
                    double median = _medianAggregator.Median;
                    double range = _minMaxAggregator.Max[0] - _minMaxAggregator.Min[0];
                    // Divide the range by 100 because we need to make the number, i.e. 75, into a decimal, .75
                    double quantileRange = (_quantileMax - _quantileMin) / 100f;
                    double scale = 1 / (range * quantileRange);
 
                    if (_centerData)
                        return AffineColumnFunction.Create(Host, scale, median);
                    else
                        return AffineColumnFunction.Create(Host, scale, 0);
                }
            }
 
            public sealed class RobustScalerVecFunctionBuilder : OneColumnFunctionBuilderBase<VBuffer<double>>
            {
                private readonly MinMaxDblAggregator _minMaxAggregator;
                private readonly MedianDblAggregator[] _medianAggregators;
                private readonly bool _centerData;
                private readonly uint _quantileMin;
                private readonly uint _quantileMax;
 
                private RobustScalerVecFunctionBuilder(IHost host, long lim, int vectorSize, bool centerData, uint quantileMin, uint quantileMax, ValueGetter<VBuffer<double>> getSrc)
                    : base(host, lim, getSrc)
                {
                    // Using the MinMax aggregator since that is what needs to be found here as well.
                    // The difference is how the min/max are used.
                    _minMaxAggregator = new MinMaxDblAggregator(vectorSize);
 
                    // If we aren't centering data don't need the median.
                    _medianAggregators = new MedianDblAggregator[vectorSize];
 
                    for (int i = 0; i < vectorSize; i++)
                    {
                        _medianAggregators[i] = new MedianDblAggregator();
                    }
 
                    _centerData = centerData;
                    _quantileMin = quantileMin;
                    _quantileMax = quantileMax;
                }
 
                protected override bool ProcessValue(in VBuffer<double> val)
                {
                    if (!base.ProcessValue(in val))
                        return false;
                    _minMaxAggregator.ProcessValue(in val);
 
                    // Have to calculate the median per slot
                    var span = val.GetValues();
                    for (int i = 0; i < _medianAggregators.Length; i++)
                    {
                        _medianAggregators[i].ProcessValue(span[i]);
                    }
 
                    return true;
                }
 
                public static IColumnFunctionBuilder Create(NormalizingEstimator.RobustScalingColumnOptions column, IHost host, VectorDataViewType srcType,
                    bool centerData, uint quantileMin, uint quantileMax, ValueGetter<VBuffer<double>> getter)
                {
                    host.CheckUserArg(column.MaximumExampleCount > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
                    var vectorSize = srcType.Size;
                    return new RobustScalerVecFunctionBuilder(host, column.MaximumExampleCount, vectorSize, centerData, quantileMin, quantileMax, getter);
                }
 
                public override IColumnFunction CreateColumnFunction()
                {
                    _minMaxAggregator.Finish();
 
                    double[] scale = new double[_medianAggregators.Length];
                    double[] median = new double[_medianAggregators.Length];
 
                    // Have to calculate the median per slot
                    for (int i = 0; i < _medianAggregators.Length; i++)
                    {
                        _medianAggregators[i].Finish();
                        median[i] = _medianAggregators[i].Median;
 
                        double range = _minMaxAggregator.Max[i] - _minMaxAggregator.Min[i];
 
                        // Divide the range by 100 because we need to make the number, i.e. 75, into a decimal, .75
                        double quantileRange = (_quantileMax - _quantileMin) / 100f;
                        scale[i] = 1 / (range * quantileRange);
 
                    }
 
                    if (_centerData)
                        return AffineColumnFunction.Create(Host, scale, median, null);
                    else
                        return AffineColumnFunction.Create(Host, scale, null, null);
 
                }
            }
        }
    }
}