File: Common\BLEUAlgorithm.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.Evaluation.NLP\Microsoft.Extensions.AI.Evaluation.NLP.csproj (Microsoft.Extensions.AI.Evaluation.NLP)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Shared.Diagnostics;
 
namespace Microsoft.Extensions.AI.Evaluation.NLP.Common;
 
/// <summary>
/// Helper methods for calculating the BLEU score.
/// See <see href="https://en.wikipedia.org/wiki/BLEU">BLEU on Wikipedia</see> or
/// <see href="https://github.com/nltk/nltk/blob/develop/nltk/translate/bleu_score.py">NLTK implementation</see>
/// for more details.
/// </summary>
internal static class BLEUAlgorithm
{
    internal static int ClosestRefLength(IEnumerable<IEnumerable<string>> references, int hypLength)
    {
        if (!references.Any())
        {
            return 0;
        }
 
        int closestRefLength = 0;
        int smallestDiff = int.MaxValue;
        foreach (var reference in references)
        {
            int refLength = reference.Count();
            int diff = Math.Abs(refLength - hypLength);
            if (diff < smallestDiff ||
               (diff == smallestDiff && refLength < closestRefLength))
            {
                smallestDiff = diff;
                closestRefLength = refLength;
            }
        }
 
        return closestRefLength;
    }
 
    internal static double BrevityPenalty(int closestRefLength, int hypLength)
    {
        if (hypLength <= 0)
        {
            return 0.0;
        }
 
        if (closestRefLength <= 0 || hypLength > closestRefLength)
        {
            return 1.0;
        }
 
        return Math.Exp(1 - ((double)closestRefLength / hypLength));
    }
 
    internal static RationalNumber ModifiedPrecision(IEnumerable<IEnumerable<string>> references, IEnumerable<string> hypothesis, int n = 1)
    {
        if (n <= 0)
        {
            Throw.ArgumentOutOfRangeException(nameof(n), $"`{nameof(n)}` must be greater than zero.");
        }
 
        if (!references.Any() || !hypothesis.Any())
        {
            return RationalNumber.Zero;
        }
 
        var hyp = hypothesis.CreateNGrams(n);
        var hypCounts = new MatchCounter<NGram<string>>(hyp);
 
        Dictionary<NGram<string>, int> maxCounts = [];
 
        foreach (var rf in references)
        {
            IEnumerable<NGram<string>> refGrams = rf.CreateNGrams(n);
            var refCounts = new MatchCounter<NGram<string>>(refGrams);
 
            foreach (var ct in refCounts)
            {
                if (maxCounts.TryGetValue(ct.Key, out int val))
                {
                    maxCounts[ct.Key] = Math.Max(val, ct.Value);
                }
                else
                {
                    maxCounts[ct.Key] = ct.Value;
                }
            }
        }
 
        Dictionary<NGram<string>, int> clippedCounts = [];
        foreach (var h in hypCounts)
        {
            if (maxCounts.TryGetValue(h.Key, out var v))
            {
                clippedCounts[h.Key] = Math.Min(h.Value, v);
            }
            else
            {
                // If the hypothesis n-gram is not in any reference, it is clipped to 0.
                clippedCounts[h.Key] = 0;
            }
        }
 
        int numerator = clippedCounts.Values.Sum();
        int denominator = Math.Max(1, hypCounts.Sum());
 
        return new RationalNumber(numerator, denominator);
    }
 
    /// <summary>
    /// Generate an n-sized array of equal weights that sum to 1.0.
    /// </summary>
    /// <param name="n">Number of weights to return.</param>
    /// <returns>Array of equal sized values that sum to 1.0.</returns>
    internal static double[] EqualWeights(int n)
    {
        if (n <= 0)
        {
            Throw.ArgumentOutOfRangeException(nameof(n), $"'{nameof(n)}' must be greater than zero.");
        }
 
        double[] weights = new double[n];
        for (int i = 0; i < n; i++)
        {
            weights[i] = 1.0 / n;
        }
 
        return weights;
    }
 
    internal static readonly double[] DefaultBLEUWeights = EqualWeights(4);
 
    internal static double SentenceBLEU(IEnumerable<IEnumerable<string>> references, IEnumerable<string> hypothesis,
        double[]? weights = null, Func<RationalNumber[], int, double[]>? smoothingFunction = null)
    {
        if (references == null || !references.Any())
        {
            Throw.ArgumentNullException(nameof(references), $"'{nameof(references)}' cannot be null or empty.");
        }
 
        if (hypothesis == null || !hypothesis.Any())
        {
            Throw.ArgumentNullException(nameof(hypothesis), $"'{nameof(hypothesis)}' cannot be null or empty.");
        }
 
        if (weights is null)
        {
            weights = DefaultBLEUWeights;
        }
 
        if (weights.Length == 0)
        {
            Throw.ArgumentNullException(nameof(weights), $"'{nameof(weights)}' cannot be empty.");
        }
 
        var precisionValues = new RationalNumber[weights.Length];
        for (int i = 0; i < weights.Length; i++)
        {
            int n = i + 1;
            RationalNumber prec = ModifiedPrecision(references, hypothesis, n);
 
            if (i == 0 && prec.Numerator == 0)
            {
                // If the precision for unigrams (n == 1) is zero, the there can be no higher order matches and BLEU score is zero.
                return 0.0;
            }
 
            precisionValues[i] = prec;
        }
 
        int hypLen = hypothesis.Count();
        int closestRefLength = ClosestRefLength(references, hypLen);
        double brevityPenalty = BrevityPenalty(closestRefLength, hypLen);
 
        if (smoothingFunction == null)
        {
            smoothingFunction = SmoothingFunction.Method0;
        }
 
        double[] smoothedValues = smoothingFunction(precisionValues, hypLen);
 
        double score = 0.0;
        for (int i = 0; i < weights.Length; i++)
        {
            if (smoothedValues[i] > 0)
            {
                score += weights[i] * Math.Log(smoothedValues[i]);
            }
        }
 
        return brevityPenalty * Math.Exp(score);
    }
 
}