File: Transforms\ValueToKeyMappingTransformerImpl.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using System.Text;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Transforms
{
    public sealed partial class ValueToKeyMappingTransformer
    {
        /// <summary>
        /// These are objects shared by both the scalar and vector implementations of <see cref="Trainer"/>
        /// to accumulate individual scalar objects, and facilitate the creation of a <see cref="TermMap"/>.
        /// </summary>
        private abstract class Builder
        {
            private static readonly FuncStaticMethodInfo1<PrimitiveDataViewType, bool, Builder> _createCoreMethodInfo
                = new FuncStaticMethodInfo1<PrimitiveDataViewType, bool, Builder>(CreateCore<int>);
 
            /// <summary>
            /// The item type we are building into a term map.
            /// </summary>
            public readonly PrimitiveDataViewType ItemType;
 
            /// <summary>
            /// The number of items that would be in the map if created right now.
            /// </summary>
            public abstract int Count { get; }
 
            protected Builder(PrimitiveDataViewType type)
            {
                Contracts.AssertValue(type);
                ItemType = type;
            }
 
            public static Builder Create(DataViewType type, ValueToKeyMappingEstimator.KeyOrdinality sortOrder)
            {
                Contracts.AssertValue(type);
                Contracts.Assert(type is VectorDataViewType || type is PrimitiveDataViewType);
                // Right now we have only two. This "public" interface externally looks like it might
                // accept any value, but currently the internal implementations of Builder are split
                // along this being a purely binary option, for now (though this can easily change
                // with mot implementations of Builder).
                Contracts.Assert(sortOrder == ValueToKeyMappingEstimator.KeyOrdinality.ByOccurrence || sortOrder == ValueToKeyMappingEstimator.KeyOrdinality.ByValue);
                bool sorted = sortOrder == ValueToKeyMappingEstimator.KeyOrdinality.ByValue;
 
                PrimitiveDataViewType itemType = type.GetItemType() as PrimitiveDataViewType;
                Contracts.AssertValue(itemType);
                if (itemType is TextDataViewType)
                    return new TextImpl(sorted);
                return Utils.MarshalInvoke(_createCoreMethodInfo, itemType.RawType, itemType, sorted);
            }
 
            private static Builder CreateCore<T>(PrimitiveDataViewType type, bool sorted)
                where T : IEquatable<T>, IComparable<T>
            {
                Contracts.AssertValue(type);
                Contracts.Assert(type.RawType == typeof(T));
 
                // If this is a type with NA values, we should ignore those NA values for the purpose
                // of building our term dictionary. For the other types (practically, only the UX types),
                // we should ignore nothing.
                InPredicate<T> mapsToMissing;
                if (!Data.Conversion.Conversions.DefaultInstance.TryGetIsNAPredicate(type, out mapsToMissing))
                    mapsToMissing = (in T val) => false;
                return new Impl<T>(type, mapsToMissing, sorted);
            }
 
            /// <summary>
            /// Called at the end of training, to get the final mapper object.
            /// </summary>
            public abstract TermMap Finish();
 
            /// <summary>
            /// Handling for the "terms" arg.
            /// </summary>
            /// <param name="terms">The input terms argument</param>
            /// <param name="ch">The channel against which to report errors and warnings</param>
            public abstract void ParseAddTermArg(ref ReadOnlyMemory<char> terms, IChannel ch);
 
            /// <summary>
            /// Handling for the "term" arg.
            /// </summary>
            /// <param name="terms">The input terms argument</param>
            /// <param name="ch">The channel against which to report errors and warnings</param>
            public abstract void ParseAddTermArg(string[] terms, IChannel ch);
 
            private sealed class TextImpl : Builder<ReadOnlyMemory<char>>
            {
                private readonly NormStr.Pool _pool;
                private readonly bool _sorted;
 
                public override int Count
                {
                    get { return _pool.Count; }
                }
 
                public TextImpl(bool sorted)
                    : base(TextDataViewType.Instance)
                {
                    _pool = new NormStr.Pool();
                    _sorted = sorted;
                }
 
                public override bool TryAdd(in ReadOnlyMemory<char> val)
                {
                    if (val.IsEmpty)
                        return false;
                    int count = _pool.Count;
                    return _pool.Add(val).Id == count;
                }
 
                public override TermMap Finish()
                {
                    if (!_sorted || _pool.Count <= 1)
                        return new TermMap.TextImpl(_pool);
                    // REVIEW: Should write a Sort method in NormStr.Pool to make sorting more memory efficient.
                    var perm = Utils.GetIdentityPermutation(_pool.Count);
                    Comparison<int> comp = (i, j) => _pool.GetNormStrById(i).Value.Span.CompareTo(_pool.GetNormStrById(j).Value.Span, StringComparison.Ordinal);
                    Array.Sort(perm, comp);
 
                    var sortedPool = new NormStr.Pool();
                    for (int i = 0; i < perm.Length; ++i)
                    {
                        var nstr = sortedPool.Add(_pool.GetNormStrById(perm[i]).Value);
                        Contracts.Assert(nstr.Id == i);
                        Contracts.Assert(i == 0 || sortedPool.GetNormStrById(i - 1).Value.Span.CompareTo(sortedPool.GetNormStrById(i).Value.Span, StringComparison.Ordinal) < 0);
                    }
                    Contracts.Assert(sortedPool.Count == _pool.Count);
                    return new TermMap.TextImpl(sortedPool);
                }
            }
 
            /// <summary>
            /// The sorted builder outputs things so that the keys are in sorted order.
            /// </summary>
            private sealed class Impl<T> : Builder<T>
                where T : IEquatable<T>, IComparable<T>
            {
                // Because we can't know the actual mapping till we finish.
                private readonly HashArray<T> _values;
                private readonly InPredicate<T> _mapsToMissing;
                private readonly bool _sort;
 
                public override int Count
                {
                    get { return _values.Count; }
                }
 
                /// <summary>
                /// Instantiates.
                /// </summary>
                /// <param name="type">The type we are mapping</param>
                /// <param name="mapsToMissing">This indicates whether a given value will map
                /// to the missing value. If this returns true for a value then we do not attempt
                /// to store it in the map.</param>
                /// <param name="sort">Indicates whether to sort mapping IDs by input values.</param>
                public Impl(PrimitiveDataViewType type, InPredicate<T> mapsToMissing, bool sort)
                    : base(type)
                {
                    Contracts.Assert(type.RawType == typeof(T));
                    Contracts.AssertValue(mapsToMissing);
 
                    _values = new HashArray<T>();
                    _mapsToMissing = mapsToMissing;
                    _sort = sort;
                }
 
                public override bool TryAdd(in T val)
                {
                    return !_mapsToMissing(in val) && _values.TryAdd(val);
                }
 
                public override TermMap Finish()
                {
                    if (_sort)
                        _values.Sort();
                    return new TermMap.HashArrayImpl<T>(ItemType, _values);
                }
            }
        }
 
        private abstract class Builder<T> : Builder
        {
            protected Builder(PrimitiveDataViewType type)
                : base(type)
            {
            }
 
            /// <summary>
            /// Ensures that the item is in the set. Returns true iff it added the item.
            /// </summary>
            /// <param name="val">The value to consider</param>
            public abstract bool TryAdd(in T val);
 
            /// <summary>
            /// Handling for the "terms" arg.
            /// </summary>
            /// <param name="terms">The input terms argument</param>
            /// <param name="ch">The channel against which to report errors and warnings</param>
            public override void ParseAddTermArg(ref ReadOnlyMemory<char> terms, IChannel ch)
            {
                T val;
                var tryParse = Data.Conversion.Conversions.DefaultInstance.GetTryParseConversion<T>(ItemType);
                for (bool more = true; more;)
                {
                    ReadOnlyMemory<char> term;
                    more = ReadOnlyMemoryUtils.SplitOne(terms, ',', out term, out terms);
                    term = ReadOnlyMemoryUtils.TrimSpaces(term);
                    if (term.IsEmpty)
                        ch.Warning("Empty strings ignored in 'terms' specification");
                    else if (!tryParse(in term, out val))
                        throw ch.Except($"Item '{term}' in 'terms' specification could not be parsed as '{ItemType}'");
                    else if (!TryAdd(in val))
                        ch.Warning($"Duplicate item '{term}' ignored in 'terms' specification", term);
                }
 
                if (Count == 0)
                    throw ch.ExceptUserArg(nameof(Options.Term), "Nothing parsed as '{0}'", ItemType);
            }
 
            /// <summary>
            /// Handling for the "term" arg.
            /// </summary>
            /// <param name="terms">The input terms argument</param>
            /// <param name="ch">The channel against which to report errors and warnings</param>
            public override void ParseAddTermArg(string[] terms, IChannel ch)
            {
                T val;
                var tryParse = Data.Conversion.Conversions.DefaultInstance.GetTryParseConversion<T>(ItemType);
                foreach (var sterm in terms)
                {
                    ReadOnlyMemory<char> term = sterm.AsMemory();
                    term = ReadOnlyMemoryUtils.TrimSpaces(term);
                    if (term.IsEmpty)
                        ch.Warning("Empty strings ignored in 'term' specification");
                    else if (!tryParse(in term, out val))
                        ch.Warning("Item '{0}' ignored in 'term' specification since it could not be parsed as '{1}'", term, ItemType);
                    else if (!TryAdd(in val))
                        ch.Warning("Duplicate item '{0}' ignored in 'term' specification", term);
                }
 
                if (Count == 0)
                    throw ch.ExceptUserArg(nameof(Options.Term), "Nothing parsed as '{0}'", ItemType);
            }
        }
 
        /// <summary>
        /// The trainer is an object that given an <see cref="Builder"/> instance, maps a particular
        /// input, whether it be scalar or vector, into this and allows us to continue training on it.
        /// </summary>
        private abstract class Trainer
        {
            private readonly Builder _bldr;
            private int _remaining;
 
            public int Count { get { return _bldr.Count; } }
 
            private Trainer(Builder bldr, int max)
            {
                Contracts.AssertValue(bldr);
                Contracts.Assert(max >= 0);
                _bldr = bldr;
                _remaining = max;
            }
 
            /// <summary>
            /// Creates an instance of <see cref="Trainer"/> appropriate for the type at a given
            /// row and column.
            /// </summary>
            /// <param name="row">The row to fetch from</param>
            /// <param name="col">The column to get the getter from</param>
            /// <param name="count">The maximum count of items to map</param>
            /// <param name="autoConvert">Whether we attempt to automatically convert
            /// the input type to the desired type</param>
            /// <param name="bldr">The builder we add items to</param>
            /// <returns>An associated training pipe</returns>
            public static Trainer Create(DataViewRow row, int col, bool autoConvert, int count, Builder bldr)
            {
                Contracts.AssertValue(row);
                var schema = row.Schema;
                Contracts.Assert(0 <= col && col < schema.Count);
                Contracts.Assert(count > 0);
                Contracts.AssertValue(bldr);
 
                var type = schema[col].Type;
                Contracts.Assert(autoConvert || bldr.ItemType == type.GetItemType());
                // Auto conversion should only be possible when the type is text.
                Contracts.Assert(type is TextDataViewType || !autoConvert);
                if (type is VectorDataViewType)
                    return Utils.MarshalInvoke(CreateVec<int>, bldr.ItemType.RawType, row, col, count, bldr);
                return Utils.MarshalInvoke(CreateOne<int>, bldr.ItemType.RawType, row, col, autoConvert, count, bldr);
            }
 
            private static Trainer CreateOne<T>(DataViewRow row, int col, bool autoConvert, int count, Builder bldr)
            {
                Contracts.AssertValue(row);
                Contracts.AssertValue(bldr);
                Contracts.Assert(bldr is Builder<T>);
                var bldrT = (Builder<T>)bldr;
 
                ValueGetter<T> inputGetter;
                if (autoConvert)
                    inputGetter = RowCursorUtils.GetGetterAs<T>(bldr.ItemType, row, col);
                else
                    inputGetter = row.GetGetter<T>(row.Schema[col]);
 
                return new ImplOne<T>(inputGetter, count, bldrT);
            }
 
            private static Trainer CreateVec<T>(DataViewRow row, int col, int count, Builder bldr)
            {
                Contracts.AssertValue(row);
                Contracts.AssertValue(bldr);
                Contracts.Assert(bldr is Builder<T>);
                var bldrT = (Builder<T>)bldr;
 
                var inputGetter = row.GetGetter<VBuffer<T>>(row.Schema[col]);
                return new ImplVec<T>(inputGetter, count, bldrT);
            }
 
            /// <summary>
            /// Indicates to the <see cref="Trainer"/> that we have reached a new row and should consider
            /// what to do with these values. Returns false if we have determined that it is no longer necessary
            /// to call this train, because we've already accumulated the maximum number of values.
            /// </summary>
            public abstract bool ProcessRow();
 
            /// <summary>
            /// Returns a <see cref="TermMap"/> over the items in this column. Note that even if this
            /// was trained over a vector valued column, the particular implementation returned here
            /// should be a mapper over the item type.
            /// </summary>
            public TermMap Finish()
            {
                return _bldr.Finish();
            }
 
            private sealed class ImplOne<T> : Trainer
            {
                private readonly ValueGetter<T> _getter;
                private T _val;
                private new readonly Builder<T> _bldr;
 
                public ImplOne(ValueGetter<T> getter, int max, Builder<T> bldr)
                    : base(bldr, max)
                {
                    Contracts.AssertValue(getter);
                    Contracts.AssertValue(bldr);
                    _getter = getter;
                    _bldr = bldr;
                }
 
                public sealed override bool ProcessRow()
                {
                    Contracts.Assert(_remaining >= 0);
                    if (_remaining <= 0)
                        return false;
                    _getter(ref _val);
                    return !_bldr.TryAdd(in _val) || --_remaining > 0;
                }
            }
 
            private sealed class ImplVec<T> : Trainer
            {
                private readonly ValueGetter<VBuffer<T>> _getter;
                private VBuffer<T> _val;
                private new readonly Builder<T> _bldr;
                private bool _addedDefaultFromSparse;
 
                public ImplVec(ValueGetter<VBuffer<T>> getter, int max, Builder<T> bldr)
                    : base(bldr, max)
                {
                    Contracts.AssertValue(getter);
                    Contracts.AssertValue(bldr);
                    _getter = getter;
                    _bldr = bldr;
                }
 
                private bool AccumAndDecrement(in T val)
                {
                    Contracts.Assert(_remaining > 0);
                    return !_bldr.TryAdd(in val) || --_remaining > 0;
                }
 
                public sealed override bool ProcessRow()
                {
                    Contracts.Assert(_remaining >= 0);
                    if (_remaining <= 0)
                        return false;
                    _getter(ref _val);
                    var values = _val.GetValues();
                    if (_val.IsDense || _addedDefaultFromSparse)
                    {
                        for (int i = 0; i < values.Length; ++i)
                        {
                            if (!AccumAndDecrement(in values[i]))
                                return false;
                        }
                        return true;
                    }
                    // The vector is sparse, and we have not yet tried adding the implicit default value.
                    // Because sparse vectors are supposed to be functionally the same as dense vectors
                    // and also because the order in which we see items matters for the final mapping, we
                    // must first add the first explicit entries where indices[i]==i, then add the default
                    // immediately following that, then continue with the remainder of the items. Note
                    // that the builder is taking care of the case where default maps to missing. Also
                    // note that this code is called on at most one row per column, so we aren't terribly
                    // excited about the slight inefficiency of that first if check.
                    Contracts.Assert(!_val.IsDense && !_addedDefaultFromSparse);
                    T def = default(T);
                    var valIndices = _val.GetIndices();
                    for (int i = 0; i < values.Length; ++i)
                    {
                        if (!_addedDefaultFromSparse && valIndices[i] != i)
                        {
                            _addedDefaultFromSparse = true;
                            if (!AccumAndDecrement(in def))
                                return false;
                        }
                        if (!AccumAndDecrement(in values[i]))
                            return false;
                    }
                    if (!_addedDefaultFromSparse)
                    {
                        _addedDefaultFromSparse = true;
                        if (!AccumAndDecrement(in def))
                            return false;
                    }
                    return true;
                }
            }
        }
 
        private enum MapType : byte
        {
            Text = 0,
            Codec = 1,
        }
 
        /// <summary>
        /// Given this instance, bind it to a particular input column. This allows us to service
        /// requests on the input dataset. This should throw an error if we attempt to bind this
        /// to the wrong type of item.
        /// </summary>
        private static BoundTermMap Bind(IHostEnvironment env, DataViewSchema schema, TermMap unbound, ColInfo[] infos, bool[] textMetadata, int iinfo)
        {
            env.Assert(0 <= iinfo && iinfo < infos.Length);
 
            var info = infos[iinfo];
            var inType = info.TypeSrc.GetItemType();
            if (!inType.Equals(unbound.ItemType))
            {
                throw env.Except("Could not apply a map over type '{0}' to column '{1}' since it has type '{2}'",
                    unbound.ItemType, info.Name, inType);
            }
            return BoundTermMap.Create(env, schema, unbound, infos, textMetadata, iinfo);
        }
 
        /// <summary>
        /// A map is an object capable of creating the association from an input type, to an output
        /// type. The input type, whatever it is, must have <see cref="ItemType"/> as its input item
        /// type, and will produce either <see cref="OutputType"/>, or a vector type with that output
        /// type if the input was a vector.
        ///
        /// Note that instances of this class can be shared among multiple <see cref="ValueToKeyMappingTransformer"/>
        /// instances. To associate this with a particular transform, use the <see cref="Bind"/> method.
        ///
        /// These are the immutable and serializable analogs to the <see cref="Builder"/> used in
        /// training.
        /// </summary>
        [BestFriend]
        internal abstract class TermMap
        {
            /// <summary>
            /// The item type of the input type, that is, either the input type or,
            /// if a vector, the item type of that type.
            /// </summary>
            public readonly PrimitiveDataViewType ItemType;
 
            /// <summary>
            /// The output item type. This will always be of known cardinality. Its count is always
            /// equal to <see cref="Count"/>, unless <see cref="Count"/> is 0 in which case this has
            /// key count of 1, since a count of 0 would indicate an unbound key. If we ever improve
            /// key types so they are capable of distinguishing between the set they index being
            /// empty vs. of unknown or unbound cardinality, this should change.
            /// </summary>
            public readonly KeyDataViewType OutputType;
 
            /// <summary>
            /// The number of items in the map.
            /// </summary>
            public readonly int Count;
 
            protected TermMap(PrimitiveDataViewType type, int count)
            {
                Contracts.AssertValue(type);
                Contracts.Assert(count >= 0);
                ItemType = type;
                Count = count;
                OutputType = new KeyDataViewType(typeof(uint), Count == 0 ? 1 : Count);
            }
 
            internal abstract void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory);
 
            internal static TermMap Load(ModelLoadContext ctx, IHostEnvironment ectx, CodecFactory codecFactory)
            {
                // *** Binary format ***
                // byte: map type code
                // <remainer> ...
 
                MapType mtype = (MapType)ctx.Reader.ReadByte();
                ectx.CheckDecode(Enum.IsDefined(typeof(MapType), mtype));
                switch (mtype)
                {
                    case MapType.Text:
                        // Binary format defined by this method.
                        return TextImpl.Create(ctx, ectx);
                    case MapType.Codec:
                        // *** Binary format ***
                        // codec parameterization: the codec
                        // int: number of terms
                        // value codec block: the terms written in the codec-defined binary format
                        IValueCodec codec;
                        if (!codecFactory.TryReadCodec(ctx.Reader.BaseStream, out codec))
                            throw ectx.Except("Unrecognized codec read");
                        ectx.CheckDecode(codec.Type is PrimitiveDataViewType);
                        int count = ctx.Reader.ReadInt32();
                        ectx.CheckDecode(count >= 0);
                        return Utils.MarshalInvoke(LoadCodecCore<int>, codec.Type.RawType, ctx, ectx, codec, count);
                    default:
                        ectx.Assert(false);
                        throw ectx.Except("Unrecognized type '{0}'", mtype);
                }
            }
 
            private static TermMap LoadCodecCore<T>(ModelLoadContext ctx, IExceptionContext ectx, IValueCodec codec, int count)
                where T : IEquatable<T>, IComparable<T>
            {
                Contracts.AssertValue(ectx);
                ectx.AssertValue(ctx);
                ectx.AssertValue(codec);
                ectx.Assert(codec is IValueCodec<T>);
                ectx.Assert(codec.Type is PrimitiveDataViewType);
                ectx.Assert(count >= 0);
 
                IValueCodec<T> codecT = (IValueCodec<T>)codec;
 
                var values = new HashArray<T>();
                if (count > 0)
                {
                    using (var reader = codecT.OpenReader(ctx.Reader.BaseStream, count))
                    {
                        T item = default(T);
                        for (int i = 0; i < count; i++)
                        {
                            reader.MoveNext();
                            reader.Get(ref item);
                            int index = values.Add(item);
                            ectx.Assert(0 <= index && index <= i);
                            if (index != i)
                                throw ectx.Except("Duplicate items at positions {0} and {1}", index, i);
                        }
                    }
                }
 
                return new HashArrayImpl<T>((PrimitiveDataViewType)codec.Type, values);
            }
 
            internal abstract void WriteTextTerms(TextWriter writer);
 
            internal sealed class TextImpl : TermMap<ReadOnlyMemory<char>>
            {
                private readonly NormStr.Pool _pool;
 
                /// <summary>
                /// A pool based text mapping implementation.
                /// </summary>
                /// <param name="pool">The string pool</param>
                public TextImpl(NormStr.Pool pool)
                    : base(TextDataViewType.Instance, pool.Count)
                {
                    Contracts.AssertValue(pool);
                    _pool = pool;
                }
 
                public static TextImpl Create(ModelLoadContext ctx, IExceptionContext ectx)
                {
                    // *** Binary format ***
                    // int: number of terms
                    // int[]: term string ids
 
                    // Note that this binary format as read here diverges from the save format
                    // insofar that the save format contains the "I am text" code, which by the
                    // time we reach here, we have already read.
 
                    var pool = new NormStr.Pool();
                    int cstr = ctx.Reader.ReadInt32();
                    ectx.CheckDecode(cstr >= 0);
 
                    for (int istr = 0; istr < cstr; istr++)
                    {
                        var nstr = pool.Add(ctx.LoadNonEmptyString());
                        ectx.CheckDecode(nstr.Id == istr);
                    }
 
                    // The way we "train" the termMap, they shouldn't contain the empty string.
                    ectx.CheckDecode(pool.Get("") == null);
 
                    return new TextImpl(pool);
                }
 
                internal override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory)
                {
                    // *** Binary format ***
                    // byte: map type code, in this case 'Text' (0)
                    // int: number of terms
                    // int[]: term string ids
 
                    ctx.Writer.Write((byte)MapType.Text);
                    host.Assert(_pool.Count >= 0);
                    host.CheckDecode(_pool.Get("") == null);
                    ctx.Writer.Write(_pool.Count);
 
                    int id = 0;
                    foreach (var nstr in _pool)
                    {
                        host.Assert(nstr.Id == id);
                        ctx.SaveNonEmptyString(nstr.Value);
                        id++;
                    }
                }
 
                private void KeyMapper(in ReadOnlyMemory<char> src, ref uint dst)
                {
                    var nstr = _pool.Get(src);
                    if (nstr == null)
                        dst = 0;
                    else
                        dst = (uint)nstr.Id + 1;
                }
 
                public override ValueMapper<ReadOnlyMemory<char>, uint> GetKeyMapper()
                {
                    return KeyMapper;
                }
 
                public override void GetTerms(ref VBuffer<ReadOnlyMemory<char>> dst)
                {
                    var editor = VBufferEditor.Create(ref dst, _pool.Count);
                    int slot = 0;
                    foreach (var nstr in _pool)
                    {
                        Contracts.Assert(0 <= nstr.Id && nstr.Id < editor.Values.Length);
                        Contracts.Assert(nstr.Id == slot);
                        editor.Values[nstr.Id] = nstr.Value;
                        slot++;
                    }
 
                    dst = editor.Commit();
                }
 
                internal override void WriteTextTerms(TextWriter writer)
                {
                    writer.WriteLine("# Number of terms = {0}", Count);
                    foreach (var nstr in _pool)
                        writer.WriteLine("{0}\t{1}", nstr.Id, nstr.Value);
                }
            }
 
            internal sealed class HashArrayImpl<T> : TermMap<T>
                where T : IEquatable<T>, IComparable<T>
            {
                // One of the two must exist. If we need one we can initialize it
                // from the other.
                private readonly HashArray<T> _values;
 
                public HashArrayImpl(PrimitiveDataViewType itemType, HashArray<T> values)
                    // Note: The caller shouldn't pass a null HashArray.
                    : base(itemType, values.Count)
                {
                    Contracts.AssertValue(values);
                    _values = values;
                }
 
                internal override void Save(ModelSaveContext ctx, IHostEnvironment host, CodecFactory codecFactory)
                {
                    // *** Binary format ***
                    // byte: map type code, in this case 'Codec'
                    // codec parameterization: the codec
                    // int: number of terms
                    // value codec block: the terms written in the codec-defined binary format
 
                    IValueCodec codec;
                    if (!codecFactory.TryGetCodec(ItemType, out codec))
                        throw host.Except("We do not know how to serialize terms of type '{0}'", ItemType);
                    ctx.Writer.Write((byte)MapType.Codec);
                    host.Assert(codec.Type.Equals(ItemType));
                    host.Assert(codec.Type is PrimitiveDataViewType);
                    codecFactory.WriteCodec(ctx.Writer.BaseStream, codec);
                    IValueCodec<T> codecT = (IValueCodec<T>)codec;
                    ctx.Writer.Write(_values.Count);
                    using (var writer = codecT.OpenWriter(ctx.Writer.BaseStream))
                    {
                        for (int i = 0; i < _values.Count; ++i)
                        {
                            T val = _values.GetItem(i);
                            writer.Write(in val);
                        }
                        writer.Commit();
                    }
                }
 
                public override ValueMapper<T, uint> GetKeyMapper()
                {
                    return
                        (in T src, ref uint dst) =>
                        {
                            int val;
                            if (_values.TryGetIndex(src, out val))
                                dst = (uint)val + 1;
                            else
                                dst = 0;
                        };
                }
 
                public override void GetTerms(ref VBuffer<T> dst)
                {
                    if (Count == 0)
                    {
                        VBufferUtils.Resize(ref dst, 0);
                        return;
                    }
                    var editor = VBufferEditor.Create(ref dst, Count);
                    Contracts.AssertValue(_values);
                    Contracts.Assert(_values.Count == Count);
                    _values.CopyTo(editor.Values);
                    dst = editor.Commit();
                }
 
                internal override void WriteTextTerms(TextWriter writer)
                {
                    writer.WriteLine("# Number of terms of type '{0}' = {1}", ItemType, Count);
                    StringBuilder sb = null;
                    var stringMapper = Data.Conversion.Conversions.DefaultInstance.GetStringConversion<T>(ItemType);
                    for (int i = 0; i < _values.Count; ++i)
                    {
                        T val = _values.GetItem(i);
                        stringMapper(in val, ref sb);
                        writer.WriteLine("{0}\t{1}", i, sb.ToString());
                    }
                }
            }
        }
 
        internal abstract class TermMap<T> : TermMap
        {
            protected TermMap(PrimitiveDataViewType type, int count)
                : base(type, count)
            {
                Contracts.Assert(ItemType.RawType == typeof(T));
            }
 
            public abstract ValueMapper<T, uint> GetKeyMapper();
 
            public abstract void GetTerms(ref VBuffer<T> dst);
        }
 
        private static void GetTextTerms<T>(in VBuffer<T> src, ValueMapper<T, StringBuilder> stringMapper, ref VBuffer<ReadOnlyMemory<char>> dst)
        {
            // REVIEW: This convenience function is not optimized. For non-string
            // types, creating a whole bunch of string objects on the heap is one that is
            // fraught with risk. Ideally we'd have some sort of "copying" text buffer builder
            // but for now we'll see if this implementation suffices.
 
            // This utility function is not intended for use when we already have text!
            Contracts.Assert(typeof(T) != typeof(ReadOnlyMemory<char>));
 
            StringBuilder sb = null;
 
            // We'd obviously have to adjust this a bit, if we ever had sparse metadata vectors.
            // The way the term map metadata getters are structured right now, this is impossible.
            Contracts.Assert(src.IsDense);
 
            var editor = VBufferEditor.Create(ref dst, src.Length);
            var srcValues = src.GetValues();
            for (int i = 0; i < srcValues.Length; ++i)
            {
                stringMapper(in srcValues[i], ref sb);
                editor.Values[i] = sb.ToString().AsMemory();
            }
            dst = editor.Commit();
        }
 
        /// <summary>
        /// A mapper bound to a particular transform, and a particular column. These wrap
        /// a <see cref="TermMap"/>, and facilitate mapping that object to the inputs of
        /// a particular column, providing both values and metadata.
        /// </summary>
        private abstract class BoundTermMap
        {
            public readonly TermMap Map;
 
            private readonly int _iinfo;
            private readonly bool _inputIsVector;
            private readonly IHostEnvironment _host;
            private readonly bool[] _textMetadata;
            private readonly ColInfo[] _infos;
            private readonly DataViewSchema _schema;
 
            private bool IsTextMetadata { get { return _textMetadata[_iinfo]; } }
 
            private BoundTermMap(IHostEnvironment env, DataViewSchema schema, TermMap map, ColInfo[] infos, bool[] textMetadata, int iinfo)
            {
                _host = env;
                //assert me.
                _textMetadata = textMetadata;
                _infos = infos;
                _schema = schema;
                _host.AssertValue(map);
                _host.Assert(0 <= iinfo && iinfo < infos.Length);
                var info = infos[iinfo];
                _host.Assert(info.TypeSrc.GetItemType().Equals(map.ItemType));
 
                Map = map;
                _iinfo = iinfo;
                _inputIsVector = info.TypeSrc is VectorDataViewType;
            }
 
            public static BoundTermMap Create(IHostEnvironment host, DataViewSchema schema, TermMap map, ColInfo[] infos, bool[] textMetadata, int iinfo)
            {
                host.AssertValue(map);
                host.Assert(0 <= iinfo && iinfo < infos.Length);
                var info = infos[iinfo];
                host.Assert(info.TypeSrc.GetItemType().Equals(map.ItemType));
 
                return Utils.MarshalInvoke(CreateCore<int>, map.ItemType.RawType, host, schema, map, infos, textMetadata, iinfo);
            }
 
            public static BoundTermMap CreateCore<T>(IHostEnvironment env, DataViewSchema schema, TermMap map, ColInfo[] infos, bool[] textMetadata, int iinfo)
            {
                TermMap<T> mapT = (TermMap<T>)map;
                if (mapT.ItemType is KeyDataViewType)
                    return new KeyImpl<T>(env, schema, mapT, infos, textMetadata, iinfo);
                return new Impl<T>(env, schema, mapT, infos, textMetadata, iinfo);
            }
 
            public abstract Delegate GetMappingGetter(DataViewRow row);
 
            /// <summary>
            /// Allows us to optionally register metadata. It is also perfectly legal for
            /// this to do nothing, which corresponds to there being no metadata.
            /// </summary>
            public abstract void AddMetadata(DataViewSchema.Annotations.Builder builder);
 
            /// <summary>
            /// Writes out all terms we map to a text writer, with one line per mapped term.
            /// The line should have the format mapped key value, then a tab, then the term
            /// that is mapped. The writer should not be closed, as it will be used to write
            /// all term maps. We should write <see cref="TermMap.Count"/> terms.
            /// </summary>
            /// <param name="writer">The writer to which we write terms</param>
            public virtual void WriteTextTerms(TextWriter writer)
            {
                Map.WriteTextTerms(writer);
            }
 
            private abstract class Base<T> : BoundTermMap
            {
                protected readonly TermMap<T> TypedMap;
 
                public Base(IHostEnvironment env, DataViewSchema schema, TermMap<T> map, ColInfo[] infos, bool[] textMetadata, int iinfo)
                    : base(env, schema, map, infos, textMetadata, iinfo)
                {
                    TypedMap = map;
                }
 
                /// <summary>
                /// Returns what the default value maps to.
                /// </summary>
                private static uint MapDefault(ValueMapper<T, uint> map)
                {
                    T src = default(T);
                    uint dst = 0;
                    map(in src, ref dst);
                    return dst;
                }
 
                public override Delegate GetMappingGetter(DataViewRow input)
                {
                    // When constructing the getter, there are a few cases we have to consider:
                    // If scalar then it's just a straightforward mapping.
                    // If vector, then we have to detect whether the mapping happens to be
                    // sparsity preserving or not, that is, if the default value maps to the
                    // default (missing) key. For some types this will always be true, but it
                    // could also be true if we happened to never see the default value in
                    // training.
 
                    if (!_inputIsVector)
                    {
                        ValueMapper<T, uint> map = TypedMap.GetKeyMapper();
                        var info = _infos[_iinfo];
                        T src = default(T);
                        Contracts.Assert(!(info.TypeSrc is VectorDataViewType));
                        var inputColumn = input.Schema[info.InputColumnName];
                        _host.Assert(input.IsColumnActive(inputColumn));
                        var getSrc = input.GetGetter<T>(inputColumn);
                        ValueGetter<uint> retVal =
                            (ref uint dst) =>
                            {
                                getSrc(ref src);
                                map(in src, ref dst);
                            };
                        return retVal;
                    }
                    else
                    {
                        // It might be tempting to move "map" and "info" out of both blocks and into
                        // the main block of the function, but please don't do that. The implicit
                        // classes created by the compiler to hold the non-vector and vector lambdas
                        // will have an indirect wrapping class to hold "map" and "info". This is
                        // bad, especially since "map" is very frequently called.
                        ValueMapper<T, uint> map = TypedMap.GetKeyMapper();
                        var info = _infos[_iinfo];
                        // First test whether default maps to default. If so this is sparsity preserving.
                        var inputColumn = input.Schema[info.InputColumnName];
                        _host.Assert(input.IsColumnActive(inputColumn));
                        var getSrc = input.GetGetter<VBuffer<T>>(inputColumn);
                        VBuffer<T> src = default(VBuffer<T>);
                        ValueGetter<VBuffer<uint>> retVal;
                        // REVIEW: Consider whether possible or reasonable to not use a builder here.
                        var bldr = new BufferBuilder<uint>(U4Adder.Instance);
                        int cv = info.TypeSrc.GetVectorSize();
                        uint defaultMapValue = MapDefault(map);
                        uint dstItem = default(uint);
 
                        if (defaultMapValue == 0)
                        {
                            // Sparsity preserving.
                            retVal =
                                (ref VBuffer<uint> dst) =>
                                {
                                    getSrc(ref src);
                                    int cval = src.Length;
                                    if (cv != 0 && cval != cv)
                                        throw _host.Except("Column '{0}': TermTransform expects {1} slots, but got {2}", info.Name, cv, cval);
                                    if (cval == 0)
                                    {
                                        // REVIEW: Should the VBufferBuilder be changed so that it can
                                        // build vectors of length zero?
                                        VBufferUtils.Resize(ref dst, cval);
                                        return;
                                    }
 
                                    bldr.Reset(cval, dense: false);
 
                                    var values = src.GetValues();
                                    var indices = src.GetIndices();
                                    int count = values.Length;
                                    for (int islot = 0; islot < count; islot++)
                                    {
                                        map(in values[islot], ref dstItem);
                                        if (dstItem != 0)
                                        {
                                            int slot = !src.IsDense ? indices[islot] : islot;
                                            bldr.AddFeature(slot, dstItem);
                                        }
                                    }
 
                                    bldr.GetResult(ref dst);
                                };
                        }
                        else
                        {
                            retVal =
                                (ref VBuffer<uint> dst) =>
                                {
                                    getSrc(ref src);
                                    int cval = src.Length;
                                    if (cv != 0 && cval != cv)
                                        throw _host.Except("Column '{0}': TermTransform expects {1} slots, but got {2}", info.Name, cv, cval);
                                    if (cval == 0)
                                    {
                                        // REVIEW: Should the VBufferBuilder be changed so that it can
                                        // build vectors of length zero?
                                        VBufferUtils.Resize(ref dst, cval);
                                        return;
                                    }
 
                                    // Despite default not mapping to default, it's very possible the result
                                    // might still be sparse, for example, the source vector could be full of
                                    // unrecognized items.
                                    bldr.Reset(cval, dense: false);
 
                                    var values = src.GetValues();
                                    if (src.IsDense)
                                    {
                                        for (int slot = 0; slot < src.Length; ++slot)
                                        {
                                            map(in values[slot], ref dstItem);
                                            if (dstItem != 0)
                                                bldr.AddFeature(slot, dstItem);
                                        }
                                    }
                                    else
                                    {
                                        var indices = src.GetIndices();
                                        int nextExplicitSlot = indices.Length == 0 ? src.Length : indices[0];
                                        int islot = 0;
                                        for (int slot = 0; slot < src.Length; ++slot)
                                        {
                                            if (nextExplicitSlot == slot)
                                            {
                                                // This was an explicitly defined value.
                                                _host.Assert(islot < values.Length);
                                                map(in values[islot], ref dstItem);
                                                if (dstItem != 0)
                                                    bldr.AddFeature(slot, dstItem);
                                                nextExplicitSlot = ++islot == indices.Length ? src.Length : indices[islot];
                                            }
                                            else
                                            {
                                                _host.Assert(slot < nextExplicitSlot);
                                                // This is a non-defined implicit default value. No need to attempt a remap
                                                // since we already have it.
                                                bldr.AddFeature(slot, defaultMapValue);
                                            }
                                        }
                                    }
                                    bldr.GetResult(ref dst);
                                };
                        }
                        return retVal;
                    }
                }
 
                public override void AddMetadata(DataViewSchema.Annotations.Builder builder)
                {
                    if (TypedMap.Count == 0)
                        return;
                    if (IsTextMetadata && !(TypedMap.ItemType is TextDataViewType))
                    {
                        var conv = Data.Conversion.Conversions.DefaultInstance;
                        var stringMapper = conv.GetStringConversion<T>(TypedMap.ItemType);
 
                        ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter =
                            (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                            {
                                // No buffer sharing convenient here.
                                VBuffer<T> dstT = default;
                                TypedMap.GetTerms(ref dstT);
                                GetTextTerms(in dstT, stringMapper, ref dst);
                            };
                        builder.AddKeyValues(TypedMap.OutputType.GetCountAsInt32(_host), TextDataViewType.Instance, getter);
                    }
                    else
                    {
                        ValueGetter<VBuffer<T>> getter =
                            (ref VBuffer<T> dst) =>
                            {
                                TypedMap.GetTerms(ref dst);
                            };
                        builder.AddKeyValues(TypedMap.OutputType.GetCountAsInt32(_host), TypedMap.ItemType, getter);
                    }
                }
            }
 
            /// <summary>
            /// The key-typed version is the same as <see cref="BoundTermMap.Impl{T}"/>, except the metadata
            /// is based off a subset of the key values metadata.
            /// </summary>
            private sealed class KeyImpl<T> : Base<T>
            {
                private static readonly FuncInstanceMethodInfo1<KeyImpl<T>, DataViewType, DataViewSchema.Annotations.Builder, bool> _addMetadataCoreMethodInfo
                    = FuncInstanceMethodInfo1<KeyImpl<T>, DataViewType, DataViewSchema.Annotations.Builder, bool>.Create(target => target.AddMetadataCore<int>);
 
                private static readonly FuncInstanceMethodInfo1<KeyImpl<T>, PrimitiveDataViewType, TextWriter, bool> _writeTextTermsCoreMethodInfo
                    = FuncInstanceMethodInfo1<KeyImpl<T>, PrimitiveDataViewType, TextWriter, bool>.Create(target => target.WriteTextTermsCore<int>);
 
                public KeyImpl(IHostEnvironment env, DataViewSchema schema, TermMap<T> map, ColInfo[] infos, bool[] textMetadata, int iinfo)
                    : base(env, schema, map, infos, textMetadata, iinfo)
                {
                    _host.Assert(TypedMap.ItemType is KeyDataViewType);
                }
 
                public override void AddMetadata(DataViewSchema.Annotations.Builder builder)
                {
                    if (TypedMap.Count == 0)
                        return;
 
                    _schema.TryGetColumnIndex(_infos[_iinfo].InputColumnName, out int srcCol);
                    VectorDataViewType srcMetaType = _schema[srcCol].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
                    if (srcMetaType == null || srcMetaType.Size != TypedMap.ItemType.GetKeyCountAsInt32(_host) ||
                        TypedMap.ItemType.GetKeyCountAsInt32(_host) == 0 || !Utils.MarshalInvoke(_addMetadataCoreMethodInfo, this, srcMetaType.ItemType.RawType, srcMetaType.ItemType, builder))
                    {
                        // No valid input key-value metadata. Back off to the base implementation.
                        base.AddMetadata(builder);
                    }
                }
 
                private bool AddMetadataCore<TMeta>(DataViewType srcMetaType, DataViewSchema.Annotations.Builder builder)
                {
                    _host.AssertValue(srcMetaType);
                    _host.Assert(srcMetaType.RawType == typeof(TMeta));
                    _host.AssertValue(builder);
                    var srcType = TypedMap.ItemType as KeyDataViewType;
                    _host.AssertValue(srcType);
                    var dstType = new KeyDataViewType(typeof(uint), srcType.Count);
                    var convInst = Data.Conversion.Conversions.DefaultInstance;
                    ValueMapper<T, uint> conv;
                    bool identity;
                    // If we can't convert this type to U4, don't try to pass along the metadata.
                    if (!convInst.TryGetStandardConversion<T, uint>(srcType, dstType, out conv, out identity))
                        return false;
                    _schema.TryGetColumnIndex(_infos[_iinfo].InputColumnName, out int srcCol);
 
                    ValueGetter<VBuffer<TMeta>> getter =
                        (ref VBuffer<TMeta> dst) =>
                        {
                            VBuffer<TMeta> srcMeta = default(VBuffer<TMeta>);
                            _schema[srcCol].GetKeyValues(ref srcMeta);
                            _host.Assert(srcMeta.Length == srcType.GetCountAsInt32(_host));
 
                            VBuffer<T> keyVals = default(VBuffer<T>);
                            TypedMap.GetTerms(ref keyVals);
                            var editor = VBufferEditor.Create(ref dst, TypedMap.OutputType.GetCountAsInt32(_host));
                            uint convKeyVal = 0;
                            foreach (var pair in keyVals.Items(all: true))
                            {
                                T keyVal = pair.Value;
                                conv(in keyVal, ref convKeyVal);
                                // The builder for the key values should not have any missings.
                                _host.Assert(0 < convKeyVal && convKeyVal <= srcMeta.Length);
                                srcMeta.GetItemOrDefault((int)(convKeyVal - 1), ref editor.Values[pair.Key]);
                            }
                            dst = editor.Commit();
                        };
 
                    int keyCount = TypedMap.OutputType.GetCountAsInt32(_host);
                    if (IsTextMetadata && !(srcMetaType is TextDataViewType))
                    {
                        var stringMapper = convInst.GetStringConversion<TMeta>(srcMetaType);
                        ValueGetter<VBuffer<ReadOnlyMemory<char>>> mgetter =
                            (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                            {
                                var tempMeta = default(VBuffer<TMeta>);
                                getter(ref tempMeta);
                                Contracts.Assert(tempMeta.IsDense);
                                GetTextTerms(in tempMeta, stringMapper, ref dst);
                                _host.Assert(dst.Length == keyCount);
                            };
                        builder.AddKeyValues(keyCount, TextDataViewType.Instance, mgetter);
                    }
                    else
                    {
                        ValueGetter<VBuffer<TMeta>> mgetter =
                            (ref VBuffer<TMeta> dst) =>
                            {
                                getter(ref dst);
                                _host.Assert(dst.Length == keyCount);
                            };
                        builder.AddKeyValues(keyCount, (PrimitiveDataViewType)srcMetaType.GetItemType(), mgetter);
                    }
                    return true;
                }
 
                public override void WriteTextTerms(TextWriter writer)
                {
                    if (TypedMap.Count == 0)
                        return;
 
                    _schema.TryGetColumnIndex(_infos[_iinfo].InputColumnName, out int srcCol);
                    VectorDataViewType srcMetaType = _schema[srcCol].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
                    if (srcMetaType == null || srcMetaType.Size != TypedMap.ItemType.GetKeyCountAsInt32(_host) ||
                        TypedMap.ItemType.GetKeyCountAsInt32(_host) == 0 || !Utils.MarshalInvoke(_writeTextTermsCoreMethodInfo, this, srcMetaType.ItemType.RawType, srcMetaType.ItemType, writer))
                    {
                        // No valid input key-value metadata. Back off to the base implementation.
                        base.WriteTextTerms(writer);
                    }
                }
 
                private bool WriteTextTermsCore<TMeta>(PrimitiveDataViewType srcMetaType, TextWriter writer)
                {
                    _host.AssertValue(srcMetaType);
                    _host.Assert(srcMetaType.RawType == typeof(TMeta));
                    var srcType = TypedMap.ItemType as KeyDataViewType;
                    _host.AssertValue(srcType);
                    var dstType = new KeyDataViewType(typeof(uint), srcType.Count);
                    var convInst = Data.Conversion.Conversions.DefaultInstance;
                    ValueMapper<T, uint> conv;
                    bool identity;
                    // If we can't convert this type to U4, don't try.
                    if (!convInst.TryGetStandardConversion<T, uint>(srcType, dstType, out conv, out identity))
                        return false;
                    _schema.TryGetColumnIndex(_infos[_iinfo].InputColumnName, out int srcCol);
 
                    VBuffer<TMeta> srcMeta = default(VBuffer<TMeta>);
                    _schema[srcCol].GetKeyValues(ref srcMeta);
                    if (srcMeta.Length != srcType.GetCountAsInt32(_host))
                        return false;
 
                    VBuffer<T> keyVals = default(VBuffer<T>);
                    TypedMap.GetTerms(ref keyVals);
                    TMeta metaVal = default(TMeta);
                    uint convKeyVal = 0;
                    StringBuilder sb = null;
                    var keyStringMapper = convInst.GetStringConversion<T>(TypedMap.ItemType);
                    var metaStringMapper = convInst.GetStringConversion<TMeta>(srcMetaType);
 
                    writer.WriteLine("# Number of terms of key '{0}' indexing '{1}' value = {2}",
                        TypedMap.ItemType, srcMetaType, TypedMap.Count);
                    foreach (var pair in keyVals.Items(all: true))
                    {
                        T keyVal = pair.Value;
                        conv(in keyVal, ref convKeyVal);
                        // The key mapping will not have admitted missing keys.
                        _host.Assert(0 < convKeyVal && convKeyVal <= srcMeta.Length);
                        srcMeta.GetItemOrDefault((int)(convKeyVal - 1), ref metaVal);
                        keyStringMapper(in keyVal, ref sb);
                        writer.Write("{0}\t{1}", pair.Key, sb.ToString());
                        metaStringMapper(in metaVal, ref sb);
                        writer.WriteLine("\t{0}", sb.ToString());
                    }
                    return true;
                }
            }
 
            private sealed class Impl<T> : Base<T>
            {
                public Impl(IHostEnvironment env, DataViewSchema schema, TermMap<T> map, ColInfo[] infos, bool[] textMetadata, int iinfo)
                    : base(env, schema, map, infos, textMetadata, iinfo)
                {
                }
            }
        }
    }
}