File: DataLoadSave\EstimatorExtensions.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML
{
    /// <summary>
    /// Extension methods that allow chaining of estimator and transformer pipelines.
    /// </summary>
    public static class LearningPipelineExtensions
    {
        /// <summary>
        /// Create a new composite loader estimator, by appending another estimator to the end of this data loader estimator.
        /// </summary>
        public static CompositeLoaderEstimator<TSource, TTrans> Append<TSource, TTrans>(
            this IDataLoaderEstimator<TSource, IDataLoader<TSource>> start, IEstimator<TTrans> estimator)
            where TTrans : class, ITransformer
        {
            Contracts.CheckValue(start, nameof(start));
            Contracts.CheckValue(estimator, nameof(estimator));
 
            return new CompositeLoaderEstimator<TSource, ITransformer>(start).Append(estimator);
        }
 
        /// <summary>
        /// Create a new composite loader estimator, by appending an estimator to this data loader.
        /// </summary>
        public static CompositeLoaderEstimator<TSource, TTrans> Append<TSource, TTrans>(
            this IDataLoader<TSource> start, IEstimator<TTrans> estimator)
            where TTrans : class, ITransformer
        {
            Contracts.CheckValue(start, nameof(start));
            Contracts.CheckValue(estimator, nameof(estimator));
 
            return new TrivialLoaderEstimator<TSource, IDataLoader<TSource>>(start).Append(estimator);
        }
 
        /// <summary>
        /// Create a new estimator chain, by appending another estimator to the end of this estimator.
        /// </summary>
        public static EstimatorChain<TTrans> Append<TTrans>(
            this IEstimator<ITransformer> start, IEstimator<TTrans> estimator,
            TransformerScope scope = TransformerScope.Everything)
            where TTrans : class, ITransformer
        {
            Contracts.CheckValue(start, nameof(start));
            Contracts.CheckValue(estimator, nameof(estimator));
 
            if (start is EstimatorChain<ITransformer> est)
                return est.Append(estimator, scope);
 
            return new EstimatorChain<ITransformer>().Append(start).Append(estimator, scope);
        }
 
        /// <summary>
        /// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against
        /// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes.
        /// </summary>
        /// <param name="start">The starting estimator</param>
        /// <param name="env">The host environment to use for caching.</param>
 
        public static EstimatorChain<TTrans> AppendCacheCheckpoint<TTrans>(this IEstimator<TTrans> start, IHostEnvironment env)
            where TTrans : class, ITransformer
        {
            Contracts.CheckValue(start, nameof(start));
            return new EstimatorChain<ITransformer>().Append(start).AppendCacheCheckpoint(env);
        }
 
        /// <summary>
        /// Create a new composite loader, by appending a transformer to this data loader.
        /// </summary>
        public static CompositeDataLoader<TSource, TTrans> Append<TSource, TTrans>(this IDataLoader<TSource> loader, TTrans transformer)
            where TTrans : class, ITransformer
        {
            Contracts.CheckValue(loader, nameof(loader));
            Contracts.CheckValue(transformer, nameof(transformer));
 
            return new CompositeDataLoader<TSource, ITransformer>(loader).AppendTransformer(transformer);
        }
 
        /// <summary>
        /// Create a new transformer chain, by appending another transformer to the end of this transformer chain.
        /// </summary>
        public static TransformerChain<TTrans> Append<TTrans>(this ITransformer start, TTrans transformer)
            where TTrans : class, ITransformer
        {
            Contracts.CheckValue(start, nameof(start));
            Contracts.CheckValue(transformer, nameof(transformer));
 
            return new TransformerChain<TTrans>(start, transformer);
        }
 
        private sealed class DelegateEstimator<TTransformer> : IEstimator<TTransformer>
            where TTransformer : class, ITransformer
        {
            private readonly IEstimator<TTransformer> _est;
            private readonly Action<TTransformer> _onFit;
 
            public DelegateEstimator(IEstimator<TTransformer> estimator, Action<TTransformer> onFit)
            {
                Contracts.AssertValue(estimator);
                Contracts.AssertValue(onFit);
                _est = estimator;
                _onFit = onFit;
            }
 
            public TTransformer Fit(IDataView input)
            {
                var trans = _est.Fit(input);
                _onFit(trans);
                return trans;
            }
 
            public SchemaShape GetOutputSchema(SchemaShape inputSchema)
                => _est.GetOutputSchema(inputSchema);
        }
 
        /// <summary>
        /// Given an estimator, return a wrapping object that will call a delegate once <see cref="IEstimator{TTransformer}.Fit(IDataView)"/>
        /// is called. It is often important for an estimator to return information about what was fit, which is why the
        /// <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> method returns a specifically typed object, rather than just a general
        /// <see cref="ITransformer"/>. However, at the same time, <see cref="IEstimator{TTransformer}"/> are often formed into pipelines
        /// with many objects, so we may need to build a chain of estimators via <see cref="EstimatorChain{TLastTransformer}"/> where the
        /// estimator for which we want to get the transformer is buried somewhere in this chain. For that scenario, we can through this
        /// method attach a delegate that will be called once fit is called.
        /// </summary>
        /// <typeparam name="TTransformer">The type of <see cref="ITransformer"/> returned by <paramref name="estimator"/></typeparam>
        /// <param name="estimator">The estimator to wrap</param>
        /// <param name="onFit">The delegate that is called with the resulting <typeparamref name="TTransformer"/> instances once
        /// <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> is called. Because <see cref="IEstimator{TTransformer}.Fit(IDataView)"/>
        /// may be called multiple times, this delegate may also be called multiple times.</param>
        /// <returns>A wrapping estimator that calls the indicated delegate whenever fit is called</returns>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[OnFit](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/WithOnFitDelegate.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public static IEstimator<TTransformer> WithOnFitDelegate<TTransformer>(this IEstimator<TTransformer> estimator, Action<TTransformer> onFit)
            where TTransformer : class, ITransformer
        {
            Contracts.CheckValue(estimator, nameof(estimator));
            Contracts.CheckValue(onFit, nameof(onFit));
            return new DelegateEstimator<TTransformer>(estimator, onFit);
        }
 
        [BestFriend]
        internal static T[] AppendElement<T>(this T[] array, T element)
        {
            T[] result = new T[Utils.Size(array) + 1];
            Array.Copy(array, result, result.Length - 1);
            result[result.Length - 1] = element;
            return result;
        }
    }
}