File: Transforms\MetadataDispatcher.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// Base class for handling the schema metadata API.
    /// </summary>
    internal abstract class MetadataDispatcherBase
    {
        private bool _sealed;
 
        /// <summary>
        /// Information for a column.
        /// </summary>
        protected sealed class ColInfo
        {
            // The source schema to pass through metadata from. May be null, indicating none.
            public readonly DataViewSchema SchemaSrc;
            // The source column index to pass through metadata from.
            public readonly int IndexSrc;
            // The metadata kind predicate indicating the kinds of metadata to pass through
            // from the source schema column. May be null, indicating all.
            public readonly Func<string, int, bool> FilterSrc;
 
            // The metadata getters.
            private readonly GetterInfo[] _getters;
 
            public int GetterCount { get { return _getters.Length; } }
 
            public IEnumerable<GetterInfo> Getters
            {
                get
                {
                    foreach (var g in _getters)
                        yield return g;
                }
            }
 
            public ColInfo(DataViewSchema schemaSrc, int indexSrc, Func<string, int, bool> filterSrc,
                IEnumerable<GetterInfo> getters = null)
            {
                SchemaSrc = schemaSrc;
                IndexSrc = indexSrc;
                FilterSrc = filterSrc;
                _getters = getters != null ? getters.ToArray() : new GetterInfo[0];
            }
 
            public ColInfo UpdateGetters(IEnumerable<GetterInfo> getters)
            {
                if (getters == null)
                    return this;
                Contracts.CheckParam(!getters.Any(g => g == null), nameof(getters), "Invalid getter info");
                return new ColInfo(SchemaSrc, IndexSrc, FilterSrc, getters);
            }
        }
 
        /// <summary>
        /// Base class for metadata getters.
        /// </summary>
        protected abstract class GetterInfo
        {
            // The metadata kind.
            public readonly string Kind;
            // The metadata type.
            public readonly DataViewType Type;
 
            protected GetterInfo(string kind, DataViewType type)
            {
                Contracts.CheckNonWhiteSpace(kind, nameof(kind), "Invalid metadata kind");
                Contracts.CheckValue(type, nameof(type));
                Kind = kind;
                Type = type;
            }
        }
 
        /// <summary>
        /// Strongly typed base class for metadata getters. Introduces the abstract Get method.
        /// </summary>
        protected abstract class GetterInfo<TValue> : GetterInfo
        {
            protected GetterInfo(string kind, DataViewType type)
                : base(kind, type)
            {
            }
 
            public abstract void Get(int index, ref TValue value);
        }
 
        /// <summary>
        /// A delegate based metadata getter.
        /// </summary>
        protected sealed class GetterInfoDelegate<TValue> : GetterInfo<TValue>
        {
            public readonly AnnotationUtils.AnnotationGetter<TValue> Getter;
 
            public GetterInfoDelegate(string kind, DataViewType type, AnnotationUtils.AnnotationGetter<TValue> getter)
                : base(kind, type)
            {
                Contracts.Check(type.RawType == typeof(TValue), "Incompatible types");
                Contracts.CheckValue(getter, nameof(getter));
                Getter = getter;
            }
 
            public override void Get(int index, ref TValue value)
            {
                Getter(index, ref value);
            }
        }
 
        /// <summary>
        /// A primitive value based metadata getter.
        /// </summary>
        protected sealed class GetterInfoPrimitive<TValue> : GetterInfo<TValue>
        {
            // This is a MetadataGetter<TValue> where TValue is Type.RawType.
            public readonly TValue Value;
 
            public GetterInfoPrimitive(string kind, DataViewType type, TValue value)
                : base(kind, type)
            {
                Contracts.Check(type.RawType == typeof(TValue), "Incompatible types");
                Value = value;
            }
 
            public override void Get(int index, ref TValue value)
            {
                value = Value;
            }
        }
 
        private readonly ColInfo[] _infos;
 
        /// <summary>
        /// The number of columns.
        /// </summary>
        protected int ColCount { get { return _infos.Length; } }
 
        protected MetadataDispatcherBase(int colCount)
        {
            Contracts.CheckParam(colCount >= 0, nameof(colCount));
            _infos = new ColInfo[colCount];
        }
 
        /// <summary>
        /// Create a ColInfo with the indicated information and no GetterInfos. This doesn't
        /// register a column, only creates a ColInfo. Note that multiple columns can share
        /// the same ColInfo, if desired. Simply call RegisterColumn multiple times, passing
        /// the same ColInfo but different index values. This can only be called before Seal is called.
        /// </summary>
        protected ColInfo CreateInfo(DataViewSchema schemaSrc = null, int indexSrc = -1,
            Func<string, int, bool> filterSrc = null)
        {
            Contracts.Check(!_sealed, "MetadataDispatcher sealed");
            Contracts.Check(schemaSrc == null || (0 <= indexSrc && indexSrc < schemaSrc.Count), "indexSrc out of range");
            Contracts.Check(filterSrc == null || schemaSrc != null, "filterSrc should be null if schemaSrc is null");
            return new ColInfo(schemaSrc, indexSrc, filterSrc);
        }
 
        /// <summary>
        /// Register the given ColInfo as the metadata handling information for the given
        /// column index. Throws if the given column index already has a ColInfo registered for it.
        /// This can only be called before Seal is called.
        /// </summary>
        protected void RegisterColumn(int index, ColInfo info)
        {
            Contracts.Check(!_sealed, "MetadataDispatcher sealed");
            Contracts.CheckValue(info, nameof(info));
            Contracts.CheckParam(0 <= index && index < _infos.Length, nameof(index), "Out of range");
            Contracts.CheckParam(_infos[index] == null, nameof(index), "Column already registered");
            _infos[index] = info;
        }
 
        /// <summary>
        /// Seals this dispatcher from further column registrations. This must be called before any
        /// metadata methods are called, otherwise an exception is thrown.
        /// </summary>
        protected void Seal()
        {
            _sealed = true;
        }
 
        /// <summary>
        /// Returns the ColInfo registered for the given column index, if there is one. This may be called
        /// before or after Seal is called.
        /// </summary>
        protected ColInfo GetColInfoOrNull(int index)
        {
            Contracts.CheckParam(0 <= index && index < _infos.Length, nameof(index));
            return _infos[index];
        }
 
        /// <summary>
        /// Gets the metadata kinds and types for the given column index.
        /// This can only be called after Seal is called.
        /// </summary>
        public IEnumerable<KeyValuePair<string, DataViewType>> GetMetadataTypes(int index)
        {
            Contracts.Check(_sealed, "MetadataDispatcher not sealed");
 
            var info = GetColInfoOrNull(index);
            if (info == null)
                return Enumerable.Empty<KeyValuePair<string, DataViewType>>();
            return GetTypesCore(index, info);
        }
 
        private IEnumerable<KeyValuePair<string, DataViewType>> GetTypesCore(int index, ColInfo info)
        {
            Contracts.Assert(_sealed);
            Contracts.AssertValue(info);
 
            HashSet<string> kinds = null;
            if (info.GetterCount > 0)
            {
                if (info.SchemaSrc != null)
                    kinds = new HashSet<string>();
                foreach (var g in info.Getters)
                {
                    yield return new KeyValuePair<string, DataViewType>(g.Kind, g.Type);
                    if (kinds != null)
                        kinds.Add(g.Kind);
                }
            }
 
            if (info.SchemaSrc == null)
                yield break;
 
            // Pass through from base, with filtering.
            foreach (var kvp in info.SchemaSrc[info.IndexSrc].Annotations.Schema.Select(c => new KeyValuePair<string, DataViewType>(c.Name, c.Type)))
            {
                if (kinds != null && kinds.Contains(kvp.Key))
                    continue;
                if (info.FilterSrc != null && !info.FilterSrc(kvp.Key, index))
                    continue;
                yield return kvp;
            }
        }
 
        /// <summary>
        /// Gets the metadata type for the given metadata kind and column index, if there is one.
        /// This can only be called after Seal is called.
        /// </summary>
        public DataViewType GetMetadataTypeOrNull(string kind, int index)
        {
            Contracts.Check(_sealed, "MetadataDispatcher not sealed");
 
            var info = GetColInfoOrNull(index);
            if (info == null)
                return null;
 
            foreach (var g in info.Getters)
            {
                if (g.Kind == kind)
                    return g.Type;
            }
 
            if (info.SchemaSrc == null)
                return null;
            if (info.FilterSrc != null && !info.FilterSrc(kind, index))
                return null;
            return info.SchemaSrc[info.IndexSrc].Annotations.Schema.GetColumnOrNull(kind)?.Type;
        }
 
        /// <summary>
        /// Gets the metadata for the given metadata kind and column index. Throws if there isn't any.
        /// This can only be called after Seal is called.
        /// </summary>
        public void GetMetadata<TValue>(IExceptionContext ectx, string kind, int index, ref TValue value)
        {
            ectx.Check(_sealed, "MetadataDispatcher not sealed");
            ectx.Check(0 <= index && index < _infos.Length);
 
            var info = _infos[index];
            if (info == null)
                throw ectx.ExceptGetAnnotation();
 
            foreach (var g in info.Getters)
            {
                if (g.Kind == kind)
                {
                    var getter = g as GetterInfo<TValue>;
                    if (getter == null)
                        throw ectx.ExceptGetAnnotation();
                    getter.Get(index, ref value);
                    return;
                }
            }
 
            if (info.SchemaSrc == null || info.FilterSrc != null && !info.FilterSrc(kind, index))
                throw ectx.ExceptGetAnnotation();
            info.SchemaSrc[info.IndexSrc].Annotations.GetValue(kind, ref value);
        }
    }
 
    /// <summary>
    /// For handling the schema metadata API. Call one of the BuildMetadata methods to get
    /// a builder for a particular column. Wrap the return in a using statement. Disposing the builder
    /// records the metadata for the column. Call Seal() once all metadata is constructed.
    /// </summary>
    [BestFriend]
    internal sealed class MetadataDispatcher : MetadataDispatcherBase
    {
        public MetadataDispatcher(int colCount)
            : base(colCount)
        {
        }
 
        /// <summary>
        /// Start building metadata for a column that doesn't pass through any metadata from
        /// a source column.
        /// </summary>
        public Builder BuildMetadata(int index)
        {
            return new Builder(this, index);
        }
 
        /// <summary>
        /// Start building metadata for a column that passes through all metadata from
        /// a source column.
        /// </summary>
        public Builder BuildMetadata(int index, DataViewSchema schemaSrc, int indexSrc)
        {
            Contracts.CheckValue(schemaSrc, nameof(schemaSrc));
            return new Builder(this, index, schemaSrc, indexSrc);
        }
 
        /// <summary>
        /// Start building metadata for a column that passes through metadata of certain kinds from
        /// a source column. The kinds that are passed through are those for which
        /// <paramref name="filterSrc"/> returns true.
        /// </summary>
        public Builder BuildMetadata(int index, DataViewSchema schemaSrc, int indexSrc, Func<string, int, bool> filterSrc)
        {
            Contracts.CheckValue(schemaSrc, nameof(schemaSrc));
            return new Builder(this, index, schemaSrc, indexSrc, filterSrc);
        }
 
        /// <summary>
        /// Start building metadata for a column that passes through metadata of the given kind from
        /// a source column.
        /// </summary>
        public Builder BuildMetadata(int index, DataViewSchema schemaSrc, int indexSrc, string kindSrc)
        {
            Contracts.CheckValue(schemaSrc, nameof(schemaSrc));
            Contracts.CheckNonWhiteSpace(kindSrc, nameof(kindSrc));
            return new Builder(this, index, schemaSrc, indexSrc, (k, i) => k == kindSrc);
        }
 
        /// <summary>
        /// Start building metadata for a column that passes through metadata of the given kinds from
        /// a source column.
        /// </summary>
        public Builder BuildMetadata(int index, DataViewSchema schemaSrc, int indexSrc, params string[] kindsSrc)
        {
            Contracts.CheckValue(schemaSrc, nameof(schemaSrc));
            Contracts.CheckParam(Utils.Size(kindsSrc) >= 2, nameof(kindsSrc));
            Contracts.CheckParam(!kindsSrc.Any(k => string.IsNullOrWhiteSpace(k)), nameof(kindsSrc));
 
            var set = new HashSet<string>(kindsSrc);
            return new Builder(this, index, schemaSrc, indexSrc, (k, i) => set.Contains(k));
        }
 
        public new void Seal()
        {
            base.Seal();
        }
 
        /// <summary>
        /// The builder for metadata for a particular column.
        /// </summary>
        public sealed class Builder : IDisposable
        {
            private readonly int _index;
            private MetadataDispatcher _md;
            private ColInfo _info;
            private List<GetterInfo> _getters;
 
            /// <summary>
            /// This should really be private to MetadataDispatcher, but C#'s accessibility model doesn't
            /// allow restricting to an outer class.
            /// </summary>
            internal Builder(MetadataDispatcher md, int index,
                DataViewSchema schemaSrc = null, int indexSrc = -1, Func<string, int, bool> filterSrc = null)
            {
                Contracts.CheckValue(md, nameof(md));
                Contracts.CheckParam(0 <= index && index < md.ColCount, nameof(index));
 
                _index = index;
                _md = md;
                _info = _md.CreateInfo(schemaSrc, indexSrc, filterSrc);
 
                var tmp = _md.GetColInfoOrNull(_index);
                Contracts.Check(tmp == null, "Duplicate building of metadata");
            }
 
            /// <summary>
            /// Add metadata of the given kind. When requested, the metadata is fetched by calling the given delegate.
            /// </summary>
            public void AddGetter<TValue>(string kind, DataViewType type,
                AnnotationUtils.AnnotationGetter<TValue> getter)
            {
                Contracts.Check(_md != null, "Builder disposed");
                Contracts.CheckNonEmpty(kind, nameof(kind));
                Contracts.CheckValue(type, nameof(type));
                Contracts.CheckValue(getter, nameof(getter));
                Contracts.CheckParam(type.RawType == typeof(TValue), nameof(type), "Given type doesn't match type parameter");
 
                if (_getters != null && _getters.Any(g => g.Kind == kind))
                    throw Contracts.Except("Duplicate specification of metadata");
                Utils.Add(ref _getters, new GetterInfoDelegate<TValue>(kind, type, getter));
            }
 
            /// <summary>
            /// Add metadata of the given kind, with the given value.
            /// </summary>
            public void AddPrimitive<TValue>(string kind, DataViewType type, TValue value)
            {
                Contracts.Check(_md != null, "Builder disposed");
                Contracts.CheckNonEmpty(kind, nameof(kind));
                Contracts.CheckValue(type, nameof(type));
                Contracts.CheckParam(type.RawType == typeof(TValue), nameof(type), "Given type doesn't match type parameter");
                Contracts.CheckParam(type is PrimitiveDataViewType, nameof(type), "Must be a primitive type");
 
                if (_getters != null && _getters.Any(g => g.Kind == kind))
                    throw Contracts.Except("Duplicate specification of metadata");
                Utils.Add(ref _getters, new GetterInfoPrimitive<TValue>(kind, type, value));
            }
 
            /// <summary>
            /// Close out the builder. This registers the metadata with the dispatcher.
            /// </summary>
            public void Dispose()
            {
                if (_md == null)
                    return;
 
                Contracts.Assert(_info != null);
 
                var md = _md;
                _md = null;
                var info = _info;
                _info = null;
                var getters = _getters;
                _getters = null;
 
                if (Utils.Size(getters) > 0)
                    info = info.UpdateGetters(getters);
 
                md.RegisterColumn(_index, info);
            }
        }
    }
}