File: Scorers\ClusteringScorer.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Newtonsoft.Json.Linq;
 
[assembly: LoadableClass(typeof(ClusteringScorer), typeof(ClusteringScorer.Arguments), typeof(SignatureDataScorer),
    "Clustering Scorer", "ClusteringScorer", AnnotationUtils.Const.ScoreColumnKind.Clustering)]
 
[assembly: LoadableClass(typeof(ClusteringScorer), null, typeof(SignatureLoadDataTransform),
    "Clustering Scorer", ClusteringScorer.LoaderSignature)]
 
namespace Microsoft.ML.Data
{
    internal sealed class ClusteringScorer : PredictedLabelScorerBase
    {
        public sealed class Arguments : ScorerArgumentsBase
        {
        }
 
        public const string LoaderSignature = "ClusteringScoreTrans";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "CLSTRSCR",
                //verWrittenCur: 0x00010001, // Initial
                //verWrittenCur: 0x00010002, // ISchemaBindableMapper
                verWrittenCur: 0x00010003, // ISchemaBindableMapper update
                verReadableCur: 0x00010003,
                verWeCanReadBack: 0x00010003,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(ClusteringScorer).Assembly.FullName);
        }
 
        private const string RegistrationName = "ClusteringScore";
 
        [BestFriend]
        internal ClusteringScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
            : base(args, env, data, mapper, trainSchema, RegistrationName, AnnotationUtils.Const.ScoreColumnKind.Clustering,
                AnnotationUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType)
        {
        }
 
        private ClusteringScorer(IHostEnvironment env, ClusteringScorer transform, IDataView newSource)
            : base(env, transform, newSource, RegistrationName)
        {
        }
 
        private ClusteringScorer(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, ctx, input, OutputTypeMatches, GetPredColType)
        {
            // *** Binary format ***
            // <base info>
        }
 
        public static ClusteringScorer Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
            h.CheckValue(ctx, nameof(ctx));
            h.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());
 
            return h.Apply("Loading Model", ch => new ClusteringScorer(h, ctx, input));
        }
 
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            Contracts.AssertValue(ctx);
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // <base info>
 
            base.SaveCore(ctx);
        }
 
        private protected override IDataTransform ApplyToDataCore(IHostEnvironment env, IDataView newSource)
        {
            Contracts.CheckValue(env, nameof(env));
            Contracts.CheckValue(newSource, nameof(newSource));
 
            return new ClusteringScorer(env, this, newSource);
        }
 
        protected override Delegate GetPredictedLabelGetter(DataViewRow output, out Delegate scoreGetter)
        {
            Contracts.AssertValue(output);
            Contracts.Assert(output.Schema == Bindings.RowMapper.OutputSchema);
            Contracts.Assert(output.IsColumnActive(output.Schema[Bindings.ScoreColumnIndex]));
 
            ValueGetter<VBuffer<float>> mapperScoreGetter = output.GetGetter<VBuffer<float>>(output.Schema[Bindings.ScoreColumnIndex]);
 
            long cachedPosition = -1;
            VBuffer<float> score = default(VBuffer<float>);
            int keyCount = Bindings.PredColType is KeyDataViewType key ? key.GetCountAsInt32(Host) : 0;
            int scoreLength = keyCount;
 
            ValueGetter<uint> predFn =
                (ref uint dst) =>
                {
                    EnsureCachedPosition(ref cachedPosition, ref score, output, mapperScoreGetter);
                    Contracts.Check(score.Length == scoreLength);
                    int index = VectorUtils.ArgMin(in score);
                    if (index < 0)
                        dst = 0;
                    else
                        dst = (uint)index + 1;
                };
            ValueGetter<VBuffer<float>> scoreFn =
                (ref VBuffer<float> dst) =>
                {
                    EnsureCachedPosition(ref cachedPosition, ref score, output, mapperScoreGetter);
                    Contracts.Check(score.Length == scoreLength);
                    score.CopyTo(ref dst);
                };
 
            scoreGetter = scoreFn;
            return predFn;
        }
 
        private protected override JToken PredictedLabelPfa(string[] mapperOutputs)
        {
            Contracts.Assert(Utils.Size(mapperOutputs) == 1);
            return PfaUtils.Call("a.argmax", mapperOutputs[0]);
        }
 
        private static DataViewType GetPredColType(DataViewType scoreType, ISchemaBoundRowMapper mapper)
        {
            return new KeyDataViewType(typeof(uint), scoreType.GetVectorSize());
        }
 
        private static bool OutputTypeMatches(DataViewType scoreType)
        {
            return scoreType is VectorDataViewType vectorType
                && vectorType.IsKnownSize
                && vectorType.ItemType == NumberDataViewType.Single;
        }
    }
}