File: Training\StepSearch.cs
Web Access
Project: src\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj (Microsoft.ML.FastTree)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Linq;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers.FastTree
{
    internal interface IStepSearch
    {
        void AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree, DocumentPartitioning partitioning, ScoreTracker trainingScores);
    }
 
    internal sealed class LineSearch : IStepSearch, IFastTrainingScoresUpdate
    {
        private double _historicStepSize;
        private readonly int _numPostbracketSteps;
        private readonly double _minStepSize;
 
        public LineSearch(Test lossCalculator, int lossIndex, int numPostbracketSteps, double minStepSize)
            : this(lossCalculator, lossIndex) { _numPostbracketSteps = numPostbracketSteps; _minStepSize = minStepSize; }
 
        public LineSearch(Test lossCalculator, int lossIndex)
        {
            _lo = new StepScoresAndLoss(lossCalculator, lossIndex);
            _hi = new StepScoresAndLoss(lossCalculator, lossIndex);
            _left = new StepScoresAndLoss(lossCalculator, lossIndex);
            _right = new StepScoresAndLoss(lossCalculator, lossIndex);
            _historicStepSize = Math.Max(1.0, _minStepSize);
        }
 
        private static readonly double _phi = (1.0 + Math.Sqrt(5)) / 2;
 
        private static void Swap<T>(ref T a, ref T b)
        {
            T t = a;
            a = b;
            b = t;
        }
 
        private static void Rotate<T>(ref T a, ref T b, ref T c)
        {
            T t = a;
            a = b;
            b = c;
            c = t;
        }
 
        private sealed class StepScoresAndLoss
        {
            private readonly Test _lossCalculator;
            private readonly int _lossIndex;
 
            public StepScoresAndLoss(Test lossCalculator, int lossIndex)
            {
                _lossCalculator = lossCalculator;
                _lossIndex = lossIndex;
            }
 
            private InternalRegressionTree _tree;
            private DocumentPartitioning _partitioning;
            private ScoreTracker _previousScores;
 
            public void Initialize(InternalRegressionTree tree, DocumentPartitioning partitioning, ScoreTracker previousScores)
            {
                _tree = tree;
                _partitioning = partitioning;
                _previousScores = previousScores;
            }
 
            private double _step;
 
            public ScoreTracker Scores;
            public TestResult Loss;
 
            public double Step
            {
                get { return _step; }
                set
                {
                    if (Scores == null || Scores.Dataset != _previousScores.Dataset)
                        Scores = new ScoreTracker(_previousScores);
                    _step = value;
                    Scores.Initialize(_previousScores, _tree, _partitioning, _step);
                    Loss = _lossCalculator.ComputeTests(Scores.Scores).ToList()[_lossIndex];
                }
            }
        }
 
        private StepScoresAndLoss _lo;
        private StepScoresAndLoss _left;
        private StepScoresAndLoss _right;
        private StepScoresAndLoss _hi;
 
        void IStepSearch.AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree, DocumentPartitioning partitioning,
            ScoreTracker previousScores)
        {
            _lo.Initialize(tree, partitioning, previousScores);
            _hi.Initialize(tree, partitioning, previousScores);
            _left.Initialize(tree, partitioning, previousScores);
            _right.Initialize(tree, partitioning, previousScores);
 
            _lo.Step = _historicStepSize / _phi;
            _left.Step = _historicStepSize;
 
            if (_lo.Loss.CompareTo(_left.Loss) == 1) // backtrack
            {
                do
                {
                    Rotate(ref _hi, ref _left, ref _lo);
                    if (_hi.Step <= _minStepSize)
                        goto FINISHED;
                    _lo.Step = _left.Step / _phi;
                } while (_lo.Loss.CompareTo(_left.Loss) == 1);
            }
            else // extend (or stay)
            {
                _hi.Step = _historicStepSize * _phi;
                while (_hi.Loss.CompareTo(_left.Loss) == 1)
                {
                    Rotate(ref _lo, ref _left, ref _hi);
                    _hi.Step = _left.Step * _phi;
                }
            }
 
            if (_numPostbracketSteps > 0)
            {
                _right.Step = _lo.Step + (_hi.Step - _lo.Step) / _phi;
                for (int step = 0; step < _numPostbracketSteps; ++step)
                {
                    int cmp = _right.Loss.CompareTo(_left.Loss);
                    if (cmp == 0)
                        break;
 
                    if (cmp == 1) // move right
                    {
                        Rotate(ref _lo, ref _left, ref _right);
                        _right.Step = _lo.Step + (_hi.Step - _lo.Step) / _phi;
                    }
                    else // move left
                    {
                        Rotate(ref _hi, ref _right, ref _left);
                        if (_hi.Step <= _minStepSize)
                            goto FINISHED;
                        _left.Step = _hi.Step - (_hi.Step - _lo.Step) / _phi;
                    }
                }
 
                // prepare to return _left
                if (_right.Loss.CompareTo(_left.Loss) == 1)
                    Swap(ref _left, ref _right);
            }
 
FINISHED:
            if (_hi.Step < _minStepSize)
                _left.Step = _minStepSize;
            else if (_hi.Step == _minStepSize)
                Swap(ref _hi, ref _left);
 
            double bestStep = _left.Step;
 
            ch.Info("multiplier: {0}", bestStep);
            _historicStepSize = bestStep;
            tree.ScaleOutputsBy(bestStep);
        }
 
        ScoreTracker IFastTrainingScoresUpdate.GetUpdatedTrainingScores()
        {
            ScoreTracker result = _left.Scores;
            _left.Scores = null; //We need to set it to null so that next call to AdjustTreeOutputs will not destroy returned object
            return result;
        }
    }
}