// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
[assembly: LoadableClass(PriorTrainer.Summary, typeof(PriorTrainer), typeof(PriorTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
#pragma warning disable CS0618 // Still desired for backwards compatibility of models, for now.
[assembly: LoadableClass(typeof(RandomModelParameters), null, typeof(SignatureLoadModel),
"Random predictor", RandomModelParameters.LoaderSignature)]
#pragma warning restore CS0618
[assembly: LoadableClass(typeof(PriorModelParameters), null, typeof(SignatureLoadModel),
"Prior predictor", PriorModelParameters.LoaderSignature)]
namespace Microsoft.ML.Trainers
/// <summary>
/// This is a faulty model parameters that claims to be a <see cref="IValueMapperDist"/>. We maintain the class for
/// now for backwards compatibility of models (assuming there are any), but might consider removing it in the future.
/// </summary>
[Obsolete("This class was fundamentally misdesigned, is incapable of implementing the interfaces it claims to implement, " +
"and it should not be used for any purposes beyond backwards compatibility of model loading.")]
internal sealed class RandomModelParameters :
IDistPredictorProducing<float, float>,
internal const string LoaderSignature = "RandomPredictor";
private static VersionInfo GetVersionInfo()
return new VersionInfo(
modelSignature: "RANDOMPR", // Unique 8-letter string.
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(RandomModelParameters).Assembly.FullName);
// Keep all the serializable state here.
private readonly int _seed;
private readonly object _instanceLock;
private readonly Random _random;
/// <summary>Return the type of prediction task.</summary>
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
private readonly DataViewType _inputType;
DataViewType IValueMapper.InputType => _inputType;
DataViewType IValueMapper.OutputType => NumberDataViewType.Single;
DataViewType IValueMapperDist.DistType => NumberDataViewType.Single;
/// <summary>
/// Instantiate a model that returns a uniform random probability.
/// </summary>
/// <param name="env">The host environment.</param>
/// <param name="seed">The random seed.</param>
internal RandomModelParameters(IHostEnvironment env, int seed)
: base(env, LoaderSignature)
_seed = seed;
_instanceLock = new object();
_random = RandomUtils.Create(_seed);
_inputType = new VectorDataViewType(NumberDataViewType.Single);
/// <summary>
/// Load the predictor from the binary format.
/// </summary>
private RandomModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx)
// *** Binary format ***
// int: _seed
_seed = ctx.Reader.ReadInt32();
_instanceLock = new object();
_random = RandomUtils.Create(_seed);
_inputType = new VectorDataViewType(NumberDataViewType.Single);
internal static RandomModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
return new RandomModelParameters(env, ctx);
/// <summary>
/// Save the predictor in the binary format.
/// </summary>
/// <param name="ctx"></param>
private protected override void SaveCore(ModelSaveContext ctx)
// *** Binary format ***
// int: _seed
ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));
ValueMapper<VBuffer<float>, float> del = Map;
return (ValueMapper<TIn, TOut>)(Delegate)del;
ValueMapper<TIn, TOut, TDist> IValueMapperDist.GetMapper<TIn, TOut, TDist>()
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));
Contracts.Check(typeof(TDist) == typeof(float));
ValueMapper<VBuffer<float>, float, float> del = MapDist;
return (ValueMapper<TIn, TOut, TDist>)(Delegate)del;
private float PredictCore()
// Predict can be called from different threads.
// Ensure your implementation is thread-safe
lock (_instanceLock)
// For binary classification, prediction may be in range [-infinity, infinity].
// default class decision boundary is at 0.
return (_random.NextDouble() * 2 - 1).ToFloat();
private void Map(in VBuffer<float> src, ref float dst)
dst = PredictCore();
private void MapDist(in VBuffer<float> src, ref float score, ref float prob)
score = PredictCore();
prob = (score + 1) / 2;
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> for predicting a target using a binary classification model.
/// </summary>
/// <remarks>
/// <format type="text/markdown"><)
/// [!include[io](~/../docs/samples/docs/api-reference/io-columns-binary-classification.md)]
/// ### Trainer Characteristics
/// | | |
/// | -- | -- |
/// | Machine learning task | Binary classification |
/// | Is normalization required? | No |
/// | Is caching required? | No |
/// | Required NuGet in addition to Microsoft.ML | None |
/// | Exportable to ONNX | Yes |
/// ### Training Algorithm Details
/// Learns the prior distribution for 0/1 class labels and outputs that.
/// Check the See Also section for links to usage examples.
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="Microsoft.ML.StandardTrainersCatalog.Prior(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,System.String,System.String)"/>
public sealed class PriorTrainer : ITrainer<PriorModelParameters>,
ITrainerEstimator<BinaryPredictionTransformer<PriorModelParameters>, PriorModelParameters>
internal const string LoadNameValue = "PriorPredictor";
internal const string UserNameValue = "Prior Predictor";
internal const string Summary = "A trivial model for producing the prior based on the number of positive and negative examples.";
internal sealed class Options
private readonly string _labelColumnName;
private readonly string _weightColumnName;
private readonly IHost _host;
/// <summary> Return the type of prediction task.</summary>
PredictionKind ITrainer.PredictionKind => PredictionKind.BinaryClassification;
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
/// <summary>
/// Auxiliary information about the trainer in terms of its capabilities
/// and requirements.
/// </summary>
public TrainerInfo Info => _info;
internal PriorTrainer(IHostEnvironment env, Options options)
Contracts.CheckValue(env, nameof(env));
_host = env.Register(LoadNameValue);
_host.CheckValue(options, nameof(options));
/// <summary>
/// Initializes PriorTrainer object.
/// </summary>
internal PriorTrainer(IHostEnvironment env, String labelColumn, String weightColunn = null)
Contracts.CheckValue(env, nameof(env));
_host = env.Register(LoadNameValue);
_host.CheckValue(labelColumn, nameof(labelColumn));
_labelColumnName = labelColumn;
_weightColumnName = weightColunn != null ? weightColunn : null;
/// <summary>
/// Trains and returns a <see cref="BinaryPredictionTransformer{PriorModelParameters}"/>.
/// </summary>
public BinaryPredictionTransformer<PriorModelParameters> Fit(IDataView input)
RoleMappedData trainRoles = new RoleMappedData(input, label: _labelColumnName, feature: null, weight: _weightColumnName);
var pred = ((ITrainer<PriorModelParameters>)this).Train(new TrainContext(trainRoles));
return new BinaryPredictionTransformer<PriorModelParameters>(_host, pred, input.Schema, featureColumn: null, labelColumn: _labelColumnName);
private PriorModelParameters Train(TrainContext context)
_host.CheckValue(context, nameof(context));
var data = context.TrainingSet;
_host.CheckParam(data.Schema.Label.HasValue, nameof(data), "Missing Label column");
var labelCol = data.Schema.Label.Value;
_host.CheckParam(labelCol.Type == BooleanDataViewType.Instance, nameof(data), "Invalid type for Label column");
double pos = 0;
double neg = 0;
int colWeight = -1;
if (data.Schema.Weight?.Type == NumberDataViewType.Single)
colWeight = data.Schema.Weight.Value.Index;
var cols = colWeight > -1 ? new DataViewSchema.Column[] { labelCol, data.Schema.Weight.Value } : new DataViewSchema.Column[] { labelCol };
using (var cursor = data.Data.GetRowCursor(cols))
var getLab = cursor.GetGetter<bool>(data.Schema.Label.Value);
var getWeight = colWeight >= 0 ? cursor.GetGetter<float>(data.Schema.Weight.Value) : null;
bool lab = default;
float weight = 1;
while (cursor.MoveNext())
getLab(ref lab);
if (getWeight != null)
getWeight(ref weight);
if (!(0 < weight && weight < float.PositiveInfinity))
// Testing both directions effectively ignores NaNs.
if (lab)
pos += weight;
neg += weight;
float prob = prob = pos + neg > 0 ? (float)(pos / (pos + neg)) : float.NaN;
return new PriorModelParameters(_host, prob);
IPredictor ITrainer.Train(TrainContext context) => Train(context);
PriorModelParameters ITrainer<PriorModelParameters>.Train(TrainContext context) => Train(context);
private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
=> new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false);
/// <summary>
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
/// Used for schema propagation and verification in a pipeline.
/// </summary>
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
_host.CheckValue(inputSchema, nameof(inputSchema));
var outColumns = inputSchema.ToDictionary(x => x.Name);
var newColumns = new[]
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation())),
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation(true))),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation()))
foreach (SchemaShape.Column column in newColumns)
outColumns[column.Name] = column;
return new SchemaShape(outColumns.Values);
/// <summary>
/// Model parameters for <see cref="PriorTrainer"/>.
/// </summary>
public sealed class PriorModelParameters :
IDistPredictorProducing<float, float>,
IValueMapperDist, ISingleCanSaveOnnx
internal const string LoaderSignature = "PriorPredictor";
private static VersionInfo GetVersionInfo()
return new VersionInfo(
modelSignature: "PRIORPRD",
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(PriorModelParameters).Assembly.FullName);
private readonly float _prob;
private readonly float _raw;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
/// <summary>
/// Instantiates a model that returns the prior probability of the positive class in the training set.
/// </summary>
/// <param name="env">The host environment.</param>
/// <param name="prob">The probability of the positive class.</param>
internal PriorModelParameters(IHostEnvironment env, float prob)
: base(env, LoaderSignature)
_prob = prob;
_raw = 2 * _prob - 1; // This could be other functions -- logodds for instance
_inputType = new VectorDataViewType(NumberDataViewType.Single);
private PriorModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx)
// *** Binary format ***
// Float: _prob
_prob = ctx.Reader.ReadFloat();
_raw = 2 * _prob - 1;
_inputType = new VectorDataViewType(NumberDataViewType.Single);
internal static PriorModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
return new PriorModelParameters(env, ctx);
private protected override void SaveCore(ModelSaveContext ctx)
// *** Binary format ***
// Float: _prob
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string labelColumn)
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) >= 3);
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
string scoreVarName = outputs[1];
string probVarName = outputs[2];
var prob = ctx.AddInitializer(_prob, "probability");
var score = ctx.AddInitializer(_raw, "score");
var xorOutput = ctx.AddIntermediateVariable(null, "XorOutput", true);
string opType = "Xor";
ctx.CreateNode(opType, new[] { labelColumn, labelColumn }, new[] { xorOutput }, ctx.GetNodeName(opType), "");
var notOutput = ctx.AddIntermediateVariable(null, "NotOutput", true);
opType = "Not";
ctx.CreateNode(opType, xorOutput, notOutput, ctx.GetNodeName(opType), "");
var castOutput = ctx.AddIntermediateVariable(null, "CastOutput", true);
opType = "Cast";
var node = ctx.CreateNode(opType, notOutput, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
node.AddAttribute("to", t);
opType = "Mul";
ctx.CreateNode(opType, new[] { castOutput, prob }, new[] { probVarName }, ctx.GetNodeName(opType), "");
opType = "Mul";
ctx.CreateNode(opType, new[] { castOutput, score }, new[] { scoreVarName }, ctx.GetNodeName(opType), "");
return true;
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
private readonly DataViewType _inputType;
DataViewType IValueMapper.InputType => _inputType;
DataViewType IValueMapper.OutputType => NumberDataViewType.Single;
DataViewType IValueMapperDist.DistType => NumberDataViewType.Single;
ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));
ValueMapper<VBuffer<float>, float> del = Map;
return (ValueMapper<TIn, TOut>)(Delegate)del;
ValueMapper<TIn, TOut, TDist> IValueMapperDist.GetMapper<TIn, TOut, TDist>()
Contracts.Check(typeof(TIn) == typeof(VBuffer<float>));
Contracts.Check(typeof(TOut) == typeof(float));
Contracts.Check(typeof(TDist) == typeof(float));
ValueMapper<VBuffer<float>, float, float> del = MapDist;
return (ValueMapper<TIn, TOut, TDist>)(Delegate)del;
private void Map(in VBuffer<float> src, ref float dst)
dst = _raw;
private void MapDist(in VBuffer<float> src, ref float score, ref float prob)
score = _raw;
prob = _prob;