File: ComputeLRTrainingStdThroughHal.cs
Web Access
Project: src\src\Microsoft.ML.Mkl.Components\Microsoft.ML.Mkl.Components.csproj (Microsoft.ML.Mkl.Components)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
 
namespace Microsoft.ML.Trainers
{
    using MklOls = OlsTrainer.Mkl;
 
    public sealed class ComputeLRTrainingStdThroughMkl : ComputeLogisticRegressionStandardDeviation
    {
        /// <summary>
        /// Computes the standart deviation matrix of each of the non-zero training weights, needed to calculate further the standart deviation,
        /// p-value and z-Score.
        /// </summary>
        /// <param name="hessian"></param>
        /// <param name="weightIndices"></param>
        /// <param name="numSelectedParams"></param>
        /// <param name="currentWeightsCount"></param>
        /// <param name="ch">The <see cref="IChannel"/> used for messaging.</param>
        /// <param name="l2Weight">The L2Weight used for training. (Supply the same one that got used during training.)</param>
        public override VBuffer<float> ComputeStandardDeviation(double[] hessian, int[] weightIndices, int numSelectedParams, int currentWeightsCount, IChannel ch, float l2Weight)
        {
            Contracts.AssertValue(ch);
            Contracts.AssertValue(hessian, nameof(hessian));
            Contracts.Assert(numSelectedParams > 0);
            Contracts.Assert(currentWeightsCount > 0);
            Contracts.Assert(l2Weight > 0);
 
            // Apply Cholesky Decomposition to find the inverse of the Hessian.
            Double[] invHessian = null;
            try
            {
                // First, find the Cholesky decomposition LL' of the Hessian.
                MklOls.Pptrf(MklOls.Layout.RowMajor, MklOls.UpLo.Lo, numSelectedParams, hessian);
                // Note that hessian is already modified at this point. It is no longer the original Hessian,
                // but instead represents the Cholesky decomposition L.
                // Also note that the following routine is supposed to consume the Cholesky decomposition L instead
                // of the original information matrix.
                MklOls.Pptri(MklOls.Layout.RowMajor, MklOls.UpLo.Lo, numSelectedParams, hessian);
                // At this point, hessian should contain the inverse of the original Hessian matrix.
                // Swap hessian with invHessian to avoid confusion in the following context.
                Utils.Swap(ref hessian, ref invHessian);
                Contracts.Assert(hessian == null);
            }
            catch (DllNotFoundException)
            {
                throw ch.ExceptNotSupp("The MKL library (MklImports.dll) or one of its dependencies is missing.");
            }
 
            float[] stdErrorValues = new float[numSelectedParams];
            stdErrorValues[0] = (float)Math.Sqrt(invHessian[0]);
 
            for (int i = 1; i < numSelectedParams; i++)
            {
                // Initialize with inverse Hessian.
                stdErrorValues[i] = (float)invHessian[i * (i + 1) / 2 + i];
            }
 
            if (l2Weight > 0)
            {
                // Iterate through all entries of inverse Hessian to make adjustment to variance.
                // A discussion on ridge regularized LR coefficient covariance matrix can be found here:
                // http://www.aloki.hu/pdf/0402_171179.pdf (Equations 11 and 25)
                // http://www.inf.unibz.it/dis/teaching/DWDM/project2010/LogisticRegression.pdf (Section "Significance testing in ridge logistic regression")
                int ioffset = 1;
                for (int iRow = 1; iRow < numSelectedParams; iRow++)
                {
                    for (int iCol = 0; iCol <= iRow; iCol++)
                    {
                        var entry = (float)invHessian[ioffset++];
                        var adjustment = l2Weight * entry * entry;
                        stdErrorValues[iRow] -= adjustment;
 
                        if (0 < iCol && iCol < iRow)
                            stdErrorValues[iCol] -= adjustment;
                    }
                }
 
                Contracts.Assert(ioffset == invHessian.Length);
            }
 
            for (int i = 1; i < numSelectedParams; i++)
                stdErrorValues[i] = (float)Math.Sqrt(stdErrorValues[i]);
 
            // currentWeights vector size is Weights2 + the bias
            return new VBuffer<float>(currentWeightsCount, numSelectedParams, stdErrorValues, weightIndices);
        }
    }
}