File: Training\TrainerInputBase.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Trainers
{
    /// <summary>
    /// The base class for all trainer inputs.
    /// </summary>
    [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
    public abstract class TrainerInputBase
    {
        private protected TrainerInputBase() { }
 
        /// <summary>
        /// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is
        /// that the user will use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> or some other train
        /// method.
        /// </summary>
        [BestFriend]
        [Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
        internal IDataView TrainingData;
 
        /// <summary>
        /// Column to use for features.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
        public string FeatureColumnName = DefaultColumnNames.Features;
 
        /// <summary>
        /// Normalize option for the feature column. Used only in entry-points, since in the API the user is expected to do this themselves.
        /// </summary>
        [BestFriend]
        [Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
        internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
 
        /// <summary>
        /// Whether trainer should cache input training data. Used only in entry-points, since the intended API mechanism
        /// is that the user will use the <see cref="DataOperationsCatalog.Cache(IDataView, string[])"/> or other method
        /// like <see cref="EstimatorChain{TLastTransformer}.AppendCacheCheckpoint(IHostEnvironment)"/>.
        /// </summary>
        [BestFriend]
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Whether trainer should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
        internal CachingOptions Caching = CachingOptions.Auto;
    }
 
    /// <summary>
    /// The base class for all trainer inputs that support a Label column.
    /// </summary>
    [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
    public abstract class TrainerInputBaseWithLabel : TrainerInputBase
    {
        private protected TrainerInputBaseWithLabel() { }
 
        /// <summary>
        /// Column to use for labels.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
        public string LabelColumnName = DefaultColumnNames.Label;
    }
 
    // REVIEW: This is a known antipattern, but the solution involves the decorator pattern which can't be used in this case.
    /// <summary>
    /// The base class for all trainer inputs that support a weight column.
    /// </summary>
    [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
    public abstract class TrainerInputBaseWithWeight : TrainerInputBaseWithLabel
    {
        private protected TrainerInputBaseWithWeight() { }
 
        /// <summary>
        /// Column to use for example weight.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
        public string ExampleWeightColumnName = null;
    }
 
    /// <summary>
    /// The base class for all unsupervised trainer inputs that support a weight column.
    /// </summary>
    [TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
    public abstract class UnsupervisedTrainerInputBaseWithWeight : TrainerInputBase
    {
        private protected UnsupervisedTrainerInputBaseWithWeight() { }
 
        /// <summary>
        /// Column to use for example weight.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
        public string ExampleWeightColumnName = null;
    }
 
    /// <summary>
    /// The base class for all trainer inputs that support a group column.
    /// </summary>
    [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
    public abstract class TrainerInputBaseWithGroupId : TrainerInputBaseWithWeight
    {
        private protected TrainerInputBaseWithGroupId() { }
 
        /// <summary>
        /// Column to use for example groupId.
        /// </summary>
        [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
        public string RowGroupColumnName = null;
    }
}