File: EntryPoints\TransformModelImpl.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.EntryPoints
{
    /// <summary>
    /// This encapsulates zero or more transform models. It does this by recording
    /// the initial schema, together with the sequence of transforms applied to that
    /// schema.
    /// </summary>
    [BestFriend]
    internal sealed class TransformModelImpl : TransformModel
    {
        // The cached schema of the root of the _chain.
        private readonly DataViewSchema _schemaRoot;
 
        /// <summary>
        /// This contains the transforms to save instantiated on an <see cref="IDataView"/> with
        /// appropriate initial schema. Note that the "root" of this is typically either
        /// an empty <see cref="IDataView"/> or a <see cref="BinaryLoader"/> with no rows. However, other root
        /// types are possible, since we don't insist on this when loading a model
        /// from a zip file. However, whenever we save, we force a <see cref="BinaryLoader"/> to
        /// be serialized for the root.
        /// </summary>
        private readonly IDataView _chain;
 
        /// <summary>
        /// The input schema that this transform model was originally instantiated on.
        /// Note that the schema may have columns that aren't needed by this transform model.
        /// If an <see cref="IDataView"/> exists with this schema, then applying this transform model to it
        /// shouldn't fail because of column type issues.
        /// REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note
        /// however that doing so may cause issues for composing transform models. For example,
        /// if transform model A needs column X and model B needs Y, that is NOT produced by A,
        /// then trimming A's input schema would cause composition to fail.
        /// </summary>
        internal override DataViewSchema InputSchema => _schemaRoot;
 
        /// <summary>
        /// The resulting schema once applied to this model. The <see cref="InputSchema"/> might have
        /// columns that are not needed by this transform and these columns will be seen in the
        /// <see cref="OutputSchema"/> produced by this transform.
        /// </summary>
        internal override DataViewSchema OutputSchema => _chain.Schema;
 
        /// <summary>
        /// Create a TransformModel containing the transforms from "result" back to "input".
        /// </summary>
        public TransformModelImpl(IHostEnvironment env, IDataView result, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(result, nameof(result));
            env.CheckValue(input, nameof(input));
 
            var root = new EmptyDataView(env, input.Schema);
            _schemaRoot = root.Schema;
            _chain = ApplyTransformUtils.ApplyAllTransformsToData(env, result, root, input);
        }
 
        private TransformModelImpl(IHostEnvironment env, DataViewSchema schemaRoot, IDataView chain)
        {
            Contracts.AssertValue(env);
            env.AssertValue(schemaRoot);
            env.AssertValue(chain);
            _schemaRoot = schemaRoot;
            _chain = chain;
        }
 
        /// <summary>
        /// Create a TransformModel containing the given (optional) transforms applied to the
        /// given root schema.
        /// </summary>
        public TransformModelImpl(IHostEnvironment env, DataViewSchema schemaRoot, IDataTransform[] xfs)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(schemaRoot, nameof(schemaRoot));
            env.CheckValueOrNull(xfs);
 
            IDataView view = new EmptyDataView(env, schemaRoot);
            _schemaRoot = view.Schema;
 
            if (Utils.Size(xfs) > 0)
            {
                foreach (var xf in xfs)
                {
                    env.AssertValue(xf, "xfs", "Transforms should not be null");
                    view = ApplyTransformUtils.ApplyTransformToData(env, xf, view);
                }
            }
 
            _chain = view;
        }
 
        /// <summary>
        /// Load a transform model.
        /// </summary>
        public TransformModelImpl(IHostEnvironment env, Stream stream)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(stream, nameof(stream));
 
            // REVIEW: Should this preserve the "tags" for the transforms?
            using (var ch = env.Start("Loading transform model"))
            {
                _chain = ModelFileUtils.LoadPipeline(env, stream, new MultiFileSource(null), extractInnerPipe: true);
            }
 
            // Find the root schema.
            for (IDataView view = _chain; ;)
            {
                var xf = view as IDataTransform;
                if (xf == null)
                {
                    _schemaRoot = view.Schema;
                    break;
                }
                view = xf.Source;
                env.AssertValue(view);
            }
        }
 
        /// <summary>
        /// Apply this transform model to the given input data.
        /// </summary>
        internal override IDataView Apply(IHostEnvironment env, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
            return ApplyTransformUtils.ApplyAllTransformsToData(env, _chain, input);
        }
 
        /// <summary>
        /// Apply this transform model to the given input transform model to produce a composite transform model.
        /// </summary>
        internal override TransformModel Apply(IHostEnvironment env, TransformModel input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
 
            IDataView view;
            DataViewSchema schemaRoot = input.InputSchema;
            var mod = input as TransformModelImpl;
            if (mod != null)
                view = ApplyTransformUtils.ApplyAllTransformsToData(env, _chain, mod._chain);
            else
            {
                view = new EmptyDataView(env, schemaRoot);
                view = input.Apply(env, view);
                view = Apply(env, view);
            }
 
            return new TransformModelImpl(env, schemaRoot, view);
        }
 
        /// <summary>
        /// Save this transform model.
        /// </summary>
        internal override void Save(IHostEnvironment env, Stream stream)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(stream, nameof(stream));
 
            using (var ch = env.Start("Saving transform model"))
            {
                using (var rep = RepositoryWriter.CreateNew(stream, ch))
                {
                    ch.Trace("Saving root schema and transformations");
                    TrainUtils.SaveDataPipe(env, rep, _chain, blankLoader: true);
                    rep.Commit();
                }
            }
        }
 
        internal override IRowToRowMapper AsRowToRowMapper(IExceptionContext ectx)
        {
            return
                CompositeRowToRowMapper.IsCompositeRowToRowMapper(_chain)
                    ? new CompositeRowToRowMapper(ectx, _chain, _schemaRoot)
                    : null;
        }
 
        private sealed class CompositeRowToRowMapper : IRowToRowMapper
        {
            private readonly IDataView _chain;
            private readonly DataViewSchema _rootSchema;
            private readonly IExceptionContext _ectx;
 
            public DataViewSchema Schema => _chain.Schema;
 
            public DataViewSchema OutputSchema => Schema;
 
            public CompositeRowToRowMapper(IExceptionContext ectx, IDataView chain, DataViewSchema rootSchema)
            {
                Contracts.CheckValue(ectx, nameof(ectx));
                _ectx = ectx;
                _ectx.CheckValue(chain, nameof(chain));
                _ectx.CheckValue(rootSchema, nameof(rootSchema));
 
                _chain = chain;
                _rootSchema = rootSchema;
            }
 
            public static bool IsCompositeRowToRowMapper(IDataView chain)
            {
                var transform = chain as IDataTransform;
                while (transform != null)
                {
                    if (!(transform is IRowToRowMapper))
                        return false;
                    transform = transform.Source as IDataTransform;
                }
                return true;
            }
 
            /// <summary>
            /// Given a set of columns, return the input columns that are needed to generate those output columns.
            /// </summary>
            IEnumerable<DataViewSchema.Column> IRowToRowMapper.GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns)
            {
                _ectx.Assert(IsCompositeRowToRowMapper(_chain));
 
                var transform = _chain as IDataTransform;
                var cols = dependingColumns;
                while (transform != null)
                {
                    var mapper = transform as IRowToRowMapper;
                    _ectx.AssertValue(mapper);
                    cols = mapper.GetDependencies(cols);
                    transform = transform.Source as IDataTransform;
                }
                return cols;
            }
 
            public DataViewSchema InputSchema => _rootSchema;
 
            DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
            {
                _ectx.Assert(IsCompositeRowToRowMapper(_chain));
                _ectx.AssertValue(input);
                _ectx.AssertValue(activeColumns);
 
                _ectx.Check(input.Schema == InputSchema, "Schema of input row must be the same as the schema the mapper is bound to");
 
                var mappers = new List<IRowToRowMapper>();
                var actives = new List<IEnumerable<DataViewSchema.Column>>();
                var transform = _chain as IDataTransform;
                var activeCur = activeColumns;
                while (transform != null)
                {
                    var mapper = transform as IRowToRowMapper;
                    _ectx.AssertValue(mapper);
                    mappers.Add(mapper);
                    actives.Add(activeCur);
                    activeCur = mapper.GetDependencies(activeCur);
                    transform = transform.Source as IDataTransform;
                }
 
                mappers.Reverse();
                actives.Reverse();
                var row = input;
                for (int i = 0; i < mappers.Count; i++)
                    row = mappers[i].GetRow(row, actives[i]);
 
                return row;
            }
        }
    }
}