File: Reductions\Moment.cs
Web Access
Project: src\src\Microsoft.ML.Fairlearn\Microsoft.ML.Fairlearn.csproj (Microsoft.ML.Fairlearn)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.Data.Analysis;
using Microsoft.ML.Data;
 
namespace Microsoft.ML.Fairlearn
{
    /// <summary>
    /// Generic moment.
    /// Modeled after the original Fairlearn <see href="https://github.com/fairlearn/fairlearn/blob/931963c40c0ba0cdd1a9e51c29adcc509da224a6/fairlearn/reductions/_moments/moment.py#L15">repo</see>
    /// Our implementations of the reductions approach to fairness
    /// <see href="https://arxiv.org/abs/1803.02453">agarwal2018reductions</see> 
    /// make use of Moment objects to describe both the optimization objective
    /// and the fairness constraints imposed on the solution.
    /// This is an abstract class for all such objects.
    /// </summary>
    public abstract class Moment
    {
        protected DataFrameColumn Y; //maybe lowercase this?
        public DataFrame Tags { get; private set; }
 
        public IDataView X { get; protected set; }
 
        public long TotalSamples { get; protected set; }
 
        public DataFrameColumn SensitiveFeatureColumn { get => Tags["group_id"]; }
 
        public string[] GroudIds;
 
        public Moment()
        {
        }
        /// <summary>
        /// Load the data into the moment to generate parity constarint
        /// </summary>
        /// <param name="features">The feature set</param>
        /// <param name="label">The label</param>
        /// <param name="sensitiveFeature">The sentivite featue that contain the sensitive groups</param>
        public virtual void LoadData(IDataView features, DataFrameColumn label, StringDataFrameColumn sensitiveFeature)
        {
            X = features;
            TotalSamples = label.Length;
            Y = label;
            Tags = new DataFrame();
            Tags["label"] = label;
            Tags["group_id"] = sensitiveFeature;
        }
 
        public virtual void LoadData(IDataView trainData, string label, string sensitiveColumnName)
        {
            var sensitiveFeature = DataFrameColumn.Create("group_id", trainData.GetColumn<string>(sensitiveColumnName));
            var labelColumn = DataFrameColumn.Create("label", trainData.GetColumn<bool>(label));
            this.LoadData(trainData, labelColumn, sensitiveFeature);
        }
 
        /// <summary>
        /// Calculate the degree to which constraints are currently violated by the predictor.
        /// </summary>
        /// <param name="yPred">Contains the predictions of the label</param>
        /// <returns></returns>
        public abstract DataFrame Gamma(PrimitiveDataFrameColumn<float> yPred);
        public abstract float Bound();
        public float ProjectLambda()
        {
            throw new NotImplementedException();
        }
        public abstract DataFrameColumn SignedWeights(DataFrame lambdaVec);
    }
    /// <summary>
    /// Moment that can be expressed as weighted classification error.
    /// </summary>
    public abstract class ClassificationMoment : Moment
    {
 
    }
}