File: SlidingWindowTransformBase.cs
Web Access
Project: src\src\Microsoft.ML.TimeSeries\Microsoft.ML.TimeSeries.csproj (Microsoft.ML.TimeSeries)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Transforms.TimeSeries
{
    /// <summary>
    /// SlidingWindowTransformBase outputs a sliding window as a VBuffer from a series of any type.
    /// The VBuffer contains n consecutive observations delayed or not from the current one.
    /// Let's denote y(t) a timeseries, the transform returns a vector of values for each time t
    /// which corresponds to [y(t-d-l+1), y(t-d-l+2), ..., y(t-l-1), y(t-l)] where d is the size of the window
    /// and l is the delay.
    /// </summary>
 
    internal abstract class SlidingWindowTransformBase<TInput> : SequentialTransformBase<TInput, VBuffer<TInput>, SlidingWindowTransformBase<TInput>.StateSlide>
    {
        /// <summary>
        /// Defines what should be done about the first rows.
        /// </summary>
        public enum BeginOptions : byte
        {
            /// <summary>
            /// Fill first rows with NaN values.
            /// </summary>
            NaNValues = 0,
 
            /// <summary>
            /// Copy the first value of the series.
            /// </summary>
            FirstValue = 1
        }
 
        public sealed class Arguments : TransformInputBase
        {
            [Argument(ArgumentType.Required, HelpText = "The name of the source column", ShortName = "src",
                SortOrder = 1, Purpose = SpecialPurpose.ColumnName)]
            public string Source;
 
            [Argument(ArgumentType.Required, HelpText = "The name of the new column",
                SortOrder = 2)]
            public string Name;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the sliding window for computing the moving average", ShortName = "wnd", SortOrder = 3)]
            public int WindowSize = 2;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Lag between current observation and last observation from the sliding window", ShortName = "l", SortOrder = 4)]
            public int Lag = 1;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Define how to populate the first rows of the produced series", SortOrder = 5)]
            public BeginOptions Begin = BeginOptions.NaNValues;
        }
 
        private readonly int _lag;
        private readonly BeginOptions _begin;
        private readonly TInput _nanValue;
 
        protected SlidingWindowTransformBase(Arguments args, string loaderSignature, IHostEnvironment env, IDataView input)
            : base(args.WindowSize + args.Lag - 1, args.WindowSize + args.Lag - 1, args.Name, args.Source, loaderSignature, env, input)
        {
            Host.CheckUserArg(args.WindowSize >= 1, nameof(args.WindowSize), "Must be at least 1.");
            Host.CheckUserArg(args.Lag >= 0, nameof(args.Lag), "Must be positive.");
            if (args.Lag == 0 && args.WindowSize <= 1)
            {
                Host.Assert(args.WindowSize == 1);
                throw Host.ExceptUserArg(nameof(args.Lag),
                    $"If {args.Lag}=0 and {args.WindowSize}=1, the transform just copies the column. Use {ColumnCopyingTransformer.LoaderSignature} transform instead.");
            }
            Host.CheckUserArg(Enum.IsDefined(typeof(BeginOptions), args.Begin), nameof(args.Begin), "Undefined value.");
            _lag = args.Lag;
            _begin = args.Begin;
            _nanValue = GetNaValue();
        }
 
        protected SlidingWindowTransformBase(IHostEnvironment env, ModelLoadContext ctx, string loaderSignature, IDataView input)
            : base(env, ctx, loaderSignature, input)
        {
            // *** Binary format ***
            // <base>
            // Int32 lag
            // byte begin
 
            Host.CheckDecode(WindowSize >= 1);
            _lag = ctx.Reader.ReadInt32();
            Host.CheckDecode(_lag >= 0);
            byte r = ctx.Reader.ReadByte();
            Host.CheckDecode(Enum.IsDefined(typeof(BeginOptions), r));
            _begin = (BeginOptions)r;
            _nanValue = GetNaValue();
        }
 
        private TInput GetNaValue()
        {
            var sch = OutputSchema;
            int index;
            sch.TryGetColumnIndex(InputColumnName, out index);
            DataViewType col = sch[index].Type;
            TInput nanValue = Data.Conversion.Conversions.DefaultInstance.GetNAOrDefault<TInput>(col);
 
            // We store the nan_value here to avoid getting it each time a state is instanciated.
            return nanValue;
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(WindowSize >= 1);
            Host.Assert(_lag >= 0);
            Host.Assert(Enum.IsDefined(typeof(BeginOptions), _begin));
            ctx.CheckAtModel();
 
            // *** Binary format ***
            // <base>
            // Int32 lag
            // byte begin
 
            base.SaveModel(ctx);
            ctx.Writer.Write(_lag);
            ctx.Writer.Write((byte)_begin);
        }
 
        public sealed class StateSlide : StateBase
        {
            private SlidingWindowTransformBase<TInput> _parentSliding;
 
            private protected override void SetNaOutput(ref VBuffer<TInput> output)
            {
 
                int size = _parentSliding.WindowSize - _parentSliding._lag + 1;
                var result = VBufferEditor.Create(ref output, size);
 
                TInput value = _parentSliding._nanValue;
                switch (_parentSliding._begin)
                {
                    case BeginOptions.NaNValues:
                        value = _parentSliding._nanValue;
                        break;
                    case BeginOptions.FirstValue:
                        // REVIEW: will complete the implementation
                        // if the design looks good
                        throw new NotImplementedException();
                }
 
                for (int i = 0; i < size; ++i)
                    result.Values[i] = value;
                output = result.Commit();
            }
 
            private protected override void TransformCore(ref TInput input, FixedSizeQueue<TInput> windowedBuffer, long iteration, ref VBuffer<TInput> output)
            {
                int size = _parentSliding.WindowSize - _parentSliding._lag + 1;
                var result = VBufferEditor.Create(ref output, size);
 
                if (_parentSliding._lag == 0)
                {
                    for (int i = 0; i < _parentSliding.WindowSize; ++i)
                        result.Values[i] = windowedBuffer[i];
                    result.Values[_parentSliding.WindowSize] = input;
                }
                else
                {
                    for (int i = 0; i < size; ++i)
                        result.Values[i] = windowedBuffer[i];
                }
                output = result.Commit();
            }
 
            private protected override void InitializeStateCore()
            {
                _parentSliding = (SlidingWindowTransformBase<TInput>)base.ParentTransform;
            }
 
            private protected override void LearnStateFromDataCore(FixedSizeQueue<TInput> data)
            {
                // This method is empty because there is no need for parameter learning from the initial windowed buffer for this transform.
            }
        }
    }
}