File: Evaluators\EvaluatorBase.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// This is a base class for TLC evaluators. It implements both of the <see cref="IEvaluator"/> methods: <see cref="Evaluate"/> and
    ///  <see cref="GetPerInstanceMetricsCore"/>. Note that the input <see cref="RoleMappedData"/> is assumed to contain all the column
    /// roles needed for evaluation, including the score column.
    /// </summary>
    [BestFriend]
    internal abstract partial class EvaluatorBase<TAgg> : IEvaluator
        where TAgg : EvaluatorBase<TAgg>.AggregatorBase
    {
        protected readonly IHost Host;
 
        [BestFriend]
        private protected EvaluatorBase(IHostEnvironment env, string registrationName)
        {
            Contracts.CheckValue(env, nameof(env));
            Host = env.Register(registrationName);
        }
 
        Dictionary<string, IDataView> IEvaluator.Evaluate(RoleMappedData data)
        {
            CheckColumnTypes(data.Schema);
            Func<int, bool> activeCols = GetActiveCols(data.Schema);
            var agg = GetAggregator(data.Schema);
            AggregatorDictionaryBase[] dictionaries = GetAggregatorDictionaries(data.Schema);
 
            var dict = ProcessData(data.Data, data.Schema, activeCols, agg, dictionaries);
            agg.GetWarnings(dict, Host);
            return dict;
        }
 
        /// <summary>
        /// Checks the column types of the evaluator's input columns. The base class implementation checks only the type
        /// of the weight column, and all other columns should be checked by the deriving classes in <see cref="CheckCustomColumnTypesCore"/>.
        /// </summary>
        [BestFriend]
        private protected void CheckColumnTypes(RoleMappedSchema schema)
        {
            // Check the weight column type.
            if (schema.Weight.HasValue)
                EvaluateUtils.CheckWeightType(Host, schema.Weight.Value.Type);
            CheckScoreAndLabelTypes(schema);
            // Check the other column types.
            CheckCustomColumnTypesCore(schema);
        }
 
        /// <summary>
        /// Check that the types of the score and label columns are as expected by the evaluator. The <see cref="RoleMappedSchema"/>
        /// is assumed to contain the label column (if it exists) and the score column.
        /// Access the label column with the <see cref="RoleMappedSchema.Label"/> property, and the score column with the
        /// <see cref="RoleMappedSchema.GetUniqueColumn"/> or <see cref="RoleMappedSchema.GetColumns"/> methods.
        /// </summary>
        [BestFriend]
        private protected abstract void CheckScoreAndLabelTypes(RoleMappedSchema schema);
 
        /// <summary>
        /// Check the types of any other columns needed by the evaluator. Only override if the evaluator uses
        /// columns other than label, score and weight.
        /// </summary>
        [BestFriend]
        private protected virtual void CheckCustomColumnTypesCore(RoleMappedSchema schema)
        {
        }
 
        private Func<int, bool> GetActiveCols(RoleMappedSchema schema)
        {
            Func<int, bool> pred = GetActiveColsCore(schema);
            var stratCols = schema.GetColumns(MamlEvaluatorBase.Strat);
            var stratIndices = Utils.Size(stratCols) > 0 ? new HashSet<int>(stratCols.Select(col => col.Index)) : new HashSet<int>();
            return i => pred(i) || stratIndices.Contains(i);
        }
 
        /// <summary>
        /// Used in the Evaluate() method, to get the predicate for cursoring over the data.
        /// The base class implementation activates the score column, the label column if it exists, the weight column if it exists
        /// and the stratification columns.
        /// Override if other input columns need to be activated.
        /// </summary>
        [BestFriend]
        private protected virtual Func<int, bool> GetActiveColsCore(RoleMappedSchema schema)
        {
            var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
            int label = schema.Label?.Index ?? -1;
            int weight = schema.Weight?.Index ?? -1;
            return i => i == score.Index || i == label || i == weight;
        }
 
        /// <summary>
        /// Get an aggregator for the specific evaluator given the current RoleMappedSchema.
        /// </summary>
        private TAgg GetAggregator(RoleMappedSchema schema)
        {
            return GetAggregatorCore(schema, "");
        }
 
        /// <summary>
        /// For each stratification column, get an aggregator dictionary.
        /// </summary>
        private AggregatorDictionaryBase[] GetAggregatorDictionaries(RoleMappedSchema schema)
        {
            var list = new List<AggregatorDictionaryBase>();
            var stratCols = schema.GetColumns(MamlEvaluatorBase.Strat);
            if (Utils.Size(stratCols) > 0)
            {
                Func<string, TAgg> createAgg = stratName => GetAggregatorCore(schema, stratName);
                foreach (var stratCol in stratCols)
                    list.Add(AggregatorDictionaryBase.Create(schema, stratCol.Name, stratCol.Type, createAgg));
            }
            return list.ToArray();
        }
 
        [BestFriend]
        private protected abstract TAgg GetAggregatorCore(RoleMappedSchema schema, string stratName);
 
        // This method does as many passes over the data as needed by the evaluator, and computes the metrics, outputting the
        // results in a dictionary from the metric kind (overal/per-fold/confusion matrix/PR-curves etc.), to a data view containing
        // the metric. If there are stratified metrics, an additional column is added to the data view containing the
        // stratification value as text in the format "column x = y".
        private Dictionary<string, IDataView> ProcessData(IDataView data, RoleMappedSchema schema,
            Func<int, bool> activeColsIndices, TAgg aggregator, AggregatorDictionaryBase[] dictionaries)
        {
            Func<bool> finishPass =
                () =>
                {
                    var need = aggregator.FinishPass();
                    foreach (var agg in dictionaries.SelectMany(dict => dict.GetAll()))
                        need |= agg.FinishPass();
                    return need;
                };
 
            bool needMorePasses = aggregator.Start();
 
            var activeCols = data.Schema.Where(x => activeColsIndices(x.Index));
            // REVIEW: Add progress reporting.
            while (needMorePasses)
            {
                using (var cursor = data.GetRowCursor(activeCols))
                {
                    if (aggregator.IsActive())
                        aggregator.InitializeNextPass(cursor, schema);
                    for (int i = 0; i < Utils.Size(dictionaries); i++)
                    {
                        dictionaries[i].Reset(cursor);
 
                        foreach (var agg in dictionaries[i].GetAll())
                        {
                            if (agg.IsActive())
                                agg.InitializeNextPass(cursor, schema);
                        }
                    }
                    while (cursor.MoveNext())
                    {
                        if (aggregator.IsActive())
                            aggregator.ProcessRow();
                        for (int i = 0; i < Utils.Size(dictionaries); i++)
                        {
                            var agg = dictionaries[i].Get();
                            if (agg.IsActive())
                                agg.ProcessRow();
                        }
                    }
                }
                needMorePasses = finishPass();
            }
 
            Action<uint, ReadOnlyMemory<char>, TAgg> addAgg;
            Func<Dictionary<string, IDataView>> consolidate;
            GetAggregatorConsolidationFuncs(aggregator, dictionaries, out addAgg, out consolidate);
 
            uint stratColKey = 0;
            addAgg(stratColKey, default, aggregator);
            for (int i = 0; i < Utils.Size(dictionaries); i++)
            {
                var dict = dictionaries[i];
                stratColKey++;
                foreach (var agg in dict.GetAll())
                    addAgg(stratColKey, agg.StratName.AsMemory(), agg);
            }
            return consolidate();
        }
 
        /// <summary>
        /// This method returns two functions used to create the data views of metrics computed by the different
        /// aggregators (the overall one, and any stratified ones if they exist). The <paramref name="addAgg"/>
        /// function is called for every aggregator, and it is where the aggregators should finish their aggregations
        /// and the aggregator results should be stored. The <paramref name="consolidate"/> function
        /// is called after <paramref name="addAgg"/> has been called on all the aggregators, and it returns
        /// the dictionary of metric data views.
        /// </summary>
        [BestFriend]
        private protected abstract void GetAggregatorConsolidationFuncs(TAgg aggregator, AggregatorDictionaryBase[] dictionaries,
            out Action<uint, ReadOnlyMemory<char>, TAgg> addAgg, out Func<Dictionary<string, IDataView>> consolidate);
 
        [BestFriend]
        private protected ValueGetter<VBuffer<ReadOnlyMemory<char>>> GetKeyValueGetter(AggregatorDictionaryBase[] dictionaries)
        {
            if (Utils.Size(dictionaries) == 0)
                return null;
            return
                (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                {
                    var editor = VBufferEditor.Create(ref dst, dictionaries.Length);
                    for (int i = 0; i < dictionaries.Length; i++)
                        editor.Values[i] = dictionaries[i].ColName.AsMemory();
                    dst = editor.Commit();
                };
        }
 
        IDataTransform IEvaluator.GetPerInstanceMetrics(RoleMappedData data) => GetPerInstanceMetricsCore(data);
 
        [BestFriend]
        internal abstract IDataTransform GetPerInstanceMetricsCore(RoleMappedData data);
 
        public abstract IEnumerable<MetricColumn> GetOverallMetricColumns();
 
        /// <summary>
        /// This is a helper class for evaluators deriving from EvaluatorBase, used for computing aggregate metrics.
        /// Aggregators should keep track of the number of passes done. The <see cref="InitializeNextPass"/> method should get
        /// the input getters of the given <see cref="DataViewRow"/> that are needed for the current pass, assuming that all the needed column
        /// information is stored in the given <see cref="RoleMappedSchema"/>.
        /// In <see cref="ProcessRow"/> the aggregator should call the getters once, and process the input as needed.
        /// <see cref="FinishPass"/> increments the pass count after each pass.
        /// </summary>
        public abstract class AggregatorBase
        {
            public readonly string StratName;
 
            protected long NumUnlabeledInstances;
            protected long NumBadScores;
            protected long NumBadWeights;
 
            protected readonly IHost Host;
 
            protected int PassNum;
 
            [BestFriend]
            private protected AggregatorBase(IHostEnvironment env, string stratName)
            {
                Contracts.AssertValue(env);
                Host = env.Register("Aggregator");
                Host.AssertValueOrNull(stratName);
 
                PassNum = -1;
                StratName = stratName;
            }
 
            public bool Start()
            {
                Host.Check(PassNum == -1, "Start() should only be called before processing any data.");
                PassNum = 0;
                return IsActive();
            }
 
            /// <summary>
            /// This method should get the getters of the new <see cref="DataViewRow"/> that are needed for the next pass.
            /// </summary>
            [BestFriend]
            internal abstract void InitializeNextPass(DataViewRow row, RoleMappedSchema schema);
 
            /// <summary>
            /// Call the getters once, and process the input as necessary.
            /// </summary>
            public abstract void ProcessRow();
 
            /// <summary>
            /// Increment the pass count. Return true if additional passes are needed.
            /// </summary>
            public bool FinishPass()
            {
                FinishPassCore();
                PassNum++;
                return IsActive();
            }
 
            // REVIEW: A more proper way to do this is to make this method protected, and have the AggregatorDictionary
            // class maintain the information about which aggregator is done and have the Get() method return either the appropriate aggregator
            // or null.
            public virtual bool IsActive()
            {
                return PassNum < 1;
            }
 
            protected virtual void FinishPassCore()
            {
                Host.Assert(PassNum < 1);
            }
 
            /// <summary>
            /// Returns a dictionary from metric kinds to data views containing the metrics.
            /// </summary>
            //public abstract Dictionary<string, IDataView> Finish();
 
            public void GetWarnings(Dictionary<string, IDataView> dict, IHostEnvironment env)
            {
                var warnings = GetWarningsCore();
                if (Utils.Size(warnings) > 0)
                {
                    var dvBldr = new ArrayDataViewBuilder(env);
                    dvBldr.AddColumn(MetricKinds.ColumnNames.WarningText, TextDataViewType.Instance,
                        warnings.Select(s => s.AsMemory()).ToArray());
                    dict.Add(MetricKinds.Warnings, dvBldr.GetDataView());
                }
            }
 
            protected virtual List<string> GetWarningsCore()
            {
                var warnings = new List<string>();
                if (NumUnlabeledInstances > 0)
                    warnings.Add(string.Format("Encountered {0} unlabeled instances during testing.", NumUnlabeledInstances));
 
                if (NumBadWeights > 0)
                    warnings.Add(string.Format("Encountered {0} non-finite weights during testing. These weights have been replaced with 1.", NumBadWeights));
 
                if (NumBadScores > 0)
                {
                    warnings.Add(string.Format("The predictor produced non-finite prediction values on {0} instances during testing. " +
                        "Possible causes: abnormal data or the predictor is numerically unstable.", NumBadScores));
                }
                return warnings;
            }
        }
 
        // This class is a dictionary for aggregators that are used to compute aggregate metrics on stratified subsets
        // of the data. The dictionary holds a getter for a stratification column, and when the Get() method is called,
        // it calls this getter, and returns the appropriate aggregator based on the value in the stratification column.
        // When a new value is encountered, it uses a callback for creating a new aggregator.
        protected abstract class AggregatorDictionaryBase
        {
            private protected DataViewRow Row;
            private protected readonly Func<string, TAgg> CreateAgg;
            private protected readonly RoleMappedSchema Schema;
 
            public string ColName { get; }
 
            public abstract int Count { get; }
 
            private protected AggregatorDictionaryBase(RoleMappedSchema schema, string stratCol, Func<string, TAgg> createAgg)
            {
                Contracts.AssertValue(schema);
                Contracts.AssertNonWhiteSpace(stratCol);
                Contracts.AssertValue(createAgg);
 
                Schema = schema;
                CreateAgg = createAgg;
                ColName = stratCol;
            }
 
            /// <summary>
            /// Gets the stratification column getter for the new <see cref="DataViewRow"/>.
            /// </summary>
            public abstract void Reset(DataViewRow row);
 
            internal static AggregatorDictionaryBase Create(RoleMappedSchema schema, string stratCol, DataViewType stratType,
                Func<string, TAgg> createAgg)
            {
                Contracts.AssertNonWhiteSpace(stratCol);
                Contracts.AssertValue(createAgg);
 
                if (stratType.GetKeyCount() == 0 && !(stratType is TextDataViewType))
                {
                    throw Contracts.ExceptUserArg(nameof(MamlEvaluatorBase.ArgumentsBase.StratColumns),
                        "Stratification column '{stratCol}' has type '{stratType}', but must be a known count key or text");
                }
                return Utils.MarshalInvoke(CreateDictionary<int>, stratType.RawType, schema, stratCol, stratType, createAgg);
            }
 
            private static AggregatorDictionaryBase CreateDictionary<TStrat>(RoleMappedSchema schema, string stratCol,
                DataViewType stratType, Func<string, TAgg> createAgg)
            {
                return new GenericAggregatorDictionary<TStrat>(schema, stratCol, stratType, createAgg);
            }
 
            /// <summary>
            /// This method calls the getter of the stratification column, and returns the aggregator corresponding to
            /// the stratification value.
            /// </summary>
            /// <returns></returns>
            public abstract TAgg Get();
 
            /// <summary>
            /// This method returns the aggregators corresponding to all the stratification values seen so far.
            /// </summary>
            public abstract IEnumerable<TAgg> GetAll();
 
            private sealed class GenericAggregatorDictionary<TStrat> : AggregatorDictionaryBase
            {
                private readonly Dictionary<TStrat, TAgg> _dict;
                private ValueGetter<TStrat> _stratGetter;
 
                // This is used to get the current stratification value in the Get() method.
                private TStrat _value;
 
                public override int Count => _dict.Count;
 
                public GenericAggregatorDictionary(RoleMappedSchema schema, string stratCol, DataViewType stratType, Func<string, TAgg> createAgg)
                    : base(schema, stratCol, createAgg)
                {
                    Contracts.Assert(stratType.RawType == typeof(TStrat));
                    _dict = new Dictionary<TStrat, TAgg>();
                }
 
                public override void Reset(DataViewRow row)
                {
                    Row = row;
                    var col = row.Schema.GetColumnOrNull(ColName);
                    Contracts.Assert(col.HasValue);
                    _stratGetter = row.GetGetter<TStrat>(col.Value);
                    Contracts.AssertValue(_stratGetter);
                }
 
                public override TAgg Get()
                {
                    _stratGetter(ref _value);
 
                    TAgg agg;
                    if (!_dict.TryGetValue(_value, out agg))
                    {
                        // REVIEW: Consider adding a specific implementation for key types
                        // that would call _createAgg with the key value instead of the raw value.
                        agg = CreateAgg(_value.ToString());
                        agg.Start();
                        agg.InitializeNextPass(Row, Schema);
                        _dict.Add(_value, agg);
                    }
                    return agg;
                }
 
                public override IEnumerable<TAgg> GetAll()
                {
                    return _dict.Select(kvp => kvp.Value);
                }
            }
        }
    }
 
    [BestFriend]
    internal abstract class RowToRowEvaluatorBase<TAgg> : EvaluatorBase<TAgg>
        where TAgg : EvaluatorBase<TAgg>.AggregatorBase
    {
        [BestFriend]
        private protected RowToRowEvaluatorBase(IHostEnvironment env, string registrationName)
            : base(env, registrationName)
        {
        }
 
        internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data)
        {
            var mapper = CreatePerInstanceRowMapper(data.Schema);
            return new RowToRowMapperTransform(Host, data.Data, mapper, null);
        }
 
        [BestFriend]
        private protected abstract IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema);
    }
 
    /// <summary>
    /// This is a helper class for creating the per-instance IDV.
    /// </summary>
    [BestFriend]
    internal abstract class PerInstanceEvaluatorBase : IRowMapper
    {
        protected readonly IHost Host;
        protected readonly string ScoreCol;
        protected readonly string LabelCol;
        protected readonly int ScoreIndex;
        protected readonly int LabelIndex;
 
        protected PerInstanceEvaluatorBase(IHostEnvironment env, DataViewSchema schema, string scoreCol, string labelCol)
        {
            Contracts.AssertValue(env);
            Contracts.AssertNonEmpty(scoreCol);
 
            Host = env.Register("PerInstanceRowMapper");
            ScoreCol = scoreCol;
            LabelCol = labelCol;
 
            if (!string.IsNullOrEmpty(LabelCol) && !schema.TryGetColumnIndex(LabelCol, out LabelIndex))
                throw Host.ExceptSchemaMismatch(nameof(schema), "label", LabelCol);
            if (!schema.TryGetColumnIndex(ScoreCol, out ScoreIndex))
                throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol);
        }
 
        protected PerInstanceEvaluatorBase(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
        {
            Host = env.Register("PerInstanceRowMapper");
 
            // *** Binary format **
            // int: Id of the score column name
            // int: Id of the label column name
 
            ScoreCol = ctx.LoadNonEmptyString();
            LabelCol = ctx.LoadStringOrNull();
            if (!string.IsNullOrEmpty(LabelCol) && !schema.TryGetColumnIndex(LabelCol, out LabelIndex))
                throw Host.ExceptSchemaMismatch(nameof(schema), "label", LabelCol);
            if (!schema.TryGetColumnIndex(ScoreCol, out ScoreIndex))
                throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol);
        }
 
        void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
 
        /// <summary>
        /// Derived class, for example A, should overwrite <see cref="SaveModel"/> so that ((<see cref="ICanSaveModel"/>)A).Save(ctx) can correctly dump A.
        /// </summary>
        private protected virtual void SaveModel(ModelSaveContext ctx)
        {
            // *** Binary format **
            // int: Id of the score column name
            // int: Id of the label column name
 
            ctx.SaveNonEmptyString(ScoreCol);
            ctx.SaveStringOrNull(LabelCol);
        }
 
        Func<int, bool> IRowMapper.GetDependencies(Func<int, bool> activeOutput)
            => GetDependenciesCore(activeOutput);
 
        [BestFriend]
        private protected abstract Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput);
 
        DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns()
            => GetOutputColumnsCore();
 
        [BestFriend]
        private protected abstract DataViewSchema.DetachedColumn[] GetOutputColumnsCore();
 
        Delegate[] IRowMapper.CreateGetters(DataViewRow input, Func<int, bool> activeCols, out Action disposer)
            => CreateGettersCore(input, activeCols, out disposer);
 
        [BestFriend]
        private protected abstract Delegate[] CreateGettersCore(DataViewRow input, Func<int, bool> activeCols, out Action disposer);
 
        public ITransformer GetTransformer()
        {
            throw Host.ExceptNotSupp();
        }
    }
}