File: Evaluators\RankingEvaluator.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
[assembly: LoadableClass(typeof(RankingEvaluator), typeof(RankingEvaluator), typeof(RankingEvaluatorOptions), typeof(SignatureEvaluator),
    "Ranking Evaluator", RankingEvaluator.LoadName, "Ranking", "rank")]
 
[assembly: LoadableClass(typeof(RankingMamlEvaluator), typeof(RankingMamlEvaluator), typeof(RankingMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator),
    "Ranking Evaluator", RankingEvaluator.LoadName, "Ranking", "rank")]
 
[assembly: LoadableClass(typeof(RankingPerInstanceTransform), null, typeof(SignatureLoadDataTransform),
    "", RankingPerInstanceTransform.LoaderSignature)]
 
namespace Microsoft.ML.Data
{
    /// <summary>
    ///  Options to control the output of the RankingEvaluator
    /// </summary>
    public sealed class RankingEvaluatorOptions
    {
        /// <value>
        /// Maximum truncation level for computing (N)DCG
        /// </value>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum truncation level for computing (N)DCG", ShortName = "t")]
        public int DcgTruncationLevel = 10;
 
        /// <value>
        /// Label relevance gains
        /// </value>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Label relevance gains", ShortName = "gains")]
        public string LabelGains = "0,3,7,15,31";
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "Generate per-group (N)DCG", ShortName = "ogs")]
        internal bool OutputGroupSummary;
    }
 
    [BestFriend]
    internal sealed class RankingEvaluator : EvaluatorBase<RankingEvaluator.Aggregator>
    {
        internal const string LoadName = "RankingEvaluator";
 
        public const string Ndcg = "NDCG";
        public const string Dcg = "DCG";
        public const string MaxDcg = "MaxDCG";
 
        /// <value>
        /// The ranking evaluator outputs a data view by this name, which contains metrics aggregated per group.
        /// It contains four columns: GroupId, NDCG, DCG and MaxDCG. Each row in the data view corresponds to one
        /// group in the scored data.
        /// </value>
        public const string GroupSummary = "GroupSummary";
 
        private const string GroupId = DefaultColumnNames.GroupId;
 
        private readonly int _truncationLevel;
        private readonly bool _groupSummary;
        private readonly Double[] _labelGains;
 
        public RankingEvaluator(IHostEnvironment env, RankingEvaluatorOptions options)
            : base(env, LoadName)
        {
            // REVIEW: What kind of checking should be applied to labelGains?
            // add the setter to utils here
            if (options.DcgTruncationLevel <= 0)
                throw Host.ExceptUserArg(nameof(options.DcgTruncationLevel), "DCG Truncation Level must be greater than 0");
            Host.CheckUserArg(options.LabelGains != null, nameof(options.LabelGains), "Label gains cannot be null");
 
            _truncationLevel = options.DcgTruncationLevel;
            _groupSummary = options.OutputGroupSummary;
 
            var labelGains = new List<Double>();
            string[] gains = options.LabelGains.Split(',');
            for (int i = 0; i < gains.Length; i++)
            {
                Double gain;
                if (!Double.TryParse(gains[i], out gain))
                    throw Host.ExceptUserArg(nameof(options.LabelGains), "Label Gains must be of floating or integral type");
                labelGains.Add(gain);
            }
            _labelGains = labelGains.ToArray();
        }
 
        private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
        {
            var t = schema.Label.Value.Type;
            if (t != NumberDataViewType.Single && !(t is KeyDataViewType))
            {
                throw Host.ExceptSchemaMismatch(nameof(RankingMamlEvaluator.Arguments.LabelColumn),
                    "label", schema.Label.Value.Name, "Single or a Key", t.ToString());
            }
            var scoreCol = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
            if (scoreCol.Type != NumberDataViewType.Single)
            {
                throw Host.ExceptSchemaMismatch(nameof(RankingMamlEvaluator.Arguments.ScoreColumn),
                    "score", scoreCol.Name, "Single", t.ToString());
            }
        }
 
        private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
        {
            var t = schema.Group.Value.Type;
            if (!(t is KeyDataViewType))
            {
                throw Host.ExceptSchemaMismatch(nameof(RankingMamlEvaluator.Arguments.GroupIdColumn),
                    "group", schema.Group.Value.Name, "Key", t.ToString());
            }
        }
 
        // Add also the group column.
        private protected override Func<int, bool> GetActiveColsCore(RoleMappedSchema schema)
        {
            var pred = base.GetActiveColsCore(schema);
            return i => i == schema.Group.Value.Index || pred(i);
        }
 
        private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
        {
            return new Aggregator(Host, _labelGains, _truncationLevel, _groupSummary, schema.Weight != null, stratName);
        }
 
        internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data)
        {
            Host.CheckValue(data, nameof(data));
            Host.CheckParam(data.Schema.Label.HasValue, nameof(data), "Schema must contain a label column");
            var scoreInfo = data.Schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
            Host.CheckParam(data.Schema.Group.HasValue, nameof(data), "Schema must contain a group column");
 
            return new RankingPerInstanceTransform(Host, data.Data,
                data.Schema.Label.Value.Name, scoreInfo.Name, data.Schema.Group.Value.Name, _truncationLevel, _labelGains);
        }
 
        public override IEnumerable<MetricColumn> GetOverallMetricColumns()
        {
            yield return new MetricColumn("NDCG@<number>", Ndcg, isVector: true,
                namePattern: new Regex(string.Format(@"^{0}@(?<at>\d+)", Ndcg), RegexOptions.IgnoreCase),
                groupName: "at", nameFormat: string.Format("{0} @{{0}}", Ndcg));
            yield return new MetricColumn("DCG@<number>", Dcg, isVector: true,
                namePattern: new Regex(string.Format(@"^{0}@(?<at>\d+)", Dcg), RegexOptions.IgnoreCase),
                groupName: "at", nameFormat: string.Format("{0} @{{0}}", Dcg));
            yield return new MetricColumn("MaxDcg@<number>", MaxDcg, isVector: true,
                namePattern: new Regex(string.Format(@"^{0}@(?<at>\d+)", MaxDcg), RegexOptions.IgnoreCase),
                groupName: "at", nameFormat: string.Format("{0} @{{0}}", MaxDcg));
        }
 
        private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
            out Action<uint, ReadOnlyMemory<char>, Aggregator> addAgg, out Func<Dictionary<string, IDataView>> consolidate)
        {
            var stratCol = new List<uint>();
            var stratVal = new List<ReadOnlyMemory<char>>();
            var isWeighted = new List<bool>();
            var ndcg = new List<Double[]>();
            var dcg = new List<Double[]>();
 
            var groupName = new List<ReadOnlyMemory<char>>();
            var groupNdcg = new List<Double[]>();
            var groupDcg = new List<Double[]>();
            var groupMaxDcg = new List<Double[]>();
            var groupStratCol = new List<uint>();
            var groupStratVal = new List<ReadOnlyMemory<char>>();
 
            bool hasStrats = Utils.Size(dictionaries) > 0;
            bool hasWeight = aggregator.Weighted;
            bool groupSummary = aggregator.UnweightedCounters.GroupSummary;
 
            addAgg =
                (stratColKey, stratColVal, agg) =>
                {
                    Host.Check(agg.Weighted == hasWeight, "All aggregators must either be weighted or unweighted");
                    Host.Check(agg.UnweightedCounters.GroupSummary == aggregator.UnweightedCounters.GroupSummary,
                        "All aggregators must either compute group summary or not compute group summary");
 
                    stratCol.Add(stratColKey);
                    stratVal.Add(stratColVal);
                    isWeighted.Add(false);
                    ndcg.Add(agg.UnweightedCounters.Ndcg);
                    dcg.Add(agg.UnweightedCounters.Dcg);
                    if (agg.UnweightedCounters.GroupSummary)
                    {
                        groupStratCol.AddRange(agg.UnweightedCounters.GroupDcg.Select(x => stratColKey));
                        groupStratVal.AddRange(agg.UnweightedCounters.GroupDcg.Select(x => stratColVal));
                        groupName.AddRange(agg.GroupId.Select(sb => sb.ToString().AsMemory()));
                        groupNdcg.AddRange(agg.UnweightedCounters.GroupNdcg);
                        groupDcg.AddRange(agg.UnweightedCounters.GroupDcg);
                        groupMaxDcg.AddRange(agg.UnweightedCounters.GroupMaxDcg);
                    }
 
                    if (agg.Weighted)
                    {
                        stratCol.Add(stratColKey);
                        stratVal.Add(stratColVal);
                        isWeighted.Add(true);
                        ndcg.Add(agg.WeightedCounters.Ndcg);
                        dcg.Add(agg.WeightedCounters.Dcg);
                    }
                };
 
            consolidate =
                () =>
                {
                    var overallDvBldr = new ArrayDataViewBuilder(Host);
                    if (hasStrats)
                    {
                        overallDvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, stratCol.ToArray());
                        overallDvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, stratVal.ToArray());
                    }
                    if (hasWeight)
                        overallDvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BooleanDataViewType.Instance, isWeighted.ToArray());
                    overallDvBldr.AddColumn(Ndcg, aggregator.GetSlotNames, NumberDataViewType.Double, ndcg.ToArray());
                    overallDvBldr.AddColumn(Dcg, aggregator.GetSlotNames, NumberDataViewType.Double, dcg.ToArray());
 
                    var groupDvBldr = new ArrayDataViewBuilder(Host);
                    if (hasStrats)
                    {
                        groupDvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, groupStratCol.ToArray());
                        groupDvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, groupStratVal.ToArray());
                    }
                    if (groupSummary)
                    {
                        groupDvBldr.AddColumn(GroupId, TextDataViewType.Instance, groupName.ToArray());
                        groupDvBldr.AddColumn(Ndcg, aggregator.GetGroupSummarySlotNames("NDCG"), NumberDataViewType.Double, groupNdcg.ToArray());
                        groupDvBldr.AddColumn(Dcg, aggregator.GetGroupSummarySlotNames("DCG"), NumberDataViewType.Double, groupDcg.ToArray());
                        groupDvBldr.AddColumn(MaxDcg, aggregator.GetGroupSummarySlotNames("MaxDCG"), NumberDataViewType.Double, groupMaxDcg.ToArray());
                    }
 
                    var result = new Dictionary<string, IDataView>();
                    result.Add(MetricKinds.OverallMetrics, overallDvBldr.GetDataView());
                    if (groupSummary)
                        result.Add(GroupSummary, groupDvBldr.GetDataView());
                    return result;
                };
        }
 
        /// <summary>
        /// Evaluates scored regression data.
        /// </summary>
        /// <param name="data">The data to evaluate.</param>
        /// <param name="label">The name of the label column.</param>
        /// <param name="groupId">The name of the groupId column.</param>
        /// <param name="score">The name of the predicted score column.</param>
        /// <returns>The evaluation metrics for these outputs.</returns>
        public RankingMetrics Evaluate(IDataView data, string label, string groupId, string score)
        {
            Host.CheckValue(data, nameof(data));
            Host.CheckNonEmpty(label, nameof(label));
            Host.CheckNonEmpty(score, nameof(score));
            var roles = new RoleMappedData(data, opt: false,
                RoleMappedSchema.ColumnRole.Label.Bind(label),
                RoleMappedSchema.ColumnRole.Group.Bind(groupId),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score));
 
            var resultDict = ((IEvaluator)this).Evaluate(roles);
            Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
            var overall = resultDict[MetricKinds.OverallMetrics];
 
            RankingMetrics result;
            using (var cursor = overall.GetRowCursorForAllColumns())
            {
                var moved = cursor.MoveNext();
                Host.Assert(moved);
                result = new RankingMetrics(Host, cursor);
                moved = cursor.MoveNext();
                Host.Assert(!moved);
            }
            return result;
        }
 
        public sealed class Aggregator : AggregatorBase
        {
            public sealed class Counters
            {
                public readonly int TruncationLevel;
                private readonly List<Double[]> _groupNdcg;
                private readonly List<Double[]> _groupDcg;
                private readonly List<Double[]> _groupMaxDcg;
                private readonly Double[] _groupDcgCur;
                private readonly Double[] _groupMaxDcgCur;
 
                private readonly Double[] _sumNdcgAtN;
                private readonly Double[] _sumDcgAtN;
                private Double _sumWeights;
 
                private readonly List<short> _queryLabels;
                private readonly List<Single> _queryOutputs;
                private readonly Double[] _labelGains;
                private readonly Double[] _discountMap;
 
                public bool GroupSummary { get { return _groupNdcg != null; } }
 
                public Double[] Ndcg
                {
                    get
                    {
                        var ndcg = new Double[TruncationLevel];
                        for (int i = 0; i < TruncationLevel; i++)
                            ndcg[i] = _sumNdcgAtN[i] / _sumWeights;
                        return ndcg;
                    }
                }
 
                public Double[] Dcg
                {
                    get
                    {
                        var dcg = new Double[TruncationLevel];
                        for (int i = 0; i < TruncationLevel; i++)
                            dcg[i] = _sumDcgAtN[i] / _sumWeights;
                        return dcg;
                    }
                }
 
                public Double[][] GroupDcg
                {
                    get
                    {
                        if (_groupDcg == null)
                            return null;
                        return _groupDcg.ToArray();
                    }
                }
 
                public Double[][] GroupNdcg
                {
                    get
                    {
                        if (_groupNdcg == null)
                            return null;
                        return _groupNdcg.ToArray();
                    }
                }
 
                public Double[][] GroupMaxDcg
                {
                    get
                    {
                        if (_groupMaxDcg == null)
                            return null;
                        return _groupMaxDcg.ToArray();
                    }
                }
 
                public Counters(Double[] labelGains, int truncationLevel, bool groupSummary)
                {
                    Contracts.Assert(truncationLevel > 0);
                    Contracts.AssertValue(labelGains);
 
                    TruncationLevel = truncationLevel;
                    _discountMap = RankingUtils.GetDiscountMap(truncationLevel);
 
                    _sumDcgAtN = new Double[TruncationLevel];
                    _sumNdcgAtN = new Double[TruncationLevel];
 
                    _groupDcgCur = new Double[TruncationLevel];
                    _groupMaxDcgCur = new Double[TruncationLevel];
                    if (groupSummary)
                    {
                        _groupNdcg = new List<Double[]>();
                        _groupDcg = new List<Double[]>();
                        _groupMaxDcg = new List<Double[]>();
                    }
 
                    _queryLabels = new List<short>();
                    _queryOutputs = new List<Single>();
                    _labelGains = labelGains;
                }
 
                public void Update(short label, Single output)
                {
                    _queryLabels.Add(label);
                    _queryOutputs.Add(output);
                }
 
                public void UpdateGroup(Single weight)
                {
                    RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupMaxDcgCur);
                    if (_groupMaxDcg != null)
                    {
                        var maxDcg = new Double[TruncationLevel];
                        Array.Copy(_groupMaxDcgCur, maxDcg, TruncationLevel);
                        _groupMaxDcg.Add(maxDcg);
                    }
 
                    RankingUtils.QueryDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupDcgCur);
                    if (_groupDcg != null)
                    {
                        var groupDcg = new Double[TruncationLevel];
                        Array.Copy(_groupDcgCur, groupDcg, TruncationLevel);
                        _groupDcg.Add(groupDcg);
                    }
 
                    var groupNdcg = new Double[TruncationLevel];
                    for (int t = 0; t < TruncationLevel; t++)
                    {
                        Double ndcg = _groupMaxDcgCur[t] > 0 ? _groupDcgCur[t] / _groupMaxDcgCur[t] : 0;
                        _sumNdcgAtN[t] += ndcg * weight;
                        _sumDcgAtN[t] += _groupDcgCur[t] * weight;
                        groupNdcg[t] = ndcg;
                    }
                    _sumWeights += weight;
 
                    if (_groupNdcg != null)
                        _groupNdcg.Add(groupNdcg);
 
                    _queryLabels.Clear();
                    _queryOutputs.Clear();
                }
            }
 
            private Single _currentQueryWeight;
 
            private ValueGetter<Single> _labelGetter;
            private ValueGetter<Single> _scoreGetter;
            private ValueGetter<Single> _weightGetter;
            private Func<bool> _newGroupDel;
            private Action _groupSbUpdate;
            private StringBuilder _groupSb;
 
            public readonly Counters UnweightedCounters;
            public readonly Counters WeightedCounters;
            public readonly bool Weighted;
            public readonly List<ReadOnlyMemory<char>> GroupId;
            private int _groupSize;
 
            public Aggregator(IHostEnvironment env, Double[] labelGains, int truncationLevel, bool groupSummary, bool weighted, string stratName)
                : base(env, stratName)
            {
                Host.AssertValue(labelGains);
                Host.Assert(truncationLevel > 0);
 
                UnweightedCounters = new Counters(labelGains, truncationLevel, groupSummary);
                Weighted = weighted;
                WeightedCounters = Weighted ? new Counters(labelGains, truncationLevel, false) : null;
 
                _currentQueryWeight = Single.NaN;
 
                if (groupSummary)
                    GroupId = new List<ReadOnlyMemory<char>>();
            }
 
            internal override void InitializeNextPass(DataViewRow row, RoleMappedSchema schema)
            {
                Contracts.Assert(PassNum < 1);
                Contracts.Assert(schema.Label.HasValue);
                Contracts.Assert(schema.Group.HasValue);
 
                var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
 
                _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index);
                _scoreGetter = row.GetGetter<Single>(score);
                _newGroupDel = RowCursorUtils.GetIsNewGroupDelegate(row, schema.Group.Value.Index);
                if (schema.Weight.HasValue)
                    _weightGetter = row.GetGetter<Single>(schema.Weight.Value);
 
                if (UnweightedCounters.GroupSummary)
                {
                    ValueGetter<StringBuilder> groupIdBuilder = RowCursorUtils.GetGetterAsStringBuilder(row, schema.Group.Value.Index);
                    _groupSbUpdate = () => groupIdBuilder(ref _groupSb);
                }
                else
                    _groupSbUpdate = () => { };
            }
 
            public override void ProcessRow()
            {
                if (_newGroupDel())
                {
                    if (_groupSize > 0)
                    {
                        ProcessGroup();
                        _groupSize = 0;
                    }
                    _groupSbUpdate();
                }
 
                Single label = 0;
                Single score = 0;
                _labelGetter(ref label);
                _scoreGetter(ref score);
 
                if (Single.IsNaN(score))
                {
                    NumBadScores++;
                    return;
                }
 
                UnweightedCounters.Update((short)label, score);
                if (WeightedCounters != null)
                    WeightedCounters.Update((short)label, score);
                _groupSize++;
                Single weight = 1;
                if (_weightGetter != null)
                {
                    _weightGetter(ref weight);
                    if (Single.IsNaN(_currentQueryWeight))
                        _currentQueryWeight = weight;
                    else
                        Contracts.Check(weight == _currentQueryWeight, "Weights within query differ");
                }
            }
 
            private void ProcessGroup()
            {
                UnweightedCounters.UpdateGroup(1);
                if (WeightedCounters != null)
                    WeightedCounters.UpdateGroup(_currentQueryWeight);
                if (GroupId != null)
                    GroupId.Add(_groupSb.ToString().AsMemory());
                _currentQueryWeight = Single.NaN;
            }
 
            protected override void FinishPassCore()
            {
                base.FinishPassCore();
                if (_groupSize > 0)
                    ProcessGroup();
            }
 
            public ValueGetter<VBuffer<ReadOnlyMemory<char>>> GetGroupSummarySlotNames(string prefix)
            {
                return
                    (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                    {
                        var editor = VBufferEditor.Create(ref dst, UnweightedCounters.TruncationLevel);
                        for (int i = 0; i < UnweightedCounters.TruncationLevel; i++)
                            editor.Values[i] = string.Format("{0}@{1}", prefix, i + 1).AsMemory();
                        dst = editor.Commit();
                    };
            }
 
            public void GetSlotNames(ref VBuffer<ReadOnlyMemory<char>> slotNames)
            {
                var editor = VBufferEditor.Create(ref slotNames, UnweightedCounters.TruncationLevel);
                for (int i = 0; i < UnweightedCounters.TruncationLevel; i++)
                    editor.Values[i] = string.Format("@{0}", i + 1).AsMemory();
                slotNames = editor.Commit();
            }
        }
    }
 
    internal sealed class RankingPerInstanceTransform : IDataTransform
    {
        public const string LoaderSignature = "RankerPerInstTransform";
        private const string RegistrationName = LoaderSignature;
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "RNK INST",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(RankingPerInstanceTransform).Assembly.FullName);
        }
 
        public const string Ndcg = "NDCG";
        public const string Dcg = "DCG";
        public const string MaxDcg = "MaxDCG";
 
        private readonly Transform _transform;
 
        public IDataView Source { get { return _transform.Source; } }
 
        public bool CanShuffle { get { return _transform.CanShuffle; } }
 
        /// <summary>
        /// Explicit implementation prevents Schema from being accessed from derived classes.
        /// It's our first step to separate data produced by transform from transform.
        /// </summary>
        DataViewSchema IDataView.Schema => OutputSchema;
 
        /// <summary>
        /// Shape information of the produced output. Note that the input and the output of this transform (and their types) are identical.
        /// </summary>
        public DataViewSchema OutputSchema => _transform.OutputSchema;
 
        public RankingPerInstanceTransform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol,
                int truncationLevel, Double[] labelGains)
        {
            _transform = new Transform(env, input, labelCol, scoreCol, groupCol, truncationLevel, labelGains);
        }
 
        private RankingPerInstanceTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            _transform = new Transform(env, ctx, input);
        }
 
        public static RankingPerInstanceTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
            h.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            h.CheckValue(input, nameof(input));
            return h.Apply("Loading Model", ch => new RankingPerInstanceTransform(h, ctx, input));
        }
 
        void ICanSaveModel.Save(ModelSaveContext ctx)
        {
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
            ((ICanSaveModel)_transform).Save(ctx);
        }
 
        public long? GetRowCount()
        {
            return _transform.GetRowCount();
        }
 
        public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
            => _transform.GetRowCursor(columnsNeeded, rand);
 
        public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
            => _transform.GetRowCursorSet(columnsNeeded, n, rand);
 
        private sealed class Transform : PerGroupTransformBase<short, Single, Transform.RowCursorState>
        {
            private sealed class Bindings : BindingsBase
            {
                private readonly DataViewType _outputType;
                private readonly DataViewType _slotNamesType;
                private readonly int _truncationLevel;
                private readonly AnnotationUtils.AnnotationGetter<VBuffer<ReadOnlyMemory<char>>> _slotNamesGetter;
 
                public Bindings(IExceptionContext ectx, DataViewSchema input, bool user, string labelCol, string scoreCol, string groupCol,
                    int truncationLevel)
                    : base(ectx, input, labelCol, scoreCol, groupCol, user, Ndcg, Dcg, MaxDcg)
                {
                    _truncationLevel = truncationLevel;
                    _outputType = new VectorDataViewType(NumberDataViewType.Double, _truncationLevel);
                    _slotNamesType = new VectorDataViewType(TextDataViewType.Instance, _truncationLevel);
                    _slotNamesGetter = SlotNamesGetter;
                }
 
                protected override DataViewType GetColumnTypeCore(int iinfo)
                {
                    Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
                    return _outputType;
                }
 
                protected override IEnumerable<KeyValuePair<string, DataViewType>> GetAnnotationTypesCore(int iinfo)
                {
                    Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
                    var types = base.GetAnnotationTypesCore(iinfo);
                    types = types.Prepend(_slotNamesType.GetPair(AnnotationUtils.Kinds.SlotNames));
                    return types;
                }
 
                protected override DataViewType GetAnnotationTypeCore(string kind, int iinfo)
                {
                    Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
                    if (kind == AnnotationUtils.Kinds.SlotNames)
                        return _slotNamesType;
                    return base.GetAnnotationTypeCore(kind, iinfo);
                }
 
                protected override void GetAnnotationCore<TValue>(string kind, int iinfo, ref TValue value)
                {
                    Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
                    if (kind == AnnotationUtils.Kinds.SlotNames)
                    {
                        _slotNamesGetter.Marshal(iinfo, ref value);
                        return;
                    }
                    base.GetAnnotationCore(kind, iinfo, ref value);
                }
 
                private void SlotNamesGetter(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
                {
                    Contracts.Assert(0 <= iinfo && iinfo < InfoCount);
                    var editor = VBufferEditor.Create(ref dst, _truncationLevel);
                    for (int i = 0; i < _truncationLevel; i++)
                        editor.Values[i] =
                            string.Format("{0}@{1}", iinfo == NdcgCol ? Ndcg : iinfo == DcgCol ? Dcg : MaxDcg,
                                i + 1).AsMemory();
                    dst = editor.Commit();
                }
            }
 
            private const int NdcgCol = 0;
            private const int DcgCol = 1;
            private const int MaxDcgCol = 2;
 
            private readonly Bindings _bindings;
            private readonly int _truncationLevel;
            private readonly Double[] _discountMap;
            private readonly Double[] _labelGains;
 
            public Transform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol,
                int truncationLevel, Double[] labelGains)
                : base(env, input, labelCol, scoreCol, groupCol, RegistrationName)
            {
                Host.CheckParam(0 < truncationLevel, nameof(truncationLevel),
                    "Truncation level must be greater than 0");
                Host.CheckValue(labelGains, nameof(labelGains));
 
                _truncationLevel = truncationLevel;
                _discountMap = RankingUtils.GetDiscountMap(_truncationLevel);
                _labelGains = labelGains;
                _bindings = new Bindings(Host, Source.Schema, true, LabelCol, ScoreCol, GroupCol, _truncationLevel);
            }
 
            public Transform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
                : base(env, ctx, input, RegistrationName)
            {
                // *** Binary format ***
                // base
                // int: _truncationLevel
                // int: _labelGains.Length
                // double[]: _labelGains
 
                _truncationLevel = ctx.Reader.ReadInt32();
                Host.CheckDecode(0 < _truncationLevel);
                _labelGains = ctx.Reader.ReadDoubleArray();
                _bindings = new Bindings(Host, input.Schema, false, LabelCol, ScoreCol, GroupCol, _truncationLevel);
            }
 
            private protected override void SaveModel(ModelSaveContext ctx)
            {
                Host.AssertValue(ctx);
 
                // *** Binary format ***
                // base
                // int: _truncationLevel
                // int: _labelGains.Length
                // double[]: _labelGains
 
                base.SaveModel(ctx);
                Host.Assert(0 < _truncationLevel);
                ctx.Writer.Write(_truncationLevel);
                ctx.Writer.WriteDoubleArray(_labelGains);
            }
 
            private protected override BindingsBase GetBindings()
            {
                return _bindings;
            }
 
            protected override Delegate[] CreateGetters(RowCursorState state, Func<int, bool> predicate)
            {
                var getters = new Delegate[_bindings.InfoCount];
 
                if (predicate(NdcgCol))
                {
                    var ndcg = state.NdcgCur;
                    ValueGetter<VBuffer<Double>> fn =
                        (ref VBuffer<Double> dst) => Copy(ndcg, ref dst);
                    getters[NdcgCol] = fn;
                }
                if (predicate(DcgCol))
                {
                    var dcg = state.DcgCur;
                    ValueGetter<VBuffer<Double>> fn = (ref VBuffer<Double> dst) => Copy(dcg, ref dst);
                    getters[DcgCol] = fn;
                }
                if (predicate(MaxDcgCol))
                {
                    var maxDcg = state.MaxDcgCur;
                    ValueGetter<VBuffer<Double>> fn = (ref VBuffer<Double> dst) => Copy(maxDcg, ref dst);
                    getters[MaxDcgCol] = fn;
                }
                return getters;
            }
 
            private void Copy(Double[] src, ref VBuffer<Double> dst)
            {
                Host.AssertValue(src);
                var editor = VBufferEditor.Create(ref dst, src.Length);
                src.CopyTo(editor.Values);
                dst = editor.Commit();
            }
 
            protected override ValueGetter<short> GetLabelGetter(DataViewRow row)
            {
                var lb = RowCursorUtils.GetLabelGetter(row, _bindings.LabelIndex);
                return
                    (ref short dst) =>
                    {
                        Single label = 0;
                        lb(ref label);
                        dst = (short)label;
                    };
            }
 
            protected override ValueGetter<Single> GetScoreGetter(DataViewRow row)
            {
                return row.GetGetter<Single>(row.Schema[_bindings.ScoreIndex]);
            }
 
            protected override RowCursorState InitializeState(DataViewRow input)
            {
                return new RowCursorState(_truncationLevel);
            }
 
            protected override void ProcessExample(RowCursorState state, short label, Single score)
            {
                state.QueryLabels.Add(label);
                state.QueryOutputs.Add(score);
            }
 
            protected override void UpdateState(RowCursorState state)
            {
                // Calculate the current group DCG, NDCG and MaxDcg.
                RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs,
                    state.MaxDcgCur);
                RankingUtils.QueryDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs, state.DcgCur);
                for (int t = 0; t < _truncationLevel; t++)
                {
                    Double ndcg = state.MaxDcgCur[t] > 0 ? state.DcgCur[t] / state.MaxDcgCur[t] : 0;
                    state.NdcgCur[t] = ndcg;
                }
                state.QueryLabels.Clear();
                state.QueryOutputs.Clear();
            }
 
            public sealed class RowCursorState
            {
                public readonly List<short> QueryLabels;
                public readonly List<Single> QueryOutputs;
 
                public readonly Double[] NdcgCur;
                public readonly Double[] DcgCur;
                public readonly Double[] MaxDcgCur;
 
                public RowCursorState(int truncationLevel)
                {
                    Contracts.Assert(0 < truncationLevel);
 
                    QueryLabels = new List<short>();
                    QueryOutputs = new List<Single>();
 
                    NdcgCur = new Double[truncationLevel];
                    DcgCur = new Double[truncationLevel];
                    MaxDcgCur = new Double[truncationLevel];
                }
            }
        }
    }
 
    [BestFriend]
    internal sealed class RankingMamlEvaluator : MamlEvaluatorBase
    {
        public sealed class Arguments : ArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for the group ID", ShortName = "group")]
            public string GroupIdColumn;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Maximum truncation level for computing (N)DCG", ShortName = "t")]
            public int DcgTruncationLevel = 10;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Label relevance gains", ShortName = "gains")]
            public string LabelGains = "0,3,7,15,31";
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Group summary filename", ShortName = "gsf", Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
            public string GroupSummaryFilename;
        }
 
        private readonly RankingEvaluator _evaluator;
        private readonly string _groupIdCol;
 
        private readonly string _groupSummaryFilename;
 
        private protected override IEvaluator Evaluator => _evaluator;
 
        public RankingMamlEvaluator(IHostEnvironment env, Arguments args)
            : base(args, env, AnnotationUtils.Const.ScoreColumnKind.Ranking, "RankerMamlEvaluator")
        {
            Host.CheckValue(args, nameof(args));
            Utils.CheckOptionalUserDirectory(args.GroupSummaryFilename, nameof(args.GroupSummaryFilename));
 
            var evalOpts = new RankingEvaluatorOptions();
            evalOpts.DcgTruncationLevel = args.DcgTruncationLevel;
            evalOpts.LabelGains = args.LabelGains;
            evalOpts.OutputGroupSummary = !string.IsNullOrEmpty(args.GroupSummaryFilename);
 
            _evaluator = new RankingEvaluator(Host, evalOpts);
            _groupSummaryFilename = args.GroupSummaryFilename;
            _groupIdCol = args.GroupIdColumn;
        }
 
        private protected override IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRolesCore(RoleMappedSchema schema)
        {
            var cols = base.GetInputColumnRolesCore(schema);
            var groupIdCol = EvaluateUtils.GetColName(_groupIdCol, schema.Group, DefaultColumnNames.GroupId);
            return cols.Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupIdCol));
        }
 
        private protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary<string, IDataView>[] metrics)
        {
            ch.AssertNonEmpty(metrics);
 
            if (!string.IsNullOrEmpty(_groupSummaryFilename))
            {
                IDataView gs;
                if (!TryGetGroupSummaryMetrics(metrics, out gs))
                    throw ch.Except("Did not find group summary metrics");
 
                ch.Trace("Saving group-summary results");
                // If the data view contains stratification columns, filter so that only the overall metrics
                // will be present, and drop them.
                gs = MetricWriter.GetNonStratifiedMetrics(Host, gs);
                MetricWriter.SavePerInstance(Host, ch, _groupSummaryFilename, gs);
            }
        }
 
        private bool TryGetGroupSummaryMetrics(Dictionary<string, IDataView>[] metrics, out IDataView gs)
        {
            Host.AssertNonEmpty(metrics);
 
            if (metrics.Length == 1)
                return metrics[0].TryGetValue(RankingEvaluator.GroupSummary, out gs);
 
            gs = null;
            var gsList = new List<IDataView>();
            for (int i = 0; i < metrics.Length; i++)
            {
                IDataView idv;
                if (!metrics[i].TryGetValue(RankingEvaluator.GroupSummary, out idv))
                    return false;
 
                idv = EvaluateUtils.AddFoldIndex(Host, idv, i, metrics.Length);
                gsList.Add(idv);
            }
            gs = AppendRowsDataView.Create(Host, gsList[0].Schema, gsList.ToArray());
            return true;
        }
 
        private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
        {
            Host.CheckValue(schema, nameof(schema));
            Host.CheckParam(schema.Label.HasValue, nameof(schema), "Data must contain a label column");
            Host.CheckParam(schema.Group.HasValue, nameof(schema), "Data must contain a group column");
 
            // The ranking evaluator outputs the label, group key and score columns.
            yield return schema.Group.Value.Name;
            yield return schema.Label.Value.Name;
            var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
                AnnotationUtils.Const.ScoreColumnKind.Ranking);
            yield return scoreCol.Name;
 
            // Return the output columns.
            yield return RankingPerInstanceTransform.Ndcg;
            yield return RankingPerInstanceTransform.Dcg;
            yield return RankingPerInstanceTransform.MaxDcg;
        }
    }
 
    internal static class RankingUtils
    {
        // Truncation levels are typically less than 100. So we maintain a fixed discount map of size 100
        // If truncation level greater than 100 is required, we build a new one and return that.
        private const int FixedDiscountMapSize = 100;
        private static Double[] _discountMapFixed;
 
        private static Double[] GetDiscountMapCore(int truncationLevel)
        {
            var discountMap = new Double[truncationLevel];
 
            for (int i = 0; i < discountMap.Length; i++)
                discountMap[i] = 1 / Math.Log(2 + i);
 
            return discountMap;
        }
 
        public static Double[] GetDiscountMap(int truncationLevel)
        {
            var discountMap = _discountMapFixed;
            if (discountMap == null)
            {
                discountMap = GetDiscountMapCore(FixedDiscountMapSize);
                Interlocked.CompareExchange(ref _discountMapFixed, discountMap, null);
                discountMap = _discountMapFixed;
            }
 
            if (truncationLevel <= discountMap.Length)
                return discountMap;
 
            return GetDiscountMapCore(truncationLevel);
        }
 
        /// <summary>
        /// Calculates natural-based max DCG at all truncations from 1 to truncationLevel.
        /// </summary>
        public static void QueryMaxDcg(Double[] labelGains, int truncationLevel, Double[] discountMap,
            List<short> queryLabels, List<Single> queryOutputs, Double[] groupMaxDcgCur)
        {
            Contracts.Assert(Utils.Size(groupMaxDcgCur) == truncationLevel);
 
            int relevancyLevel = labelGains.Length;
 
            int[] labelCounts = new int[relevancyLevel];
 
            int maxTrunc = Math.Min(truncationLevel, queryLabels.Count);
 
            if (maxTrunc == 0)
            {
                for (int t = 0; t < truncationLevel; t++)
                    groupMaxDcgCur[t] = Double.NaN;
            }
            else
            {
                for (int l = 0; l < queryLabels.Count; l++)
                    labelCounts[queryLabels[l]]++;
 
                int topLabel = labelGains.Length - 1;
                while (labelCounts[topLabel] == 0)
                    topLabel--;
 
                groupMaxDcgCur[0] = labelGains[topLabel] * discountMap[0];
                labelCounts[topLabel]--;
                for (int t = 1; t < maxTrunc; t++)
                {
                    while (labelCounts[topLabel] == 0)
                        topLabel--;
                    groupMaxDcgCur[t] = groupMaxDcgCur[t - 1] + labelGains[topLabel] * discountMap[t];
                    labelCounts[topLabel]--;
                }
                for (int t = maxTrunc; t < truncationLevel; t++)
                    groupMaxDcgCur[t] = groupMaxDcgCur[t - 1];
            }
        }
 
        public static void QueryDcg(Double[] labelGains, int truncationLevel, Double[] discountMap,
            List<short> queryLabels, List<Single> queryOutputs, Double[] groupDcgCur)
        {
            // calculate the permutation
            int count = queryLabels.Count;
            int[] permutation = Utils.GetIdentityPermutation(count);
            Array.Sort(permutation, GetCompareItems(queryLabels, queryOutputs));
 
            if (count > truncationLevel)
                count = truncationLevel;
            Double dcg = 0;
            for (int t = 0; t < count; ++t)
            {
                dcg = dcg + labelGains[queryLabels[permutation[t]]] * discountMap[t];
                groupDcgCur[t] = dcg;
            }
            for (int t = count; t < truncationLevel; ++t)
                groupDcgCur[t] = dcg;
        }
 
        // Used for sorting.
        private static Comparison<int> GetCompareItems(List<short> queryLabels, List<Single> queryOutputs)
        {
            return
                (i, j) =>
                {
                    Contracts.Assert(0 <= i && i < queryLabels.Count && i < queryOutputs.Count);
                    Contracts.Assert(0 <= j && j < queryLabels.Count && j < queryOutputs.Count);
 
                    if (queryOutputs[i] > queryOutputs[j])
                        return -1;
                    if (queryOutputs[i] < queryOutputs[j])
                        return 1;
                    if (queryLabels[i] < queryLabels[j])
                        return -1;
                    if (queryLabels[i] > queryLabels[j])
                        return 1;
                    return i.CompareTo(j);
 
                };
        }
    }
 
    internal static partial class Evaluate
    {
        [TlcModule.EntryPoint(Name = "Models.RankingEvaluator", Desc = "Evaluates a ranking scored dataset.")]
        public static CommonOutputs.CommonEvaluateOutput Ranking(IHostEnvironment env, RankingMamlEvaluator.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("EvaluateRanker");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            string label;
            string weight;
            string name;
            MatchColumns(host, input, out label, out weight, out name);
            var schema = input.Data.Schema;
            string groupId = TrainUtils.MatchNameOrDefaultOrNull(host, schema,
                nameof(RankingMamlEvaluator.Arguments.GroupIdColumn),
                input.GroupIdColumn, DefaultColumnNames.GroupId);
            IMamlEvaluator evaluator = new RankingMamlEvaluator(host, input);
            var data = new RoleMappedData(input.Data, label, null, groupId, weight, name);
            var metrics = evaluator.Evaluate(data);
 
            var warnings = ExtractWarnings(host, metrics);
            var overallMetrics = ExtractOverallMetrics(host, metrics, evaluator);
            var perInstanceMetrics = evaluator.GetPerInstanceMetrics(data);
 
            return new CommonOutputs.CommonEvaluateOutput()
            {
                Warnings = warnings,
                OverallMetrics = overallMetrics,
                PerInstanceMetrics = perInstanceMetrics
            };
        }
    }
}