File: UnitTests\TestLoss.cs
Web Access
Project: src\test\Microsoft.ML.Core.Tests\Microsoft.ML.Core.Tests.csproj (Microsoft.ML.Core.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.TestFramework;
using Microsoft.ML.Trainers;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.RunTests
{
    /// <summary>
    /// These are tests of the loss functions in the Learners assembly.
    /// </summary>
    public class TestLoss : BaseTestClass
    {
        private const string _category = "Loss";
 
        private const float _epsilon = 1e-4f;
 
        public TestLoss(ITestOutputHelper output) : base(output)
        {
        }
 
        /// <summary>
        /// A small helper for comparing a loss's computations to expected values.
        /// </summary>
        /// <param name="lossFunc">The training loss.</param>
        /// <param name="label">The ideal labeled output.</param>
        /// <param name="output">The actual output.</param>
        /// <param name="expectedLoss">The expected value of this loss, given
        /// <c>label</c> and <c>output</c></param>
        /// <param name="expectedUpdate">The expected value of the update
        /// step, given <c>label</c> and <c>output</c></param>
        /// <param name="differentiable">Whether the loss function is differentiable
        /// w.r.t. the output in the vicinity of the output value</param>
        private void TestHelper(IScalarLoss lossFunc, double label, double output, double expectedLoss, double expectedUpdate, bool differentiable = true)
        {
            Double loss = lossFunc.Loss((float)output, (float)label);
            float derivative = lossFunc.Derivative((float)output, (float)label);
            Assert.Equal(expectedLoss, loss, 0.00001);
            Assert.Equal(expectedUpdate, -derivative, 0.00001);
 
            if (differentiable)
            {
                // In principle, the update should be the negative of the first derivative of the loss.
                // Use a simple finite difference method to see if it's in the right ballpark.
                float almostOutput = Math.Max((float)output * (1 + _epsilon), (float)output + _epsilon);
                Double almostLoss = lossFunc.Loss(almostOutput, (float)label);
                Assert.Equal((almostLoss - loss) / (almostOutput - output), derivative, 0.1);
            }
        }
 
        [Fact]
        public void LossHinge()
        {
            var loss = new HingeLoss();
            // Positive examples.
            TestHelper(loss, 1, 2, 0, 0);
            TestHelper(loss, 1, 1, 0, 0, false);
            TestHelper(loss, 1, 0.99, 0.01, 1, false);
            TestHelper(loss, 1, 0.5, 0.5, 1);
            // Negative examples.
            TestHelper(loss, 0, 0.5, 1.5, -1);
            TestHelper(loss, 0, -0.5, 0.5, -1);
            TestHelper(loss, 0, -1, 0, 0, false);
            TestHelper(loss, 0, -2, 0, 0);
        }
 
        [Fact]
        public void LossExponential()
        {
            ExpLoss.Options options = new ExpLoss.Options();
            ExpLoss loss = new ExpLoss(options);
            TestHelper(loss, 1, 3, Math.Exp(-3), Math.Exp(-3));
            TestHelper(loss, 0, 3, Math.Exp(3), -Math.Exp(3));
            TestHelper(loss, 0, -3, Math.Exp(-3), -Math.Exp(-3));
        }
 
        [Fact]
        public void LossSquared()
        {
            SquaredLoss loss = new SquaredLoss();
            TestHelper(loss, 3, 2, 1, 2);
            TestHelper(loss, 3, 4, 1, -2);
            TestHelper(loss, 3, 5, 4, -4);
            TestHelper(loss, 0, -3, 9, 6);
            TestHelper(loss, -3, -5, 4, 4);
        }
    }
}