File: DataLoadSave\DataOperationsCatalog.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
namespace Microsoft.ML
{
    /// <summary>
    /// Class used to create components that operate on data, but are not part of the model training pipeline.
    /// Includes components to load, save, cache, filter, shuffle, and split data.
    /// </summary>
    public sealed class DataOperationsCatalog : IInternalCatalog
    {
        IHostEnvironment IInternalCatalog.Environment => _env;
        private readonly IHostEnvironment _env;
 
        /// <summary>
        /// A pair of datasets, for the train and test set.
        /// </summary>
        public struct TrainTestData
        {
            /// <summary>
            /// Training set.
            /// </summary>
            public readonly IDataView TrainSet;
            /// <summary>
            /// Testing set.
            /// </summary>
            public readonly IDataView TestSet;
            /// <summary>
            /// Create pair of datasets.
            /// </summary>
            /// <param name="trainSet">Training set.</param>
            /// <param name="testSet">Testing set.</param>
            internal TrainTestData(IDataView trainSet, IDataView testSet)
            {
                TrainSet = trainSet;
                TestSet = testSet;
            }
        }
 
        internal DataOperationsCatalog(IHostEnvironment env)
        {
            Contracts.AssertValue(env);
            _env = env;
        }
 
        /// <summary>
        /// Create a new <see cref="IDataView"/> over an enumerable of the items of user-defined type.
        /// The user maintains ownership of the <paramref name="data"/> and the resulting data view will
        /// never alter the contents of the <paramref name="data"/>.
        /// Since <see cref="IDataView"/> is assumed to be immutable, the user is expected to support
        /// multiple enumerations of the <paramref name="data"/> that would return the same results, unless
        /// the user knows that the data will only be cursored once.
        ///
        /// One typical usage for streaming data view could be: create the data view that lazily loads data
        /// as needed, then apply pre-trained transformations to it and cursor through it for transformation
        /// results.
        /// </summary>
        /// <typeparam name="TRow">The user-defined item type.</typeparam>
        /// <param name="data">The enumerable data containing type <typeparamref name="TRow"/> to convert to a <see cref="IDataView"/>.</param>
        /// <param name="schemaDefinition">The optional schema definition of the data view to create. If <c>null</c>,
        /// the schema definition is inferred from <typeparamref name="TRow"/>.</param>
        /// <returns>The constructed <see cref="IDataView"/>.</returns>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[LoadFromEnumerable](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/LoadFromEnumerable.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView LoadFromEnumerable<TRow>(IEnumerable<TRow> data, SchemaDefinition schemaDefinition = null)
            where TRow : class
        {
            _env.CheckValue(data, nameof(data));
            _env.CheckValueOrNull(schemaDefinition);
            return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schemaDefinition);
        }
 
