File: Dracula\Featurizer.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(typeof(CountTargetEncodingFeaturizer), null, typeof(SignatureLoadModel),
    "Count Target Encoding Featurizer", CountTargetEncodingFeaturizer.RegistrationName)]
 
namespace Microsoft.ML.Transforms
{
    internal sealed class CountTargetEncodingFeaturizer : ICanSaveModel
    {
        private readonly IHost _host;
        private readonly int _labelBinCount;
        private readonly int _logOddsCount;
        private readonly MultiCountTableBase _countTables;
        public MultiCountTableBase MultiCountTable => _countTables;
 
        public float[] PriorCoef { get; }
        public float[] LaplaceScale { get; }
 
        internal const string RegistrationName = "CountTargetEncoder";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "DRCTFEAT",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: RegistrationName,
                loaderAssemblyName: typeof(CountTargetEncodingFeaturizer).Assembly.FullName);
        }
 
        public int ColCount => _countTables.ColCount;
 
        public ReadOnlySpan<int> SlotCount => _countTables.SlotCount;
 
        public CountTargetEncodingFeaturizer(IHostEnvironment env, float[] priorCoef, float[] laplaceScale, long labelBinCount, MultiCountTableBase countTable)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            _host.CheckParam(labelBinCount > 1, nameof(labelBinCount), "Must be greater than 1");
 
            _labelBinCount = (int)labelBinCount;
            _logOddsCount = _labelBinCount == 2 ? 1 : _labelBinCount;
            NumFeatures = _labelBinCount + _logOddsCount + 1;
 
            PriorCoef = priorCoef;
 
            LaplaceScale = laplaceScale;
 
            _host.AssertValue(countTable);
            _countTables = countTable;
        }
 
        public CountTargetEncodingFeaturizer(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            _host.CheckValue(ctx, nameof(ctx));
 
            ctx.CheckAtModel(GetVersionInfo());
 
            // *** Binary format ***
            // int: label bin count
            // int: #features
            // float[]: prior coefficient
            // float[]: laplace scale
            // _countTables
 
            _labelBinCount = ctx.Reader.ReadInt32();
            _host.CheckDecode(_labelBinCount > 1);
            _logOddsCount = _labelBinCount == 2 ? 1 : _labelBinCount;
 
            NumFeatures = ctx.Reader.ReadInt32();
            _host.CheckDecode(NumFeatures == _labelBinCount + _logOddsCount + 1);
 
            // initialize count tables
            ctx.LoadModelOrNull<MultiCountTableBase, SignatureLoadModel>(_host, out _countTables, "CountTables");
            _host.AssertValue(_countTables);
 
            PriorCoef = ctx.Reader.ReadFloatArray(_countTables.ColCount);
            _host.CheckDecode(PriorCoef.All(x => x > 0));
 
            LaplaceScale = ctx.Reader.ReadFloatArray(_countTables.ColCount);
            _host.CheckDecode(LaplaceScale.All(x => x >= 0));
        }
 
        public void Save(ModelSaveContext ctx)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // int: label bin count
            // int: #features
            // Single: prior coefficient
            // Single: laplace scale
            // _countTable
 
            ctx.Writer.Write(_labelBinCount);
            ctx.Writer.Write(NumFeatures);
            ctx.Writer.WriteSinglesNoCount(PriorCoef);
            ctx.Writer.WriteSinglesNoCount(LaplaceScale);
            ctx.SaveModel(_countTables, "CountTables");
        }
 
        public int NumFeatures { get; }
 
        public void GetFeatureNames(string[] classNames, ref VBuffer<ReadOnlyMemory<char>> featureNames)
        {
            if (classNames == null)
            {
                classNames = new string[_labelBinCount];
                for (int i = 0; i < _labelBinCount; i++)
                    classNames[i] = string.Format("Class{0:000}", i);
            }
 
            Contracts.Check(Utils.Size(classNames) == _labelBinCount, "incorrect class names");
 
            var editor = VBufferEditor.Create(ref featureNames, NumFeatures);
            for (int i = 0; i < _labelBinCount; i++)
                editor.Values[i] = $"{classNames[i]}_Count".AsMemory();
            for (int i = 0; i < _logOddsCount; i++)
                editor.Values[_labelBinCount + i] = $"{classNames[i]}_LogOdds".AsMemory();
            editor.Values[NumFeatures - 1] = "IsBackoff".AsMemory();
            featureNames = editor.Commit();
        }
 
        public void GetFeatures(int iCol, int iSlot, Random rand, long key, Span<float> features)
        {
            _host.Assert(features.Length == NumFeatures);
 
            // get counts from count table in the first _labelBinCount indices.
            var countsBuffer = features.Slice(0, _labelBinCount);
            var countTable = _countTables[iCol, iSlot];
            countTable.GetCounts(key, countsBuffer);
 
            // check if it's garbage and replace with garbage counts if true
            float sum = 0;
            foreach (var feat in countsBuffer)
                sum += feat;
            bool isGarbage = sum < countTable.GarbageThreshold;
            if (isGarbage)
            {
                int i = 0;
                foreach (var count in countTable.GarbageCounts)
                    countsBuffer[i++] = count;
            }
 
            sum = AddLaplacianNoisePerLabel(iCol, rand, countsBuffer);
 
            // add log odds in the next _logOddsCount indices.
            GenerateLogOdds(iCol, countTable, countsBuffer, features.Slice(_labelBinCount, _logOddsCount), sum);
            _host.Assert(FloatUtils.IsFinite(features));
 
            // Add the last feature: an indicator for isGarbage.
            features[NumFeatures - 1] = isGarbage ? 1 : 0;
        }
 
        // Assumes features are filled with the respective counts.
        // Adds laplacian noise if set and returns the sum of the counts.
        private float AddLaplacianNoisePerLabel(int iCol, Random rand, Span<float> counts)
        {
            _host.Assert(_labelBinCount == counts.Length);
 
            float sum = 0;
            for (int ifeat = 0; ifeat < _labelBinCount; ifeat++)
            {
                if (rand != null && LaplaceScale[iCol] > 0)
                    counts[ifeat] += Stats.SampleFromLaplacian(rand, 0, LaplaceScale[iCol]);
 
                // Clamp to zero when noise is too big and negative.
                if (counts[ifeat] < 0)
                    counts[ifeat] = 0;
 
                sum += counts[ifeat];
            }
 
            return sum;
        }
 
        // Fills _labelBinCount log odds features. One per class, or only one if 2 classes.
        private void GenerateLogOdds(int iCol, ICountTable countTable, Span<float> counts, Span<float> logOdds, float sum)
        {
            _host.Assert(counts.Length == _labelBinCount);
            _host.Assert(logOdds.Length == _logOddsCount);
 
            for (int i = 0; i < _logOddsCount; i++)
            {
                _host.Assert(counts[i] >= 0);
                if (counts[i] <= 0 && countTable.PriorFrequencies[i] <= 0 || countTable.PriorFrequencies[i] >= 1)
                    logOdds[i] = 0; // guarding against infinite log-odds
                else
                {
                    logOdds[i] = (float)Math.Log(
                        (counts[i] + PriorCoef[iCol] * countTable.PriorFrequencies[i]) /
                        (sum - counts[i] + PriorCoef[iCol] * (1 - countTable.PriorFrequencies[i])));
                }
            }
        }
    }
}