File: Evaluators\EvaluatorUtils.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Text;
using System.Threading;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
namespace Microsoft.ML.Data
{
    [BestFriend]
    internal static class EvaluateUtils
    {
        public struct AggregatedMetric
        {
            public double Sum;
            public double SumSq;
            public string Name;
        }
 
        private static class DefaultEvaluatorTable
        {
            private static volatile Dictionary<string, Func<IHostEnvironment, IMamlEvaluator>> _knownEvaluatorFactories;
 
            public static Dictionary<string, Func<IHostEnvironment, IMamlEvaluator>> Instance
            {
                get
                {
                    Dictionary<string, Func<IHostEnvironment, IMamlEvaluator>> result = _knownEvaluatorFactories;
                    if (result == null)
                    {
                        var tmp = new Dictionary<string, Func<IHostEnvironment, IMamlEvaluator>>
                        {
                            { AnnotationUtils.Const.ScoreColumnKind.BinaryClassification, env => new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()) },
                            { AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification, env => new MulticlassClassificationMamlEvaluator(env, new MulticlassClassificationMamlEvaluator.Arguments()) },
                            { AnnotationUtils.Const.ScoreColumnKind.Regression, env => new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()) },
                            { AnnotationUtils.Const.ScoreColumnKind.MultiOutputRegression, env => new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()) },
                            { AnnotationUtils.Const.ScoreColumnKind.QuantileRegression, env => new QuantileRegressionMamlEvaluator(env, new QuantileRegressionMamlEvaluator.Arguments()) },
                            { AnnotationUtils.Const.ScoreColumnKind.Ranking, env => new RankingMamlEvaluator(env, new RankingMamlEvaluator.Arguments()) },
                            { AnnotationUtils.Const.ScoreColumnKind.Clustering, env => new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()) },
                            { AnnotationUtils.Const.ScoreColumnKind.AnomalyDetection, env => new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()) }
                        };
                        //tmp.Add(MetadataUtils.Const.ScoreColumnKind.SequenceClassification, "SequenceClassifierEvaluator");
                        Interlocked.CompareExchange(ref _knownEvaluatorFactories, tmp, null);
                        result = _knownEvaluatorFactories;
                    }
                    return result;
                }
            }
        }
 
        public static IMamlEvaluator GetEvaluator(IHostEnvironment env, DataViewSchema schema)
        {
            Contracts.CheckValueOrNull(env);
            ReadOnlyMemory<char> tmp = default;
            schema.GetMaxAnnotationKind(out int col, AnnotationUtils.Kinds.ScoreColumnSetId, CheckScoreColumnKindIsKnown);
            if (col >= 0)
            {
                schema[col].Annotations.GetValue(AnnotationUtils.Kinds.ScoreColumnKind, ref tmp);
                var kind = tmp.ToString();
                var map = DefaultEvaluatorTable.Instance;
                // The next assert is guaranteed because it is checked in CheckScoreColumnKindIsKnown which is the lambda passed to GetMaxMetadataKind.
                Contracts.Assert(map.ContainsKey(kind));
                return map[kind](env);
            }
 
            schema.GetMaxAnnotationKind(out col, AnnotationUtils.Kinds.ScoreColumnSetId, CheckScoreColumnKind);
            if (col >= 0)
            {
                schema[col].Annotations.GetValue(AnnotationUtils.Kinds.ScoreColumnKind, ref tmp);
                throw env.ExceptUserArg(nameof(EvaluateCommand.Arguments.Evaluator), "No default evaluator found for score column kind '{0}'.", tmp.ToString());
            }
 
            throw env.ExceptParam(nameof(schema), "No score columns have been automatically detected.");
        }
 
        // Lambda used as validator/filter in calls to GetMaxMetadataKind.
        private static bool CheckScoreColumnKindIsKnown(DataViewSchema schema, int col)
        {
            var columnType = schema[col].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.ScoreColumnKind)?.Type;
            if (columnType == null || !(columnType is TextDataViewType))
                return false;
            ReadOnlyMemory<char> tmp = default;
            schema[col].Annotations.GetValue(AnnotationUtils.Kinds.ScoreColumnKind, ref tmp);
            var map = DefaultEvaluatorTable.Instance;
            return map.ContainsKey(tmp.ToString());
        }
 
        // Lambda used as validator/filter in calls to GetMaxMetadataKind.
        private static bool CheckScoreColumnKind(DataViewSchema schema, int col)
        {
            var columnType = schema[col].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.ScoreColumnKind)?.Type;
            return columnType != null && columnType is TextDataViewType;
        }
 
        /// <summary>
        /// Find the score column to use. If <paramref name="name"/> is specified, that is used. Otherwise, this searches
        /// for the most recent score set of the given <paramref name="kind"/>. If there is no such score set and
        /// <paramref name="defName"/> is specifed it uses <paramref name="defName"/>. Otherwise, it throws.
        /// </summary>
        public static DataViewSchema.Column GetScoreColumn(IExceptionContext ectx, DataViewSchema schema, string name, string argName, string kind,
            string valueKind = AnnotationUtils.Const.ScoreValueKind.Score, string defName = null)
        {
            Contracts.CheckValueOrNull(ectx);
            ectx.CheckValue(schema, nameof(schema));
            ectx.CheckValueOrNull(name);
            ectx.CheckNonEmpty(argName, nameof(argName));
            ectx.CheckNonEmpty(kind, nameof(kind));
            ectx.CheckNonEmpty(valueKind, nameof(valueKind));
 
            if (!string.IsNullOrWhiteSpace(name))
            {
#pragma warning disable MSML_ContractsNameUsesNameof // This utility method is meant to reflect the argument name of whatever is calling it, so we take that as a parameter, rather than using nameof directly as in most cases.
                var col = schema.GetColumnOrNull(name);
                if (!col.HasValue)
                    throw ectx.ExceptUserArg(argName, "Score column is missing");
#pragma warning restore MSML_ContractsNameUsesNameof
                return col.Value;
            }
 
            var maxSetNum = schema.GetMaxAnnotationKind(out int colTmp, AnnotationUtils.Kinds.ScoreColumnSetId,
                (s, c) => IsScoreColumnKind(ectx, s, c, kind));
 
            ReadOnlyMemory<char> tmp = default;
            foreach (var colIdx in schema.GetColumnSet(AnnotationUtils.Kinds.ScoreColumnSetId, maxSetNum))
            {
                var col = schema[colIdx];
#if DEBUG
                col.Annotations.GetValue(AnnotationUtils.Kinds.ScoreColumnKind, ref tmp);
                ectx.Assert(ReadOnlyMemoryUtils.EqualsStr(kind, tmp));
#endif
                // REVIEW: What should this do about hidden columns? Currently we ignore them.
                if (col.IsHidden)
                    continue;
                if (col.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.ScoreValueKind)?.Type == TextDataViewType.Instance)
                {
                    col.Annotations.GetValue(AnnotationUtils.Kinds.ScoreValueKind, ref tmp);
                    if (ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp))
                        return col;
                }
            }
 
            if (!string.IsNullOrWhiteSpace(defName) && schema.GetColumnOrNull(defName) is DataViewSchema.Column defCol)
                return defCol;
 
#pragma warning disable MSML_ContractsNameUsesNameof
            throw ectx.ExceptUserArg(argName, "Score column is missing");
