File: OnnxTypeParser.cs
Web Access
Project: src\src\Microsoft.ML.OnnxTransformer\Microsoft.ML.OnnxTransformer.csproj (Microsoft.ML.OnnxTransformer)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
 
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Transforms.Onnx
{
    internal static class OnnxTypeParser
    {
        /// <summary>
        /// Derive the corresponding <see cref="Type"/> for ONNX tensor's element type specified by <paramref name="elementType"/>.
        /// The corresponding <see cref="Type"/> should match the type system in ONNXRuntime's C# APIs.
        /// This function is used when determining the corresponding <see cref="Type"/> of <see cref="OnnxCSharpToProtoWrapper.TypeProto"/>.
        /// </summary>
        /// <param name="elementType">ONNX's tensor element type.</param>
        public static Type GetNativeScalarType(int elementType)
        {
            var dataType = (OnnxCSharpToProtoWrapper.TensorProto.Types.DataType)elementType;
            Type scalarType = null;
            switch (dataType)
            {
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Bool:
                    scalarType = typeof(System.Boolean);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int8:
                    scalarType = typeof(System.SByte);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint8:
                    scalarType = typeof(System.Byte);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int16:
                    scalarType = typeof(System.Int16);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint16:
                    scalarType = typeof(System.UInt16);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int32:
                    scalarType = typeof(System.Int32);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint32:
                    scalarType = typeof(System.UInt32);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int64:
                    scalarType = typeof(System.Int64);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint64:
                    scalarType = typeof(System.UInt64);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Double:
                    scalarType = typeof(System.Double);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Float:
                    scalarType = typeof(System.Single);
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.String:
                    scalarType = typeof(string);
                    break;
                default:
                    throw Contracts.Except("Unsupported ONNX scalar type: " + dataType.ToString());
            }
            return scalarType;
        }
 
        /// <summary>
        /// Derive the corresponding <see cref="Type"/> for ONNX variable typed to <paramref name="typeProto"/>.
        /// The corresponding <see cref="Type"/> should match the type system in ONNXRuntime's C# APIs.
        /// </summary>
        /// <param name="typeProto">ONNX variable's type.</param>
        public static Type GetNativeType(OnnxCSharpToProtoWrapper.TypeProto typeProto)
        {
            if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType)
            {
                if (typeProto.TensorType.Shape == null || typeProto.TensorType.Shape.Dim.Count == 0)
                {
                    return GetNativeScalarType(typeProto.TensorType.ElemType);
                }
                else
                {
                    Type tensorType = typeof(VBuffer<>);
                    Type elementType = GetNativeScalarType(typeProto.TensorType.ElemType);
                    return tensorType.MakeGenericType(elementType);
                }
            }
            else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType)
            {
                var enumerableType = typeof(IEnumerable<>);
                var elementType = GetNativeType(typeProto.SequenceType.ElemType);
                return enumerableType.MakeGenericType(elementType);
            }
            else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
            {
                var dictionaryType = typeof(IDictionary<,>);
                Type keyType = GetNativeScalarType(typeProto.MapType.KeyType);
                Type valueType = GetNativeType(typeProto.MapType.ValueType);
                return dictionaryType.MakeGenericType(keyType, valueType);
            }
            return null;
        }
 
        /// <summary>
        /// Derive the corresponding <see cref="DataViewType"/> for ONNX tensor's element type specified by <paramref name="elementType"/>.
        /// </summary>
        /// <param name="elementType">ONNX's tensor element type.</param>
        public static DataViewType GetScalarDataViewType(int elementType)
        {
            var dataType = (OnnxCSharpToProtoWrapper.TensorProto.Types.DataType)elementType;
            DataViewType scalarType = null;
            switch (dataType)
            {
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Bool:
                    scalarType = BooleanDataViewType.Instance;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int8:
                    scalarType = NumberDataViewType.SByte;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint8:
                    scalarType = NumberDataViewType.Byte;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int16:
                    scalarType = NumberDataViewType.Int16;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint16:
                    scalarType = NumberDataViewType.UInt16;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int32:
                    scalarType = NumberDataViewType.Int32;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint32:
                    scalarType = NumberDataViewType.UInt32;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int64:
                    scalarType = NumberDataViewType.Int64;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint64:
                    scalarType = NumberDataViewType.UInt64;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Float:
                    scalarType = NumberDataViewType.Single;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Double:
                    scalarType = NumberDataViewType.Double;
                    break;
                case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.String:
                    scalarType = TextDataViewType.Instance;
                    break;
                default:
                    throw Contracts.Except("Unsupported ONNX scalar type: " + dataType.ToString());
            }
            return scalarType;
        }
 
        /// <summary>
        /// Parse the dimension information of a single tensor axis. Note that 2-D ONNX tensors have two axes.
        /// </summary>
        /// <param name="dim">ONNX's tensor dimension.</param>
        public static int GetDimValue(OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension dim)
        {
            int value = 0;
            switch (dim.ValueCase)
            {
                case OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue:
                    // The vector length in ML.NET is typed to 32-bit integer, so the check below is added for perverting overflowing.
                    if (dim.DimValue > int.MaxValue)
                        throw Contracts.ExceptParamValue(dim.DimValue, nameof(dim), $"Dimension {dim} in ONNX tensor cannot exceed the maximum of 32-bit signed integer.");
                    // Variable-length dimension is translated to 0.
                    value = dim.DimValue > 0 ? (int)dim.DimValue : 0;
                    break;
                case OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension.ValueOneofCase.DimParam:
                    // Variable-length dimension is translated to 0.
                    break;
                case OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension.ValueOneofCase.None:
                    // Empty dimension is translated to 0.
                    break;
            }
            return value;
        }
 
        /// <summary>
        /// Parse the shape information of a tensor.
        /// </summary>
        /// <param name="tensorShapeProto">ONNX's tensor shape.</param>
        public static IEnumerable<int> GetTensorDims(Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper.TensorShapeProto tensorShapeProto)
        {
            if (tensorShapeProto == null)
                // Scalar has null dimensionality.
                return null;
 
            List<int> dims = new List<int>();
            foreach (var d in tensorShapeProto.Dim)
            {
                var dimValue = GetDimValue(d);
                dims.Add(dimValue);
            }
 
            // In ONNX, the first dimension refers to the batch size. If that is set to -1, it means OnnxRuntime can do inferencing in batches on
            // multiple rows at once. In ML.NET, a vector is considered to be of known size if the dimensions are all greater than zero
            // Leaving the batch size at -1 causes all Onnx vectors to be considered to be of unknown size. Therefore, if the first dimension is -1,
            // we need to fix up the shape. But GetDimValue above converts any dimension < 0 to be 0. We need that behavior for dimensions other than
            // the first dimension. So we check only the first dimension here and fix it up. (The '<=' comparison below is there to make sure that
            // this holds even if the behavior of GetDimValue changes).
            if ((dims.Count > 0) && (dims[0] <= 0))
                dims[0] = 1;
 
            return dims;
        }
 
        /// <summary>
        /// Derive the corresponding <see cref="DataViewType"/> for ONNX variable typed to <paramref name="typeProto"/>.
        /// The returned <see cref="DataViewType.RawType"/> should match the type system in ONNXRuntime's C# APIs.
        /// </summary>
        /// <param name="typeProto">ONNX variable's type.</param>
        public static DataViewType GetDataViewType(OnnxCSharpToProtoWrapper.TypeProto typeProto)
        {
            var oneOfFieldName = typeProto.ValueCase.ToString();
            if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType)
            {
                if (typeProto.TensorType.Shape.Dim.Count == 0)
                    // ONNX scalar is a tensor without shape information; that is,
                    // ONNX scalar's shape is an empty list.
                    return GetScalarDataViewType(typeProto.TensorType.ElemType);
                else
                {
                    var shape = GetTensorDims(typeProto.TensorType.Shape);
                    if (shape == null)
                        // Scalar has null shape.
                        return GetScalarDataViewType(typeProto.TensorType.ElemType);
                    else if (shape.Count() != 0 && shape.Aggregate((x, y) => x * y) > 0)
                        // Known shape tensor.
                        return new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), shape.ToArray());
                    else
                        // Tensor with unknown shape.
                        return new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), 0);
                }
            }
            else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType)
            {
                var elemTypeProto = typeProto.SequenceType.ElemType;
                var elemType = GetNativeType(elemTypeProto);
                return new OnnxSequenceType(elemType);
            }
            else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
            {
                var keyType = GetNativeScalarType(typeProto.MapType.KeyType);
                var valueType = GetNativeType(typeProto.MapType.ValueType);
                return new OnnxMapType(keyType, valueType);
            }
            else
                throw Contracts.ExceptParamValue(typeProto, nameof(typeProto), $"Unsupported ONNX variable type {typeProto}");
        }
 
        /// <summary>
        /// Class which store casting functions used in <see cref="GetDataViewValueCasterAndResultedType(OnnxCSharpToProtoWrapper.TypeProto, out Type)"/>.
        /// </summary>
        private class CastHelper
        {
            public static T CastTo<T>(object o) => (T)o;
 
            public static IEnumerable<TDst> CastOnnxSequenceToIEnumerable<TSrc, TDst>(IEnumerable<TSrc> o, Func<TSrc, object> caster)
            {
                // Since now we're disposing the NamedOnnxValue objects
                // after running inference on each output, we need
                // to copy (enumerate) the output through ".ToList()"
                // else, if our users try the keep the past sequence
                // outputs of their OnnxTransformer, they would
                // end up with empty sequences.
                return o.Select(v => (TDst)caster(v)).ToList();
            }
        }
 
        /// <summary>
        /// Create a <see cref="Func{T, TResult}"/> to map a <see cref="NamedOnnxValue"/> to the associated .NET <see langword="object"/>.
        /// The resulted .NET object's actual type is <paramref name="resultedType"/>.
        /// The returned <see cref="DataViewType.RawType"/> should match the type system in ONNXRuntime's C# APIs.
        /// </summary>
        /// <param name="typeProto">ONNX variable's type.</param>
        /// <param name="resultedType">C# type of <paramref name="typeProto"/>.</param>
        public static Func<NamedOnnxValue, object> GetDataViewValueCasterAndResultedType(OnnxCSharpToProtoWrapper.TypeProto typeProto, out Type resultedType)
        {
            var oneOfFieldName = typeProto.ValueCase.ToString();
            if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType)
            {
                var shape = GetTensorDims(typeProto.TensorType.Shape);
 
                if (shape == null)
                {
                    // Entering this scope means that an ONNX scalar is found. Note that ONNX scalar is typed to tensor without a shape.
 
                    // Get tensor element type.
                    var type = GetScalarDataViewType(typeProto.TensorType.ElemType).RawType;
 
                    // Access the first element as a scalar.
                    var accessInfo = typeof(Tensor<>).GetMethod(nameof(Tensor<int>.GetValue));
                    var accessSpecialized = accessInfo.MakeGenericMethod(type);
 
                    // NamedOnnxValue to scalar.
                    Func<NamedOnnxValue, object> caster = (NamedOnnxValue value) =>
                    {
                        var scalar = accessSpecialized.Invoke(value, new object[] { 0 });
                        return scalar;
                    };
 
                    resultedType = type;
 
                    return caster;
                }
                else
                {
                    // Entering this scope means an ONNX tensor is found.
 
                    var type = GetScalarDataViewType(typeProto.TensorType.ElemType).RawType;
                    var methodInfo = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsTensor));
                    var methodSpecialized = methodInfo.MakeGenericMethod(type);
 
                    // NamedOnnxValue to Tensor.
                    Func<NamedOnnxValue, object> caster = (NamedOnnxValue value) => methodSpecialized.Invoke(value, new object[] { });
 
                    resultedType = typeof(Tensor<>).MakeGenericType(type);
 
                    return caster;
                }
            }
            else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType)
            {
                // Now, we see a Sequence in ONNX. If its element type is T, the variable produced by
                // ONNXRuntime would be typed to IEnumerable<T>.
 
                // Find a proper caster (a function which maps NamedOnnxValue to a .NET object) for the element in
                // the ONNX sequence. Note that ONNX sequence is typed to IEnumerable<NamedOnnxValue>, so we need
                // to convert NamedOnnxValue to a proper type such as IDictionary<>.
                var elementCaster = GetDataViewValueCasterAndResultedType(typeProto.SequenceType.ElemType, out Type elementType);
 
                // Set the .NET type which corresponds to the first input argument, typeProto.
                resultedType = typeof(IEnumerable<>).MakeGenericType(elementType);
 
                // Create the element's caster to map IEnumerable<NamedOnnxValue> produced by ONNXRuntime to
                // IEnumerable<elementType>.
                var methodInfo = typeof(CastHelper).GetMethod(nameof(CastHelper.CastOnnxSequenceToIEnumerable));
                var methodSpecialized = methodInfo.MakeGenericMethod(typeof(NamedOnnxValue), elementType);
 
                // Use element-level caster to create sequence caster.
                Func<NamedOnnxValue, object> caster = (NamedOnnxValue value) =>
                {
                    var enumerable = value.AsEnumerable<NamedOnnxValue>();
                    return methodSpecialized.Invoke(null, new object[] { enumerable, elementCaster });
                };
 
                return caster;
            }
            else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
            {
                // Entering this scope means a ONNX Map (equivalent to IDictionary<>) will be produced.
 
                var keyType = GetNativeScalarType(typeProto.MapType.KeyType);
                var valueType = GetNativeType(typeProto.MapType.ValueType);
 
                // The resulted type of the object returned by the caster below.
                resultedType = typeof(IDictionary<,>).MakeGenericType(keyType, valueType);
 
                // Create a method to convert NamedOnnxValue to IDictionary<keyValue, valueType>.
                var asDictionaryMethodInfo = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsDictionary));
                var asDictionaryMethod = asDictionaryMethodInfo.MakeGenericMethod(keyType, valueType);
 
                // Create a caster to convert NamedOnnxValue to IDictionary<keyValue, valueType>.
                Func<NamedOnnxValue, object> caster = (NamedOnnxValue value) =>
                {
                    return asDictionaryMethod.Invoke(value, new object[] { });
                };
 
                return caster;
            }
            else
                throw Contracts.ExceptParamValue(typeProto, nameof(typeProto), $"Unsupported ONNX variable type {typeProto}");
        }
    }
 
}