File: TrainerExtensions\BinaryTrainerExtensions.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Trainers.LightGbm;
 
namespace Microsoft.ML.AutoML
{
    using ITrainerEstimator = ITrainerEstimator<IPredictionTransformer<object>, object>;
 
    internal class AveragedPerceptronBinaryExtension : ITrainerExtension
    {
        private const int DefaultNumIterations = 10;
 
        public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
        {
            return SweepableParams.BuildAveragePerceptronParams();
        }
 
        public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<SweepableParam> sweepParams,
            ColumnInformation columnInfo, IDataView validationSet)
        {
            AveragedPerceptronTrainer.Options options = null;
            if (sweepParams == null || !sweepParams.Any())
            {
                options = new AveragedPerceptronTrainer.Options();
                options.NumberOfIterations = DefaultNumIterations;
                options.LabelColumnName = columnInfo.LabelColumnName;
            }
            else
            {
                options = TrainerExtensionUtil.CreateOptions<AveragedPerceptronTrainer.Options>(sweepParams, columnInfo.LabelColumnName);
                if (!sweepParams.Any(p => p.Name == "NumberOfIterations"))
                {
                    options.NumberOfIterations = DefaultNumIterations;
                }
            }
            return mlContext.BinaryClassification.Trainers.AveragedPerceptron(options);
        }
 
        public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            Dictionary<string, object> additionalProperties = null;
 
            if (sweepParams == null || !sweepParams.Any(p => p.Name != "NumberOfIterations"))
            {
                additionalProperties = new Dictionary<string, object>()
                {
                    { "NumberOfIterations", DefaultNumIterations }
                };
            }
 
            return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                columnInfo.LabelColumnName, additionalProperties: additionalProperties);
        }
    }
 
    internal class FastForestBinaryExtension : ITrainerExtension
    {
        public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
        {
            return SweepableParams.BuildFastForestParams();
        }
 
        public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<SweepableParam> sweepParams,
            ColumnInformation columnInfo, IDataView validationSet)
        {
            var options = TrainerExtensionUtil.CreateOptions<FastForestBinaryTrainer.Options>(sweepParams, columnInfo.LabelColumnName);
            options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName;
            return mlContext.BinaryClassification.Trainers.FastForest(options);
        }
 
        public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName);
        }
    }
 
    internal class FastTreeBinaryExtension : ITrainerExtension
    {
        public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
        {
            return SweepableParams.BuildFastTreeParams();
        }
 
        public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<SweepableParam> sweepParams,
            ColumnInformation columnInfo, IDataView validationSet)
        {
            var options = TrainerExtensionUtil.CreateOptions<FastTreeBinaryTrainer.Options>(sweepParams, columnInfo.LabelColumnName);
            options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName;
            return mlContext.BinaryClassification.Trainers.FastTree(options);
        }
 
        public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName);
        }
    }
 
    internal class LightGbmBinaryExtension : ITrainerExtension
    {
        public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
        {
            return SweepableParams.BuildLightGbmParams();
        }
 
        public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<SweepableParam> sweepParams,
            ColumnInformation columnInfo, IDataView validationSet)
        {
            LightGbmBinaryTrainer.Options options = TrainerExtensionUtil.CreateLightGbmOptions<LightGbmBinaryTrainer.Options, float, BinaryPredictionTransformer<CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator>>, CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator>>(sweepParams, columnInfo);
            return mlContext.BinaryClassification.Trainers.LightGbm(options);
        }
 
        public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            return TrainerExtensionUtil.BuildLightGbmPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName, columnInfo.GroupIdColumnName);
        }
    }
 
    internal class LinearSvmBinaryExtension : ITrainerExtension
    {
        public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
        {
            return SweepableParams.BuildLinearSvmParams();
        }
 
        public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<SweepableParam> sweepParams,
            ColumnInformation columnInfo, IDataView validationSet)
        {
            var options = TrainerExtensionUtil.CreateOptions<LinearSvmTrainer.Options>(sweepParams, columnInfo.LabelColumnName);
            return mlContext.BinaryClassification.Trainers.LinearSvm(options);
        }
 
        public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                columnInfo.LabelColumnName);
        }
    }
 
    internal class SdcaLogisticRegressionBinaryExtension : ITrainerExtension
    {
        public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
        {
            return SweepableParams.BuildSdcaParams();
        }
 
        public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<SweepableParam> sweepParams,
            ColumnInformation columnInfo, IDataView validationSet)
        {
            var options = TrainerExtensionUtil.CreateOptions<SdcaLogisticRegressionBinaryTrainer.Options>(sweepParams, columnInfo.LabelColumnName);
            return mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(options);
        }
 
        public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                columnInfo.LabelColumnName);
        }
    }
 
    internal class LbfgsLogisticRegressionBinaryExtension : ITrainerExtension
    {
        public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
        {
            return SweepableParams.BuildLbfgsLogisticRegressionParams();
        }
 
        public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<SweepableParam> sweepParams,
            ColumnInformation columnInfo, IDataView validationSet)
        {
            var options = TrainerExtensionUtil.CreateOptions<LbfgsLogisticRegressionBinaryTrainer.Options>(sweepParams, columnInfo.LabelColumnName);
            options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName;
            return mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(options);
        }
 
        public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName);
        }
    }
 
    internal class SgdCalibratedBinaryExtension : ITrainerExtension
    {
        public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
        {
            return SweepableParams.BuildSgdParams();
        }
 
        public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<SweepableParam> sweepParams,
            ColumnInformation columnInfo, IDataView validationSet)
        {
            var options = TrainerExtensionUtil.CreateOptions<SgdCalibratedTrainer.Options>(sweepParams, columnInfo.LabelColumnName);
            options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName;
            return mlContext.BinaryClassification.Trainers.SgdCalibrated(options);
        }
 
        public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                 columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName);
        }
    }
 
    internal class SymbolicSgdLogisticRegressionBinaryExtension : ITrainerExtension
    {
        public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
        {
            return SweepableParams.BuildSymSgdLogisticRegressionParams();
        }
 
        public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<SweepableParam> sweepParams,
            ColumnInformation columnInfo, IDataView validationSet)
        {
            var options = TrainerExtensionUtil.CreateOptions<SymbolicSgdLogisticRegressionBinaryTrainer.Options>(sweepParams, columnInfo.LabelColumnName);
            return mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(options);
        }
 
        public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
        {
            return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams,
                columnInfo.LabelColumnName);
        }
    }
}