File: DataLoadSave\TransformWrapper.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Data.DataLoadSave;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
    /// <summary>
    /// This is a shim class to present the legacy <see cref="IDataTransform"/> interface as an <see cref="ITransformer"/>.
    /// Note that there are some important differences in usages that make this shimming somewhat non-seamless, so the goal
    /// would be gradual removal of this as we do away with <see cref="IDataTransform"/> based code.
    /// </summary>
    internal sealed class TransformWrapper : ITransformer
        internal const string LoaderSignature = "TransformWrapper";
        private readonly IHost _host;
        private readonly IDataView _xf;
        public TransformWrapper(IHostEnvironment env, IDataView xf)
            Contracts.CheckValue(env, nameof(env));
            Contracts.Check(xf is IDataTransform);
            _host = env.Register(nameof(TransformWrapper));
            _host.CheckValue(xf, nameof(xf));
            _xf = xf;
        public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var dv = new EmptyDataView(_host, inputSchema);
            var output = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv);
            return output.Schema;
        void ICanSaveModel.Save(ModelSaveContext ctx) => throw _host.Except("Saving is not permitted.");
        public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);
        bool ITransformer.IsRowToRowMapper => _xf is IRowToRowMapper;
        IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var transform = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, new EmptyDataView(_host, inputSchema)) as IRowToRowMapper;
            _host.Check(transform is IRowToRowMapper);
            return new CompositeRowToRowMapper(inputSchema, new[] { transform });
    /// <summary>
    /// Estimator for trained wrapped transformers.
    /// </summary>
    internal abstract class TrainedWrapperEstimatorBase : IEstimator<TransformWrapper>
        private protected readonly IHost Host;
        private protected TrainedWrapperEstimatorBase(IHost host)
            Contracts.CheckValue(host, nameof(host));
            Host = host;
        public abstract TransformWrapper Fit(IDataView input);
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var fakeSchema = FakeSchemaFactory.Create(inputSchema);
            var transformer = Fit(new EmptyDataView(Host, fakeSchema));
            return SchemaShape.Create(transformer.GetOutputSchema(fakeSchema));