File: Transforms\NormalizeUtils.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
using Newtonsoft.Json.Linq;
 
[assembly: LoadableClass(typeof(void), typeof(Normalize), null, typeof(SignatureEntryPointModule), "Normalize")]
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// Signature for a repository based loader of an <see cref="IColumnFunction"/>.
    /// </summary>
    [BestFriend]
    internal delegate void SignatureLoadColumnFunction(ModelLoadContext ctx, IHost host, DataViewType typeSrc);
 
    internal interface IColumnFunctionBuilder
    {
        /// <summary>
        /// Trains on the current value.
        /// </summary>
        /// <returns>True if it can use more values for training.</returns>
        bool ProcessValue();
 
        /// <summary>
        /// Finishes training and returns a column function.
        /// </summary>
        IColumnFunction CreateColumnFunction();
    }
 
    /// <summary>
    /// Interface to define an aggregate function over values
    /// </summary>
    [BestFriend]
    internal interface IColumnAggregator<T>
    {
        /// <summary>
        /// Updates the aggregate function with a value
        /// </summary>
        void ProcessValue(in T val);
 
        /// <summary>
        /// Finishes the aggregation
        /// </summary>
        void Finish();
    }
 
    [BestFriend]
    internal interface IColumnFunction : ICanSaveModel
    {
        Delegate GetGetter(DataViewRow input, int icol);
 
        void AttachMetadata(MetadataDispatcher.Builder bldr, DataViewType typeSrc);
 
        JToken PfaInfo(BoundPfaContext ctx, JToken srcToken);
 
        bool CanSaveOnnx(OnnxContext ctx);
 
        bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount);
 
        NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams();
    }
 
    /// <summary>
    /// This contains entry-point definitions related to <see cref="NormalizeTransform"/>.
    /// </summary>
    [BestFriend]
    internal static class Normalize
    {
        [TlcModule.EntryPoint(Name = "Transforms.MinMaxNormalizer", Desc = NormalizeTransform.MinMaxNormalizerSummary, UserName = NormalizeTransform.MinMaxNormalizerUserName, ShortName = NormalizeTransform.MinMaxNormalizerShortName)]
        public static CommonOutputs.TransformOutput MinMax(IHostEnvironment env, NormalizeTransform.MinMaxArguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("MinMax");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            var xf = NormalizeTransform.Create(host, input, input.Data);
            return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf };
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.MeanVarianceNormalizer", Desc = NormalizeTransform.MeanVarNormalizerSummary, UserName = NormalizeTransform.MeanVarNormalizerUserName, ShortName = NormalizeTransform.MeanVarNormalizerShortName)]
        public static CommonOutputs.TransformOutput MeanVar(IHostEnvironment env, NormalizeTransform.MeanVarArguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("MeanVar");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            var xf = NormalizeTransform.Create(host, input, input.Data);
            return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf };
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.LogMeanVarianceNormalizer", Desc = NormalizeTransform.LogMeanVarNormalizerSummary, UserName = NormalizeTransform.LogMeanVarNormalizerUserName, ShortName = NormalizeTransform.LogMeanVarNormalizerShortName)]
        public static CommonOutputs.TransformOutput LogMeanVar(IHostEnvironment env, NormalizeTransform.LogMeanVarArguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("LogMeanVar");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            var xf = NormalizeTransform.Create(host, input, input.Data);
            return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf };
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.BinNormalizer", Desc = NormalizeTransform.BinNormalizerSummary, UserName = NormalizeTransform.BinNormalizerUserName, ShortName = NormalizeTransform.BinNormalizerShortName)]
        public static CommonOutputs.TransformOutput Bin(IHostEnvironment env, NormalizeTransform.BinArguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Bin");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            var xf = NormalizeTransform.Create(host, input, input.Data);
            return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf };
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.RobustScalingNormalizer", Desc = NormalizeTransform.RobustScalingNormalizerSummary, UserName = NormalizeTransform.RobustScalingNormalizerUserName, ShortName = NormalizeTransform.RobustScalingNormalizerShortName)]
        public static CommonOutputs.TransformOutput RobustScaling(IHostEnvironment env, NormalizeTransform.RobustScalingArguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("RobustScaling");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            var xf = NormalizeTransform.Create(host, input, input.Data);
            return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf };
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.ConditionalNormalizer", Desc = "Normalize the columns only if needed", UserName = "Normalize If Needed")]
        public static CommonOutputs.MacroOutput<CommonOutputs.TransformOutput> IfNeeded(
            IHostEnvironment env,
            NormalizeTransform.MinMaxArguments input,
            EntryPointNode node)
        {
            var schema = input.Data.Schema;
            var columnsToNormalize = new List<NormalizeTransform.AffineColumn>();
            foreach (var column in input.Columns)
            {
                if (!schema.TryGetColumnIndex(column.Source, out int col))
                    throw env.ExceptUserArg(nameof(input.Columns), $"Column '{column.Source}' does not exist.");
                if (!schema[col].IsNormalized())
                    columnsToNormalize.Add(column);
            }
 
            var entryPoints = new List<EntryPointNode>();
            if (columnsToNormalize.Count == 0)
            {
                var entryPointNode = EntryPointNode.Create(env, "Transforms.NoOperation", new NopTransform.NopInput(), node.Context, node.InputBindingMap, node.InputMap, node.OutputMap);
                entryPoints.Add(entryPointNode);
            }
            else
            {
                input.Columns = columnsToNormalize.ToArray();
                var entryPointNode = EntryPointNode.Create(env, "Transforms.MinMaxNormalizer", input, node.Context, node.InputBindingMap, node.InputMap, node.OutputMap);
                entryPoints.Add(entryPointNode);
            }
 
            return new CommonOutputs.MacroOutput<CommonOutputs.TransformOutput>() { Nodes = entryPoints };
        }
    }
}