File: PcaTransformer.cs
Web Access
Project: src\src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj (Microsoft.ML.PCA)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(PrincipalComponentAnalysisTransformer.Summary, typeof(IDataTransform), typeof(PrincipalComponentAnalysisTransformer), typeof(PrincipalComponentAnalysisTransformer.Options), typeof(SignatureDataTransform),
    PrincipalComponentAnalysisTransformer.UserName, PrincipalComponentAnalysisTransformer.LoaderSignature, PrincipalComponentAnalysisTransformer.ShortName)]
 
[assembly: LoadableClass(PrincipalComponentAnalysisTransformer.Summary, typeof(IDataTransform), typeof(PrincipalComponentAnalysisTransformer), null, typeof(SignatureLoadDataTransform),
    PrincipalComponentAnalysisTransformer.UserName, PrincipalComponentAnalysisTransformer.LoaderSignature)]
 
[assembly: LoadableClass(PrincipalComponentAnalysisTransformer.Summary, typeof(PrincipalComponentAnalysisTransformer), null, typeof(SignatureLoadModel),
    PrincipalComponentAnalysisTransformer.UserName, PrincipalComponentAnalysisTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IRowMapper), typeof(PrincipalComponentAnalysisTransformer), null, typeof(SignatureLoadRowMapper),
    PrincipalComponentAnalysisTransformer.UserName, PrincipalComponentAnalysisTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(void), typeof(PrincipalComponentAnalysisTransformer), null, typeof(SignatureEntryPointModule), PrincipalComponentAnalysisTransformer.LoaderSignature)]
 
namespace Microsoft.ML.Transforms
{
    /// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*' />
    public sealed class PrincipalComponentAnalysisTransformer : OneToOneTransformerBase
    {
        internal sealed class Options : TransformInputBase
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public Column[] Columns;
 
