|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
namespace Microsoft.ML
{
/// <summary>
/// Base class for the trainer catalogs.
/// </summary>
public abstract class TrainCatalogBase : IInternalCatalog
{
IHostEnvironment IInternalCatalog.Environment => Environment;
[BestFriend]
private protected IHostEnvironment Environment { get; }
/// <summary>
/// Results for specific cross-validation fold.
/// </summary>
[BestFriend]
private protected struct CrossValidationResult
{
/// <summary>
/// Model trained during cross validation fold.
/// </summary>
public readonly ITransformer Model;
/// <summary>
/// Scored test set with <see cref="Model"/> for this fold.
/// </summary>
public readonly IDataView Scores;
/// <summary>
/// Fold number.
/// </summary>
public readonly int Fold;
public CrossValidationResult(ITransformer model, IDataView scores, int fold)
{
Model = model;
Scores = scores;
Fold = fold;
}
}
/// <summary>
/// Results of running cross-validation.
/// </summary>
/// <typeparam name="T">Type of metric class.</typeparam>
public sealed class CrossValidationResult<T> where T : class
{
/// <summary>
/// Metrics for this cross-validation fold.
/// </summary>
public readonly T Metrics;
/// <summary>
/// Model trained during cross-validation fold.
/// </summary>
public readonly ITransformer Model;
/// <summary>
/// The scored hold-out set for this fold.
/// </summary>
public readonly IDataView ScoredHoldOutSet;
/// <summary>
/// Fold number.
/// </summary>
public readonly int Fold;
internal CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold)
{
Model = model;
Metrics = metrics;
ScoredHoldOutSet = scores;
Fold = fold;
}
}
/// <summary>
/// Train the <paramref name="estimator"/> on <paramref name="numFolds"/> folds of the data sequentially.
/// Return each model and each scored test dataset.
/// </summary>
[BestFriend]
private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator,
int numFolds, string samplingKeyColumn, int? seed = null)
{
Environment.CheckValue(data, nameof(data));
Environment.CheckValue(estimator, nameof(estimator));
Environment.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1");
Environment.CheckValueOrNull(samplingKeyColumn);
var splitColumn = DataOperationsCatalog.CreateSplitColumn(Environment, ref data, samplingKeyColumn, seed, fallbackInEnvSeed: true);
var result = new CrossValidationResult[numFolds];
int fold = 0;
// Sequential per-fold training.
// REVIEW: we could have a parallel implementation here. We would need to
// spawn off a separate host per fold in that case.
foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, splitColumn, numFolds))
{
var model = estimator.Fit(split.TrainSet);
IDataView scoredTest;
if (IsCastableToTransformerChainOfITransformer(model))
scoredTest = (Unsafe.As<TransformerChain<ITransformer>>(model)).Transform(split.TestSet, TransformerScope.Everything);
else
scoredTest = model.Transform(split.TestSet);
result[fold] = new CrossValidationResult(model, scoredTest, fold);
fold++;
}
return result;
}
private static bool IsCastableToTransformerChainOfITransformer(object o)
{
var type = o.GetType();
while (!type!.FullName!.StartsWith("Microsoft.ML.Data.TransformerChain`1[", StringComparison.Ordinal))
{
type = type!.BaseType;
if (type is null)
{
return false;
}
}
return true;
}
[BestFriend]
private protected TrainCatalogBase(IHostEnvironment env, string registrationName)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonEmpty(registrationName, nameof(registrationName));
Environment = env;
}
/// <summary>
/// Subclasses of <see cref="TrainContext"/> will provide little "extension method" hookable objects
/// (for example, something like <see cref="BinaryClassificationCatalog.Trainers"/>). User code will only
/// interact with these objects by invoking the extension methods. The actual component code can work
/// through <see cref="CatalogUtils"/> to get more "hidden" information from this object,
/// for example, the environment.
/// </summary>
public abstract class CatalogInstantiatorBase : IInternalCatalog
{
IHostEnvironment IInternalCatalog.Environment => Owner.GetEnvironment();
[BestFriend]
internal TrainCatalogBase Owner { get; }
[BestFriend]
private protected CatalogInstantiatorBase(TrainCatalogBase catalog)
{
Owner = catalog;
}
}
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of binary classification components,
/// such as trainers and calibrators.
/// </summary>
public sealed class BinaryClassificationCatalog : TrainCatalogBase
{
/// <summary>
/// The list of trainers for performing binary classification.
/// </summary>
public BinaryClassificationTrainers Trainers { get; }
internal BinaryClassificationCatalog(IHostEnvironment env)
: base(env, nameof(BinaryClassificationCatalog))
{
Calibrators = new CalibratorsCatalog(this);
Trainers = new BinaryClassificationTrainers(this);
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of binary classification trainers.
/// </summary>
public sealed class BinaryClassificationTrainers : CatalogInstantiatorBase
{
internal BinaryClassificationTrainers(BinaryClassificationCatalog catalog)
: base(catalog)
{
}
}
/// <summary>
/// Evaluates scored binary classification data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param>
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
/// <param name="probabilityColumnName">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="scoreColumnName"/>.</param>
/// <param name="predictedLabelColumnName">The name of the predicted label column in <paramref name="data"/>.</param>
/// <returns>The evaluation results for these calibrated outputs.</returns>
public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string labelColumnName = DefaultColumnNames.Label, string scoreColumnName = DefaultColumnNames.Score,
string probabilityColumnName = DefaultColumnNames.Probability, string predictedLabelColumnName = DefaultColumnNames.PredictedLabel)
{
Environment.CheckValue(data, nameof(data));
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName));
Environment.CheckNonEmpty(probabilityColumnName, nameof(probabilityColumnName));
Environment.CheckNonEmpty(predictedLabelColumnName, nameof(predictedLabelColumnName));
var eval = new BinaryClassifierEvaluator(Environment, new BinaryClassifierEvaluator.Arguments() { });
return eval.Evaluate(data, labelColumnName, scoreColumnName, probabilityColumnName, predictedLabelColumnName);
}
/// <summary>
/// Evaluates scored binary classification data, without probability-based metrics.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param>
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
/// <param name="predictedLabelColumnName">The name of the predicted label column in <paramref name="data"/>.</param>
/// <returns>The evaluation results for these uncalibrated outputs.</returns>
public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string labelColumnName = DefaultColumnNames.Label, string scoreColumnName = DefaultColumnNames.Score,
string predictedLabelColumnName = DefaultColumnNames.PredictedLabel)
{
Environment.CheckValue(data, nameof(data));
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
Environment.CheckNonEmpty(predictedLabelColumnName, nameof(predictedLabelColumnName));
var eval = new BinaryClassifierEvaluator(Environment, new BinaryClassifierEvaluator.Arguments() { });
return eval.Evaluate(data, labelColumnName, scoreColumnName, predictedLabelColumnName);
}
/// <summary>
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return a <see cref="BinaryClassificationMetrics"/> object, which
/// do not include probability-based metrics, for each sub-model. Each sub-model is evaluated on the cross-validation fold that it did not see during training.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
/// <param name="labelColumnName">The label column (for evaluation).</param>
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
/// If <see langword="null"/> no row grouping will be performed.</param>
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public IReadOnlyList<CrossValidationResult<BinaryClassificationMetrics>> CrossValidateNonCalibrated(
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = DefaultColumnNames.Label,
string samplingKeyColumnName = null, int? seed = null)
{
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
return result.Select(x => new CrossValidationResult<BinaryClassificationMetrics>(x.Model,
EvaluateNonCalibrated(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
}
/// <summary>
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return a <see cref="CalibratedBinaryClassificationMetrics"/> object, which
/// includes probability-based metrics, for each sub-model. Each sub-model is evaluated on the cross-validation fold that it did not see during training.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
/// <param name="labelColumnName">The label column (for evaluation).</param>
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
/// If <see langword="null"/> no row grouping will be performed.</param>
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public IReadOnlyList<CrossValidationResult<CalibratedBinaryClassificationMetrics>> CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = DefaultColumnNames.Label,
string samplingKeyColumnName = null, int? seed = null)
{
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
return result.Select(x => new CrossValidationResult<CalibratedBinaryClassificationMetrics>(x.Model,
Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
}
/// <summary>
/// Method to modify the threshold to existing model and return modified model.
/// </summary>
/// <typeparam name="TModel">The type of the model parameters.</typeparam>
/// <param name="model">Existing model to modify threshold.</param>
/// <param name="threshold">New threshold.</param>
/// <returns>New model with modified threshold.</returns>
public BinaryPredictionTransformer<TModel> ChangeModelThreshold<TModel>(BinaryPredictionTransformer<TModel> model, float threshold)
where TModel : class
{
if (model.Threshold == threshold)
return model;
return new BinaryPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumnName, threshold, model.ThresholdColumn);
}
/// <summary>
/// The list of calibrators for performing binary classification.
/// </summary>
public CalibratorsCatalog Calibrators { get; }
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of binary classification calibrators.
/// </summary>
public sealed class CalibratorsCatalog : CatalogInstantiatorBase
{
internal CalibratorsCatalog(BinaryClassificationCatalog catalog)
: base(catalog)
{
}
/// <summary>
/// Adds probability column by training naive binning-based calibrator.
/// </summary>
/// <param name="labelColumnName">The name of the label column.</param>
/// <param name="scoreColumnName">The name of the score column.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[NaiveCalibrator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/Calibrators/Naive.cs)]
/// ]]>
/// </format>
/// </example>
public NaiveCalibratorEstimator Naive(
string labelColumnName = DefaultColumnNames.Label,
string scoreColumnName = DefaultColumnNames.Score)
{
return new NaiveCalibratorEstimator(Owner.GetEnvironment(), labelColumnName, scoreColumnName);
}
/// <summary>
/// Adds probability column by training <a href="https://en.wikipedia.org/wiki/Platt_scaling">platt calibrator</a>.
/// </summary>
/// <param name="labelColumnName">The name of the label column.</param>
/// <param name="scoreColumnName">The name of the score column.</param>
/// <param name="exampleWeightColumnName">The name of the example weight column (optional).</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[PlattCalibrator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/Calibrators/Platt.cs)]
/// ]]>
/// </format>
/// </example>
public PlattCalibratorEstimator Platt(
string labelColumnName = DefaultColumnNames.Label,
string scoreColumnName = DefaultColumnNames.Score,
string exampleWeightColumnName = null)
{
return new PlattCalibratorEstimator(Owner.GetEnvironment(), labelColumnName, scoreColumnName, exampleWeightColumnName);
}
/// <summary>
/// Adds probability column by specifying <a href="https://en.wikipedia.org/wiki/Platt_scaling">platt calibrator</a>.
/// </summary>
/// <param name="slope">The slope in the function of the exponent of the sigmoid.</param>
/// <param name="offset">The offset in the function of the exponent of the sigmoid.</param>
/// <param name="scoreColumnName">The name of the score column.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FixedPlattCalibrator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/Calibrators/FixedPlatt.cs)]
/// ]]>
/// </format>
/// </example>
public FixedPlattCalibratorEstimator Platt(
double slope,
double offset,
string scoreColumnName = DefaultColumnNames.Score)
{
return new FixedPlattCalibratorEstimator(Owner.GetEnvironment(), slope, offset, scoreColumnName);
}
/// <summary>
/// Adds probability column by training pair adjacent violators calibrator.
/// </summary>
/// <remarks>
/// The calibrator finds a stepwise constant function (using the Pool Adjacent Violators Algorithm aka PAV) that minimizes the squared error.
/// Also know as <a href="https://en.wikipedia.org/wiki/Isotonic_regression">Isotonic regression</a>
/// </remarks>
/// <param name="labelColumnName">The name of the label column.</param>
/// <param name="scoreColumnName">The name of the score column.</param>
/// <param name="exampleWeightColumnName">The name of the example weight column (optional).</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[PairAdjacentViolators](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/Calibrators/Isotonic.cs)]
/// ]]>
/// </format>
/// </example>
public IsotonicCalibratorEstimator Isotonic(
string labelColumnName = DefaultColumnNames.Label,
string scoreColumnName = DefaultColumnNames.Score,
string exampleWeightColumnName = null)
{
return new IsotonicCalibratorEstimator(Owner.GetEnvironment(), labelColumnName, scoreColumnName, exampleWeightColumnName);
}
}
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of clustering components,
/// such as trainers.
/// </summary>
public sealed class ClusteringCatalog : TrainCatalogBase
{
/// <summary>
/// The list of trainers for performing clustering.
/// </summary>
public ClusteringTrainers Trainers { get; }
/// <summary>
/// The clustering context.
/// </summary>
internal ClusteringCatalog(IHostEnvironment env)
: base(env, nameof(ClusteringCatalog))
{
Trainers = new ClusteringTrainers(this);
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of clustering trainers.
/// </summary>
public sealed class ClusteringTrainers : CatalogInstantiatorBase
{
internal ClusteringTrainers(ClusteringCatalog catalog)
: base(catalog)
{
}
}
/// <summary>
/// Evaluates scored clustering data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
/// <param name="labelColumnName">The name of the optional label column in <paramref name="data"/>.
/// If present, the <see cref="ClusteringMetrics.NormalizedMutualInformation"/> metric will be computed.</param>
/// <param name="featureColumnName">The name of the optional features column in <paramref name="data"/>.
/// If present, the <see cref="ClusteringMetrics.DaviesBouldinIndex"/> metric will be computed.</param>
/// <returns>The evaluation result.</returns>
public ClusteringMetrics Evaluate(IDataView data,
string labelColumnName = null,
string scoreColumnName = DefaultColumnNames.Score,
string featureColumnName = null)
{
Environment.CheckValue(data, nameof(data));
Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName));
if (featureColumnName != null)
Environment.CheckNonEmpty(featureColumnName, nameof(featureColumnName), "The features column name should be non-empty if you want to calculate the Dbi metric.");
if (labelColumnName != null)
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName), "The label column name should be non-empty if you want to calculate the Nmi metric.");
var eval = new ClusteringEvaluator(Environment, new ClusteringEvaluator.Arguments() { CalculateDbi = !string.IsNullOrEmpty(featureColumnName) });
return eval.Evaluate(data, scoreColumnName, labelColumnName, featureColumnName);
}
/// <summary>
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return metrics.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
/// <param name="labelColumnName">Optional label column for evaluation (clustering tasks may not always have a label).</param>
/// <param name="featuresColumnName">Optional features column for evaluation (needed for calculating Dbi metric)</param>
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
/// If <see langword="null"/> no row grouping will be performed.</param>
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
public IReadOnlyList<CrossValidationResult<ClusteringMetrics>> CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = null, string featuresColumnName = null,
string samplingKeyColumnName = null, int? seed = null)
{
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
return result.Select(x => new CrossValidationResult<ClusteringMetrics>(x.Model,
Evaluate(x.Scores, labelColumnName: labelColumnName, featureColumnName: featuresColumnName), x.Scores, x.Fold)).ToArray();
}
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of multiclass classification components,
/// such as trainers.
/// </summary>
public sealed class MulticlassClassificationCatalog : TrainCatalogBase
{
/// <summary>
/// The list of trainers for performing multiclass classification.
/// </summary>
public MulticlassClassificationTrainers Trainers { get; }
internal MulticlassClassificationCatalog(IHostEnvironment env)
: base(env, nameof(MulticlassClassificationCatalog))
{
Trainers = new MulticlassClassificationTrainers(this);
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of multiclass classification trainers.
/// </summary>
public sealed class MulticlassClassificationTrainers : CatalogInstantiatorBase
{
internal MulticlassClassificationTrainers(MulticlassClassificationCatalog catalog)
: base(catalog)
{
}
}
/// <summary>
/// Evaluates scored multiclass classification data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param>
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
/// <param name="predictedLabelColumnName">The name of the predicted label column in <paramref name="data"/>.</param>
/// <param name="topKPredictionCount">If given a positive value, the <see cref="MulticlassClassificationMetrics.TopKAccuracy"/> will be filled with
/// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within
/// the top-K values as being stored "correctly."</param>
/// <returns>The evaluation results for these calibrated outputs.</returns>
public MulticlassClassificationMetrics Evaluate(IDataView data, string labelColumnName = DefaultColumnNames.Label, string scoreColumnName = DefaultColumnNames.Score,
string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int topKPredictionCount = 0)
{
Environment.CheckValue(data, nameof(data));
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName));
Environment.CheckNonEmpty(predictedLabelColumnName, nameof(predictedLabelColumnName));
Environment.CheckUserArg(topKPredictionCount >= 0, nameof(topKPredictionCount), "Must be non-negative");
var args = new MulticlassClassificationEvaluator.Arguments() { };
if (topKPredictionCount > 0)
args.OutputTopKAcc = topKPredictionCount;
var eval = new MulticlassClassificationEvaluator(Environment, args);
return eval.Evaluate(data, labelColumnName, scoreColumnName, predictedLabelColumnName);
}
/// <summary>
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return metrics.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
/// <param name="labelColumnName">The label column (for evaluation).</param>
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
/// If <see langword="null"/> no row grouping will be performed.</param>
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public IReadOnlyList<CrossValidationResult<MulticlassClassificationMetrics>> CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = DefaultColumnNames.Label,
string samplingKeyColumnName = null, int? seed = null)
{
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
return result.Select(x => new CrossValidationResult<MulticlassClassificationMetrics>(x.Model,
Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
}
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of regression components,
/// such as trainers and evaluators.
/// </summary>
public sealed class RegressionCatalog : TrainCatalogBase
{
/// <summary>
/// The list of trainers for performing regression.
/// </summary>
public RegressionTrainers Trainers { get; }
internal RegressionCatalog(IHostEnvironment env)
: base(env, nameof(RegressionCatalog))
{
Trainers = new RegressionTrainers(this);
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of regression trainers.
/// </summary>
public sealed class RegressionTrainers : CatalogInstantiatorBase
{
internal RegressionTrainers(RegressionCatalog catalog)
: base(catalog)
{
}
}
/// <summary>
/// Evaluates scored regression data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param>
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
/// <returns>The evaluation results for these calibrated outputs.</returns>
public RegressionMetrics Evaluate(IDataView data, string labelColumnName = DefaultColumnNames.Label, string scoreColumnName = DefaultColumnNames.Score)
{
Environment.CheckValue(data, nameof(data));
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName));
var eval = new RegressionEvaluator(Environment, new RegressionEvaluator.Arguments() { });
return eval.Evaluate(data, labelColumnName, scoreColumnName);
}
/// <summary>
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return metrics.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
/// <param name="labelColumnName">The label column (for evaluation).</param>
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
/// If <see langword="null"/> no row grouping will be performed.</param>
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public IReadOnlyList<CrossValidationResult<RegressionMetrics>> CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = DefaultColumnNames.Label,
string samplingKeyColumnName = null, int? seed = null)
{
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
return result.Select(x => new CrossValidationResult<RegressionMetrics>(x.Model,
Evaluate(x.Scores, labelColumnName), x.Scores, x.Fold)).ToArray();
}
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of ranking components,
/// such as trainers and evaluators.
/// </summary>
public sealed class RankingCatalog : TrainCatalogBase
{
/// <summary>
/// The list of trainers for performing regression.
/// </summary>
public RankingTrainers Trainers { get; }
internal RankingCatalog(IHostEnvironment env)
: base(env, nameof(RankingCatalog))
{
Trainers = new RankingTrainers(this);
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of ranking trainers.
/// </summary>
public sealed class RankingTrainers : CatalogInstantiatorBase
{
internal RankingTrainers(RankingCatalog catalog)
: base(catalog)
{
}
}
/// <summary>
/// Evaluates scored ranking data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param>
/// <param name="rowGroupColumnName">The name of the groupId column in <paramref name="data"/>.</param>
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
/// <returns>The evaluation results for these calibrated outputs.</returns>
public RankingMetrics Evaluate(IDataView data,
string labelColumnName = DefaultColumnNames.Label,
string rowGroupColumnName = DefaultColumnNames.GroupId,
string scoreColumnName = DefaultColumnNames.Score) => Evaluate(data, null, labelColumnName, rowGroupColumnName, scoreColumnName);
/// <summary>
/// Evaluates scored ranking data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="options">Options to control the evaluation result.</param>
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param>
/// <param name="rowGroupColumnName">The name of the groupId column in <paramref name="data"/>.</param>
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
/// <returns>The evaluation results for these calibrated outputs.</returns>
public RankingMetrics Evaluate(IDataView data,
RankingEvaluatorOptions options,
string labelColumnName = DefaultColumnNames.Label,
string rowGroupColumnName = DefaultColumnNames.GroupId,
string scoreColumnName = DefaultColumnNames.Score)
{
Environment.CheckValue(data, nameof(data));
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName));
Environment.CheckNonEmpty(rowGroupColumnName, nameof(rowGroupColumnName));
var eval = new RankingEvaluator(Environment, options ?? new RankingEvaluatorOptions() { });
return eval.Evaluate(data, labelColumnName, rowGroupColumnName, scoreColumnName);
}
/// <summary>
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="rowGroupColumnName"/>if provided.
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return metrics.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
/// <param name="labelColumnName">The label column (for evaluation).</param>
/// <param name="rowGroupColumnName">The name of the groupId column in <paramref name="data"/>, which is used to group rows.
/// This column will automatically be used as SamplingKeyColumn when splitting the data for Cross Validation,
/// as this is required by the ranking algorithms
/// If <see langword="null"/> no row grouping will be performed. </param>
/// <param name="seed"> Seed for the random number generator used to select rows for cross-validation folds.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public IReadOnlyList<CrossValidationResult<RankingMetrics>> CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = DefaultColumnNames.Label,
string rowGroupColumnName = DefaultColumnNames.GroupId, int? seed = null)
{
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
var result = CrossValidateTrain(data, estimator, numberOfFolds, rowGroupColumnName, seed);
return result.Select(x => new CrossValidationResult<RankingMetrics>(x.Model,
Evaluate(x.Scores, labelColumnName, rowGroupColumnName), x.Scores, x.Fold)).ToArray();
}
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of anomaly detection components,
/// such as trainers and evaluators.
/// </summary>
public sealed class AnomalyDetectionCatalog : TrainCatalogBase
{
/// <summary>
/// The list of trainers for anomaly detection.
/// </summary>
public AnomalyDetectionTrainers Trainers { get; }
internal AnomalyDetectionCatalog(IHostEnvironment env)
: base(env, nameof(AnomalyDetectionCatalog))
{
Trainers = new AnomalyDetectionTrainers(this);
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of anomaly detection trainers.
/// </summary>
public sealed class AnomalyDetectionTrainers : CatalogInstantiatorBase
{
internal AnomalyDetectionTrainers(AnomalyDetectionCatalog catalog)
: base(catalog)
{
}
}
/// <summary>
/// Evaluates scored anomaly detection data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param>
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
/// <param name="predictedLabelColumnName">The name of the predicted label column in <paramref name="data"/>.</param>
/// <param name="falsePositiveCount">The number of false positives to compute the <see cref="AnomalyDetectionMetrics.DetectionRateAtFalsePositiveCount"/> metric. </param>
/// <returns>Evaluation results.</returns>
public AnomalyDetectionMetrics Evaluate(IDataView data, string labelColumnName = DefaultColumnNames.Label, string scoreColumnName = DefaultColumnNames.Score,
string predictedLabelColumnName = DefaultColumnNames.PredictedLabel, int falsePositiveCount = 10)
{
Environment.CheckValue(data, nameof(data));
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName));
Environment.CheckNonEmpty(predictedLabelColumnName, nameof(predictedLabelColumnName));
var args = new AnomalyDetectionEvaluator.Arguments();
args.K = falsePositiveCount;
var eval = new AnomalyDetectionEvaluator(Environment, args);
return eval.Evaluate(data, labelColumnName, scoreColumnName, predictedLabelColumnName);
}
/// <summary>
/// Creates a new <see cref="AnomalyPredictionTransformer{TModel}"/> with the specified <paramref name="threshold"/>.
/// If the provided <paramref name="threshold"/> is the same as the <paramref name="model"/> threshold it simply returns <paramref name="model"/>.
/// Note that by default the threshold is 0.5 and valid scores range from 0 to 1.
/// </summary>
/// <param name="model">A trained <see cref="AnomalyPredictionTransformer{TModel}"/>.</param>
/// <param name="threshold">The new threshold value that will be used to determine the label of a data point
/// based on the predicted score by the model.</param>
public AnomalyPredictionTransformer<TModel> ChangeModelThreshold<TModel>(AnomalyPredictionTransformer<TModel> model, float threshold)
where TModel : class
{
if (model.Threshold == threshold)
return model;
return new AnomalyPredictionTransformer<TModel>(Environment, model.Model, model.TrainSchema, model.FeatureColumnName, threshold, model.ThresholdColumn);
}
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of forecasting components.
/// </summary>
public sealed class ForecastingCatalog : TrainCatalogBase
{
/// <summary>
/// The list of trainers for performing forecasting.
/// </summary>
public Forecasters Trainers { get; }
internal ForecastingCatalog(IHostEnvironment env) : base(env, nameof(ForecastingCatalog))
{
Trainers = new Forecasters(this);
}
/// <summary>
/// Class used by <see cref="MLContext"/> to create instances of forecasting trainers.
/// </summary>
public sealed class Forecasters : CatalogInstantiatorBase
{
internal Forecasters(ForecastingCatalog catalog)
: base(catalog)
{
}
}
}
}
|