File: MissingValueIndicatorTransform.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(typeof(MissingValueIndicatorTransform), typeof(MissingValueIndicatorTransform.Arguments), typeof(SignatureDataTransform),
    "", "MissingValueIndicatorTransform", "MissingValueTransform", "MissingTransform", "Missing")]
 
[assembly: LoadableClass(typeof(MissingValueIndicatorTransform), null, typeof(SignatureLoadDataTransform),
    "Missing Value Indicator Transform", MissingValueIndicatorTransform.LoaderSignature, "MissingFeatureFunction")]
 
namespace Microsoft.ML.Transforms
{
    internal sealed class MissingValueIndicatorTransform : OneToOneTransformBase
    {
        public sealed class Column : OneToOneColumn
        {
            internal static Column Parse(string str)
            {
                Contracts.AssertNonEmpty(str);
 
                var res = new Column();
                if (res.TryParse(str))
                    return res;
                return null;
            }
 
            internal bool TryUnparse(StringBuilder sb)
            {
                Contracts.AssertValue(sb);
                return TryUnparseCore(sb);
            }
        }
 
        public sealed class Arguments
        {
            [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", Name = "Column", ShortName = "col", SortOrder = 1)]
            public Column[] Columns;
        }
 
        public const string LoaderSignature = "MissingIndicatorFunction";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "MISFEATF",
                // verWrittenCur: 0x00010001, // Initial
                verWrittenCur: 0x00010002, // Changed to OneToOneColumn
                verReadableCur: 0x00010002,
                verWeCanReadBack: 0x00010002,
                loaderSignature: LoaderSignature,
                // This is an older name and can be removed once we don't care about old code
                // being able to load this.
                loaderSignatureAlt: "MissingFeatureFunction",
                loaderAssemblyName: typeof(MissingValueIndicatorTransform).Assembly.FullName);
        }
 
        private const string RegistrationName = "MissingIndicator";
 
        private const string IndicatorSuffix = "_Indicator";
 
        // The output column types, parallel to Infos.
        private readonly VectorDataViewType[] _types;
 