            [Argument(ArgumentType.Multiple, HelpText = "The name of the weight column", ShortName = "weight", Purpose = SpecialPurpose.ColumnName)]
            public string ExampleWeightColumnName = PrincipalComponentAnalyzer.Defaults.WeightColumn;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The number of components in the PCA", ShortName = "k")]
            public int Rank = PrincipalComponentAnalyzer.Defaults.Rank;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Oversampling parameter for randomized PCA training", ShortName = "over")]
            public int Oversampling = PrincipalComponentAnalyzer.Defaults.Oversampling;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "If enabled, data is centered to be zero mean")]
            public bool Center = PrincipalComponentAnalyzer.Defaults.EnsureZeroMean;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The seed for random number generation")]
            public int Seed = PrincipalComponentAnalyzer.Defaults.Seed;
        }
 
        internal class Column : OneToOneColumn
        {
            [Argument(ArgumentType.Multiple, HelpText = "The name of the weight column", ShortName = "weight")]
            public string WeightColumn;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The number of components in the PCA", ShortName = "k")]
            public int? Rank;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Oversampling parameter for randomized PCA training", ShortName = "over")]
            public int? Oversampling;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "If enabled, data is centered to be zero mean", ShortName = "center")]
            public bool? Center;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The seed for random number generation", ShortName = "seed")]
            public int? Seed;
 
            internal static Column Parse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                if (!string.IsNullOrEmpty(WeightColumn) || Rank != null || Oversampling != null ||
                    Center != null || Seed != null)
                {
                    return false;
                }
                return TryUnparseCore(sb);
            }
        }
 
        private sealed class TransformInfo
        {
            public readonly int Dimension;
            public readonly int Rank;
 
            public float[][] Eigenvectors;
            public float[] MeanProjected;
 
            public DataViewType OutputType => new VectorDataViewType(NumberDataViewType.Single, Rank);
 
            public TransformInfo(int rank, int dim)
            {
                Dimension = dim;
                Rank = rank;
                Contracts.CheckParam(0 < Rank && Rank <= Dimension, nameof(Rank), "Rank must be positive, and at most the dimension of untransformed data");
            }
 
            public TransformInfo(ModelLoadContext ctx)
            {
                Contracts.AssertValue(ctx);
 
                // *** Binary format ***
                // int: Dimension
                // int: Rank
                // for i=0,..,Rank-1:
                //   float[]: the i'th eigenvector
                // int: the size of MeanProjected (0 if it is null)
                // float[]: MeanProjected
 
                Dimension = ctx.Reader.ReadInt32();
                Rank = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 < Rank && Rank <= Dimension);
 
                Eigenvectors = new float[Rank][];
                for (int i = 0; i < Rank; i++)
                {
                    Eigenvectors[i] = ctx.Reader.ReadFloatArray(Dimension);
                    Contracts.CheckDecode(FloatUtils.IsFinite(Eigenvectors[i]));
                }
 
                MeanProjected = ctx.Reader.ReadFloatArray();
                Contracts.CheckDecode(MeanProjected == null || (MeanProjected.Length == Rank && FloatUtils.IsFinite(MeanProjected)));
            }
 
            internal void Save(ModelSaveContext ctx)
            {
                Contracts.AssertValue(ctx);
 
                // *** Binary format ***
                // int: Dimension
                // int: Rank
                // for i=0,..,Rank-1:
                //   float[]: the i'th eigenvector
                // int: the size of MeanProjected (0 if it is null)
                // float[]: MeanProjected
 
                Contracts.Assert(0 < Rank && Rank <= Dimension);
                ctx.Writer.Write(Dimension);
                ctx.Writer.Write(Rank);
                for (int i = 0; i < Rank; i++)
                {
                    Contracts.Assert(FloatUtils.IsFinite(Eigenvectors[i]));
                    ctx.Writer.WriteSinglesNoCount(Eigenvectors[i].AsSpan(0, Dimension));
                }
                Contracts.Assert(MeanProjected == null || (MeanProjected.Length == Rank && FloatUtils.IsFinite(MeanProjected)));
                ctx.Writer.WriteSingleArray(MeanProjected);
            }
 
            public void ProjectMean(float[] mean)
            {
                Contracts.AssertValue(Eigenvectors);
                if (mean == null)
                {
                    MeanProjected = null;
                    return;
                }
 
                MeanProjected = new float[Rank];
                for (var i = 0; i < Rank; ++i)
                    MeanProjected[i] = VectorUtils.DotProduct(Eigenvectors[i], mean);
            }
        }
 
        internal const string Summary = "PCA is a dimensionality-reduction transform which computes the projection of a numeric vector onto a low-rank subspace.";
        internal const string UserName = "Principal Component Analysis Transform";
        internal const string ShortName = "Pca";
 
        internal const string LoaderSignature = "PcaTransform";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "PCA FUNC",
                //verWrittenCur: 0x00010001, // Initial
                verWrittenCur: 0x00010002, // Got rid of writing float size in model context
                verReadableCur: 0x00010002,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(PrincipalComponentAnalysisTransformer).Assembly.FullName);
        }
 
        private readonly int _numColumns;
        private readonly Mapper.ColumnSchemaInfo[] _schemaInfos;
        private readonly TransformInfo[] _transformInfos;
 
        private const string RegistrationName = "Pca";
 
        internal PrincipalComponentAnalysisTransformer(IHostEnvironment env, IDataView input, PrincipalComponentAnalyzer.ColumnOptions[] columns)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(PrincipalComponentAnalysisTransformer)), GetColumnPairs(columns))
        {
            Host.AssertNonEmpty(ColumnPairs);
            _numColumns = columns.Length;
            _transformInfos = new TransformInfo[_numColumns];
            _schemaInfos = new Mapper.ColumnSchemaInfo[_numColumns];
 
            for (int i = 0; i < _numColumns; i++)
            {
                var colInfo = columns[i];
                var sInfo = _schemaInfos[i] = new Mapper.ColumnSchemaInfo(ColumnPairs[i], input.Schema, colInfo.WeightColumn);
                ValidatePcaInput(Host, colInfo.InputColumnName, sInfo.InputType);
                _transformInfos[i] = new TransformInfo(colInfo.Rank, sInfo.InputType.GetValueCount());
            }
 
            Train(columns, _transformInfos, input);
        }
 
        private PrincipalComponentAnalysisTransformer(IHost host, ModelLoadContext ctx)
         : base(host, ctx)
        {
            Host.AssertValue(ctx);
 
            // *** Binary format ***
            // <prefix handled in static Create method>
            // <base>
            // transformInfos
            Host.AssertNonEmpty(ColumnPairs);
            _numColumns = ColumnPairs.Length;
            _transformInfos = new TransformInfo[_numColumns];
            for (int i = 0; i < _numColumns; i++)
                _transformInfos[i] = new TransformInfo(ctx);
        }
 
        // Factory method for SignatureLoadDataTransform.
        private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
            => Create(env, ctx).MakeDataTransform(input);
 
        // Factory method for SignatureLoadRowMapper.
        private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
            => Create(env, ctx).MakeRowMapper(inputSchema);
 
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            env.CheckValue(input, nameof(input));
            env.CheckValue(options.Columns, nameof(options.Columns));
            var cols = options.Columns.Select(item => new PrincipalComponentAnalyzer.ColumnOptions(
                        item.Name,
                        item.Source,
                        item.WeightColumn,
                        item.Rank ?? options.Rank,
                        item.Oversampling ?? options.Oversampling,
                        item.Center ?? options.Center,
                        item.Seed ?? options.Seed)).ToArray();
            return new PrincipalComponentAnalysisTransformer(env, input, cols).MakeDataTransform(input);
        }
 
        // Factory method for SignatureLoadModel.
        private static PrincipalComponentAnalysisTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(nameof(PrincipalComponentAnalysisTransformer));
 
            host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            if (ctx.Header.ModelVerWritten == 0x00010001)
            {
                int cbFloat = ctx.Reader.ReadInt32();
                env.CheckDecode(cbFloat == sizeof(float));
            }
            return new PrincipalComponentAnalysisTransformer(host, ctx);
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // <base>
            // transformInfos
            SaveColumns(ctx);
            for (int i = 0; i < _transformInfos.Length; i++)
                _transformInfos[i].Save(ctx);
        }
        private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(PrincipalComponentAnalyzer.ColumnOptions[] columns)
        {
            Contracts.CheckValue(columns, nameof(columns));
            return columns.Select(x => (x.Name, x.InputColumnName)).ToArray();
        }
 
        private void Train(PrincipalComponentAnalyzer.ColumnOptions[] columns, TransformInfo[] transformInfos, IDataView trainingData)
        {
            var y = new float[_numColumns][][];
            var omega = new float[_numColumns][][];
            var mean = new float[_numColumns][];
            var oversampledRank = new int[_numColumns];
            double totalMemoryUsageEstimate = 0;
            for (int iinfo = 0; iinfo < _numColumns; iinfo++)
            {
                var rnd = columns[iinfo].Seed == null ? Host.Rand : new Random(columns[iinfo].Seed.Value);
                oversampledRank[iinfo] = Math.Min(transformInfos[iinfo].Rank + columns[iinfo].Oversampling, transformInfos[iinfo].Dimension);
 
                //exact: (size of the 2 big matrices + other minor allocations) / (2^30)
                double colMemoryUsageEstimate = 2.0 * transformInfos[iinfo].Dimension * oversampledRank[iinfo] * sizeof(float) / 1e9;
                totalMemoryUsageEstimate += colMemoryUsageEstimate;
                if (colMemoryUsageEstimate > 2)
                {
                    using (var ch = Host.Start("Memory usage"))
                    {
                        ch.Info("Estimate memory usage for transforming column {1}: {0:G2} GB. If running out of memory, reduce rank and oversampling factor.",
                            colMemoryUsageEstimate, ColumnPairs[iinfo].inputColumnName);
                    }
                }
 
                y[iinfo] = new float[oversampledRank[iinfo]][];
                omega[iinfo] = new float[oversampledRank[iinfo]][];
                for (int i = 0; i < oversampledRank[iinfo]; i++)
                {
                    y[iinfo][i] = new float[transformInfos[iinfo].Dimension];
                    omega[iinfo][i] = new float[transformInfos[iinfo].Dimension];
                    for (int j = 0; j < transformInfos[iinfo].Dimension; j++)
                    {
                        omega[iinfo][i][j] = (float)Stats.SampleFromGaussian(rnd);
                    }
                }
 
                if (columns[iinfo].EnsureZeroMean)
                    mean[iinfo] = new float[transformInfos[iinfo].Dimension];
            }
            if (totalMemoryUsageEstimate > 2)
            {
                using (var ch = Host.Start("Memory usage"))
                {
                    ch.Info("Estimate memory usage for all PCA transforms: {0:G2} GB. If running out of memory, reduce ranks and oversampling factors.",
                        totalMemoryUsageEstimate);
                }
            }
 
            Project(trainingData, mean, omega, y, transformInfos);
 
            for (int iinfo = 0; iinfo < transformInfos.Length; iinfo++)
            {
                //Orthonormalize Y in-place using stabilized Gram Schmidt algorithm
                //Ref: https://en.wikipedia.org/wiki/Gram-Schmidt#Algorithm
                for (var i = 0; i < oversampledRank[iinfo]; ++i)
                {
                    var v = y[iinfo][i];
                    VectorUtils.ScaleBy(v, 1 / VectorUtils.Norm(y[iinfo][i])); // normalize
 
                    // Make the next vectors in the queue orthogonal to the orthonormalized vectors
                    for (var j = i + 1; j < oversampledRank[iinfo]; ++j)
                        VectorUtils.AddMult(v, y[iinfo][j], -VectorUtils.DotProduct(v, y[iinfo][j])); //subtract the projection of y[j] on v
                }
            }
            var q = y; // q in QR decomposition
 
            var b = omega; // reuse the memory allocated by Omega
            Project(trainingData, mean, q, b, transformInfos);
 
            for (int iinfo = 0; iinfo < transformInfos.Length; iinfo++)
            {
                //Compute B2 = B' * B
                var b2 = new float[oversampledRank[iinfo] * oversampledRank[iinfo]];
                for (var i = 0; i < oversampledRank[iinfo]; ++i)
                {
                    for (var j = i; j < oversampledRank[iinfo]; ++j)
                        b2[i * oversampledRank[iinfo] + j] = b2[j * oversampledRank[iinfo] + i] = VectorUtils.DotProduct(b[iinfo][i], b[iinfo][j]);
                }
 
                float[] smallEigenvalues; // eigenvectors and eigenvalues of the small matrix B2.
                float[] smallEigenvectors;
 
                EigenUtils.EigenDecomposition(b2, out smallEigenvalues, out smallEigenvectors);
                transformInfos[iinfo].Eigenvectors = PostProcess(b[iinfo], smallEigenvalues, smallEigenvectors, transformInfos[iinfo].Dimension, oversampledRank[iinfo]);
                transformInfos[iinfo].ProjectMean(mean[iinfo]);
            }
        }
 
        //Project the covariance matrix A on to Omega: Y <- A * Omega
        //A = X' * X / n, where X = data - mean
        //Note that the covariance matrix is not computed explicitly
        private void Project(IDataView trainingData, float[][] mean, float[][][] omega, float[][][] y, TransformInfo[] transformInfos)
        {
            Host.Assert(mean.Length == omega.Length && omega.Length == y.Length && y.Length == _numColumns);
            for (int i = 0; i < omega.Length; i++)
                Contracts.Assert(omega[i].Length == y[i].Length);
 
            // set y to be all zeros
            for (int iinfo = 0; iinfo < y.Length; iinfo++)
            {
                for (int i = 0; i < y[iinfo].Length; i++)
                    Array.Clear(y[iinfo][i], 0, y[iinfo][i].Length);
            }
 
            bool[] center = Enumerable.Range(0, mean.Length).Select(i => mean[i] != null).ToArray();
 
            Double[] totalColWeight = new Double[_numColumns];
 
            bool[] activeColumns = new bool[trainingData.Schema.Count];
            foreach (var sInfo in _schemaInfos)
            {
                activeColumns[sInfo.InputIndex] = true;
                if (sInfo.WeightColumnIndex >= 0)
                    activeColumns[sInfo.WeightColumnIndex] = true;
            }
 
            var inputCols = trainingData.Schema.Where(x => activeColumns[x.Index]);
            using (var cursor = trainingData.GetRowCursor(inputCols))
            {
                var weightGetters = new ValueGetter<float>[_numColumns];
                var columnGetters = new ValueGetter<VBuffer<float>>[_numColumns];
                for (int iinfo = 0; iinfo < _numColumns; iinfo++)
                {
                    var sInfo = _schemaInfos[iinfo];
                    if (sInfo.WeightColumnIndex >= 0)
                        weightGetters[iinfo] = cursor.GetGetter<float>(cursor.Schema[sInfo.WeightColumnIndex]);
                    columnGetters[iinfo] = cursor.GetGetter<VBuffer<float>>(cursor.Schema[sInfo.InputIndex]);
                }
 
                var features = default(VBuffer<float>);
                while (cursor.MoveNext())
                {
                    for (int iinfo = 0; iinfo < _numColumns; iinfo++)
                    {
                        float weight = 1;
                        weightGetters[iinfo]?.Invoke(ref weight);
                        columnGetters[iinfo](ref features);
 
                        var featureValues = features.GetValues();
                        if (FloatUtils.IsFinite(weight) && weight >= 0 && (featureValues.Length == 0 || FloatUtils.IsFinite(featureValues)))
                        {
                            totalColWeight[iinfo] += weight;
 
                            if (center[iinfo])
                                VectorUtils.AddMult(in features, mean[iinfo], weight);
 
                            for (int i = 0; i < omega[iinfo].Length; i++)
                                VectorUtils.AddMult(in features, y[iinfo][i], weight * VectorUtils.DotProductWithOffset(omega[iinfo][i], 0, in features));
                        }
                    }
                }
 
                for (int iinfo = 0; iinfo < _numColumns; iinfo++)
                {
                    if (totalColWeight[iinfo] <= 0)
                        throw Host.Except("Empty data in column '{0}'", ColumnPairs[iinfo].inputColumnName);
                }
 
                for (int iinfo = 0; iinfo < _numColumns; iinfo++)
                {
                    var invn = (float)(1 / totalColWeight[iinfo]);
 
                    for (var i = 0; i < omega[iinfo].Length; ++i)
                        VectorUtils.ScaleBy(y[iinfo][i], invn);
 
                    if (center[iinfo])
                    {
                        VectorUtils.ScaleBy(mean[iinfo], invn);
                        for (int i = 0; i < omega[iinfo].Length; i++)
                            VectorUtils.AddMult(mean[iinfo], y[iinfo][i], -VectorUtils.DotProduct(omega[iinfo][i], mean[iinfo]));
                    }
                }
            }
        }
 
        //return Y * eigenvectors / eigenvalues
        // REVIEW: improve
        private float[][] PostProcess(float[][] y, float[] sigma, float[] z, int d, int k)
        {
            var pinv = new float[k];
            var tmp = new float[k];
 
            for (int i = 0; i < k; i++)
                pinv[i] = (float)(1.0) / ((float)(1e-6) + sigma[i]);
 
            for (int i = 0; i < d; i++)
            {
                for (int j = 0; j < k; j++)
                {
                    tmp[j] = 0;
                    for (int l = 0; l < k; l++)
                        tmp[j] += y[l][i] * z[j * k + l];
                }
                for (int j = 0; j < k; j++)
                    y[j][i] = pinv[j] * tmp[j];
            }
 
            return y;
        }
 
        private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
 
        private protected override void CheckInputColumn(DataViewSchema inputSchema, int col, int srcCol)
        {
            ValidatePcaInput(Host, inputSchema[srcCol].Name, inputSchema[srcCol].Type);
        }
 
        internal static void ValidatePcaInput(IExceptionContext ectx, string name, DataViewType type)
        {
            string inputSchema; // just used for the excpections
 
            if (!(type is VectorDataViewType vectorType && vectorType.Size > 1 && vectorType.ItemType.Equals(NumberDataViewType.Single)))
                throw ectx.ExceptSchemaMismatch(nameof(inputSchema), "input", name, "known-size vector of Single of two or more items", type.ToString());
        }
 
        private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
        {
            public sealed class ColumnSchemaInfo
            {
                public DataViewType InputType { get; }
                public int InputIndex { get; }
                public int WeightColumnIndex { get; }
 
                public ColumnSchemaInfo((string outputColumnName, string inputColumnName) columnPair, DataViewSchema schema, string weightColumn = null)
                {
                    schema.TryGetColumnIndex(columnPair.inputColumnName, out int inputIndex);
                    InputIndex = inputIndex;
                    InputType = schema[columnPair.inputColumnName].Type;
 
                    var weightIndex = -1;
                    if (weightColumn != null)
                    {
                        if (!schema.TryGetColumnIndex(weightColumn, out weightIndex))
                            throw Contracts.Except("Weight column '{0}' does not exist.", weightColumn);
                        Contracts.CheckParam(schema[weightIndex].Type == NumberDataViewType.Single, nameof(weightColumn));
                    }
                    WeightColumnIndex = weightIndex;
                }
            }
 
            private readonly PrincipalComponentAnalysisTransformer _parent;
            private readonly int _numColumns;
 
            public Mapper(PrincipalComponentAnalysisTransformer parent, DataViewSchema inputSchema)
               : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
            {
                _parent = parent;
                _numColumns = parent._numColumns;
                for (int i = 0; i < _numColumns; i++)
                {
                    var colPair = _parent.ColumnPairs[i];
                    var colSchemaInfo = new ColumnSchemaInfo(colPair, inputSchema);
                    ValidatePcaInput(Host, colPair.inputColumnName, colSchemaInfo.InputType);
                    if (colSchemaInfo.InputType.GetVectorSize() != _parent._transformInfos[i].Dimension)
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName,
                            new VectorDataViewType(NumberDataViewType.Single, _parent._transformInfos[i].Dimension).ToString(), colSchemaInfo.InputType.ToString());
                    }
                }
            }
 
            protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            {
                var result = new DataViewSchema.DetachedColumn[_numColumns];
                for (int i = 0; i < _numColumns; i++)
                    result[i] = new DataViewSchema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _parent._transformInfos[i].OutputType, null);
                return result;
            }
 
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
            {
                Contracts.AssertValue(input);
                Contracts.Assert(0 <= iinfo && iinfo < _numColumns);
                disposer = null;
 
                var srcGetter = input.GetGetter<VBuffer<float>>(input.Schema[ColMapNewToOld[iinfo]]);
                var src = default(VBuffer<float>);
 
                ValueGetter<VBuffer<float>> dstGetter = (ref VBuffer<float> dst) =>
                    {
                        srcGetter(ref src);
                        TransformFeatures(Host, in src, ref dst, _parent._transformInfos[iinfo]);
                    };
 
                return dstGetter;
            }
 
            private static void TransformFeatures(IExceptionContext ectx, in VBuffer<float> src, ref VBuffer<float> dst, TransformInfo transformInfo)
            {
                ectx.Check(src.Length == transformInfo.Dimension);
 
                var editor = VBufferEditor.Create(ref dst, transformInfo.Rank);
                for (int i = 0; i < transformInfo.Rank; i++)
                {
                    editor.Values[i] = VectorUtils.DotProductWithOffset(transformInfo.Eigenvectors[i], 0, in src) -
                        (transformInfo.MeanProjected == null ? 0 : transformInfo.MeanProjected[i]);
                }
 
                dst = editor.Commit();
            }
 
            public bool CanSaveOnnx(OnnxContext ctx) => true;
 
            public void SaveAsOnnx(OnnxContext ctx)
            {
                Host.CheckValue(ctx, nameof(ctx));
 
                for (int i = 0; i < _numColumns; i++)
                {
                    var colPair = _parent.ColumnPairs[i];
                    var transformInfo = _parent._transformInfos[i];
                    string inputColumnName = colPair.inputColumnName;
                    string outputColumnName = colPair.outputColumnName;
                    if (!ctx.ContainsColumn(inputColumnName))
                    {
                        ctx.RemoveColumn(colPair.outputColumnName, false);
                        continue;
                    }
 
                    var dstVariableName = ctx.AddIntermediateVariable(transformInfo.OutputType, outputColumnName);
                    SaveAsOnnxCore(ctx, i, ctx.GetVariableName(inputColumnName), dstVariableName);
                }
            }
 
            private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
            {
                Host.CheckValue(ctx, nameof(ctx));
 
                const int minimumOpSetVersion = 9;
                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
 
                TransformInfo transformInfo = _parent._transformInfos[iinfo];
 
                // When the transformer is loaded from a model file,
                // _schemaInfos does not exist. Infer the input type
                // from the transformInfo dimension.
                DataViewType inputType = (_parent._schemaInfos != null) ?
                                          _parent._schemaInfos[iinfo].InputType :
                                          new VectorDataViewType(NumberDataViewType.Single, transformInfo.Dimension);
 
                float[] principalComponents = new float[transformInfo.Rank * transformInfo.Dimension];
                for (int i = 0; i < transformInfo.Rank; i++)
                {
                    Array.Copy(transformInfo.Eigenvectors[i], 0, principalComponents, i * transformInfo.Dimension, transformInfo.Dimension);
                }
                long[] pcaDims = { transformInfo.Rank, transformInfo.Dimension };
                var pcaMatrix = ctx.AddInitializer(principalComponents, pcaDims, "principalComponents");
 
                float[] zeroMean = new float[transformInfo.Rank];
                if (transformInfo.MeanProjected != null)
                {
                    Array.Copy(transformInfo.MeanProjected, zeroMean, transformInfo.Rank);
                }
 
                long[] meanDims = { transformInfo.Rank };
                var zeroMeanNode = ctx.AddInitializer(zeroMean, meanDims, "meanVector");
 
                // NB: Hack
                // Currently ML.NET persists ONNX graphs in proto-buf 3 format but the Onnx runtime uses the proto-buf 2 format
                // There is an incompatibility between the two where proto-buf 3 does not include variables whose values are zero
                // In the Gemm node below, we want the srcVariableName matrix to be sent in without a transpose, so transA has to be zero
                // Due to the incompatibility, we get an exception from the Onnx runtime
                // To workaround this, we transpose the input data first with the Transpose operator and then use the Gemm operator with transA=1
                // This should be removed once incompatibility is fixed.
                string opType;
                opType = "Transpose";
                var transposeOutput = ctx.AddIntermediateVariable(inputType, "TransposeOutput", true);
                var transposeNode = ctx.CreateNode(opType, srcVariableName, transposeOutput, ctx.GetNodeName(opType), "");
 
                opType = "Gemm";
                var gemmNode = ctx.CreateNode(opType, new[] { transposeOutput, pcaMatrix, zeroMeanNode }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
                gemmNode.AddAttribute("alpha", 1.0);
                gemmNode.AddAttribute("beta", -1.0);
                gemmNode.AddAttribute("transA", 1);
                gemmNode.AddAttribute("transB", 1);
            }
        }
 
        [TlcModule.EntryPoint(Name = "Transforms.PcaCalculator",
            Desc = Summary,
            UserName = UserName,
            ShortName = ShortName)]
        internal static CommonOutputs.TransformOutput Calculate(IHostEnvironment env, Options input)
        {
            var h = EntryPointUtils.CheckArgsAndCreateHost(env, "Pca", input);
            var view = PrincipalComponentAnalysisTransformer.Create(h, input, input.Data);
            return new CommonOutputs.TransformOutput()
            {
                Model = new TransformModelImpl(h, view, input.Data),
                OutputData = view
            };
        }
    }
 
    /// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*'/>
    public sealed class PrincipalComponentAnalyzer : IEstimator<PrincipalComponentAnalysisTransformer>
    {
        [BestFriend]
        internal static class Defaults
        {
            public const string WeightColumn = null;
            public const int Rank = 20;
            public const int Oversampling = 20;
            public const bool EnsureZeroMean = true;
            public const int Seed = 0;
        }
 
        /// <summary>
        /// Describes how the transformer handles one column pair.
        /// </summary>
        [BestFriend]
        internal sealed class ColumnOptions
        {
            /// <summary>
            /// Name of the column resulting from the transformation of <see cref="InputColumnName"/>.
            /// </summary>
            public readonly string Name;
            /// <summary>
            /// Name of column to transform.
            /// </summary>
            public readonly string InputColumnName;
            /// <summary>
            /// The name of the weight column.
            /// </summary>
            public readonly string WeightColumn;
            /// <summary>
            /// The number of components in the PCA.
            /// </summary>
            public readonly int Rank;
            /// <summary>
            /// Oversampling parameter for randomized PCA training.
            /// </summary>
            public readonly int Oversampling;
            /// <summary>
            /// If enabled, data is centered to be zero mean.
            /// </summary>
            public readonly bool EnsureZeroMean;
            /// <summary>
            /// The seed for random number generation.
            /// </summary>
            public readonly int? Seed;
 
            /// <summary>
            /// Describes how the transformer handles one column pair.
            /// </summary>
            /// <param name="name">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
            /// <param name="inputColumnName">Name of column to transform.
            /// If set to <see langword="null"/>, the value of the <paramref name="name"/> will be used as source.</param>
            /// <param name="weightColumn">The name of the weight column.</param>
            /// <param name="rank">The number of components in the PCA.</param>
            /// <param name="overSampling">Oversampling parameter for randomized PCA training.</param>
            /// <param name="ensureZeroMean">If enabled, data is centered to be zero mean.</param>
            /// <param name="seed">The random seed. If unspecified random state will be instead derived from the <see cref="MLContext"/>.</param>
            public ColumnOptions(string name,
                              string inputColumnName = null,
                              string weightColumn = Defaults.WeightColumn,
                              int rank = Defaults.Rank,
                              int overSampling = Defaults.Oversampling,
                              bool ensureZeroMean = Defaults.EnsureZeroMean,
                              int? seed = null)
            {
                Name = name;
                InputColumnName = inputColumnName ?? name;
                WeightColumn = weightColumn;
                Rank = rank;
                Oversampling = overSampling;
                EnsureZeroMean = ensureZeroMean;
                Seed = seed;
                Contracts.CheckParam(Oversampling >= 0, nameof(Oversampling), "Oversampling must be non-negative.");
                Contracts.CheckParam(Rank > 0, nameof(Rank), "Rank must be positive.");
            }
        }
 
        private readonly IHost _host;
        private readonly ColumnOptions[] _columns;
 
        /// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*'/>
        /// <param name="env">The environment to use.</param>
        /// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
        /// <param name="inputColumnName">Name of the column to transform.
        /// If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
        /// <param name="weightColumn">The name of the weight column.</param>
        /// <param name="rank">The number of components in the PCA.</param>
        /// <param name="overSampling">Oversampling parameter for randomized PCA training.</param>
        /// <param name="ensureZeroMean">If enabled, data is centered to be zero mean.</param>
        /// <param name="seed">The seed for random number generation.</param>
        internal PrincipalComponentAnalyzer(IHostEnvironment env,
            string outputColumnName,
            string inputColumnName = null,
            string weightColumn = Defaults.WeightColumn, int rank = Defaults.Rank,
            int overSampling = Defaults.Oversampling, bool ensureZeroMean = Defaults.EnsureZeroMean,
            int? seed = null)
            : this(env, new ColumnOptions(outputColumnName, inputColumnName ?? outputColumnName, weightColumn, rank, overSampling, ensureZeroMean, seed))
        {
        }
 
        /// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*'/>
        /// <param name="env">The environment to use.</param>
        /// <param name="columns">The dataset columns to use, and their specific settings.</param>
        internal PrincipalComponentAnalyzer(IHostEnvironment env, params ColumnOptions[] columns)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(PrincipalComponentAnalyzer));
            _columns = columns;
        }
 
        /// <summary>
        /// Trains and returns a <see cref="PrincipalComponentAnalysisTransformer"/>.
        /// </summary>
        public PrincipalComponentAnalysisTransformer Fit(IDataView input) => new PrincipalComponentAnalysisTransformer(_host, input, _columns);
 
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);
            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
 
                if (col.Kind != SchemaShape.Column.VectorKind.Vector || !col.ItemType.Equals(NumberDataViewType.Single))
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
 
                result[colInfo.Name] = new SchemaShape.Column(colInfo.Name,
                    SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }
 
            return new SchemaShape(result.Values);
        }
    }
}