#pragma warning restore MSML_ContractsNameUsesNameof
        }
 
        /// <summary>
        /// Find the optional auxiliary score column to use. If <paramref name="name"/> is specified, that is used.
        /// Otherwise, if <paramref name="colScore"/> is part of a score set, this looks in the score set for a column
        /// with the given <paramref name="valueKind"/>. If none is found, it returns <see langword="null"/>.
        /// </summary>
        public static DataViewSchema.Column? GetOptAuxScoreColumn(IExceptionContext ectx, DataViewSchema schema, string name, string argName,
            int colScore, string valueKind, Func<DataViewType, bool> testType)
        {
            Contracts.CheckValueOrNull(ectx);
            ectx.CheckValue(schema, nameof(schema));
            ectx.CheckValueOrNull(name);
            ectx.CheckNonEmpty(argName, nameof(argName));
            ectx.CheckParam(0 <= colScore && colScore < schema.Count, nameof(colScore));
            ectx.CheckNonEmpty(valueKind, nameof(valueKind));
 
            if (!string.IsNullOrWhiteSpace(name))
            {
#pragma warning disable MSML_ContractsNameUsesNameof
                var col = schema.GetColumnOrNull(name);
                if (!col.HasValue)
                    throw ectx.ExceptUserArg(argName, "{0} column is missing", valueKind);
                if (!testType(col.Value.Type))
                    throw ectx.ExceptUserArg(argName, "{0} column has incompatible type", valueKind);
#pragma warning restore MSML_ContractsNameUsesNameof
                return col.Value;
            }
 
            // Get the score column set id from colScore.
            var type = schema[colScore].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.ScoreColumnSetId)?.Type;
            if (!(type is KeyDataViewType) || type.RawType != typeof(uint))
            {
                // scoreCol is not part of a score column set, so can't determine an aux column.
                return null;
            }
            uint setId = 0;
            schema[colScore].Annotations.GetValue(AnnotationUtils.Kinds.ScoreColumnSetId, ref setId);
 
            ReadOnlyMemory<char> tmp = default;
            foreach (var colIdx in schema.GetColumnSet(AnnotationUtils.Kinds.ScoreColumnSetId, setId))
            {
                // REVIEW: What should this do about hidden columns? Currently we ignore them.
                var col = schema[colIdx];
                if (col.IsHidden)
                    continue;
 
                if (col.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.ScoreValueKind)?.Type == TextDataViewType.Instance)
                {
                    col.Annotations.GetValue(AnnotationUtils.Kinds.ScoreValueKind, ref tmp);
                    if (ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp) && testType(col.Type))
                        return col;
                }
            }
 
            // Didn't find it in the score column set.
            return null;
        }
 
        private static bool IsScoreColumnKind(IExceptionContext ectx, DataViewSchema schema, int col, string kind)
        {
            Contracts.CheckValueOrNull(ectx);
            ectx.CheckValue(schema, nameof(schema));
            ectx.CheckParam(0 <= col && col < schema.Count, nameof(col));
            ectx.CheckNonEmpty(kind, nameof(kind));
 
            var type = schema[col].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.ScoreColumnKind)?.Type;
            if (type == null || !(type is TextDataViewType))
                return false;
            var tmp = default(ReadOnlyMemory<char>);
            schema[col].Annotations.GetValue(AnnotationUtils.Kinds.ScoreColumnKind, ref tmp);
            return ReadOnlyMemoryUtils.EqualsStr(kind, tmp);
        }
 
        /// <summary>
        /// If <paramref name="str"/> is non-empty, returns it. Otherwise if <paramref name="info"/> is non-<see langword="null"/>,
        /// returns its <see cref="DataViewSchema.Column.Name"/>. Otherwise, returns <paramref name="def"/>.
        /// </summary>
        public static string GetColName(string str, DataViewSchema.Column? info, string def)
        {
            Contracts.CheckValueOrNull(str);
            Contracts.CheckValueOrNull(def);
 
            if (!string.IsNullOrEmpty(str))
                return str;
            return info?.Name ?? def;
        }
 
        public static void CheckWeightType(IExceptionContext ectx, DataViewType type)
        {
            ectx.AssertValue(type);
            if (type != NumberDataViewType.Single)
                throw ectx.ExceptUserArg(nameof(EvaluateCommand.Arguments.WeightColumn), "Incompatible Weight column. Weight column type must be {0}.", NumberDataViewType.Single);
        }
 
        /// <summary>
        /// Helper method to get an IEnumerable of double metrics from an overall metrics IDV produced by an evaluator.
        /// </summary>
        public static IEnumerable<KeyValuePair<string, double>> GetMetrics(IDataView metricsView, bool getVectorMetrics = true)
        {
            Contracts.CheckValue(metricsView, nameof(metricsView));
            var schema = metricsView.Schema;
 
            // Figure out whether there is an "IsWeighted" column.
            int isWeightedCol;
            var hasWeighted = schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeightedCol);
 
            // Figure out whether there are stratification columns.
            int stratCol;
            int stratVal = -1;
            bool hasStrats;
            if (hasStrats = schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol))
            {
                if (!schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal))
                {
                    throw Contracts.Except("If data contains a '{0}' column, it must also contain a '{1}' column",
                        MetricKinds.ColumnNames.StratCol, MetricKinds.ColumnNames.StratVal);
                }
            }
 
            using (var cursor = metricsView.GetRowCursorForAllColumns())
            {
                bool isWeighted = false;
                ValueGetter<bool> isWeightedGetter;
                if (hasWeighted)
                    isWeightedGetter = cursor.GetGetter<bool>(schema[isWeightedCol]);
                else
                    isWeightedGetter = (ref bool dst) => dst = false;
 
                ValueGetter<uint> stratColGetter;
                if (hasStrats)
                {
                    var type = cursor.Schema[stratCol].Type;
                    stratColGetter = RowCursorUtils.GetGetterAs<uint>(type, cursor, stratCol);
                }
                else
                    stratColGetter = (ref uint dst) => dst = 0;
 
                // We currently have only double valued or vector of double valued metrics.
                var colCount = schema.Count;
                var getters = new ValueGetter<double>[colCount];
                var vBufferGetters = getVectorMetrics ? new ValueGetter<VBuffer<double>>[colCount] : null;
 
                for (int i = 0; i < schema.Count; i++)
                {
                    var column = schema[i];
                    if (column.IsHidden || hasWeighted && i == isWeightedCol ||
                        hasStrats && (i == stratCol || i == stratVal))
                        continue;
 
                    var type = schema[i].Type;
                    if (type == NumberDataViewType.Double || type == NumberDataViewType.Single)
                        getters[i] = RowCursorUtils.GetGetterAs<double>(NumberDataViewType.Double, cursor, i);
                    else if (type is VectorDataViewType vectorType
                        && vectorType.IsKnownSize
                        && vectorType.ItemType == NumberDataViewType.Double
                        && getVectorMetrics)
                        vBufferGetters[i] = cursor.GetGetter<VBuffer<double>>(column);
                }
 
                Double metricVal = 0;
                VBuffer<double> metricVals = default(VBuffer<double>);
                uint strat = 0;
                bool foundRow = false;
                while (cursor.MoveNext())
                {
                    isWeightedGetter(ref isWeighted);
                    if (isWeighted)
                        continue;
 
                    stratColGetter(ref strat);
                    if (strat > 0)
                        continue;
 
                    // There should only be one row where isWeighted is false and strat=0.
                    Contracts.Check(!foundRow, "Multiple metric rows found in metrics data view.");
 
                    foundRow = true;
                    for (int i = 0; i < colCount; i++)
                    {
                        if (hasWeighted && i == isWeightedCol || hasStrats && (i == stratCol || i == stratVal))
                            continue;
 
                        if (getters[i] != null)
                        {
                            getters[i](ref metricVal);
                            // For R8 valued columns the metric name is the column name.
                            yield return new KeyValuePair<string, double>(schema[i].Name, metricVal);
                        }
                        else if (getVectorMetrics && vBufferGetters[i] != null)
                        {
                            vBufferGetters[i](ref metricVals);
 
                            // For R8 vector valued columns the names of the metrics are the column name,
                            // followed by the slot name if it exists, or "Label_i" if it doesn't.
                            VBuffer<ReadOnlyMemory<char>> names = default;
                            var size = schema[i].Type.GetVectorSize();
                            var slotNamesType = schema[i].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
                            if (slotNamesType != null && slotNamesType.Size == size && slotNamesType.ItemType is TextDataViewType)
                                schema[i].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref names);
                            else
                            {
                                var namesArray = new ReadOnlyMemory<char>[size];
                                for (int j = 0; j < size; j++)
                                    namesArray[j] = string.Format("({0})", j).AsMemory();
                                names = new VBuffer<ReadOnlyMemory<char>>(size, namesArray);
                            }
                            var colName = schema[i].Name;
                            foreach (var metric in metricVals.Items(all: true))
                            {
                                yield return new KeyValuePair<string, double>(
                                    string.Format("{0} {1}", colName, names.GetItemOrDefault(metric.Key)), metric.Value);
                            }
                        }
                    }
                }
            }
        }
 
        private static IDataView AddTextColumn<TSrc>(IHostEnvironment env, IDataView input, string inputColName, string outputColName,
            DataViewType typeSrc, string value, string registrationName)
        {
            Contracts.Check(typeSrc.RawType == typeof(TSrc));
            return LambdaColumnMapper.Create(env, registrationName, input, inputColName, outputColName, typeSrc, TextDataViewType.Instance,
                (in TSrc src, ref ReadOnlyMemory<char> dst) => dst = value.AsMemory());
        }
 
        /// <summary>
        /// Add a text column containing a fold index to a data view.
        /// </summary>
        /// <param name="env">The host environment.</param>
        /// <param name="input">The data view to which we add the column</param>
        /// <param name="curFold">The current fold this data view belongs to.</param>
        /// <returns>The input data view with an additional text column containing the current fold index.</returns>
        public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int curFold)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
            env.CheckParam(curFold >= 0, nameof(curFold));
 
            // We use the first column in the data view as an input column to the LambdaColumnMapper,
            // because it must have an input.
            int inputCol = 0;
            while (inputCol < input.Schema.Count && input.Schema[inputCol].IsHidden)
                inputCol++;
            env.Assert(inputCol < input.Schema.Count);
 
            var inputColName = input.Schema[0].Name;
            var inputColType = input.Schema[0].Type;
            return Utils.MarshalInvoke(AddTextColumn<int>, inputColType.RawType, env,
                input, inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, $"Fold {curFold}", "FoldName");
        }
 
        private static IDataView AddKeyColumn<TSrc>(IHostEnvironment env, IDataView input, string inputColName, string outputColName,
            DataViewType typeSrc, int keyCount, int value, string registrationName, ValueGetter<VBuffer<ReadOnlyMemory<char>>> keyValueGetter)
        {
            Contracts.Check(typeSrc.RawType == typeof(TSrc));
            return LambdaColumnMapper.Create(env, registrationName, input, inputColName, outputColName, typeSrc,
                new KeyDataViewType(typeof(uint), keyCount), (in TSrc src, ref uint dst) =>
                {
                    if (value < 0 || value > keyCount)
                        dst = 0;
                    else
                        dst = (uint)value;
                }, keyValueGetter);
        }
 
        /// <summary>
        /// Add a key type column containing a fold index to a data view.
        /// </summary>
        /// <param name="env">The host environment.</param>
        /// <param name="input">The data view to which we add the column</param>
        /// <param name="curFold">The current fold this data view belongs to.</param>
        /// <param name="numFolds">The total number of folds.</param>
        /// <returns>The input data view with an additional key type column containing the current fold index.</returns>
        public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int curFold, int numFolds)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
            env.CheckParam(curFold >= 0, nameof(curFold));
            env.CheckParam(numFolds > 0, nameof(numFolds));
 
            // We use the first column in the data view as an input column to the LambdaColumnMapper,
            // because it must have an input.
            int inputCol = 0;
            while (inputCol < input.Schema.Count && input.Schema[inputCol].IsHidden)
                inputCol++;
            env.Assert(inputCol < input.Schema.Count);
 
            var inputColName = input.Schema[inputCol].Name;
            var inputColType = input.Schema[inputCol].Type;
            return Utils.MarshalInvoke(AddKeyColumn<int>, inputColType.RawType, env,
                input, inputColName, MetricKinds.ColumnNames.FoldIndex,
                inputColType, numFolds, curFold + 1, "FoldIndex", default(ValueGetter<VBuffer<ReadOnlyMemory<char>>>));
        }
 
        /// <summary>
        /// This method takes an array of data views and a specified input vector column, and adds a new output column to each of the data views.
        /// First, we find the union set of the slot names in the different data views. Next we define a new vector column for each
        /// data view, indexed by the union of the slot names. For each data view, every slot value is the value in the slot corresponding
        /// to its slot name in the original column. If a reconciled slot name does not exist in an input column, the value in the output
        /// column is def.
        /// </summary>
        public static void ReconcileSlotNames<T>(IHostEnvironment env, IDataView[] views, string columnName, PrimitiveDataViewType itemType, T def = default(T))
        {
            Contracts.CheckNonEmpty(views, nameof(views));
            Contracts.CheckValue(itemType, nameof(itemType));
            Contracts.CheckParam(typeof(T) == itemType.RawType, nameof(itemType), "Generic type does not match the item type");
 
            var numIdvs = views.Length;
            var slotNames = new Dictionary<string, int>();
            var maps = new int[numIdvs][];
            var slotNamesCur = default(VBuffer<ReadOnlyMemory<char>>);
            var typeSrc = new DataViewType[numIdvs];
            // Create mappings from the original slots to the reconciled slots.
            for (int i = 0; i < numIdvs; i++)
            {
                var idv = views[i];
                int col;
                if (!idv.Schema.TryGetColumnIndex(columnName, out col))
                    throw env.Except("Data view number {0} does not contain column '{1}'", i, columnName);
 
                var type = typeSrc[i] = idv.Schema[col].Type;
                if (!idv.Schema[col].HasSlotNames(type.GetVectorSize()))
                    throw env.Except("Column '{0}' in data view number {1} did not contain slot names metadata", columnName, i);
                idv.Schema[col].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref slotNamesCur);
 
                var map = maps[i] = new int[slotNamesCur.Length];
                foreach (var kvp in slotNamesCur.Items(true))
                {
                    var index = kvp.Key;
                    var name = kvp.Value.ToString();
                    if (!slotNames.ContainsKey(name))
                        slotNames[name] = slotNames.Count;
                    map[index] = slotNames[name];
                }
            }
 
            var reconciledSlotNames = new VBuffer<ReadOnlyMemory<char>>(slotNames.Count, slotNames.Keys.Select(k => k.AsMemory()).ToArray());
            ValueGetter<VBuffer<ReadOnlyMemory<char>>> slotNamesGetter =
                (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                {
                    reconciledSlotNames.CopyTo(ref dst);
                };
 
            // For each input data view, create the reconciled key column by wrapping it in a LambdaColumnMapper.
            for (int i = 0; i < numIdvs; i++)
            {
                var map = maps[i];
 
                ValueMapper<VBuffer<T>, VBuffer<T>> mapper;
                if (def.Equals(default(T)))
                {
                    mapper =
                        (in VBuffer<T> src, ref VBuffer<T> dst) =>
                        {
                            Contracts.Assert(src.Length == Utils.Size(map));
                            var editor = VBufferEditor.Create(ref dst, slotNames.Count);
 
                            foreach (var kvp in src.Items())
                                editor.Values[map[kvp.Key]] = kvp.Value;
                            dst = editor.Commit();
                        };
                }
                else
                {
                    // Create a list of the slots in the reconciled output column that do not correspond to any slots
                    // in the input column, so we can populate them with NAs.
                    var mappedIndices = new bool[slotNames.Count];
                    for (int j = 0; j < map.Length; j++)
                        mappedIndices[map[j]] = true;
                    var naIndices = new List<int>();
                    for (int j = 0; j < mappedIndices.Length; j++)
                    {
                        if (!mappedIndices[j])
                            naIndices.Add(j);
                    }
                    mapper =
                        (in VBuffer<T> src, ref VBuffer<T> dst) =>
                        {
                            Contracts.Assert(src.Length == Utils.Size(map));
                            var editor = VBufferEditor.Create(ref dst, slotNames.Count);
 
                            foreach (var kvp in src.Items(true))
                                editor.Values[map[kvp.Key]] = kvp.Value;
                            foreach (var j in naIndices)
                                editor.Values[j] = def;
                            dst = editor.Commit();
                        };
                }
 
                var typeDst = new VectorDataViewType(itemType, slotNames.Count);
                views[i] = LambdaColumnMapper.Create(env, "ReconciledSlotNames", views[i],
                    columnName, columnName, typeSrc[i], typeDst, mapper, slotNamesGetter: slotNamesGetter);
            }
        }
 
        private static int[][] MapKeys<T>(DataViewSchema[] schemas, string columnName, bool isVec,
            int[] indices, Dictionary<ReadOnlyMemory<char>, int> reconciledKeyNames)
        {
            Contracts.AssertValue(indices);
            Contracts.AssertValue(reconciledKeyNames);
 
            var dvCount = schemas.Length;
            var keyValueMappers = new int[dvCount][];
            var keyNamesCur = default(VBuffer<T>);
            for (int i = 0; i < dvCount; i++)
            {
                var schema = schemas[i];
                if (!schema.TryGetColumnIndex(columnName, out indices[i]))
                    throw Contracts.Except($"Schema number {i} does not contain column '{columnName}'");
 
                var type = schema[indices[i]].Type;
                var keyValueType = schema[indices[i]].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type;
                VectorDataViewType vectorType = type as VectorDataViewType;
                bool typeIsVector = vectorType != null;
                if (typeIsVector != isVec)
                    throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type");
                DataViewType keyValueItemType = (keyValueType as VectorDataViewType)?.ItemType ?? keyValueType;
                if (keyValueItemType == null || keyValueItemType.RawType != typeof(T))
                    throw Contracts.Except($"Column '{columnName}' in schema number {i} does not have the correct type of key values");
                DataViewType typeItemType = vectorType?.ItemType ?? type;
                if (!(typeItemType is KeyDataViewType itemKeyType) || typeItemType.RawType != typeof(uint))
                    throw Contracts.Except($"Column '{columnName}' must be a U4 key type, but is '{typeItemType}'");
 
                schema[indices[i]].GetKeyValues(ref keyNamesCur);
 
                keyValueMappers[i] = new int[itemKeyType.Count];
                foreach (var kvp in keyNamesCur.Items(true))
                {
                    var key = kvp.Key;
                    var name = kvp.Value.ToString().AsMemory();
                    if (!reconciledKeyNames.ContainsKey(name))
                        reconciledKeyNames[name] = reconciledKeyNames.Count;
                    keyValueMappers[i][key] = reconciledKeyNames[name];
                }
            }
            return keyValueMappers;
        }
 
        /// <summary>
        /// This method takes an array of data views and a specified input key column, and adds a new output column to each of the data views.
        /// First, we find the union set of the key values in the different data views. Next we define a new key column for each
        /// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
        /// corresponding to the key value in the original column.
        /// </summary>
        public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, string columnName, DataViewType keyValueType)
        {
            Contracts.CheckNonEmpty(views, nameof(views));
            Contracts.CheckNonEmpty(columnName, nameof(columnName));
 
            var dvCount = views.Length;
 
            // Create mappings from the original key types to the reconciled key type.
            var indices = new int[dvCount];
            var keyNames = new Dictionary<ReadOnlyMemory<char>, int>();
            // We use MarshalInvoke so that we can call MapKeys with the correct generic: keyValueType.RawType.
            var keyValueMappers = Utils.MarshalInvoke(MapKeys<int>, keyValueType.RawType, views.Select(view => view.Schema).ToArray(), columnName, false, indices, keyNames);
            var keyType = new KeyDataViewType(typeof(uint), keyNames.Count);
            var keyNamesVBuffer = new VBuffer<ReadOnlyMemory<char>>(keyNames.Count, keyNames.Keys.ToArray());
            ValueGetter<VBuffer<ReadOnlyMemory<char>>> keyValueGetter =
                    (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                        keyNamesVBuffer.CopyTo(ref dst);
 
            // For each input data view, create the reconciled key column by wrapping it in a LambdaColumnMapper.
            for (int i = 0; i < dvCount; i++)
            {
                var keyMapperCur = keyValueMappers[i];
                ValueMapper<uint, uint> mapper =
                    (in uint src, ref uint dst) =>
                    {
                        if (src == 0 || src > keyMapperCur.Length)
                            dst = 0;
                        else
                            dst = (uint)keyMapperCur[src - 1] + 1;
                    };
                views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName,
                    views[i].Schema[indices[i]].Type, keyType, mapper, keyValueGetter);
            }
        }
 
        /// <summary>
        /// This method takes an array of data views and a specified input key column, and adds a new output column to each of the data views.
        /// First, we find the union set of the key values in the different data views. Next we define a new key column for each
        /// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
        /// corresponding to the key value in the original column.
        /// </summary>
        public static void ReconcileKeyValuesWithNoNames(IHostEnvironment env, IDataView[] views, string columnName, ulong keyCount)
        {
            Contracts.CheckNonEmpty(views, nameof(views));
            Contracts.CheckNonEmpty(columnName, nameof(columnName));
 
            var keyType = new KeyDataViewType(typeof(uint), keyCount);
 
            // For each input data view, create the reconciled key column by wrapping it in a LambdaColumnMapper.
            for (int i = 0; i < views.Length; i++)
            {
                if (!views[i].Schema.TryGetColumnIndex(columnName, out var index))
                    throw env.Except($"Data view {i} doesn't contain a column '{columnName}'");
                ValueMapper<uint, uint> mapper =
                    (in uint src, ref uint dst) =>
                    {
                        if (src > keyCount)
                            dst = 0;
                        else
                            dst = src;
                    };
                views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName,
                    views[i].Schema[index].Type, keyType, mapper);
            }
        }
 
        /// <summary>
        /// This method is similar to <see cref="ReconcileKeyValues"/>, but it reconciles the key values over vector
        /// input columns.
        /// </summary>
        public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] views, string columnName, DataViewType keyValueType)
        {
            Contracts.CheckNonEmpty(views, nameof(views));
            Contracts.CheckNonEmpty(columnName, nameof(columnName));
 
            var dvCount = views.Length;
 
            var keyNames = new Dictionary<ReadOnlyMemory<char>, int>();
            var columnIndices = new int[dvCount];
            var keyValueMappers = Utils.MarshalInvoke(MapKeys<int>, keyValueType.RawType, views.Select(view => view.Schema).ToArray(), columnName, true, columnIndices, keyNames);
            var keyType = new KeyDataViewType(typeof(uint), keyNames.Count);
            var keyNamesVBuffer = new VBuffer<ReadOnlyMemory<char>>(keyNames.Count, keyNames.Keys.ToArray());
            ValueGetter<VBuffer<ReadOnlyMemory<char>>> keyValueGetter =
                    (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                        keyNamesVBuffer.CopyTo(ref dst);
 
            for (int i = 0; i < dvCount; i++)
            {
                var keyMapperCur = keyValueMappers[i];
                ValueMapper<VBuffer<uint>, VBuffer<uint>> mapper =
                    (in VBuffer<uint> src, ref VBuffer<uint> dst) =>
                    {
                        var srcValues = src.GetValues();
                        var editor = VBufferEditor.Create(
                            ref dst,
                            src.Length,
                            srcValues.Length);
                        if (src.IsDense)
                        {
                            for (int j = 0; j < src.Length; j++)
                            {
                                if (srcValues[j] == 0 || srcValues[j] > keyMapperCur.Length)
                                    editor.Values[j] = 0;
                                else
                                    editor.Values[j] = (uint)keyMapperCur[srcValues[j] - 1] + 1;
                            }
                        }
                        else
                        {
                            var srcIndices = src.GetIndices();
                            for (int j = 0; j < srcValues.Length; j++)
                            {
                                if (srcValues[j] == 0 || srcValues[j] > keyMapperCur.Length)
                                    editor.Values[j] = 0;
                                else
                                    editor.Values[j] = (uint)keyMapperCur[srcValues[j] - 1] + 1;
                                editor.Indices[j] = srcIndices[j];
                            }
                        }
                        dst = editor.Commit();
                    };
 
                ValueGetter<VBuffer<ReadOnlyMemory<char>>> slotNamesGetter = null;
                var type = views[i].Schema[columnIndices[i]].Type;
                if (views[i].Schema[columnIndices[i]].HasSlotNames(type.GetVectorSize()))
                {
                    var schema = views[i].Schema;
                    int index = columnIndices[i];
                    slotNamesGetter =
                        (ref VBuffer<ReadOnlyMemory<char>> dst) => schema[index].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref dst);
                }
                views[i] = LambdaColumnMapper.Create(env, "ReconcileKeyValues", views[i], columnName, columnName,
                    type, new VectorDataViewType(keyType, ((VectorDataViewType)type).Dimensions), mapper, keyValueGetter, slotNamesGetter);
            }
        }
 
        /// <summary>
        /// This method gets the per-instance metrics from multiple scored data views and either returns them as an
        /// array or combines them into a single data view, based on user specifications.
        /// </summary>
        /// <param name="env">A host environment.</param>
        /// <param name="eval">The evaluator to use for getting the per-instance metrics.</param>
        /// <param name="collate">If true, data views are combined into a single data view. Otherwise, data views
        /// are returned as an array.</param>
        /// <param name="outputFoldIndex">If true, a column containing the fold index is added to the returned data views.</param>
        /// <param name="perInstance">The array of scored data views to evaluate. These are passed as <see cref="RoleMappedData"/>
        /// so that the evaluator can know the role mappings it needs.</param>
        /// <param name="variableSizeVectorColumnNames">A list of column names that are not included in the combined data view
        /// since their types do not match.</param>
        /// <returns></returns>
        public static IDataView[] ConcatenatePerInstanceDataViews(IHostEnvironment env, IMamlEvaluator eval, bool collate, bool outputFoldIndex, RoleMappedData[] perInstance, out string[] variableSizeVectorColumnNames)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(eval, nameof(eval));
            env.CheckNonEmpty(perInstance, nameof(perInstance));
 
            Func<RoleMappedData, int, IDataView> getPerInstance =
                (rmd, i) =>
                {
                    var perInst = eval.GetPerInstanceDataViewToSave(rmd);
 
                    if (!outputFoldIndex)
                        return perInst;
 
                    // If the fold index is requested, add a column containing it. We use the first column in the data view
                    // as an input column to the LambdaColumnMapper, because it must have an input.
                    return AddFoldIndex(env, perInst, i, perInstance.Length);
                };
 
            var foldDataViews = perInstance.Select(getPerInstance).ToArray();
            if (collate)
            {
                var combined = AppendPerInstanceDataViews(env, perInstance[0].Schema.Label?.Name, foldDataViews, out variableSizeVectorColumnNames);
                return new[] { combined };
            }
            else
            {
                variableSizeVectorColumnNames = new string[0];
                return foldDataViews.ToArray();
            }
        }
 
        /// <summary>
        /// Create an output data view that is the vertical concatenation of the metric data views.
        /// </summary>
        public static IDataView ConcatenateOverallMetrics(IHostEnvironment env, IDataView[] metrics)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckNonEmpty(metrics, nameof(metrics));
 
            if (metrics.Length == 1)
                return metrics[0];
 
            var overallList = new List<IDataView>();
            for (int i = 0; i < metrics.Length; i++)
            {
                // Add a fold-name column. We add it as a text column, since it is only used for saving the result summary file.
                var idv = AddFoldIndex(env, metrics[i], i);
                overallList.Add(idv);
            }
            return AppendRowsDataView.Create(env, overallList[0].Schema, overallList.ToArray());
        }
 
        private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string labelColName,
            IEnumerable<IDataView> foldDataViews, out string[] variableSizeVectorColumnNames)
        {
            Contracts.AssertValue(env);
            env.AssertValue(foldDataViews);
 
            // Make sure there are no variable size vector columns.
            // This is a dictionary from the column name to its vector size.
            var vectorSizes = new Dictionary<string, int>();
            var firstDvSlotNames = new Dictionary<string, VBuffer<ReadOnlyMemory<char>>>();
            DataViewType labelColKeyValuesType = null;
            var firstDvKeyWithNamesColumns = new List<string>();
            var firstDvKeyNoNamesColumns = new Dictionary<string, ulong>();
            var firstDvVectorKeyColumns = new List<string>();
            var variableSizeVectorColumnNamesList = new List<string>();
            var list = new List<IDataView>();
            int dvNumber = 0;
            foreach (var dv in foldDataViews)
            {
                var hidden = new List<int>();
                for (int i = 0; i < dv.Schema.Count; i++)
                {
                    if (dv.Schema[i].IsHidden)
                    {
                        hidden.Add(i);
                        continue;
                    }
 
                    var type = dv.Schema[i].Type;
                    var name = dv.Schema[i].Name;
                    ulong typeKeyCount = type.GetKeyCount();
                    if (type is VectorDataViewType vectorType)
                    {
                        if (dvNumber == 0)
                        {
                            if (dv.Schema[i].HasKeyValues())
                                firstDvVectorKeyColumns.Add(name);
                            // Store the slot names of the 1st idv and use them as baseline.
                            if (dv.Schema[i].HasSlotNames(vectorType.Size))
                            {
                                VBuffer<ReadOnlyMemory<char>> slotNames = default;
                                dv.Schema[i].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref slotNames);
                                firstDvSlotNames.Add(name, slotNames);
                            }
                        }
 
                        int cachedSize;
                        if (vectorSizes.TryGetValue(name, out cachedSize))
                        {
                            VBuffer<ReadOnlyMemory<char>> slotNames;
                            // In the event that no slot names were recorded here, then slotNames will be
                            // the default, length 0 vector.
                            firstDvSlotNames.TryGetValue(name, out slotNames);
                            if (!VerifyVectorColumnsMatch(cachedSize, i, dv, vectorType, in slotNames))
                                variableSizeVectorColumnNamesList.Add(name);
                        }
                        else
                            vectorSizes.Add(name, vectorType.Size);
                    }
                    else if (dvNumber == 0 && name == labelColName)
                    {
                        // The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform.
                        labelColKeyValuesType = dv.Schema[i].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type;
                    }
                    else if (dvNumber == 0 && dv.Schema[i].HasKeyValues())
                        firstDvKeyWithNamesColumns.Add(name);
                    else if (type.GetKeyCount() > 0 && name != labelColName && !dv.Schema[i].HasKeyValues())
                    {
                        // For any other key column (such as GroupId) we do not reconcile the key values, we only convert to U4.
                        if (!firstDvKeyNoNamesColumns.ContainsKey(name))
                            firstDvKeyNoNamesColumns[name] = typeKeyCount;
                        if (firstDvKeyNoNamesColumns[name] < typeKeyCount)
                            firstDvKeyNoNamesColumns[name] = typeKeyCount;
                    }
                }
                var idv = dv;
                if (hidden.Count > 0)
                {
                    var args = new ChooseColumnsByIndexTransform.Options();
                    args.Drop = true;
                    args.Indices = hidden.ToArray();
                    idv = new ChooseColumnsByIndexTransform(env, args, idv);
                }
                list.Add(idv);
                dvNumber++;
            }
            variableSizeVectorColumnNames = variableSizeVectorColumnNamesList.ToArray();
 
            var views = list.ToArray();
            foreach (var keyCol in firstDvKeyWithNamesColumns)
                ReconcileKeyValues(env, views, keyCol, TextDataViewType.Instance);
            if (labelColKeyValuesType != null)
                ReconcileKeyValues(env, views, labelColName, labelColKeyValuesType.GetItemType());
            foreach (var keyCol in firstDvKeyNoNamesColumns)
                ReconcileKeyValuesWithNoNames(env, views, keyCol.Key, keyCol.Value);
            foreach (var vectorKeyCol in firstDvVectorKeyColumns)
                ReconcileVectorKeyValues(env, views, vectorKeyCol, TextDataViewType.Instance);
 
            Func<IDataView, int, IDataView> keyToValue =
                (idv, i) =>
                {
                    foreach (var keyCol in AnnotationUtils.Prepend(firstDvVectorKeyColumns.Concat(firstDvKeyWithNamesColumns), labelColName))
                    {
                        if (keyCol == labelColName && labelColKeyValuesType == null)
                            continue;
 
                        idv = new KeyToValueMappingTransformer(env, keyCol).Transform(idv);
                        var hidden = FindHiddenColumns(idv.Schema, keyCol);
                        idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Options() { Drop = true, Indices = hidden.ToArray() }, idv);
                    }
                    foreach (var keyCol in firstDvKeyNoNamesColumns)
                    {
                        var hidden = FindHiddenColumns(idv.Schema, keyCol.Key);
                        idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Options() { Drop = true, Indices = hidden.ToArray() }, idv);
                    }
                    return idv;
                };
 
            Func<IDataView, IDataView> selectDropNonVarLenthCol =
                (idv) =>
                {
                    foreach (var variableSizeVectorColumnName in variableSizeVectorColumnNamesList)
                    {
                        int index;
                        idv.Schema.TryGetColumnIndex(variableSizeVectorColumnName, out index);
                        var vectorType = idv.Schema[index].Type as VectorDataViewType;
                        env.AssertValue(vectorType);
 
                        idv = Utils.MarshalInvoke(AddVarLengthColumn<int>, vectorType.ItemType.RawType, env, idv,
                                 variableSizeVectorColumnName, vectorType);
 
                        // Drop the old column that does not have variable length.
                        idv = ColumnSelectingTransformer.CreateDrop(env, idv, variableSizeVectorColumnName);
                    }
                    return idv;
                };
 
            return AppendRowsDataView.Create(env, null, views.Select(keyToValue).Select(selectDropNonVarLenthCol).ToArray());
        }
 
        private static IEnumerable<int> FindHiddenColumns(DataViewSchema schema, string colName)
        {
            for (int i = 0; i < schema.Count; i++)
            {
                if (schema[i].IsHidden && schema[i].Name == colName)
                    yield return i;
            }
        }
 
        private static bool VerifyVectorColumnsMatch(int cachedSize, int col, IDataView dv,
            VectorDataViewType type, in VBuffer<ReadOnlyMemory<char>> firstDvSlotNames)
        {
            if (cachedSize != type.Size)
                return false;
 
            // If we detect mismatch it a sign that slots reshuffling has happened.
            if (dv.Schema[col].HasSlotNames(type.Size))
            {
                // Verify that slots match with slots from 1st idv.
                VBuffer<ReadOnlyMemory<char>> currSlotNames = default;
                dv.Schema[col].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref currSlotNames);
 
                if (currSlotNames.Length != firstDvSlotNames.Length)
                    return false;
                else
                {
                    var result = true;
                    VBufferUtils.ForEachEitherDefined(in currSlotNames, in firstDvSlotNames,
                        (slot, val1, val2) => result = result && val1.Span.SequenceEqual(val2.Span));
                    return result;
                }
            }
            else
            {
                // If we don't have slot names, then the first dataview should not have had slot names either.
                return firstDvSlotNames.Length == 0;
            }
        }
 
        private static IDataView AddVarLengthColumn<TSrc>(IHostEnvironment env, IDataView idv, string variableSizeVectorColumnName, VectorDataViewType typeSrc)
        {
            return LambdaColumnMapper.Create(env, "ChangeToVarLength", idv, variableSizeVectorColumnName,
                       variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorDataViewType((PrimitiveDataViewType)typeSrc.ItemType),
                       (in VBuffer<TSrc> src, ref VBuffer<TSrc> dst) => src.CopyTo(ref dst));
        }
 
        private static List<string> GetMetricNames(IChannel ch, DataViewSchema schema, DataViewRow row, Func<int, bool> ignoreCol,
            ValueGetter<double>[] getters, ValueGetter<VBuffer<double>>[] vBufferGetters)
        {
            ch.AssertValue(schema);
            ch.AssertValue(row);
            ch.Assert(Utils.Size(getters) == schema.Count);
            ch.Assert(Utils.Size(vBufferGetters) == schema.Count);
 
            // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns
            // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't.
            VBuffer<ReadOnlyMemory<char>> names = default;
            int metricCount = 0;
            var metricNames = new List<string>();
            for (int i = 0; i < schema.Count; i++)
            {
                if (schema[i].IsHidden || ignoreCol(i))
                    continue;
 
                var type = schema[i].Type;
                var metricName = row.Schema[i].Name;
                if (type is NumberDataViewType)
                {
                    getters[i] = RowCursorUtils.GetGetterAs<double>(NumberDataViewType.Double, row, i);
                    metricNames.Add(metricName);
                    metricCount++;
                }
                else if (type is VectorDataViewType vectorType && vectorType.ItemType == NumberDataViewType.Double)
                {
                    if (vectorType.Size == 0)
                    {
                        ch.Warning("Vector metric '{0}' has different lengths in different folds and will not be averaged for overall results.", metricName);
                        continue;
                    }
 
                    vBufferGetters[i] = row.GetGetter<VBuffer<double>>(schema[i]);
                    metricCount += vectorType.Size;
                    var slotNamesType = schema[i].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
                    if (slotNamesType != null && slotNamesType.Size == vectorType.Size && slotNamesType.ItemType is TextDataViewType)
                        schema[i].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref names);
                    else
                    {
                        var editor = VBufferEditor.Create(ref names, vectorType.Size);
                        for (int j = 0; j < vectorType.Size; j++)
                            editor.Values[j] = string.Format("Label_{0}", j).AsMemory();
                        names = editor.Commit();
                    }
                    foreach (var name in names.Items(all: true))
                    {
                        var tryNaming = string.Format(metricName, name.Value);
                        if (tryNaming == metricName) // metricName wasn't a format string, so just append slotname
                            tryNaming = (string.Format("{0}{1}", metricName, name.Value));
 
                        metricNames.Add(tryNaming);
                    }
                }
            }
            ch.Assert(metricNames.Count == metricCount);
            return metricNames;
        }
 
        internal static IDataView GetOverallMetricsData(IHostEnvironment env, IDataView data, int numFolds, out AggregatedMetric[] agg,
            out AggregatedMetric[] weightedAgg)
        {
            agg = ComputeMetricsSum(env, data, numFolds, out int isWeightedCol, out int stratCol, out int stratVal, out int foldCol, out weightedAgg);
 
            var nonAveragedCols = new List<string>();
            var avgMetrics = GetAverageToDataView(env, data.Schema, agg, weightedAgg, numFolds, stratCol, stratVal,
                isWeightedCol, foldCol, numFolds > 1, nonAveragedCols);
 
            var idvList = new List<IDataView>() { avgMetrics };
 
            var hasStrat = stratCol >= 0;
            if (numFolds > 1 || hasStrat)
            {
                if (Utils.Size(nonAveragedCols) > 0)
                {
                    data = ColumnSelectingTransformer.CreateDrop(env, data, nonAveragedCols.ToArray());
                }
                idvList.Add(data);
            }
 
            var overall = AppendRowsDataView.Create(env, avgMetrics.Schema, idvList.ToArray());
 
            // If there are stratified results, apply a KeyToValue transform to get the stratification column
            // names from the key column.
            if (hasStrat)
                overall = new KeyToValueMappingTransformer(env, MetricKinds.ColumnNames.StratCol).Transform(overall);
            return overall;
        }
 
        internal static AggregatedMetric[] ComputeMetricsSum(IHostEnvironment env, IDataView data, int numFolds, out int isWeightedCol,
            out int stratCol, out int stratVal, out int foldCol, out AggregatedMetric[] weightedAgg)
        {
            var isWeightedColumn = data.Schema.GetColumnOrNull(MetricKinds.ColumnNames.IsWeighted);
            var hasWeighted = isWeightedColumn.HasValue;
            var hasStrats = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out int scol);
            var hasStratVals = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out int svalcol);
            env.Assert(hasStrats == hasStratVals);
            var hasFoldCol = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out int fcol);
 
            isWeightedCol = hasWeighted ? isWeightedColumn.Value.Index : -1;
            stratCol = hasStrats ? scol : -1;
            stratVal = hasStratVals ? svalcol : -1;
            foldCol = hasFoldCol ? fcol : -1;
 
            // We currently have only double valued or vector of double valued metrics.
            int colCount = data.Schema.Count;
            var getters = new ValueGetter<double>[colCount];
            var vBufferGetters = new ValueGetter<VBuffer<double>>[colCount];
            int numResults = 0;
            int numWeightedResults = 0;
            AggregatedMetric[] agg;
            using (var cursor = data.GetRowCursorForAllColumns())
            {
                bool isWeighted = false;
                ValueGetter<bool> isWeightedGetter;
                if (hasWeighted)
                    isWeightedGetter = cursor.GetGetter<bool>(isWeightedColumn.Value);
                else
                    isWeightedGetter = (ref bool dst) => dst = false;
 
                ValueGetter<uint> stratColGetter;
                if (hasStrats)
                {
                    var type = cursor.Schema[stratCol].Type;
                    stratColGetter = RowCursorUtils.GetGetterAs<uint>(type, cursor, stratCol);
                }
                else
                    stratColGetter = (ref uint dst) => dst = 0;
 
                // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns
                // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't.
                List<string> metricNames;
                using (var ch = env.Register("GetMetricsAsString").Start("Get Metric Names"))
                {
                    metricNames = GetMetricNames(ch, data.Schema, cursor,
                        i => hasWeighted && i == isWeightedColumn.Value.Index || hasStrats && (i == scol || i == svalcol) ||
                            hasFoldCol && i == fcol, getters, vBufferGetters);
                }
                agg = new AggregatedMetric[metricNames.Count];
 
                Double metricVal = 0;
                VBuffer<Double> metricVals = default(VBuffer<Double>);
                if (hasWeighted)
                    weightedAgg = new AggregatedMetric[metricNames.Count];
                else
                    weightedAgg = null;
                uint strat = 0;
                while (cursor.MoveNext())
                {
                    stratColGetter(ref strat);
                    // REVIEW: how to print stratified results?
                    if (strat > 0)
                        continue;
 
                    isWeightedGetter(ref isWeighted);
                    if (isWeighted)
                    {
                        // If !average, we should have only one relevant row.
                        if (numWeightedResults > numFolds)
                            throw Contracts.Except("Multiple weighted rows found in metrics data view.");
 
                        numWeightedResults++;
                        UpdateSums(isWeightedCol, stratCol, stratVal, weightedAgg, numFolds > 1, metricNames, hasWeighted,
                            hasStrats, colCount, getters, vBufferGetters, ref metricVal, ref metricVals);
                    }
                    else
                    {
                        // If !average, we should have only one relevant row.
                        if (numResults > numFolds)
                            throw Contracts.Except("Multiple unweighted rows found in metrics data view.");
 
                        numResults++;
                        UpdateSums(isWeightedCol, stratCol, stratVal, agg, numFolds > 1, metricNames, hasWeighted, hasStrats,
                            colCount, getters, vBufferGetters, ref metricVal, ref metricVals);
                    }
 
                    if (numResults == numFolds && (!hasWeighted || numWeightedResults == numFolds))
                        break;
                }
            }
            return agg;
        }
 
        private static void UpdateSums(int isWeightedCol, int stratCol, int stratVal, AggregatedMetric[] aggregated, bool hasStdev, List<string> metricNames, bool hasWeighted, bool hasStrats, int colCount, ValueGetter<double>[] getters, ValueGetter<VBuffer<double>>[] vBufferGetters, ref double metricVal, ref VBuffer<double> metricVals)
        {
            int iMetric = 0;
            for (int i = 0; i < colCount; i++)
            {
                if (hasWeighted && i == isWeightedCol || hasStrats && (i == stratCol || i == stratVal))
                    continue;
 
                if (getters[i] == null && vBufferGetters[i] == null)
                {
                    // REVIEW: What to do with metrics that are not doubles?
                    continue;
                }
                if (getters[i] != null)
                {
                    getters[i](ref metricVal);
                    aggregated[iMetric].Sum += metricVal;
                    if (hasStdev)
                        aggregated[iMetric].SumSq += metricVal * metricVal;
                    aggregated[iMetric].Name = metricNames[iMetric];
                    iMetric++;
                }
                else
                {
                    Contracts.AssertValue(vBufferGetters[i]);
                    vBufferGetters[i](ref metricVals);
                    foreach (var metric in metricVals.Items(all: true))
                    {
                        aggregated[iMetric].Sum += metric.Value;
                        if (hasStdev)
                            aggregated[iMetric].SumSq += metric.Value * metric.Value;
                        aggregated[iMetric].Name = metricNames[iMetric];
                        iMetric++;
                    }
                }
            }
            Contracts.Assert(iMetric == metricNames.Count);
        }
 
        internal static IDataView GetAverageToDataView(IHostEnvironment env, DataViewSchema schema, AggregatedMetric[] agg, AggregatedMetric[] weightedAgg,
            int numFolds, int stratCol, int stratVal, int isWeightedCol, int foldCol, bool hasStdev, List<string> nonAveragedCols = null)
        {
            Contracts.AssertValue(env);
 
            int colCount = schema.Count;
 
            var dvBldr = new ArrayDataViewBuilder(env);
            var weightedDvBldr = isWeightedCol >= 0 ? new ArrayDataViewBuilder(env) : null;
 
            int iMetric = 0;
            for (int i = 0; i < colCount; i++)
            {
                if (schema[i].IsHidden)
                    continue;
 
                var type = schema[i].Type;
                var name = schema[i].Name;
                if (i == stratCol)
                {
                    int typeKeyCount = type.GetKeyCountAsInt32(env);
 
                    var keyValuesType = schema[i].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
                    if (keyValuesType == null || !(keyValuesType.ItemType is TextDataViewType) ||
                        keyValuesType.Size != typeKeyCount)
                    {
                        throw env.Except("Column '{0}' must have key values metadata",
                            MetricKinds.ColumnNames.StratCol);
                    }
 
                    ValueGetter<VBuffer<ReadOnlyMemory<char>>> getKeyValues =
                        (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                        {
                            schema[stratCol].GetKeyValues(ref dst);
                            Contracts.Assert(dst.IsDense);
                        };
 
                    var keys = foldCol >= 0 ? new uint[] { 0, 0 } : new uint[] { 0 };
                    dvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, getKeyValues, type.GetKeyCount(), keys);
                    weightedDvBldr?.AddColumn(MetricKinds.ColumnNames.StratCol, getKeyValues, type.GetKeyCount(), keys);
                }
                else if (i == stratVal)
                {
                    //REVIEW: Not sure if empty string makes sense here.
                    var stratVals = foldCol >= 0 ? new[] { "".AsMemory(), "".AsMemory() } : new[] { "".AsMemory() };
                    dvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, stratVals);
                    weightedDvBldr?.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, stratVals);
                }
                else if (i == isWeightedCol)
                {
                    env.AssertValue(weightedDvBldr);
                    dvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BooleanDataViewType.Instance, foldCol >= 0 ? new[] { false, false } : new[] { false });
                    weightedDvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BooleanDataViewType.Instance, foldCol >= 0 ? new[] { true, true } : new[] { true });
                }
                else if (i == foldCol)
                {
                    var foldVals = new[] { "Average".AsMemory(), "Standard Deviation".AsMemory() };
                    dvBldr.AddColumn(MetricKinds.ColumnNames.FoldIndex, TextDataViewType.Instance, foldVals);
                    weightedDvBldr?.AddColumn(MetricKinds.ColumnNames.FoldIndex, TextDataViewType.Instance, foldVals);
                }
                else if (type is NumberDataViewType)
                {
                    dvBldr.AddScalarColumn(schema, agg, hasStdev, numFolds, iMetric);
                    weightedDvBldr?.AddScalarColumn(schema, weightedAgg, hasStdev, numFolds, iMetric);
                    iMetric++;
                }
                else if (type is VectorDataViewType vectorType && vectorType.IsKnownSize && vectorType.ItemType == NumberDataViewType.Double)
                {
                    dvBldr.AddVectorColumn(env, schema, agg, hasStdev, numFolds, iMetric, i, vectorType, name);
                    weightedDvBldr?.AddVectorColumn(env, schema, weightedAgg, hasStdev, numFolds, iMetric, i, vectorType, name);
                    iMetric += vectorType.Size;
                }
                else
                    nonAveragedCols?.Add(name);
            }
            var idv = dvBldr.GetDataView();
            if (weightedDvBldr != null)
                idv = AppendRowsDataView.Create(env, idv.Schema, idv, weightedDvBldr.GetDataView());
            return idv;
        }
 
        private static void AddVectorColumn(this ArrayDataViewBuilder dvBldr, IHostEnvironment env, DataViewSchema schema,
            AggregatedMetric[] agg, bool hasStdev, int numFolds, int iMetric, int i, VectorDataViewType type, string columnName)
        {
            var vectorMetrics = new double[type.Size];
            env.Assert(vectorMetrics.Length > 0);
            for (int j = 0; j < vectorMetrics.Length; j++)
                vectorMetrics[j] = agg[iMetric + j].Sum / numFolds;
            double[] vectorStdevMetrics = null;
            if (hasStdev)
            {
                vectorStdevMetrics = new double[type.Size];
                for (int j = 0; j < vectorStdevMetrics.Length; j++)
                    vectorStdevMetrics[j] = Math.Sqrt(agg[iMetric + j].SumSq / numFolds - vectorMetrics[j] * vectorMetrics[j]);
            }
            var names = new ReadOnlyMemory<char>[type.Size];
            for (int j = 0; j < names.Length; j++)
                names[j] = agg[iMetric + j].Name.AsMemory();
            var slotNames = new VBuffer<ReadOnlyMemory<char>>(type.Size, names);
            ValueGetter<VBuffer<ReadOnlyMemory<char>>> getSlotNames = (ref VBuffer<ReadOnlyMemory<char>> dst) => dst = slotNames;
            if (vectorStdevMetrics != null)
            {
                env.AssertValue(vectorStdevMetrics);
                dvBldr.AddColumn(columnName, getSlotNames, NumberDataViewType.Double, new[] { vectorMetrics, vectorStdevMetrics });
            }
            else
                dvBldr.AddColumn(columnName, getSlotNames, NumberDataViewType.Double, new[] { vectorMetrics });
        }
 
        private static void AddScalarColumn(this ArrayDataViewBuilder dvBldr, DataViewSchema schema, AggregatedMetric[] agg, bool hasStdev, int numFolds, int iMetric)
        {
            Contracts.AssertValue(dvBldr);
 
            var avg = agg[iMetric].Sum / numFolds;
            if (hasStdev)
                dvBldr.AddColumn(agg[iMetric].Name, NumberDataViewType.Double, avg, Math.Sqrt(agg[iMetric].SumSq / numFolds - avg * avg));
            else
                dvBldr.AddColumn(agg[iMetric].Name, NumberDataViewType.Double, avg);
        }
 
        /// <summary>
        /// Takes a data view containing one or more rows of metrics, and returns a data view containing additional
        /// rows with the average and the standard deviation of the metrics in the input data view.
        /// </summary>
        public static IDataView CombineFoldMetricsDataViews(IHostEnvironment env, IDataView data, int numFolds)
        {
            return GetOverallMetricsData(env, data, numFolds, out var _, out var _);
        }
    }
 
    internal static class MetricWriter
    {
        /// <summary>
        /// Get the confusion tables as strings to be printed to the Console.
        /// </summary>
        /// <param name="host">The host is used for getting the random number generator for sampling classes</param>
        /// <param name="confusionDataView">The data view containing the confusion matrix. It should contain a text column
        /// with the label names named "LabelNames", and an R8 vector column named "Count" containing the counts: in the row
        /// corresponding to label i, slot j should contain the number of class i examples that were predicted as j by the predictor.</param>
        /// <param name="weightedConfusionTable">If there is an R8 vector column named "Weight" containing the weighted counts, this parameter
        /// is assigned the string representation of the weighted confusion table. Otherwise it is assigned null.</param>
        /// <param name="binary">Indicates whether the confusion table is for binary classification.</param>
        /// <param name="sample">Indicates how many classes to sample from the confusion table (-1 indicates no sampling)</param>
        public static string GetConfusionTableAsFormattedString(IHost host, IDataView confusionDataView, out string weightedConfusionTable, bool binary = true, int sample = -1)
        {
            host.CheckValue(confusionDataView, nameof(confusionDataView));
            host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2");
 
            var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight);
            bool isWeighted = weightColumn.HasValue;
 
            var confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, false);
            var confusionTableString = GetConfusionTableAsString(confusionMatrix, false);
 
            // If there is a Weight column, return the weighted confusionMatrix as well, from this function.
            if (isWeighted)
            {
                confusionMatrix = GetConfusionMatrix(host, confusionDataView, binary, sample, true);
                weightedConfusionTable = GetConfusionTableAsString(confusionMatrix, true);
            }
            else
                weightedConfusionTable = null;
 
            return confusionTableString;
        }
 
        public static ConfusionMatrix GetConfusionMatrix(IHost host, IDataView confusionDataView, bool binary = true, int sample = -1, bool getWeighted = false)
        {
            host.CheckValue(confusionDataView, nameof(confusionDataView));
            host.CheckParam(sample == -1 || sample >= 2, nameof(sample), "Should be -1 to indicate no sampling, or at least 2");
 
            // check that there is a Weight column, if isWeighted parameter is set to true.
            var weightColumn = confusionDataView.Schema.GetColumnOrNull(MetricKinds.ColumnNames.Weight);
            if (getWeighted)
                host.CheckParam(weightColumn.HasValue, nameof(getWeighted), "There is no Weight column in the confusionMatrix data view.");
 
            // Get the counts names.
            var countColumn = confusionDataView.Schema[MetricKinds.ColumnNames.Count];
            var type = countColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
            //"The Count column does not have a text vector metadata of kind SlotNames."
            host.Assert(type != null && type.IsKnownSize && type.ItemType is TextDataViewType);
 
            // Get the class names
            var labelNames = default(VBuffer<ReadOnlyMemory<char>>);
            countColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref labelNames);
            host.Assert(labelNames.IsDense, "Slot names vector must be dense");
 
            int numConfusionTableLabels = sample < 0 ? labelNames.Length : Math.Min(labelNames.Length, sample);
 
            // Sample the classes. We choose a random permutation, keep the first 'sample' indices and drop the rest.
            // The labelIndexToConfIndexMap array indicates for each class its index in the confusion table, or -1 if it is dropped
            var labelIndexToConfIndexMap = new int[labelNames.Length];
            if (numConfusionTableLabels < labelNames.Length)
            {
                var tempPerm = Utils.GetRandomPermutation(host.Rand, labelNames.Length);
                var sampledIndices = tempPerm.Skip(labelNames.Length - numConfusionTableLabels).OrderBy(i => i);
 
                for (int i = 0; i < labelIndexToConfIndexMap.Length; i++)
                    labelIndexToConfIndexMap[i] = -1;
                int countNotDropped = 0;
                foreach (var i in sampledIndices)
                    labelIndexToConfIndexMap[i] = countNotDropped++;
            }
            else
            {
                for (int i = 0; i < labelNames.Length; i++)
                    labelIndexToConfIndexMap[i] = i;
            }
 
            double[] precisionSums;
            double[] recallSums;
            double[][] confusionTable;
 
            if (getWeighted)
                confusionTable = GetConfusionTableAsArray(confusionDataView, weightColumn.Value.Index, labelNames.Length,
                  labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
            else
                confusionTable = GetConfusionTableAsArray(confusionDataView, countColumn.Index, labelNames.Length,
                labelIndexToConfIndexMap, numConfusionTableLabels, out precisionSums, out recallSums);
 
            double[] precision = new double[numConfusionTableLabels];
            double[] recall = new double[numConfusionTableLabels];
            for (int i = 0; i < numConfusionTableLabels; i++)
            {
                recall[i] = recallSums[i] > 0 ? confusionTable[i][i] / recallSums[i] : 0;
                precision[i] = precisionSums[i] > 0 ? confusionTable[i][i] / precisionSums[i] : 0;
            }
 
            var predictedLabelNames = GetPredictedLabelNames(in labelNames, labelIndexToConfIndexMap);
            bool sampled = numConfusionTableLabels < labelNames.Length;
 
            return new ConfusionMatrix(host, precision, recall, confusionTable, predictedLabelNames, sampled, binary);
        }
 
        private static List<ReadOnlyMemory<char>> GetPredictedLabelNames(in VBuffer<ReadOnlyMemory<char>> labelNames, int[] labelIndexToConfIndexMap)
        {
            List<ReadOnlyMemory<char>> result = new List<ReadOnlyMemory<char>>();
            var values = labelNames.GetValues();
            for (int i = 0; i < values.Length; i++)
            {
                if (labelIndexToConfIndexMap[i] >= 0)
                {
                    result.Add(values[i]);
                }
            }
            return result;
        }
 
        // This methods is given a data view and a column index of the counts, and computes three arrays: the confusion table,
        // the per class recall and the per class precision.
        private static double[][] GetConfusionTableAsArray(IDataView confusionDataView, int countIndex, int numClasses,
            int[] labelIndexToConfIndexMap, int numConfusionTableLabels, out double[] precisionSums, out double[] recallSums)
        {
            var confusionTable = new Double[numConfusionTableLabels][];
            for (int i = 0; i < numConfusionTableLabels; i++)
                confusionTable[i] = new Double[numConfusionTableLabels];
 
            precisionSums = new Double[numConfusionTableLabels];
            recallSums = new Double[numConfusionTableLabels];
 
            int stratCol;
            var hasStrat = confusionDataView.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol);
            using (var cursor = confusionDataView.GetRowCursor(confusionDataView.Schema.Where(col => col.Index == countIndex || hasStrat && col.Index == stratCol)))
            {
                var type = cursor.Schema[countIndex].Type as VectorDataViewType;
                Contracts.Check(type != null && type.IsKnownSize && type.ItemType == NumberDataViewType.Double);
                var countGetter = cursor.GetGetter<VBuffer<double>>(cursor.Schema[countIndex]);
                ValueGetter<uint> stratGetter = null;
                if (hasStrat)
                {
                    var stratType = cursor.Schema[stratCol].Type;
                    stratGetter = RowCursorUtils.GetGetterAs<uint>(stratType, cursor, stratCol);
                }
 
                var count = default(VBuffer<double>);
                int numRows = -1;
                while (cursor.MoveNext())
                {
                    uint strat = 0;
                    if (stratGetter != null)
                        stratGetter(ref strat);
                    if (strat > 0)
                        continue;
 
                    numRows++;
                    if (labelIndexToConfIndexMap[numRows] < 0)
                        continue;
 
                    countGetter(ref count);
                    if (count.Length != numClasses)
                        throw Contracts.Except("Expected {0} values in 'Count' column, but got {1}.", numClasses, count.Length);
                    int row = labelIndexToConfIndexMap[numRows];
                    foreach (var val in count.Items())
                    {
                        var index = val.Key;
                        if (labelIndexToConfIndexMap[index] < 0)
                            continue;
 
                        confusionTable[row][labelIndexToConfIndexMap[index]] = val.Value;
                        precisionSums[labelIndexToConfIndexMap[index]] += val.Value;
                        recallSums[row] += val.Value;
                    }
 
                    if (numRows == numClasses - 1)
                        break;
                }
            }
            return confusionTable;
        }
 
        /// <summary>
        /// This method returns the per-fold metrics as a string. If weighted metrics are present they are returned in a separate string.
        /// </summary>
        /// <param name="env">An IHostEnvironment.</param>
        /// <param name="fold">The data view containing the per-fold metrics. Each row in the data view represents a set of metrics
        /// calculated either on the whole dataset or on a subset of it defined by a stratification column. If the data view contains
        /// stratified metrics, it must contain two text columns named "StratCol" and "StratVal", containing the stratification column
        /// name, and a text description of the value. In this case, the value of column StratVal in the row corresponding to the entire
        /// dataset should contain the text "overall", and the value of column StratCol should be DvText.NA. If weighted metrics are present
        /// then the data view should also contain a bool column named "IsWeighted".</param>
        /// <param name="weightedMetrics">If the IsWeighted column exists, this is assigned the string representation of the weighted
        /// metrics. Otherwise it is assigned null.</param>
        public static string GetPerFoldResults(IHostEnvironment env, IDataView fold, out string weightedMetrics)
        {
            return GetFoldMetricsAsString(env, fold, out weightedMetrics);
        }
 
        private static string GetOverallMetricsAsString(double[] sumMetrics, double[] sumSqMetrics, int numFolds, bool weighted, bool average, List<string> metricNames)
        {
            var sb = new StringBuilder();
            for (int i = 0; i < metricNames.Count; i++)
            {
                var avg = sumMetrics[i] / numFolds;
                sb.Append(string.Format("{0}{1}: ", weighted ? "Weighted " : "", metricNames[i]).PadRight(20));
                sb.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", avg));
                if (average)
                {
                    Contracts.Assert(sumSqMetrics != null || numFolds == 1);
                    sb.AppendLine(string.Format(" ({0:N4})", numFolds == 1 ? 0 :
                        Math.Sqrt(sumSqMetrics[i] / numFolds - avg * avg)));
                }
                else
                    sb.AppendLine();
            }
            return sb.ToString();
        }
 
        // This method returns a string representation of a set of metrics. If there are stratification columns, it looks for columns named
        // StratCol and StratVal, and outputs the metrics in the rows with NA in the StratCol column. If weighted is true, it looks
        // for a bool column named "IsWeighted" and outputs the metrics in the rows with a value of true in that column.
        // If nonAveragedCols is non-null, it computes the average and standard deviation over all the relevant rows and populates
        // nonAveragedCols with columns that are either hidden, or are not of a type that we can display (i.e., either a numeric column,
        // or a known length vector of doubles).
        // If average is false, no averaging is done, and instead we check that there is exactly one relevant row. Otherwise, we
        // add the vector columns of variable length of the list of non-averagable columns if nonAveragedCols is not null.
        private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView data, out string weightedMetricsString)
        {
            var metrics = EvaluateUtils.ComputeMetricsSum(env, data, 1, out int isWeightedCol, out int stratCol,
                out int stratVal, out int foldCol, out var weightedMetrics);
 
            var sb = new StringBuilder();
            var weightedSb = isWeightedCol >= 0 ? new StringBuilder() : null;
            for (int i = 0; i < metrics.Length; i++)
            {
                sb.Append($"{metrics[i].Name}: ".PadRight(20));
                sb.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", metrics[i].Sum));
                weightedSb?.Append($"Weighted {weightedMetrics[i].Name}: ".PadRight(20));
                weightedSb?.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", weightedMetrics[i].Sum));
                sb.AppendLine();
                weightedSb?.AppendLine();
            }
 
            weightedMetricsString = weightedSb?.ToString();
            return sb.ToString();
        }
 
        // Get a string representation of a confusion table.
        internal static string GetConfusionTableAsString(ConfusionMatrix confusionMatrix, bool isWeighted)
        {
            string prefix = isWeighted ? "Weighted " : "";
            int numLabels = confusionMatrix?.Counts == null ? 0 : confusionMatrix.Counts.Count;
 
            int colWidth = numLabels == 2 ? 8 : 5;
            int maxNameLen = confusionMatrix.PredictedClassesIndicators.Max(name => name.Length);
            // If the names are too long to fit in the column header, we back off to using class indices
            // in the header. This will also require putting the indices in the row, but it's better than
            // the alternative of having ambiguous abbreviated column headers, or having a table potentially
            // too wide to fit in a console.
            bool useNumbersInHeader = maxNameLen > colWidth;
 
            int rowLabelLen = maxNameLen;
            int rowDigitLen = 0;
            if (useNumbersInHeader)
            {
                // The row label will also include the index, so a user can easily match against the header.
                // In such a case, a label like "Foo" would be presented as something like "5. Foo".
                rowDigitLen = Math.Max(confusionMatrix.PredictedClassesIndicators.Count - 1, 0).ToString().Length;
                Contracts.Assert(rowDigitLen >= 1);
                rowLabelLen += rowDigitLen + 2;
            }
            Contracts.Assert((rowDigitLen == 0) == !useNumbersInHeader);
 
            // The "PREDICTED" in the table, at length 9, dictates the amount of additional padding that will
            // be necessary on account of label names.
            int paddingLen = Math.Max(9, rowLabelLen);
            string pad = new string(' ', paddingLen - 9);
            string rowLabelFormat = null;
            if (useNumbersInHeader)
            {
                int namePadLen = paddingLen - (rowDigitLen + 2);
                rowLabelFormat = string.Format("{{0,{0}}}. {{1,{1}}} ||", rowDigitLen, namePadLen);
            }
            else
                rowLabelFormat = string.Format("{{1,{0}}} ||", paddingLen);
 
            var confusionTable = confusionMatrix.Counts;
            var sb = new StringBuilder();
            if (numLabels == 2 && confusionMatrix.IsBinary)
            {
                var positiveCaps = confusionMatrix.PredictedClassesIndicators[0].ToString().ToUpper();
 
                var numTruePos = confusionTable[0][0];
                var numFalseNeg = confusionTable[0][1];
                var numTrueNeg = confusionTable[1][1];
                var numFalsePos = confusionTable[1][0];
                sb.AppendFormat("{0}TEST {1} RATIO:\t{2:N4} ({3:F1}/({3:F1}+{4:F1}))", prefix, positiveCaps,
                    1.0 * (numTruePos + numFalseNeg) / (numTruePos + numTrueNeg + numFalseNeg + numFalsePos),
                    numTruePos + numFalseNeg, numFalsePos + numTrueNeg);
            }
 
            sb.AppendLine();
            sb.AppendFormat("{0}Confusion table", prefix);
            if (confusionMatrix.IsSampled)
                sb.AppendLine(" (sampled)");
            else
                sb.AppendLine();
 
            sb.AppendFormat("          {0}||", pad);
            for (int i = 0; i < numLabels; i++)
                sb.Append(numLabels > 2 ? "========" : "===========");
            sb.AppendLine();
            sb.AppendFormat("PREDICTED {0}||", pad);
            string format = string.Format(" {{{0},{1}}} |", useNumbersInHeader ? 0 : 1, colWidth);
            for (int i = 0; i < numLabels; i++)
                sb.AppendFormat(format, i, confusionMatrix.PredictedClassesIndicators[i]);
            sb.AppendLine(" Recall");
            sb.AppendFormat("TRUTH     {0}||", pad);
            for (int i = 0; i < numLabels; i++)
                sb.Append(numLabels > 2 ? "========" : "===========");
 
            sb.AppendLine();
 
            string format2 = string.Format(" {{0,{0}:{1}}} |", colWidth,
                string.IsNullOrWhiteSpace(prefix) ? "N0" : "F1");
            for (int i = 0; i < numLabels; i++)
            {
                sb.AppendFormat(rowLabelFormat, i, confusionMatrix.PredictedClassesIndicators[i]);
                for (int j = 0; j < numLabels; j++)
                    sb.AppendFormat(format2, confusionTable[i][j]);
                sb.AppendFormat(" {0,5:F4}", confusionMatrix.PerClassRecall[i]);
                sb.AppendLine();
            }
            sb.AppendFormat("          {0}||", pad);
            for (int i = 0; i < numLabels; i++)
                sb.Append(numLabels > 2 ? "========" : "===========");
            sb.AppendLine();
            sb.AppendFormat("Precision {0}||", pad);
            format = string.Format("{{0,{0}:N4}} |", colWidth + 1);
            for (int i = 0; i < numLabels; i++)
                sb.AppendFormat(format, confusionMatrix.PerClassPrecision[i]);
 
            sb.AppendLine();
            return sb.ToString();
        }
 
        /// <summary>
        /// Print the overall results to the Console. The overall data view should contain rows from all the folds being averaged.
        /// If filename is not null then also save the results to the specified file. The first row in the file is the averaged
        /// results, followed by the results of each fold.
        /// </summary>
        public static void PrintOverallMetrics(IHostEnvironment env, IChannel ch, string filename, IDataView overall, int numFolds)
        {
            var overallWithAvg = EvaluateUtils.GetOverallMetricsData(env, overall, numFolds, out var agg, out var weightedAgg);
 
            var sb = new StringBuilder();
            sb.AppendLine();
            sb.AppendLine("OVERALL RESULTS");
            sb.AppendLine("---------------------------------------");
 
            var nonAveragedCols = new List<string>();
            if (weightedAgg != null)
                sb.Append(GetOverallMetricsAsString(weightedAgg.Select(x => x.Sum).ToArray(), weightedAgg.Select(x => x.SumSq).ToArray(), numFolds, true, true, weightedAgg.Select(x => x.Name).ToList()));
            sb.Append(GetOverallMetricsAsString(agg.Select(x => x.Sum).ToArray(), agg.Select(x => x.SumSq).ToArray(), numFolds, false, true, agg.Select(x => x.Name).ToList()));
            sb.AppendLine("\n---------------------------------------");
            ch.Info(sb.ToString());
 
            if (!string.IsNullOrEmpty(filename))
            {
                using (var file = env.CreateOutputFile(filename))
                {
                    var saverArgs = new TextSaver.Arguments() { Dense = true, Silent = true };
                    DataSaverUtils.SaveDataView(ch, new TextSaver(env, saverArgs), overallWithAvg, file);
                }
            }
        }
 
        private static string PadLeft(string str, int totalLength)
        {
            if (str.Length > totalLength)
                return str.Substring(0, totalLength - 1).PadRight(totalLength, '.');
            return str.PadLeft(totalLength);
        }
 
        /// <summary>
        /// Searches for a warning dataview in the given dictionary, and if present, prints the warnings to the given channel. The warning dataview
        /// should contain a text column named "WarningText".
        /// </summary>
        public static void PrintWarnings(IChannel ch, Dictionary<string, IDataView> metrics)
        {
            IDataView warnings;
            if (metrics.TryGetValue(MetricKinds.Warnings, out warnings))
            {
                var warningTextColumn = warnings.Schema.GetColumnOrNull(MetricKinds.ColumnNames.WarningText);
                if (warningTextColumn != null && warningTextColumn.HasValue && warningTextColumn.Value.Type is TextDataViewType)
                {
                    using (var cursor = warnings.GetRowCursor(warnings.Schema[MetricKinds.ColumnNames.WarningText]))
                    {
                        var warning = default(ReadOnlyMemory<char>);
                        var getter = cursor.GetGetter<ReadOnlyMemory<char>>(warningTextColumn.Value);
                        while (cursor.MoveNext())
                        {
                            getter(ref warning);
                            ch.Warning(warning.ToString());
                        }
                    }
                }
            }
        }
 
        /// <summary>
        ///  Save the given data view using text saver.
        /// </summary>
        public static void SavePerInstance(IHostEnvironment env, IChannel ch, string filename, IDataView data,
            bool dense = true, bool saveSchema = false)
        {
            using (var file = env.CreateOutputFile(filename))
            {
                DataSaverUtils.SaveDataView(ch,
                    new TextSaver(env, new TextSaver.Arguments() { OutputSchema = saveSchema, Dense = dense, Silent = true }),
                    data, file);
            }
        }
 
        /// <summary>
        /// Filter out the stratified results from overall and drop the stratification columns.
        /// </summary>
        public static IDataView GetNonStratifiedMetrics(IHostEnvironment env, IDataView data)
        {
            int stratCol;
            if (!data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol))
                return data;
            var type = data.Schema[stratCol].Type;
            env.Check(type.GetKeyCount() > 0, "Expected a known count key type stratification column");
            var filterArgs = new NAFilter.Arguments();
            filterArgs.Columns = new[] { MetricKinds.ColumnNames.StratCol };
            filterArgs.Complement = true;
            data = new NAFilter(env, filterArgs, data);
 
            int stratVal;
            var found = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal);
            env.Check(found, "If stratification column exist, data view must also contain a StratVal column");
 
            data = ColumnSelectingTransformer.CreateDrop(env, data, data.Schema[stratCol].Name, data.Schema[stratVal].Name);
            return data;
        }
    }
 
    /// <summary>
    /// This is a list of string constants denoting 'standard' metric kinds.
    /// </summary>
    [BestFriend]
    internal static class MetricKinds
    {
        /// <summary>
        /// This data view contains the confusion matrix for N-class classification. It has N rows, and each row has
        /// the following columns:
        /// * Count (vector indicating how many examples of this class were predicted as each one of the classes). This column
        /// should have metadata containing the class names.
        /// * (Optional) Weight (vector with the total weight of the examples of this class that were predicted as each one of the classes).
        /// </summary>
        public const string ConfusionMatrix = "ConfusionMatrix";
 
        /// <summary>
        /// This is a data view with 'global' dataset-wise metrics in its columns. It has one row containing the overall metrics,
        /// and optionally more rows for weighted metrics, and stratified metrics.
        /// </summary>
        public const string OverallMetrics = "OverallMetrics";
 
        /// <summary>
        /// This is a data view with precision recall data in its columns. It has four columns: Threshold, Precision, Recall and Fpr.
        /// </summary>
        public const string PrCurve = "PrCurve";
 
        /// <summary>
        /// This data view contains a single text column, with warnings about bad input values encountered by the evaluator during
        /// the aggregation of metrics. Each warning is in a separate row.
        /// </summary>
        public const string Warnings = "Warnings";
 
        /// <summary>
        /// Names for the columns in the data views output by evaluators.
        /// </summary>
        public sealed class ColumnNames
        {
            public const string WarningText = "WarningText";
            public const string IsWeighted = "IsWeighted";
            public const string Count = "Count";
            public const string Weight = "Weight";
            public const string StratCol = "StratCol";
            public const string StratVal = "StratVal";
            public const string FoldIndex = "Fold Index";
        }
    }
}