File: RandomFourierFeaturizing.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(ApproximatedKernelTransformer.Summary, typeof(IDataTransform), typeof(ApproximatedKernelTransformer), typeof(ApproximatedKernelTransformer.Options), typeof(SignatureDataTransform),
    "Random Fourier Features Transform", "RffTransform", "Rff")]
 
[assembly: LoadableClass(ApproximatedKernelTransformer.Summary, typeof(IDataTransform), typeof(ApproximatedKernelTransformer), null, typeof(SignatureLoadDataTransform),
    "Random Fourier Features Transform", ApproximatedKernelTransformer.LoaderSignature)]
 
[assembly: LoadableClass(ApproximatedKernelTransformer.Summary, typeof(ApproximatedKernelTransformer), null, typeof(SignatureLoadModel),
    "Random Fourier Features Transform", ApproximatedKernelTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IRowMapper), typeof(ApproximatedKernelTransformer), null, typeof(SignatureLoadRowMapper),
    "Random Fourier Features Transform", ApproximatedKernelTransformer.LoaderSignature)]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// <see cref="ITransformer"/> resulting from fitting an <see cref="ApproximatedKernelMappingEstimator"/>.
    /// </summary>
    public sealed class ApproximatedKernelTransformer : OneToOneTransformerBase
    {
        internal sealed class Options
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public Column[] Columns;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The number of random Fourier features to create", ShortName = "dim")]
            public int NewDim = ApproximatedKernelMappingEstimator.Defaults.Rank;
 
            [Argument(ArgumentType.Multiple, HelpText = "Which kernel to use?", ShortName = "kernel", SignatureType = typeof(SignatureKernelBase))]
            public IComponentFactory<KernelBase> MatrixGenerator = new GaussianKernel.Options();
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Create two features for every random Fourier frequency? (one for cos and one for sin)")]
            public bool UseSin = ApproximatedKernelMappingEstimator.Defaults.UseCosAndSinBases;
 
            [Argument(ArgumentType.LastOccurrenceWins,
                HelpText = "The seed of the random number generator for generating the new features (if unspecified, " +
                "the global random is used)")]
            public int? Seed;
        }
 
        internal sealed class Column : OneToOneColumn
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "The number of random Fourier features to create", ShortName = "dim")]
            public int? NewDim;
 
            [Argument(ArgumentType.Multiple, HelpText = "which kernel to use?", ShortName = "kernel", SignatureType = typeof(SignatureKernelBase))]
            public IComponentFactory<KernelBase> MatrixGenerator;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "create two features for every random Fourier frequency? (one for cos and one for sin)")]
            public bool? UseSin;
 
            [Argument(ArgumentType.LastOccurrenceWins,
                HelpText = "The seed of the random number generator for generating the new features (if unspecified, " +
                           "the global random is used)")]
            public int? Seed;
 
            internal static Column Parse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                if (NewDim != null || MatrixGenerator != null || UseSin != null || Seed != null)
                    return false;
                return TryUnparseCore(sb);
            }
        }
 
        private sealed class TransformInfo
        {
            public readonly int NewDim;
            public readonly int SrcDim;
 
            // the matrix containing the random fourier vectors
            public readonly AlignedArray RndFourierVectors;
 
            // the random rotations
            public readonly AlignedArray RotationTerms;
 
            private readonly FourierRandomNumberGeneratorBase _matrixGenerator;
            private readonly bool _useSin;
            private readonly TauswortheHybrid _rand;
            private readonly TauswortheHybrid.State _state;
 
            public TransformInfo(IHost host, ApproximatedKernelMappingEstimator.ColumnOptions column, int d, float avgDist)
            {
                Contracts.AssertValue(host);
 
                SrcDim = d;
                NewDim = column.Rank;
                host.CheckUserArg(NewDim > 0, nameof(column.Rank));
                _useSin = column.UseCosAndSinBases;
                var seed = column.Seed;
                _rand = seed.HasValue ? RandomUtils.Create(seed) : RandomUtils.Create(host.Rand);
                _state = _rand.GetState();
 
                var generator = column.Generator;
                _matrixGenerator = generator.GetRandomNumberGenerator(avgDist);
 
                int roundedUpD = RoundUp(NewDim, _cfltAlign);
                int roundedUpNumFeatures = RoundUp(SrcDim, _cfltAlign);
                RndFourierVectors = new AlignedArray(roundedUpD * roundedUpNumFeatures, CpuMathUtils.GetVectorAlignment());
                RotationTerms = _useSin ? null : new AlignedArray(roundedUpD, CpuMathUtils.GetVectorAlignment());
 
                InitializeFourierCoefficients(roundedUpNumFeatures, roundedUpD);
            }
 
            public TransformInfo(IHostEnvironment env, ModelLoadContext ctx, string directoryName)
            {
                env.AssertValue(env);
 
                // *** Binary format ***
                // int: d (number of untransformed features)
                // int: NewDim (number of transformed features)
                // bool: UseSin
                // uint[4]: the seeds for the pseudo random number generator.
 
                SrcDim = ctx.Reader.ReadInt32();
 
                NewDim = ctx.Reader.ReadInt32();
                env.CheckDecode(NewDim > 0);
 
                _useSin = ctx.Reader.ReadBoolByte();
 
                var length = ctx.Reader.ReadInt32();
                env.CheckDecode(length == 4);
                _state = TauswortheHybrid.State.Load(ctx.Reader);
                _rand = new TauswortheHybrid(_state);
 
                env.CheckDecode(ctx.Repository != null &&
                    ctx.LoadModelOrNull<FourierRandomNumberGeneratorBase, SignatureLoadModel>(env, out _matrixGenerator, directoryName));
 
                // initialize the transform matrix
                int roundedUpD = RoundUp(NewDim, _cfltAlign);
                int roundedUpNumFeatures = RoundUp(SrcDim, _cfltAlign);
                RndFourierVectors = new AlignedArray(roundedUpD * roundedUpNumFeatures, CpuMathUtils.GetVectorAlignment());
                RotationTerms = _useSin ? null : new AlignedArray(roundedUpD, CpuMathUtils.GetVectorAlignment());
                InitializeFourierCoefficients(roundedUpNumFeatures, roundedUpD);
            }
 
            internal void Save(ModelSaveContext ctx, string directoryName)
            {
                Contracts.AssertValue(ctx);
 
                // *** Binary format ***
                // int: d (number of untransformed features)
                // int: NewDim (number of transformed features)
                // bool: UseSin
                // uint[4]: the seeds for the pseudo random number generator.
 
                ctx.Writer.Write(SrcDim);
                ctx.Writer.Write(NewDim);
                ctx.Writer.WriteBoolByte(_useSin);
                ctx.Writer.Write(4); // fake array length
                _state.Save(ctx.Writer);
                ctx.SaveModel(_matrixGenerator, directoryName);
            }
 
            private void GetDDimensionalFeatureMapping(int rowSize)
            {
                Contracts.Assert(rowSize >= SrcDim);
 
                for (int i = 0; i < NewDim; i++)
                {
                    for (int j = 0; j < SrcDim; j++)
                        RndFourierVectors[i * rowSize + j] = _matrixGenerator.Next(_rand);
                }
            }
 
            private void GetDRotationTerms(int colSize)
            {
                for (int i = 0; i < colSize; ++i)
                    RotationTerms[i] = (_rand.NextSingle() - (float)0.5) * (float)Math.PI;
            }
 
            private void InitializeFourierCoefficients(int rowSize, int colSize)
            {
                GetDDimensionalFeatureMapping(rowSize);
 
                if (!_useSin)
                    GetDRotationTerms(NewDim);
            }
        }
 
        internal const string Summary = "This transform maps numeric vectors to a random low-dimensional feature space. It is useful when data has non-linear features, "
            + "since the transform is designed so that the inner products of the transformed data are approximately equal to those in the feature space of a user specified "
            + "shift-invariant kernel.";
 
        internal const string LoaderSignature = "RffTransform";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "RFF FUNC",
                //verWrittenCur: 0x00010001, // Initial
                verWrittenCur: 0x00010002, // Get rid of writing float size in model context
                verReadableCur: 0x00010002,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(ApproximatedKernelTransformer).Assembly.FullName);
        }
 
        private readonly TransformInfo[] _transformInfos;
 
        private static readonly int _cfltAlign = CpuMathUtils.GetVectorAlignment() / sizeof(float);
 
        private static string TestColumnType(DataViewType type)
        {
            if (type is VectorDataViewType vectorType && vectorType.IsKnownSize && vectorType.ItemType == NumberDataViewType.Single)
                return null;
            return "Expected known-size vector of Single";
        }
 
        private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ApproximatedKernelMappingEstimator.ColumnOptions[] columns)
        {
            Contracts.CheckValue(columns, nameof(columns));
            return columns.Select(x => (x.Name, x.InputColumnName)).ToArray();
        }
 
        private protected override void CheckInputColumn(DataViewSchema inputSchema, int col, int srcCol)
        {
            var type = inputSchema[srcCol].Type;
            string reason = TestColumnType(type);
            if (reason != null)
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, reason, type.ToString());
            if (_transformInfos[col].SrcDim != type.GetVectorSize())
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName,
                    new VectorDataViewType(NumberDataViewType.Single, _transformInfos[col].SrcDim).ToString(), type.ToString());
        }
 
        internal ApproximatedKernelTransformer(IHostEnvironment env, IDataView input, ApproximatedKernelMappingEstimator.ColumnOptions[] columns)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ApproximatedKernelTransformer)), GetColumnPairs(columns))
        {
            var avgDistances = GetAvgDistances(columns, input);
            _transformInfos = new TransformInfo[columns.Length];
            for (int i = 0; i < columns.Length; i++)
            {
                input.Schema.TryGetColumnIndex(columns[i].InputColumnName, out int srcCol);
                var typeSrc = input.Schema[srcCol].Type;
                _transformInfos[i] = new TransformInfo(Host.Register(string.Format("column{0}", i)), columns[i],
                    typeSrc.GetValueCount(), avgDistances[i]);
            }
        }
 
        // Round cflt up to a multiple of cfltAlign.
        private static int RoundUp(int cflt, int cfltAlign)
        {
            Contracts.Assert(0 < cflt);
            // cfltAlign should be a power of two.
            Contracts.Assert(0 < cfltAlign && (cfltAlign & (cfltAlign - 1)) == 0);
 
            // Determine the number of "blobs" of size cfltAlign.
            int cblob = (cflt + cfltAlign - 1) / cfltAlign;
            return cblob * cfltAlign;
        }
 
        private float[] GetAvgDistances(ApproximatedKernelMappingEstimator.ColumnOptions[] columns, IDataView input)
        {
            var avgDistances = new float[columns.Length];
            const int reservoirSize = 5000;
            var activeColumns = new List<DataViewSchema.Column>();
            int[] srcCols = new int[columns.Length];
            for (int i = 0; i < columns.Length; i++)
            {
                if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].inputColumnName, out int srcCol))
                    throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].inputColumnName);
                var type = input.Schema[srcCol].Type;
                string reason = TestColumnType(type);
                if (reason != null)
                    throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].inputColumnName, reason, type.ToString());
                srcCols[i] = srcCol;
                activeColumns.Add(input.Schema[srcCol]);
            }
            var reservoirSamplers = new ReservoirSamplerWithReplacement<VBuffer<float>>[columns.Length];
            using (var cursor = input.GetRowCursor(activeColumns))
            {
                for (int i = 0; i < columns.Length; i++)
                {
                    var rng = columns[i].Seed.HasValue ? RandomUtils.Create(columns[i].Seed.Value) : Host.Rand;
                    var srcType = input.Schema[srcCols[i]].Type;
                    if (srcType is VectorDataViewType)
                    {
                        var get = cursor.GetGetter<VBuffer<float>>(cursor.Schema[srcCols[i]]);
                        reservoirSamplers[i] = new ReservoirSamplerWithReplacement<VBuffer<float>>(rng, reservoirSize, get);
                    }
                    else
                    {
                        var getOne = cursor.GetGetter<float>(cursor.Schema[srcCols[i]]);
                        float val = 0;
                        ValueGetter<VBuffer<float>> get =
                            (ref VBuffer<float> dst) =>
                            {
                                getOne(ref val);
                                dst = new VBuffer<float>(1, new[] { val });
                            };
                        reservoirSamplers[i] = new ReservoirSamplerWithReplacement<VBuffer<float>>(rng, reservoirSize, get);
                    }
                }
 
                while (cursor.MoveNext())
                {
                    for (int i = 0; i < columns.Length; i++)
                        reservoirSamplers[i].Sample();
                }
                for (int i = 0; i < columns.Length; i++)
                    reservoirSamplers[i].Lock();
 
                for (int iinfo = 0; iinfo < columns.Length; iinfo++)
                {
                    var instanceCount = reservoirSamplers[iinfo].NumSampled;
 
                    // If the number of pairs is at most the maximum reservoir size / 2, we go over all the pairs,
                    // so we get all the examples. Otherwise, get a sample with replacement.
                    VBuffer<float>[] res;
                    int resLength;
                    if (instanceCount < reservoirSize && instanceCount * (instanceCount - 1) <= reservoirSize)
                    {
                        res = reservoirSamplers[iinfo].GetCache();
                        resLength = reservoirSamplers[iinfo].Size;
                        Contracts.Assert(resLength == instanceCount);
                    }
                    else
                    {
                        res = reservoirSamplers[iinfo].GetSample().ToArray();
                        resLength = res.Length;
                    }
 
                    // If the dataset contains only one valid Instance, then we can't learn anything anyway, so just return 1.
                    if (instanceCount <= 1)
                        avgDistances[iinfo] = 1;
                    else
                    {
                        float[] distances;
                        var generator = columns[iinfo].Generator;
                        // If the number of pairs is at most the maximum reservoir size / 2, go over all the pairs.
                        if (resLength < reservoirSize)
                        {
                            distances = new float[instanceCount * (instanceCount - 1) / 2];
                            int count = 0;
                            for (int i = 0; i < instanceCount; i++)
                            {
                                for (int j = i + 1; j < instanceCount; j++)
                                    distances[count++] = generator.Distance(in res[i], in res[j]);
                            }
                            Host.Assert(count == distances.Length);
                        }
                        else
                        {
                            distances = new float[reservoirSize / 2];
                            for (int i = 0; i < reservoirSize - 1; i += 2)
                                distances[i / 2] = generator.Distance(in res[i], in res[i + 1]);
                        }
 
                        // If by chance, in the random permutation all the pairs are the same instance we return 1.
                        float median = MathUtils.GetMedianInPlace(distances, distances.Length);
                        avgDistances[iinfo] = median == 0 ? 1 : median;
                    }
                }
                return avgDistances;
            }
        }
 
        // Factory method for SignatureLoadDataTransform.
        private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
            => Create(env, ctx).MakeDataTransform(input);
 
        // Factory method for SignatureLoadRowMapper.
        private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
            => Create(env, ctx).MakeRowMapper(inputSchema);
 
        private ApproximatedKernelTransformer(IHost host, ModelLoadContext ctx)
         : base(host, ctx)
        {
            // *** Binary format ***
            // <prefix handled in static Create method>
            // <base>
            // transformInfos
            var columnsLength = ColumnPairs.Length;
            _transformInfos = new TransformInfo[columnsLength];
            for (int i = 0; i < columnsLength; i++)
            {
                _transformInfos[i] = new TransformInfo(Host, ctx,
                    string.Format("MatrixGenerator{0}", i));
            }
        }
 
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            env.CheckValue(options.MatrixGenerator, nameof(options.MatrixGenerator));
            env.CheckValue(input, nameof(input));
 
            env.CheckValue(options.Columns, nameof(options.Columns));
            var cols = new ApproximatedKernelMappingEstimator.ColumnOptions[options.Columns.Length];
            using (var ch = env.Start("ValidateArgs"))
            {
 
                for (int i = 0; i < cols.Length; i++)
                {
                    var item = options.Columns[i];
                    cols[i] = new ApproximatedKernelMappingEstimator.ColumnOptions(
                        item.Name,
                        item.NewDim ?? options.NewDim,
                        item.UseSin ?? options.UseSin,
                        item.Source ?? item.Name,
                        (item.MatrixGenerator ?? options.MatrixGenerator).CreateComponent(env),
                        item.Seed ?? options.Seed);
                }
            }
            return new ApproximatedKernelTransformer(env, input, cols).MakeDataTransform(input);
        }
 
        // Factory method for SignatureLoadModel.
        private static ApproximatedKernelTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(nameof(ApproximatedKernelTransformer));
 
            host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            if (ctx.Header.ModelVerWritten == 0x00010001)
            {
                int cbFloat = ctx.Reader.ReadInt32();
                env.CheckDecode(cbFloat == sizeof(float));
            }
            return new ApproximatedKernelTransformer(host, ctx);
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
 
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
            // *** Binary format ***
            // <base>
            // transformInfos
            SaveColumns(ctx);
            for (int i = 0; i < _transformInfos.Length; i++)
                _transformInfos[i].Save(ctx, string.Format("MatrixGenerator{0}", i));
        }
 
        private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
 
        private sealed class Mapper : OneToOneMapperBase
        {
            private readonly DataViewType[] _srcTypes;
            private readonly int[] _srcCols;
            private readonly DataViewType[] _types;
            private readonly ApproximatedKernelTransformer _parent;
 
            public Mapper(ApproximatedKernelTransformer parent, DataViewSchema inputSchema)
               : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
            {
                _parent = parent;
                _types = new DataViewType[_parent.ColumnPairs.Length];
                _srcTypes = new DataViewType[_parent.ColumnPairs.Length];
                _srcCols = new int[_parent.ColumnPairs.Length];
                for (int i = 0; i < _parent.ColumnPairs.Length; i++)
                {
                    inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]);
                    var srcCol = inputSchema[_srcCols[i]];
                    _srcTypes[i] = srcCol.Type;
                    //validate typeSrc.ValueCount and transformInfo.SrcDim
                    _types[i] = new VectorDataViewType(NumberDataViewType.Single, _parent._transformInfos[i].RotationTerms == null ?
                    _parent._transformInfos[i].NewDim * 2 : _parent._transformInfos[i].NewDim);
                }
            }
 
            protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            {
                var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length];
                for (int i = 0; i < _parent.ColumnPairs.Length; i++)
                    result[i] = new DataViewSchema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], null);
                return result;
            }
 
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
            {
                Contracts.AssertValue(input);
                Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                disposer = null;
                if (_srcTypes[iinfo] is VectorDataViewType)
                    return GetterFromVectorType(input, iinfo);
                return GetterFromFloatType(input, iinfo);
            }
 
            private ValueGetter<VBuffer<float>> GetterFromVectorType(DataViewRow input, int iinfo)
            {
                var getSrc = input.GetGetter<VBuffer<float>>(input.Schema[_srcCols[iinfo]]);
                var src = default(VBuffer<float>);
 
                var featuresAligned = new AlignedArray(RoundUp(_srcTypes[iinfo].GetValueCount(), _cfltAlign), CpuMathUtils.GetVectorAlignment());
                var productAligned = new AlignedArray(RoundUp(_parent._transformInfos[iinfo].NewDim, _cfltAlign), CpuMathUtils.GetVectorAlignment());
 
                return
                    (ref VBuffer<float> dst) =>
                    {
                        getSrc(ref src);
                        TransformFeatures(in src, ref dst, _parent._transformInfos[iinfo], featuresAligned, productAligned);
                    };
 
            }
 
            private ValueGetter<VBuffer<float>> GetterFromFloatType(DataViewRow input, int iinfo)
            {
                var getSrc = input.GetGetter<float>(input.Schema[_srcCols[iinfo]]);
                var src = default(float);
 
                var featuresAligned = new AlignedArray(RoundUp(1, _cfltAlign), CpuMathUtils.GetVectorAlignment());
                var productAligned = new AlignedArray(RoundUp(_parent._transformInfos[iinfo].NewDim, _cfltAlign), CpuMathUtils.GetVectorAlignment());
 
                var oneDimensionalVector = new VBuffer<float>(1, new float[] { 0 });
 
                return
                    (ref VBuffer<float> dst) =>
                    {
                        getSrc(ref src);
                        VBufferEditor.CreateFromBuffer(ref oneDimensionalVector).Values[0] = src;
                        TransformFeatures(in oneDimensionalVector, ref dst, _parent._transformInfos[iinfo], featuresAligned, productAligned);
                    };
            }
 
            private void TransformFeatures(in VBuffer<float> src, ref VBuffer<float> dst, TransformInfo transformInfo,
                AlignedArray featuresAligned, AlignedArray productAligned)
            {
                Host.Check(src.Length == transformInfo.SrcDim, "column does not have the expected dimensionality.");
 
                float scale;
                int newDstLength;
                if (transformInfo.RotationTerms != null)
                {
                    newDstLength = transformInfo.NewDim;
                    scale = MathUtils.Sqrt(2.0f / transformInfo.NewDim);
                }
                else
                {
                    newDstLength = 2 * transformInfo.NewDim;
                    scale = MathUtils.Sqrt(1.0f / transformInfo.NewDim);
                }
 
                if (src.IsDense)
                {
                    featuresAligned.CopyFrom(src.GetValues());
                    CpuMathUtils.MatrixTimesSource(false, transformInfo.RndFourierVectors, featuresAligned, productAligned,
                        transformInfo.NewDim);
                }
                else
                {
                    // This overload of MatTimesSrc ignores the values in slots that are not in src.Indices, so there is
                    // no need to zero them out.
                    var srcValues = src.GetValues();
                    var srcIndices = src.GetIndices();
                    featuresAligned.CopyFrom(srcIndices, srcValues, 0, 0, srcValues.Length, zeroItems: false);
                    CpuMathUtils.MatrixTimesSource(transformInfo.RndFourierVectors, srcIndices, featuresAligned, 0, 0,
                        srcValues.Length, productAligned, transformInfo.NewDim);
                }
 
                var dstEditor = VBufferEditor.Create(ref dst, newDstLength);
                for (int i = 0; i < transformInfo.NewDim; i++)
                {
                    var dotProduct = productAligned[i];
                    if (transformInfo.RotationTerms != null)
                        dstEditor.Values[i] = (float)MathUtils.Cos(dotProduct + transformInfo.RotationTerms[i]) * scale;
                    else
                    {
                        dstEditor.Values[2 * i] = (float)MathUtils.Cos(dotProduct) * scale;
                        dstEditor.Values[2 * i + 1] = (float)MathUtils.Sin(dotProduct) * scale;
                    }
                }
 
                dst = dstEditor.Commit();
            }
        }
    }
 
    /// <summary>
    /// Maps vector columns to a low -dimensional feature space.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    ///
    /// ###  Estimator Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Does this estimator need to look at the data to train its parameters? | Yes |
    /// | Input column data type | Known-sized vector of <xref:System.Single> |
    /// | Output column data type | Known-sized vector of <xref:System.Single> |
    /// | Exportable to ONNX | No |
    ///
    /// The resulting <xref:Microsoft.ML.Transforms.ApproximatedKernelTransformer> creates a new column, named as specified in
    /// the output column name parameters, where each input vector is mapped to a feature space where inner products
    /// approximate one of two shift-invariant kernel functions: The Gaussian kernel, or the Laplacian kernel.
    /// By mapping features to a space that approximate non-linear kernels, linear methods can be used to approximate
    /// more complex kernel SVM models.
    /// This mapping is based on the paper [Random Features for Large-Scale Kernel Machines](http://pages.cs.wisc.edu/~brecht/papers/07.rah.rec.nips.pdf)
    /// by Rahimi and Recht.
    ///
    /// Check the See Also section for links to usage examples.
    /// ]]></format>
    /// </remarks>
    /// <seealso cref="KernelExpansionCatalog.ApproximatedKernelMap(TransformsCatalog, string, string, int, bool, KernelBase, int?)"/>
    public sealed class ApproximatedKernelMappingEstimator : IEstimator<ApproximatedKernelTransformer>
    {
        [BestFriend]
        internal static class Defaults
        {
            public const int Rank = 1000;
            public const bool UseCosAndSinBases = false;
        }
 
        /// <summary>
        /// Describes how the transformer handles one Gcn column pair.
        /// </summary>
        [BestFriend]
        internal sealed class ColumnOptions
        {
            /// <summary>
            /// Name of the column resulting from the transformation of <see cref="InputColumnName"/>.
            /// </summary>
            public readonly string Name;
            /// <summary>
            /// Name of the column to transform.
            /// </summary>
            public readonly string InputColumnName;
            /// <summary>
            /// Which fourier generator to use.
            /// </summary>
            public readonly KernelBase Generator;
            /// <summary>
            /// The number of random Fourier features to create.
            /// </summary>
            public readonly int Rank;
            /// <summary>
            /// If <see langword="true"/>, use both of cos and sin basis functions to create two features for every random Fourier frequency.
            /// Otherwise, only cos bases would be used.
            /// </summary>
            public readonly bool UseCosAndSinBases;
            /// <summary>
            /// The seed of the random number generator for generating the new features (if unspecified, the global random is used).
            /// </summary>
            public readonly int? Seed;
 
            /// <summary>
            /// Describes how the transformer handles one column pair.
            /// </summary>
            /// <param name="name">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
            /// <param name="rank">The number of random Fourier features to create.</param>
            /// <param name="useCosAndSinBases">If <see langword="true"/>, use both of cos and sin basis functions to create two features for
            /// every random Fourier frequency. Otherwise, only cos bases would be used.</param>
            /// <param name="inputColumnName">Name of column to transform. </param>
            /// <param name="generator">Which fourier generator to use.</param>
            /// <param name="seed">The seed of the random number generator for generating the new features (if unspecified, the global random is used).</param>
            public ColumnOptions(string name, int rank, bool useCosAndSinBases, string inputColumnName = null, KernelBase generator = null, int? seed = null)
            {
                Contracts.CheckUserArg(rank > 0, nameof(rank), "must be positive.");
                InputColumnName = inputColumnName ?? name;
                Name = name;
                Generator = generator ?? new GaussianKernel();
                Rank = rank;
                UseCosAndSinBases = useCosAndSinBases;
                Seed = seed;
            }
        }
 
        private readonly IHost _host;
        private readonly ColumnOptions[] _columns;
 
        /// <summary>
        /// Convinence constructor for simple one column case.
        /// </summary>
        /// <param name="env">Host Environment.</param>
        /// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
        /// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
        /// <param name="rank">The number of random Fourier features to create.</param>
        /// <param name="useCosAndSinBases">Create two features for every random Fourier frequency? (one for cos and one for sin).</param>
        internal ApproximatedKernelMappingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int rank = Defaults.Rank, bool useCosAndSinBases = Defaults.UseCosAndSinBases)
            : this(env, new ColumnOptions(outputColumnName, rank, useCosAndSinBases, inputColumnName ?? outputColumnName))
        {
        }
 
        internal ApproximatedKernelMappingEstimator(IHostEnvironment env, params ColumnOptions[] columns)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(ApproximatedKernelMappingEstimator));
            _columns = columns;
        }
 
        /// <summary>
        /// Trains and returns a <see cref="ApproximatedKernelTransformer"/>.
        /// </summary>
        public ApproximatedKernelTransformer Fit(IDataView input) => new ApproximatedKernelTransformer(_host, input, _columns);
 
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);
            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                if (col.ItemType.RawType != typeof(float) || col.Kind != SchemaShape.Column.VectorKind.Vector)
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
 
                result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }
 
            return new SchemaShape(result.Values);
        }
    }
}