File: Data\IEstimator.cs
Web Access
Project: src\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj (Microsoft.ML.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML;
 
/// <summary>
/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
/// This is more relaxed than the proper <see cref="DataViewSchema"/>, since it's only a subset of the columns,
/// and also since it doesn't specify exact <see cref="DataViewType"/>'s for vectors and keys.
/// </summary>
public sealed class SchemaShape : IReadOnlyList<SchemaShape.Column>
{
    private readonly Column[] _columns;
 
    private static readonly SchemaShape _empty = new SchemaShape(Enumerable.Empty<Column>());
 
    public int Count => _columns.Count();
 
    public Column this[int index] => _columns[index];
 
    public struct Column
    {
        public enum VectorKind
        {
            Scalar,
            Vector,
            VariableVector
        }
 
        /// <summary>
        /// The column name.
        /// </summary>
        public readonly string Name;
 
        /// <summary>
        /// The type of the column: scalar, fixed vector or variable vector.
        /// </summary>
        public readonly VectorKind Kind;
 
        /// <summary>
        /// The 'raw' type of column item: must be a primitive type or a structured type.
        /// </summary>
        public readonly DataViewType ItemType;
        /// <summary>
        /// The flag whether the column is actually a key. If yes, <see cref="ItemType"/> is representing
        /// the underlying primitive type.
        /// </summary>
        public readonly bool IsKey;
        /// <summary>
        /// The annotations that are present for this column.
        /// </summary>
        public readonly SchemaShape Annotations;
 
        [BestFriend]
        internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null)
        {
            Contracts.CheckNonEmpty(name, nameof(name));
            Contracts.CheckValueOrNull(annotations);
            Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key");
            Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector");
            Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");
 
            Name = name;
            Kind = vecKind;
            ItemType = itemType;
            IsKey = isKey;
            Annotations = annotations ?? _empty;
        }
 
        /// <summary>
        /// Returns whether <paramref name="source"/> is a valid input, if this object represents a
        /// requirement.
        ///
        /// Namely, it returns true iff:
        ///  - The <see cref="Name"/>, <see cref="Kind"/>, <see cref="ItemType"/>, <see cref="IsKey"/> fields match.
        ///  - The columns of <see cref="Annotations"/> of <paramref name="source"/> is a superset of our <see cref="Annotations"/> columns.
        ///  - Each such annotation column is itself compatible with the input annotation column.
        /// </summary>
        [BestFriend]
        internal bool IsCompatibleWith(Column source)
        {
            Contracts.Check(source.IsValid, nameof(source));
            if (Name != source.Name)
                return false;
            if (Kind != source.Kind)
                return false;
            if (!ItemType.Equals(source.ItemType))
                return false;
            if (IsKey != source.IsKey)
                return false;
            foreach (var annotationCol in Annotations)
            {
                if (!source.Annotations.TryFindColumn(annotationCol.Name, out var inputAnnotationCol))
                    return false;
                if (!annotationCol.IsCompatibleWith(inputAnnotationCol))
                    return false;
            }
            return true;
        }
 
        [BestFriend]
        internal string GetTypeString()
        {
            string result = ItemType.ToString();
            if (IsKey)
                result = $"Key<{result}>";
            if (Kind == VectorKind.Vector)
                result = $"Vector<{result}>";
            else if (Kind == VectorKind.VariableVector)
                result = $"VarVector<{result}>";
            return result;
        }
 
        /// <summary>
        /// Return if this structure is not identical to the default value of <see cref="Column"/>. If true,
        /// it means this structure is initialized properly and therefore considered as valid.
        /// </summary>
        [BestFriend]
        internal bool IsValid => Name != null;
    }
 
    public SchemaShape(IEnumerable<Column> columns)
    {
        Contracts.CheckValue(columns, nameof(columns));
        _columns = columns.ToArray();
        Contracts.CheckParam(columns.All(c => c.IsValid), nameof(columns), "Some items are not initialized properly.");
    }
 
    /// <summary>
    /// Given a <paramref name="type"/>, extract the type parameters that describe this type
    /// as a <see cref="SchemaShape"/>'s column type.
    /// </summary>
    /// <param name="type">The actual column type to process.</param>
    /// <param name="vecKind">The vector kind of <paramref name="type"/>.</param>
    /// <param name="itemType">The item type of <paramref name="type"/>.</param>
    /// <param name="isKey">Whether <paramref name="type"/> (or its item type) is a key.</param>
    [BestFriend]
    internal static void GetColumnTypeShape(DataViewType type,
        out Column.VectorKind vecKind,
        out DataViewType itemType,
        out bool isKey)
    {
        if (type is VectorDataViewType vectorType)
        {
            if (vectorType.IsKnownSize)
            {
                vecKind = Column.VectorKind.Vector;
            }
            else
            {
                vecKind = Column.VectorKind.VariableVector;
            }
 
            itemType = vectorType.ItemType;
        }
        else
        {
            vecKind = Column.VectorKind.Scalar;
            itemType = type;
        }
 
        isKey = itemType is KeyDataViewType;
        if (isKey)
            itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType);
    }
 
    /// <summary>
    /// Create a schema shape out of the fully defined schema.
    /// </summary>
    [BestFriend]
    internal static SchemaShape Create(DataViewSchema schema)
    {
        Contracts.CheckValue(schema, nameof(schema));
        var cols = new List<Column>();
 
        for (int iCol = 0; iCol < schema.Count; iCol++)
        {
            if (!schema[iCol].IsHidden)
            {
                // First create the annotations.
                var mCols = new List<Column>();
                foreach (var annotationColumn in schema[iCol].Annotations.Schema)
                {
                    GetColumnTypeShape(annotationColumn.Type, out var mVecKind, out var mItemType, out var mIsKey);
                    mCols.Add(new Column(annotationColumn.Name, mVecKind, mItemType, mIsKey));
                }
                var annotations = mCols.Count > 0 ? new SchemaShape(mCols) : _empty;
                // Next create the single column.
                GetColumnTypeShape(schema[iCol].Type, out var vecKind, out var itemType, out var isKey);
                cols.Add(new Column(schema[iCol].Name, vecKind, itemType, isKey, annotations));
            }
        }
        return new SchemaShape(cols);
    }
 
    /// <summary>
    /// Returns if there is a column with a specified <paramref name="name"/> and if so stores it in <paramref name="column"/>.
    /// </summary>
    [BestFriend]
    internal bool TryFindColumn(string name, out Column column)
    {
        Contracts.CheckValue(name, nameof(name));
        column = _columns.FirstOrDefault(x => x.Name == name);
        return column.IsValid;
    }
 
    public IEnumerator<Column> GetEnumerator() => ((IEnumerable<Column>)_columns).GetEnumerator();
 
    IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
 
    // REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
    // as an input to another schema shape. I started writing, but realized that there's more than one way to check for
    // the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
}
 
