File: KeyDataViewType.cs
Web Access
Project: src\src\Microsoft.ML.DataView\Microsoft.ML.DataView.csproj (Microsoft.ML.DataView)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Diagnostics;
using Microsoft.ML.Internal.DataView;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// Type representing categorical or enumerated values, most commonly used for
    /// the values of labels in multiclass classification models.
    /// </summary>
    /// <remarks>
    /// The underlying .NET type is one of the unsigned integer types. The default is
    /// <see cref="uint"/>, but it can also be <see cref="byte"/>,
    /// <see cref="ushort"/>, or <see cref="ulong"/>.
    /// Despite keys being numerical types, the information is not inherently numeric,
    /// so typically, arithmetic is not meaningful.
    ///
    /// Missing values are mapped to 0.
    ///
    /// The first non-missing value of the set is always <c>1</c>.
    ///
    /// The other values range up to the value of <see cref="Count"/>.
    ///
    /// For example, if you have a key value with a <see cref="Count"/> of 3, then
    /// the <see cref="uint"/> value <c>0</c> corresponds to missing key values, and
    /// one of the values of <c>1</c>, <c>2</c>, or <c>3</c> is of the valid values,
    /// and no other values are used.
    /// </remarks>
    public sealed class KeyDataViewType : PrimitiveDataViewType
    {
        /// <summary>
        /// Initializes a new instance of the <see cref="KeyDataViewType"/> class.
        /// </summary>
        /// <param name="type">
        /// The underlying representation type. Should be one of <see cref="byte"/>, <see cref="ushort"/>,
        /// <see cref="uint"/> (the most common choice), or <see cref="ulong"/>.
        /// </param>
        /// <param name="count">
        /// The cardinality of the underlying set. This must not exceed the associated maximum value of the
        /// representation type. For example, if <paramref name="type"/> is <see cref="uint"/>, then this must not
        /// exceed <see cref="uint.MaxValue"/>.
        /// </param>
        public KeyDataViewType(Type type, ulong count)
            : base(type)
        {
            if (!IsValidDataType(type))
                throw Contracts.ExceptParam(nameof(type), $"Type is not valid, it must be {typeof(byte).Name}, {typeof(ushort).Name}, {typeof(uint).Name}, or {typeof(ulong).Name}.");
            if (count == 0 || GetMaxInt(type) < count)
                throw Contracts.ExceptParam(nameof(count), $"The cardinality of a {nameof(KeyDataViewType)} must not exceed {type.Name}.{nameof(uint.MaxValue)} " +
                    $"and must be strictly positive, but got {count}.");
            Count = count;
        }
 
        /// <summary>
        /// Initializes a new instance of the <see cref="KeyDataViewType"/> class. This differs from the hypothetically more general
        /// <see cref="KeyDataViewType.KeyDataViewType(Type, ulong)"/> constructor by taking an <see cref="int"/> for
        /// <paramref name="count"/>, to more naturally facilitate the most common case that the key value is being used
        /// as an enumeration over an array or list of some form.
        /// </summary>
        /// <param name="type">
        /// The underlying representation type. Should be one of <see cref="byte"/>, <see cref="ushort"/>,
        /// <see cref="uint"/> (the most common choice), or <see cref="ulong"/>.
        /// </param>
        /// <param name="count">
        /// The cardinality of the underlying set. This must not exceed the associated maximum value of the
        /// representation type. For example, if <paramref name="type"/> is <see cref="uint"/>, then this must not
        /// exceed <see cref="uint.MaxValue"/>.
        /// </param>
        public KeyDataViewType(Type type, int count)
            : this(type, (ulong)count)
        {
            Contracts.CheckParam(0 < count, nameof(count), "The cardinality of a " + nameof(KeyDataViewType) + " must be strictly positive.");
        }
 
        /// <summary>
        /// Returns true iff the given type is valid for a <see cref="KeyDataViewType"/>. The valid ones are
        /// <see cref="byte"/>, <see cref="ushort"/>, <see cref="uint"/>, and <see cref="ulong"/>, that is, the unsigned
        /// integer types.
        /// </summary>
        public static bool IsValidDataType(Type type)
        {
            return type == typeof(uint) || type == typeof(ulong) || type == typeof(ushort) || type == typeof(byte);
        }
 
        private static ulong GetMaxInt(Type type)
        {
            if (type == typeof(uint))
                return uint.MaxValue;
            else if (type == typeof(ulong))
                return ulong.MaxValue;
            else if (type == typeof(ushort))
                return ushort.MaxValue;
            else if (type == typeof(byte))
                return byte.MaxValue;
 
            return 0;
        }
 
        private string GetRawTypeName()
        {
            if (RawType == typeof(uint))
                return NumberDataViewType.UInt32.ToString();
            else if (RawType == typeof(ulong))
                return NumberDataViewType.UInt64.ToString();
            else if (RawType == typeof(ushort))
                return NumberDataViewType.UInt16.ToString();
            else if (RawType == typeof(byte))
                return NumberDataViewType.Byte.ToString();
 
            Debug.Fail("Invalid type");
            return null;
        }
 
        /// <summary>
        /// <see cref="Count"/> is the cardinality of the <see cref="KeyDataViewType"/>.
        /// </summary>
        /// <remarks>
        /// The typical legal values for data of this type ranges from the missing value of <c>0</c>, and non-missing
        /// values ranging from to <c>1</c> through <see cref="Count"/>, inclusive, being the enumeration into whatever
        /// set the key values are enumerated over.
        /// </remarks>
        public ulong Count { get; }
 
        /// <summary>
        /// Determine if this <see cref="KeyDataViewType"/> object is equal to another <see cref="DataViewType"/> instance.
        /// Checks if the other item is the type of <see cref="KeyDataViewType"/>, if the <see cref="DataViewType.RawType"/>
        /// is the same, and if the <see cref="Count"/> is the same.
        /// </summary>
        /// <param name="other">The other object to compare against.</param>
        /// <returns><see langword="true" /> if both objects are equal, otherwise <see langword="false"/>.</returns>
        public override bool Equals(DataViewType other)
        {
            if (other == this)
                return true;
 
            if (!(other is KeyDataViewType tmp))
                return false;
            if (RawType != tmp.RawType)
                return false;
            if (Count != tmp.Count)
                return false;
            return true;
        }
 
        /// <summary>
        /// Determine if a <see cref="KeyDataViewType"/> instance is equal to another <see cref="KeyDataViewType"/> instance.
        /// Checks if any object is the type of <see cref="KeyDataViewType"/>, if the <see cref="DataViewType.RawType"/>
        /// is the same, and if the <see cref="Count"/> is the same.
        /// </summary>
        /// <param name="other">The other object to compare against.</param>
        /// <returns><see langword="true" /> if both objects are equal, otherwise <see langword="false"/>.</returns>
        public override bool Equals(object other)
            => other is DataViewType tmp && Equals(tmp);
 
        /// <summary>
        /// Retrieves the hash code.
        /// </summary>
        /// <returns>An integer representing the hash code.</returns>
        public override int GetHashCode()
        {
            return Hashing.CombineHash(RawType.GetHashCode(), Count.GetHashCode());
        }
 
        /// <summary>
        /// The string representation of the <see cref="KeyDataViewType"/>.
        /// </summary>
        /// <returns>A formatted string.</returns>
        public override string ToString()
        {
            string rawTypeName = GetRawTypeName();
            return string.Format("Key<{0}, {1}-{2}>", rawTypeName, 0, Count - 1);
        }
    }
}