File: SrCnnTransformBase.cs
Web Access
Project: src\src\Microsoft.ML.TimeSeries\Microsoft.ML.TimeSeries.csproj (Microsoft.ML.TimeSeries)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using System.Threading;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Transforms.TimeSeries
{
    internal abstract class SrCnnArgumentBase
    {
        [Argument(ArgumentType.Required, HelpText = "The name of the source column.", ShortName = "src",
            SortOrder = 1, Purpose = SpecialPurpose.ColumnName)]
        public string Source;
 
        [Argument(ArgumentType.Required, HelpText = "The name of the new column.",
            SortOrder = 2)]
        public string Name;
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the sliding window for computing spectral residual", ShortName = "wnd",
            SortOrder = 3)]
        public int WindowSize = 24;
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the initial window for computing. The default value is set to 0, which means there is no initial window considered.", ShortName = "iwnd",
            SortOrder = 4)]
        public int InitialWindowSize = 0;
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "The number of points to the back of training window.",
            ShortName = "backwnd", SortOrder = 5)]
        public int BackAddWindowSize = 5;
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "The number of pervious points used in prediction.",
            ShortName = "aheadwnd", SortOrder = 6)]
        public int LookaheadWindowSize = 5;
 
        [Argument(ArgumentType.Required, HelpText = "The size of sliding window to generate a saliency map for the series.",
            ShortName = "avgwnd", SortOrder = 7)]
        public int AvergingWindowSize = 3;
 
        [Argument(ArgumentType.Required, HelpText = "The size of sliding window to generate a saliency map for the series.",
            ShortName = "jdgwnd", SortOrder = 8)]
        public int JudgementWindowSize = 21;
 
        [Argument(ArgumentType.AtMostOnce, HelpText = "The threshold to determine anomaly, score larger than the threshold is considered as anomaly.",
            ShortName = "thre", SortOrder = 9)]
        public double Threshold = 0.3;
    }
 
    internal abstract class SrCnnTransformBase<TInput, TState> : SequentialTransformerBase<TInput, VBuffer<Double>, TState>
        where TState : SrCnnTransformBase<TInput, TState>.SrCnnStateBase, new()
    {
        internal int BackAddWindowSize { get; }
 
        internal int LookaheadWindowSize { get; }
 
        internal int AvergingWindowSize { get; }
 
        internal int JudgementWindowSize { get; }
 
        internal double AlertThreshold { get; }
 
        internal int OutputLength { get; }
 
        private protected SrCnnTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env,
            int backAddWindowSize, int lookaheadWindowSize, int averagingWindowSize, int judgementWindowSize, Double alertThreshold)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, outputColumnName, inputColumnName, new VectorDataViewType(NumberDataViewType.Double, 3))
        {
            Host.CheckUserArg(backAddWindowSize > 0, nameof(SrCnnArgumentBase.BackAddWindowSize), "Must be non-negative");
            Host.CheckUserArg(lookaheadWindowSize > 0 && lookaheadWindowSize <= windowSize, nameof(SrCnnArgumentBase.LookaheadWindowSize), "Must be non-negative and not larger than window size");
            Host.CheckUserArg(averagingWindowSize > 0 && averagingWindowSize <= windowSize, nameof(SrCnnArgumentBase.AvergingWindowSize), "Must be non-negative and not larger than window size");
            Host.CheckUserArg(judgementWindowSize > 0 && judgementWindowSize <= windowSize, nameof(SrCnnArgumentBase.JudgementWindowSize), "Must be non-negative and not larger than window size");
            Host.CheckUserArg(alertThreshold > 0 && alertThreshold < 1, nameof(SrCnnArgumentBase.Threshold), "Must be in (0,1)");
 
            BackAddWindowSize = backAddWindowSize;
            LookaheadWindowSize = lookaheadWindowSize;
            AvergingWindowSize = averagingWindowSize;
            JudgementWindowSize = judgementWindowSize;
            AlertThreshold = alertThreshold;
 
            OutputLength = 3;
        }
 
        private protected SrCnnTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), ctx)
        {
            OutputLength = 3;
 
            byte temp;
            temp = ctx.Reader.ReadByte();
            BackAddWindowSize = (int)temp;
            Host.CheckDecode(BackAddWindowSize > 0);
 
            temp = ctx.Reader.ReadByte();
            LookaheadWindowSize = (int)temp;
            Host.CheckDecode(LookaheadWindowSize > 0);
 
            temp = ctx.Reader.ReadByte();
            AvergingWindowSize = (int)temp;
            Host.CheckDecode(AvergingWindowSize > 0);
 
            temp = ctx.Reader.ReadByte();
            JudgementWindowSize = (int)temp;
            Host.CheckDecode(JudgementWindowSize > 0);
 
            AlertThreshold = ctx.Reader.ReadDouble();
            Host.CheckDecode(AlertThreshold >= 0 && AlertThreshold <= 1);
        }
 
        private protected SrCnnTransformBase(SrCnnArgumentBase args, string name, IHostEnvironment env)
            : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name,
                  name, env, args.BackAddWindowSize, args.LookaheadWindowSize, args.AvergingWindowSize, args.JudgementWindowSize, args.Threshold)
        {
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
 
            Host.Assert(WindowSize > 0);
            Host.Assert(InitialWindowSize == WindowSize);
            Host.Assert(BackAddWindowSize > 0);
            Host.Assert(LookaheadWindowSize > 0);
            Host.Assert(AvergingWindowSize > 0);
            Host.Assert(JudgementWindowSize > 0);
            Host.Assert(AlertThreshold >= 0 && AlertThreshold <= 1);
 
            base.SaveModel(ctx);
            ctx.Writer.Write((byte)BackAddWindowSize);
            ctx.Writer.Write((byte)LookaheadWindowSize);
            ctx.Writer.Write((byte)AvergingWindowSize);
            ctx.Writer.Write((byte)JudgementWindowSize);
            ctx.Writer.Write(AlertThreshold);
        }
 
        internal override IStatefulRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(Host, this, schema);
 
        internal sealed class Mapper : IStatefulRowMapper
        {
            private readonly IHost _host;
            private readonly SrCnnTransformBase<TInput, TState> _parent;
            private readonly DataViewSchema _parentSchema;
            private readonly int _inputColumnIndex;
            private readonly VBuffer<ReadOnlyMemory<Char>> _slotNames;
            private SrCnnStateBase State { get; set; }
 
            public Mapper(IHostEnvironment env, SrCnnTransformBase<TInput, TState> parent, DataViewSchema inputSchema)
            {
                Contracts.CheckValue(env, nameof(env));
                _host = env.Register(nameof(Mapper));
                _host.CheckValue(inputSchema, nameof(inputSchema));
                _host.CheckValue(parent, nameof(parent));
 
                if (!inputSchema.TryGetColumnIndex(parent.InputColumnName, out _inputColumnIndex))
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName);
 
                var colType = inputSchema[_inputColumnIndex].Type;
                if (colType != NumberDataViewType.Single)
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName, "Single", colType.ToString());
 
                _parent = parent;
                _parentSchema = inputSchema;
                _slotNames = new VBuffer<ReadOnlyMemory<char>>(_parent.OutputLength, new[] { "Alert".AsMemory(), "Raw Score".AsMemory(),
                    "Mag".AsMemory()});
 
                State = (SrCnnStateBase)_parent.StateRef;
            }
 
            public DataViewSchema.DetachedColumn[] GetOutputColumns()
            {
                var meta = new DataViewSchema.Annotations.Builder();
                meta.AddSlotNames(_parent.OutputLength, GetSlotNames);
                var info = new DataViewSchema.DetachedColumn[1];
                info[0] = new DataViewSchema.DetachedColumn(_parent.OutputColumnName, new VectorDataViewType(NumberDataViewType.Double, _parent.OutputLength), meta.ToAnnotations());
                return info;
            }
 
            public void GetSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst) => _slotNames.CopyTo(ref dst, 0, _parent.OutputLength);
 
            public Func<int, bool> GetDependencies(Func<int, bool> activeOutput)
            {
                if (activeOutput(0))
                    return col => col == _inputColumnIndex;
                else
                    return col => false;
            }
 
            void ICanSaveModel.Save(ModelSaveContext ctx) => _parent.SaveModel(ctx);
 
            public Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
            {
                disposer = null;
                var getters = new Delegate[1];
                if (activeOutput(0))
                    getters[0] = MakeGetter(input, State);
 
                return getters;
            }
 
            private delegate void ProcessData(ref TInput src, ref VBuffer<double> dst);
 
            private Delegate MakeGetter(DataViewRow input, SrCnnStateBase state)
            {
                _host.AssertValue(input);
                var srcGetter = input.GetGetter<TInput>(input.Schema[_inputColumnIndex]);
                ProcessData processData = _parent.WindowSize > 0 ?
                    (ProcessData)state.Process : state.ProcessWithoutBuffer;
 
                ValueGetter<VBuffer<double>> valueGetter = (ref VBuffer<double> dst) =>
                {
                    TInput src = default;
                    srcGetter(ref src);
                    processData(ref src, ref dst);
                };
                return valueGetter;
            }
 
            public Action<PingerArgument> CreatePinger(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
            {
                disposer = null;
                Action<PingerArgument> pinger = null;
                if (activeOutput(0))
                    pinger = MakePinger(input, State);
 
                return pinger;
            }
 
            private Action<PingerArgument> MakePinger(DataViewRow input, SrCnnStateBase state)
            {
                _host.AssertValue(input);
                var srcGetter = input.GetGetter<TInput>(input.Schema[_inputColumnIndex]);
                Action<PingerArgument> pinger = (PingerArgument args) =>
                {
                    if (args.DontConsumeSource)
                        return;
 
                    TInput src = default;
                    srcGetter(ref src);
                    state.UpdateState(ref src, args.RowPosition, _parent.WindowSize > 0);
                };
                return pinger;
            }
 
            public void CloneState()
            {
                if (Interlocked.Increment(ref _parent.StateRefCount) > 1)
                {
                    State = (SrCnnStateBase)_parent.StateRef.Clone();
                }
            }
 
            public ITransformer GetTransformer()
            {
                return _parent;
            }
        }
 
        internal abstract class SrCnnStateBase : SequentialTransformerBase<TInput, VBuffer<Double>, TState>.StateBase
        {
            protected SrCnnTransformBase<TInput, TState> Parent;
 
            private protected SrCnnStateBase() { }
 
            private protected override void CloneCore(TState state)
            {
                base.CloneCore(state);
                Contracts.Assert(state is SrCnnStateBase);
            }
 
            private protected SrCnnStateBase(BinaryReader reader) : base(reader)
            {
            }
 
            internal override void Save(BinaryWriter writer)
            {
                base.Save(writer);
            }
 
            private protected override void SetNaOutput(ref VBuffer<double> dst)
            {
                var outputLength = Parent.OutputLength;
                var editor = VBufferEditor.Create(ref dst, outputLength);
 
                for (int i = 0; i < outputLength; ++i)
                    editor.Values[i] = 0;
 
                dst = editor.Commit();
            }
 
            public sealed override void TransformCore(ref TInput input, FixedSizeQueue<TInput> windowedBuffer, long iteration, ref VBuffer<double> dst)
            {
                var outputLength = Parent.OutputLength;
 
                var result = VBufferEditor.Create(ref dst, outputLength);
                result.Values.Fill(Double.NaN);
 
                SpectralResidual(input, windowedBuffer, ref result);
 
                dst = result.Commit();
            }
 
            private protected sealed override void InitializeStateCore(bool disk = false)
            {
                Parent = (SrCnnTransformBase<TInput, TState>)ParentTransform;
            }
 
            private protected override void LearnStateFromDataCore(FixedSizeQueue<TInput> data)
            {
            }
 
            private protected virtual void SpectralResidual(TInput input, FixedSizeQueue<TInput> data, ref VBufferEditor<double> result)
            {
            }
        }
    }
}