File: Evaluators\RegressionEvaluatorBase.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
 
namespace Microsoft.ML.Data
{
    [BestFriend]
    internal abstract class RegressionLossEvaluatorBase<TAgg> : RowToRowEvaluatorBase<TAgg>
        where TAgg : EvaluatorBase<TAgg>.AggregatorBase
    {
        public abstract class ArgumentsBase
        {
            [Argument(ArgumentType.Multiple, HelpText = "Loss function", ShortName = "loss")]
            public ISupportRegressionLossFactory LossFunction = new SquaredLossFactory();
        }
 
        public const string L1 = "L1(avg)";
        public const string L2 = "L2(avg)";
        public const string Rms = "RMS(avg)";
        public const string Loss = "Loss-fn(avg)";
        public const string RSquared = "R Squared";
 
        protected readonly IRegressionLoss LossFunction;
 
        protected RegressionLossEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string registrationName)
            : base(env, registrationName)
        {
            Host.CheckUserArg(args.LossFunction != null, nameof(args.LossFunction), "Loss function must be specified.");
            LossFunction = args.LossFunction.CreateComponent(env);
        }
    }
 
    [BestFriend]
    internal abstract class RegressionEvaluatorBase<TAgg, TScore, TMetrics> : RegressionLossEvaluatorBase<TAgg>
        where TAgg : RegressionEvaluatorBase<TAgg, TScore, TMetrics>.RegressionAggregatorBase
    {
        [BestFriend]
        private protected RegressionEvaluatorBase(ArgumentsBase args, IHostEnvironment env, string registrationName)
            : base(args, env, registrationName)
        {
        }
 
        private protected override void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries,
            out Action<uint, ReadOnlyMemory<char>, TAgg> addAgg, out Func<Dictionary<string, IDataView>> consolidate)
        {
            var stratCol = new List<uint>();
            var stratVal = new List<ReadOnlyMemory<char>>();
            var isWeighted = new List<bool>();
            var l1 = new List<TMetrics>();
            var l2 = new List<TMetrics>();
            var rms = new List<TMetrics>();
            var loss = new List<TMetrics>();
            var rSquared = new List<TMetrics>();
 
            bool hasStrats = Utils.Size(dictionaries) > 0;
            bool hasWeight = aggregator.Weighted;
 
            addAgg =
                (stratColKey, stratColVal, agg) =>
                {
                    Host.Check(agg.Weighted == hasWeight, "All aggregators must either be weighted or unweighted");
 
                    stratCol.Add(stratColKey);
                    stratVal.Add(stratColVal);
                    isWeighted.Add(false);
                    l1.Add(agg.UnweightedCounters.L1);
                    l2.Add(agg.UnweightedCounters.L2);
                    rms.Add(agg.UnweightedCounters.Rms);
                    loss.Add(agg.UnweightedCounters.Loss);
                    rSquared.Add(agg.UnweightedCounters.RSquared);
                    if (agg.Weighted)
                    {
                        stratCol.Add(stratColKey);
                        stratVal.Add(stratColVal);
                        isWeighted.Add(true);
                        l1.Add(agg.WeightedCounters.L1);
                        l2.Add(agg.WeightedCounters.L2);
                        rms.Add(agg.WeightedCounters.Rms);
                        loss.Add(agg.WeightedCounters.Loss);
                        rSquared.Add(agg.WeightedCounters.RSquared);
                    }
                };
 
            consolidate =
                () =>
                {
                    var overallDvBldr = new ArrayDataViewBuilder(Host);
                    if (hasStrats)
                    {
                        overallDvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, stratCol.ToArray());
                        overallDvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, stratVal.ToArray());
                    }
                    if (hasWeight)
                        overallDvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BooleanDataViewType.Instance, isWeighted.ToArray());
                    aggregator.AddColumn(overallDvBldr, L1, l1.ToArray());
                    aggregator.AddColumn(overallDvBldr, L2, l2.ToArray());
                    aggregator.AddColumn(overallDvBldr, Rms, rms.ToArray());
                    aggregator.AddColumn(overallDvBldr, Loss, loss.ToArray());
                    aggregator.AddColumn(overallDvBldr, RSquared, rSquared.ToArray());
 
                    var result = new Dictionary<string, IDataView>();
                    result.Add(MetricKinds.OverallMetrics, overallDvBldr.GetDataView());
                    return result;
                };
        }
 
        public abstract class RegressionAggregatorBase : AggregatorBase
        {
            public abstract class CountersBase
            {
                protected Double SumWeights;
                protected TMetrics TotalL1Loss;
                protected TMetrics TotalL2Loss;
                protected TMetrics TotalLoss;
                protected Double TotalLabelW;
                protected Double TotalLabelSquaredW;
 
                public TMetrics L1
                {
                    get
                    {
                        var res = Zero();
                        if (SumWeights > 0)
                            Normalize(in TotalL1Loss, ref res);
                        return res;
                    }
                }
 
                public TMetrics L2
                {
                    get
                    {
                        var res = Zero();
                        if (SumWeights > 0)
                            Normalize(in TotalL2Loss, ref res);
                        return res;
                    }
                }
 
                public abstract TMetrics Rms { get; }
 
                //Note this can be NaN if regressor reports loss as NaN
                public TMetrics Loss
                {
                    get
                    {
                        var res = Zero();
                        if (SumWeights > 0)
                            Normalize(in TotalLoss, ref res);
                        return res;
                    }
                }
 
                public abstract TMetrics RSquared { get; }
 
                public void Update(ref TScore score, float label, float weight, ref TMetrics loss)
                {
                    SumWeights += weight;
                    TotalLabelW += label * weight;
                    TotalLabelSquaredW += label * label * weight;
                    UpdateCore(label, in score, in loss, weight);
                }
 
                protected abstract void UpdateCore(float label, in TScore score, in TMetrics loss, float weight);
 
                protected abstract void Normalize(in TMetrics src, ref TMetrics dst);
 
                protected abstract TMetrics Zero();
            }
 
            private ValueGetter<float> _labelGetter;
            private ValueGetter<TScore> _scoreGetter;
            private ValueGetter<float> _weightGetter;
            protected TScore Score;
            protected TMetrics Loss;
 
            protected readonly IRegressionLoss LossFunction;
 
            public readonly bool Weighted;
 
            public abstract CountersBase UnweightedCounters { get; }
            public abstract CountersBase WeightedCounters { get; }
 
            [BestFriend]
            private protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss lossFunction, bool weighted, string stratName)
                : base(env, stratName)
            {
                Host.AssertValue(lossFunction);
                LossFunction = lossFunction;
                Weighted = weighted;
            }
 
            internal override void InitializeNextPass(DataViewRow row, RoleMappedSchema schema)
            {
                Contracts.Assert(PassNum < 1);
                Contracts.Assert(schema.Label.HasValue);
 
                var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
 
                _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index);
                _scoreGetter = row.GetGetter<TScore>(score);
                Contracts.AssertValue(_labelGetter);
                Contracts.AssertValue(_scoreGetter);
 
                if (schema.Weight.HasValue)
                    _weightGetter = row.GetGetter<float>(schema.Weight.Value);
            }
 
            public override void ProcessRow()
            {
                float label = 0;
                _labelGetter(ref label);
                _scoreGetter(ref Score);
 
                if (float.IsNaN(label))
                {
                    NumUnlabeledInstances++;
                    return;
                }
 
                if (IsNaN(in Score))
                {
                    NumBadScores++;
                    return;
                }
 
                float weight = 1;
                if (_weightGetter != null)
                {
                    _weightGetter(ref weight);
                    if (!FloatUtils.IsFinite(weight))
                    {
                        NumBadWeights++;
                        weight = 1;
                    }
                }
 
                ApplyLossFunction(in Score, label, ref Loss);
                UnweightedCounters.Update(ref Score, label, 1, ref Loss);
                if (WeightedCounters != null)
                    WeightedCounters.Update(ref Score, label, weight, ref Loss);
            }
 
            protected abstract void ApplyLossFunction(in TScore score, float label, ref TMetrics loss);
 
            protected abstract bool IsNaN(in TScore score);
 
            public abstract void AddColumn(ArrayDataViewBuilder dvBldr, string metricName, params TMetrics[] metric);
        }
    }
}