File: ScoreColumnSelector.cs
Web Access
Project: src\src\Microsoft.ML.EntryPoints\Microsoft.ML.EntryPoints.csproj (Microsoft.ML.EntryPoints)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
namespace Microsoft.ML.EntryPoints
{
    internal static partial class ScoreModel
    {
        public sealed class ScoreColumnSelectorInput : TransformInputBase
        {
            [Argument(ArgumentType.Multiple, HelpText = "Extra columns to write", SortOrder = 2)]
            public string[] ExtraColumns;
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.ScoreColumnSelector", Desc = "Selects only the last score columns and the extra columns specified in the arguments.", UserName = "Choose Columns By Indices")]
        public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env, ScoreColumnSelectorInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(env, input);
            var view = input.Data;
            var maxScoreId = view.Schema.GetMaxAnnotationKind(out int colMax, AnnotationUtils.Kinds.ScoreColumnSetId);
            List<int> indices = new List<int>();
            for (int i = 0; i < view.Schema.Count; i++)
            {
                if (view.Schema[i].IsHidden)
                    continue;
                if (!ShouldAddColumn(view.Schema, i, input.ExtraColumns, maxScoreId))
                    continue;
                indices.Add(i);
            }
            var newView = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Options() { Indices = indices.ToArray() }, input.Data);
            return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, newView, input.Data), OutputData = newView };
        }
 
        private static bool ShouldAddColumn(DataViewSchema schema, int i, string[] extraColumns, uint scoreSet)
        {
            uint scoreSetId = 0;
            if (schema.TryGetAnnotation(AnnotationUtils.ScoreColumnSetIdType, AnnotationUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
                && scoreSetId == scoreSet)
            {
                return true;
            }
            var columnName = schema[i].Name;
            if (extraColumns != null && Array.FindIndex(extraColumns, columnName.Equals) >= 0)
                return true;
            return false;
        }
 
        public sealed class RenameBinaryPredictionScoreColumnsInput : TransformInputBase
        {
            [Argument(ArgumentType.Required, HelpText = "The predictor model used in scoring", SortOrder = 2)]
            public PredictorModel PredictorModel;
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.BinaryPredictionScoreColumnsRenamer", Desc = "For binary prediction, it renames the PredictedLabel and Score columns to include the name of the positive class.", UserName = "Rename Binary Prediction Score Columns")]
        public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(IHostEnvironment env,
            RenameBinaryPredictionScoreColumnsInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ScoreModel");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            if (input.PredictorModel.Predictor.PredictionKind == PredictionKind.BinaryClassification)
            {
                DataViewType labelType;
                var labelNames = input.PredictorModel.GetLabelInfo(host, out labelType);
                if (labelNames != null && labelNames.Length == 2)
                {
                    var positiveClass = labelNames[1];
 
                    // Rename all the score columns.
                    int colMax;
                    var maxScoreId = input.Data.Schema.GetMaxAnnotationKind(out colMax, AnnotationUtils.Kinds.ScoreColumnSetId);
                    var copyCols = new List<(string name, string source)>();
                    for (int i = 0; i < input.Data.Schema.Count; i++)
                    {
                        if (input.Data.Schema[i].IsHidden)
                            continue;
                        if (!ShouldAddColumn(input.Data.Schema, i, null, maxScoreId))
                            continue;
                        // Do not rename the PredictedLabel column.
                        ReadOnlyMemory<char> tmp = default;
                        if (input.Data.Schema.TryGetAnnotation(TextDataViewType.Instance, AnnotationUtils.Kinds.ScoreValueKind, i,
                            ref tmp)
                            && ReadOnlyMemoryUtils.EqualsStr(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, tmp))
                        {
                            continue;
                        }
                        var source = input.Data.Schema[i].Name;
                        var name = source + "." + positiveClass;
                        copyCols.Add((name, source));
                    }
 
                    var copyColumn = new ColumnCopyingTransformer(env, copyCols.ToArray()).Transform(input.Data);
                    var dropColumn = ColumnSelectingTransformer.CreateDrop(env, copyColumn, copyCols.Select(c => c.source).ToArray());
                    return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, dropColumn, input.Data), OutputData = dropColumn };
                }
            }
 
            var newView = NopTransform.CreateIfNeeded(env, input.Data);
            return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, newView, input.Data), OutputData = newView };
        }
    }
}