File: DataLoadSave\CompositeLoaderEstimator.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// An estimator class for composite data loader.
    /// It can be used to build a 'trainable smart data loader', although this pattern is not very common.
    /// </summary>
    public sealed class CompositeLoaderEstimator<TSource, TLastTransformer> : IDataLoaderEstimator<TSource, CompositeDataLoader<TSource, TLastTransformer>>
        where TLastTransformer : class, ITransformer
    {
        private readonly IDataLoaderEstimator<TSource, IDataLoader<TSource>> _start;
        private readonly EstimatorChain<TLastTransformer> _estimatorChain;
 
        public CompositeLoaderEstimator(IDataLoaderEstimator<TSource, IDataLoader<TSource>> start, EstimatorChain<TLastTransformer> estimatorChain = null)
        {
            Contracts.CheckValue(start, nameof(start));
            Contracts.CheckValueOrNull(estimatorChain);
 
            _start = start;
            _estimatorChain = estimatorChain ?? new EstimatorChain<TLastTransformer>();
 
            // REVIEW: enforce that estimator chain can read the loader's schema.
            // Right now it throws.
            // GetOutputSchema();
        }
 
        public CompositeDataLoader<TSource, TLastTransformer> Fit(TSource input)
        {
            var start = _start.Fit(input);
            var idv = start.Load(input);
 
            var xfChain = _estimatorChain.Fit(idv);
            return new CompositeDataLoader<TSource, TLastTransformer>(start, xfChain);
        }
 
        public SchemaShape GetOutputSchema()
        {
            var shape = _start.GetOutputSchema();
            return _estimatorChain.GetOutputSchema(shape);
        }
 
        /// <summary>
        /// Create a new loader estimator, by appending another estimator to the end of this loader estimator.
        /// </summary>
        public CompositeLoaderEstimator<TSource, TNewTrans> Append<TNewTrans>(IEstimator<TNewTrans> estimator)
            where TNewTrans : class, ITransformer
        {
            Contracts.CheckValue(estimator, nameof(estimator));
 
            return new CompositeLoaderEstimator<TSource, TNewTrans>(_start, _estimatorChain.Append(estimator));
        }
    }
}