2 writes to FastTreeTrainerOptions
Microsoft.ML.FastTree (2)
FastTree.cs (2)
113FastTreeTrainerOptions = new TOptions(); 147FastTreeTrainerOptions = options;
249 references to FastTreeTrainerOptions
Microsoft.ML.FastTree (249)
BoostingFastTree.cs (48)
31FastTreeTrainerOptions.LearningRate = learningRate; 36if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeOptions.OptimizationAlgorithmType.AcceleratedGradientDescent) 37FastTreeTrainerOptions.UseLineSearch = true; 38if (FastTreeTrainerOptions.OptimizationAlgorithm == BoostedTreeOptions.OptimizationAlgorithmType.ConjugateGradientDescent) 39FastTreeTrainerOptions.UseLineSearch = true; 41if (FastTreeTrainerOptions.CompressEnsemble && FastTreeTrainerOptions.WriteLastEnsemble) 44if (FastTreeTrainerOptions.NumberOfLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumberOfLeaves - 1) 47if (FastTreeTrainerOptions.NumberOfLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumberOfLeaves - 1) 50if (FastTreeTrainerOptions.EnablePruning && !HasValidSet) 53bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null; 57if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet)) 66TrainSet, FastTreeTrainerOptions.NumberOfLeaves, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, FastTreeTrainerOptions.EntropyCoefficient, 67FastTreeTrainerOptions.FeatureFirstUsePenalty, FastTreeTrainerOptions.FeatureReusePenalty, FastTreeTrainerOptions.SoftmaxTemperature, 68FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit, FastTreeTrainerOptions.FilterZeroLambdas, 69FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode, 70FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, BsrMaxTreeOutput(), ParallelTraining, 71FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, 72FastTreeTrainerOptions.Bias, Host); 81switch (FastTreeTrainerOptions.OptimizationAlgorithm) 93throw ch.Except("Unknown optimization algorithm '{0}'", FastTreeTrainerOptions.OptimizationAlgorithm); 98optimizationAlgorithm.Smoothing = FastTreeTrainerOptions.Smoothing; 99optimizationAlgorithm.DropoutRate = FastTreeTrainerOptions.DropoutRate; 100optimizationAlgorithm.DropoutRng = new Random(FastTreeTrainerOptions.Seed); 108if (!FastTreeTrainerOptions.BestStepRankingRegressionTrees) 121if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null) 137if (FastTreeTrainerOptions.EarlyStoppingRuleFactory != null) 138earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRuleFactory.CreateComponent(Host, lowerIsBetter); 157if (!FastTreeTrainerOptions.WriteLastEnsemble && PruningTest != null) 170if (FastTreeTrainerOptions.BestStepRankingRegressionTrees) 171return FastTreeTrainerOptions.MaximumTreeOutput; 178return FastTreeTrainerOptions.RandomStart;
FastTree.cs (58)
94private protected string InnerOptions => CmdParser.GetSettings(Host, FastTreeTrainerOptions, new TOptions()); 117FastTreeTrainerOptions.NumberOfLeaves = numberOfLeaves; 118FastTreeTrainerOptions.NumberOfTrees = numberOfTrees; 119FastTreeTrainerOptions.MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf; 121FastTreeTrainerOptions.LabelColumnName = label.Name; 122FastTreeTrainerOptions.FeatureColumnName = featureColumnName; 123FastTreeTrainerOptions.ExampleWeightColumnName = exampleWeightColumnName; 124FastTreeTrainerOptions.RowGroupColumnName = rowGroupColumnName; 180ParallelTraining = FastTreeTrainerOptions.ParallelTrainer != null ? FastTreeTrainerOptions.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); 185InitializeThreads(FastTreeTrainerOptions.NumberOfThreads ?? Environment.ProcessorCount); 191var useTranspose = UseTranspose(FastTreeTrainerOptions.DiskTranspose, trainData) && (ValidData == null || UseTranspose(FastTreeTrainerOptions.DiskTranspose, ValidData)); 192var instanceConverter = new ExamplesToFastTreeBins(Host, FastTreeTrainerOptions.MaximumBinCountPerFeature, useTranspose, !FastTreeTrainerOptions.FeatureFlocks, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, GetMaxLabel()); 194TrainSet = instanceConverter.FindBinsAndReturnDataset(trainData, PredictionKind, ParallelTraining, CategoricalFeatures, FastTreeTrainerOptions.CategoricalSplit); 197ValidSet = instanceConverter.GetCompatibleDataset(ValidData, PredictionKind, CategoricalFeatures, FastTreeTrainerOptions.CategoricalSplit); 199TestSets = new[] { instanceConverter.GetCompatibleDataset(TestData, PredictionKind, CategoricalFeatures, FastTreeTrainerOptions.CategoricalSplit) }; 226if (FastTreeTrainerOptions.MemoryStatistics) 231if (FastTreeTrainerOptions.ExecutionTime) 259FastTreeTrainerOptions.Check(ch); 261IntArray.CompatibilityLevel = FastTreeTrainerOptions.FeatureCompressionLevel; 264if (FastTreeTrainerOptions.HistogramPoolSize < 2) 265FastTreeTrainerOptions.HistogramPoolSize = FastTreeTrainerOptions.NumberOfLeaves * 2 / 3; 266if (FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumberOfLeaves - 1) 267FastTreeTrainerOptions.HistogramPoolSize = FastTreeTrainerOptions.NumberOfLeaves - 1; 269if (FastTreeTrainerOptions.BaggingSize > 0) 271int bagCount = FastTreeTrainerOptions.NumberOfTrees / FastTreeTrainerOptions.BaggingSize; 272if (bagCount * FastTreeTrainerOptions.BaggingSize != FastTreeTrainerOptions.NumberOfTrees) 276if (!(0 <= FastTreeTrainerOptions.GainConfidenceLevel && FastTreeTrainerOptions.GainConfidenceLevel < 1)) 327if (!FastTreeTrainerOptions.PrintTestGraph) 413if (FastTreeTrainerOptions.FeatureFraction < 1.0) 416_featureSelectionRandom = new Random(FastTreeTrainerOptions.FeatureSelectionSeed); 421activeFeatures[i] = _featureSelectionRandom.NextDouble() <= FastTreeTrainerOptions.FeatureFraction; 585Contracts.Assert(FastTreeTrainerOptions.BaggingSize > 0); 586return new BaggingProvider(TrainSet, FastTreeTrainerOptions.NumberOfLeaves, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.BaggingExampleFraction); 597int numTotalTrees = FastTreeTrainerOptions.NumberOfTrees; 617OptimizationAlgorithm.TrainingScores.RandomizeScores(FastTreeTrainerOptions.Seed, false); 623BaggingProvider baggingProvider = FastTreeTrainerOptions.BaggingSize > 0 ? CreateBaggingProvider() : null; 661if (FastTreeTrainerOptions.BaggingSize > 0 && Ensemble.NumTrees % FastTreeTrainerOptions.BaggingSize == 0) 684else if (FastTreeTrainerOptions.BaggingSize > 0 && Ensemble.Trees.Count() > 0) 707OptimizationAlgorithm.TrainingScores.RandomizeScores(FastTreeTrainerOptions.Seed, true); 798if (FastTreeTrainerOptions.TestFrequency != int.MaxValue && (Ensemble.NumTrees % FastTreeTrainerOptions.TestFrequency == 0 || Ensemble.NumTrees == FastTreeTrainerOptions.NumberOfTrees)) 818ch.Trace("CommandLine = {0}", CmdParser.GetSettings(Host, FastTreeTrainerOptions, new TOptions())); 820ch.Trace("{0}", FastTreeTrainerOptions); 848if (!FastTreeTrainerOptions.CompressEnsemble)
FastTreeClassification.cs (17)
167_sigmoidParameter = 2.0 * FastTreeTrainerOptions.LearningRate; 179_sigmoidParameter = 2.0 * FastTreeTrainerOptions.LearningRate; 221FastTreeTrainerOptions.LearningRate, 222FastTreeTrainerOptions.Shrinkage, 224FastTreeTrainerOptions.UnbalancedSets, 225FastTreeTrainerOptions.MaximumTreeOutput, 226FastTreeTrainerOptions.GetDerivativesSampleRate, 227FastTreeTrainerOptions.BestStepRankingRegressionTrees, 228FastTreeTrainerOptions.Seed, 235if (FastTreeTrainerOptions.UseLineSearch) 239optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, FastTreeTrainerOptions.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/, FastTreeTrainerOptions.MaximumNumberOfLineSearchSteps, FastTreeTrainerOptions.MinimumStepSize); 286if (FastTreeTrainerOptions.EnablePruning && ValidSet != null) 288if (!FastTreeTrainerOptions.UseTolerantPruning) 296PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold);
FastTreeRanking.cs (40)
159Host.AssertValue(FastTreeTrainerOptions.CustomGains); 160return FastTreeTrainerOptions.CustomGains; 172if (FastTreeTrainerOptions.CustomGains != null) 174var gains = FastTreeTrainerOptions.CustomGains; 177throw ch.ExceptUserArg(nameof(FastTreeTrainerOptions.CustomGains), 185bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || 186FastTreeTrainerOptions.EnablePruning; 189ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3, 190nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 or 3."); 198if (FastTreeTrainerOptions.CompressEnsemble) 201_ensembleCompressor.Initialize(FastTreeTrainerOptions.NumberOfTrees, TrainSet, TrainSet.Ratings, FastTreeTrainerOptions.Seed); 207return new LambdaRankObjectiveFunction(TrainSet, TrainSet.Ratings, FastTreeTrainerOptions, ParallelTraining); 213if (FastTreeTrainerOptions.UseLineSearch) 215_specialTrainSetTest = new FastNdcgTest(optimizationAlgorithm.TrainingScores, TrainSet.Ratings, FastTreeTrainerOptions.SortingAlgorithm, FastTreeTrainerOptions.EarlyStoppingMetrics); 216optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(_specialTrainSetTest, 0, FastTreeTrainerOptions.MaximumNumberOfLineSearchSteps, FastTreeTrainerOptions.MinimumStepSize); 223Host.Assert(FastTreeTrainerOptions.BaggingSize > 0); 224return new RankingBaggingProvider(TrainSet, FastTreeTrainerOptions.NumberOfLeaves, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.BaggingExampleFraction); 233return new NdcgTest(ConstructScoreTracker(TrainSet), TrainSet.Ratings, FastTreeTrainerOptions.SortingAlgorithm); 238if (FastTreeTrainerOptions.TestFrequency != int.MaxValue) 243if (FastTreeTrainerOptions.PrintTestGraph) 260if (FastTreeTrainerOptions.PrintTrainValidGraph && FastTreeTrainerOptions.EnablePruning && _specialTrainSetTest == null) 265if (FastTreeTrainerOptions.EnablePruning && ValidTest != null) 267if (!FastTreeTrainerOptions.UseTolerantPruning) 275PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold); 397if (tree != null && FastTreeTrainerOptions.CompressEnsemble) 417FastTreeTrainerOptions.SortingAlgorithm); 430FastTreeTrainerOptions.SortingAlgorithm, 431FastTreeTrainerOptions.EarlyStoppingMetrics); 443FastTreeTrainerOptions.SortingAlgorithm, 444FastTreeTrainerOptions.EarlyStoppingMetrics); 464if (FastTreeTrainerOptions.PrintTrainValidGraph) 469FastTreeTrainerOptions.EarlyStoppingMetrics);
FastTreeRegression.cs (20)
133bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || 134FastTreeTrainerOptions.EnablePruning; 137ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2, 138nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); 148return new ObjectiveImpl(TrainSet, FastTreeTrainerOptions); 154if (FastTreeTrainerOptions.UseLineSearch) 158optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, FastTreeTrainerOptions.MaximumNumberOfLineSearchSteps, FastTreeTrainerOptions.MinimumStepSize); 259if (FastTreeTrainerOptions.TestFrequency != int.MaxValue) 262if (FastTreeTrainerOptions.PrintTestGraph) 274if (FastTreeTrainerOptions.PrintTrainValidGraph && _trainRegressionTest == null) 280if (FastTreeTrainerOptions.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0) 284TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), FastTreeTrainerOptions.EarlyStoppingMetrics); 286ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), FastTreeTrainerOptions.EarlyStoppingMetrics); 288if (FastTreeTrainerOptions.EnablePruning && ValidTest != null) 290if (FastTreeTrainerOptions.UseTolerantPruning) // Use simple early stopping condition. 291PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold); 341if (FastTreeTrainerOptions.PrintTrainValidGraph)
FastTreeTweedie.cs (24)
146if (FastTreeTrainerOptions.EarlyStoppingMetrics > 0) 149bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || 150FastTreeTrainerOptions.EnablePruning; 154ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 2, 155nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 (L1-norm) or 2 (L2-norm)."); 160return new ObjectiveImpl(TrainSet, FastTreeTrainerOptions); 166if (FastTreeTrainerOptions.UseLineSearch) 171optimizationAlgorithm.AdjustTreeOutputsOverride = new LineSearch(lossCalculator, 1 /*L2 error*/, FastTreeTrainerOptions.MaximumNumberOfLineSearchSteps, FastTreeTrainerOptions.MinimumStepSize); 207Host.CheckUserArg(1 <= FastTreeTrainerOptions.Index && FastTreeTrainerOptions.Index <= 2, nameof(FastTreeTrainerOptions.Index), "Must be in the range [1, 2]"); 238if (FastTreeTrainerOptions.TestFrequency != int.MaxValue) 241if (FastTreeTrainerOptions.PrintTestGraph) 253if (FastTreeTrainerOptions.PrintTrainValidGraph && _trainRegressionTest == null) 259if (FastTreeTrainerOptions.PrintTrainValidGraph && _testRegressionTest == null && TestSets != null && TestSets.Length > 0) 263TrainTest = new RegressionTest(ConstructScoreTracker(TrainSet), FastTreeTrainerOptions.EarlyStoppingMetrics); 265ValidTest = new RegressionTest(ConstructScoreTracker(ValidSet), FastTreeTrainerOptions.EarlyStoppingMetrics); 267if (FastTreeTrainerOptions.EnablePruning && ValidTest != null) 269if (FastTreeTrainerOptions.UseTolerantPruning) // Use simple early stopping condition. 270PruningTest = new TestWindowWithTolerance(ValidTest, 0, FastTreeTrainerOptions.PruningWindowSize, FastTreeTrainerOptions.PruningThreshold); 285if (FastTreeTrainerOptions.PrintTrainValidGraph)
RandomForest.cs (19)
50optimizationAlgorithm.Smoothing = FastTreeTrainerOptions.Smoothing; 66TrainSet, FastTreeTrainerOptions.NumberOfLeaves, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, FastTreeTrainerOptions.EntropyCoefficient, 67FastTreeTrainerOptions.FeatureFirstUsePenalty, FastTreeTrainerOptions.FeatureReusePenalty, FastTreeTrainerOptions.SoftmaxTemperature, 68FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit, 69FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode, 70FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, _quantileEnabled, FastTreeTrainerOptions.NumberOfQuantileSamples, ParallelTraining, 71FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, 72FastTreeTrainerOptions.Bias, Host);
RandomForestClassification.cs (11)
229if (FastTreeTrainerOptions.FeatureFraction != 1.0) 231ch.Warning($"oneDAL decision forest doesn't support 'FeatureFraction'[per tree] != 1.0, changing it from {FastTreeTrainerOptions.FeatureFraction} to 1.0"); 232FastTreeTrainerOptions.FeatureFraction = 1.0; 274int numberOfLeaves = FastTreeTrainerOptions.NumberOfLeaves; 275int numberOfTrees = FastTreeTrainerOptions.NumberOfTrees; 278if (FastTreeTrainerOptions.NumberOfThreads.HasValue) 279numberOfThreads = FastTreeTrainerOptions.NumberOfThreads.Value; 305numberOfThreads, (float)FastTreeTrainerOptions.FeatureFractionPerSplit, numberOfTrees, 306numberOfLeaves, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, FastTreeTrainerOptions.MaximumBinCountPerFeature, 350return new ObjectiveFunctionImpl(TrainSet, _trainSetLabels, FastTreeTrainerOptions);
RandomForestRegression.cs (12)
368if (FastTreeTrainerOptions.FeatureFraction != 1.0) 370ch.Warning($"oneDAL decision forest doesn't support 'FeatureFraction'[per tree] != 1.0, changing it from {FastTreeTrainerOptions.FeatureFraction} to 1.0"); 371FastTreeTrainerOptions.FeatureFraction = 1.0; 384return new FastForestRegressionModelParameters(Host, TrainedEnsemble, FeatureCount, InnerOptions, FastTreeTrainerOptions.NumberOfQuantileSamples); 406int numberOfLeaves = FastTreeTrainerOptions.NumberOfLeaves; 407int numberOfTrees = FastTreeTrainerOptions.NumberOfTrees; 410if (FastTreeTrainerOptions.NumberOfThreads.HasValue) 411numberOfThreads = FastTreeTrainerOptions.NumberOfThreads.Value; 437numberOfThreads, (float)FastTreeTrainerOptions.FeatureFractionPerSplit, numberOfTrees, 438numberOfLeaves, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, FastTreeTrainerOptions.MaximumBinCountPerFeature, 486return ObjectiveFunctionImplBase.Create(TrainSet, FastTreeTrainerOptions);