File: Scorers\MulticlassClassificationScorer.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Newtonsoft.Json.Linq;
 
[assembly: LoadableClass(typeof(MulticlassClassificationScorer),
    typeof(MulticlassClassificationScorer.Arguments), typeof(SignatureDataScorer),
    "Multi-Class Classifier Scorer", "MultiClassClassifierScorer", "MultiClassClassifier",
    "MultiClass", AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification)]
 
[assembly: LoadableClass(typeof(MulticlassClassificationScorer), null, typeof(SignatureLoadDataTransform),
    "Multi-Class Classifier Scorer", MulticlassClassificationScorer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(MulticlassClassificationScorer.LabelNameBindableMapper), null, typeof(SignatureLoadModel),
    "Multi-Class Label-Name Mapper", MulticlassClassificationScorer.LabelNameBindableMapper.LoaderSignature)]
 
namespace Microsoft.ML.Data
{
    internal sealed class MulticlassClassificationScorer : PredictedLabelScorerBase
    {
        // REVIEW: consider outputting probabilities when multi-class classifiers distinguish
        // between scores and probabilities (using IDistributionPredictor)
        public sealed class Arguments : ScorerArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Score Column Name.", ShortName = "scn")]
            public string ScoreColumnName = AnnotationUtils.Const.ScoreValueKind.Score;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Predicted Label Column Name.", ShortName = "plcn")]
            public string PredictedLabelColumnName = DefaultColumnNames.PredictedLabel;
        }
 
