File: AutoMLExperiment\IDatasetManager.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
#nullable enable
 
using Microsoft.ML.SearchSpace;
 
namespace Microsoft.ML.AutoML
{
    /// <summary>
    /// Interface for dataset manager. This interface doesn't include any method or property definition and is used by <see cref="AutoMLExperiment"/> and other components to retrieve the instance of the actual
    /// dataset manager from containers.
    /// </summary>
    public interface IDatasetManager
    {
    }
 
    /// <summary>
    /// Inferface for cross validate dataset manager.
    /// </summary>
    public interface ICrossValidateDatasetManager : IDatasetManager
    {
        /// <summary>
        /// Cross validate fold.
        /// </summary>
        int Fold { get; set; }
 
        /// <summary>
        /// The dataset to cross validate.
        /// </summary>
        IDataView Dataset { get; set; }
 
        /// <summary>
        /// The dataset column used for grouping rows.
        /// </summary>
        string? SamplingKeyColumnName { get; set; }
    }
 
    public interface ITrainValidateDatasetManager : IDatasetManager
    {
        IDataView LoadTrainDataset(MLContext context, TrialSettings? settings);
 
        IDataView LoadValidateDataset(MLContext context, TrialSettings? settings);
    }
 
    internal class TrainValidateDatasetManager : IDatasetManager, ITrainValidateDatasetManager
    {
        private ulong _rowCount;
        private readonly IDataView _trainDataset;
        private readonly IDataView _validateDataset;
        private readonly string _subSamplingKey = "TrainValidateDatasetSubsamplingKey";
        private bool _isInitialized = false;
        public TrainValidateDatasetManager(IDataView trainDataset, IDataView validateDataset, string? subSamplingKey = null)
        {
            _trainDataset = trainDataset;
            _validateDataset = validateDataset;
            _subSamplingKey = subSamplingKey ?? _subSamplingKey;
        }
 
        public string SubSamplingKey => _subSamplingKey;
 
        /// <summary>
        /// Load Train Dataset. If <see cref="TrialSettings.Parameter"/> contains <see cref="_subSamplingKey"/> then the train dataset will be subsampled.
        /// </summary>
        /// <param name="context">MLContext.</param>
        /// <param name="settings">trial settings. If null, return entire train dataset.</param>
        /// <returns>train dataset.</returns>
        public IDataView LoadTrainDataset(MLContext context, TrialSettings? settings)
        {
            if (!_isInitialized)
            {
                InitializeTrainDataset(context);
                _isInitialized = true;
            }
            var trainTestSplitParameter = settings?.Parameter.ContainsKey(nameof(TrainValidateDatasetManager)) is true ? settings.Parameter[nameof(TrainValidateDatasetManager)] : null;
            if (trainTestSplitParameter is Parameter parameter)
            {
                var subSampleRatio = parameter.ContainsKey(_subSamplingKey) ? parameter[_subSamplingKey].AsType<double>() : 1;
                if (subSampleRatio < 1.0)
                {
                    var count = (long)(subSampleRatio * _rowCount);
                    if (count <= 10)
                    {
                        // fix issue https://github.com/dotnet/machinelearning-modelbuilder/issues/2734
                        // take at least 10 rows to avoid empty dataset
                        count = 10;
                    }
 
                    var subSampledTrainDataset = context.Data.TakeRows(_trainDataset, count);
                    return subSampledTrainDataset;
                }
            }
 
            return _trainDataset;
        }
 
        public IDataView LoadValidateDataset(MLContext context, TrialSettings? settings)
        {
            return _validateDataset;
        }
 
        private void InitializeTrainDataset(MLContext context)
        {
            _rowCount = DatasetDimensionsUtil.CountRows(_trainDataset, ulong.MaxValue);
        }
    }
 
    internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateDatasetManager
    {
        public CrossValidateDatasetManager(IDataView dataset, int fold, string? samplingKeyColumnName = null)
        {
            Dataset = dataset;
            Fold = fold;
            SamplingKeyColumnName = samplingKeyColumnName;
        }
 
        public IDataView Dataset { get; set; }
 
        public int Fold { get; set; }
 
        public string? SamplingKeyColumnName { get; set; }
    }
}