|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Trainers.FastTree;
using Xunit;
using Xunit.Abstractions;
[assembly: LoadableClass(typeof(FastTreeParallelInterfaceChecker),
null, typeof(Microsoft.ML.Trainers.FastTree.SignatureParallelTrainer), "FastTreeParallelInterfaceChecker")]
namespace Microsoft.ML.RunTests
{
using LeafSplitCandidates = LeastSquaresRegressionTreeLearner.LeafSplitCandidates;
using SplitInfo = LeastSquaresRegressionTreeLearner.SplitInfo;
internal sealed class FastTreeParallelInterfaceChecker : Trainers.FastTree.IParallelTraining
{
private bool _isInitEnv = false;
private bool _isInitTreeLearner = false;
private bool _isInitIteration = false;
private bool _isCache = false;
public void CacheHistogram(bool isSmallerLeaf, int featureIdx, int subfeature, SufficientStatsBase sufficientStatsBase, bool hasWeights)
{
Assert.True(_isInitEnv);
Assert.True(_isInitTreeLearner);
Assert.True(_isInitIteration);
Assert.NotNull(sufficientStatsBase);
Assert.False(!_isCache);
_isCache = true;
return;
}
public bool IsNeedFindLocalBestSplit()
{
Assert.True(_isInitEnv);
Assert.True(_isInitTreeLearner);
Assert.True(_isInitIteration);
return true;
}
public void FindGlobalBestSplit(LeafSplitCandidates smallerChildSplitCandidates,
LeafSplitCandidates largerChildSplitCandidates,
Microsoft.ML.Trainers.FastTree.FindBestThresholdFromRawArrayFun findFunction,
SplitInfo[] bestSplits)
{
Assert.True(_isInitEnv);
Assert.True(_isInitTreeLearner);
Assert.True(_isInitIteration);
Assert.True(_isCache);
_isCache = false;
Assert.NotNull(smallerChildSplitCandidates);
Assert.NotNull(bestSplits);
return;
}
public void GetGlobalDataCountInLeaf(int leafIdx, ref int cnt)
{
Assert.True(_isInitEnv);
Assert.True(_isInitTreeLearner);
Assert.True(_isInitIteration);
Assert.True(leafIdx >= 0);
return;
}
public bool[] GetLocalBinConstructionFeatures(int numFeatures)
{
Assert.True(_isInitEnv);
Assert.True(numFeatures >= 0);
return Utils.CreateArray<bool>(numFeatures, true);
}
public double[] GlobalMean(Dataset dataset, InternalRegressionTree tree, DocumentPartitioning partitioning, double[] weights, bool filterZeroLambdas)
{
Assert.True(_isInitEnv);
Assert.True(_isInitTreeLearner);
Assert.NotNull(dataset);
Assert.NotNull(tree);
Assert.NotNull(partitioning);
double[] means = new double[tree.NumLeaves];
for (int l = 0; l < tree.NumLeaves; ++l)
{
means[l] = partitioning.Mean(weights, dataset.SampleWeights, l, filterZeroLambdas);
}
return means;
}
public void PerformGlobalSplit(int leaf, int lteChild, int gtChild, SplitInfo splitInfo)
{
Assert.True(_isInitEnv);
Assert.True(_isInitTreeLearner);
Assert.True(_isInitIteration);
return;
}
public void InitIteration(ref bool[] activeFeatures)
{
Assert.True(_isInitEnv);
Assert.True(_isInitTreeLearner);
Assert.False(_isInitIteration);
_isInitIteration = true;
Assert.NotNull(activeFeatures);
return;
}
public void InitEnvironment()
{
Assert.False(_isInitEnv);
_isInitEnv = true;
return;
}
public void InitTreeLearner(Dataset trainData, int maxNumLeaves, int maxCatSplitPoints, ref int minDocInLeaf)
{
Assert.True(_isInitEnv);
Assert.False(_isInitTreeLearner);
_isInitTreeLearner = true;
Assert.NotNull(trainData);
return;
}
public void SyncGlobalBoundary(int numFeatures, int maxBin, Double[][] binUpperBounds)
{
Assert.True(_isInitEnv);
Assert.NotNull(binUpperBounds);
return;
}
public void FinalizeEnvironment()
{
Assert.True(_isInitEnv);
Assert.False(_isInitTreeLearner);
Assert.False(_isInitIteration);
_isInitEnv = false;
return;
}
public void FinalizeTreeLearner()
{
Assert.True(_isInitEnv);
Assert.True(_isInitTreeLearner);
Assert.False(_isInitIteration);
_isInitTreeLearner = false;
return;
}
public void FinalizeIteration()
{
Assert.True(_isInitEnv);
Assert.True(_isInitTreeLearner);
Assert.True(_isInitIteration);
_isInitIteration = false;
return;
}
public bool IsSkipNonSplittableHistogram()
{
Assert.True(_isInitEnv);
Assert.True(_isInitTreeLearner);
Assert.True(_isInitIteration);
return true;
}
}
public class TestParallelFasttreeInterface : BaseTestBaseline
{
public TestParallelFasttreeInterface(ITestOutputHelper helper)
: base(helper)
{
}
[Fact(Skip = "'checker' is not a valid value for the 'parag' argument in FastTree")]
[TestCategory("ParallelFasttree")]
public void CheckFastTreeParallelInterface()
{
var dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
var outRoot = @"..\Common\CheckInterface";
var modelOutPath = DeleteOutputPath(outRoot, "codegen-model.zip");
var trainArgs = string.Format(
"train data={{{0}}} loader=Text{{col=Label:0 col=F!1:1-5 col=F2:6-9}} xf=Concat{{col=Features:F!1,F2}} tr=FastTreeBinaryClassification{{lr=0.1 nl=12 mil=10 iter=1 parag=checker}} out={{{1}}}",
dataPath, modelOutPath);
var res = MainForTest(trainArgs);
Assert.Equal(0, res);
}
}
}
|