File: Utilities\ApplyTransformUtils.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// Utilities to rebind data transforms
    /// </summary>
    [BestFriend]
    internal static class ApplyTransformUtils
    {
        /// <summary>
        /// Attempt to apply the data transform to a different data view source.
        /// If the transform in question implements <see cref="ITransformTemplate"/>, <see cref="ITransformTemplate.ApplyToData"/>
        /// is called. Otherwise, the transform is serialized into a byte array and then deserialized.
        /// </summary>
        /// <param name="env">The host to use</param>
        /// <param name="transform">The transform to apply.</param>
        /// <param name="newSource">The data view to apply the transform to.</param>
        /// <returns>The resulting data view.</returns>
        public static IDataTransform ApplyTransformToData(IHostEnvironment env, IDataTransform transform, IDataView newSource)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(transform, nameof(transform));
            env.CheckValue(newSource, nameof(newSource));
            var rebindable = transform as ITransformTemplate;
            if (rebindable != null)
                return rebindable.ApplyToData(env, newSource);
 
            // Revert to serialization.
            using (var stream = new MemoryStream())
            {
                using (var rep = RepositoryWriter.CreateNew(stream, env))
                {
                    ModelSaveContext.SaveModel(rep, transform, "model");
                    rep.Commit();
                }
 
                stream.Position = 0;
                using (var rep = RepositoryReader.Open(stream, env))
                {
                    IDataTransform newData;
                    ModelLoadContext.LoadModel<IDataTransform, SignatureLoadDataTransform>(env,
                        out newData, rep, "model", newSource);
                    return newData;
                }
            }
        }
 
        /// <summary>
        /// Walks back the Source chain of the <see cref="IDataTransform"/> up to the <paramref name="oldSource"/>
        /// (or <see cref="ILegacyDataLoader"/> if <paramref name="oldSource"/> is <c>null</c>),
        /// and reapplies all transforms in the chain, to produce the same chain but bound to the different data.
        /// It is valid to have no transforms: in this case the result will be equal to <paramref name="newSource"/>
        /// If <paramref name="oldSource"/> is specified and not found in the pipe, an exception is thrown.
        /// </summary>
        /// <param name="env">The environment to use.</param>
        /// <param name="chain">The end of the chain.</param>
        /// <param name="newSource">The new data to attach the chain to.</param>
        /// <param name="oldSource">The 'old source' of the pipe, that doesn't need to be reapplied. If null, all transforms are reapplied.</param>
        /// <returns>The resulting data view.</returns>
        public static IDataView ApplyAllTransformsToData(IHostEnvironment env, IDataView chain, IDataView newSource, IDataView oldSource = null)
        {
            // REVIEW: have a variation that would selectively apply transforms?
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(chain, nameof(chain));
            env.CheckValue(newSource, nameof(newSource));
            env.CheckValueOrNull(oldSource);
 
            // Backtrack the chain until we reach a chain start or a non-transform.
            // REVIEW: we 'unwrap' the composite data loader here and step through its pipeline.
            // It's probably more robust to make CompositeDataLoader not even be an IDataView, this
            // would force the user to do the right thing and unwrap on his end.
            var cdl = chain as LegacyCompositeDataLoader;
            if (cdl != null)
                chain = cdl.View;
 
            var transforms = new List<IDataTransform>();
            IDataTransform xf;
            while ((xf = chain as IDataTransform) != null)
            {
                if (chain == oldSource)
                    break;
                transforms.Add(xf);
                chain = xf.Source;
 
                cdl = chain as LegacyCompositeDataLoader;
                if (cdl != null)
                    chain = cdl.View;
            }
            transforms.Reverse();
 
            env.Check(oldSource == null || chain == oldSource, "Source data not found in the chain");
 
            IDataView newChain = newSource;
            foreach (var transform in transforms)
                newChain = ApplyTransformToData(env, transform, newChain);
 
            return newChain;
        }
    }
}