// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
[assembly: LoadableClass(typeof(RegressionEvaluator), typeof(RegressionEvaluator), typeof(RegressionEvaluator.Arguments), typeof(SignatureEvaluator),
    "Regression Evaluator", RegressionEvaluator.LoadName, "Regression")]
[assembly: LoadableClass(typeof(RegressionMamlEvaluator), typeof(RegressionMamlEvaluator), typeof(RegressionMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator),
    "Regression Evaluator", RegressionEvaluator.LoadName, "Regression")]
// This is for deserialization from a binary model file.
[assembly: LoadableClass(typeof(RegressionPerInstanceEvaluator), null, typeof(SignatureLoadRowMapper),
    "", RegressionPerInstanceEvaluator.LoaderSignature)]
namespace Microsoft.ML.Data
    internal sealed class RegressionEvaluator :
        RegressionEvaluatorBase<RegressionEvaluator.Aggregator, float, Double>
        public sealed class Arguments : ArgumentsBase
        public enum Metrics
        public const string LoadName = "RegressionEvaluator";
        public RegressionEvaluator(IHostEnvironment env, Arguments args)
            : base(args, env, LoadName)
        private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
            var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
            var t = score.Type;
            if (t != NumberDataViewType.Single)
                throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "Single", t.ToString());
            Host.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column");
            t = schema.Label.Value.Type;
            if (t != NumberDataViewType.Single)
                throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "Single", t.ToString());
        private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
            return new Aggregator(Host, LossFunction, schema.Weight != null, stratName);
        private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema)
            Contracts.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column");
            var scoreInfo = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
            return new RegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Value.Name);
        public override IEnumerable<MetricColumn> GetOverallMetricColumns()
            yield return new MetricColumn("L1", L1, MetricColumn.Objective.Minimize);
            yield return new MetricColumn("L2", L2, MetricColumn.Objective.Minimize);
            yield return new MetricColumn("Rms", Rms, MetricColumn.Objective.Minimize);
            yield return new MetricColumn("Loss", Loss, MetricColumn.Objective.Minimize);
            yield return new MetricColumn("RSquared", RSquared);
        public sealed class Aggregator : RegressionAggregatorBase
            private sealed class Counters : CountersBase
                public override double Rms
                        return SumWeights > 0 ? Math.Sqrt(TotalL2Loss / SumWeights) : 0;
                public override double RSquared
                        // RSquared value cannot be well-defined with less than two samples.
                        // Return NaN instead of -Infinity.
                        if (SumWeights > 0)
                            if ((TotalLabelSquaredW - TotalLabelW * TotalLabelW / SumWeights) == 0)
                                return double.NaN;
                            return 1 - TotalL2Loss / (TotalLabelSquaredW - TotalLabelW * TotalLabelW / SumWeights);
                        return 0;
                protected override void UpdateCore(float label, in float score, in double loss, float weight)
                    Double currL1Loss = Math.Abs((Double)label - score);
                    TotalL1Loss += currL1Loss * weight;
                    TotalL2Loss += currL1Loss * currL1Loss * weight;
                    TotalLoss += loss * weight; // REVIEW: Fix this! += (Double)loss * wht; //Loss as reported by regressor, note it can result in NaN if loss is NaN
                protected override void Normalize(in double src, ref double dst)
                    dst = src / SumWeights;
                protected override double Zero()
                    return 0;
            private readonly Counters _counters;
            private readonly Counters _weightedCounters;
            public override CountersBase UnweightedCounters { get { return _counters; } }
            public override CountersBase WeightedCounters { get { return _weightedCounters; } }
            public Aggregator(IHostEnvironment env, IRegressionLoss lossFunction, bool weighted, string stratName)
                : base(env, lossFunction, weighted, stratName)
                _counters = new Counters();
                _weightedCounters = Weighted ? new Counters() : null;
            protected override void ApplyLossFunction(in float score, float label, ref double loss)
                loss = LossFunction.Loss(score, label);
            protected override bool IsNaN(in float score)
                return float.IsNaN(score);
            public override void AddColumn(ArrayDataViewBuilder dvBldr, string metricName, params double[] metric)
                dvBldr.AddColumn(metricName, NumberDataViewType.Double, metric);
        /// <summary>
        /// Evaluates scored regression data.
        /// </summary>
        /// <param name="data">The data to evaluate.</param>
        /// <param name="label">The name of the label column.</param>
        /// <param name="score">The name of the predicted score column.</param>
        /// <returns>The evaluation metrics for these outputs.</returns>
        public RegressionMetrics Evaluate(IDataView data, string label, string score)
            Host.CheckValue(data, nameof(data));
            Host.CheckNonEmpty(label, nameof(label));
            Host.CheckNonEmpty(score, nameof(score));
            var roles = new RoleMappedData(data, opt: false,
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score));
            var resultDict = ((IEvaluator)this).Evaluate(roles);
            var overall = resultDict[MetricKinds.OverallMetrics];
            RegressionMetrics result;
            using (var cursor = overall.GetRowCursorForAllColumns())
                var moved = cursor.MoveNext();
                result = new RegressionMetrics(Host, cursor);
                moved = cursor.MoveNext();
            return result;
    internal sealed class RegressionPerInstanceEvaluator : PerInstanceEvaluatorBase
        public const string LoaderSignature = "RegressionPerInstance";
        private static VersionInfo GetVersionInfo()
            return new VersionInfo(
                modelSignature: "REG INST",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(RegressionPerInstanceEvaluator).Assembly.FullName);
        private const int L1Col = 0;
        private const int L2Col = 1;
        public const string L1 = "L1-loss";
        public const string L2 = "L2-loss";
        public RegressionPerInstanceEvaluator(IHostEnvironment env, DataViewSchema schema, string scoreCol, string labelCol)
            : base(env, schema, scoreCol, labelCol)
        private RegressionPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
            : base(env, ctx, schema)
            // *** Binary format **
            // base
        public static RegressionPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            return new RegressionPerInstanceEvaluator(env, ctx, schema);
        private protected override void SaveModel(ModelSaveContext ctx)
            Contracts.CheckValue(ctx, nameof(ctx));
            // *** Binary format **
            // base
        private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
                col => (activeOutput(L1Col) || activeOutput(L2Col)) && (col == ScoreIndex || col == LabelIndex);
        private protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            var infos = new DataViewSchema.DetachedColumn[2];
            infos[L1Col] = new DataViewSchema.DetachedColumn(L1, NumberDataViewType.Double, null);
            infos[L2Col] = new DataViewSchema.DetachedColumn(L2, NumberDataViewType.Double, null);
            return infos;
        private protected override Delegate[] CreateGettersCore(DataViewRow input, Func<int, bool> activeCols, out Action disposer)
            Host.Assert(LabelIndex >= 0);
            Host.Assert(ScoreIndex >= 0);
            disposer = null;
            long cachedPosition = -1;
            float label = 0;
            float score = 0;
            ValueGetter<float> nan = (ref float value) => value = Single.NaN;
            var labelGetter = activeCols(L1Col) || activeCols(L2Col) ? RowCursorUtils.GetLabelGetter(input, LabelIndex) : nan;
            ValueGetter<float> scoreGetter;
            if (activeCols(L1Col) || activeCols(L2Col))
                scoreGetter = input.GetGetter<float>(input.Schema[ScoreIndex]);
                scoreGetter = nan;
            Action updateCacheIfNeeded =
                () =>
                    if (cachedPosition != input.Position)
                        labelGetter(ref label);
                        scoreGetter(ref score);
                        cachedPosition = input.Position;
            var getters = new Delegate[2];
            if (activeCols(L1Col))
                ValueGetter<double> l1Fn =
                    (ref double dst) =>
                        dst = Math.Abs((Double)label - score);
                getters[L1Col] = l1Fn;
            if (activeCols(L2Col))
                ValueGetter<double> l2Fn =
                    (ref double dst) =>
                        dst = Math.Abs((Double)label - score);
                        dst *= dst;
                getters[L2Col] = l2Fn;
            return getters;
        private void CheckInputColumnTypes(DataViewSchema schema)
            var t = schema[(int)LabelIndex].Type;
            if (t != NumberDataViewType.Single)
                throw Host.ExceptSchemaMismatch(nameof(schema), "label", LabelCol, "Single", t.ToString());
            t = schema[ScoreIndex].Type;
            if (t != NumberDataViewType.Single)
                throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol, "Single", t.ToString());
    internal sealed class RegressionMamlEvaluator : MamlEvaluatorBase
        public sealed class Arguments : ArgumentsBase
            [Argument(ArgumentType.Multiple, HelpText = "Loss function", ShortName = "loss")]
            public ISupportRegressionLossFactory LossFunction = new SquaredLossFactory();
        private readonly RegressionEvaluator _evaluator;
        private protected override IEvaluator Evaluator => _evaluator;
        public RegressionMamlEvaluator(IHostEnvironment env, Arguments args)
            : base(args, env, AnnotationUtils.Const.ScoreColumnKind.Regression, "RegressionMamlEvaluator")
            Host.CheckUserArg(args.LossFunction != null, nameof(args.LossFunction), "Loss function must be specified.");
            var evalArgs = new RegressionEvaluator.Arguments();
            evalArgs.LossFunction = args.LossFunction;
            _evaluator = new RegressionEvaluator(Host, evalArgs);
        private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
            Host.CheckValue(schema, nameof(schema));
            Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column");
            // The regression evaluator outputs the label and score columns.
            yield return schema.Label.Value.Name;
            var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
            yield return scoreCol.Name;
            // Return the output columns.
            yield return RegressionPerInstanceEvaluator.L1;
            yield return RegressionPerInstanceEvaluator.L2;
            // REVIEW: Identify by metadata.
            int col;
            if (schema.Schema.TryGetColumnIndex("FeatureContributions", out col))
                yield return "FeatureContributions";
    internal static partial class Evaluate
        [TlcModule.EntryPoint(Name = "Models.RegressionEvaluator", Desc = "Evaluates a regression scored dataset.")]
        public static CommonOutputs.CommonEvaluateOutput Regression(IHostEnvironment env, RegressionMamlEvaluator.Arguments input)
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("EvaluateRegression");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            string label;
            string weight;
            string name;
            MatchColumns(host, input, out label, out weight, out name);
            IMamlEvaluator evaluator = new RegressionMamlEvaluator(host, input);
            var data = new RoleMappedData(input.Data, label, null, null, weight, name);
            var metrics = evaluator.Evaluate(data);
            var warnings = ExtractWarnings(host, metrics);
            var overallMetrics = ExtractOverallMetrics(host, metrics, evaluator);
            var perInstanceMetrics = evaluator.GetPerInstanceMetrics(data);
            return new CommonOutputs.CommonEvaluateOutput()
                Warnings = warnings,
                OverallMetrics = overallMetrics,
                PerInstanceMetrics = perInstanceMetrics