File: Scorers\GenericScorer.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
 
[assembly: LoadableClass(typeof(GenericScorer), typeof(GenericScorer.Arguments), typeof(SignatureDataScorer),
    "Generic Scorer", GenericScorer.LoadName, "Generic")]
 
[assembly: LoadableClass(typeof(GenericScorer), null, typeof(SignatureLoadDataTransform),
    "Generic Scorer", GenericScorer.LoaderSignature)]
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// This class is a scorer that passes through all the ISchemaBound columns without adding any "derived columns".
    /// It also passes through all metadata (except for possibly changing the score column kind), and adds the
    /// score set id metadata.
    /// </summary>
 
    [BestFriend]
    internal sealed class GenericScorer : RowToRowScorerBase, ITransformCanSavePfa, ITransformCanSaveOnnx
    {
        public const string LoadName = "GenericScorer";
 
        public sealed class Arguments : ScorerArgumentsBase
        {
        }
 
        private sealed class Bindings : BindingsBase
        {
            /// <summary>
            /// The one and only constructor for Bindings.
            /// </summary>
            private Bindings(DataViewSchema input, ISchemaBoundRowMapper mapper, string suffix, bool user)
                : base(input, mapper, suffix, user)
            {
                Contracts.Assert(DerivedColumnCount == 0);
            }
 
            /// <summary>
            /// Create the bindings given the input schema, bound mapper, and column name suffix.
            /// </summary>
            public static Bindings Create(DataViewSchema input, ISchemaBoundRowMapper mapper, string suffix, bool user = true)
            {
                Contracts.AssertValue(input);
                Contracts.AssertValue(mapper);
                Contracts.AssertValueOrNull(suffix);
                // We don't actually depend on this invariant, but if this assert fires it means the bindable
                // did the wrong thing.
                Contracts.Assert(mapper.InputSchema == input);
 
                return new Bindings(input, mapper, suffix, user);
            }
 
            /// <summary>
            /// Create the bindings given the env, bindable, input schema, column roles, and column name suffix.
            /// </summary>
            private static Bindings Create(IHostEnvironment env, ISchemaBindableMapper bindable, DataViewSchema input,
                IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> roles, string suffix, bool user = true)
            {
                Contracts.AssertValue(env);
                Contracts.AssertValue(bindable);
                Contracts.AssertValue(input);
                Contracts.AssertValue(roles);
                Contracts.AssertValueOrNull(suffix);
 
                var mapper = bindable.Bind(env, new RoleMappedSchema(input, roles));
                // We don't actually depend on this invariant, but if this assert fires it means the bindable
                // did the wrong thing.
                Contracts.Assert(mapper.InputRoleMappedSchema.Schema == input);
 
                var rowMapper = mapper as ISchemaBoundRowMapper;
                Contracts.Check(rowMapper != null, "Predictor expected to be a RowMapper!");
 
                return Create(input, rowMapper, suffix, user);
            }
 
            /// <summary>
            /// Create a new Bindings from this one, but based on a potentially different schema.
            /// Used by the ITransformTemplate.ApplyToData implementation.
            /// </summary>
            public Bindings ApplyToSchema(IHostEnvironment env, DataViewSchema input)
            {
                Contracts.AssertValue(input);
                Contracts.AssertValue(env);
 
                var bindable = RowMapper.Bindable;
                var roles = RowMapper.GetInputColumnRoles();
                string suffix = Suffix;
                return Create(env, bindable, input, roles, suffix);
            }
 
            /// <summary>
            /// Deserialize the bindings, given the env, bindable and input schema.
            /// </summary>
            public static Bindings Create(ModelLoadContext ctx,
                IHostEnvironment env, ISchemaBindableMapper bindable, DataViewSchema input)
            {
                Contracts.AssertValue(ctx);
 
                // *** Binary format ***
                // <base info>
                string suffix;
                var roles = LoadBaseInfo(ctx, out suffix);
 
                return Create(env, bindable, input, roles, suffix, user: false);
            }
 
            internal override void SaveModel(ModelSaveContext ctx)
            {
                Contracts.AssertValue(ctx);
 
                // *** Binary format ***
                // <base info>
                SaveBase(ctx);
            }
        }
 
        public const string LoaderSignature = "GenericScoreTransform";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "GNRICSCR",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(GenericScorer).Assembly.FullName);
        }
 
        private const string RegistrationName = "GenericScore";
 
        private readonly Bindings _bindings;
        private protected override BindingsBase GetBindings() => _bindings;
 
        public override DataViewSchema OutputSchema { get; }
 
        bool ICanSavePfa.CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true;
 
        bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
 
        /// <summary>
        /// The <see cref="SignatureDataScorer"/> entry point for creating a <see cref="GenericScorer"/>.
        /// </summary>
        [BestFriend]
        internal GenericScorer(IHostEnvironment env, ScorerArgumentsBase args, IDataView data,
            ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
            : base(env, data, RegistrationName, Contracts.CheckRef(mapper, nameof(mapper)).Bindable)
        {
            Host.CheckValue(args, nameof(args));
            Host.AssertValue(data, "data");
            Host.AssertValue(mapper, "mapper");
 
            var rowMapper = mapper as ISchemaBoundRowMapper;
            Host.CheckParam(rowMapper != null, nameof(mapper), "mapper should implement ISchemaBoundRowMapper");
            _bindings = Bindings.Create(data.Schema, rowMapper, args.Suffix);
            OutputSchema = _bindings.AsSchema;
        }
 
        /// <summary>
        /// Constructor for <see cref="ApplyToDataCore"/> method.
        /// </summary>
        private GenericScorer(IHostEnvironment env, GenericScorer transform, IDataView data)
            : base(env, data, RegistrationName, transform.Bindable)
        {
            _bindings = transform._bindings.ApplyToSchema(env, data.Schema);
            OutputSchema = _bindings.AsSchema;
        }
 
        /// <summary>
        /// Constructor for deserialization.
        /// </summary>
        private GenericScorer(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, ctx, input)
        {
            Contracts.AssertValue(ctx);
            _bindings = Bindings.Create(ctx, host, Bindable, input.Schema);
            OutputSchema = _bindings.AsSchema;
        }
 
        /// <summary>
        /// <see cref="SignatureLoadDataTransform"/> entry point - for deserialization.
        /// </summary>
        public static GenericScorer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
 
            h.CheckValue(ctx, nameof(ctx));
            h.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());
 
            return h.Apply("Loading Model", ch => new GenericScorer(h, ctx, input));
        }
 
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            Contracts.AssertValue(ctx);
            ctx.SetVersionInfo(GetVersionInfo());
            _bindings.SaveModel(ctx);
        }
 
        void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSavePfa);
            var pfaBindable = (IBindableCanSavePfa)Bindable;
 
            var schema = _bindings.RowMapper.InputRoleMappedSchema;
            Host.Assert(_bindings.DerivedColumnCount == 0);
            string[] outColNames = new string[_bindings.InfoCount];
            for (int iinfo = 0; iinfo < _bindings.InfoCount; ++iinfo)
                outColNames[iinfo] = _bindings.GetColumnName(_bindings.MapIinfoToCol(iinfo));
 
            pfaBindable.SaveAsPfa(ctx, schema, outColNames);
        }
 
        void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);
            var onnxBindable = (IBindableCanSaveOnnx)Bindable;
 
            var schema = _bindings.RowMapper.InputRoleMappedSchema;
            Host.Assert(_bindings.DerivedColumnCount == 0);
            string[] outVariableNames = new string[_bindings.InfoCount];
            for (int iinfo = 0; iinfo < _bindings.InfoCount; ++iinfo)
            {
                int colIndex = _bindings.MapIinfoToCol(iinfo);
                string colName = _bindings.GetColumnName(colIndex);
                colName = ctx.AddIntermediateVariable(_bindings.GetColumnType(colIndex), colName);
                outVariableNames[iinfo] = colName;
            }
 
            if (!onnxBindable.SaveAsOnnx(ctx, schema, outVariableNames))
            {
                foreach (var name in outVariableNames)
                    ctx.RemoveVariable(name, true);
            }
        }
 
        protected override bool WantParallelCursors(Func<int, bool> predicate)
        {
            Host.AssertValue(predicate, "predicate");
 
            // Prefer parallel cursors iff some of our columns are active, otherwise, don't care.
            return _bindings.AnyNewColumnsActive(predicate);
        }
 
        private protected override IDataTransform ApplyToDataCore(IHostEnvironment env, IDataView newSource)
        {
            Host.CheckValue(env, nameof(env));
            Host.CheckValue(newSource, nameof(newSource));
 
            return new GenericScorer(env, this, newSource);
        }
 
        protected override Delegate[] GetGetters(DataViewRow output, Func<int, bool> predicate)
        {
            Host.Assert(_bindings.DerivedColumnCount == 0);
            Host.AssertValue(output);
            Host.AssertValue(predicate);
            Host.Assert(output.Schema == _bindings.RowMapper.OutputSchema);
 
            return GetGettersFromRow(output, predicate);
        }
    }
}