File: Commands\ScoreCommand.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.Command;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
[assembly: LoadableClass(ScoreCommand.Summary, typeof(ScoreCommand), typeof(ScoreCommand.Arguments), typeof(SignatureCommand),
    "Score Predictor", "Score")]
 
namespace Microsoft.ML.Data
{
    using TScorerFactory = IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>;
 
    [BestFriend]
    internal interface IDataScorerTransform : IDataTransform, ITransformTemplate
    {
    }
 
    /// <summary>
    /// Signature for creating an <see cref="IDataScorerTransform"/>.
    /// </summary>
    /// <param name="data">The data containing the columns to score</param>
    /// <param name="mapper">The mapper, already bound to the schema column in <paramref name="data"/></param>
    /// <param name="trainSchema">This parameter holds a snapshot of the role mapped training schema as
    /// it existed at the point when <paramref name="mapper"/> was trained, or <c>null</c> if it not
    /// available for some reason</param>
    [BestFriend]
    internal delegate void SignatureDataScorer(IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema);
 
    [BestFriend]
    internal delegate void SignatureBindableMapper(IPredictor predictor);
 
    internal sealed class ScoreCommand : DataCommand.ImplBase<ScoreCommand.Arguments>
    {
        public sealed class Arguments : DataCommand.ArgumentsBase
        {
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Column to use for features when scorer is not defined", ShortName = "feat")]
            public string FeatureColumn = DefaultColumnNames.Features;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Group column name", ShortName = "group")]
            public string GroupColumn = DefaultColumnNames.GroupId;
 
            [Argument(ArgumentType.Multiple,
                HelpText = "Input columns: Columns with custom kinds declared through key assignments, for example, col[Kind]=Name to assign column named 'Name' kind 'Kind'",
                Name = "CustomColumn", ShortName = "col", SortOrder = 10)]
            public KeyValuePair<string, string>[] CustomColumns;
 
            [Argument(ArgumentType.Multiple, HelpText = "Scorer to use", SignatureType = typeof(SignatureDataScorer))]
            public TScorerFactory Scorer;
 
            [Argument(ArgumentType.Multiple, HelpText = "The data saver to use", SignatureType = typeof(SignatureDataSaver))]
            public IComponentFactory<IDataSaver> Saver;
 
            [Argument(ArgumentType.LastOccurrenceWins, HelpText = "File to save the data", ShortName = "dout")]
            public string OutputDataFile;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to include hidden columns", ShortName = "keep")]
            public bool KeepHidden;
 
            [Argument(ArgumentType.Multiple, HelpText = "Post processing transform", ShortName = "pxf", SignatureType = typeof(SignatureDataTransform))]
            public KeyValuePair<string, IComponentFactory<IDataView, IDataTransform>>[] PostTransform;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output all columns or just scores", ShortName = "all")]
            public bool? OutputAllColumns;
 
            [Argument(ArgumentType.Multiple, HelpText = "What columns to output beyond score columns, if outputAllColumns=-.",
                Name = "OutputColumn", ShortName = "outCol")]
            public string[] OutputColumns;
        }
 
        internal const string Summary = "Scores a data file.";
 
