File: KeyToVectorMapping.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(KeyToBinaryVectorMappingTransformer.Summary, typeof(IDataTransform), typeof(KeyToBinaryVectorMappingTransformer), typeof(KeyToBinaryVectorMappingTransformer.Options), typeof(SignatureDataTransform),
    "Key To Binary Vector Transform", KeyToBinaryVectorMappingTransformer.UserName, "KeyToBinary", "ToBinaryVector", DocName = "transform/KeyToBinaryVectorTransform.md")]
 
[assembly: LoadableClass(KeyToBinaryVectorMappingTransformer.Summary, typeof(IDataTransform), typeof(KeyToBinaryVectorMappingTransformer), null, typeof(SignatureLoadDataTransform),
    "Key To Binary Vector Transform", KeyToBinaryVectorMappingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(KeyToBinaryVectorMappingTransformer.Summary, typeof(KeyToBinaryVectorMappingTransformer), null, typeof(SignatureLoadModel),
    KeyToBinaryVectorMappingTransformer.UserName, KeyToBinaryVectorMappingTransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IRowMapper), typeof(KeyToBinaryVectorMappingTransformer), null, typeof(SignatureLoadRowMapper),
   KeyToBinaryVectorMappingTransformer.UserName, KeyToBinaryVectorMappingTransformer.LoaderSignature)]
 