        /// <summary>
        /// Public constructor corresponding to SignatureDataTransform.
        /// </summary>
        public MissingValueIndicatorTransform(IHostEnvironment env, Arguments args, IDataView input)
            : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Columns,
                input, TestIsFloatItem)
        {
            Host.AssertNonEmpty(Infos);
            Host.Assert(Infos.Length == Utils.Size(args.Columns));
 
            _types = GetTypesAndMetadata();
        }
 
        private MissingValueIndicatorTransform(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, ctx, input, TestIsFloatItem)
        {
            Host.AssertValue(ctx);
 
            // *** Binary format ***
            // <prefix handled in static Create method>
            // <base>
            Host.AssertNonEmpty(Infos);
 
            _types = GetTypesAndMetadata();
        }
 
        public static MissingValueIndicatorTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);
            h.CheckValue(ctx, nameof(ctx));
            h.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());
            return h.Apply("Loading Model",
                ch =>
                {
                    // *** Binary format ***
                    // int: sizeof(Float)
                    // <remainder handled in ctors>
                    int cbFloat = ctx.Reader.ReadInt32();
                    ch.CheckDecode(cbFloat == sizeof(float));
                    return new MissingValueIndicatorTransform(h, ctx, input);
                });
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // int: sizeof(Float)
            // <base>
            ctx.Writer.Write(sizeof(float));
            SaveBase(ctx);
        }
 
        private VectorDataViewType[] GetTypesAndMetadata()
        {
            var md = Metadata;
            var types = new VectorDataViewType[Infos.Length];
            for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
            {
                var type = Infos[iinfo].TypeSrc;
 
                // This ensures that our feature count doesn't overflow.
                Host.Check(type.GetValueCount() < int.MaxValue / 2);
 
                if (!(type is VectorDataViewType vectorType))
                    types[iinfo] = new VectorDataViewType(NumberDataViewType.Single, 2);
                else
                {
                    types[iinfo] = new VectorDataViewType(NumberDataViewType.Single, vectorType.Dimensions.Add(2));
 
                    // Produce slot names metadata iff the source has (valid) slot names.
                    VectorDataViewType typeNames;
                    if (!vectorType.IsKnownSize ||
                        (typeNames = Source.Schema[Infos[iinfo].Source].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType) == null ||
                        typeNames.Size != vectorType.Size ||
                        !(typeNames.ItemType is TextDataViewType))
                    {
                        continue;
                    }
                }
 
                // Add slot names metadata.
                using (var bldr = md.BuildMetadata(iinfo))
                {
                    bldr.AddGetter<VBuffer<ReadOnlyMemory<char>>>(AnnotationUtils.Kinds.SlotNames,
                        AnnotationUtils.GetNamesType(types[iinfo].Size), GetSlotNames);
                }
            }
            md.Seal();
            return types;
        }
 
        protected override DataViewType GetColumnTypeCore(int iinfo)
        {
            Host.Assert(0 <= iinfo && iinfo < Infos.Length);
            return _types[iinfo];
        }
 
        private void GetSlotNames(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
        {
            Host.Assert(0 <= iinfo && iinfo < Infos.Length);
 
            int size = _types[iinfo].Size;
            if (size == 0)
                throw AnnotationUtils.ExceptGetAnnotation();
 
            var editor = VBufferEditor.Create(ref dst, size);
 
            var type = Infos[iinfo].TypeSrc;
            if (!(type is VectorDataViewType srcVectorType))
            {
                Host.Assert(_types[iinfo].Size == 2);
                var columnName = Source.Schema[Infos[iinfo].Source].Name;
                editor.Values[0] = columnName.AsMemory();
                editor.Values[1] = (columnName + IndicatorSuffix).AsMemory();
            }
            else
            {
                Host.Assert(srcVectorType.IsKnownSize);
                Host.Assert(size == 2 * srcVectorType.Size);
 
                // REVIEW: Do we need to verify that there is metadata or should we just call GetMetadata?
                var typeNames = Source.Schema[Infos[iinfo].Source].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type as VectorDataViewType;
                if (typeNames == null || typeNames.Size != srcVectorType.Size || !(typeNames.ItemType is TextDataViewType))
                    throw AnnotationUtils.ExceptGetAnnotation();
 
                var names = default(VBuffer<ReadOnlyMemory<char>>);
                Source.Schema[Infos[iinfo].Source].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref names);
 
                // We both assert and check. If this fails, there is a bug somewhere (possibly in this code
                // but more likely in the implementation of Base. On the other hand, we don't want to proceed
                // if we've received garbage.
                Host.Check(names.Length == srcVectorType.Size, "Unexpected slot name vector size");
 
                var sb = new StringBuilder();
                int slot = 0;
                foreach (var kvp in names.Items(all: true))
                {
                    Host.Assert(0 <= slot && slot < size);
                    Host.Assert(slot % 2 == 0);
 
                    sb.Clear();
                    if (kvp.Value.IsEmpty)
                        sb.Append('[').Append(slot / 2).Append(']');
                    else
                        sb.AppendMemory(kvp.Value);
 
                    int len = sb.Length;
                    sb.Append(IndicatorSuffix);
                    var str = sb.ToString();
 
                    editor.Values[slot++] = str.AsMemory().Slice(0, len);
                    editor.Values[slot++] = str.AsMemory();
                }
                Host.Assert(slot == size);
            }
 
            dst = editor.Commit();
        }
 
        protected override Delegate GetGetterCore(IChannel ch, DataViewRow input, int iinfo, out Action disposer)
        {
            Host.AssertValueOrNull(ch);
            Host.AssertValue(input);
            Host.Assert(0 <= iinfo && iinfo < Infos.Length);
            disposer = null;
 
            ValueGetter<VBuffer<float>> del;
            if (Infos[iinfo].TypeSrc is VectorDataViewType)
            {
                var getSrc = GetSrcGetter<VBuffer<float>>(input, iinfo);
                del =
                    (ref VBuffer<float> dst) =>
                    {
                        getSrc(ref dst);
                        FillValues(Host, ref dst);
                    };
            }
            else
            {
                var getSrc = GetSrcGetter<float>(input, iinfo);
                del =
                    (ref VBuffer<float> dst) =>
                    {
                        var src = default(float);
                        getSrc(ref src);
                        FillValues(src, ref dst);
                        Host.Assert(dst.Length == 2);
                    };
            }
            return del;
        }
 
        private static void FillValues(float input, ref VBuffer<float> result)
        {
            if (input == 0)
            {
                VBufferUtils.Resize(ref result, 2, 0);
                return;
            }
 
            var editor = VBufferEditor.Create(ref result, 2, 1);
            if (float.IsNaN(input))
            {
                editor.Values[0] = 1;
                editor.Indices[0] = 1;
            }
            else
            {
                editor.Values[0] = input;
                editor.Indices[0] = 0;
            }
 
            result = editor.Commit();
        }
 
        // This converts in place.
        private static void FillValues(IExceptionContext ectx, ref VBuffer<float> buffer)
        {
            int size = buffer.Length;
            ectx.Check(0 <= size && size < int.MaxValue / 2);
 
            var values = buffer.GetValues();
            var editor = VBufferEditor.Create(ref buffer, size * 2, values.Length);
            int iivDst = 0;
            if (buffer.IsDense)
            {
                // Currently, it's dense. We always produce sparse.
 
                for (int ivSrc = 0; ivSrc < values.Length; ivSrc++)
                {
                    ectx.Assert(iivDst <= ivSrc);
                    var val = values[ivSrc];
                    if (val == 0)
                        continue;
                    if (float.IsNaN(val))
                    {
                        editor.Values[iivDst] = 1;
                        editor.Indices[iivDst] = 2 * ivSrc + 1;
                    }
                    else
                    {
                        editor.Values[iivDst] = val;
                        editor.Indices[iivDst] = 2 * ivSrc;
                    }
                    iivDst++;
                }
            }
            else
            {
                // Currently, it's sparse.
 
                var indices = buffer.GetIndices();
                int ivPrev = -1;
                for (int iivSrc = 0; iivSrc < values.Length; iivSrc++)
                {
                    ectx.Assert(iivDst <= iivSrc);
                    var val = values[iivSrc];
                    if (val == 0)
                        continue;
                    int iv = indices[iivSrc];
                    ectx.Assert(ivPrev < iv && iv < size);
                    ivPrev = iv;
                    if (float.IsNaN(val))
                    {
                        editor.Values[iivDst] = 1;
                        editor.Indices[iivDst] = 2 * iv + 1;
                    }
                    else
                    {
                        editor.Values[iivDst] = val;
                        editor.Indices[iivDst] = 2 * iv;
                    }
                    iivDst++;
                }
            }
 
            ectx.Assert(0 <= iivDst && iivDst <= values.Length);
            buffer = editor.CommitTruncated(iivDst);
        }
    }
}