File: DataLoadSave\FakeSchema.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.Internal.Utilities;
 
namespace Microsoft.ML.Data.DataLoadSave
{
    /// <summary>
    /// A fake schema that is manufactured out of a SchemaShape.
    /// It will pretend that all vector sizes are equal to 10, all key value counts are equal to 10,
    /// and all values are defaults (for annotations).
    /// </summary>
    [BestFriend]
    internal static class FakeSchemaFactory
    {
        private static readonly FuncStaticMethodInfo1<Delegate> _getDefaultVectorGetterMethodInfo = new FuncStaticMethodInfo1<Delegate>(GetDefaultVectorGetter<int>);
        private static readonly FuncStaticMethodInfo1<Delegate> _getDefaultGetterMethodInfo = new FuncStaticMethodInfo1<Delegate>(GetDefaultGetter<int>);
 
        private const int AllVectorSizes = 10;
        private const int AllKeySizes = 10;
 
        public static DataViewSchema Create(SchemaShape shape)
        {
            var builder = new DataViewSchema.Builder();
 
            for (int i = 0; i < shape.Count; ++i)
            {
                var metaBuilder = new DataViewSchema.Annotations.Builder();
                var partialAnnotations = shape[i].Annotations;
                for (int j = 0; j < partialAnnotations.Count; ++j)
                {
                    var metaColumnType = MakeColumnType(partialAnnotations[j]);
                    Delegate del;
                    if (metaColumnType is VectorDataViewType vectorType)
                        del = Utils.MarshalInvoke(_getDefaultVectorGetterMethodInfo, vectorType.ItemType.RawType);
                    else
                        del = Utils.MarshalInvoke(_getDefaultGetterMethodInfo, metaColumnType.RawType);
                    metaBuilder.Add(partialAnnotations[j].Name, metaColumnType, del);
                }
                builder.AddColumn(shape[i].Name, MakeColumnType(shape[i]), metaBuilder.ToAnnotations());
            }
            return builder.ToSchema();
        }
 
        private static DataViewType MakeColumnType(SchemaShape.Column column)
        {
            DataViewType curType = column.ItemType;
            if (column.IsKey)
                curType = new KeyDataViewType(((PrimitiveDataViewType)curType).RawType, AllKeySizes);
            if (column.Kind == SchemaShape.Column.VectorKind.VariableVector)
                curType = new VectorDataViewType((PrimitiveDataViewType)curType, 0);
            else if (column.Kind == SchemaShape.Column.VectorKind.Vector)
                curType = new VectorDataViewType((PrimitiveDataViewType)curType, AllVectorSizes);
            return curType;
        }
 
        private static Delegate GetDefaultVectorGetter<TValue>()
        {
            ValueGetter<VBuffer<TValue>> getter = (ref VBuffer<TValue> value) => value = new VBuffer<TValue>(AllVectorSizes, 0, null, null);
            return getter;
        }
 
        private static Delegate GetDefaultGetter<TValue>()
        {
            ValueGetter<TValue> getter = (ref TValue value) => value = default;
            return getter;
        }
 
    }
}