/// <summary>
/// The 'data loader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
/// </summary>
/// <typeparam name="TSource">The type of input the loader takes.</typeparam>
public interface IDataLoader<in TSource> : ICanSaveModel
{
    /// <summary>
    /// Produce the data view from the specified input.
    /// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
    /// </summary>
    IDataView Load(TSource input);
 
    /// <summary>
    /// The output schema of the loader.
    /// </summary>
    DataViewSchema GetOutputSchema();
}
 
/// <summary>
/// Sometimes we need to 'fit' an <see cref="IDataLoader{TIn}"/>.
/// A DataLoader estimator is the object that does it.
/// </summary>
public interface IDataLoaderEstimator<in TSource, out TLoader>
    where TLoader : IDataLoader<TSource>
{
    // REVIEW: you could consider the transformer to take a different <typeparamref name="TSource"/>, but we don't have such components
    // yet, so why complicate matters?
 
    /// <summary>
    /// Train and return a data loader.
    /// </summary>
    TLoader Fit(TSource input);
 
    /// <summary>
    /// The 'promise' of the output schema.
    /// It will be used for schema propagation.
    /// </summary>
    SchemaShape GetOutputSchema();
}
 
/// <summary>
/// The transformer is a component that transforms data.
/// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'.
/// </summary>
public interface ITransformer : ICanSaveModel
{
    /// <summary>
    /// Schema propagation for transformers.
    /// Returns the output schema of the data, if the input schema is like the one provided.
    /// </summary>
    DataViewSchema GetOutputSchema(DataViewSchema inputSchema);
 
    /// <summary>
    /// Take the data in, make transformations, output the data.
    /// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
    /// </summary>
    IDataView Transform(IDataView input);
 
    /// <summary>
    /// Whether a call to <see cref="GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
    /// appropriate schema.
    /// </summary>
    bool IsRowToRowMapper { get; }
 
    /// <summary>
    /// Constructs a row-to-row mapper based on an input schema. If <see cref="IsRowToRowMapper"/>
    /// is <c>false</c>, then an exception should be thrown. If the input schema is in any way
    /// unsuitable for constructing the mapper, an exception should likewise be thrown.
    /// </summary>
    /// <param name="inputSchema">The input schema for which we should get the mapper.</param>
    /// <returns>The row to row mapper.</returns>
    IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema);
}
 
[BestFriend]
internal interface ITransformerWithDifferentMappingAtTrainingTime : ITransformer
{
    IDataView TransformForTrainingPipeline(IDataView input);
}
 
/// <summary>
/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
/// a transformer.
/// It also provides the 'schema propagation' like transformers do, but over <see cref="SchemaShape"/> instead of <see cref="DataViewSchema"/>.
/// </summary>
public interface IEstimator<out TTransformer>
    where TTransformer : ITransformer
{
    /// <summary>
    /// Train and return a transformer.
    /// </summary>
    TTransformer Fit(IDataView input);
 
    /// <summary>
    /// Schema propagation for estimators.
    /// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
    /// </summary>
    SchemaShape GetOutputSchema(SchemaShape inputSchema);
}