using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Transforms.TimeSeries
internal interface IStatefulRowToRowMapper : IRowToRowMapper
internal interface IStatefulTransformer : ITransformer
/// <summary>
/// Same as <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> but also supports mechanism to save the state.
/// </summary>
/// <param name="inputSchema">The input schema for which we should get the mapper.</param>
/// <returns>The row to row mapper.</returns>
IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema);
/// <summary>
/// Creates a clone of the transfomer. Used for taking the snapshot of the state.
/// This is used to create multiple time series with their own state.
/// </summary>
/// <returns></returns>
IStatefulTransformer Clone();
internal abstract class StatefulRow : DataViewRow
public abstract Action<PingerArgument> GetPinger();
internal interface IStatefulRowMapper : IRowMapper
void CloneState();
Action<PingerArgument> CreatePinger(DataViewRow input, Func<int, bool> activeOutput, out Action disposer);
internal class PingerArgument
public long RowPosition;
public float? ConfidenceLevel;
public int? Horizon;
public bool DontConsumeSource;
/// <summary>
/// A class that runs the previously trained model (and the preceding transform pipeline) on the
/// in-memory data, one example at a time.
/// This can also be used with trained pipelines that do not end with a predictor: in this case, the
/// 'prediction' will be just the outcome of all the transformations.
/// </summary>
/// <typeparam name="TSrc">The user-defined type that holds the example.</typeparam>
/// <typeparam name="TDst">The user-defined type that holds the prediction.</typeparam>
public sealed class TimeSeriesPredictionEngine<TSrc, TDst> : PredictionEngineBase<TSrc, TDst>
where TSrc : class
where TDst : class, new()
private Action<PingerArgument> _pinger;
private long _rowPosition;
private ITransformer InputTransformer { get; set; }
/// <summary>
/// Checkpoints <see cref="TimeSeriesPredictionEngine{TSrc, TDst}"/> to disk with the updated
/// state.
/// </summary>
/// <param name="env">Usually <see cref="MLContext"/>.</param>
/// <param name="modelPath">Path to file on disk where the updated model needs to be saved.</param>
/// <example>
/// <format type="text/markdown">
/// <]
/// ]]>
/// </format>
/// </example>
public void CheckPoint(IHostEnvironment env, string modelPath)
Contracts.CheckValue(env, nameof(env));
env.CheckParam(!string.IsNullOrEmpty(modelPath), nameof(modelPath));
using (var file = File.Create(modelPath))
CheckPoint(env, file);
/// <summary>
/// Checkpoints <see cref="TimeSeriesPredictionEngine{TSrc, TDst}"/> to a <see cref="Stream"/> with the updated
/// state.
/// </summary>
/// <param name="env">Usually <see cref="MLContext"/>.</param>
/// <param name="stream">Stream where the updated model needs to be saved.</param>
/// <example>
/// <format type="text/markdown">
/// <]
/// ]]>
/// </format>
/// </example>
public void CheckPoint(IHostEnvironment env, Stream stream)
Contracts.CheckValue(env, nameof(env));
env.CheckParam(stream != null, nameof(stream));
if (Transformer is ITransformerChainAccessor transformerChainAccessor)
new TransformerChain<ITransformer>(transformerChainAccessor.Transformers, transformerChainAccessor.Scopes).SaveTo(env, stream);
Transformer.SaveTo(env, stream);
private static ITransformer CloneTransformers(ITransformer transformer)
ITransformer[] transformersClone = null;
TransformerScope[] scopeClone = null;
if (transformer is ITransformerChainAccessor transformerChainAccessor)
ITransformerChainAccessor accessor = transformerChainAccessor;
transformersClone = accessor.Transformers.Select(x => x).ToArray();
scopeClone = accessor.Scopes.Select(x => x).ToArray();
int index = 0;
foreach (var xf in transformersClone)
transformersClone[index++] = xf is IStatefulTransformer ? ((IStatefulTransformer)xf).Clone() : xf;
return new TransformerChain<ITransformer>(transformersClone, scopeClone);
return transformer is IStatefulTransformer ? ((IStatefulTransformer)transformer).Clone() : transformer;
/// <summary>
/// Contructor for creating time series specific prediction engine. It allows the time series model to be updated with the observations
/// seen at prediction time via <see cref="CheckPoint(IHostEnvironment, string)"/>
/// </summary>
public TimeSeriesPredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) :
base(env, CloneTransformers(transformer), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
/// <summary>
/// Contructor for creating time series specific prediction engine. It allows the time series model to be updated with the observations
/// seen at prediction time via <see cref="CheckPoint(IHostEnvironment, string)"/>
/// </summary>
internal TimeSeriesPredictionEngine(IHostEnvironment env, ITransformer transformer, PredictionEngineOptions options) :
base(env, CloneTransformers(transformer), options.IgnoreMissingColumns, options.InputSchemaDefinition, options.OutputSchemaDefinition, options.OwnsTransformer)
internal DataViewRow GetStatefulRows(DataViewRow input, IRowToRowMapper mapper, IEnumerable<DataViewSchema.Column> activeColumns, List<StatefulRow> rows)
Contracts.CheckValue(input, nameof(input));
Contracts.CheckValue(activeColumns, nameof(activeColumns));
IRowToRowMapper[] innerMappers = new IRowToRowMapper[0];
if (mapper is CompositeRowToRowMapper compositeMapper)
innerMappers = compositeMapper.InnerMappers;
var activeIndices = new HashSet<int>(activeColumns.Select(c => c.Index));
if (innerMappers.Length == 0)
bool differentActive = false;
for (int c = 0; c < input.Schema.Count; ++c)
bool wantsActive = activeIndices.Contains(c);
bool isActive = input.IsColumnActive(input.Schema[c]);
differentActive |= wantsActive != isActive;
if (wantsActive && !isActive)
throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema[c].Name}' active but it was not.");
var row = mapper.GetRow(input, activeColumns);
if (row is StatefulRow statefulRow)
return row;
// For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know
// what we need from them. The last one will just have the input, but the rest will need to be
// computed based on the dependencies of the next one in the chain.
var deps = new IEnumerable<DataViewSchema.Column>[innerMappers.Length];
deps[deps.Length - 1] = activeColumns;
for (int i = deps.Length - 1; i >= 1; --i)
deps[i - 1] = innerMappers[i].GetDependencies(deps[i]);
DataViewRow result = input;
for (int i = 0; i < innerMappers.Length; ++i)
result = GetStatefulRows(result, innerMappers[i], deps[i], rows);
if (result is StatefulRow statefulResult)
return result;
internal Action<PingerArgument> CreatePinger(List<StatefulRow> rows)
if (rows.Count == 0)
return position => { };
Action<PingerArgument> pinger = null;
foreach (var row in rows)
pinger += row.GetPinger();
return pinger;
private protected override void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow<TSrc> inputRow,
IRowToRowMapper mapper, bool ignoreMissingColumns, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs<TDst> outputRow)
List<StatefulRow> rows = new List<StatefulRow>();
DataViewRow outputRowLocal = outputRowLocal = GetStatefulRows(inputRow, mapper, mapper.OutputSchema, rows);
var cursorable = TypedCursorable<TDst>.Create(env, new EmptyDataView(env, mapper.OutputSchema), ignoreMissingColumns, outputSchemaDefinition);
_pinger = CreatePinger(rows);
disposer = outputRowLocal.Dispose;
outputRow = cursorable.GetRow(outputRowLocal);
private bool IsRowToRowMapper(ITransformer transformer)
if (transformer is ITransformerChainAccessor transformerChainAccessor)
return transformerChainAccessor.Transformers.All(t => t.IsRowToRowMapper || t is IStatefulTransformer);
return transformer.IsRowToRowMapper || transformer is IStatefulTransformer;
private IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema)
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Contracts.Check(IsRowToRowMapper(InputTransformer), nameof(GetRowToRowMapper) +
" method called despite " + nameof(IsRowToRowMapper) + " being false. or transformer not being " + nameof(IStatefulTransformer));
if (!(InputTransformer is ITransformerChainAccessor))
if (InputTransformer is IStatefulTransformer statefulTransformer)
return statefulTransformer.GetStatefulRowToRowMapper(inputSchema);
return InputTransformer.GetRowToRowMapper(inputSchema);
Contracts.Check(InputTransformer is ITransformerChainAccessor);
var transformers = ((ITransformerChainAccessor)InputTransformer).Transformers;
IRowToRowMapper[] mappers = new IRowToRowMapper[transformers.Length];
DataViewSchema schema = inputSchema;
for (int i = 0; i < mappers.Length; ++i)
if (transformers[i] is IStatefulTransformer)
mappers[i] = ((IStatefulTransformer)transformers[i]).GetStatefulRowToRowMapper(schema);
mappers[i] = transformers[i].GetRowToRowMapper(schema);
schema = mappers[i].OutputSchema;
return new CompositeRowToRowMapper(inputSchema, mappers);
private protected override Func<DataViewSchema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
ectx.CheckValue(transformer, nameof(transformer));
ectx.CheckParam(IsRowToRowMapper(transformer), nameof(transformer), "Must be a row to row mapper or " + nameof(IStatefulTransformer));
InputTransformer = transformer;
return GetRowToRowMapper;
/// <summary>
/// Performs prediction. In the case of forecasting only task <paramref name="example"/> can be left as null.
/// If <paramref name="example"/> is not null then it could be used to update forecasting models with new obervation.
/// For anomaly detection the model is always updated with <paramref name="example"/>.
/// </summary>
/// <param name="example">Input to the prediction engine.</param>
/// <param name="prediction">Forecasting/Prediction from the engine.</param>
/// <param name="horizon">Used to indicate the number of values to forecast.</param>
/// <param name="confidenceLevel">Used in forecasting model for confidence.</param>
public void Predict(TSrc example, ref TDst prediction, int? horizon = null, float? confidenceLevel = null)
if (example != null && prediction != null)
//Update models and make a prediction after updating.
Contracts.CheckValue(example, nameof(example));
// Update state.
_pinger(new PingerArgument()
RowPosition = _rowPosition,
ConfidenceLevel = confidenceLevel,
Horizon = horizon
// Predict.
else if (prediction != null)
// Signal all time series models to not fetch src values in getters.
_pinger(new PingerArgument()
RowPosition = _rowPosition,
DontConsumeSource = true,
ConfidenceLevel = confidenceLevel,
Horizon = horizon
// Predict. The expectation is user has asked for columns that are
// forecasting columns and hence will not trigger a getter that needs an input.
else if (example != null)
//Update models.
//Extract value that needs to propagated to all the models.
Contracts.CheckValue(example, nameof(example));
// Update state.
_pinger(new PingerArgument()
RowPosition = _rowPosition,
ConfidenceLevel = confidenceLevel,
Horizon = horizon
/// <summary>
/// Performs prediction. In the case of forecasting only task <paramref name="example"/> can be left as null.
/// If <paramref name="example"/> is not null then it could be used to update forecasting models with new obervation.
/// For anomaly detection the model is always updated with <paramref name="example"/>.
/// </summary>
/// <param name="example">Input to the prediction engine.</param>
/// <param name="prediction">Forecasting/Prediction from the engine.</param>
public override void Predict(TSrc example, ref TDst prediction) => Predict(example, ref prediction);
/// <summary>
/// Performs prediction. In the case of forecasting only task <paramref name="example"/> can be left as null.
/// If <paramref name="example"/> is not null then it could be used to update forecasting models with new obervation.
/// </summary>
/// <param name="example">Input to the prediction engine.</param>
/// <param name="horizon">Number of values to forecast.</param>
/// <param name="confidenceLevel">Confidence level for forecasting.</param>
/// <returns>Prediction/Forecasting after the model has been updated with <paramref name="example"/></returns>
public TDst Predict(TSrc example, int? horizon = null, float? confidenceLevel = null)
TDst dst = new TDst();
Predict(example, ref dst, horizon, confidenceLevel);
return dst;
/// <summary>
/// Forecasting only task.
/// </summary>
/// <param name="horizon">Number of values to forecast.</param>
/// <param name="confidenceLevel">Confidence level for forecasting.</param>
public TDst Predict(int? horizon = null, float? confidenceLevel = null)
TDst dst = new TDst();
Predict(null, ref dst, horizon, confidenceLevel);
return dst;
public static class PredictionFunctionExtensions
/// <summary>
/// <see cref="TimeSeriesPredictionEngine{TSrc, TDst}"/> creates a prediction engine for a time series pipeline.
/// It updates the state of time series model with observations seen at prediction phase and allows checkpointing the model.
/// </summary>
/// <typeparam name="TSrc">Class describing input schema to the model.</typeparam>
/// <typeparam name="TDst">Class describing the output schema of the prediction.</typeparam>
/// <param name="transformer">The time series pipeline in the form of a <see cref="ITransformer"/>.</param>
/// <param name="env">Usually <see cref="MLContext"/></param>
/// <param name="ignoreMissingColumns">To ignore missing columns. Default is false.</param>
/// <param name="inputSchemaDefinition">Input schema definition. Default is null.</param>
/// <param name="outputSchemaDefinition">Output schema definition. Default is null.</param>
/// <p>Example code can be found by searching for <i>TimeSeriesPredictionEngine</i> in <a href='https://github.com/dotnet/machinelearning'>ML.NET.</a></p>
/// <example>
/// <format type="text/markdown">
/// <]
/// ]]>
/// </format>
/// </example>
public static TimeSeriesPredictionEngine<TSrc, TDst> CreateTimeSeriesEngine<TSrc, TDst>(this ITransformer transformer, IHostEnvironment env,
bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
where TSrc : class
where TDst : class, new()
Contracts.CheckValue(env, nameof(env));
env.CheckValue(transformer, nameof(transformer));
return new TimeSeriesPredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
/// <summary>
/// <see cref="TimeSeriesPredictionEngine{TSrc, TDst}"/> creates a prediction engine for a time series pipeline.
/// It updates the state of time series model with observations seen at prediction phase and allows checkpointing the model.
/// </summary>
/// <typeparam name="TSrc">Class describing input schema to the model.</typeparam>
/// <typeparam name="TDst">Class describing the output schema of the prediction.</typeparam>
/// <param name="transformer">The time series pipeline in the form of a <see cref="ITransformer"/>.</param>
/// <param name="env">Usually <see cref="MLContext"/></param>
/// <param name="options">Advanced configuration options.</param>
/// <p>Example code can be found by searching for <i>TimeSeriesPredictionEngine</i> in <a href='https://github.com/dotnet/machinelearning'>ML.NET.</a></p>
/// <example>
/// <format type="text/markdown">
/// <]
/// ]]>
/// </format>
/// </example>
public static TimeSeriesPredictionEngine<TSrc, TDst> CreateTimeSeriesEngine<TSrc, TDst>(this ITransformer transformer, IHostEnvironment env,
PredictionEngineOptions options)
where TSrc : class
where TDst : class, new()
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));
return new TimeSeriesPredictionEngine<TSrc, TDst>(env, transformer, options);