File: ProduceIdTransform.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(ProduceIdTransform.Summary, typeof(ProduceIdTransform), typeof(ProduceIdTransform.Arguments), typeof(SignatureDataTransform),
    "", "ProduceIdTransform", "ProduceId")]
 
[assembly: LoadableClass(ProduceIdTransform.Summary, typeof(ProduceIdTransform), null, typeof(SignatureLoadDataTransform),
    "Produce ID Transform", ProduceIdTransform.LoaderSignature)]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// Produces a column with the cursor's ID as a column. This can be useful for diagnostic purposes.
    ///
    /// This class will obviously generate different data given different IDs. So, if you save data to
    /// some other file, then apply this transform to that dataview, it may of course have a different
    /// result. This is distinct from most transforms that produce results based on data alone.
    /// </summary>
    internal sealed class ProduceIdTransform : RowToRowTransformBase
    {
        public sealed class Arguments
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Name of the column to produce", ShortName = "col", SortOrder = 1)]
            public string Column = "Id";
        }
 
        private sealed class Bindings : ColumnBindingsBase
        {
            public Bindings(DataViewSchema input, bool user, string name)
                : base(input, user, name)
            {
                Contracts.Assert(InfoCount == 1);
            }
 
            protected override DataViewType GetColumnTypeCore(int iinfo)
            {
                Contracts.Assert(iinfo == 0);
                return RowIdDataViewType.Instance;
            }
 
            public static Bindings Create(ModelLoadContext ctx, DataViewSchema input)
            {
                Contracts.AssertValue(ctx);
                Contracts.AssertValue(input);
 
                // *** Binary format ***
                // int: id of output column name
                string name = ctx.LoadNonEmptyString();
                return new Bindings(input, true, name);
            }
 
            internal void Save(ModelSaveContext ctx)
            {
                Contracts.AssertValue(ctx);
 
                // *** Binary format ***
                // int: id of output column name
                ctx.SaveNonEmptyString(GetColumnNameCore(0));
            }
 
            public Func<int, bool> GetDependencies(Func<int, bool> predicate)
            {
                Contracts.AssertValue(predicate);
 
                var active = GetActiveInput(predicate);
                Contracts.Assert(active.Length == Input.Count);
                return col => 0 <= col && col < active.Length && active[col];
            }
        }
 
        internal const string Summary = "Produces a new column with the row ID.";
        internal const string LoaderSignature = "ProduceIdTransform";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "PR ID XF",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(ProduceIdTransform).Assembly.FullName);
        }
 
        private readonly Bindings _bindings;
 
        public override DataViewSchema OutputSchema => _bindings.AsSchema;
 
        public override bool CanShuffle { get { return Source.CanShuffle; } }
 
        public ProduceIdTransform(IHostEnvironment env, Arguments args, IDataView input)
            : base(env, LoaderSignature, input)
        {
            Host.CheckValue(args, nameof(args));
            Host.CheckNonWhiteSpace(args.Column, nameof(args.Column));
 
            _bindings = new Bindings(input.Schema, true, args.Column);
        }
 
        private ProduceIdTransform(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            Host.AssertValue(ctx);
 
            // *** Binary format ***
            // bindings
            _bindings = Bindings.Create(ctx, Source.Schema);
        }
 
        public static ProduceIdTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(LoaderSignature);
            h.CheckValue(ctx, nameof(ctx));
            h.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());
            return h.Apply("Loading Model", ch => new ProduceIdTransform(h, ctx, input));
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // bindings
            _bindings.Save(ctx);
        }
 
        protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            Host.AssertValueOrNull(rand);
 
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
            var inputPred = _bindings.GetDependencies(predicate);
            var inputCols = Source.Schema.Where(x => inputPred(x.Index));
            var input = Source.GetRowCursor(inputCols, rand);
            bool active = predicate(_bindings.MapIinfoToCol(0));
 
            return new Cursor(Host, _bindings, input, active);
        }
 
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            Host.CheckValueOrNull(rand);
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
            var inputPred = _bindings.GetDependencies(predicate);
 
            var inputCols = Source.Schema.Where(x => inputPred(x.Index));
            DataViewRowCursor[] cursors = Source.GetRowCursorSet(inputCols, n, rand);
            bool active = predicate(_bindings.MapIinfoToCol(0));
            for (int c = 0; c < cursors.Length; ++c)
                cursors[c] = new Cursor(Host, _bindings, cursors[c], active);
            return cursors;
        }
 
        protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
        {
            Host.AssertValue(predicate, "predicate");
            return null;
        }
 
        private sealed class Cursor : SynchronizedCursorBase
        {
            private readonly Bindings _bindings;
            private readonly bool _active;
 
            public override DataViewSchema Schema => _bindings.AsSchema;
 
            public Cursor(IChannelProvider provider, Bindings bindings, DataViewRowCursor input, bool active)
                : base(provider, input)
            {
                Ch.CheckValue(bindings, nameof(bindings));
                _bindings = bindings;
                _active = active;
            }
 
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                Ch.CheckParam(column.Index < _bindings.ColumnCount, nameof(column));
                bool isSrc;
                int index = _bindings.MapColumnIndex(out isSrc, column.Index);
                if (isSrc)
                    return Input.IsColumnActive(Input.Schema[index]);
                Ch.Assert(index == 0);
                return _active;
            }
 
            /// <summary>
            /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
            /// This throws if the column is not active in this row, or if the type
            /// <typeparamref name="TValue"/> differs from this column's type.
            /// </summary>
            /// <typeparam name="TValue"> is the column's content type.</typeparam>
            /// <param name="column"> is the output column whose getter should be returned.</param>
            public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
            {
                Ch.CheckParam(column.Index < _bindings.ColumnCount, nameof(column));
                Ch.CheckParam(IsColumnActive(column), nameof(column.Index));
                bool isSrc;
                int index = _bindings.MapColumnIndex(out isSrc, column.Index);
                if (isSrc)
                    return Input.GetGetter<TValue>(Input.Schema[index]);
                Ch.Assert(index == 0);
                Delegate idGetter = Input.GetIdGetter();
                Ch.AssertValue(idGetter);
                var fn = idGetter as ValueGetter<TValue>;
                if (fn == null)
                    throw Ch.Except($"Invalid TValue in GetGetter: '{typeof(TValue)}', " +
                            $"expected type: '{idGetter.GetType().GetGenericArguments().First()}'.");
                return fn;
            }
        }
    }
}