File: LdSvm\LdSvmModelParameters.cs
Web Access
Project: src\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj (Microsoft.ML.StandardTrainers)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
 
[assembly: LoadableClass(typeof(LdSvmModelParameters), null, typeof(SignatureLoadModel), "LDSVM binary predictor", LdSvmModelParameters.LoaderSignature)]
 
namespace Microsoft.ML.Trainers
{
    public sealed class LdSvmModelParameters : ModelParametersBase<float>,
        IValueMapper,
        ICanSaveModel
    {
        internal const string LoaderSignature = "LDSVMBinaryPredictor";
 
        /// <summary>
        /// Version information to be saved in binary format
        /// </summary>
        /// <returns></returns>
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "LDSVM BC",
                verWrittenCur: 0x00010001,
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(LdSvmModelParameters).Assembly.FullName);
        }
 
        // Classifier Parameters
        private readonly int _numLeaf;
        private readonly float _sigma;
        private readonly VBuffer<float>[] _w;
        private readonly VBuffer<float>[] _thetaPrime;
        private readonly VBuffer<float>[] _theta;
        private readonly float[] _biasW;
        private readonly float[] _biasTheta;
        private readonly float[] _biasThetaPrime;
 
        /// <summary>
        /// Constructor. w, thetaPrime, theta must be dense <see cref="VBuffer{T}"/>s.
        /// Note that this takes over ownership of all such vectors.
        /// </summary>
        internal LdSvmModelParameters(IHostEnvironment env, VBuffer<float>[] w, VBuffer<float>[] thetaPrime, VBuffer<float>[] theta,
            float sigma, float[] biasW, float[] biasTheta, float[] biasThetaPrime, int treeDepth)
            : base(env, LoaderSignature)
        {
            // _numLeaf is 32-bit signed integer.
            Host.Assert(treeDepth > 0 && treeDepth < 31);
            int numLeaf = 1 << treeDepth;
 
            Host.Assert(w.Length == numLeaf * 2 - 1);
            Host.Assert(w.All(v => v.IsDense));
            Host.Assert(w.All(v => v.Length == w[0].Length));
            Host.Assert(thetaPrime.Length == numLeaf * 2 - 1);
            Host.Assert(thetaPrime.All(v => v.IsDense));
            Host.Assert(thetaPrime.All(v => v.Length == thetaPrime[0].Length));
            Host.Assert(theta.Length == numLeaf - 1);
            Host.Assert(theta.All(v => v.IsDense));
            Host.Assert(theta.All(v => v.Length == theta[0].Length));
            Host.Assert(biasW.Length == numLeaf * 2 - 1);
            Host.Assert(biasTheta.Length == numLeaf - 1);
            Host.Assert(biasThetaPrime.Length == numLeaf * 2 - 1);
            Host.Assert((w[0].Length > 0) && (w[0].Length == thetaPrime[0].Length) && (w[0].Length == theta[0].Length));
 
            _numLeaf = numLeaf;
            _sigma = sigma;
            _w = w;
            _thetaPrime = thetaPrime;
            _theta = theta;
            _biasW = biasW;
            _biasTheta = biasTheta;
            _biasThetaPrime = biasThetaPrime;
 
            InputType = new VectorDataViewType(NumberDataViewType.Single, _w[0].Length);
 
            AssertValid();
        }
 
        private LdSvmModelParameters(IHostEnvironment env, ModelLoadContext ctx)
            : base(env, LoaderSignature, ctx)
        {
            // *** Binary format ***
            // int: _numLeaf
            // int: numFeatures
            // float: _sigma
            // (_numLeaf * 2 - 1) times: a vector in _w
            //      float[numFeatures]
            // (_numLeaf * 2 - 1) times: a vector in _thetaPrime
            //      float[numFeatures]
            // (_numLeaf - 1) times: a vector in _theta
            //      float[numFeatures]
            // float[_numLeaf * 2 - 1]: _biasW
            // float[_numLeaf - 1]: _biasTheta
            // float[_numLeaf * 2 - 1]: _biasThetaPrime
 
            _numLeaf = ctx.Reader.ReadInt32();
            Host.CheckDecode(_numLeaf > 1 && (_numLeaf & (_numLeaf - 1)) == 0);
            int numFeatures = ctx.Reader.ReadInt32();
            Host.CheckDecode(numFeatures > 0);
 
            _sigma = ctx.Reader.ReadFloat();
 
            _w = LoadVBufferArray(ctx, _numLeaf * 2 - 1, numFeatures);
            _thetaPrime = LoadVBufferArray(ctx, _numLeaf * 2 - 1, numFeatures);
            _theta = LoadVBufferArray(ctx, _numLeaf - 1, numFeatures);
            _biasW = ctx.Reader.ReadFloatArray(_numLeaf * 2 - 1);
            _biasTheta = ctx.Reader.ReadFloatArray(_numLeaf - 1);
            _biasThetaPrime = ctx.Reader.ReadFloatArray(_numLeaf * 2 - 1);
            WarnOnOldNormalizer(ctx, GetType(), Host);
 
            InputType = new VectorDataViewType(NumberDataViewType.Single, numFeatures);
 
            AssertValid();
        }
 
        private void AssertValid()
        {
            Host.Assert(_numLeaf > 1 && (_numLeaf & (_numLeaf - 1)) == 0); // Check if _numLeaf is power of 2
            Host.Assert(_w.Length == _numLeaf * 2 - 1);
            Host.Assert(_w.All(v => v.IsDense));
            Host.Assert(_w.All(v => v.Length == _w[0].Length));
            Host.Assert(_thetaPrime.Length == _numLeaf * 2 - 1);
            Host.Assert(_thetaPrime.All(v => v.IsDense));
            Host.Assert(_thetaPrime.All(v => v.Length == _thetaPrime[0].Length));
            Host.Assert(_theta.Length == _numLeaf - 1);
            Host.Assert(_theta.All(v => v.IsDense));
            Host.Assert(_theta.All(v => v.Length == _theta[0].Length));
            Host.Assert(_biasW.Length == _numLeaf * 2 - 1);
            Host.Assert(_biasTheta.Length == _numLeaf - 1);
            Host.Assert(_biasThetaPrime.Length == _numLeaf * 2 - 1);
            Host.Assert((_w[0].Length > 0) && (_w[0].Length == _thetaPrime[0].Length) && (_w[0].Length == _theta[0].Length)); // numFeatures
            Host.Assert(InputType != null && InputType.GetVectorSize() == _w[0].Length);
        }
 
        /// <summary>
        /// Create method to instantiate a predictor.
        /// </summary>
        private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            return new LdSvmModelParameters(env, ctx);
        }
 
        private protected override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
 
        /// <summary>
        /// Save the predictor in binary format.
        /// </summary>
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // int: _numLeaf
            // int: numFeatures
            // float: _sigma
            // (_numLeaf * 2 - 1) times: a vector in _w
            //      float[numFeatures]
            // (_numLeaf * 2 - 1) times: a vector in _thetaPrime
            //      float[numFeatures]
            // (_numLeaf - 1) times: a vector in _theta
            //      float[numFeatures]
            // float[_numLeaf * 2 - 1]: _biasW
            // float[_numLeaf - 1]: _biasTheta
            // float[_numLeaf * 2 - 1]: _biasThetaPrime
 
            int numFeatures = _w[0].Length;
 
            ctx.Writer.Write(_numLeaf);
            ctx.Writer.Write(numFeatures);
            ctx.Writer.Write(_sigma);
 
            Host.Assert(_w.Length == _numLeaf * 2 - 1);
            SaveVBufferArray(ctx, _w);
            Host.Assert(_thetaPrime.Length == _numLeaf * 2 - 1);
            SaveVBufferArray(ctx, _thetaPrime);
            Host.Assert(_theta.Length == _numLeaf - 1);
            SaveVBufferArray(ctx, _theta);
 
            Host.Assert(_biasW.Length == _numLeaf * 2 - 1);
            ctx.Writer.WriteSinglesNoCount(_biasW.AsSpan());
            Host.Assert(_biasTheta.Length == _numLeaf - 1);
            ctx.Writer.WriteSinglesNoCount(_biasTheta.AsSpan());
            Host.Assert(_biasThetaPrime.Length == _numLeaf * 2 - 1);
            ctx.Writer.WriteSinglesNoCount(_biasThetaPrime.AsSpan());
        }
 
        /// <summary>
        /// Save an array of <see cref="VBuffer{T}"/> in binary format. The vectors must be dense.
        /// </summary>
        /// <param name="ctx">The context where we will save the vectors.</param>
        /// <param name="data">An array of vectors.</param>
        private void SaveVBufferArray(ModelSaveContext ctx, VBuffer<float>[] data)
        {
            if (data.Length == 0)
                return;
 
            int vectorLength = data[0].Length;
            for (int i = 0; i < data.Length; i++)
            {
                var vector = data[i];
                Host.Assert(vector.IsDense);
                Host.Assert(vector.Length == vectorLength);
                ctx.Writer.WriteSinglesNoCount(vector.GetValues());
            }
        }
 
        /// <summary>
        /// Load an array of <see cref="VBuffer{T}"/> from binary format.
        /// </summary>
        /// <param name="ctx">The context from which to read the vectors.</param>
        /// <param name="length">The length of the array of vectors.</param>
        /// <param name="vectorLength">The length of each vector.</param>
        /// <returns>An array of vectors.</returns>
        private VBuffer<float>[] LoadVBufferArray(ModelLoadContext ctx, int length, int vectorLength)
        {
            Host.Assert(length >= 0);
            Host.Assert(vectorLength >= 0);
 
            VBuffer<float>[] result = new VBuffer<float>[length];
 
            for (int i = 0; i < length; i++)
            {
                result[i] = new VBuffer<float>(vectorLength, ctx.Reader.ReadFloatArray(vectorLength));
                Host.Assert(result[i].IsDense);
                Host.Assert(result[i].Length == vectorLength);
            }
            return result;
        }
 
        /// <summary>
        /// Compute Margin.
        /// </summary>
        private float Margin(in VBuffer<float> src)
        {
            double score = 0;
            double childIndicator;
            int current = 0;
            while (current < _numLeaf - 1)
            {
                score += Math.Tanh(_sigma * (VectorUtils.DotProduct(in _thetaPrime[current], in src) + _biasThetaPrime[current])) *
                    (VectorUtils.DotProduct(in _w[current], in src) + _biasW[current]);
                childIndicator = VectorUtils.DotProduct(in _theta[current], in src) + _biasTheta[current];
                current = (childIndicator > 0) ? 2 * current + 1 : 2 * current + 2;
            }
            score += Math.Tanh(_sigma * (VectorUtils.DotProduct(in _thetaPrime[current], in src) + _biasThetaPrime[current])) *
                    (VectorUtils.DotProduct(in _w[current], in src) + _biasW[current]);
            return (float)score;
        }
 
        public DataViewType InputType { get; }
 
        public DataViewType OutputType => NumberDataViewType.Single;
 
        ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
        {
            Host.Check(typeof(TIn) == typeof(VBuffer<float>));
            Host.Check(typeof(TOut) == typeof(float));
 
            ValueMapper<VBuffer<float>, float> del =
                (in VBuffer<float> src, ref float dst) =>
                {
                    Host.Check(src.Length == InputType.GetVectorSize());
                    dst = Margin(in src);
                };
            return (ValueMapper<TIn, TOut>)(Delegate)del;
        }
    }
}