File: LambdaTransform.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using System.Reflection;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(typeof(ITransformer), typeof(LambdaTransform), null, typeof(SignatureLoadModel), "", LambdaTransform.LoaderSignature)]
 
namespace Microsoft.ML.Transforms
{
    using Conditional = System.Diagnostics.ConditionalAttribute;
 
    /// <summary>
    /// Utility class for creating transforms easily.
    /// </summary>
    [BestFriend]
    internal static class LambdaTransform
    {
        internal const string LoaderSignature = "CustomTransformer";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "CUSTOMXF",
                //verWrittenCur: 0x00010001,  // Initial
                verWrittenCur: 0x00010002,  // Added name of assembly in which the contractName is present
                verReadableCur: 0x00010002,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(LambdaTransform).Assembly.FullName);
        }
 
        private const uint VerAssemblyNameSaved = 0x00010002;
 
        internal static void SaveCustomTransformer(IExceptionContext ectx, ModelSaveContext ctx, string contractName, string contractAssembly)
        {
            ectx.CheckValue(ctx, nameof(ctx));
            ectx.CheckValue(contractName, nameof(contractName));
 
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            ctx.SaveString(contractName);
            ctx.SaveString(contractAssembly);
        }
 
        // Factory for SignatureLoadModel.
        private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            var contractName = ctx.LoadString();
            if (ctx.Header.ModelVerWritten >= VerAssemblyNameSaved)
            {
                var contractAssembly = ctx.LoadString();
                Assembly assembly = Assembly.Load(contractAssembly);
                env.ComponentCatalog.RegisterAssembly(assembly);
            }
 
            object factoryObject = env.ComponentCatalog.GetExtensionValue(env, typeof(CustomMappingFactoryAttributeAttribute), contractName);
            if (!(factoryObject is ICustomMappingFactory mappingFactory))
            {
                throw env.Except($"The class with contract '{contractName}' must derive from '{typeof(CustomMappingFactory<,>).FullName}' or from '{typeof(StatefulCustomMappingFactory<,,>).FullName}'.");
            }
 
            return mappingFactory.CreateTransformer(env, contractName);
        }
 
        /// <summary>
        /// This is a 'stateful non-savable' version of the map transform: the mapping function is guaranteed to be invoked once per
        /// every row of the data set, in sequence; one user-defined state object will be allocated per cursor and passed to the
        /// map function every time. If <typeparamref name="TSrc"/>, <typeparamref name="TDst"/>, or
        /// <typeparamref name="TState"/> implement the <see cref="IDisposable" /> interface, they will be disposed after use.
        /// </summary>
        /// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the
        /// input <see cref="IDataView"/>.</typeparam>
        /// <typeparam name="TState">The type of the state object to allocate per cursor.</typeparam>
        /// <typeparam name="TDst">The type that describes what new columns are added by this transform.</typeparam>
        /// <param name="env">The host environment to use.</param>
        /// <param name="source">The input data to apply transformation to.</param>
        /// <param name="mapAction">The function that performs the transformation. The function should transform its <typeparamref name="TSrc"/>
        /// argument into its <typeparamref name="TDst"/> argument and can utilize the per-cursor <typeparamref name="TState"/> state.</param>
        /// <param name="initStateAction">The function that is called once per cursor to initialize state. Can be null.</param>
        /// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is
        /// inferred from the <typeparamref name="TSrc"/> type.</param>
        /// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is
        /// inferred from the <typeparamref name="TDst"/> type.</param>
        public static IDataView CreateMap<TSrc, TDst, TState>(IHostEnvironment env, IDataView source,
            Action<TSrc, TDst, TState> mapAction, Action<TState> initStateAction,
            SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
            where TSrc : class, new()
            where TDst : class, new()
            where TState : class, new()
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(source, nameof(source));
            env.CheckValue(mapAction, nameof(mapAction));
            env.CheckValueOrNull(initStateAction);
            env.CheckValueOrNull(inputSchemaDefinition);
            env.CheckValueOrNull(outputSchemaDefinition);
 
            return new StatefulFilterTransform<TSrc, TDst, TState>(env, source,
                (src, dst, state) =>
                {
                    mapAction(src, dst, state);
                    return true;
                }, initStateAction, inputSchemaDefinition, outputSchemaDefinition);
        }
 
        /// <summary>
        /// This creates a filter transform that can 'accept' or 'decline' any row of the data based on the contents of the row
        /// or state of the cursor.
        /// This is a 'stateful non-savable' version of the filter: the filter function is guaranteed to be invoked once per
        /// every row of the data set, in sequence (non-parallelizable); one user-defined state object will be allocated per cursor and passed to the
        /// filter function every time.
        /// If <typeparamref name="TSrc"/> or <typeparamref name="TState"/> implement the <see cref="IDisposable" /> interface, they will be disposed after use.
        /// </summary>
        /// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the
        /// input <see cref="IDataView"/>.</typeparam>
        /// <typeparam name="TState">The type of the state object to allocate per cursor.</typeparam>
        /// <param name="env">The host environment to use.</param>
        /// <param name="source">The input data to apply transformation to.</param>
        /// <param name="filterFunc">The user-defined function that determines whether to keep the row or discard it. First parameter
        /// is the current row's contents, the second parameter is the cursor-specific state object.</param>
        /// <param name="initStateAction">The function that is called once per cursor to initialize state. Can be null.</param>
        /// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is
        /// inferred from the <typeparamref name="TSrc"/> type.</param>
        /// <returns></returns>
        public static IDataView CreateFilter<TSrc, TState>(IHostEnvironment env, IDataView source,
            Func<TSrc, TState, bool> filterFunc, Action<TState> initStateAction, SchemaDefinition inputSchemaDefinition = null)
            where TSrc : class, new()
            where TState : class, new()
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(source, nameof(source));
            env.CheckValue(filterFunc, nameof(filterFunc));
            env.CheckValueOrNull(initStateAction);
            env.CheckValueOrNull(inputSchemaDefinition);
 
            return new StatefulFilterTransform<TSrc, object, TState>(env, source,
                (src, dst, state) => filterFunc(src, state), initStateAction, inputSchemaDefinition);
        }
    }
}