File: Model\Onnx\OnnxContext.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
 
namespace Microsoft.ML.Model.OnnxConverter
{
    [BestFriend]
    internal enum OnnxVersion { Stable = 0, Experimental = 1 }
 
    /// <summary>
    /// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This
    /// same context object is iteratively given to exportable components via the <see cref="ICanSaveOnnx"/> interface
    /// and subinterfaces, that attempt to express their operations as ONNX nodes, if they can. At the point that it is
    /// given to a component, all other components up to that component have already attempted to express themselves in
    /// this context, with their outputs possibly available in the ONNX graph.
    /// </summary>
    [BestFriend]
    internal abstract class OnnxContext
    {
        /// <summary>
        /// Generates a unique name for the node based on a prefix.
        /// </summary>
        /// <param name="prefix">The prefix for the node</param>
        /// <returns>A name that has not yet been returned from this function, starting with <paramref name="prefix"/></returns>
        public abstract string GetNodeName(string prefix);
 
        /// <summary>
        /// Determine if a string has been used as ONNX variable name somewhere.
        /// </summary>
        /// <param name="variableName">examined string</param>
        /// <returns>True if the input argument has been used to denote an ONNX variable. Otherwise, False.</returns>
        public abstract bool IsVariableDefined(string variableName);
 
        /// <summary>
        /// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can
        /// safely call <see cref="GetVariableName(string)"/>.
        /// </summary>
        /// <param name="colName">The data view column name</param>
        /// <returns>Whether the column is mapped in this context</returns>
        public abstract bool ContainsColumn(string colName);
 
        /// <summary>
        /// Check the required OpSet version satisfies our requirement
        /// </summary>
        /// <returns></returns>
        public abstract void CheckOpSetVersion(int thisTransformerMinumumOpSetVersion, string registerTransformerName);
 
        /// <summary>
        /// Stops tracking a column.
        /// </summary>
        /// <param name="colName">Column name to stop tracking</param>
        /// <param name="removeVariable">Remove associated ONNX variable. This is useful in the event where an output
        /// variable is created through <see cref="AddIntermediateVariable(DataViewType, string, bool)"/>before realizing
        /// the transform cannot actually save as ONNX.</param>
        public abstract void RemoveColumn(string colName, bool removeVariable = false);
 
        /// <summary>
        /// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the <see
        /// cref="IDataView"/> column associated with it.
        /// </summary>
        /// <param name="variableName">ONNX variable to remove. Note that this is an ONNX variable name, not an <see
        /// cref="IDataView"/> column name</param>
        /// <param name="removeColumn">IDataView column to stop tracking</param>
        public abstract void RemoveVariable(string variableName, bool removeColumn);
 
        /// <summary>
        /// Removes a variable from the input columns list. This function is used only by the ColumnSelectingTransformer.
        /// </summary>
        /// <param name="variableName">ONNX variable to remove. </param>
        public abstract void RemoveInputVariable(string variableName);
 
        /// <summary>
        /// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding
        /// <see cref="IDataView"/>'s column names will map to a variable in the ONNX graph if the intermediate steps
        /// used to calculate that value are things we knew how to save as ONNX. Retrieves the variable name that maps
        /// to the <see cref="IDataView"/> column name at a given point in the pipeline execution. Callers should
        /// probably confirm with <see cref="ContainsColumn(string)"/> whether a mapping for that data view column
        /// already exists.
        /// </summary>
        /// <param name="colName">The data view column name</param>
        /// <returns>The ONNX variable name corresponding to that data view column</returns>
        public abstract string GetVariableName(string colName);
 
        /// <summary>
        /// Establishes a new mapping from an data view column in the context, if necessary generates a unique name, and
        /// returns that newly allocated name.
        /// </summary>
        /// <param name="type">The data view type associated with this column name</param>
        /// <param name="colName">The data view column name</param>
        /// <param name="skip">Whether we should skip the process of establishing the mapping from data view column to
        /// ONNX variable name.</param>
        /// <returns>The returned value is the name of the variable corresponding </returns>
        public abstract string AddIntermediateVariable(DataViewType type, string colName, bool skip = false);
 
        /// <summary>
        /// Creates an ONNX node
        /// </summary>
        /// <param name="opType">The name of the ONNX operator to apply</param>
        /// <param name="inputs">The names of the variables as inputs</param>
        /// <param name="outputs">The names of the variables to create as outputs,
        /// which ought to have been something returned from <see cref="AddIntermediateVariable(DataViewType, string, bool)"/></param>
        /// <param name="name">The name of the operator, which ought to be something returned from <see cref="GetNodeName(string)"/></param>
        /// <param name="domain">The domain of the ONNX operator, if non-default</param>
        /// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
        public abstract OnnxNode CreateNode(string opType, IEnumerable<string> inputs,
            IEnumerable<string> outputs, string name, string domain = null);
 
