File: AutoML\AutoMLExperimentExtension.cs
Web Access
Project: src\src\Microsoft.ML.Fairlearn\Microsoft.ML.Fairlearn.csproj (Microsoft.ML.Fairlearn)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.Data.Analysis;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.ML.AutoML;
using Microsoft.ML.Data;
using Microsoft.ML.SearchSpace;
 
namespace Microsoft.ML.Fairlearn.AutoML
{
    /// <summary>
    /// An internal class that holds the gridLimit value to conduct gridsearch.
    /// Needed to pass the value into the AutoMLExperiment as a singleton
    /// </summary>
    internal class GridLimit
    {
        public float Value { get; set; }
    }
    /// <summary>
    /// An extension class used to add more options to the Fairlearn girdsearch experiment
    /// </summary>
    public static class AutoMLExperimentExtension
    {
        public static AutoMLExperiment SetBinaryClassificationMoment(this AutoMLExperiment experiment, ClassificationMoment moment)
        {
            experiment.ServiceCollection.AddSingleton(moment);
 
            return experiment;
        }
 
        public static AutoMLExperiment SetGridLimit(this AutoMLExperiment experiment, float gridLimit)
        {
            var gridLimitObject = new GridLimit();
            gridLimitObject.Value = gridLimit;
            experiment.ServiceCollection.AddSingleton(gridLimitObject);
            experiment.SetTuner<CostFrugalWithLambdaTunerFactory>();
 
            return experiment;
        }
 
        public static AutoMLExperiment SetBinaryClassificationMetricWithFairLearn(
            this AutoMLExperiment experiment,
            string labelColumn,
            string predictedColumn,
            string sensitiveColumnName,
            string exampleWeightColumnName,
            float gridLimit = 10f,
            bool negativeAllowed = true)
        {
            experiment.ServiceCollection.AddSingleton<ClassificationMoment>((serviceProvider) =>
            {
                var datasetManager = serviceProvider.GetRequiredService<TrainValidateDatasetManager>();
                var moment = new UtilityParity();
                var context = serviceProvider.GetRequiredService<MLContext>();
                var trainData = datasetManager.LoadTrainDataset(context, new TrialSettings
                {
                    Parameter = Parameter.CreateNestedParameter(),
                });
                var sensitiveFeature = DataFrameColumn.Create("group_id", trainData.GetColumn<string>(sensitiveColumnName));
                var label = DataFrameColumn.Create("label", trainData.GetColumn<bool>(labelColumn));
                moment.LoadData(trainData, label, sensitiveFeature);
                var lambdaSearchSpace = Utilities.GenerateBinaryClassificationLambdaSearchSpace(moment, gridLimit, negativeAllowed);
                experiment.AddSearchSpace("_lambda_search_space", lambdaSearchSpace);
 
                return moment;
            });
 
            experiment.SetTrialRunner((serviceProvider) =>
            {
                var context = serviceProvider.GetRequiredService<MLContext>();
                var moment = serviceProvider.GetRequiredService<ClassificationMoment>();
                var datasetManager = serviceProvider.GetRequiredService<TrainValidateDatasetManager>();
                var pipeline = serviceProvider.GetRequiredService<SweepablePipeline>();
                return new GridSearchTrailRunner(context, datasetManager, labelColumn, sensitiveColumnName, pipeline, moment);
            });
 
            experiment.SetRandomSearchTuner();
 
            return experiment;
        }
    }
}