File: Scorers\FeatureContributionCalculation.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
 
[assembly: LoadableClass(typeof(IDataScorerTransform), typeof(FeatureContributionScorer), typeof(FeatureContributionScorer.Arguments),
    typeof(SignatureDataScorer), "Feature Contribution Scorer", "fcc", "fct", "FeatureContributionCalculationScorer", AnnotationUtils.Const.ScoreColumnKind.FeatureContribution)]
 
[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(FeatureContributionScorer), typeof(FeatureContributionScorer.Arguments),
    typeof(SignatureBindableMapper), "Feature Contribution Mapper", "fcc", "fct", AnnotationUtils.Const.ScoreColumnKind.FeatureContribution)]
 
[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(FeatureContributionScorer), null, typeof(SignatureLoadModel),
    "Feature Contribution Mapper", FeatureContributionScorer.MapperLoaderSignature)]
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// Used only by the command line API for scoring and calculation of feature contribution.
    /// </summary>
    internal sealed class FeatureContributionScorer
    {
        // Apparently, loader signature is limited in length to 24 characters.
        internal const string MapperLoaderSignature = "FCCBindable";
 
        internal sealed class Arguments : ScorerArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Number of top contributions", SortOrder = 1)]
            public int Top = 10;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Number of bottom contributions", SortOrder = 2)]
            public int Bottom = 10;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether or not output of Features contribution should be normalized", ShortName = "norm", SortOrder = 3)]
            public bool Normalize = true;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Whether or not output of Features contribution in string key-value format", ShortName = "str", SortOrder = 4)]
            public bool Stringify = false;
 
            // REVIEW: the scorer currently ignores the 'suffix' argument from the base class. It should respect it.
        }
 
        // Factory method for SignatureDataScorer.
        private static IDataScorerTransform Create(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));
            env.CheckValue(mapper, nameof(mapper));
            if (args.Top < 0)
                throw env.Except($"Number of top contribution must be non negative");
            if (args.Bottom < 0)
                throw env.Except($"Number of bottom contribution must be non negative");
 
            var contributionMapper = mapper as RowMapper;
            env.CheckParam(mapper != null, nameof(mapper), "Unexpected mapper");
 
            var scorer = ScoreUtils.GetScorerComponent(env, contributionMapper);
            var scoredPipe = scorer.CreateComponent(env, data, contributionMapper, trainSchema);
            return scoredPipe;
        }
 
        // Factory method for SignatureBindableMapper.
        private static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(predictor, nameof(predictor));
            var pred = predictor as IFeatureContributionMapper;
            env.CheckParam(pred != null, nameof(predictor), "Predictor doesn't support getting feature contributions");
            return new BindableMapper(env, pred, args.Top, args.Bottom, args.Normalize, args.Stringify);
        }
 
        // Factory method for SignatureLoadModel.
        private static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx)
            => new BindableMapper(env, ctx);
 
        /// <summary>
        /// Holds the definition of the getters for the FeatureContribution column. It also contains the generic mapper that is used to score the Predictor.
        /// This is only used by the command line API.
        /// </summary>
        private sealed class BindableMapper : ISchemaBindableMapper, ICanSaveModel, IPredictor
        {
            private static readonly FuncInstanceMethodInfo1<BindableMapper, DataViewRow, int, Delegate> _getValueGetterMethodInfo
                = FuncInstanceMethodInfo1<BindableMapper, DataViewRow, int, Delegate>.Create(target => target.GetValueGetter<int>);
 
            private readonly int _topContributionsCount;
            private readonly int _bottomContributionsCount;
            private readonly bool _normalize;
            private readonly IHostEnvironment _env;
 
            public readonly IFeatureContributionMapper Predictor;
            public readonly ISchemaBindableMapper GenericMapper;
            public readonly bool Stringify;
 
            private static VersionInfo GetVersionInfo()
            {
                return new VersionInfo(
                    modelSignature: "FCC SCBI",
                    verWrittenCur: 0x00010001, // Initial
                    verReadableCur: 0x00010001,
                    verWeCanReadBack: 0x00010001,
                    loaderSignature: MapperLoaderSignature,
                    loaderAssemblyName: typeof(FeatureContributionScorer).Assembly.FullName);
            }
 
            public PredictionKind PredictionKind => Predictor.PredictionKind;
 
            public BindableMapper(IHostEnvironment env, IFeatureContributionMapper predictor, int topContributionsCount, int bottomContributionsCount, bool normalize, bool stringify)
            {
                Contracts.CheckValue(env, nameof(env));
                _env = env;
                _env.CheckValue(predictor, nameof(predictor));
                if (topContributionsCount < 0)
                    throw env.Except($"Number of top contribution must be non negative");
                if (bottomContributionsCount < 0)
                    throw env.Except($"Number of bottom contribution must be non negative");
 
                _topContributionsCount = topContributionsCount;
                _bottomContributionsCount = bottomContributionsCount;
                _normalize = normalize;
                Stringify = stringify;
                Predictor = predictor;
 
                GenericMapper = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null);
            }
 
            public BindableMapper(IHostEnvironment env, ModelLoadContext ctx)
            {
                Contracts.CheckValue(env, nameof(env));
                _env = env;
                _env.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel(GetVersionInfo());
 
                // *** Binary format ***
                // IFeatureContributionMapper: Predictor
                // int: topContributionsCount
                // int: bottomContributionsCount
                // bool: normalize
                // bool: stringify
 
                ctx.LoadModel<IFeatureContributionMapper, SignatureLoadModel>(env, out Predictor, ModelFileUtils.DirPredictor);
                GenericMapper = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null);
                _topContributionsCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 <= _topContributionsCount);
                _bottomContributionsCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 <= _bottomContributionsCount);
                _normalize = ctx.Reader.ReadBoolByte();
                Stringify = ctx.Reader.ReadBoolByte();
            }
 
            public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
            {
                Contracts.CheckValue(env, nameof(env));
                env.CheckValue(schema, nameof(schema));
                CheckSchemaValid(env, schema, Predictor);
                return new RowMapper(env, this, schema);
            }
 
            void ICanSaveModel.Save(ModelSaveContext ctx)
            {
                Contracts.CheckValue(ctx, nameof(ctx));
                ctx.SetVersionInfo(GetVersionInfo());
 
                // *** Binary format ***
                // IFeatureContributionMapper: Predictor
                // int: topContributionsCount
                // int: bottomContributionsCount
                // bool: normalize
                // bool: stringify
 
                ctx.SaveModel(Predictor, ModelFileUtils.DirPredictor);
                Contracts.Assert(0 <= _topContributionsCount);
                ctx.Writer.Write(_topContributionsCount);
                Contracts.Assert(0 <= _bottomContributionsCount);
                ctx.Writer.Write(_bottomContributionsCount);
                ctx.Writer.WriteBoolByte(_normalize);
                ctx.Writer.WriteBoolByte(Stringify);
            }
 
            public Delegate GetTextContributionGetter(DataViewRow input, int colSrc, VBuffer<ReadOnlyMemory<char>> slotNames)
            {
                Contracts.CheckValue(input, nameof(input));
                Contracts.Check(0 <= colSrc && colSrc < input.Schema.Count);
                var typeSrc = input.Schema[colSrc].Type;
 
                Func<DataViewRow, int, VBuffer<ReadOnlyMemory<char>>, ValueGetter<ReadOnlyMemory<char>>> del = GetTextValueGetter<int>;
                var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType);
                return (Delegate)meth.Invoke(this, new object[] { input, colSrc, slotNames });
            }
 
            public Delegate GetContributionGetter(DataViewRow input, int colSrc)
            {
                Contracts.CheckValue(input, nameof(input));
                Contracts.Check(0 <= colSrc && colSrc < input.Schema.Count);
 
                var typeSrc = input.Schema[colSrc].Type;
 
                // REVIEW: Assuming Feature contributions will be VBuffer<float>.
                // For multiclass LR it needs to be(VBuffer<float>[].
                return Utils.MarshalInvoke(_getValueGetterMethodInfo, this, typeSrc.RawType, input, colSrc);
            }
 
            private ReadOnlyMemory<char> GetSlotName(int index, VBuffer<ReadOnlyMemory<char>> slotNames)
            {
                _env.Assert(0 <= index && index < slotNames.Length);
                var slotName = slotNames.GetItemOrDefault(index);
                return slotName.IsEmpty
                    ? new ReadOnlyMemory<char>($"f{index}".ToCharArray())
                    : slotName;
            }
 
            private ValueGetter<ReadOnlyMemory<char>> GetTextValueGetter<TSrc>(DataViewRow input, int colSrc, VBuffer<ReadOnlyMemory<char>> slotNames)
            {
                Contracts.AssertValue(input);
                Contracts.AssertValue(Predictor);
 
                var featureGetter = input.GetGetter<TSrc>(input.Schema[colSrc]);
                var map = Predictor.GetFeatureContributionMapper<TSrc, VBuffer<float>>(_topContributionsCount, _bottomContributionsCount, _normalize);
 
                var features = default(TSrc);
                var contributions = default(VBuffer<float>);
                return
                    (ref ReadOnlyMemory<char> dst) =>
                    {
                        featureGetter(ref features);
                        map(in features, ref contributions);
                        var editor = VBufferEditor.CreateFromBuffer(ref contributions);
                        var indices = contributions.IsDense ? Utils.GetIdentityPermutation(contributions.Length) : editor.Indices;
                        var values = editor.Values;
                        var count = values.Length;
                        var sb = new StringBuilder();
                        GenericSpanSortHelper<float>.Sort(values, indices, 0, count);
                        for (var i = 0; i < count; i++)
                        {
                            var val = values[i];
                            var ind = indices[i];
                            var name = GetSlotName(ind, slotNames);
                            sb.AppendFormat("{0}: {1}, ", name, val);
                        }
 
                        if (sb.Length > 0)
                        {
                            _env.Assert(sb.Length >= 2);
                            sb.Remove(sb.Length - 2, 2);
                        }
 
                        dst = new ReadOnlyMemory<char>(sb.ToString().ToCharArray());
                    };
            }
 
            private ValueGetter<VBuffer<float>> GetValueGetter<TSrc>(DataViewRow input, int colSrc)
            {
                Contracts.AssertValue(input);
                Contracts.AssertValue(Predictor);
 
                var featureGetter = input.GetGetter<TSrc>(input.Schema[colSrc]);
 
                // REVIEW: Scorer can call Sparsification\Norm routine.
 
                var map = Predictor.GetFeatureContributionMapper<TSrc, VBuffer<float>>(_topContributionsCount, _bottomContributionsCount, _normalize);
                var features = default(TSrc);
                return
                    (ref VBuffer<float> dst) =>
                    {
                        featureGetter(ref features);
                        map(in features, ref dst);
                    };
            }
 
            private static void CheckSchemaValid(IExceptionContext ectx, RoleMappedSchema schema,
                IFeatureContributionMapper predictor)
            {
                Contracts.AssertValue(ectx);
                ectx.AssertValue(schema);
                ectx.AssertValue(predictor);
 
                // REVIEW: Check that Features column is present and is of correct size and item type.
            }
        }
 
        /// <summary>
        /// Maps a schema from input columns to output columns. Keeps track of the input columns that are needed for the mapping.
        /// </summary>
        private sealed class RowMapper : ISchemaBoundRowMapper
        {
            private readonly IHostEnvironment _env;
            private readonly ISchemaBoundRowMapper _genericRowMapper;
            private readonly BindableMapper _parent;
            private readonly DataViewSchema _outputSchema;
            private readonly DataViewSchema _outputGenericSchema;
            private readonly VBuffer<ReadOnlyMemory<char>> _slotNames;
 
            public RoleMappedSchema InputRoleMappedSchema { get; }
 
            public DataViewSchema InputSchema => InputRoleMappedSchema.Schema;
            private DataViewSchema.Column FeatureColumn => InputRoleMappedSchema.Feature.Value;
 
            public DataViewSchema OutputSchema { get; }
 
            public ISchemaBindableMapper Bindable => _parent;
 
            public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema schema)
            {
                Contracts.AssertValue(env);
                _env = env;
                _env.AssertValue(schema);
                _env.AssertValue(parent);
                _env.Assert(schema.Feature.HasValue);
                _parent = parent;
                InputRoleMappedSchema = schema;
                var genericMapper = parent.GenericMapper.Bind(_env, schema);
                _genericRowMapper = genericMapper as ISchemaBoundRowMapper;
                var featureSize = FeatureColumn.Type.GetVectorSize();
 
                if (parent.Stringify)
                {
                    var builder = new DataViewSchema.Builder();
                    builder.AddColumn(DefaultColumnNames.FeatureContributions, TextDataViewType.Instance, null);
                    _outputSchema = builder.ToSchema();
                    if (FeatureColumn.HasSlotNames(featureSize))
                        FeatureColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref _slotNames);
                    else
                        _slotNames = VBufferUtils.CreateEmpty<ReadOnlyMemory<char>>(featureSize);
                }
                else
                {
                    var metadataBuilder = new DataViewSchema.Annotations.Builder();
                    if (InputSchema[FeatureColumn.Index].HasSlotNames(featureSize))
                        metadataBuilder.AddSlotNames(featureSize, (ref VBuffer<ReadOnlyMemory<char>> value) =>
                            FeatureColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref value));
 
                    var schemaBuilder = new DataViewSchema.Builder();
                    var featureContributionType = new VectorDataViewType(NumberDataViewType.Single, ((VectorDataViewType)FeatureColumn.Type).Dimensions);
                    schemaBuilder.AddColumn(DefaultColumnNames.FeatureContributions, featureContributionType, metadataBuilder.ToAnnotations());
                    _outputSchema = schemaBuilder.ToSchema();
                }
 
                _outputGenericSchema = _genericRowMapper.OutputSchema;
                OutputSchema = new ZipBinding(new DataViewSchema[] { _outputGenericSchema, _outputSchema, }).OutputSchema;
            }
 
            /// <summary>
            /// Returns the input columns needed for the requested output columns.
            /// </summary>
            IEnumerable<DataViewSchema.Column> ISchemaBoundRowMapper.GetDependenciesForNewColumns(IEnumerable<DataViewSchema.Column> dependingColumns)
            {
                if (dependingColumns.Count() == 0)
                    return Enumerable.Empty<DataViewSchema.Column>();
 
                return Enumerable.Repeat(FeatureColumn, 1);
            }
 
            DataViewRow ISchemaBoundRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
            {
                Contracts.AssertValue(input);
                Contracts.AssertValue(activeColumns);
                var totalColumnsCount = 1 + _outputGenericSchema.Count;
                var getters = new Delegate[totalColumnsCount];
 
                if (activeColumns.Select(c => c.Index).Contains(_outputGenericSchema.Count))
                {
                    getters[totalColumnsCount - 1] = _parent.Stringify
                        ? _parent.GetTextContributionGetter(input, FeatureColumn.Index, _slotNames)
                        : _parent.GetContributionGetter(input, FeatureColumn.Index);
                }
 
                var genericRow = _genericRowMapper.GetRow(input, activeColumns);
                for (var i = 0; i < _outputGenericSchema.Count; i++)
                {
                    if (genericRow.IsColumnActive(genericRow.Schema[i]))
                        getters[i] = RowCursorUtils.GetGetterAsDelegate(genericRow, i);
                }
 
                return new SimpleRow(OutputSchema, genericRow, getters);
            }
 
            public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
            {
                yield return RoleMappedSchema.ColumnRole.Feature.Bind(FeatureColumn.Name);
            }
        }
    }
}