        /// <summary>
        /// Convenience alternative to <see cref="CreateNode(string, IEnumerable{string}, IEnumerable{string}, string, string)"/>
        /// for the case where there is exactly one input and output.
        /// </summary>
        /// <param name="opType">The name of the ONNX operator to apply</param>
        /// <param name="input">The name of the variable as input</param>
        /// <param name="output">The name of the variable as output,
        /// which ought to have been something returned from <see cref="OnnxContext.AddIntermediateVariable(DataViewType, string, bool)"/></param>
        /// <param name="name">The name of the operator, which ought to be something returned from <see cref="OnnxContext.GetNodeName(string)"/></param>
        /// <param name="domain">The domain of the ONNX operator, if non-default</param>
        /// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
        public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null)
            => CreateNode(opType, new[] { input }, new[] { output }, name, domain);
 
        /// <summary>
        /// Get the targeted ONNX version string. Only two values are allowed now: "Stable" and "Experimental".
        /// </summary>
        /// <returns></returns>
        public abstract OnnxVersion GetOnnxVersion();
 
        /// <summary>
        /// Retrieve the shape of an ONNX variable. Returns null if no shape for the specified variable can be found.
        /// </summary>
        /// <param name="variableName">The ONNX name of the returned shape</param>
        /// <returns>The shape of the retrieved variable</returns>
        public abstract List<long> RetrieveShapeOrNull(string variableName);
 
        /// <summary>
        /// Call this function to declare a global bool scalar
        /// </summary>
        /// <param name="value">The boolean value which is going to be added</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(bool value, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global float scalar
        /// </summary>
        /// <param name="value">The float number which is going to be added</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(float value, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global integer scalar or smaller types
        /// </summary>
        /// <param name="value">The float number which is going to be added</param>
        /// <param name="type">The type of integer to be added, e.g. typeof(short). Use this for all integer types Int32 and smaller</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(int value, Type type, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global string scalar
        /// </summary>
        /// <param name="value">The string which is going to be added into the ONNX graph</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(string value, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global long scalar
        /// </summary>
        /// <param name="value">The long number which is going to be added into the ONNX graph</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(long value, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global double scalar
        /// </summary>
        /// <param name="value">The double number which is going to be added into the ONNX graph</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(double value, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global ulong or uint scalar
        /// </summary>
        /// <param name="value">The long number which is going to be added into the ONNX graph</param>
        /// <param name="isUint64">true if value contains a ulong value and false if it contains uint </param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(ulong value, bool isUint64, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global bool tensor
        /// </summary>
        /// <param name="values">The boolean values which are going to be added into the ONNX graph</param>
        /// <param name="dims">The shape of values</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(IEnumerable<bool> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global float tensor
        /// </summary>
        /// <param name="values">The floats which are going to be added into the ONNX graph</param>
        /// <param name="dims">The shape of values</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global tensor of integer or smaller types
        /// </summary>
        /// <param name="values">The ints which are going to be added into the ONNX graph</param>
        /// <param name="type">The type of ints which are going to be added into the ONNX graph, e.g. typeof(short). Use this for adding array initializers of integer types smaller than Int32</param>
        /// <param name="dims">The shape of values</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(IEnumerable<int> values, Type type, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global string tensor
        /// </summary>
        /// <param name="values">The strings which are going to be added into the ONNX graph</param>
        /// <param name="dims">The shape of values</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global long tensor
        /// </summary>
        /// <param name="values">The longs which are going to be added into the ONNX graph</param>
        /// <param name="dims">The shape of values</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global double tensor
        /// </summary>
        /// <param name="values">The doubles which are going to be added into the ONNX graph</param>
        /// <param name="dims">The shape of values</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(IEnumerable<double> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
 
        /// <summary>
        /// Call this function to declare a global ulong tensor
        /// </summary>
        /// <param name="values">The unsigned integers which are going to be added into the ONNX graph</param>
        /// <param name="isUint64">Set to true if values contain ulong values false if they contain uint values</param>
        /// <param name="dims">The shape of values</param>
        /// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
        /// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
        /// <returns>The initializer's ONNX name</returns>
        public abstract string AddInitializer(IEnumerable<ulong> values, bool isUint64, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
    }
}