|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Transforms.TimeSeries
{
/// <summary>
/// The base class for sequential processing transforms. This class implements the basic sliding window buffering. The derived classes need to specify the transform logic,
/// the initialization logic and the learning logic via implementing the abstract methods TransformCore(), InitializeStateCore() and LearnStateFromDataCore(), respectively
/// </summary>
/// <typeparam name="TInput">The input type of the sequential processing.</typeparam>
/// <typeparam name="TOutput">The dst type of the sequential processing.</typeparam>
/// <typeparam name="TState">The dst type of the sequential processing.</typeparam>
internal abstract class SequentialTransformerBase<TInput, TOutput, TState> : IStatefulTransformer
where TState : SequentialTransformerBase<TInput, TOutput, TState>.StateBase, new()
{
public SequentialTransformerBase() { }
/// <summary>
/// The base class for encapsulating the State object for sequential processing. This class implements a windowed buffer.
/// </summary>
internal class StateBase
{
// Ideally this class should be private. However, due to the current constraints with the LambdaTransform, we need to have
// access to the state class when inheriting from SequentialTransformerBase.
private protected IHost Host;
/// <summary>
/// A reference to the parent transform that operates on the state object.
/// </summary>
protected SequentialTransformerBase<TInput, TOutput, TState> ParentTransform;
/// <summary>
/// The internal windowed buffer for buffering the values in the input sequence.
/// </summary>
private protected FixedSizeQueue<TInput> WindowedBuffer { get; set; }
/// <summary>
/// The buffer used to buffer the training data points.
/// </summary>
private protected FixedSizeQueue<TInput> InitialWindowedBuffer { get; set; }
private protected int WindowSize { get; private set; }
private protected int InitialWindowSize { get; private set; }
/// <summary>
/// Counts the number of rows observed by the transform so far.
/// </summary>
protected long RowCounter { get; private set; }
public StateBase()
{
}
protected long IncrementRowCounter()
{
RowCounter++;
return RowCounter;
}
protected long PreviousPosition;
private protected StateBase(BinaryReader reader)
{
WindowSize = reader.ReadInt32();
InitialWindowSize = reader.ReadInt32();
}
internal virtual void Save(BinaryWriter writer)
{
writer.Write(WindowSize);
writer.Write(InitialWindowSize);
}
/// <summary>
/// This method sets the window size and initializes the buffer only once.
/// Since the class needs to implement a default constructor, this methods provides a mechanism to initialize the window size and buffer.
/// </summary>
/// <param name="windowSize">The size of the windowed buffer</param>
/// <param name="initialWindowSize">The size of the windowed initial buffer used for training</param>
/// <param name="parentTransform">The parent transform of this state object</param>
/// <param name="host">The host</param>
public void InitState(int windowSize, int initialWindowSize, SequentialTransformerBase<TInput, TOutput, TState> parentTransform,
IHost host)
{
Contracts.CheckValue(host, nameof(host), "The host cannot be null.");
host.CheckValue(parentTransform, nameof(parentTransform));
host.CheckParam(windowSize >= 0, nameof(windowSize), "Must be non-negative.");
host.CheckParam(initialWindowSize >= 0, nameof(initialWindowSize), "Must be non-negative.");
Host = host;
WindowSize = windowSize;
InitialWindowSize = initialWindowSize;
ParentTransform = parentTransform;
WindowedBuffer = (WindowSize > 0) ? new FixedSizeQueue<TInput>(WindowSize) : new FixedSizeQueue<TInput>(1);
InitialWindowedBuffer = (InitialWindowSize > 0) ? new FixedSizeQueue<TInput>(InitialWindowSize) : new FixedSizeQueue<TInput>(1);
RowCounter = 0;
InitializeStateCore();
PreviousPosition = -1;
}
public void InitState(SequentialTransformerBase<TInput, TOutput, TState> parentTransform, IHost host)
{
Contracts.CheckValue(host, nameof(host), "The host cannot be null.");
host.CheckValue(parentTransform, nameof(parentTransform));
Host = host;
ParentTransform = parentTransform;
RowCounter = 0;
InitializeStateCore(true);
PreviousPosition = -1;
}
/// <summary>
/// This method implements the basic resetting mechanism for a state object and clears the buffer.
/// </summary>
public virtual void Reset()
{
Host.Assert(WindowedBuffer != null);
Host.Assert(InitialWindowedBuffer != null);
RowCounter = 0;
WindowedBuffer.Clear();
InitialWindowedBuffer.Clear();
PreviousPosition = -1;
}
public void UpdateState(ref TInput input, long rowPosition, bool buffer = true)
{
if (rowPosition > PreviousPosition)
{
PreviousPosition = rowPosition;
UpdateStateCore(ref input, buffer);
Consume(input);
}
}
public void UpdateStateCore(ref TInput input, bool buffer = true)
{
if (InitialWindowedBuffer.Count < InitialWindowSize)
{
InitialWindowedBuffer.AddLast(input);
if (InitialWindowedBuffer.Count >= InitialWindowSize - WindowSize && buffer)
WindowedBuffer.AddLast(input);
}
else
{
if (buffer)
WindowedBuffer.AddLast(input);
IncrementRowCounter();
}
}
public void Process(ref TInput input, ref TOutput output1, ref TOutput output2, ref TOutput output3)
{
//Using prediction engine will not evaluate the below condition to true.
if (PreviousPosition == -1)
UpdateStateCore(ref input);
if (InitialWindowedBuffer.Count < InitialWindowSize)
{
SetNaOutput(ref output1);
if (InitialWindowedBuffer.Count == InitialWindowSize)
LearnStateFromDataCore(InitialWindowedBuffer);
}
else
{
TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output1, ref output2, ref output3);
}
}
public void ProcessWithoutBuffer(ref TInput input, ref TOutput output1, ref TOutput output2, ref TOutput output3)
{
//Using prediction engine will not evaluate the below condition to true.
if (PreviousPosition == -1)
UpdateStateCore(ref input);
if (InitialWindowedBuffer.Count < InitialWindowSize)
{
SetNaOutput(ref output1);
if (InitialWindowedBuffer.Count == InitialWindowSize)
LearnStateFromDataCore(InitialWindowedBuffer);
}
else
TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output1, ref output2, ref output3);
}
public void Process(ref TInput input, ref TOutput output)
{
//Using prediction engine will not evaluate the below condition to true.
if (PreviousPosition == -1)
UpdateStateCore(ref input);
if (InitialWindowedBuffer.Count < InitialWindowSize)
{
SetNaOutput(ref output);
if (InitialWindowedBuffer.Count == InitialWindowSize)
LearnStateFromDataCore(InitialWindowedBuffer);
}
else
{
TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output);
}
}
public void ProcessWithoutBuffer(ref TInput input, ref TOutput output)
{
//Using prediction engine will not evaluate the below condition to true.
if (PreviousPosition == -1)
UpdateStateCore(ref input);
if (InitialWindowedBuffer.Count < InitialWindowSize)
{
SetNaOutput(ref output);
if (InitialWindowedBuffer.Count == InitialWindowSize)
LearnStateFromDataCore(InitialWindowedBuffer);
}
else
TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output);
}
/// <summary>
/// The abstract method that specifies the NA value for the dst type.
/// </summary>
/// <returns></returns>
private protected virtual void SetNaOutput(ref TOutput dst) { }
/// <summary>
/// The abstract method that realizes the main logic for the transform.
/// </summary>
/// <param name="input">A reference to the input object.</param>
/// <param name="dst">A reference to the dst object.</param>
/// <param name="windowedBuffer">A reference to the windowed buffer.</param>
/// <param name="iteration">A long number that indicates the number of times TransformCore has been called so far (starting value = 0).</param>
public virtual void TransformCore(ref TInput input, FixedSizeQueue<TInput> windowedBuffer, long iteration, ref TOutput dst)
{
}
public virtual void Forecast(ref TOutput dst)
{
}
public virtual void ConfidenceIntervalLowerBound(ref TOutput dst)
{
}
public virtual void ConfidenceIntervalUpperBound(ref TOutput dst)
{
}
/// <summary>
/// The abstract method that realizes the main logic for the transform.
/// </summary>
/// <param name="input">A reference to the input object.</param>
/// <param name="dst1">A reference to the dst object.</param>
/// <param name="dst2"></param>
/// <param name="dst3"></param>
/// <param name="windowedBuffer">A reference to the windowed buffer.</param>
/// <param name="iteration">A long number that indicates the number of times TransformCore has been called so far (starting value = 0).</param>
private protected virtual void TransformCore(ref TInput input, FixedSizeQueue<TInput> windowedBuffer, long iteration,
ref TOutput dst1, ref TOutput dst2, ref TOutput dst3)
{
}
/// <summary>
/// The abstract method that realizes the logic for initializing the state object.
/// </summary>
private protected virtual void InitializeStateCore(bool disk = false)
{
}
/// <summary>
/// The abstract method that realizes the logic for learning the parameters and the initial state object from data.
/// </summary>
/// <param name="data">A queue of data points used for training</param>
private protected virtual void LearnStateFromDataCore(FixedSizeQueue<TInput> data)
{
}
public virtual void Consume(TInput value)
{
}
public TState Clone()
{
var clone = (TState)MemberwiseClone();
CloneCore(clone);
return clone;
}
private protected virtual void CloneCore(TState state)
{
state.WindowedBuffer = WindowedBuffer.Clone();
state.InitialWindowedBuffer = InitialWindowedBuffer.Clone();
}
}
internal readonly IHost Host;
/// <summary>
/// The window size for buffering.
/// </summary>
internal readonly int WindowSize;
/// <summary>
/// The number of datapoints from the beginning of the sequence that are used for learning the initial state.
/// </summary>
private protected int InitialWindowSize;
internal readonly string InputColumnName;
internal readonly string OutputColumnName;
internal readonly string ConfidenceLowerBoundColumn;
internal readonly string ConfidenceUpperBoundColumn;
private protected DataViewType OutputColumnType;
bool ITransformer.IsRowToRowMapper => false;
internal TState StateRef { get; set; }
public int StateRefCount;
/// <summary>
/// The main constructor for the sequential transform
/// </summary>
/// <param name="host">The host.</param>
/// <param name="windowSize">The size of buffer used for windowed buffering.</param>
/// <param name="initialWindowSize">The number of datapoints picked from the beginning of the series for training the transform parameters if needed.</param>
/// <param name="outputColumnName">The name of the dst column.</param>
/// <param name="inputColumnName">The name of the input column.</param>
/// <param name="outputColType"></param>
private protected SequentialTransformerBase(IHost host, int windowSize, int initialWindowSize,
string outputColumnName, string inputColumnName, DataViewType outputColType)
{
Host = host;
Host.CheckParam(initialWindowSize >= 0, nameof(initialWindowSize), "Must be non-negative.");
Host.CheckParam(windowSize >= 0, nameof(windowSize), "Must be non-negative.");
// REVIEW: Very bad design. This base class is responsible for reporting errors on
// the arguments, but the arguments themselves are not derived form any base class.
Host.CheckNonEmpty(inputColumnName, nameof(PercentileThresholdTransform.Arguments.Source));
Host.CheckNonEmpty(outputColumnName, nameof(PercentileThresholdTransform.Arguments.Source));
InputColumnName = inputColumnName;
OutputColumnName = outputColumnName;
OutputColumnType = outputColType;
InitialWindowSize = initialWindowSize;
WindowSize = windowSize;
}
private protected SequentialTransformerBase(IHost host, int windowSize, int initialWindowSize,
string outputColumnName, string confidenceLowerBoundColumn,
string confidenceUpperBoundColumn, string inputColumnName, DataViewType outputColType) :
this(host, windowSize, initialWindowSize, outputColumnName, inputColumnName, outputColType)
{
ConfidenceLowerBoundColumn = confidenceLowerBoundColumn;
ConfidenceUpperBoundColumn = confidenceUpperBoundColumn;
}
private protected SequentialTransformerBase(IHost host, ModelLoadContext ctx)
{
Host = host;
Host.CheckValue(ctx, nameof(ctx));
// *** Binary format ***
// int: _windowSize
// int: _initialWindowSize
// int (string ID): _sourceColumnName
// int (string ID): _outputColumnName
// ColumnType: _transform.Schema.GetColumnType(0)
var windowSize = ctx.Reader.ReadInt32();
Host.CheckDecode(windowSize >= 0);
var initialWindowSize = ctx.Reader.ReadInt32();
Host.CheckDecode(initialWindowSize >= 0);
var inputColumnName = ctx.LoadNonEmptyString();
var outputColumnName = ctx.LoadNonEmptyString();
InputColumnName = inputColumnName;
OutputColumnName = outputColumnName;
ConfidenceLowerBoundColumn = ctx.Reader.ReadString();
ConfidenceUpperBoundColumn = ctx.Reader.ReadString();
InitialWindowSize = initialWindowSize;
WindowSize = windowSize;
BinarySaver bs = new BinarySaver(Host, new BinarySaver.Arguments());
OutputColumnType = bs.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream);
}
void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
private protected virtual void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Assert(InitialWindowSize >= 0);
Host.Assert(WindowSize >= 0);
// *** Binary format ***
// int: _windowSize
// int: _initialWindowSize
// int (string ID): _sourceColumnName
// int (string ID): _outputColumnName
// ColumnType: _transform.Schema.GetColumnType(0)
ctx.Writer.Write(WindowSize);
ctx.Writer.Write(InitialWindowSize);
ctx.SaveNonEmptyString(InputColumnName);
ctx.SaveNonEmptyString(OutputColumnName);
ctx.Writer.Write(ConfidenceLowerBoundColumn ?? string.Empty);
ctx.Writer.Write(ConfidenceUpperBoundColumn ?? string.Empty);
var bs = new BinarySaver(Host, new BinarySaver.Arguments());
bs.TryWriteTypeDescription(ctx.Writer.BaseStream, OutputColumnType, out int byteWritten);
}
public abstract DataViewSchema GetOutputSchema(DataViewSchema inputSchema);
internal abstract IStatefulRowMapper MakeRowMapper(DataViewSchema schema);
internal SequentialDataTransform MakeDataTransform(IDataView input)
{
Host.CheckValue(input, nameof(input));
return new SequentialDataTransform(Host, this, input, MakeRowMapper(input.Schema));
}
public IDataView Transform(IDataView input) => MakeDataTransform(input);
public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema)
{
throw new InvalidOperationException("Not a RowToRowMapper.");
}
IRowToRowMapper IStatefulTransformer.GetStatefulRowToRowMapper(DataViewSchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
return new TimeSeriesRowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema));
}
internal virtual IStatefulTransformer Clone() => (SequentialTransformerBase<TInput, TOutput, TState>)MemberwiseClone();
IStatefulTransformer IStatefulTransformer.Clone() => Clone();
internal sealed class SequentialDataTransform : TransformBase, ITransformTemplate, IRowToRowMapper
{
private readonly IStatefulRowMapper _mapper;
private readonly SequentialTransformerBase<TInput, TOutput, TState> _parent;
private readonly IDataView _transform;
private readonly ColumnBindings _bindings;
private MetadataDispatcher Metadata { get; }
public SequentialDataTransform(IHost host, SequentialTransformerBase<TInput, TOutput, TState> parent,
IDataView input, IStatefulRowMapper mapper)
: base(parent.Host, input)
{
Metadata = new MetadataDispatcher(3);
_parent = parent;
_transform = CreateLambdaTransform(_parent.Host, input, _parent.InputColumnName,
_parent.OutputColumnName, _parent.ConfidenceLowerBoundColumn,
_parent.ConfidenceUpperBoundColumn, InitFunction, _parent.WindowSize > 0, _parent.OutputColumnType);
_mapper = mapper;
_bindings = new ColumnBindings(input.Schema, _mapper.GetOutputColumns());
}
public void CloneStateInMapper() => _mapper.CloneState();
private static IDataView CreateLambdaTransform(IHost host, IDataView input, string inputColumnName,
string outputColumnName, string forecastingConfidenceIntervalMinOutputColumnName,
string forecastingConfidenceIntervalMaxOutputColumnName, Action<TState> initFunction, bool hasBuffer, DataViewType outputColTypeOverride)
{
var inputSchema = SchemaDefinition.Create(typeof(DataBox<TInput>));
inputSchema[0].ColumnName = inputColumnName;
SchemaDefinition outputSchema;
if (!string.IsNullOrEmpty(forecastingConfidenceIntervalMinOutputColumnName))
{
outputSchema = SchemaDefinition.Create(typeof(DataBoxForecastingWithConfidenceIntervals<TOutput>));
outputSchema[0].ColumnName = outputColumnName;
if (outputColTypeOverride != null)
outputSchema[0].ColumnType = outputSchema[1].ColumnType = outputSchema[2].ColumnType = outputColTypeOverride;
outputSchema[1].ColumnName = forecastingConfidenceIntervalMinOutputColumnName;
outputSchema[2].ColumnName = forecastingConfidenceIntervalMaxOutputColumnName;
Action<DataBox<TInput>, DataBoxForecastingWithConfidenceIntervals<TOutput>, TState> lambda;
if (hasBuffer)
lambda = MapFunction;
else
lambda = MapFunctionWithoutBuffer;
return LambdaTransform.CreateMap(host, input, lambda, initFunction, inputSchema, outputSchema);
}
else
{
outputSchema = SchemaDefinition.Create(typeof(DataBox<TOutput>));
outputSchema[0].ColumnName = outputColumnName;
if (outputColTypeOverride != null)
outputSchema[0].ColumnType = outputColTypeOverride;
Action<DataBox<TInput>, DataBox<TOutput>, TState> lambda;
if (hasBuffer)
lambda = MapFunction;
else
lambda = MapFunctionWithoutBuffer;
return LambdaTransform.CreateMap(host, input, lambda, initFunction, inputSchema, outputSchema);
}
}
private static void MapFunction(DataBox<TInput> input, DataBox<TOutput> output, TState state)
{
state.Process(ref input.Value, ref output.Value);
}
private static void MapFunction(DataBox<TInput> input, DataBoxForecastingWithConfidenceIntervals<TOutput> output, TState state)
{
state.Process(ref input.Value, ref output.Forecast, ref output.ConfidenceIntervalLowerBound, ref output.ConfidenceIntervalUpperBound);
}
private static void MapFunctionWithoutBuffer(DataBox<TInput> input, DataBox<TOutput> output, TState state)
{
state.ProcessWithoutBuffer(ref input.Value, ref output.Value);
}
private static void MapFunctionWithoutBuffer(DataBox<TInput> input, DataBoxForecastingWithConfidenceIntervals<TOutput> output, TState state)
{
state.ProcessWithoutBuffer(ref input.Value, ref output.Forecast, ref output.ConfidenceIntervalLowerBound, ref output.ConfidenceIntervalUpperBound);
}
private void InitFunction(TState state)
{
state.InitState(_parent.WindowSize, _parent.InitialWindowSize, _parent, _parent.Host);
}
public override bool CanShuffle { get { return false; } }
protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
{
var srcCursor = _transform.GetRowCursor(columnsNeeded, rand);
var clone = (SequentialDataTransform)MemberwiseClone();
clone.CloneStateInMapper();
return new Cursor(Host, clone, srcCursor);
}
protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
{
Host.AssertValue(predicate);
return false;
}
public override long? GetRowCount()
=> _transform.GetRowCount();
public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
=> new DataViewRowCursor[] { GetRowCursorCore(columnsNeeded, rand) };
private protected override void SaveModel(ModelSaveContext ctx)
{
(_parent as ICanSaveModel).Save(ctx);
}
IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource)
{
return new SequentialDataTransform(Contracts.CheckRef(env, nameof(env)).Register("SequentialDataTransform"), _parent, newSource, _mapper);
}
public DataViewSchema InputSchema => Source.Schema;
public override DataViewSchema OutputSchema => _bindings.Schema;
/// <summary>
/// Given a set of columns, return the input columns that are needed to generate those output columns.
/// </summary>
public IEnumerable<DataViewSchema.Column> GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns)
{
if (dependingColumns.Count() == 0)
return Enumerable.Empty<DataViewSchema.Column>();
return InputSchema;
}
DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
{
var active = RowCursorUtils.FromColumnsToPredicate(activeColumns, OutputSchema);
var getters = _mapper.CreateGetters(input, active, out Action disposer);
var pingers = _mapper.CreatePinger(input, active, out Action pingerDisposer);
return new RowImpl(_bindings, input, getters, pingers, disposer + pingerDisposer);
}
}
private sealed class RowImpl : StatefulRow
{
private readonly DataViewSchema _schema;
private readonly DataViewRow _input;
private readonly Delegate[] _getters;
private readonly Action<PingerArgument> _pinger;
private readonly Action _disposer;
private bool _disposed;
private readonly ColumnBindings _bindings;
public override DataViewSchema Schema => _schema;
public override long Position => _input.Position;
public override long Batch => _input.Batch;
public RowImpl(ColumnBindings bindings, DataViewRow input, Delegate[] getters, Action<PingerArgument> pinger, Action disposer)
{
Contracts.CheckValue(bindings, nameof(bindings));
Contracts.CheckValue(input, nameof(input));
_schema = bindings.Schema;
_input = input;
_getters = getters ?? new Delegate[0];
_pinger = pinger;
_disposer = disposer;
_bindings = bindings;
}
protected override void Dispose(bool disposing)
{
if (_disposed)
return;
if (disposing)
_disposer?.Invoke();
_disposed = true;
base.Dispose(disposing);
}
public override ValueGetter<DataViewRowId> GetIdGetter()
=> _input.GetIdGetter();
public override ValueGetter<T> GetGetter<T>(DataViewSchema.Column column)
{
bool isSrc;
int index = _bindings.MapColumnIndex(out isSrc, column.Index);
if (isSrc)
return _input.GetGetter<T>(_input.Schema[index]);
Contracts.CheckParam(index < _getters.Length, nameof(column), "Invalid col value in GetGetter");
Contracts.Check(IsColumnActive(column));
var fn = _getters[index] as ValueGetter<T>;
if (fn == null)
throw Contracts.Except("Unexpected TValue in GetGetter");
return fn;
}
public override Action<PingerArgument> GetPinger() =>
_pinger as Action<PingerArgument> ?? throw Contracts.Except("Invalid TValue in GetPinger: '{0}'", typeof(PingerArgument));
/// <summary>
/// Returns whether the given column is active in this row.
/// </summary>
public override bool IsColumnActive(DataViewSchema.Column column)
{
int index = _bindings.MapColumnIndex(out bool isSrc, column.Index);
Contracts.Check(index < _getters.Length);
return _getters[index] != null;
}
}
/// <summary>
/// A wrapper around the cursor which replaces the schema.
/// </summary>
private sealed class Cursor : SynchronizedCursorBase
{
private readonly SequentialDataTransform _parent;
public Cursor(IHost host, SequentialDataTransform parent, DataViewRowCursor input)
: base(host, input)
{
Ch.Assert(input.Schema.Count == parent.OutputSchema.Count);
_parent = parent;
}
public override DataViewSchema Schema => _parent.OutputSchema;
/// <summary>
/// Returns whether the given column is active in this row.
/// </summary>
public override bool IsColumnActive(DataViewSchema.Column column)
{
Ch.Check(column.Index < Schema.Count, nameof(column));
return Input.IsColumnActive(column);
}
/// <summary>
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
/// This throws if the column is not active in this row, or if the type
/// <typeparamref name="TValue"/> differs from this column's type.
/// </summary>
/// <typeparam name="TValue"> is the column's content type.</typeparam>
/// <param name="column"> is the output column whose getter should be returned.</param>
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
{
Ch.Check(IsColumnActive(column), nameof(column));
return Input.GetGetter<TValue>(column);
}
}
}
/// <summary>
/// This class is a transform that can add any number of output columns, that depend on any number of input columns.
/// It does so with the help of an <see cref="IRowMapper"/>, that is given a schema in its constructor, and has methods
/// to get the dependencies on input columns and the getters for the output columns, given an active set of output columns.
/// </summary>
internal sealed class TimeSeriesRowToRowMapperTransform : RowToRowTransformBase, IStatefulRowToRowMapper,
ITransformCanSaveOnnx, ITransformCanSavePfa
{
private readonly IStatefulRowMapper _mapper;
private readonly ColumnBindings _bindings;
public const string RegistrationName = "TimeSeriesRowToRowMapperTransform";
public const string LoaderSignature = "TimeSeriesRowToRowMapper";
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "TS ROW MPPR",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(TimeSeriesRowToRowMapperTransform).Assembly.FullName);
}
public override DataViewSchema OutputSchema => _bindings.Schema;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;
bool ICanSavePfa.CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false;
public TimeSeriesRowToRowMapperTransform(IHostEnvironment env, IDataView input, IStatefulRowMapper mapper)
: base(env, RegistrationName, input)
{
Contracts.CheckValue(mapper, nameof(mapper));
_mapper = mapper;
_bindings = new ColumnBindings(input.Schema, mapper.GetOutputColumns());
}
public static DataViewSchema GetOutputSchema(DataViewSchema inputSchema, IRowMapper mapper)
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Contracts.CheckValue(mapper, nameof(mapper));
return new ColumnBindings(inputSchema, mapper.GetOutputColumns()).Schema;
}
private TimeSeriesRowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input)
: base(host, input)
{
// *** Binary format ***
// _mapper
ctx.LoadModel<IStatefulRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
_bindings = new ColumnBindings(input.Schema, _mapper.GetOutputColumns());
}
public static TimeSeriesRowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
h.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
h.CheckValue(input, nameof(input));
return h.Apply("Loading Model", ch => new TimeSeriesRowToRowMapperTransform(h, ctx, input));
}
private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
// *** Binary format ***
// _mapper
ctx.SaveModel(_mapper, "Mapper");
}
/// <summary>
/// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
/// a predicate for the needed active input columns, and a predicate for the needed active
/// output columns.
/// </summary>
private bool[] GetActive(Func<int, bool> predicate, out Func<int, bool> predicateInput)
{
int n = _bindings.Schema.Count;
var active = Utils.BuildArray(n, predicate);
Contracts.Assert(active.Length == n);
var activeInput = _bindings.GetActiveInput(predicate);
Contracts.Assert(activeInput.Length == _bindings.InputSchema.Count);
// Get a predicate that determines which outputs are active.
var predicateOut = GetActiveOutputColumns(active);
// Now map those to active input columns.
var predicateIn = _mapper.GetDependencies(predicateOut);
// Combine the two sets of input columns.
predicateInput =
col => 0 <= col && col < activeInput.Length && (activeInput[col] || predicateIn(col));
return active;
}
/// <summary>
/// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
/// a predicate for the needed active input columns, and a predicate for the needed active
/// output columns.
/// </summary>
private IEnumerable<DataViewSchema.Column> GetActive(Func<int, bool> predicate)
{
Func<int, bool> predicateInput;
var active = GetActive(predicate, out predicateInput);
return _bindings.Schema.Where(col => predicateInput(col.Index));
}
private Func<int, bool> GetActiveOutputColumns(bool[] active)
{
Contracts.AssertValue(active);
Contracts.Assert(active.Length == _bindings.Schema.Count);
return
col =>
{
Contracts.Assert(0 <= col && col < _bindings.AddedColumnIndices.Count);
return 0 <= col && col < _bindings.AddedColumnIndices.Count && active[_bindings.AddedColumnIndices[col]];
};
}
protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
{
Host.AssertValue(predicate, "predicate");
if (_bindings.AddedColumnIndices.Any(predicate))
return true;
return null;
}
protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
{
Func<int, bool> predicateInput;
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
var active = GetActive(predicate, out predicateInput);
var inputCols = Source.Schema.Where(x => predicateInput(x.Index));
return new Cursor(Host, Source.GetRowCursor(inputCols, rand), this, active);
}
public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
{
Host.CheckValueOrNull(rand);
Func<int, bool> predicateInput;
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
var active = GetActive(predicate, out predicateInput);
var inputCols = Source.Schema.Where(x => predicateInput(x.Index));
var inputs = Source.GetRowCursorSet(inputCols, n, rand);
Host.AssertNonEmpty(inputs);
if (inputs.Length == 1 && n > 1 && _bindings.AddedColumnIndices.Any(predicate))
inputs = DataViewUtils.CreateSplitCursors(Host, inputs[0], n);
Host.AssertNonEmpty(inputs);
var cursors = new DataViewRowCursor[inputs.Length];
for (int i = 0; i < inputs.Length; i++)
cursors[i] = new Cursor(Host, inputs[i], this, active);
return cursors;
}
void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
if (_mapper is ISaveAsOnnx onnx)
{
Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
onnx.SaveAsOnnx(ctx);
}
}
void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
if (_mapper is ISaveAsPfa pfa)
{
Host.Check(pfa.CanSavePfa, "Cannot be saved as PFA.");
pfa.SaveAsPfa(ctx);
}
}
/// <summary>
/// Given a set of columns, return the input columns that are needed to generate those output columns.
/// </summary>
public IEnumerable<DataViewSchema.Column> GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns)
{
var predicate = RowCursorUtils.FromColumnsToPredicate(dependingColumns, OutputSchema);
return GetActive(predicate);
}
DataViewSchema IRowToRowMapper.InputSchema => Source.Schema;
DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
{
Host.CheckValue(input, nameof(input));
Host.CheckValue(activeColumns, nameof(activeColumns));
Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to");
using (var ch = Host.Start("GetEntireRow"))
{
var activeArr = Utils.BuildArray(OutputSchema.Count, activeColumns);
var pred = GetActiveOutputColumns(activeArr);
var getters = _mapper.CreateGetters(input, pred, out Action disp);
var pingers = _mapper.CreatePinger(input, pred, out Action pingerDisp);
return new StatefulRowImpl(input, this, OutputSchema, getters, pingers, disp + pingerDisp);
}
}
private sealed class StatefulRowImpl : StatefulRow
{
private readonly DataViewRow _input;
private readonly Delegate[] _getters;
private readonly Action<PingerArgument> _pinger;
private readonly Action _disposer;
private readonly TimeSeriesRowToRowMapperTransform _parent;
public override long Batch => _input.Batch;
public override long Position => _input.Position;
public override DataViewSchema Schema { get; }
public StatefulRowImpl(DataViewRow input, TimeSeriesRowToRowMapperTransform parent,
DataViewSchema schema, Delegate[] getters, Action<PingerArgument> pinger, Action disposer)
{
_input = input;
_parent = parent;
Schema = schema;
_getters = getters;
_pinger = pinger;
_disposer = disposer;
}
protected override void Dispose(bool disposing)
{
if (disposing)
_disposer?.Invoke();
}
/// <summary>
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
/// This throws if the column is not active in this row, or if the type
/// <typeparamref name="TValue"/> differs from this column's type.
/// </summary>
/// <typeparam name="TValue"> is the column's content type.</typeparam>
/// <param name="column"> is the output column whose getter should be returned.</param>
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
{
bool isSrc;
int index = _parent._bindings.MapColumnIndex(out isSrc, column.Index);
if (isSrc)
return _input.GetGetter<TValue>(column);
var originFn = _getters[index];
Contracts.Assert(originFn != null);
var fn = originFn as ValueGetter<TValue>;
if (fn == null)
throw Contracts.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}', " +
$"expected type: '{originFn.GetType().GetGenericArguments().First()}'.");
return fn;
}
public override Action<PingerArgument> GetPinger() =>
_pinger as Action<PingerArgument> ?? throw Contracts.Except("Invalid TValue in GetPinger: '{0}'", typeof(PingerArgument));
public override ValueGetter<DataViewRowId> GetIdGetter() => _input.GetIdGetter();
/// <summary>
/// Returns whether the given column is active in this row.
/// </summary>
public override bool IsColumnActive(DataViewSchema.Column column)
{
bool isSrc;
int index = _parent._bindings.MapColumnIndex(out isSrc, column.Index);
if (isSrc)
return _input.IsColumnActive(_input.Schema[index]);
return _getters[index] != null;
}
}
private sealed class Cursor : SynchronizedCursorBase
{
private readonly Delegate[] _getters;
private readonly bool[] _active;
private readonly ColumnBindings _bindings;
private readonly Action _disposer;
private bool _disposed;
public override DataViewSchema Schema => _bindings.Schema;
public Cursor(IChannelProvider provider, DataViewRowCursor input, TimeSeriesRowToRowMapperTransform parent, bool[] active)
: base(provider, input)
{
var pred = parent.GetActiveOutputColumns(active);
_getters = parent._mapper.CreateGetters(input, pred, out _disposer);
_active = active;
_bindings = parent._bindings;
}
/// <summary>
/// Returns whether the given column is active in this row.
/// </summary>
public override bool IsColumnActive(DataViewSchema.Column column)
{
Ch.Check(column.Index < _bindings.Schema.Count);
return _active[column.Index];
}
/// <summary>
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
/// This throws if the column is not active in this row, or if the type
/// <typeparamref name="TValue"/> differs from this column's type.
/// </summary>
/// <typeparam name="TValue"> is the column's content type.</typeparam>
/// <param name="column"> is the output column whose getter should be returned.</param>
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
{
Ch.Check(IsColumnActive(column));
bool isSrc;
int index = _bindings.MapColumnIndex(out isSrc, column.Index);
if (isSrc)
return Input.GetGetter<TValue>(column);
Ch.AssertValue(_getters);
var getter = _getters[index];
Ch.AssertValue(getter);
if (getter is ValueGetter<TValue> fn)
return fn;
throw Ch.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}', " +
$"expected type: '{getter.GetType().GetGenericArguments().First()}'.");
}
protected override void Dispose(bool disposing)
{
if (_disposed)
return;
if (disposing)
_disposer?.Invoke();
_disposed = true;
base.Dispose(disposing);
}
}
}
}
|