File: OnnxUtils.cs
Web Access
Project: src\src\Microsoft.ML.OnnxConverter\Microsoft.ML.OnnxConverter.csproj (Microsoft.ML.OnnxConverter)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Google.Protobuf;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper;
namespace Microsoft.ML.Model.OnnxConverter
    /// <summary>
    /// Contains methods to create ONNX models in protocol buffer.
    /// </summary>
    internal static class OnnxUtils
        private static TypeProto MakeType(TypeProto typeProto, TensorProto.Types.DataType dataType,
            List<long> dims, List<bool> dimsParam)
            Contracts.CheckValue(typeProto, nameof(typeProto));
            if (typeProto.TensorType == null)
                typeProto.TensorType = new TypeProto.Types.Tensor();
            typeProto.TensorType.ElemType = (int)dataType;
            if (dims != null)
                for (int index = 0; index < dims.Count; index++)
                    var d = new TensorShapeProto.Types.Dimension();
                    if (typeProto.TensorType.Shape == null)
                        typeProto.TensorType.Shape = new TensorShapeProto();
                    if (dimsParam != null && dimsParam.Count > index && dimsParam[index])
                        d.DimParam = "None";
                        d.DimValue = dims[index];
            return typeProto;
        private static ValueInfoProto MakeValue(ValueInfoProto value, string name, TensorProto.Types.DataType dataType,
            List<long> dims, List<bool> dimsParam)
            Contracts.CheckValue(value, nameof(value));
            Contracts.CheckNonEmpty(name, nameof(name));
            value.Name = name;
            if (value.Type == null)
                value.Type = new TypeProto();
            MakeType(value.Type, dataType, dims, dimsParam);
            return value;
        private static AttributeProto MakeAttribute(string key)
            Contracts.CheckNonEmpty(key, nameof(key));
            var attribute = new AttributeProto();
            attribute.Name = key;
            return attribute;
        private static AttributeProto MakeAttribute(string key, TensorProto.Types.DataType value)
            AttributeProto attribute = MakeAttribute(key);
            attribute.Type = AttributeProto.Types.AttributeType.Int;
            attribute.I = (int)value;
            return attribute;
        private static AttributeProto MakeAttribute(string key, double value)
            AttributeProto attribute = MakeAttribute(key);
            attribute.Type = AttributeProto.Types.AttributeType.Float;
            attribute.F = (float)value;
            return attribute;
        private static AttributeProto MakeAttribute(string key, IEnumerable<double> value)
            Contracts.CheckValue(value, nameof(value));
            AttributeProto attribute = MakeAttribute(key);
            attribute.Type = AttributeProto.Types.AttributeType.Floats;
            attribute.Floats.Add(value.Select(x => (float)x));
            return attribute;
        private static AttributeProto MakeAttribute(string key, IEnumerable<float> value)
            Contracts.CheckValue(value, nameof(value));
            AttributeProto attribute = MakeAttribute(key);
            attribute.Type = AttributeProto.Types.AttributeType.Floats;
            attribute.Floats.Add(value.Select(x => x));
            return attribute;
        private static AttributeProto MakeAttribute(string key, long value)
            AttributeProto attribute = MakeAttribute(key);
            attribute.Type = AttributeProto.Types.AttributeType.Int;
            attribute.I = value;
            return attribute;
        private static AttributeProto MakeAttribute(string key, IEnumerable<long> value)
            Contracts.CheckValue(value, nameof(value));
            AttributeProto attribute = MakeAttribute(key);
            attribute.Type = AttributeProto.Types.AttributeType.Ints;
            return attribute;
        private static AttributeProto MakeAttribute(string key, ByteString value)
            AttributeProto attribute = MakeAttribute(key);
            attribute.Type = AttributeProto.Types.AttributeType.String;
            attribute.S = value;
            return attribute;
        private static AttributeProto MakeAttribute(string key, IEnumerable<ByteString> value)
            Contracts.CheckValue(value, nameof(value));
            AttributeProto attribute = MakeAttribute(key);
            attribute.Type = AttributeProto.Types.AttributeType.Strings;
            return attribute;
        private static AttributeProto MakeAttribute(string key, GraphProto value)
            AttributeProto attribute = MakeAttribute(key);
            attribute.Type = AttributeProto.Types.AttributeType.Graph;
            attribute.G = value;
            return attribute;
        private static AttributeProto MakeAttribute(string key, IEnumerable<GraphProto> value)
            Contracts.CheckValue(value, nameof(value));
            AttributeProto attribute = MakeAttribute(key);
            attribute.Type = AttributeProto.Types.AttributeType.Graphs;
            return attribute;
        private static AttributeProto MakeAttribute(string key, bool value) => MakeAttribute(key, value ? 1 : 0);
        public static NodeProto MakeNode(string opType, IEnumerable<string> inputs, IEnumerable<string> outputs, string name, string domain = null)
            Contracts.CheckNonEmpty(opType, nameof(opType));
            Contracts.CheckValue(inputs, nameof(inputs));
            Contracts.CheckValue(outputs, nameof(outputs));
            Contracts.CheckNonEmpty(name, nameof(name));
            var node = new NodeProto();
            node.OpType = opType;
            node.Name = name;
            node.Domain = domain ?? "";
            return node;
        public static void NodeAddAttributes(NodeProto node, string argName, double value)
            => node.Attribute.Add(MakeAttribute(argName, value));
        public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable<double> value)
            => node.Attribute.Add(MakeAttribute(argName, value));
        public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable<float> value)
            => node.Attribute.Add(MakeAttribute(argName, value));
        public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable<bool> value)
            => node.Attribute.Add(MakeAttribute(argName, value.Select(v => v ? (long)1 : 0)));
        public static void NodeAddAttributes(NodeProto node, string argName, long value)
            => node.Attribute.Add(MakeAttribute(argName, value));
        public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable<long> value)
            => node.Attribute.Add(MakeAttribute(argName, value));
        public static void NodeAddAttributes(NodeProto node, string argName, ReadOnlyMemory<char> value)
            => node.Attribute.Add(MakeAttribute(argName, StringToByteString(value)));
        public static void NodeAddAttributes(NodeProto node, string argName, string[] value)
            => node.Attribute.Add(MakeAttribute(argName, StringToByteString(value)));
        public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable<ReadOnlyMemory<char>> value)
            => node.Attribute.Add(MakeAttribute(argName, StringToByteString(value)));
        public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable<string> value)
            => node.Attribute.Add(MakeAttribute(argName, StringToByteString(value)));
        public static void NodeAddAttributes(NodeProto node, string argName, string value)
            => node.Attribute.Add(MakeAttribute(argName, StringToByteString(value)));
        public static void NodeAddAttributes(NodeProto node, string argName, GraphProto value)
            => node.Attribute.Add(MakeAttribute(argName, value));
        public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable<GraphProto> value)
            => node.Attribute.Add(MakeAttribute(argName, value));
        public static void NodeAddAttributes(NodeProto node, string argName, bool value)
            => node.Attribute.Add(MakeAttribute(argName, value));
        public static void NodeAddAttributes(NodeProto node, string argName, Type value)
            => node.Attribute.Add(MakeAttribute(argName, ConvertToTensorProtoType(value)));
        private static TensorProto.Types.DataType ConvertToTensorProtoType(Type rawType)
            var dataType = TensorProto.Types.DataType.Undefined;
            if (rawType == typeof(bool))
                dataType = TensorProto.Types.DataType.Bool;
            else if (rawType == typeof(ReadOnlyMemory<char>))
                dataType = TensorProto.Types.DataType.String;
            else if (rawType == typeof(sbyte))
                dataType = TensorProto.Types.DataType.Int8;
            else if (rawType == typeof(byte))
                dataType = TensorProto.Types.DataType.Uint8;
            else if (rawType == typeof(short))
                dataType = TensorProto.Types.DataType.Int16;
            else if (rawType == typeof(ushort))
                dataType = TensorProto.Types.DataType.Uint16;
            else if (rawType == typeof(int))
                dataType = TensorProto.Types.DataType.Int32;
            else if (rawType == typeof(uint))
                dataType = TensorProto.Types.DataType.Uint32;
            else if (rawType == typeof(long))
                dataType = TensorProto.Types.DataType.Int64;
            else if (rawType == typeof(ulong))
                dataType = TensorProto.Types.DataType.Uint64;
            else if (rawType == typeof(float))
                dataType = TensorProto.Types.DataType.Float;
            else if (rawType == typeof(double))
                dataType = TensorProto.Types.DataType.Double;
                string msg = "Unsupported type: " + rawType.ToString();
                Contracts.Check(false, msg);
            return dataType;
        private static ByteString StringToByteString(ReadOnlyMemory<char> str) => ByteString.CopyFrom(Encoding.UTF8.GetBytes(str.ToString()));
        private static IEnumerable<ByteString> StringToByteString(IEnumerable<ReadOnlyMemory<char>> str)
            => str.Select(s => ByteString.CopyFrom(Encoding.UTF8.GetBytes(s.ToString())));
        private static IEnumerable<ByteString> StringToByteString(IEnumerable<string> str)
            => str.Select(s => ByteString.CopyFrom(Encoding.UTF8.GetBytes(s)));
        private static ByteString StringToByteString(string str) => ByteString.CopyFrom(Encoding.UTF8.GetBytes(str));
        public sealed class ModelArgs
            public readonly string Name;
            public readonly TensorProto.Types.DataType DataType;
            public readonly List<long> Dims;
            public readonly List<bool> DimParams;
            public ModelArgs(string name, TensorProto.Types.DataType dataType, List<long> dims, List<bool> dimParams)
                Name = name;
                DataType = dataType;
                Dims = dims;
                DimParams = dimParams;
        public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name,
            string domain, string producerVersion, long modelVersion, int opSetVersion, List<ModelArgs> inputs,
            List<ModelArgs> outputs, List<ModelArgs> intermediateValues, List<TensorProto> initializers)
            Contracts.CheckValue(nodes, nameof(nodes));
            Contracts.CheckValue(inputs, nameof(inputs));
            Contracts.CheckValue(outputs, nameof(outputs));
            Contracts.CheckValue(intermediateValues, nameof(intermediateValues));
            Contracts.CheckValue(initializers, nameof(initializers));
            Contracts.CheckNonEmpty(producerName, nameof(producerName));
            Contracts.CheckNonEmpty(name, nameof(name));
            Contracts.CheckNonEmpty(domain, nameof(domain));
            Contracts.CheckNonEmpty(producerVersion, nameof(producerVersion));
            var model = new ModelProto();
            model.Domain = domain;
            model.ProducerName = producerName;
            model.ProducerVersion = producerVersion;
            model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion;
            model.ModelVersion = modelVersion;
            model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 2 });
            model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = opSetVersion });
            model.Graph = new GraphProto();
            var graph = model.Graph;
            graph.Name = name;
            foreach (var arg in inputs)
                var val = new ValueInfoProto();
                MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
            foreach (var arg in outputs)
                var val = new ValueInfoProto();
                MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
            foreach (var arg in intermediateValues)
                var val = new ValueInfoProto();
                MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
            return model;
        public static ModelArgs GetModelArgs(DataViewType type, string colName,
            List<long> dims = null, List<bool> dimsParams = null)
            Contracts.CheckValue(type, nameof(type));
            Contracts.CheckNonEmpty(colName, nameof(colName));
            Type rawType;
            if (type is VectorDataViewType vectorType)
                rawType = vectorType.ItemType.RawType;
                rawType = type.RawType;
            var dataType = ConvertToTensorProtoType(rawType);
            string name = colName;
            List<long> dimsLocal = null;
            List<bool> dimsParamLocal = null;
            if (dims != null)
                dimsLocal = dims;
                dimsParamLocal = dimsParams;
                dimsLocal = new List<long>();
                int valueCount = type.GetValueCount();
                if (valueCount == 0) //Unknown size.
                    dimsParamLocal = new List<bool>() { false, true }; //false for batch size, true for dims.
                else if (valueCount == 1)
                else if (valueCount > 1)
                    var vec = (VectorDataViewType)type;
                    for (int i = 0; i < vec.Dimensions.Length; i++)
            // Set batch size to -1. The ONNX docs,, state that if
            // dim_param is used instead of dim_value, that the size of the dimension "is not statically constrained to a particular number"
            // "This is useful for declaring the interfaces that care about the number of dimensions, but not the exact size of each dimension"
            // This file,, explains that if the dim value is negative
            // than it treats that as a dim_param instead of a dim_value. This allows ML.NET to run 1 row at a time in a streaming fassion,
            // but allows the ONNX model the flexibility to be run in batch mode if that is desired.
            dimsLocal?.Insert(0, -1);
            return new ModelArgs(name, dataType, dimsLocal, dimsParamLocal);
        // Make long scalar in ONNX from native C# number
        public static TensorProto MakeInt64(string name, long value)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)TensorProto.Types.DataType.Int64;
            return tensor;
        // Make long vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
        public static TensorProto MakeInt64s(string name, IEnumerable<long> values, IEnumerable<long> dims = null)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)TensorProto.Types.DataType.Int64;
            if (dims != null)
            return tensor;
        // Make int32 and smaller integer types scalar in ONNX from native C# number
        public static TensorProto MakeInt32(string name, Type type, int value)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)ConvertToTensorProtoType(type);
            return tensor;
        // Make int32 and smaller integer types vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
        public static TensorProto MakeInt32s(string name, Type type, IEnumerable<int> values, IEnumerable<long> dims = null)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)ConvertToTensorProtoType(type);
            if (dims != null)
            return tensor;
        // Make ulong and uint integer types scalar in ONNX from native C# number
        public static TensorProto MakeUInt(string name, bool isUint64, ulong value)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)ConvertToTensorProtoType(isUint64 ? typeof(ulong) : typeof(uint));
            return tensor;
        // Make ulong and uint integer vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
        public static TensorProto MakeUInts(string name, bool isUint64, IEnumerable<ulong> values, IEnumerable<long> dims = null)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)ConvertToTensorProtoType(isUint64 ? typeof(ulong) : typeof(uint));
            if (dims != null)
            return tensor;
        // Make int32 and smaller integer types scalar in ONNX from native C# number
        public static TensorProto MakeDouble(string name, double value)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)TensorProto.Types.DataType.Double;
            return tensor;
        // Make double vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
        public static TensorProto MakeDoubles(string name, IEnumerable<double> values, IEnumerable<long> dims = null)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)TensorProto.Types.DataType.Double;
            if (dims != null)
            return tensor;
        // Make float scalar in ONNX from native C# number
        public static TensorProto MakeFloat(string name, float value)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)TensorProto.Types.DataType.Float;
            return tensor;
        // Make float vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
        public static TensorProto MakeFloats(string name, IEnumerable<float> values, IEnumerable<long> dims = null)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)TensorProto.Types.DataType.Float;
            if (dims != null)
            return tensor;
        // Make string scalar in ONNX from native C# number
        public static TensorProto MakeString(string name, string value)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)TensorProto.Types.DataType.String;
            return tensor;
        // Make string vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
        public static TensorProto MakeStrings(string name, IEnumerable<string> values, IEnumerable<long> dims = null)
            var tensor = new TensorProto();
            tensor.Name = name;
            tensor.DataType = (int)TensorProto.Types.DataType.String;
            if (dims != null)
            return tensor;