        public const string LoaderSignature = "MultiClassScoreTrans";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "MULCLSCR",
                //verWrittenCur: 0x00010001, // Initial
                //verWrittenCur: 0x00010002, // ISchemaBindableMapper
                verWrittenCur: 0x00010003, // ISchemaBindableMapper update
                verReadableCur: 0x00010003,
                verWeCanReadBack: 0x00010003,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(MulticlassClassificationScorer).Assembly.FullName);
        }
 
        private const string RegistrationName = "MultiClassClassifierScore";
 
        /// <summary>
        /// This bindable mapper facilitates the serialization and rebinding of the special bound
        /// mapper that attaches the label metadata to the slot names of the output score column.
        /// </summary>
        // REVIEW: It seems like the attachment of metadata should be solvable in a manner
        // less ridiculously verbose than this.
        public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveModel, IBindableCanSavePfa,
            IBindableCanSaveOnnx, IDisposable
        {
            private static readonly FuncInstanceMethodInfo1<LabelNameBindableMapper, object, Delegate> _decodeInitMethodInfo
                = FuncInstanceMethodInfo1<LabelNameBindableMapper, object, Delegate>.Create(target => target.DecodeInit<int>);
 
            public const string LoaderSignature = "LabelSlotNameMapper";
            private const string _innerDir = "InnerMapper";
            private readonly ISchemaBindableMapper _bindable;
            private readonly VectorDataViewType _type;
            private readonly string _metadataKind;
            // In an ideal world this would be a value getter of the appropriate type. However, it is awkward
            // to have this class be generic due to restrictions on loadable classes, so we instead pay the
            // price of a handful of runtime casts.
            // REVIEW: Worth it to have this be abstract, with a nested generic implementation?
            // That seems like a bit much...
            private readonly Delegate _getter;
            private readonly IHost _host;
            private readonly Func<ISchemaBoundMapper, DataViewType, bool> _canWrap;
 
            internal ISchemaBindableMapper Bindable => _bindable;
 
            public VectorDataViewType Type => _type;
            bool ICanSavePfa.CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true;
            bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
 
            private static VersionInfo GetVersionInfo()
            {
                return new VersionInfo(
                    modelSignature: "LABNAMBM",
                    // verWrittenCur: 0x00010001, // Initial
                    verWrittenCur: 0x00010002, // Added metadataKind
                    verReadableCur: 0x00010002,
                    verWeCanReadBack: 0x00010001,
                    loaderSignature: LoaderSignature,
                    loaderAssemblyName: typeof(LabelNameBindableMapper).Assembly.FullName);
            }
 
            private const int VersionAddedMetadataKind = 0x00010002;
 
            private LabelNameBindableMapper(IHostEnvironment env, ISchemaBoundMapper mapper, VectorDataViewType type, Delegate getter,
                string metadataKind, Func<ISchemaBoundMapper, DataViewType, bool> canWrap)
                : this(env, mapper.Bindable, type, getter, metadataKind, canWrap)
            {
            }
 
            private LabelNameBindableMapper(IHostEnvironment env, ISchemaBindableMapper bindable, VectorDataViewType type, Delegate getter,
                string metadataKind, Func<ISchemaBoundMapper, DataViewType, bool> canWrap)
            {
                Contracts.AssertValue(env);
                _host = env.Register(LoaderSignature);
                _host.AssertValue(bindable);
                _host.AssertValue(type);
                _host.AssertValue(getter);
                _host.AssertNonEmpty(metadataKind);
                _host.AssertValueOrNull(canWrap);
 
                _bindable = bindable;
                _type = type;
                _getter = getter;
                _metadataKind = metadataKind;
                _canWrap = canWrap;
            }
 
            private LabelNameBindableMapper(IHost host, ModelLoadContext ctx)
            {
                Contracts.AssertValue(host);
                _host = host;
                _host.AssertValue(ctx);
 
                ctx.LoadModel<ISchemaBindableMapper, SignatureLoadModel>(_host, out _bindable, _innerDir);
                BinarySaver saver = new BinarySaver(_host, new BinarySaver.Arguments());
                DataViewType type;
                object value;
                _host.CheckDecode(saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out type, out value));
                _type = type as VectorDataViewType;
                _host.CheckDecode(_type != null);
                _host.CheckDecode(value != null);
                _getter = Utils.MarshalInvoke(_decodeInitMethodInfo, this, _type.ItemType.RawType, value);
                _metadataKind = ctx.Header.ModelVerReadable >= VersionAddedMetadataKind ?
                    ctx.LoadNonEmptyString() : AnnotationUtils.Kinds.SlotNames;
            }
 
            private Delegate DecodeInit<T>(object value)
            {
                _host.CheckDecode(value is VBuffer<T>);
                VBuffer<T> buffValue = (VBuffer<T>)value;
                ValueGetter<VBuffer<T>> buffGetter = (ref VBuffer<T> dst) => buffValue.CopyTo(ref dst);
                return buffGetter;
            }
 
            /// <summary>
            /// Method corresponding to <see cref="SignatureLoadModel"/>.
            /// </summary>
            private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx)
            {
                Contracts.CheckValue(env, nameof(env));
                var h = env.Register(LoaderSignature);
 
                h.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel(GetVersionInfo());
 
                // *** Binary format ***
                // byte[]: A chunk of data saving both the type and value of the label names, as saved by the BinarySaver.
                // int: string id of the metadata kind
 
                return h.Apply("Loading Model", ch => new LabelNameBindableMapper(h, ctx));
            }
 
            void ICanSaveModel.Save(ModelSaveContext ctx)
            {
                Contracts.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel();
                ctx.SetVersionInfo(GetVersionInfo());
 
                // *** Binary format ***
                // byte[]: A chunk of data saving both the type and value of the label names, as saved by the BinarySaver.
                // int: string id of the metadata kind
 
                ctx.SaveModel(_bindable, _innerDir);
                Utils.MarshalActionInvoke(SaveCore<int>, _type.ItemType.RawType, ctx);
                ctx.SaveNonEmptyString(_metadataKind);
            }
 
            private void SaveCore<T>(ModelSaveContext ctx)
            {
                Contracts.Assert(_type.ItemType.RawType == typeof(T));
                Contracts.Assert(_getter is ValueGetter<VBuffer<T>>);
 
                var getter = (ValueGetter<VBuffer<T>>)_getter;
                var val = default(VBuffer<T>);
                getter(ref val);
 
                BinarySaver saver = new BinarySaver(_host, new BinarySaver.Arguments());
                int bytesWritten;
                if (!saver.TryWriteTypeAndValue<VBuffer<T>>(ctx.Writer.BaseStream, _type, ref val, out bytesWritten))
                    throw _host.Except("We do not know how to serialize label names of type '{0}'", _type.ItemType);
            }
 
            internal ISchemaBindableMapper Clone(ISchemaBindableMapper inner)
            {
                return new LabelNameBindableMapper(_host, inner, _type, _getter, _metadataKind, _canWrap);
            }
 
            void IBindableCanSavePfa.SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] outputNames)
            {
                Contracts.CheckValue(ctx, nameof(ctx));
                Contracts.CheckValue(schema, nameof(schema));
                Contracts.Check(((ICanSavePfa)this).CanSavePfa, "Cannot be saved as PFA");
                Contracts.Assert(_bindable is IBindableCanSavePfa);
                ((IBindableCanSavePfa)_bindable).SaveAsPfa(ctx, schema, outputNames);
            }
 
            bool IBindableCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames)
            {
                Contracts.CheckValue(ctx, nameof(ctx));
                Contracts.CheckValue(schema, nameof(schema));
                Contracts.Check(((ICanSaveOnnx)this).CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
                Contracts.Assert(_bindable is IBindableCanSaveOnnx);
                return ((IBindableCanSaveOnnx)_bindable).SaveAsOnnx(ctx, schema, outputNames);
            }
 
            ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
            {
                var innerBound = _bindable.Bind(env, schema);
                if (_canWrap != null && !_canWrap(innerBound, _type))
                    return innerBound;
                Contracts.Assert(innerBound is ISchemaBoundRowMapper);
                return Utils.MarshalInvoke(CreateBound<int>, _type.ItemType.RawType, env, (ISchemaBoundRowMapper)innerBound, _type, _getter, _metadataKind, _canWrap);
            }
 
            internal static ISchemaBoundMapper CreateBound<T>(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorDataViewType type, Delegate getter,
                string metadataKind, Func<ISchemaBoundMapper, DataViewType, bool> canWrap)
            {
                Contracts.AssertValue(env);
                env.AssertValue(mapper);
                env.AssertValue(type);
                env.AssertValue(getter);
                env.Assert(getter is ValueGetter<VBuffer<T>>);
                env.AssertNonEmpty(metadataKind);
                env.AssertValueOrNull(canWrap);
 
                return new Bound<T>(env, mapper, type, (ValueGetter<VBuffer<T>>)getter, metadataKind, canWrap);
            }
 
            private sealed class Bound<T> : ISchemaBoundRowMapper
            {
                private readonly IHost _host;
                /// <summary>The mapper we are wrapping.</summary>
                private readonly ISchemaBoundRowMapper _mapper;
                private readonly VectorDataViewType _labelNameType;
                private readonly string _metadataKind;
                private readonly ValueGetter<VBuffer<T>> _labelNameGetter;
                // Lazily initialized by the property.
                private LabelNameBindableMapper _bindable;
                private readonly Func<ISchemaBoundMapper, DataViewType, bool> _canWrap;
 
                public RoleMappedSchema InputRoleMappedSchema => _mapper.InputRoleMappedSchema;
                public DataViewSchema InputSchema => _mapper.InputSchema;
                public DataViewSchema OutputSchema { get; }
 
                public ISchemaBindableMapper Bindable
                {
                    get
                    {
                        return _bindable ??
                            Interlocked.CompareExchange(ref _bindable,
                                new LabelNameBindableMapper(_host, _mapper, _labelNameType, _labelNameGetter, _metadataKind, _canWrap), null) ??
                            _bindable;
                    }
                }
 
                /// <summary>
                /// This is the constructor called for the initial wrapping.
                /// </summary>
                public Bound(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorDataViewType type, ValueGetter<VBuffer<T>> getter,
                    string metadataKind, Func<ISchemaBoundMapper, DataViewType, bool> canWrap)
                {
                    Contracts.CheckValue(env, nameof(env));
                    _host = env.Register(LoaderSignature);
                    _host.CheckValue(mapper, nameof(mapper));
                    _host.CheckValue(type, nameof(type));
                    _host.CheckValue(getter, nameof(getter));
                    _host.CheckNonEmpty(metadataKind, nameof(metadataKind));
                    _host.CheckValueOrNull(canWrap);
 
                    _mapper = mapper;
 
                    int scoreIdx;
                    bool result = mapper.OutputSchema.TryGetColumnIndex(AnnotationUtils.Const.ScoreValueKind.Score, out scoreIdx);
                    if (!result)
                        throw env.ExceptParam(nameof(mapper), "Mapper did not have a '{0}' column", AnnotationUtils.Const.ScoreValueKind.Score);
 
                    _labelNameType = type;
                    _labelNameGetter = getter;
                    _metadataKind = metadataKind;
                    _canWrap = canWrap;
 
                    OutputSchema = DecorateOutputSchema(mapper.OutputSchema, scoreIdx, _labelNameType, _labelNameGetter, _metadataKind);
                }
 
                /// <summary>
                /// Append label names to score column as its metadata.
                /// </summary>
                private DataViewSchema DecorateOutputSchema(DataViewSchema partialSchema, int scoreColumnIndex, VectorDataViewType labelNameType,
                    ValueGetter<VBuffer<T>> labelNameGetter, string labelNameKind)
                {
                    var builder = new DataViewSchema.Builder();
                    // Sequentially add columns so that the order of them is not changed comparing with the schema in the mapper
                    // that computes score column.
                    for (int i = 0; i < partialSchema.Count; ++i)
                    {
                        var meta = new DataViewSchema.Annotations.Builder();
                        if (i == scoreColumnIndex)
                        {
                            // Add label names for score column.
                            meta.Add(partialSchema[i].Annotations, selector: s => s != labelNameKind);
                            meta.Add(labelNameKind, labelNameType, labelNameGetter);
                        }
                        else
                        {
                            // Copy all existing metadata because this transform only affects score column.
                            meta.Add(partialSchema[i].Annotations, selector: s => true);
                        }
                        // Instead of appending extra metadata to the existing score column, we create new one because
                        // metadata is read-only.
                        builder.AddColumn(partialSchema[i].Name, partialSchema[i].Type, meta.ToAnnotations());
                    }
                    return builder.ToSchema();
                }
 
                /// <summary>
                /// Given a set of columns, return the input columns that are needed to generate those output columns.
                /// </summary>
                IEnumerable<DataViewSchema.Column> ISchemaBoundRowMapper.GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns)
                    => _mapper.GetDependenciesForNewColumns(dependingColumns);
 
                public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles() => _mapper.GetInputColumnRoles();
 
                DataViewRow ISchemaBoundRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
                {
                    var innerRow = _mapper.GetRow(input, activeColumns);
                    return new RowImpl(innerRow, OutputSchema);
                }
 
                private sealed class RowImpl : WrappingRow
                {
                    private readonly DataViewSchema _schema;
 
                    // The schema is of course the only difference from _row.
                    public override DataViewSchema Schema => _schema;
 
                    public RowImpl(DataViewRow row, DataViewSchema schema)
                        : base(row)
                    {
                        Contracts.AssertValue(row);
                        Contracts.AssertValue(schema);
 
                        _schema = schema;
                    }
 
                    /// <summary>
                    /// Returns whether the given column is active in this row.
                    /// </summary>
                    public override bool IsColumnActive(DataViewSchema.Column column) => Input.IsColumnActive(column);
 
                    /// <summary>
                    /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
                    /// This throws if the column is not active in this row, or if the type
                    /// <typeparamref name="TValue"/> differs from this column's type.
                    /// </summary>
                    /// <typeparam name="TValue"> is the column's content type.</typeparam>
                    /// <param name="column"> is the output column whose getter should be returned.</param>
                    public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) => Input.GetGetter<TValue>(column);
                }
            }
 
            #region IDisposable Support
            private bool _disposed;
 
            public void Dispose()
            {
                // TODO: Is it necessary to call the base class Dispose()?
                if (_disposed)
                    return;
 
                (_bindable as IDisposable)?.Dispose();
 
                _disposed = true;
            }
            #endregion
        }
 
        /// <summary>
        /// This function performs a number of checks on the inputs and, if appropriate and possible, will produce
        /// a mapper with slots names on the output score column properly mapped. If this is not possible for any
        /// reason, it will just return the input bound mapper.
        /// </summary>
        private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(mapper, nameof(mapper));
            env.CheckValueOrNull(trainSchema);
 
            // The idea is that we will take the key values from the train schema label, and present
            // them as slot name metadata. But there are a number of conditions for this to actually
            // happen, so we test those here. If these are not
 
            if (trainSchema?.Label == null)
                return mapper; // We don't even have a label identified in a training schema.
            var keyType = trainSchema.Label.Value.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
            if (keyType == null)
                return mapper;
 
            // Great!! All checks pass.
            return Utils.MarshalInvoke(WrapCore<int>, keyType.ItemType.RawType, env, mapper, trainSchema);
        }
 
        /// <summary>
        /// This is a utility method used to determine whether <see cref="LabelNameBindableMapper"/>
        /// can or should be used to wrap <paramref name="mapper"/>. This will not throw, since the
        /// desired behavior in the event that it cannot be wrapped, is to just back off to the original
        /// "unwrapped" bound mapper.
        /// </summary>
        /// <param name="mapper">The mapper we are seeing if we can wrap</param>
        /// <param name="labelNameType">The type of the label names from the metadata (either
        /// originating from the key value metadata of the training label column, or deserialized
        /// from the model of a bindable mapper)</param>
        /// <returns>Whether we can call <see cref="LabelNameBindableMapper.CreateBound{T}"/> with
        /// this mapper and expect it to succeed</returns>
        internal static bool CanWrapTrainingLabels(ISchemaBoundMapper mapper, DataViewType labelNameType)
        {
            if (GetTypesForWrapping(mapper, labelNameType, AnnotationUtils.Kinds.TrainingLabelValues, out var scoreType))
                // Check that the type is vector, and is of compatible size with the score output.
                return labelNameType is VectorDataViewType vectorType && vectorType.Size == scoreType.GetVectorSize();
            return false;
        }
 
        internal static bool GetTypesForWrapping(ISchemaBoundMapper mapper, DataViewType labelNameType, string metaKind, out DataViewType scoreType)
        {
            Contracts.AssertValue(mapper);
            Contracts.AssertValue(labelNameType);
            scoreType = null;
            ISchemaBoundRowMapper rowMapper = mapper as ISchemaBoundRowMapper;
            if (rowMapper == null)
                return false; // We could cover this case, but it is of no practical worth as far as I see, so I decline to do so.
 
            var outSchema = mapper.OutputSchema;
            int scoreIdx;
            var scoreCol = outSchema.GetColumnOrNull(AnnotationUtils.Const.ScoreValueKind.Score);
            if (!outSchema.TryGetColumnIndex(AnnotationUtils.Const.ScoreValueKind.Score, out scoreIdx))
                return false; // The mapper doesn't even publish a score column to attach the metadata to.
            if (outSchema[scoreIdx].Annotations.Schema.GetColumnOrNull(metaKind)?.Type != null)
                return false; // The mapper publishes a score column, and already produces its own metakind.
            scoreType = outSchema[scoreIdx].Type;
            return true;
        }
 
        /// <summary>
        /// This is a utility method used to determine whether <see cref="LabelNameBindableMapper"/>
        /// can or should be used to wrap <paramref name="mapper"/>. This will not throw, since the
        /// desired behavior in the event that it cannot be wrapped, is to just back off to the original
        /// "unwrapped" bound mapper.
        /// </summary>
        /// <param name="mapper">The mapper we are seeing if we can wrap</param>
        /// <param name="labelNameType">The type of the label names from the metadata (either
        /// originating from the key value metadata of the training label column, or deserialized
        /// from the model of a bindable mapper)</param>
        /// <returns>Whether we can call <see cref="LabelNameBindableMapper.CreateBound{T}"/> with
        /// this mapper and expect it to succeed</returns>
        internal static bool CanWrapSlotNames(ISchemaBoundMapper mapper, DataViewType labelNameType)
        {
            if (GetTypesForWrapping(mapper, labelNameType, AnnotationUtils.Kinds.SlotNames, out var scoreType))
                // Check that the type is vector, and is of compatible size with the score output.
                return labelNameType is VectorDataViewType vectorType && vectorType.Size == scoreType.GetVectorSize() && vectorType.ItemType == TextDataViewType.Instance;
            return false;
        }
 
        internal static ISchemaBoundMapper WrapCore<T>(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
        {
            Contracts.AssertValue(env);
            env.AssertValue(mapper);
            env.AssertValue(trainSchema);
            env.Assert(mapper is ISchemaBoundRowMapper);
 
            // Key values from the training schema label, will map to slot names of the score output.
            var type = trainSchema.Label.Value.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type;
            env.AssertValue(type);
            env.Assert(type is VectorDataViewType);
 
            // Wrap the fetching of the metadata as a simple getter.
            ValueGetter<VBuffer<T>> getter =
                (ref VBuffer<T> value) =>
                {
                    trainSchema.Label.Value.GetKeyValues(ref value);
                };
            var resultMapper = mapper;
            if (CanWrapTrainingLabels(resultMapper, type))
                resultMapper = LabelNameBindableMapper.CreateBound<T>(env, (ISchemaBoundRowMapper)resultMapper, type as VectorDataViewType, getter, AnnotationUtils.Kinds.TrainingLabelValues, CanWrapTrainingLabels);
            if (CanWrapSlotNames(resultMapper, type))
                resultMapper = LabelNameBindableMapper.CreateBound<T>(env, (ISchemaBoundRowMapper)resultMapper, type as VectorDataViewType, getter, AnnotationUtils.Kinds.SlotNames, CanWrapSlotNames);
            return resultMapper;
        }
 
        [BestFriend]
        internal MulticlassClassificationScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
            : base(args, env, data, WrapIfNeeded(env, mapper, trainSchema), trainSchema, RegistrationName, AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification,
                args.ScoreColumnName, OutputTypeMatches, GetPredColType, args.PredictedLabelColumnName)
        {
        }
 
        private MulticlassClassificationScorer(IHostEnvironment env, MulticlassClassificationScorer transform, IDataView newSource)
            : base(env, transform, newSource, RegistrationName)
        {
        }
 
        private MulticlassClassificationScorer(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, ctx, input, OutputTypeMatches, GetPredColType)
        {
            // *** Binary format ***
            // <base info>
        }
 
        /// <summary>
        /// Corresponds to <see cref="SignatureLoadDataTransform"/>.
        /// </summary>
        private static MulticlassClassificationScorer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
            h.CheckValue(ctx, nameof(ctx));
            h.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());
            return h.Apply("Loading Model", ch => new MulticlassClassificationScorer(h, ctx, input));
        }
 
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            Contracts.AssertValue(ctx);
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // <base info>
 
            base.SaveCore(ctx);
        }
 
        private protected override IDataTransform ApplyToDataCore(IHostEnvironment env, IDataView newSource)
        {
            Contracts.CheckValue(env, nameof(env));
            Contracts.CheckValue(newSource, nameof(newSource));
 
            return new MulticlassClassificationScorer(env, this, newSource);
        }
 
        protected override Delegate GetPredictedLabelGetter(DataViewRow output, out Delegate scoreGetter)
        {
            Host.AssertValue(output);
            Host.Assert(output.Schema == Bindings.RowMapper.OutputSchema);
            Host.Assert(output.IsColumnActive(output.Schema[Bindings.ScoreColumnIndex]));
 
            ValueGetter<VBuffer<float>> mapperScoreGetter = output.GetGetter<VBuffer<float>>(Bindings.RowMapper.OutputSchema[Bindings.ScoreColumnIndex]);
 
            long cachedPosition = -1;
            VBuffer<float> score = default;
            int scoreLength = Bindings.PredColType.GetKeyCountAsInt32(Host);
 
            ValueGetter<uint> predFn =
                (ref uint dst) =>
                {
                    EnsureCachedPosition(ref cachedPosition, ref score, output, mapperScoreGetter);
                    Host.Check(score.Length == scoreLength);
                    int index = VectorUtils.ArgMax(in score);
                    if (index < 0)
                        dst = 0;
                    else
                        dst = (uint)index + 1;
                };
            ValueGetter<VBuffer<float>> scoreFn =
                (ref VBuffer<float> dst) =>
                {
                    EnsureCachedPosition(ref cachedPosition, ref score, output, mapperScoreGetter);
                    Host.Check(score.Length == scoreLength);
                    score.CopyTo(ref dst);
                };
 
            scoreGetter = scoreFn;
            return predFn;
        }
 
        private protected override JToken PredictedLabelPfa(string[] mapperOutputs)
        {
            Contracts.Assert(Utils.Size(mapperOutputs) == 1);
            return PfaUtils.Call("a.argmax", mapperOutputs[0]);
        }
 
        private static DataViewType GetPredColType(DataViewType scoreType, ISchemaBoundRowMapper mapper) => new KeyDataViewType(typeof(uint), scoreType.GetVectorSize());
 
        private static bool OutputTypeMatches(DataViewType scoreType) =>
            scoreType is VectorDataViewType vectorType && vectorType.IsKnownSize && vectorType.ItemType == NumberDataViewType.Single;
    }
}