File: LdSvm\LdSvmTrainer.cs
Web Access
Project: src\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj (Microsoft.ML.StandardTrainers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.SearchSpace;
using Microsoft.ML.Trainers;
 
[assembly: LoadableClass(LdSvmTrainer.Summary, typeof(LdSvmTrainer), typeof(LdSvmTrainer.Options),
    new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
    LdSvmTrainer.UserNameValue,
    LdSvmTrainer.LoadNameValue
    )]
 
[assembly: LoadableClass(typeof(void), typeof(LdSvmTrainer), null, typeof(SignatureEntryPointModule), LdSvmTrainer.LoadNameValue)]
 
namespace Microsoft.ML.Trainers
{
    /// <summary>
    /// The <see cref="IEstimator{TTransformer}"/> to predict a target using a non-linear binary classification model
    /// trained with Local Deep SVM.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    /// To create this trainer, use [LdSvm](xref:Microsoft.ML.StandardTrainersCatalog.LdSvm(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,System.String,System.String,System.String,System.Int32,System.Int32,System.Boolean,System.Boolean))
    /// or [LdSvm(Options)](xref:Microsoft.ML.StandardTrainersCatalog.LdSvm(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,Microsoft.ML.Trainers.LdSvmTrainer.Options)).
    ///
    /// [!include[io](~/../docs/samples/docs/api-reference/io-columns-binary-classification-no-prob.md)]
    ///
    /// ### Trainer Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Machine learning task | Binary classification |
    /// | Is normalization required? | Yes |
    /// | Is caching required? | No |
    /// | Required NuGet in addition to Microsoft.ML | None |
    /// | Exportable to ONNX | No |
    ///
    /// ### Training Algorithm Details
    /// Local Deep SVM (LD-SVM) is a generalization of Localized Multiple Kernel Learning for non-linear SVM. Multiple kernel methods learn a different
    /// kernel, and hence a different classifier, for each point in the feature space. The prediction time cost for multiple kernel methods can be prohibitively
    /// expensive for large training sets because it is proportional to the number of support vectors, and these grow linearly with the size of the training
    /// set. LD-SVM reduces the prediction cost by learning a tree-based local feature embedding that is high dimensional and sparse, efficiently encoding
    /// non-linearities. Using LD-SVM, the prediction cost grows logarithmically with the size of the training set, rather than linearly, with a tolerable loss
    /// in classification accuracy.
    ///
    /// Local Deep SVM is an implementation of the algorithm described in [C. Jose, P. Goyal, P. Aggrwal, and M. Varma, Local Deep
    /// Kernel Learning for Efficient Non-linear SVM Prediction, ICML, 2013](http://proceedings.mlr.press/v28/jose13.pdf).
    ///
    /// Check the See Also section for links to usage examples.
    /// ]]>
    /// </format>
    /// </remarks>
    /// <seealso cref="StandardTrainersCatalog.LdSvm(BinaryClassificationCatalog.BinaryClassificationTrainers, LdSvmTrainer.Options)"/>
    /// <seealso cref="StandardTrainersCatalog.LdSvm(BinaryClassificationCatalog.BinaryClassificationTrainers, string, string, string, int, int, bool, bool)"/>
    public sealed class LdSvmTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<LdSvmModelParameters>, LdSvmModelParameters>
    {
        internal const string LoadNameValue = "LDSVM";
        internal const string UserNameValue = "Local Deep SVM (LDSVM)";
        internal const string Summary = "LD-SVM learns a binary, non-linear SVM classifier with a kernel that is specifically designed to reduce prediction time. "
            + "LD-SVM learns decision boundaries that are locally linear.";
 
        public sealed class Options : TrainerInputBaseWithWeight
        {
            /// <summary>
            /// Depth of LDSVM Tree
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Depth of Local Deep SVM tree", ShortName = "depth", SortOrder = 50)]
            [TGUI(SuggestedSweeps = "1,3,5,7")]
            [TlcModule.SweepableDiscreteParam("TreeDepth", new object[] { 1, 3, 5, 7 })]
            [Range(1, 128, 1, true)]
            public int TreeDepth = Defaults.TreeDepth;
 
            /// <summary>
            ///  Regularizer for classifier parameter W
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Regularizer for classifier parameter W", ShortName = "lw", SortOrder = 50)]
            [TGUI(SuggestedSweeps = "0.1,0.01,0.001")]
            [TlcModule.SweepableDiscreteParam("LambdaW", new object[] { 0.1f, 0.01f, 0.001f })]
            [Range(1e-4f, 1f, 1e-4f, true)]
            public float LambdaW = Defaults.LambdaW;
 
            /// <summary>
            ///  Regularizer for kernel parameter Theta
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Regularizer for kernel parameter Theta", ShortName = "lt", SortOrder = 50)]
            [TGUI(SuggestedSweeps = "0.1,0.01,0.001")]
            [TlcModule.SweepableDiscreteParam("LambdaTheta", new object[] { 0.1f, 0.01f, 0.001f })]
            [Range(1e-4f, 1f, 1e-4f, true)]
            public float LambdaTheta = Defaults.LambdaTheta;
 
            /// <summary>
            ///  Regularizer for kernel parameter ThetaPrime
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Regularizer for kernel parameter Thetaprime", ShortName = "lp", SortOrder = 50)]
            [TGUI(SuggestedSweeps = "0.1,0.01,0.001")]
            [TlcModule.SweepableDiscreteParam("LambdaThetaprime", new object[] { 0.1f, 0.01f, 0.001f })]
            [Range(1e-4f, 1f, 1e-4f, true)]
            public float LambdaThetaprime = Defaults.LambdaThetaprime;
 
            /// <summary>
            ///  Parameter for sigmoid sharpness
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "Parameter for sigmoid sharpness", ShortName = "s", SortOrder = 50)]
            [TGUI(SuggestedSweeps = "1.0,0.1,0.01")]
            [TlcModule.SweepableDiscreteParam("Sigma", new object[] { 1.0f, 0.1f, 0.01f })]
            [Range(1e-4f, 1f, 1e-4f, true)]
            public float Sigma = Defaults.Sigma;
 
            /// <summary>
            /// Indicates if we should use Bias or not in our model.
            /// </summary>
            [Argument(ArgumentType.AtMostOnce, HelpText = "No bias", ShortName = "bias")]
            [TlcModule.SweepableDiscreteParam("NoBias", null, isBool: true)]
            [BooleanChoice(true)]
            public bool UseBias = Defaults.UseBias;
 
            /// <summary>
            /// Number of iterations
            /// </summary>
            [Argument(ArgumentType.AtMostOnce,
                HelpText = "Number of iterations", ShortName = "iter,NumIterations", SortOrder = 50)]
            [TGUI(SuggestedSweeps = "10000,15000")]
            [TlcModule.SweepableDiscreteParam("NumIterations", new object[] { 10000, 15000 })]
            [Range(1, int.MaxValue, 1, true)]
            public int NumberOfIterations = Defaults.NumberOfIterations;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
            internal ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory();
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
            internal int MaxCalibrationExamples = 1000000;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to cache the data before the first iteration")]
            public bool Cache = Defaults.Cache;
 
            internal class Defaults
            {
                public const int NumberOfIterations = 15000;
                public const bool UseBias = true;
                public const float Sigma = 1.0f;
                public const float LambdaThetaprime = 0.01f;
                public const float LambdaTheta = 0.01f;
                public const float LambdaW = 0.1f;
                public const int TreeDepth = 3;
                public const bool Cache = true;
            }
        }
 
        private const int NumberOfSamplesForGammaUpdate = 100;
 
        private readonly Options _options;
 
        internal LdSvmTrainer(IHostEnvironment env, Options options)
            : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue),
                  TrainerUtils.MakeR4VecFeature(options.FeatureColumnName),
                  TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName),
                  TrainerUtils.MakeR4ScalarWeightColumn(options.ExampleWeightColumnName))
        {
            Host.CheckValue(options, nameof(options));
            CheckOptions(Host, options);
            _options = options;
        }
 
        private static readonly TrainerInfo _info = new TrainerInfo(calibration: true, caching: false);
        public override TrainerInfo Info => _info;
 
        private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
 
        private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
        {
            return new[]
            {
                new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation())),
                new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation()))
            };
        }
 
        private protected override LdSvmModelParameters TrainModelCore(TrainContext trainContext)
        {
            Host.CheckValue(trainContext, nameof(trainContext));
            using (var ch = Host.Start("Training"))
            {
                trainContext.TrainingSet.CheckFeatureFloatVector(out var numFeatures);
                trainContext.TrainingSet.CheckBinaryLabel();
 
                var numLeaf = 1 << _options.TreeDepth;
                return TrainCore(ch, trainContext.TrainingSet, numLeaf, numFeatures);
            }
        }
 
        /// <summary>
        /// Compute gradient w.r.t theta for an instance X
        /// </summary>
        private void ComputeGradTheta(in VBuffer<float> feat, float[] gradTheta, int numLeaf, float gamma,
            VBuffer<float>[] theta, float[] biasTheta, float[] pathWt, float[] localWt, VBuffer<float>[] w, float[] biasW)
        {
            Array.Clear(gradTheta, 0, numLeaf - 1);
            int numNodes = 2 * numLeaf - 1;
            float[] tanhThetaTx = new float[numLeaf - 1];
            for (int i = 0; i < numLeaf - 1; i++)
                tanhThetaTx[i] = (float)Math.Tanh(gamma * (VectorUtils.DotProduct(in feat, in theta[i]) + biasTheta[i]));
            for (int i = 0; i < numNodes; i++)
            {
                int current = i;
                float tempGrad = pathWt[i] * localWt[i] * (VectorUtils.DotProduct(in feat, in w[i]) + biasW[i]);
                while (current > 0)
                {
                    int parent = (current - 1) / 2;
                    gradTheta[parent] += tempGrad * (current % 2 == 1 ? (1 - tanhThetaTx[parent]) : (-1 - tanhThetaTx[parent]));
                    current = parent;
                }
            }
        }
 
        /// <summary>
        /// Adaptively update gamma for indicator function approximation.
        /// </summary>
        private void UpdateGamma(int iter, int numLeaf, ref float gamma, Data data, VBuffer<float>[] theta, float[] biasTheta)
        {
            if (numLeaf == 1)
                gamma = 1.0f;
            else
            {
                float tempSum = 0;
                var sample = data.SampleForGammaUpdate(Host.Rand);
                int sampleSize = 0;
                foreach (var s in sample)
                {
                    int thetaIdx = Host.Rand.Next(numLeaf - 1);
                    tempSum += Math.Abs(VectorUtils.DotProduct(in s, in theta[thetaIdx]) + biasTheta[thetaIdx]);
                    sampleSize++;
                }
                tempSum /= sampleSize;
                gamma = 0.1f / tempSum;
                gamma *= (float)Math.Pow(2.0, iter / (_options.NumberOfIterations / 10.0));
            }
        }
 
        /// <summary>
        /// Main LDSVM training routine.
        /// </summary>
        private LdSvmModelParameters TrainCore(IChannel ch, RoleMappedData trainingData, int numLeaf, int numFeatures)
        {
            int numNodes = 2 * numLeaf - 1;
 
            var w = new VBuffer<float>[numNodes];
            var thetaPrime = new VBuffer<float>[numNodes];
            var theta = new VBuffer<float>[numLeaf - 1];
            var biasW = new float[numNodes];
            var biasTheta = new float[numLeaf - 1];
            var biasThetaPrime = new float[numNodes];
 
            var tempW = new VBuffer<float>[numNodes];
            var tempThetaPrime = new VBuffer<float>[numNodes];
            var tempTheta = new VBuffer<float>[numLeaf - 1];
            var tempBiasW = new float[numNodes];
            var tempBiasTheta = new float[numLeaf - 1];
            var tempBiasThetaPrime = new float[numNodes];
 
            InitClassifierParam(numLeaf, numFeatures, tempW, w, theta, thetaPrime, biasW,
                biasTheta, biasThetaPrime, tempThetaPrime, tempTheta, tempBiasW, tempBiasTheta, tempBiasThetaPrime);
 
            var gamma = 0.01f;
            Data data = _options.Cache ?
                (Data)new CachedData(ch, trainingData) :
                new StreamingData(ch, trainingData);
            var pathWt = new float[numNodes];
            var localWt = new float[numNodes];
            var gradTheta = new float[numLeaf - 1];
            var wDotX = new float[numNodes];
 
            // Number of samples processed in each iteration
            int sampleSize = Math.Max(1, (int)Math.Sqrt(data.Length));
            for (int iter = 1; iter <= _options.NumberOfIterations; iter++)
            {
                // Update gamma adaptively
                if (iter % 100 == 1)
                    UpdateGamma(iter, numLeaf, ref gamma, data, theta, biasTheta);
 
                // Update learning rate
                float etaTW = (float)1.0 / (_options.LambdaW * (float)Math.Sqrt(iter + 1));
                float etaTTheta = (float)1.0 / (_options.LambdaTheta * (float)Math.Sqrt(iter + 1));
                float etaTThetaPrime = (float)1.0 / (_options.LambdaThetaprime * (float)Math.Sqrt(iter + 1));
                float coef = iter / (float)(iter + 1);
 
                // Update classifier parameters
                for (int i = 0; i < tempW.Length; ++i)
                    VectorUtils.ScaleBy(ref tempW[i], coef);
                for (int i = 0; i < tempTheta.Length; ++i)
                    VectorUtils.ScaleBy(ref tempTheta[i], coef);
                for (int i = 0; i < tempThetaPrime.Length; ++i)
                    VectorUtils.ScaleBy(ref tempThetaPrime[i], coef);
 
                for (int i = 0; i < numNodes; i++)
                {
                    tempBiasW[i] *= coef;
                    tempBiasThetaPrime[i] *= coef;
                }
                for (int i = 0; i < numLeaf - 1; i++)
                    tempBiasTheta[i] *= coef;
 
                var sample = data.SampleExamples(Host.Rand);
                foreach (var s in sample)
                {
                    float trueLabel = s.Label;
                    var features = s.Features;
 
                    // Compute path weight
                    for (int i = 0; i < numNodes; i++)
                        pathWt[i] = 1;
                    for (int i = 0; i < numLeaf - 1; i++)
                    {
                        var tanhDist = (float)Math.Tanh(gamma * (VectorUtils.DotProduct(in features, in theta[i]) + biasTheta[i]));
                        pathWt[2 * i + 1] = pathWt[i] * (1 + tanhDist) / (float)2.0;
                        pathWt[2 * i + 2] = pathWt[i] * (1 - tanhDist) / (float)2.0;
                    }
 
                    // Compute local weight
                    for (int l = 0; l < numNodes; l++)
                        localWt[l] = (float)Math.Tanh(_options.Sigma * (VectorUtils.DotProduct(in features, in thetaPrime[l]) + biasThetaPrime[l]));
 
                    // Make prediction
                    float yPredicted = 0;
                    for (int l = 0; l < numNodes; l++)
                    {
                        wDotX[l] = VectorUtils.DotProduct(in features, in w[l]) + biasW[l];
                        yPredicted += pathWt[l] * localWt[l] * wDotX[l];
                    }
                    float loss = 1 - trueLabel * yPredicted;
 
                    // If wrong prediction update classifier parameters
                    if (loss > 0)
                    {
                        // Compute gradient w.r.t current instance
                        ComputeGradTheta(in features, gradTheta, numLeaf, gamma, theta, biasTheta, pathWt, localWt, w, biasW);
 
                        // Check if bias is used ot not
                        int biasUpdateMult = _options.UseBias ? 1 : 0;
 
                        // Update W
                        for (int l = 0; l < numNodes; l++)
                        {
                            float tempGradW = trueLabel * etaTW / sampleSize * pathWt[l] * localWt[l];
                            VectorUtils.AddMult(in features, tempGradW, ref tempW[l]);
                            tempBiasW[l] += biasUpdateMult * tempGradW;
                        }
 
                        // Update ThetaPrime
                        for (int l = 0; l < numNodes; l++)
                        {
                            float tempGradThetaPrime = (1 - localWt[l] * localWt[l]) * trueLabel * etaTThetaPrime / sampleSize * pathWt[l] * wDotX[l];
                            VectorUtils.AddMult(in features, tempGradThetaPrime, ref tempThetaPrime[l]);
                            tempBiasThetaPrime[l] += biasUpdateMult * tempGradThetaPrime;
                        }
 
                        // Update Theta
                        for (int m = 0; m < numLeaf - 1; m++)
                        {
                            float tempGradTheta = trueLabel * etaTTheta / sampleSize * gradTheta[m];
                            VectorUtils.AddMult(in features, tempGradTheta, ref tempTheta[m]);
                            tempBiasTheta[m] += biasUpdateMult * tempGradTheta;
                        }
                    }
                }
 
                // Copy solution
                for (int i = 0; i < numNodes; i++)
                {
                    tempW[i].CopyTo(ref w[i]);
                    biasW[i] = tempBiasW[i];
 
                    tempThetaPrime[i].CopyTo(ref thetaPrime[i]);
                    biasThetaPrime[i] = tempBiasThetaPrime[i];
                }
                for (int i = 0; i < numLeaf - 1; i++)
                {
                    tempTheta[i].CopyTo(ref theta[i]);
                    biasTheta[i] = tempBiasTheta[i];
                }
            }
 
            return new LdSvmModelParameters(Host, w, thetaPrime, theta, _options.Sigma, biasW, biasTheta,
                biasThetaPrime, _options.TreeDepth);
        }
 
        /// <summary>
        /// Inititlize classifier parameters
        /// </summary>
        private void InitClassifierParam(int numLeaf, int numFeatures, VBuffer<float>[] tempW, VBuffer<float>[] w,
            VBuffer<float>[] theta, VBuffer<float>[] thetaPrime, float[] biasW, float[] biasTheta,
            float[] biasThetaPrime, VBuffer<float>[] tempThetaPrime, VBuffer<float>[] tempTheta,
            float[] tempBiasW, float[] tempBiasTheta, float[] tempBiasThetaPrime)
        {
            int count = 2 * numLeaf - 1;
            int half = numLeaf - 1;
 
            Host.Assert(Utils.Size(tempW) == count);
            Host.Assert(Utils.Size(w) == count);
            Host.Assert(Utils.Size(theta) == half);
            Host.Assert(Utils.Size(thetaPrime) == count);
            Host.Assert(Utils.Size(biasW) == count);
            Host.Assert(Utils.Size(biasTheta) == half);
            Host.Assert(Utils.Size(biasThetaPrime) == count);
            Host.Assert(Utils.Size(tempThetaPrime) == count);
            Host.Assert(Utils.Size(tempTheta) == half);
            Host.Assert(Utils.Size(tempBiasW) == count);
            Host.Assert(Utils.Size(tempBiasTheta) == half);
            Host.Assert(Utils.Size(tempBiasThetaPrime) == count);
 
            for (int i = 0; i < count; i++)
            {
                VBufferEditor<float> thetaInit = default;
                if (i < numLeaf - 1)
                    thetaInit = VBufferEditor.Create(ref theta[i], numFeatures);
                var wInit = VBufferEditor.Create(ref w[i], numFeatures);
                var thetaPrimeInit = VBufferEditor.Create(ref thetaPrime[i], numFeatures);
                for (int j = 0; j < numFeatures; j++)
                {
                    wInit.Values[j] = 2 * Host.Rand.NextSingle() - 1;
                    thetaPrimeInit.Values[j] = 2 * Host.Rand.NextSingle() - 1;
                    if (i < numLeaf - 1)
                        thetaInit.Values[j] = 2 * Host.Rand.NextSingle() - 1;
                }
 
                w[i] = wInit.Commit();
                w[i].CopyTo(ref tempW[i]);
                thetaPrime[i] = thetaPrimeInit.Commit();
                thetaPrime[i].CopyTo(ref tempThetaPrime[i]);
 
                if (_options.UseBias)
                {
                    float bW = 2 * Host.Rand.NextSingle() - 1;
                    biasW[i] = bW;
                    tempBiasW[i] = bW;
                    float bTP = 2 * Host.Rand.NextSingle() - 1;
                    biasThetaPrime[i] = bTP;
                    tempBiasThetaPrime[i] = bTP;
                }
 
                if (i >= half)
                    continue;
 
                theta[i] = thetaInit.Commit();
                theta[i].CopyTo(ref tempTheta[i]);
 
                if (_options.UseBias)
                {
                    float bT = 2 * Host.Rand.NextSingle() - 1;
                    biasTheta[i] = bT;
                    tempBiasTheta[i] = bT;
                }
            }
        }
 
        /// <summary>
        /// Initialization of model.
        /// </summary>
        private static void CheckOptions(IExceptionContext ectx, Options options)
        {
            ectx.AssertValue(options);
 
            ectx.CheckUserArg(options.TreeDepth >= 0, nameof(options.TreeDepth), "Tree depth can not be negative.");
            ectx.CheckUserArg(options.TreeDepth <= 24, nameof(options.TreeDepth), "Try running with a tree of smaller depth first and cross validate over other parameters.");
            ectx.CheckUserArg(options.LambdaW > 0, nameof(options.LambdaW), "Regularizer for W must be positive and non-zero.");
            ectx.CheckUserArg(options.LambdaTheta > 0, nameof(options.LambdaTheta), "Regularizer for Theta must be positive and non-zero.");
            ectx.CheckUserArg(options.LambdaThetaprime > 0, nameof(options.LambdaThetaprime), "Regularizer for Thetaprime must be positive and non-zero.");
        }
 
        internal struct LabelFeatures
        {
            public float Label;
            public VBuffer<float> Features;
        }
 
        private abstract class Data
        {
            protected readonly IChannel Ch;
 
            public abstract long Length { get; }
 
            protected Data(IChannel ch)
            {
                Ch = ch;
            }
 
            public abstract IEnumerable<VBuffer<float>> SampleForGammaUpdate(Random rand);
            public abstract IEnumerable<LabelFeatures> SampleExamples(Random rand);
        }
 
        private sealed class CachedData : Data
        {
            private readonly LabelFeatures[] _examples;
            private readonly int[] _indices;
 
            public override long Length => _examples.Length;
 
            public CachedData(IChannel ch, RoleMappedData data)
                : base(ch)
            {
                var examples = new List<LabelFeatures>();
                using (var cursor = new FloatLabelCursor(data, CursOpt.Label | CursOpt.Features))
                {
                    while (cursor.MoveNext())
                    {
                        var example = new LabelFeatures();
                        cursor.Features.CopyTo(ref example.Features);
                        example.Label = cursor.Label > 0 ? 1 : -1;
                        examples.Add(example);
                    }
                    Ch.Check(cursor.KeptRowCount > 0, NoTrainingInstancesMessage);
                    if (cursor.SkippedRowCount > 0)
                        Ch.Warning("Skipped {0} rows with missing feature/label values", cursor.SkippedRowCount);
                }
                _examples = examples.ToArray();
                _indices = Utils.GetIdentityPermutation((int)Length);
            }
 
            public override IEnumerable<LabelFeatures> SampleExamples(Random rand)
            {
                var sampleSize = Math.Max(1, (int)Math.Sqrt(Length));
                var length = (int)Length;
                // Select random subset of data - the first sampleSize indices will be
                // our subset.
                for (int k = 0; k < sampleSize; k++)
                {
                    int randIdx = k + rand.Next(length - k);
                    Utils.Swap(ref _indices[k], ref _indices[randIdx]);
                }
 
                for (int k = 0; k < sampleSize; k++)
                {
                    yield return _examples[_indices[k]];
                }
            }
 
            public override IEnumerable<VBuffer<float>> SampleForGammaUpdate(Random rand)
            {
                int length = (int)Length;
                for (int i = 0; i < NumberOfSamplesForGammaUpdate; i++)
                {
                    int index = rand.Next(length);
                    yield return _examples[index].Features;
                }
            }
        }
 
        private sealed class StreamingData : Data
        {
            private readonly RoleMappedData _data;
            private readonly int[] _indices;
            private readonly int[] _indices2;
 
            public override long Length { get; }
 
            public StreamingData(IChannel ch, RoleMappedData data)
                : base(ch)
            {
                Ch.AssertValue(data);
 
                _data = data;
 
                using (var cursor = _data.Data.GetRowCursor())
                {
                    while (cursor.MoveNext())
                        Length++;
                }
                _indices = Utils.GetIdentityPermutation((int)Length);
                _indices2 = new int[NumberOfSamplesForGammaUpdate];
            }
 
            public override IEnumerable<VBuffer<float>> SampleForGammaUpdate(Random rand)
            {
                int length = (int)Length;
                for (int i = 0; i < NumberOfSamplesForGammaUpdate; i++)
                {
                    _indices2[i] = rand.Next(length);
                }
                Array.Sort(_indices2);
 
                using (var cursor = _data.Data.GetRowCursor(_data.Data.Schema[_data.Schema.Feature.Value.Name]))
                {
                    var getter = cursor.GetGetter<VBuffer<float>>(_data.Data.Schema[_data.Schema.Feature.Value.Name]);
                    var features = default(VBuffer<float>);
                    int iIndex = 0;
                    while (cursor.MoveNext())
                    {
                        if (cursor.Position == _indices2[iIndex])
                        {
                            iIndex++;
                            getter(ref features);
                            var noNaNs = FloatUtils.IsFinite(features.GetValues());
                            if (noNaNs)
                                yield return features;
                            while (iIndex < NumberOfSamplesForGammaUpdate && cursor.Position == _indices2[iIndex])
                            {
                                iIndex++;
                                if (noNaNs)
                                    yield return features;
                            }
                            if (iIndex == NumberOfSamplesForGammaUpdate)
                                break;
                        }
                    }
                }
            }
 
            public override IEnumerable<LabelFeatures> SampleExamples(Random rand)
            {
                var sampleSize = Math.Max(1, (int)Math.Sqrt(Length));
                var length = (int)Length;
                // Select random subset of data - the first sampleSize indices will be
                // our subset.
                for (int k = 0; k < sampleSize; k++)
                {
                    int randIdx = k + rand.Next(length - k);
                    Utils.Swap(ref _indices[k], ref _indices[randIdx]);
                }
 
                Array.Sort(_indices, 0, sampleSize);
 
                var featureCol = _data.Data.Schema[_data.Schema.Feature.Value.Name];
                var labelCol = _data.Data.Schema[_data.Schema.Label.Value.Name];
                using (var cursor = _data.Data.GetRowCursor(featureCol, labelCol))
                {
                    var featureGetter = cursor.GetGetter<VBuffer<float>>(featureCol);
                    var labelGetter = RowCursorUtils.GetLabelGetter(cursor, labelCol.Index);
                    ValueGetter<LabelFeatures> getter =
                        (ref LabelFeatures dst) =>
                        {
                            featureGetter(ref dst.Features);
                            var label = default(float);
                            labelGetter(ref label);
                            dst.Label = label > 0 ? 1 : -1;
                        };
 
                    int iIndex = 0;
                    while (cursor.MoveNext())
                    {
                        if (cursor.Position == _indices[iIndex])
                        {
                            var example = new LabelFeatures();
                            getter(ref example);
                            iIndex++;
                            if (FloatUtils.IsFinite(example.Features.GetValues()))
                                yield return example;
                            if (iIndex == sampleSize)
                                break;
                        }
                    }
                }
            }
        }
 
        private protected override BinaryPredictionTransformer<LdSvmModelParameters> MakeTransformer(LdSvmModelParameters model, DataViewSchema trainSchema)
            => new BinaryPredictionTransformer<LdSvmModelParameters>(Host, model, trainSchema, _options.FeatureColumnName);
 
        [TlcModule.EntryPoint(Name = "Trainers.LocalDeepSvmBinaryClassifier", Desc = Summary, UserName = UserNameValue, ShortName = LoadNameValue)]
        internal static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("TrainLDSVM");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            return TrainerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
                () => new LdSvmTrainer(host, input),
                () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
                calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
        }
    }
}