File: CodeGenerator\CSharp\TrainerGeneratorFactory.cs
Web Access
Project: src\src\Microsoft.ML.CodeGenerator\Microsoft.ML.CodeGenerator.csproj (Microsoft.ML.CodeGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.AutoML;
using static Microsoft.ML.CodeGenerator.CSharp.TrainerGenerators;
 
namespace Microsoft.ML.CodeGenerator.CSharp
{
    internal interface ITrainerGenerator
    {
        string GenerateTrainer();
 
        string[] GenerateUsings();
    }
 
    internal static class TrainerGeneratorFactory
    {
        internal static ITrainerGenerator GetInstance(PipelineNode node)
        {
            if (Enum.TryParse(node.Name, out TrainerName trainer))
            {
                switch (trainer)
                {
                    case TrainerName.LightGbmBinary:
                        return new LightGbmBinary(node);
                    case TrainerName.LightGbmMulti:
                        return new LightGbmMulti(node);
                    case TrainerName.LightGbmRegression:
                        return new LightGbmRegression(node);
                    case TrainerName.AveragedPerceptronBinary:
                        return new AveragedPerceptron(node);
                    case TrainerName.FastForestBinary:
                        return new FastForestClassification(node);
                    case TrainerName.FastForestRegression:
                        return new FastForestRegression(node);
                    case TrainerName.FastTreeBinary:
                        return new FastTreeClassification(node);
                    case TrainerName.FastTreeRegression:
                        return new FastTreeRegression(node);
                    case TrainerName.FastTreeTweedieRegression:
                        return new FastTreeTweedie(node);
                    case TrainerName.LinearSvmBinary:
                        return new LinearSvm(node);
                    case TrainerName.LbfgsLogisticRegressionBinary:
                        return new LbfgsLogisticRegressionBinary(node);
                    case TrainerName.LbfgsMaximumEntropyMulti:
                        return new LbfgsMaximumEntropyMulti(node);
                    case TrainerName.OnlineGradientDescentRegression:
                        return new OnlineGradientDescentRegression(node);
                    case TrainerName.OlsRegression:
                        return new OlsRegression(node);
                    case TrainerName.LbfgsPoissonRegression:
                        return new LbfgsPoissonRegression(node);
                    case TrainerName.SdcaLogisticRegressionBinary:
                        return new StochasticDualCoordinateAscentBinary(node);
                    case TrainerName.SdcaMaximumEntropyMulti:
                        return new StochasticDualCoordinateAscentMulti(node);
                    case TrainerName.SdcaRegression:
                        return new StochasticDualCoordinateAscentRegression(node);
                    case TrainerName.SgdCalibratedBinary:
                        return new SgdCalibratedBinary(node);
                    case TrainerName.SymbolicSgdLogisticRegressionBinary:
                        return new SymbolicSgdLogisticRegressionBinary(node);
                    case TrainerName.Ova:
                        return new OneVersusAll(node);
                    case TrainerName.ImageClassification:
                        return new ImageClassificationTrainer(node);
                    case TrainerName.MatrixFactorization:
                        return new MatrixFactorization(node);
                    case TrainerName.LightGbmRanking:
                        return new LightGbmRanking(node);
                    case TrainerName.FastTreeRanking:
                        return new FastTreeRanking(node);
                    default:
                        throw new ArgumentException($"The trainer '{trainer}' is not handled currently.");
                }
            }
            throw new ArgumentException($"The trainer '{node.Name}' is not handled currently.");
        }
    }
}