        /// <summary>
        /// Create a new <see cref="IDataView"/> over an enumerable of the items of user-defined type using the provided <see cref="DataViewSchema"/>,
        /// which might contain more information about the schema than the type can capture.
        /// </summary>
        /// <remarks>
        /// The user maintains ownership of the <paramref name="data"/> and the resulting data view will
        /// never alter the contents of the <paramref name="data"/>.
        /// Since <see cref="IDataView"/> is assumed to be immutable, the user is expected to support
        /// multiple enumerations of the <paramref name="data"/> that would return the same results, unless
        /// the user knows that the data will only be cursored once.
        /// One typical usage for streaming data view could be: create the data view that lazily loads data
        /// as needed, then apply pre-trained transformations to it and cursor through it for transformation
        /// results.
        /// One practical usage of this would be to supply the feature column names through the <see cref="DataViewSchema.Annotations"/>.
        /// </remarks>
        /// <typeparam name="TRow">The user-defined item type.</typeparam>
        /// <param name="data">The enumerable data containing type <typeparamref name="TRow"/> to convert to an <see cref="IDataView"/>.</param>
        /// <param name="schema">The schema of the returned <see cref="IDataView"/>.</param>
        /// <returns>An <see cref="IDataView"/> with the given <paramref name="schema"/>.</returns>
        public IDataView LoadFromEnumerable<TRow>(IEnumerable<TRow> data, DataViewSchema schema)
            where TRow : class
        {
            _env.CheckValue(data, nameof(data));
            _env.CheckValue(schema, nameof(schema));
            return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schema);
        }
 
        /// <summary>
        /// Convert an <see cref="IDataView"/> into a strongly-typed <see cref="IEnumerable{TRow}"/>.
        /// </summary>
        /// <typeparam name="TRow">The user-defined item type.</typeparam>
        /// <param name="data">The underlying data view.</param>
        /// <param name="reuseRowObject">Whether to return the same object on every row, or allocate a new one per row.</param>
        /// <param name="ignoreMissingColumns">Whether to ignore the case when a requested column is not present in the data view.</param>
        /// <param name="schemaDefinition">Optional user-provided schema definition. If it is not present, the schema is inferred from the definition of T.</param>
        /// <returns>The <see cref="IEnumerable{TRow}"/> that holds the data in <paramref name="data"/>. It can be enumerated multiple times.</returns>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[CreateEnumerable](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/DataViewEnumerable.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IEnumerable<TRow> CreateEnumerable<TRow>(IDataView data, bool reuseRowObject,
            bool ignoreMissingColumns = false, SchemaDefinition schemaDefinition = null)
            where TRow : class, new()
        {
            _env.CheckValue(data, nameof(data));
            _env.CheckValueOrNull(schemaDefinition);
 
            var engine = new PipeEngine<TRow>(_env, data, ignoreMissingColumns, schemaDefinition);
            return engine.RunPipe(reuseRowObject);
        }
 
        /// <summary>
        /// Take an approximate bootstrap sample of <paramref name="input"/>.
        /// </summary>
        /// <remarks>
        /// This sampler is a streaming version of <a href="https://en.wikipedia.org/wiki/Bootstrapping_(statistics)">bootstrap resampling</a>.
        /// Instead of taking the whole dataset into memory and resampling, <see cref="BootstrapSample"/> streams through the dataset and
        /// uses a <a href="https://en.wikipedia.org/wiki/Poisson_distribution">Poisson</a>(1) distribution to select the number of times a
        /// given row will be added to the sample. The <paramref name="complement"/> parameter allows for the creation of a bootstap sample
        /// and complementary out-of-bag sample by using the same <paramref name="seed"/>.
        /// </remarks>
        /// <param name="input">The input data.</param>
        /// <param name="seed">The random seed. If unspecified, the random state will be instead derived from the <see cref="MLContext"/>.</param>
        /// <param name="complement">Whether this is the out-of-bag sample, that is, all those rows that are not selected by the transform.
        /// Can be used to create a complementary pair of samples by using the same seed.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[BootstrapSample](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/BootstrapSample.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView BootstrapSample(IDataView input,
            int? seed = null,
            bool complement = BootstrapSamplingTransformer.Defaults.Complement)
        {
            _env.CheckValue(input, nameof(input));
            return new BootstrapSamplingTransformer(
                _env,
                input,
                complement: complement,
                seed: (uint?)seed,
                shuffleInput: false,
                poolSize: 0);
        }
 
        /// <summary>
        /// Creates a lazy in-memory cache of <paramref name="input"/>.
        /// </summary>
        /// <remarks>
        /// Caching happens per-column. A column is only cached when it is first accessed.
        /// In addition, <paramref name="columnsToPrefetch"/> are considered 'always needed', so these columns
        /// will be cached the first time any data is requested.
        /// </remarks>
        /// <param name="input">The input data.</param>
        /// <param name="columnsToPrefetch">The columns that must be cached whenever anything is cached. An empty array or null
        /// value means that columns are cached upon their first access.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[Cache](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/Cache.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView Cache(IDataView input, params string[] columnsToPrefetch)
        {
            _env.CheckValue(input, nameof(input));
            _env.CheckValueOrNull(columnsToPrefetch);
 
            int[] prefetch = new int[Utils.Size(columnsToPrefetch)];
            for (int i = 0; i < prefetch.Length; i++)
            {
                if (!input.Schema.TryGetColumnIndex(columnsToPrefetch[i], out prefetch[i]))
                    throw _env.ExceptSchemaMismatch(nameof(columnsToPrefetch), "prefetch", columnsToPrefetch[i]);
            }
            return new CacheDataView(_env, input, prefetch);
        }
 
        /// <summary>
        /// Filter the dataset by the values of a numeric column.
        /// </summary>
        /// <remarks>
        /// Keep only those rows that satisfy the range condition: the value of column <paramref name="columnName"/>
        /// must be between <paramref name="lowerBound"/> (inclusive) and <paramref name="upperBound"/> (exclusive).
        /// </remarks>
        /// <param name="input">The input data.</param>
        /// <param name="columnName">The name of a column to use for filtering.</param>
        /// <param name="lowerBound">The inclusive lower bound.</param>
        /// <param name="upperBound">The exclusive upper bound.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[FilterRowsByColumn](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/FilterRowsByColumn.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView FilterRowsByColumn(IDataView input, string columnName, double lowerBound = double.NegativeInfinity, double upperBound = double.PositiveInfinity)
        {
            _env.CheckValue(input, nameof(input));
            _env.CheckNonEmpty(columnName, nameof(columnName));
            _env.CheckParam(lowerBound < upperBound, nameof(upperBound), "Must be less than lowerBound");
 
            var type = input.Schema[columnName].Type;
            if (!(type is NumberDataViewType))
                throw _env.ExceptSchemaMismatch(nameof(columnName), "filter", columnName, "number", type.ToString());
            return new RangeFilter(_env, input, columnName, lowerBound, upperBound, false);
        }
 
        /// <summary>
        /// Filter the dataset by the values of a <see cref="KeyDataViewType"/> column.
        /// </summary>
        /// <remarks>
        /// Keep only those rows that satisfy the range condition: the value of a key column <paramref name="columnName"/>
        /// (treated as a fraction of the entire key range) must be between <paramref name="lowerBound"/> (inclusive) and <paramref name="upperBound"/> (exclusive).
        /// This filtering is useful if the <paramref name="columnName"/> is a key column obtained by some 'stable randomization',
        /// for example, hashing.
        /// </remarks>
        /// <param name="input">The input data.</param>
        /// <param name="columnName">The name of a column to use for filtering.</param>
        /// <param name="lowerBound">The inclusive lower bound.</param>
        /// <param name="upperBound">The exclusive upper bound.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[FilterRowsByKeyColumnFraction](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/FilterRowsByKeyColumnFraction.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView FilterRowsByKeyColumnFraction(IDataView input, string columnName, double lowerBound = 0, double upperBound = 1)
        {
            _env.CheckValue(input, nameof(input));
            _env.CheckNonEmpty(columnName, nameof(columnName));
            _env.CheckParam(0 <= lowerBound && lowerBound <= 1, nameof(lowerBound), "Must be in [0, 1]");
            _env.CheckParam(0 <= upperBound && upperBound <= 1, nameof(upperBound), "Must be in [0, 1]");
            _env.CheckParam(lowerBound <= upperBound, nameof(upperBound), "Must be no less than lowerBound");
 
            var type = input.Schema[columnName].Type;
            if (type.GetKeyCount() == 0)
                throw _env.ExceptSchemaMismatch(nameof(columnName), "filter", columnName, "KeyType", type.ToString());
            return new RangeFilter(_env, input, columnName, lowerBound, upperBound, false);
        }
 
        /// <summary>
        /// Drop rows where any column in <paramref name="columns"/> contains a missing value.
        /// </summary>
        /// <param name="input">The input data.</param>
        /// <param name="columns">Name of the columns to filter on. If a row is has a missing value in any of
        /// these columns, it will be dropped from the dataset.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[FilterRowsByMissingValues](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/FilterRowsByMissingValues.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView FilterRowsByMissingValues(IDataView input, params string[] columns)
        {
            _env.CheckValue(input, nameof(input));
            _env.CheckUserArg(Utils.Size(columns) > 0, nameof(columns));
 
            return new NAFilter(_env, input, complement: false, columns);
        }
 
        /// <summary>
        /// Shuffle the rows of <paramref name="input"/>.
        /// </summary>
        /// <remarks>
        /// <see cref="ShuffleRows"/> will shuffle the rows of any input <see cref="IDataView"/> using a streaming approach.
        /// In order to not load the entire dataset in memory, a pool of <paramref name="shufflePoolSize"/> rows will be used
        /// to randomly select rows to output. The pool is constructed from the first <paramref name="shufflePoolSize"/> rows
        /// in <paramref name="input"/>. Rows will then be randomly yielded from the pool and replaced with the next row from <paramref name="input"/>
        /// until all the rows have been yielded, resulting in a new <see cref="IDataView"/> of the same size as <paramref name="input"/>
        /// but with the rows in a randomized order.
        /// If the <see cref="IDataView.CanShuffle"/> property of <paramref name="input"/> is true, then it will also be read into the
        /// pool in a random order, offering two sources of randomness.
        /// </remarks>
        /// <param name="input">The input data.</param>
        /// <param name="seed">The random seed. If unspecified, the random state will be instead derived from the <see cref="MLContext"/>.</param>
        /// <param name="shufflePoolSize">The number of rows to hold in the pool. Setting this to 1 will turn off pool shuffling and
        /// <see cref="ShuffleRows"/> will only perform a shuffle by reading <paramref name="input"/> in a random order.</param>
        /// <param name="shuffleSource">If <see langword="false"/>, the transform will not attempt to read <paramref name="input"/> in a random order and only use
        /// pooling to shuffle. This parameter has no effect if the <see cref="IDataView.CanShuffle"/> property of <paramref name="input"/> is <see langword="false"/>.
        /// </param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[ShuffleRows](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/ShuffleRows.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView ShuffleRows(IDataView input,
            int? seed = null,
            int shufflePoolSize = RowShufflingTransformer.Defaults.PoolRows,
            bool shuffleSource = !RowShufflingTransformer.Defaults.PoolOnly)
        {
            _env.CheckValue(input, nameof(input));
            _env.CheckUserArg(shufflePoolSize > 0, nameof(shufflePoolSize), "Must be positive");
 
            var options = new RowShufflingTransformer.Options
            {
                PoolRows = shufflePoolSize,
                PoolOnly = !shuffleSource,
                ForceShuffle = true,
                ForceShuffleSeed = seed
            };
 
            return new RowShufflingTransformer(_env, options, input);
        }
 
        /// <summary>
        /// Skip <paramref name="count"/> rows in <paramref name="input"/>.
        /// </summary>
        /// <remarks>
        /// Skips the first <paramref name="count"/> rows from <paramref name="input"/> and returns an <see cref="IDataView"/> with all other rows.
        /// </remarks>
        /// <param name="input">The input data.</param>
        /// <param name="count">Number of rows to skip.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[SkipRows](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/SkipRows.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView SkipRows(IDataView input, long count)
        {
            _env.CheckValue(input, nameof(input));
            _env.CheckUserArg(count > 0, nameof(count), "Must be greater than zero.");
 
            var options = new SkipTakeFilter.SkipOptions()
            {
                Count = count
            };
 
            return new SkipTakeFilter(_env, options, input);
        }
 
        /// <summary>
        /// Take <paramref name="count"/> rows from <paramref name="input"/>.
        /// </summary>
        /// <remarks>
        /// Returns returns an <see cref="IDataView"/> with the first <paramref name="count"/> rows from <paramref name="input"/>.
        /// </remarks>
        /// <param name="input">The input data.</param>
        /// <param name="count">Number of rows to take.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[TakeRows](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/TakeRows.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView TakeRows(IDataView input, long count)
        {
            _env.CheckValue(input, nameof(input));
            _env.CheckUserArg(count > 0, nameof(count), "Must be greater than zero.");
 
            var options = new SkipTakeFilter.TakeOptions()
            {
                Count = count
            };
 
            return new SkipTakeFilter(_env, options, input);
        }
 
        /// <summary>
        /// Split the dataset into the train set and test set according to the given fraction.
        /// Respects the <paramref name="samplingKeyColumnName"/> if provided.
        /// </summary>
        /// <param name="data">The dataset to split.</param>
        /// <param name="testFraction">The fraction of data to go into the test set.</param>
        /// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
        /// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
        /// Note that when performing a Ranking Experiment, the <paramref name="samplingKeyColumnName"/> must be the GroupId column.
        /// If <see langword="null"/> no row grouping will be performed.</param>
        /// <param name="seed">Seed for the random number generator used to select rows for the train-test split.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[TrainTestSplit](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/TrainTestSplit.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, string samplingKeyColumnName = null, int? seed = null)
        {
            _env.CheckValue(data, nameof(data));
            _env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive");
            _env.CheckValueOrNull(samplingKeyColumnName);
 
            var splitColumn = CreateSplitColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true);
 
            var trainFilter = new RangeFilter(_env, new RangeFilter.Options()
            {
                Column = splitColumn,
                Min = 0,
                Max = testFraction,
                Complement = true
            }, data);
            var testFilter = new RangeFilter(_env, new RangeFilter.Options()
            {
                Column = splitColumn,
                Min = 0,
                Max = testFraction,
                Complement = false
            }, data);
 
            var trainDV = ColumnSelectingTransformer.CreateDrop(_env, trainFilter, splitColumn);
            var testDV = ColumnSelectingTransformer.CreateDrop(_env, testFilter, splitColumn);
 
            return new TrainTestData(trainDV, testDV);
        }
 
        /// <summary>
        /// Split the dataset into cross-validation folds of train set and test set.
        /// Respects the <paramref name="samplingKeyColumnName"/> if provided.
        /// </summary>
        /// <param name="data">The dataset to split.</param>
        /// <param name="numberOfFolds">Number of cross-validation folds.</param>
        /// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
        /// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
        /// Note that when performing a Ranking Experiment, the <paramref name="samplingKeyColumnName"/> must be the GroupId column.
        /// If <see langword="null"/> no row grouping will be performed.</param>
        /// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[CrossValidationSplit](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/CrossValidationSplit.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IReadOnlyList<TrainTestData> CrossValidationSplit(IDataView data, int numberOfFolds = 5, string samplingKeyColumnName = null, int? seed = null)
        {
            _env.CheckValue(data, nameof(data));
            _env.CheckParam(numberOfFolds > 1, nameof(numberOfFolds), "Must be more than 1");
            _env.CheckValueOrNull(samplingKeyColumnName);
            var splitColumn = CreateSplitColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true);
            var result = new List<TrainTestData>();
            foreach (var split in CrossValidationSplit(_env, data, splitColumn, numberOfFolds))
                result.Add(split);
            return result;
        }
 
        /// <summary>
        /// Splits the data based on the splitColumn, and drops that column as it is only
        /// intended to be used for splitting the data, and shouldn't be part of the output schema.
        /// </summary>
        internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment env, IDataView data, string splitColumn, int numberOfFolds = 5)
        {
            env.CheckValue(splitColumn, nameof(splitColumn));
 
            for (int fold = 0; fold < numberOfFolds; fold++)
            {
                var trainFilter = new RangeFilter(env, new RangeFilter.Options
                {
                    Column = splitColumn,
                    Min = (double)fold / numberOfFolds,
                    Max = (double)(fold + 1) / numberOfFolds,
                    Complement = true,
                    IncludeMin = true,
                    IncludeMax = true,
                }, data);
 
                var testFilter = new RangeFilter(env, new RangeFilter.Options
                {
                    Column = splitColumn,
                    Min = (double)fold / numberOfFolds,
                    Max = (double)(fold + 1) / numberOfFolds,
                    Complement = false,
                    IncludeMin = true,
                    IncludeMax = true
                }, data);
 
                var trainDV = ColumnSelectingTransformer.CreateDrop(env, trainFilter, splitColumn);
                var testDV = ColumnSelectingTransformer.CreateDrop(env, testFilter, splitColumn);
 
                yield return new TrainTestData(trainDV, testDV);
            }
        }
 
        /// <summary>
        /// Based on the input samplingKeyColumn creates a new splitColumn that will be used by the callers to apply a RangeFilter that will produce train-test splits
        /// or cross-validation splits.
        ///
        /// Notice that the new splitColumn might get dropped by the callers of this method after using it, as it wasn't part of
        /// the input DataView schema.
        /// </summary>
        /// <param name="env">IHostEnvironment of the caller</param>
        /// <param name="data">DataView that should contain the "samplingKeyColumn". The new splitColumn will be added to this DataView.</param>
        /// <param name="samplingKeyColumn">Name of the column that will be used as base of the new splitColumn.
        /// Notice that in other places in the code the samplingKeyColumn, and/or the splitColumn this method creates,
        /// are refered to as "SamplingKeyColumn", "StratificationColumn", "SplitColumn", "GroupPreservationColumn" or similar names. </param>
        /// <param name="seed">The seed that might be used by the transformers that will create the new splitColumn</param>
        /// <param name="fallbackInEnvSeed">If seed = null, then should we use the env seed? If seed = null, and this parameter is false, then we won't use a seed.</param>
        /// <return>The name of the new column</return>
        [BestFriend]
        internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int? seed = null, bool fallbackInEnvSeed = false)
        {
            Contracts.CheckValue(env, nameof(env));
            Contracts.CheckValueOrNull(samplingKeyColumn);
 
            var splitColumnName = data.Schema.GetTempColumnName("SplitColumn");
            int? seedToUse;
 
            if (seed.HasValue)
            {
                seedToUse = seed.Value;
            }
            else if (fallbackInEnvSeed)
            {
                IHostEnvironmentInternal seededEnv = (IHostEnvironmentInternal)env;
                seedToUse = seededEnv.Seed;
            }
            else
            {
                seedToUse = null;
            }
 
            // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number.
            if (samplingKeyColumn == null)
            {
                data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse);
            }
            else
            {
                // If samplingKeyColumn was provided we will make a new column based on it, but using a temporary
                // name, as it might be dropped elsewhere in the code
 
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex))
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
 
                var type = data.Schema[samplingColIndex].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    var hashInputColumnName = samplingKeyColumn;
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data);
                        hashInputColumnName = splitColumnName;
                    }
 
                    var columnOptions =
                        seedToUse.HasValue ?
                        new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) :
                        new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true);
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (type != NumberDataViewType.Single && type != NumberDataViewType.Double)
                    {
                        data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data);
                    }
                    else
                    {
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data);
                    }
                }
            }
 
            return splitColumnName;
        }
    }
}