File: OnnxContextImpl.cs
Web Access
Project: src\src\Microsoft.ML.OnnxConverter\Microsoft.ML.OnnxConverter.csproj (Microsoft.ML.OnnxConverter)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Model.OnnxConverter
{
    /// <summary>
    /// A context for defining a ONNX output.
    /// </summary>
    internal sealed class OnnxContextImpl : OnnxContext
    {
        private const int CurrentOpSetVersion = 12;
        private const int MinimumOpSetVersion = 9;
        private readonly List<OnnxCSharpToProtoWrapper.NodeProto> _nodes;
        private readonly List<OnnxUtils.ModelArgs> _inputs;
        // The map from IDataView column names to variable names.
        private readonly List<OnnxCSharpToProtoWrapper.TensorProto> _initializers;
        private readonly List<OnnxUtils.ModelArgs> _intermediateValues;
        private readonly List<OnnxUtils.ModelArgs> _outputs;
        private readonly Dictionary<string, string> _columnNameMap;
        // All existing variable names. New variables must not exist in this set.
        private readonly HashSet<string> _variableNames;
        // All existing node names. New node names must not alrady exist in this set.
        private readonly HashSet<string> _nodeNames;
        private readonly string _name;
        private readonly string _producerName;
        private readonly IHost _host;
        private readonly string _domain;
        private readonly string _producerVersion;
        private readonly long _modelVersion;
        private readonly OnnxVersion _onnxVersion;
        private readonly int _opSetVersion;
 
        public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
            string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion, int opSetVersion = CurrentOpSetVersion)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(OnnxContext));
            _host.CheckValue(name, nameof(name));
            _host.CheckValue(name, nameof(domain));
 
            _nodes = new List<OnnxCSharpToProtoWrapper.NodeProto>();
            _intermediateValues = new List<OnnxUtils.ModelArgs>();
            _inputs = new List<OnnxUtils.ModelArgs>();
            _initializers = new List<OnnxCSharpToProtoWrapper.TensorProto>();
            _outputs = new List<OnnxUtils.ModelArgs>();
            _columnNameMap = new Dictionary<string, string>();
            _variableNames = new HashSet<string>();
            _nodeNames = new HashSet<string>();
            _name = name;
            _producerName = producerName;
            _producerVersion = producerVersion;
            _modelVersion = modelVersion;
            _domain = domain;
            _onnxVersion = onnxVersion;
            _opSetVersion = opSetVersion <= CurrentOpSetVersion ?
                            opSetVersion >= MinimumOpSetVersion ? opSetVersion : throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is lower than the minimum required OpSet version {MinimumOpSetVersion}") :
                            throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is higher than the current most updated OpSet version {CurrentOpSetVersion}");
        }
 
        public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
 
        public override bool IsVariableDefined(string variableName) => _variableNames.Contains(variableName);
 
        /// <summary>
        /// Stops tracking a column. If removeVariable is true then it also removes the
        /// variable associated with it, this is useful in the event where an output variable is
        /// created before realizing the transform cannot actually save as ONNX.
        /// </summary>
        /// <param name="colName">IDataView column name to stop tracking</param>
        /// <param name="removeVariable">Remove associated ONNX variable at the time.</param>
        public override void RemoveColumn(string colName, bool removeVariable)
        {
            _host.CheckNonEmpty(colName, nameof(colName));
 
            if (removeVariable)
            {
                foreach (var val in _intermediateValues)
                {
                    if (val.Name == _columnNameMap[colName])
                    {
                        _intermediateValues.Remove(val);
                        break;
                    }
                }
            }
            _columnNameMap.Remove(colName);
        }
 
        /// <summary>
        /// Removes an ONNX variable. If removeColumn is true then it also removes the
        /// IDataView column associated with it.
        /// </summary>
        /// <param name="variableName">ONNX variable to remove.</param>
        /// <param name="removeColumn">IDataView column to stop tracking</param>
        public override void RemoveVariable(string variableName, bool removeColumn)
        {
            _host.CheckNonEmpty(variableName, nameof(variableName));
            if (!_columnNameMap.ContainsValue(variableName))
                throw _host.ExceptParam(nameof(variableName), $"Could not find '{variableName}' declared in ONNX graph");
 
            if (removeColumn)
            {
                foreach (var val in _intermediateValues)
                {
                    if (val.Name == variableName)
                    {
                        _intermediateValues.Remove(val);
                        break;
                    }
                }
            }
 
            string columnName = _columnNameMap.Single(kvp => kvp.Value == variableName).Key;
 
            Contracts.Assert(_variableNames.Contains(columnName));
 
            _columnNameMap.Remove(columnName);
            _variableNames.Remove(columnName);
        }
 
        /// <summary>
        /// Generates a unique name for the node based on a prefix.
        /// </summary>
        public override string GetNodeName(string prefix)
        {
            _host.CheckNonEmpty(prefix, nameof(prefix));
            return GetUniqueName(prefix, _nodeNames.Contains);
        }
 
        public override void CheckOpSetVersion(int thisTransformerMinumumOpSetVersion, string registerTransformerName)
        {
            if (_opSetVersion < thisTransformerMinumumOpSetVersion)
                throw _host.ExceptParam(nameof(thisTransformerMinumumOpSetVersion), $"Requested OpSet version {_opSetVersion} is lower than {registerTransformerName}'s minimum OpSet version requirement: {thisTransformerMinumumOpSetVersion}");
        }
 
        /// <summary>
        /// Adds a node to the node list of the graph.
        /// </summary>
        /// <param name="node"></param>
        private void AddNode(OnnxCSharpToProtoWrapper.NodeProto node)
        {
            _host.CheckValue(node, nameof(node));
            _host.Assert(!_nodeNames.Contains(node.Name));
 
            _nodeNames.Add(node.Name);
            _nodes.Add(node);
        }
 
        public override OnnxNode CreateNode(string opType, IEnumerable<string> inputs,
            IEnumerable<string> outputs, string name, string domain = null)
        {
            _host.CheckNonEmpty(opType, nameof(opType));
            _host.CheckValue(inputs, nameof(inputs));
            _host.CheckValue(outputs, nameof(outputs));
            _host.CheckNonEmpty(name, nameof(name));
 
            var innerNode = OnnxUtils.MakeNode(opType, inputs, outputs, name, domain);
            AddNode(innerNode);
            return new OnnxNodeImpl(innerNode);
        }
 
        /// <summary>
        /// Generates a unique name based on a prefix.
        /// </summary>
        private string GetUniqueName(string prefix, Func<string, bool> pred)
        {
            _host.CheckNonEmpty(prefix, nameof(prefix));
            _host.CheckValue(pred, nameof(pred));
 
            if (!pred(prefix))
                return prefix;
 
            int count = 0;
            while (pred(prefix + count++)) ;
            return prefix + --count;
        }
 
        /// <summary>
        /// Retrieves the variable name that maps to the IDataView column name at a
        /// given point in the pipeline execution.
        /// </summary>
        /// <returns>Column Name mapping.</returns>
        public override string GetVariableName(string colName)
        {
            _host.CheckNonEmpty(colName, nameof(colName));
            _host.Assert(_columnNameMap.ContainsKey(colName));
 
            return _columnNameMap[colName];
        }
 
        /// <summary>
        /// Retrieves the variable name that maps to the IDataView column name at a
        /// given point in the pipeline execution.
        /// </summary>
        /// <returns>Column Name mapping.</returns>
        public string TryGetVariableName(string colName)
        {
            _host.CheckNonEmpty(colName, nameof(colName));
            if (_columnNameMap.ContainsKey(colName))
                return GetVariableName(colName);
            return null;
        }
 
        /// <summary>
        /// Generates a unique column name based on the IDataView column name if
        /// there is a collision between names in the pipeline at any point.
        /// </summary>
        /// <param name="colName">IDataView column name.</param>
        /// <param name="makeUniqueName">Whether a unique name should be chosen for this variable.</param>
        /// <returns>Unique variable name.</returns>
        public string AddVariable(string colName, bool makeUniqueName = true)
        {
            _host.CheckNonEmpty(colName, nameof(colName));
            _columnNameMap[colName] = makeUniqueName ? GetUniqueName(colName, _variableNames.Contains) : colName;
            _variableNames.Add(_columnNameMap[colName]);
            return _columnNameMap[colName];
        }
 
        /// <summary>
        /// Adds an intermediate column to the list.
        /// </summary>
        public override string AddIntermediateVariable(DataViewType type, string colName, bool skip = false)
        {
            colName = AddVariable(colName);
            // Let the runtime figure the shape.
            if (!skip)
            {
                _host.CheckValue(type, nameof(type));
                _intermediateValues.Add(OnnxUtils.GetModelArgs(type, colName));
            }
            return colName;
        }
 
        /// <summary>
        /// Adds an output variable to the list.
        /// </summary>
        public void AddOutputVariable(DataViewType type, string variableName, List<long> dim = null)
        {
            _host.CheckValue(type, nameof(type));
            _host.CheckParam(IsVariableDefined(variableName), nameof(variableName));
            _outputs.Add(OnnxUtils.GetModelArgs(type, variableName, dim));
        }
 
        /// <summary>
        /// Adds an input variable to the list.
        /// </summary>
        public void AddInputVariable(DataViewType type, string colName)
        {
            _host.CheckValue(type, nameof(type));
            _host.CheckValue(colName, nameof(colName));
 
            colName = AddVariable(colName);
            _inputs.Add(OnnxUtils.GetModelArgs(type, colName));
        }
 
        public override void RemoveInputVariable(string colName)
        {
            var variableName = TryGetVariableName(colName);
            _host.CheckValue(variableName, nameof(variableName));
 
            RemoveVariable(variableName, true);
            _inputs.Remove(_inputs.Single(modelArg => modelArg.Name == variableName));
        }
 
        /// <summary>
        /// Retrieve the shape of an ONNX variable. Returns null if no shape for the specified variable can be found.
        /// </summary>
        /// <param name="variableName">The ONNX name of the returned shape</param>
        /// <returns>The shape of the retrieved variable</returns>
        public override List<long> RetrieveShapeOrNull(string variableName)
        {
            foreach (var arg in _inputs)
                if (arg.Name == variableName)
                    return arg.Dims;
 
            foreach (var arg in _intermediateValues)
                if (arg.Name == variableName)
                    return arg.Dims;
 
            foreach (var arg in _outputs)
                if (arg.Name == variableName)
                    return arg.Dims;
 
            return null;
        }
 
        /// Adds constant tensor into the graph.
        public override string AddInitializer(bool value, string name = null, bool makeUniqueName = true)
        {
            name = AddVariable(name ?? "bool", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeInt32(name, typeof(bool), value ? 1 : 0));
            return name;
        }
 
        public override string AddInitializer(float value, string name = null, bool makeUniqueName = true)
        {
            name = AddVariable(name ?? "float", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeFloat(name, value));
            return name;
        }
 
        public override string AddInitializer(int value, Type type, string name = null, bool makeUniqueName = true)
        {
            name = AddVariable(name ?? "int32", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeInt32(name, type, value));
            return name;
        }
 
        public override string AddInitializer(string value, string name = null, bool makeUniqueName = true)
        {
            name = AddVariable(name ?? "string", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeString(name, value));
            return name;
        }
 
        public override string AddInitializer(long value, string name = null, bool makeUniqueName = true)
        {
            name = AddVariable(name ?? "int64", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeInt64(name, value));
            return name;
        }
 
        public override string AddInitializer(double value, string name = null, bool makeUniqueName = true)
        {
            name = AddVariable(name ?? "double", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeDouble(name, value));
            return name;
        }
 
        public override string AddInitializer(ulong value, bool isUint64, string name = null, bool makeUniqueName = true)
        {
            name = AddVariable(name ?? "uint64", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeUInt(name, isUint64, value));
            return name;
        }
 
        public override string AddInitializer(IEnumerable<bool> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
        {
            _host.CheckValue(values, nameof(values));
            if (dims != null)
                _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
 
            name = AddVariable(name ?? "bools", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeInt32s(name, typeof(bool), values.Select(v => Convert.ToInt32(v)), dims));
            return name;
        }
 
        public override string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
        {
            _host.CheckValue(values, nameof(values));
            if (dims != null)
                _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
 
            name = AddVariable(name ?? "floats", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeFloats(name, values, dims));
            return name;
        }
 
        public override string AddInitializer(IEnumerable<int> values, Type type, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
        {
            _host.CheckValue(values, nameof(values));
            if (dims != null)
                _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
 
            name = AddVariable(name ?? "int32s", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeInt32s(name, type, values, dims));
            return name;
        }
 
        public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
        {
            _host.CheckValue(values, nameof(values));
            if (dims != null)
                _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
 
            name = AddVariable(name ?? "strings", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeStrings(name, values, dims));
            return name;
        }
 
        public override string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
        {
            _host.CheckValue(values, nameof(values));
            if (dims != null)
                _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
 
            name = AddVariable(name ?? "int64s", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeInt64s(name, values, dims));
            return name;
        }
 
        public override string AddInitializer(IEnumerable<double> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
        {
            _host.CheckValue(values, nameof(values));
            if (dims != null)
                _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
 
            name = AddVariable(name ?? "doubles", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeDoubles(name, values, dims));
            return name;
        }
 
        public override string AddInitializer(IEnumerable<ulong> values, bool isUint64, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
        {
            _host.CheckValue(values, nameof(values));
            if (dims != null)
                _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
 
            name = AddVariable(name ?? "uints", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeUInts(name, isUint64, values, dims));
            return name;
        }
 
        /// <summary>
        /// Makes the ONNX model based on the context.
        /// </summary>
        public OnnxCSharpToProtoWrapper.ModelProto MakeModel()
            => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _opSetVersion, _inputs, _outputs, _intermediateValues, _initializers);
 
        /// <summary>
        /// Return either "Experimental" or "Stable". The string "Experimental" indicates that some experimental features which are
        /// not officially supported in the official ONNX standard. Otherwise, only official ONNX features should be used.
        /// </summary>
        public override OnnxVersion GetOnnxVersion() => _onnxVersion;
    }
}