File: Evaluators\RegressionEvaluator.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
 
[assembly: LoadableClass(typeof(RegressionEvaluator), typeof(RegressionEvaluator), typeof(RegressionEvaluator.Arguments), typeof(SignatureEvaluator),
    "Regression Evaluator", RegressionEvaluator.LoadName, "Regression")]
 
[assembly: LoadableClass(typeof(RegressionMamlEvaluator), typeof(RegressionMamlEvaluator), typeof(RegressionMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator),
    "Regression Evaluator", RegressionEvaluator.LoadName, "Regression")]
 
// This is for deserialization from a binary model file.
[assembly: LoadableClass(typeof(RegressionPerInstanceEvaluator), null, typeof(SignatureLoadRowMapper),
    "", RegressionPerInstanceEvaluator.LoaderSignature)]
 
namespace Microsoft.ML.Data
{
    [BestFriend]
    internal sealed class RegressionEvaluator :
        RegressionEvaluatorBase<RegressionEvaluator.Aggregator, float, Double>
    {
        public sealed class Arguments : ArgumentsBase
        {
        }
 
        public enum Metrics
        {
            [EnumValueDisplay(RegressionEvaluator.L1)]
            L1,
            [EnumValueDisplay(RegressionEvaluator.L2)]
            L2,
            [EnumValueDisplay(RegressionEvaluator.Rms)]
            Rms,
            [EnumValueDisplay(RegressionEvaluator.Loss)]
            Loss,
            [EnumValueDisplay(RegressionEvaluator.RSquared)]
            RSquared,
        }
 
        public const string LoadName = "RegressionEvaluator";
 
        public RegressionEvaluator(IHostEnvironment env, Arguments args)
            : base(args, env, LoadName)
        {
        }
 
        private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
        {
            var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
            var t = score.Type;
            if (t != NumberDataViewType.Single)
                throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "Single", t.ToString());
            Host.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column");
            t = schema.Label.Value.Type;
            if (t != NumberDataViewType.Single)
                throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "Single", t.ToString());
        }
 
        private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
        {
            return new Aggregator(Host, LossFunction, schema.Weight != null, stratName);
        }
 
        private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema)
        {
            Contracts.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column");
            var scoreInfo = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
 
            return new RegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Value.Name);
        }
 
        public override IEnumerable<MetricColumn> GetOverallMetricColumns()
        {
            yield return new MetricColumn("L1", L1, MetricColumn.Objective.Minimize);
            yield return new MetricColumn("L2", L2, MetricColumn.Objective.Minimize);
            yield return new MetricColumn("Rms", Rms, MetricColumn.Objective.Minimize);
            yield return new MetricColumn("Loss", Loss, MetricColumn.Objective.Minimize);
            yield return new MetricColumn("RSquared", RSquared);
        }
 
        public sealed class Aggregator : RegressionAggregatorBase
        {
            private sealed class Counters : CountersBase
            {
                public override double Rms
                {
                    get
                    {
                        return SumWeights > 0 ? Math.Sqrt(TotalL2Loss / SumWeights) : 0;
                    }
                }
 
                public override double RSquared
                {
                    get
                    {
                        // RSquared value cannot be well-defined with less than two samples.
                        // Return NaN instead of -Infinity.
                        if (SumWeights > 0)
                        {
                            if ((TotalLabelSquaredW - TotalLabelW * TotalLabelW / SumWeights) == 0)
                                return double.NaN;
                            return 1 - TotalL2Loss / (TotalLabelSquaredW - TotalLabelW * TotalLabelW / SumWeights);
                        }
                        return 0;
                    }
                }
 
                protected override void UpdateCore(float label, in float score, in double loss, float weight)
                {
                    Double currL1Loss = Math.Abs((Double)label - score);
                    TotalL1Loss += currL1Loss * weight;
                    TotalL2Loss += currL1Loss * currL1Loss * weight;
                    TotalLoss += loss * weight; // REVIEW: Fix this! += (Double)loss * wht; //Loss as reported by regressor, note it can result in NaN if loss is NaN
                }
 
                protected override void Normalize(in double src, ref double dst)
                {
                    dst = src / SumWeights;
                }
 
                protected override double Zero()
                {
                    return 0;
                }
            }
 
            private readonly Counters _counters;
            private readonly Counters _weightedCounters;
 
            public override CountersBase UnweightedCounters { get { return _counters; } }
 
            public override CountersBase WeightedCounters { get { return _weightedCounters; } }
 
            public Aggregator(IHostEnvironment env, IRegressionLoss lossFunction, bool weighted, string stratName)
                : base(env, lossFunction, weighted, stratName)
            {
                _counters = new Counters();
                _weightedCounters = Weighted ? new Counters() : null;
            }
 
            protected override void ApplyLossFunction(in float score, float label, ref double loss)
            {
                loss = LossFunction.Loss(score, label);
            }
 
            protected override bool IsNaN(in float score)
            {
                return float.IsNaN(score);
            }
 
            public override void AddColumn(ArrayDataViewBuilder dvBldr, string metricName, params double[] metric)
            {
                Host.AssertValue(dvBldr);
                dvBldr.AddColumn(metricName, NumberDataViewType.Double, metric);
            }
        }
 
        /// <summary>
        /// Evaluates scored regression data.
        /// </summary>
        /// <param name="data">The data to evaluate.</param>
        /// <param name="label">The name of the label column.</param>
        /// <param name="score">The name of the predicted score column.</param>
        /// <returns>The evaluation metrics for these outputs.</returns>
        public RegressionMetrics Evaluate(IDataView data, string label, string score)
        {
            Host.CheckValue(data, nameof(data));
            Host.CheckNonEmpty(label, nameof(label));
            Host.CheckNonEmpty(score, nameof(score));
            var roles = new RoleMappedData(data, opt: false,
                RoleMappedSchema.ColumnRole.Label.Bind(label),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score));
 
            var resultDict = ((IEvaluator)this).Evaluate(roles);
            Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
            var overall = resultDict[MetricKinds.OverallMetrics];
 
            RegressionMetrics result;
            using (var cursor = overall.GetRowCursorForAllColumns())
            {
                var moved = cursor.MoveNext();
                Host.Assert(moved);
                result = new RegressionMetrics(Host, cursor);
                moved = cursor.MoveNext();
                Host.Assert(!moved);
            }
            return result;
        }
    }
 
    internal sealed class RegressionPerInstanceEvaluator : PerInstanceEvaluatorBase
    {
        public const string LoaderSignature = "RegressionPerInstance";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "REG INST",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(RegressionPerInstanceEvaluator).Assembly.FullName);
        }
 
        private const int L1Col = 0;
        private const int L2Col = 1;
 
        public const string L1 = "L1-loss";
        public const string L2 = "L2-loss";
 
        public RegressionPerInstanceEvaluator(IHostEnvironment env, DataViewSchema schema, string scoreCol, string labelCol)
            : base(env, schema, scoreCol, labelCol)
        {
            CheckInputColumnTypes(schema);
        }
 
        private RegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
            : base(env, ctx, schema)
        {
            CheckInputColumnTypes(schema);
 
            // *** Binary format **
            // base
        }
 
        public static RegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            return new RegressionPerInstanceEvaluator(env, ctx, schema);
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format **
            // base
            base.SaveModel(ctx);
        }
 
        private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
        {
            return
                col => (activeOutput(L1Col) || activeOutput(L2Col)) && (col == ScoreIndex || col == LabelIndex);
        }
 
        private protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
        {
            var infos = new DataViewSchema.DetachedColumn[2];
            infos[L1Col] = new DataViewSchema.DetachedColumn(L1, NumberDataViewType.Double, null);
            infos[L2Col] = new DataViewSchema.DetachedColumn(L2, NumberDataViewType.Double, null);
            return infos;
        }
 
        private protected override Delegate[] CreateGettersCore(DataViewRow input, Func<int, bool> activeCols, out Action disposer)
        {
            Host.Assert(LabelIndex >= 0);
            Host.Assert(ScoreIndex >= 0);
 
            disposer = null;
 
            long cachedPosition = -1;
            float label = 0;
            float score = 0;
 
            ValueGetter<float> nan = (ref float value) => value = Single.NaN;
            var labelGetter = activeCols(L1Col) || activeCols(L2Col) ? RowCursorUtils.GetLabelGetter(input, LabelIndex) : nan;
            ValueGetter<float> scoreGetter;
            if (activeCols(L1Col) || activeCols(L2Col))
                scoreGetter = input.GetGetter<float>(input.Schema[ScoreIndex]);
            else
                scoreGetter = nan;
            Action updateCacheIfNeeded =
                () =>
                {
                    if (cachedPosition != input.Position)
                    {
                        labelGetter(ref label);
                        scoreGetter(ref score);
                        cachedPosition = input.Position;
                    }
                };
 
            var getters = new Delegate[2];
            if (activeCols(L1Col))
            {
                ValueGetter<double> l1Fn =
                    (ref double dst) =>
                    {
                        updateCacheIfNeeded();
                        dst = Math.Abs((Double)label - score);
                    };
                getters[L1Col] = l1Fn;
            }
            if (activeCols(L2Col))
            {
                ValueGetter<double> l2Fn =
                    (ref double dst) =>
                    {
                        updateCacheIfNeeded();
                        dst = Math.Abs((Double)label - score);
                        dst *= dst;
                    };
                getters[L2Col] = l2Fn;
            }
            return getters;
        }
 
        private void CheckInputColumnTypes(DataViewSchema schema)
        {
            Host.AssertNonEmpty(ScoreCol);
            Host.AssertNonEmpty(LabelCol);
 
            var t = schema[(int)LabelIndex].Type;
            if (t != NumberDataViewType.Single)
                throw Host.ExceptSchemaMismatch(nameof(schema), "label", LabelCol, "Single", t.ToString());
 
            t = schema[ScoreIndex].Type;
            if (t != NumberDataViewType.Single)
                throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol, "Single", t.ToString());
        }
    }
 
    [BestFriend]
    internal sealed class RegressionMamlEvaluator : MamlEvaluatorBase
    {
        public sealed class Arguments : ArgumentsBase
        {
            [Argument(ArgumentType.Multiple, HelpText = "Loss function", ShortName = "loss")]
            public ISupportRegressionLossFactory LossFunction = new SquaredLossFactory();
        }
 
        private readonly RegressionEvaluator _evaluator;
 
        private protected override IEvaluator Evaluator => _evaluator;
 
        public RegressionMamlEvaluator(IHostEnvironment env, Arguments args)
            : base(args, env, AnnotationUtils.Const.ScoreColumnKind.Regression, "RegressionMamlEvaluator")
        {
            Host.CheckUserArg(args.LossFunction != null, nameof(args.LossFunction), "Loss function must be specified.");
 
            var evalArgs = new RegressionEvaluator.Arguments();
            evalArgs.LossFunction = args.LossFunction;
            _evaluator = new RegressionEvaluator(Host, evalArgs);
        }
 
        private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
        {
            Host.CheckValue(schema, nameof(schema));
            Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column");
 
            // The regression evaluator outputs the label and score columns.
            yield return schema.Label.Value.Name;
            var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
                AnnotationUtils.Const.ScoreColumnKind.Regression);
            yield return scoreCol.Name;
 
            // Return the output columns.
            yield return RegressionPerInstanceEvaluator.L1;
            yield return RegressionPerInstanceEvaluator.L2;
 
            // REVIEW: Identify by metadata.
            int col;
            if (schema.Schema.TryGetColumnIndex("FeatureContributions", out col))
                yield return "FeatureContributions";
        }
    }
 
    internal static partial class Evaluate
    {
        [TlcModule.EntryPoint(Name = "Models.RegressionEvaluator", Desc = "Evaluates a regression scored dataset.")]
        public static CommonOutputs.CommonEvaluateOutput Regression(IHostEnvironment env, RegressionMamlEvaluator.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("EvaluateRegression");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            string label;
            string weight;
            string name;
            MatchColumns(host, input, out label, out weight, out name);
            IMamlEvaluator evaluator = new RegressionMamlEvaluator(host, input);
            var data = new RoleMappedData(input.Data, label, null, null, weight, name);
            var metrics = evaluator.Evaluate(data);
 
            var warnings = ExtractWarnings(host, metrics);
            var overallMetrics = ExtractOverallMetrics(host, metrics, evaluator);
            var perInstanceMetrics = evaluator.GetPerInstanceMetrics(data);
 
            return new CommonOutputs.CommonEvaluateOutput()
            {
                Warnings = warnings,
                OverallMetrics = overallMetrics,
                PerInstanceMetrics = perInstanceMetrics
            };
        }
    }
}