File: SsaAnomalyDetectionBase.cs
Web Access
Project: src\src\Microsoft.ML.TimeSeries\Microsoft.ML.TimeSeries.csproj (Microsoft.ML.TimeSeries)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Transforms.TimeSeries
{
    public enum ErrorFunction : byte
    {
        SignedDifference,
        AbsoluteDifference,
        SignedProportion,
        AbsoluteProportion,
        SquaredDifference
    }
 
    /// <summary>
    /// Provides the utility functions for different error functions for computing deviation.
    /// </summary>
    internal static class ErrorFunctionUtils
    {
        public const string ErrorFunctionHelpText = "The error function should be either (0) SignedDifference, (1) AbsoluteDifference, (2) SignedProportion" +
                                                     " (3) AbsoluteProportion or (4) SquaredDifference.";
 
        public static double SignedDifference(double actual, Double predicted)
        {
            return actual - predicted;
        }
 
        public static Double AbsoluteDifference(Double actual, Double predicted)
        {
            return Math.Abs(actual - predicted);
        }
 
        public static Double SignedProportion(Double actual, Double predicted)
        {
            return predicted == 0 ? 0 : (actual - predicted) / predicted;
        }
 
        public static Double AbsoluteProportion(Double actual, Double predicted)
        {
            return predicted == 0 ? 0 : Math.Abs((actual - predicted) / predicted);
        }
 
        public static Double SquaredDifference(Double actual, Double predicted)
        {
            Double temp = actual - predicted;
            return temp * temp;
        }
 
        public static Func<Double, Double, Double> GetErrorFunction(ErrorFunction errorFunction)
        {
            switch (errorFunction)
            {
                case ErrorFunction.SignedDifference:
                    return SignedDifference;
 
                case ErrorFunction.AbsoluteDifference:
                    return AbsoluteDifference;
 
                case ErrorFunction.SignedProportion:
                    return SignedProportion;
 
                case ErrorFunction.AbsoluteProportion:
                    return AbsoluteProportion;
 
                case ErrorFunction.SquaredDifference:
                    return SquaredDifference;
 
                default:
                    throw Contracts.Except(ErrorFunctionHelpText);
            }
        }
    }
 
    /// <summary>
    /// The wrapper to <see cref="SsaAnomalyDetectionBase"/> that implements the general anomaly detection transform based on Singular Spectrum modeling of the time-series.
    /// For the details of the Singular Spectrum Analysis (SSA), refer to http://arxiv.org/pdf/1206.6910.pdf.
    /// </summary>
    public class SsaAnomalyDetectionBaseWrapper : IStatefulTransformer, ICanSaveModel
    {
        /// <summary>
        /// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
        /// appropriate schema.
        /// </summary>
        bool ITransformer.IsRowToRowMapper => ((ITransformer)InternalTransform).IsRowToRowMapper;
 
        /// <summary>
        /// Creates a clone of the transformer. Used for taking the snapshot of the state.
        /// </summary>
        /// <returns></returns>
        IStatefulTransformer IStatefulTransformer.Clone() => InternalTransform.Clone();
 
        /// <summary>
        /// Schema propagation for transformers.
        /// Returns the output schema of the data, if the input schema is like the one provided.
        /// </summary>
        public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => InternalTransform.GetOutputSchema(inputSchema);
 
        /// <summary>
        /// Constructs a row-to-row mapper based on an input schema. If <see cref="ITransformer.IsRowToRowMapper"/>
        /// is <c>false</c>, then an exception should be thrown. If the input schema is in any way
        /// unsuitable for constructing the mapper, an exception should likewise be thrown.
        /// </summary>
        /// <param name="inputSchema">The input schema for which we should get the mapper.</param>
        /// <returns>The row to row mapper.</returns>
        IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
            => ((ITransformer)InternalTransform).GetRowToRowMapper(inputSchema);
 
        /// <summary>
        /// Same as <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> but also supports mechanism to save the state.
        /// </summary>
        /// <param name="inputSchema">The input schema for which we should get the mapper.</param>
        /// <returns>The row to row mapper.</returns>
        public IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema)
            => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema);
 
        /// <summary>
        /// Take the data in, make transformations, output the data.
        /// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
        /// </summary>
        public IDataView Transform(IDataView input) => InternalTransform.Transform(input);
 
        /// <summary>
        /// For saving a model into a repository.
        /// </summary>
        void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
 
        private protected virtual void SaveModel(ModelSaveContext ctx) => InternalTransform.SaveThis(ctx);
 
        /// <summary>
        /// Creates a row mapper from Schema.
        /// </summary>
        internal IStatefulRowMapper MakeRowMapper(DataViewSchema schema) => InternalTransform.MakeRowMapper(schema);
 
        /// <summary>
        /// Creates an IDataTransform from an IDataView.
        /// </summary>
        internal IDataTransform MakeDataTransform(IDataView input) => InternalTransform.MakeDataTransform(input);
 
        /// <summary>
        /// Options for SSA Anomaly Detection.
        /// </summary>
        internal abstract class SsaOptions : ArgumentsBase
        {
            [Argument(ArgumentType.Required, HelpText = "The inner window size for SSA in [2, windowSize]", ShortName = "swnd", SortOrder = 11)]
            public int SeasonalWindowSize;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The discount factor in [0, 1]", ShortName = "disc", SortOrder = 12)]
            public Single DiscountFactor = 1;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The function used to compute the error between the expected and the observed value", ShortName = "err", SortOrder = 13)]
            public ErrorFunction ErrorFunction = ErrorFunction.SignedDifference;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The flag determing whether the model is adaptive", ShortName = "adp", SortOrder = 14)]
            public bool IsAdaptive = false;
        }
 
        internal SsaAnomalyDetectionBase InternalTransform;
 
        internal SsaAnomalyDetectionBaseWrapper(SsaOptions options, string name, IHostEnvironment env)
        {
            InternalTransform = new SsaAnomalyDetectionBase(options, name, env, this);
        }
 
        internal SsaAnomalyDetectionBaseWrapper(IHostEnvironment env, ModelLoadContext ctx, string name)
        {
            InternalTransform = new SsaAnomalyDetectionBase(env, ctx, name);
        }
 
        /// <summary>
        /// This base class that implements the general anomaly detection transform based on Singular Spectrum modeling of the time-series.
        /// For the details of the Singular Spectrum Analysis (SSA), refer to http://arxiv.org/pdf/1206.6910.pdf.
        /// </summary>
        internal sealed class SsaAnomalyDetectionBase : SequentialAnomalyDetectionTransformBase<float, SsaAnomalyDetectionBase.State>
        {
            internal SsaAnomalyDetectionBaseWrapper Parent;
            internal readonly int SeasonalWindowSize;
            internal readonly Single DiscountFactor;
            internal readonly bool IsAdaptive;
            internal readonly ErrorFunction ErrorFunction;
            internal readonly Func<Double, Double, Double> ErrorFunc;
            internal SequenceModelerBase<Single, Single> Model;
 
            public SsaAnomalyDetectionBase(SsaOptions options, string name, IHostEnvironment env, SsaAnomalyDetectionBaseWrapper parent)
                : base(options.WindowSize, 0, options.Source, options.Name, name, env, options.Side, options.Martingale, options.AlertOn, options.PowerMartingaleEpsilon, options.AlertThreshold)
            {
                Host.CheckUserArg(2 <= options.SeasonalWindowSize, nameof(options.SeasonalWindowSize), "Must be at least 2.");
                Host.CheckUserArg(0 <= options.DiscountFactor && options.DiscountFactor <= 1, nameof(options.DiscountFactor), "Must be in the range [0, 1].");
                Host.CheckUserArg(Enum.IsDefined(typeof(ErrorFunction), options.ErrorFunction), nameof(options.ErrorFunction), ErrorFunctionUtils.ErrorFunctionHelpText);
 
                SeasonalWindowSize = options.SeasonalWindowSize;
                DiscountFactor = options.DiscountFactor;
                ErrorFunction = options.ErrorFunction;
                ErrorFunc = ErrorFunctionUtils.GetErrorFunction(ErrorFunction);
                IsAdaptive = options.IsAdaptive;
                // Creating the master SSA model
                Model = new AdaptiveSingularSpectrumSequenceModelerInternal(Host, options.InitialWindowSize, SeasonalWindowSize + 1, SeasonalWindowSize,
                    DiscountFactor, RankSelectionMethod.Exact, null, SeasonalWindowSize / 2, false, false);
 
                StateRef = new State();
                StateRef.InitState(WindowSize, InitialWindowSize, this, Host);
                Parent = parent;
            }
 
            public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name)
                : base(env, ctx, name)
            {
                // *** Binary format ***
                // <base>
                // int: _seasonalWindowSize
                // float: _discountFactor
                // byte: _errorFunction
                // bool: _isAdaptive
                // AdaptiveSingularSpectrumSequenceModeler: _model
 
                Host.CheckDecode(InitialWindowSize == 0);
 
                SeasonalWindowSize = ctx.Reader.ReadInt32();
                Host.CheckDecode(2 <= SeasonalWindowSize);
 
                DiscountFactor = ctx.Reader.ReadSingle();
                Host.CheckDecode(0 <= DiscountFactor && DiscountFactor <= 1);
 
                byte temp;
                temp = ctx.Reader.ReadByte();
                Host.CheckDecode(Enum.IsDefined(typeof(ErrorFunction), temp));
                ErrorFunction = (ErrorFunction)temp;
                ErrorFunc = ErrorFunctionUtils.GetErrorFunction(ErrorFunction);
 
                IsAdaptive = ctx.Reader.ReadBoolean();
                StateRef = new State(ctx.Reader);
 
                ctx.LoadModel<SequenceModelerBase<Single, Single>, SignatureLoadModel>(env, out Model, "SSA");
                Host.CheckDecode(Model != null);
                StateRef.InitState(this, Host);
            }
 
            public override DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
            {
                Host.CheckValue(inputSchema, nameof(inputSchema));
 
                if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName);
 
                var colType = inputSchema[col].Type;
                if (colType != NumberDataViewType.Single)
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, "Single", colType.ToString());
 
                return Transform(new EmptyDataView(Host, inputSchema)).Schema;
            }
 
            private protected override void SaveModel(ModelSaveContext ctx)
            {
                ((ICanSaveModel)Parent).Save(ctx);
            }
 
            internal void SaveThis(ModelSaveContext ctx)
            {
                Host.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel();
 
                Host.Assert(InitialWindowSize == 0);
                Host.Assert(2 <= SeasonalWindowSize);
                Host.Assert(0 <= DiscountFactor && DiscountFactor <= 1);
                Host.Assert(Enum.IsDefined(typeof(ErrorFunction), ErrorFunction));
                Host.Assert(Model != null);
 
                // *** Binary format ***
                // <base>
                // int: _seasonalWindowSize
                // float: _discountFactor
                // byte: _errorFunction
                // bool: _isAdaptive
                // State: StateRef
                // AdaptiveSingularSpectrumSequenceModeler: _model
 
                base.SaveModel(ctx);
                ctx.Writer.Write(SeasonalWindowSize);
                ctx.Writer.Write(DiscountFactor);
                ctx.Writer.Write((byte)ErrorFunction);
                ctx.Writer.Write(IsAdaptive);
                StateRef.Save(ctx.Writer);
 
                ctx.SaveModel(Model, "SSA");
            }
 
            internal sealed class State : AnomalyDetectionStateBase
            {
                private SequenceModelerBase<Single, Single> _model;
                private SsaAnomalyDetectionBase _parentAnomalyDetector;
 
                public State()
                {
                }
 
                internal State(BinaryReader reader) : base(reader)
                {
                    WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
                    InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host);
                }
 
                internal override void Save(BinaryWriter writer)
                {
                    base.Save(writer);
                    TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer);
                    TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer);
                }
 
                private protected override void CloneCore(State state)
                {
                    base.CloneCore(state);
                    Contracts.Assert(state is State);
                    var stateLocal = state as State;
                    stateLocal.WindowedBuffer = WindowedBuffer.Clone();
                    stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone();
                    if (_model != null)
                    {
                        _parentAnomalyDetector.Model = _parentAnomalyDetector.Model.Clone();
                        _model = _parentAnomalyDetector.Model;
                    }
                }
 
                private protected override void LearnStateFromDataCore(FixedSizeQueue<Single> data)
                {
                    // This method is empty because there is no need to implement a training logic here.
                }
 
                private protected override void InitializeAnomalyDetector()
                {
                    _parentAnomalyDetector = (SsaAnomalyDetectionBase)Parent;
                    _model = _parentAnomalyDetector.Model;
                }
 
                private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue<Single> windowedBuffer, long iteration)
                {
                    // Get the prediction for the next point in the series
                    Single expectedValue = 0;
                    _model.PredictNext(ref expectedValue);
 
                    if (PreviousPosition == -1)
                        // Feed the current point to the model
                        _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive);
 
                    // Return the error as the raw anomaly score
                    return _parentAnomalyDetector.ErrorFunc(input, expectedValue);
                }
 
                public override void Consume(Single input) => _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive);
            }
        }
    }
}