        public ScoreCommand(IHostEnvironment env, Arguments args)
            : base(env, args, nameof(ScoreCommand))
        {
            Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile), "The input model file is required.");
            Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile), nameof(ImplOptions.OutputDataFile), "The output data file is required.");
            Utils.CheckOptionalUserDirectory(ImplOptions.OutputDataFile, nameof(ImplOptions.OutputDataFile));
        }
 
        public override void Run()
        {
            using (var ch = Host.Start("Score"))
            {
                RunCore(ch);
            }
        }
 
        private void RunCore(IChannel ch)
        {
            Host.AssertValue(ch);
 
            ch.Trace("Creating loader");
 
            LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader);
            ch.AssertValue(predictor);
            ch.AssertValueOrNull(trainSchema);
            ch.AssertValue(loader);
 
            ch.Trace("Creating pipeline");
            var scorer = ImplOptions.Scorer;
            ch.Assert(scorer == null || scorer is ICommandLineComponentFactory, "ScoreCommand should only be used from the command line.");
            var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorerFactorySettings: scorer as ICommandLineComponentFactory);
            ch.AssertValue(bindable);
 
            // REVIEW: We probably ought to prefer role mappings from the training schema.
            string feat = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
                nameof(ImplOptions.FeatureColumn), ImplOptions.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
                nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
            var schema = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true);
            var mapper = bindable.Bind(Host, schema);
 
            if (scorer == null)
                scorer = ScoreUtils.GetScorerComponent(Host, mapper);
 
            loader = LegacyCompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(),
                (env, view) => scorer.CreateComponent(env, view, mapper, trainSchema));
 
            loader = LegacyCompositeDataLoader.Create(Host, loader, ImplOptions.PostTransform);
 
            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
            {
                ch.Trace("Saving the data pipe");
                SaveLoader(loader, ImplOptions.OutputModelFile);
            }
 
            ch.Trace("Creating saver");
            IDataSaver writer;
            if (ImplOptions.Saver == null)
            {
                var ext = Path.GetExtension(ImplOptions.OutputDataFile);
                var isText = ext == ".txt" || ext == ".tlc";
                if (isText)
                {
                    writer = new TextSaver(Host, new TextSaver.Arguments());
                }
                else
                {
                    writer = new BinarySaver(Host, new BinarySaver.Arguments());
                }
            }
            else
            {
                writer = ImplOptions.Saver.CreateComponent(Host);
            }
            ch.Assert(writer != null);
            var outputIsBinary = writer is BinaryWriter;
 
            bool outputAllColumns =
                ImplOptions.OutputAllColumns == true
                || (ImplOptions.OutputAllColumns == null && Utils.Size(ImplOptions.OutputColumns) == 0 && outputIsBinary);
 
            bool outputNamesAndLabels =
                ImplOptions.OutputAllColumns == true || Utils.Size(ImplOptions.OutputColumns) == 0;
 
            if (ImplOptions.OutputAllColumns == true && Utils.Size(ImplOptions.OutputColumns) != 0)
                ch.Warning(nameof(ImplOptions.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(ImplOptions.OutputColumns) + " specified.");
 
            if (!outputAllColumns && Utils.Size(ImplOptions.OutputColumns) != 0)
            {
                foreach (var outCol in ImplOptions.OutputColumns)
                {
                    if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex))
                        throw ch.ExceptUserArg(nameof(Arguments.OutputColumns), "Column '{0}' not found.", outCol);
                }
            }
 
            uint maxScoreId = 0;
            if (!outputAllColumns)
                maxScoreId = loader.Schema.GetMaxAnnotationKind(out int colMax, AnnotationUtils.Kinds.ScoreColumnSetId);
            ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based
            var cols = new List<int>();
            for (int i = 0; i < loader.Schema.Count; i++)
            {
                if (!ImplOptions.KeepHidden && loader.Schema[i].IsHidden)
                    continue;
                if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels)))
                    continue;
                var type = loader.Schema[i].Type;
                if (writer.IsColumnSavable(type))
                    cols.Add(i);
                else
                {
                    ch.Warning("The column '{0}' will not be written as it has unsavable column type.",
                        loader.Schema[i].Name);
                }
            }
 
            ch.Check(cols.Count > 0, "No valid columns to save");
 
            ch.Trace("Scoring and saving data");
            using (var file = Host.CreateOutputFile(ImplOptions.OutputDataFile))
            using (var stream = file.CreateWriteStream())
                writer.SaveData(stream, loader, cols.ToArray());
        }
 
        /// <summary>
        /// Whether a column should be added, assuming it's not hidden
        /// (i.e.: this doesn't check for hidden
        /// </summary>
        private bool ShouldAddColumn(DataViewSchema schema, int i, uint scoreSet, bool outputNamesAndLabels)
        {
            uint scoreSetId = 0;
            if (schema.TryGetAnnotation(AnnotationUtils.ScoreColumnSetIdType, AnnotationUtils.Kinds.ScoreColumnSetId, i, ref scoreSetId)
                && scoreSetId == scoreSet)
            {
                return true;
            }
            if (outputNamesAndLabels)
            {
                switch (schema[i].Name)
                {
                    case "Label":
                    case "Name":
                    case "Names":
                        return true;
                    default:
                        break;
                }
            }
            if (ImplOptions.OutputColumns != null && Array.FindIndex(ImplOptions.OutputColumns, schema[i].Name.Equals) >= 0)
                return true;
            return false;
        }
    }
 
    [BestFriend]
    internal static class ScoreUtils
    {
        public static IDataScorerTransform GetScorer(IPredictor predictor, RoleMappedData data, IHostEnvironment env, RoleMappedSchema trainSchema)
        {
            var sc = GetScorerComponentAndMapper(predictor, null, data.Schema, env, null, out var mapper);
            return sc.CreateComponent(env, data.Data, mapper, trainSchema);
        }
 
        public static IDataScorerTransform GetScorer(
            TScorerFactory scorer,
            IPredictor predictor,
            IDataView input,
            string featureColName,
            string groupColName,
            IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> customColumns,
            IHostEnvironment env,
            RoleMappedSchema trainSchema,
            IComponentFactory<IPredictor, ISchemaBindableMapper> mapperFactory = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValueOrNull(scorer);
            env.CheckValue(predictor, nameof(predictor));
            env.CheckValue(input, nameof(input));
            env.CheckValueOrNull(featureColName);
            env.CheckValueOrNull(groupColName);
            env.CheckValueOrNull(customColumns);
            env.CheckValueOrNull(trainSchema);
 
            var schema = new RoleMappedSchema(input.Schema, label: null, feature: featureColName, group: groupColName, custom: customColumns, opt: true);
            var sc = GetScorerComponentAndMapper(predictor, scorer, schema, env, mapperFactory, out var mapper);
            return sc.CreateComponent(env, input, mapper, trainSchema);
        }
 
        /// <summary>
        /// Determines the scorer component factory (if the given one is null or empty), and creates the schema bound mapper.
        /// </summary>
        private static TScorerFactory GetScorerComponentAndMapper(
            IPredictor predictor,
            TScorerFactory scorerFactory,
            RoleMappedSchema schema,
            IHostEnvironment env,
            IComponentFactory<IPredictor, ISchemaBindableMapper> mapperFactory,
            out ISchemaBoundMapper mapper)
        {
            Contracts.AssertValue(env);
 
            var bindable = GetSchemaBindableMapper(env, predictor, mapperFactory, scorerFactory as ICommandLineComponentFactory);
            env.AssertValue(bindable);
            mapper = bindable.Bind(env, schema);
            if (scorerFactory != null)
                return scorerFactory;
            return GetScorerComponent(env, mapper);
        }
 
        /// <summary>
        /// Determine the default scorer for a schema bound mapper. This looks for text-valued ScoreColumnKind
        /// metadata on the first column of the mapper. If that text is found and maps to a scorer loadable class,
        /// that component is used. Otherwise, the GenericScorer is used.
        /// </summary>
        /// <param name="environment">The host environment.</param>.
        /// <param name="mapper">The schema bound mapper to get the default scorer.</param>.
        /// <param name="suffix">An optional suffix to append to the default column names.</param>
        public static TScorerFactory GetScorerComponent(
            IHostEnvironment environment,
            ISchemaBoundMapper mapper,
            string suffix = null)
        {
            Contracts.CheckValue(environment, nameof(environment));
            Contracts.AssertValue(mapper);
 
            ComponentCatalog.LoadableClassInfo info = null;
            ReadOnlyMemory<char> scoreKind = default;
            if (mapper.OutputSchema.Count > 0 &&
                mapper.OutputSchema.TryGetAnnotation(TextDataViewType.Instance, AnnotationUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) &&
                !scoreKind.IsEmpty)
            {
                var loadName = scoreKind.ToString();
                info = environment.ComponentCatalog.GetLoadableClassInfo<SignatureDataScorer>(loadName);
                if (info == null || !typeof(IDataScorerTransform).IsAssignableFrom(info.Type))
                    info = null;
            }
 
            Func<IHostEnvironment, IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> factoryFunc;
            if (info == null)
            {
                factoryFunc = (env, data, innerMapper, trainSchema) =>
                    new GenericScorer(
                        env,
                        new GenericScorer.Arguments() { Suffix = suffix },
                        data,
                        innerMapper,
                        trainSchema);
            }
            else
            {
                factoryFunc = (env, data, innerMapper, trainSchema) =>
                {
                    object args = info.CreateArguments();
                    if (args is ScorerArgumentsBase scorerArgs)
                    {
                        scorerArgs.Suffix = suffix;
                    }
                    return (IDataScorerTransform)info.CreateInstance(
                        env,
                        args,
                        new object[] { data, innerMapper, trainSchema });
                };
            }
 
            return ComponentFactoryUtils.CreateFromFunction(factoryFunc);
        }
 
        /// <summary>
        /// Given a predictor, an optional mapper factory, and an optional scorer factory settings,
        /// produces a compatible ISchemaBindableMapper.
        /// First, it tries to instantiate the bindable mapper using the mapper factory.
        /// Next, it tries to instantiate the bindable mapper using the <paramref name="scorerFactorySettings"/>
        /// (this will only succeed if there's a registered BindableMapper creation method with load name equal to the one
        /// of the scorer).
        /// If the above fails, it checks whether the predictor implements <see cref="ISchemaBindableMapper"/>
        /// directly.
        /// If this also isn't true, it will create a 'matching' standard mapper.
        /// </summary>
        public static ISchemaBindableMapper GetSchemaBindableMapper(
            IHostEnvironment env,
            IPredictor predictor,
            IComponentFactory<IPredictor, ISchemaBindableMapper> mapperFactory = null,
            ICommandLineComponentFactory scorerFactorySettings = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(predictor, nameof(predictor));
            env.CheckValueOrNull(mapperFactory);
            env.CheckValueOrNull(scorerFactorySettings);
 
            // if the mapperFactory was supplied, use it
            if (mapperFactory != null)
                return mapperFactory.CreateComponent(env, predictor);
 
            // See if we can instantiate a mapper using scorer arguments.
            if (scorerFactorySettings != null && TryCreateBindableFromScorer(env, predictor, scorerFactorySettings, out var bindable))
                return bindable;
 
            // The easy case is that the predictor implements the interface.
            bindable = predictor as ISchemaBindableMapper;
            if (bindable != null)
                return bindable;
 
            // Use one of the standard wrappers.
            if (predictor is IValueMapperDist)
                return new SchemaBindableBinaryPredictorWrapper(predictor);
 
            return new SchemaBindablePredictorWrapper(predictor);
        }
 
        private static bool TryCreateBindableFromScorer(IHostEnvironment env, IPredictor predictor,
            ICommandLineComponentFactory scorerSettings, out ISchemaBindableMapper bindable)
        {
            Contracts.AssertValue(env);
            env.AssertValue(predictor);
            env.AssertValue(scorerSettings);
 
            // Try to find a mapper factory method with the same loadname as the scorer settings.
            return ComponentCatalog.TryCreateInstance<ISchemaBindableMapper, SignatureBindableMapper>(
                env, out bindable, scorerSettings.Name, scorerSettings.GetSettingsString(), predictor);
        }
    }
}