namespace Microsoft.ML.Transforms
{
    /// <summary>
    /// <see cref="ITransformer"/> resulting from fitting a <see cref="KeyToBinaryVectorMappingEstimator"/>.
    /// </summary>
    public sealed class KeyToBinaryVectorMappingTransformer : OneToOneTransformerBase
    {
        internal sealed class Options
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)",
                Name = "Column", ShortName = "col", SortOrder = 1)]
            public KeyToVectorMappingTransformer.Column[] Columns;
        }
 
        internal const string Summary = "Converts a key column to a binary encoded vector.";
        internal const string UserName = "KeyToBinaryVectorTransform";
        internal const string LoaderSignature = "KeyToBinaryTransform";
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "KEY2BINR",
                verWrittenCur: 0x00000001, // Initial
                verReadableCur: 0x00000001,
                verWeCanReadBack: 0x00000001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(KeyToBinaryVectorMappingTransformer).Assembly.FullName);
        }
 
        private const string RegistrationName = "KeyToBinary";
 
        /// <summary>
        /// The names of the output and input column pairs on which the transformation is performed.
        /// </summary>
        internal IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly();
 
        private string TestIsKey(DataViewType type)
        {
            if (type.GetItemType().GetKeyCount() > 0)
                return null;
            return "key type of known cardinality";
        }
 
        private protected override void CheckInputColumn(DataViewSchema inputSchema, int col, int srcCol)
        {
            var type = inputSchema[srcCol].Type;
            string reason = TestIsKey(type);
            if (reason != null)
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, reason, type.ToString());
        }
 
        internal KeyToBinaryVectorMappingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
            : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns)
        {
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
 
            // *** Binary format ***
            // <prefix handled in static Create method>
            // <base>
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
            SaveColumns(ctx);
        }
 
        // Factory method for SignatureLoadModel.
        private static KeyToBinaryVectorMappingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(RegistrationName);
 
            host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            return new KeyToBinaryVectorMappingTransformer(host, ctx);
        }
 
        private KeyToBinaryVectorMappingTransformer(IHost host, ModelLoadContext ctx)
            : base(host, ctx)
        {
        }
 
        private static IDataTransform Create(IHostEnvironment env, IDataView input, params (string outputColumnName, string inputColumnName)[] columns) =>
            new KeyToBinaryVectorMappingTransformer(env, columns).MakeDataTransform(input);
 
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(options, nameof(options));
            env.CheckValue(input, nameof(input));
 
            env.CheckValue(options.Columns, nameof(options.Columns));
            var cols = new (string outputColumnName, string inputColumnName)[options.Columns.Length];
            using (var ch = env.Start("ValidateArgs"))
            {
                for (int i = 0; i < cols.Length; i++)
                {
                    var item = options.Columns[i];
                    cols[i] = (item.Name, item.Source ?? item.Name);
                }
            }
            return new KeyToBinaryVectorMappingTransformer(env, cols).MakeDataTransform(input);
        }
 
        // Factory method for SignatureLoadDataTransform.
        private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
            => Create(env, ctx).MakeDataTransform(input);
 
        // Factory method for SignatureLoadRowMapper.
        private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
            => Create(env, ctx).MakeRowMapper(inputSchema);
 
        private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
 
        private sealed class Mapper : OneToOneMapperBase
        {
            private sealed class ColInfo
            {
                public readonly string Name;
                public readonly string InputColumnName;
                public readonly DataViewType TypeSrc;
 
                public ColInfo(string name, string inputColumnName, DataViewType type)
                {
                    Name = name;
                    InputColumnName = inputColumnName;
                    TypeSrc = type;
                }
            }
 
            private readonly KeyToBinaryVectorMappingTransformer _parent;
            private readonly ColInfo[] _infos;
            private readonly VectorDataViewType[] _types;
            private readonly int[] _bitsPerKey;
 
            public Mapper(KeyToBinaryVectorMappingTransformer parent, DataViewSchema inputSchema)
                : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
            {
                _parent = parent;
                _infos = CreateInfos(inputSchema);
                _types = new VectorDataViewType[_parent.ColumnPairs.Length];
                _bitsPerKey = new int[_parent.ColumnPairs.Length];
                for (int i = 0; i < _parent.ColumnPairs.Length; i++)
                {
                    //Add an additional bit for all 1s to represent missing values.
                    _bitsPerKey[i] = Utils.IbitHigh((uint)_infos[i].TypeSrc.GetItemType().GetKeyCount()) + 2;
                    Host.Assert(_bitsPerKey[i] > 0);
                    int srcValueCount = _infos[i].TypeSrc.GetValueCount();
                    if (srcValueCount == 1)
                        // Output is a single vector computed as the sum of the output indicator vectors.
                        _types[i] = new VectorDataViewType(NumberDataViewType.Single, _bitsPerKey[i]);
                    else
                        // Output is the concatenation of the multiple output indicator vectors.
                        _types[i] = new VectorDataViewType(NumberDataViewType.Single, srcValueCount, _bitsPerKey[i]);
                }
            }
            private ColInfo[] CreateInfos(DataViewSchema inputSchema)
            {
                Host.AssertValue(inputSchema);
                var infos = new ColInfo[_parent.ColumnPairs.Length];
                for (int i = 0; i < _parent.ColumnPairs.Length; i++)
                {
                    if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colSrc))
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName);
                    var type = inputSchema[colSrc].Type;
                    _parent.CheckInputColumn(inputSchema, i, colSrc);
                    infos[i] = new ColInfo(_parent.ColumnPairs[i].outputColumnName, _parent.ColumnPairs[i].inputColumnName, type);
                }
                return infos;
            }
 
            protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            {
                var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length];
                for (int i = 0; i < _parent.ColumnPairs.Length; i++)
                {
                    InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex);
                    Host.Assert(colIndex >= 0);
                    var builder = new DataViewSchema.Annotations.Builder();
                    AddMetadata(i, builder);
 
                    result[i] = new DataViewSchema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], builder.ToAnnotations());
                }
                return result;
            }
 
            private void AddMetadata(int iinfo, DataViewSchema.Annotations.Builder builder)
            {
                InputSchema.TryGetColumnIndex(_infos[iinfo].InputColumnName, out int srcCol);
                var inputMetadata = InputSchema[srcCol].Annotations;
                var srcType = _infos[iinfo].TypeSrc;
                // See if the source has key names.
 
                VectorDataViewType typeNames = null;
                int metaKeyValuesCol = 0;
                if (inputMetadata.Schema.TryGetColumnIndex(AnnotationUtils.Kinds.KeyValues, out metaKeyValuesCol))
                    typeNames = inputMetadata.Schema[metaKeyValuesCol].Type as VectorDataViewType;
                if (typeNames == null || !typeNames.IsKnownSize || !(typeNames.ItemType is TextDataViewType) ||
                    typeNames.Size != _infos[iinfo].TypeSrc.GetItemType().GetKeyCountAsInt32(Host))
                {
                    typeNames = null;
                }
 
                if (_infos[iinfo].TypeSrc is PrimitiveDataViewType)
                {
                    if (typeNames != null)
                    {
                        ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                        {
                            GenerateBitSlotName(iinfo, ref dst);
                        };
 
                        var slotNamesType = new VectorDataViewType(TextDataViewType.Instance, _types[iinfo].Dimensions);
                        builder.AddSlotNames(slotNamesType.Size, getter);
                    }
 
                    ValueGetter<bool> normalizeGetter = (ref bool dst) =>
                    {
                        dst = true;
                    };
                    builder.Add(AnnotationUtils.Kinds.IsNormalized, BooleanDataViewType.Instance, normalizeGetter);
                }
                else
                {
                    if (typeNames != null && _types[iinfo].IsKnownSize)
                    {
                        ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                        {
                            GetSlotNames(iinfo, ref dst);
                        };
                        var slotNamesType = new VectorDataViewType(TextDataViewType.Instance, _types[iinfo].Dimensions);
                        builder.Add(AnnotationUtils.Kinds.SlotNames, slotNamesType, getter);
                    }
                }
            }
 
            private void GenerateBitSlotName(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
            {
                const string slotNamePrefix = "Bit";
                var bldr = new BufferBuilder<ReadOnlyMemory<char>>(TextCombiner.Instance);
                bldr.Reset(_bitsPerKey[iinfo], true);
                for (int i = 0; i < _bitsPerKey[iinfo]; i++)
                    bldr.AddFeature(i, (slotNamePrefix + (_bitsPerKey[iinfo] - i - 1)).AsMemory());
 
                bldr.GetResult(ref dst);
            }
 
            private void GetSlotNames(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
            {
                Host.Assert(0 <= iinfo && iinfo < _infos.Length);
                Host.Assert(_types[iinfo].IsKnownSize);
 
                // Variable size should have thrown (by the caller).
                var typeSrc = _infos[iinfo].TypeSrc;
                var srcVectorSize = typeSrc.GetVectorSize();
                Host.Assert(srcVectorSize > 1);
 
                // Get the source slot names, defaulting to empty text.
                var namesSlotSrc = default(VBuffer<ReadOnlyMemory<char>>);
 
                var inputMetadata = InputSchema[_infos[iinfo].InputColumnName].Annotations;
                VectorDataViewType typeSlotSrc = null;
                if (inputMetadata != null)
                    typeSlotSrc = inputMetadata.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
                if (typeSlotSrc != null && typeSlotSrc.Size == srcVectorSize && typeSlotSrc.ItemType is TextDataViewType)
                {
                    inputMetadata.GetValue(AnnotationUtils.Kinds.SlotNames, ref namesSlotSrc);
                    Host.Check(namesSlotSrc.Length == srcVectorSize);
                }
                else
                    namesSlotSrc = VBufferUtils.CreateEmpty<ReadOnlyMemory<char>>(srcVectorSize);
 
                int slotLim = _types[iinfo].Size;
                Host.Assert(slotLim == (long)srcVectorSize * _bitsPerKey[iinfo]);
 
                var editor = VBufferEditor.Create(ref dst, slotLim);
 
                var sb = new StringBuilder();
                int slot = 0;
                VBuffer<ReadOnlyMemory<char>> bits = default;
                GenerateBitSlotName(iinfo, ref bits);
                foreach (var kvpSlot in namesSlotSrc.Items(all: true))
                {
                    Contracts.Assert(slot == (long)kvpSlot.Key * _bitsPerKey[iinfo]);
                    sb.Clear();
                    if (!kvpSlot.Value.IsEmpty)
                        sb.AppendMemory(kvpSlot.Value);
                    else
                        sb.Append('[').Append(kvpSlot.Key).Append(']');
                    sb.Append('.');
 
                    int len = sb.Length;
                    foreach (var key in bits.GetValues())
                    {
                        sb.Length = len;
                        sb.AppendMemory(key);
                        editor.Values[slot++] = sb.ToString().AsMemory();
                    }
                }
                Host.Assert(slot == slotLim);
 
                dst = editor.Commit();
            }
 
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
            {
                Host.AssertValue(input);
                Host.Assert(0 <= iinfo && iinfo < _infos.Length);
                disposer = null;
 
                var info = _infos[iinfo];
                if (!(info.TypeSrc is VectorDataViewType vectorType))
                    return MakeGetterOne(input, iinfo);
                return MakeGetterInd(input, iinfo, vectorType);
            }
 
            /// <summary>
            /// This is for the scalar case.
            /// </summary>
            private ValueGetter<VBuffer<float>> MakeGetterOne(DataViewRow input, int iinfo)
            {
                Host.AssertValue(input);
                Host.Assert(_infos[iinfo].TypeSrc is KeyDataViewType);
 
                int bitsPerKey = _bitsPerKey[iinfo];
                Host.Assert(bitsPerKey == _types[iinfo].Size);
 
                int dstLength = _types[iinfo].Size;
                Host.Assert(dstLength > 0);
                input.Schema.TryGetColumnIndex(_infos[iinfo].InputColumnName, out int srcCol);
                Host.Assert(srcCol >= 0);
                var getSrc = RowCursorUtils.GetGetterAs<uint>(NumberDataViewType.UInt32, input, srcCol);
                var src = default(uint);
                var bldr = new BufferBuilder<float>(R4Adder.Instance);
                return
                    (ref VBuffer<float> dst) =>
                    {
                        getSrc(ref src);
                        bldr.Reset(bitsPerKey, false);
                        EncodeValueToBinary(bldr, src, bitsPerKey, 0);
                        bldr.GetResult(ref dst);
 
                        Contracts.Assert(dst.Length == bitsPerKey);
                    };
            }
 
            /// <summary>
            /// This is for the indicator case - vector input and outputs should be concatenated.
            /// </summary>
            private ValueGetter<VBuffer<float>> MakeGetterInd(DataViewRow input, int iinfo, VectorDataViewType typeSrc)
            {
                Host.AssertValue(input);
                Host.AssertValue(typeSrc);
                Host.Assert(typeSrc.ItemType is KeyDataViewType);
 
                int cv = typeSrc.Size;
                Host.Assert(cv >= 0);
                input.Schema.TryGetColumnIndex(_infos[iinfo].InputColumnName, out int srcCol);
                Host.Assert(srcCol >= 0);
                var getSrc = RowCursorUtils.GetVecGetterAs<uint>(NumberDataViewType.UInt32, input, srcCol);
                var src = default(VBuffer<uint>);
                var bldr = new BufferBuilder<float>(R4Adder.Instance);
                int bitsPerKey = _bitsPerKey[iinfo];
                return
                    (ref VBuffer<float> dst) =>
                    {
                        getSrc(ref src);
                        Host.Check(src.Length == cv || cv == 0);
                        bldr.Reset(src.Length * bitsPerKey, false);
 
                        int index = 0;
                        foreach (uint value in src.DenseValues())
                        {
                            EncodeValueToBinary(bldr, value, bitsPerKey, index * bitsPerKey);
                            index++;
                        }
 
                        bldr.GetResult(ref dst);
 
                        Contracts.Assert(dst.Length == src.Length * bitsPerKey);
                    };
            }
 
            private void EncodeValueToBinary(BufferBuilder<float> bldr, uint value, int bitsToConsider, int startIndex)
            {
                Contracts.Assert(0 < bitsToConsider && bitsToConsider <= sizeof(uint) * 8);
                Contracts.Assert(startIndex >= 0);
 
                //Treat missing values, zero, as a special value of all 1s.
                value--;
                while (bitsToConsider > 0)
                    bldr.AddFeature(startIndex++, (value >> --bitsToConsider) & 1U);
            }
        }
    }
 
    /// <summary>
    /// Estimator for <see cref="KeyToBinaryVectorMappingTransformer"/>. Converts key types to their corresponding binary representation of the original value.
    /// </summary>
    /// <remarks>
    /// <format type="text/markdown"><![CDATA[
    ///
    /// ###  Estimator Characteristics
    /// |  |  |
    /// | -- | -- |
    /// | Does this estimator need to look at the data to train its parameters? | No |
    /// | Input column data type | [key](xref:Microsoft.Ml.Data.KeyDataViewType) or a known-size vector of keys. |
    /// | Output column data type | A known-size vector of [System.Single](xref:System.Single). |
    /// | Exportable to ONNX | No |
    ///
    /// Check the See Also section for links to usage examples.
    /// ]]></format>
    /// </remarks>
    /// <seealso cref="ConversionsCatalog.MapKeyToBinaryVector(TransformsCatalog.ConversionTransforms, string, string)"/>
    public sealed class KeyToBinaryVectorMappingEstimator : TrivialEstimator<KeyToBinaryVectorMappingTransformer>
    {
        internal KeyToBinaryVectorMappingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
            : this(env, new KeyToBinaryVectorMappingTransformer(env, columns))
        {
        }
 
        internal KeyToBinaryVectorMappingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null)
            : this(env, new KeyToBinaryVectorMappingTransformer(env, (outputColumnName, inputColumnName ?? outputColumnName)))
        {
        }
 
        private KeyToBinaryVectorMappingEstimator(IHostEnvironment env, KeyToBinaryVectorMappingTransformer transformer)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(KeyToBinaryVectorMappingEstimator)), transformer)
        {
        }
 
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);
            foreach (var colInfo in Transformer.Columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.inputColumnName, out var col))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName);
                if (!(col.ItemType is VectorDataViewType || col.ItemType is PrimitiveDataViewType))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName);
 
                var metadata = new List<SchemaShape.Column>();
                if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.KeyValues, out var keyMeta))
                    if (col.Kind != SchemaShape.Column.VectorKind.VariableVector && keyMeta.ItemType is TextDataViewType)
                        metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false));
                if (col.Kind == SchemaShape.Column.VectorKind.Scalar)
                    metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
                result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName,
                    col.Kind == SchemaShape.Column.VectorKind.VariableVector ? SchemaShape.Column.VectorKind.VariableVector : SchemaShape.Column.VectorKind.Vector,
                    NumberDataViewType.Single, false, new SchemaShape(metadata));
            }
 
            return new SchemaShape(result.Values);
        }
    }
 
}