|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text.RegularExpressions;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
namespace Microsoft.ML.EntryPoints
{
internal class VarSerializer : JsonConverter
{
public override void WriteJson(JsonWriter writer, object value, JsonSerializer serializer)
{
var variable = value as IVarSerializationHelper;
Contracts.AssertValue(variable);
if (!variable.IsValue)
serializer.Serialize(writer, $"${variable.VarName}");
else
serializer.Serialize(writer, variable.Values.Select(v => $"${v}"));
}
public override object ReadJson(JsonReader reader, Type objectType, object existingValue, JsonSerializer serializer)
{
throw Contracts.ExceptNotImpl();
}
public override bool CanConvert(Type objectType)
{
return typeof(IVarSerializationHelper).IsAssignableFrom(objectType);
}
public override bool CanRead => false;
}
internal interface IVarSerializationHelper
{
string VarName { get; set; }
bool IsValue { get; }
string[] Values { get; }
}
/// <summary>
/// Marker class for the arguments that can be used as variables
/// in an entry point graph.
/// </summary>
[JsonConverter(typeof(VarSerializer))]
[BestFriend]
internal sealed class Var<T> : IVarSerializationHelper
{
public string VarName { get; set; }
bool IVarSerializationHelper.IsValue { get; }
string[] IVarSerializationHelper.Values { get; }
public Var()
{
Contracts.Assert(CheckType(typeof(T)));
VarName = $"Var_{Guid.NewGuid().ToString("N")}";
}
public static bool CheckType(Type type)
{
if (type.IsArray)
return CheckType(type.GetElementType());
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Dictionary<,>)
&& type.GetGenericTypeArgumentsEx()[0] == typeof(string))
{
return CheckType(type.GetGenericTypeArgumentsEx()[1]);
}
return
type == typeof(IDataView) ||
type == typeof(IFileHandle) ||
type == typeof(PredictorModel) ||
type == typeof(TransformModel) ||
type == typeof(CommonInputs.IEvaluatorInput) ||
type == typeof(CommonOutputs.IEvaluatorOutput);
}
}
/// <summary>
/// Marker class for the arguments that can be used as array output variables
/// in an entry point graph.
/// </summary>
[JsonConverter(typeof(VarSerializer))]
[BestFriend]
internal sealed class ArrayVar<T> : IVarSerializationHelper
{
public string VarName { get; set; }
private readonly bool _isValue;
bool IVarSerializationHelper.IsValue => _isValue;
private readonly string[] _values;
string[] IVarSerializationHelper.Values => _values;
public ArrayVar()
{
Contracts.Assert(Var<T>.CheckType(typeof(T)));
VarName = $"Var_{Guid.NewGuid().ToString("N")}";
}
public ArrayVar(params Var<T>[] variables)
{
Contracts.Assert(Var<T>.CheckType(typeof(T)));
_values = variables.Select(v => v.VarName).ToArray();
_isValue = true;
}
public Var<T> this[int i]
{
get
{
var item = new Var<T>();
item.VarName = $"{VarName}[{i}]";
return item;
}
}
}
/// <summary>
/// Marker class for the arguments that can be used as dictionary output variables
/// in an entry point graph.
/// </summary>
[JsonConverter(typeof(VarSerializer))]
internal sealed class DictionaryVar<T> : IVarSerializationHelper
{
public string VarName { get; set; }
bool IVarSerializationHelper.IsValue { get; }
string[] IVarSerializationHelper.Values { get; }
public DictionaryVar()
{
Contracts.Assert(Var<T>.CheckType(typeof(T)));
VarName = $"Var_{Guid.NewGuid().ToString("N")}";
}
public Var<T> this[string key]
{
get
{
var item = new Var<T>();
item.VarName = $"{VarName}[\"{key}\"]";
return item;
}
}
}
/// <summary>
/// A descriptor of one 'variable' of the graph (input or output that is referenced as a $variable in the graph definition).
/// </summary>
[BestFriend]
internal sealed class EntryPointVariable
{
private readonly IExceptionContext _ectx;
public readonly string Name;
public readonly Type Type;
/// <summary>
/// The value. It will originally start as null, and then assigned to the value,
/// once it is available. The type is one of the valid types according to <see cref="IsValidType"/>.
/// </summary>
public object Value { get; private set; }
public bool HasInputs { get; private set; }
public bool HasOutputs { get; private set; }
public bool IsValueSet { get; private set; }
/// <summary>
/// Whether the given type is a valid one to be a variable.
/// </summary>
public static bool IsValidType(Type variableType)
{
Contracts.CheckValue(variableType, nameof(variableType));
// Option types should not be used to construct graph.
if (variableType.IsGenericType && variableType.GetGenericTypeDefinition() == typeof(Optional<>))
return false;
if (variableType == typeof(CommonInputs.IEvaluatorInput))
return true;
if (variableType == typeof(CommonOutputs.IEvaluatorOutput))
return true;
var kind = TlcModule.GetDataType(variableType);
if (kind == TlcModule.DataKind.Array)
{
if (!variableType.IsArray)
{
Contracts.Assert(false, "Unexpected type for array variable");
return false;
}
return IsValidType(variableType.GetElementType());
}
if (kind == TlcModule.DataKind.Dictionary)
{
Contracts.Assert(variableType.IsGenericType && variableType.GetGenericTypeDefinition() == typeof(Dictionary<,>)
&& variableType.GetGenericTypeArgumentsEx()[0] == typeof(string));
return IsValidType(variableType.GetGenericTypeArgumentsEx()[1]);
}
return kind == TlcModule.DataKind.DataView
|| kind == TlcModule.DataKind.FileHandle
|| kind == TlcModule.DataKind.PredictorModel
|| kind == TlcModule.DataKind.TransformModel;
}
public EntryPointVariable(IExceptionContext ectx, string name, Type type)
{
Contracts.AssertValueOrNull(ectx);
_ectx = ectx;
_ectx.AssertNonEmpty(name);
Name = name;
ectx.Assert(IsValidType(type));
Type = type;
}
/// <summary>
/// Set the value. It is only allowed once.
/// </summary>
public void SetValue(object value)
{
_ectx.AssertValueOrNull(value);
_ectx.Assert(!IsValueSet);
_ectx.Assert(value == null || Type.IsAssignableFrom(value.GetType()));
Value = value;
IsValueSet = true;
}
public void MarkUsage(bool isInput)
{
if (isInput)
HasInputs = true;
else
HasOutputs = true;
}
public EntryPointVariable Clone(string newName)
{
var v = new EntryPointVariable(_ectx, newName, Type);
v.MarkUsage(HasInputs);
v.IsValueSet = IsValueSet;
v.Value = Value;
return v;
}
}
/// <summary>
/// A collection of all known variables, with an interface to add new variables, get values based on names etc.
/// This is populated by individual nodes when they parse their respective JSON definitions, and then the values are updated
/// during the node execution.
/// </summary>
[BestFriend]
internal sealed class RunContext
{
private readonly Dictionary<string, EntryPointVariable> _vars;
private readonly IExceptionContext _ectx;
private int _idCount;
public RunContext(IExceptionContext ectx)
{
Contracts.AssertValueOrNull(ectx);
_ectx = ectx;
_vars = new Dictionary<string, EntryPointVariable>();
}
public bool TryGetVariable(string name, out EntryPointVariable v)
{
return _vars.TryGetValue(name, out v);
}
public object GetValueOrNull(VariableBinding binding)
{
_ectx.AssertValue(binding);
EntryPointVariable v;
if (!TryGetVariable(binding.VariableName, out v))
return null;
return binding.GetVariableValueOrNull(v);
}
public void AddInputVariable(VariableBinding binding, Type type)
{
_ectx.AssertValue(binding);
_ectx.AssertValue(type);
if (binding is ArrayIndexVariableBinding)
type = type.MakeArrayType();
else if (binding is DictionaryKeyVariableBinding)
type = typeof(Dictionary<,>).MakeGenericType(typeof(string), type);
EntryPointVariable v;
if (!_vars.TryGetValue(binding.VariableName, out v))
{
v = new EntryPointVariable(_ectx, binding.VariableName, type);
_vars[binding.VariableName] = v;
}
else if (v.Type != type)
throw _ectx.Except($"Variable '{v.Name}' is used as {v.Type} and as {type}");
v.MarkUsage(true);
}
public void RemoveVariable(EntryPointVariable variable)
{
_ectx.CheckValue(variable, nameof(variable));
_vars.Remove(variable.Name);
}
/// <summary>
/// Returns true if added new variable, false if variable already exists.
/// </summary>
public Boolean AddOutputVariable(string name, Type type)
{
_ectx.AssertNonEmpty(name);
_ectx.AssertValue(type);
EntryPointVariable v;
if (!_vars.TryGetValue(name, out v))
{
v = new EntryPointVariable(_ectx, name, type);
_vars[name] = v;
}
else
{
if (v.Type != type)
throw _ectx.Except($"Variable '{v.Name}' is used as {v.Type} and as {type}");
return false;
}
v.MarkUsage(false);
return true;
}
public string[] GetMissingInputs()
{
return _vars.Values.Where(x => x.HasInputs && !x.HasOutputs && !x.IsValueSet)
.Select(x => x.Name)
.ToArray();
}
public string GenerateId(string name)
{
return $"Node_{_idCount++:000}_{name.Replace(" ", "_")}";
}
public void AddContextVariables(RunContext subGraphRunContext)
{
foreach (var kvp in subGraphRunContext._vars)
{
EntryPointVariable v;
if (!_vars.TryGetValue(kvp.Key, out v))
_vars.Add(kvp.Key, kvp.Value);
else
throw _ectx.Except($"Duplicate variable '{kvp.Key}' in subgraph.");
}
}
public void RenameContextVariable(string oldName, string newName)
{
if (_vars.ContainsKey(newName))
throw _ectx.Except($"Variable with name '{newName}' already exists in subgraph.");
if (!_vars.ContainsKey(oldName))
throw _ectx.Except($"Variable with name '{oldName}' not found in subgraph.");
var v = _vars[oldName].Clone(newName);
_vars.Add(newName, v);
_vars.Remove(oldName);
}
public EntryPointVariable CreateTempOutputVar<T>(string varPrefix)
{
_ectx.CheckValue(varPrefix, nameof(varPrefix));
int id = 0;
EntryPointVariable v;
string name = $"{varPrefix}_{id}";
while (_vars.TryGetValue(name, out v))
{
name = $"{varPrefix}_{id}";
id++;
}
Type type = typeof(T);
v = new EntryPointVariable(_ectx, name, type);
_vars[name] = v;
v.MarkUsage(false);
return v;
}
}
/// <summary>
/// A representation of one graph node.
/// </summary>
[BestFriend]
internal sealed class EntryPointNode
{
// The unique node ID, generated at compilation.
public readonly string Id;
private readonly IHost _host;
private readonly ComponentCatalog.EntryPointInfo _entryPoint;
private readonly InputBuilder _inputBuilder;
private readonly OutputHelper _outputHelper;
// Reference to the global run context.
private RunContext _context;
// Mapping of input parameter names to a list of ParameterBindings. This list
// will contain a single element when a variable is directly assigned to an input
// parameter. When an input parameter is assigned an array or dictionary of variable
// values this list will contain an entry for each needed assignment.
private readonly Dictionary<string, List<ParameterBinding>> _inputBindingMap;
private readonly Dictionary<ParameterBinding, VariableBinding> _inputMap;
// Outputs are simple- we both can't bind index/keyed values to a variable and can't
// bind a value to a variable index/key slot.
private readonly Dictionary<string, string> _outputMap;
public bool IsFinished { get; private set; }
public TimeSpan RunTime { get; internal set; }
private static readonly Regex _stageIdRegex = new Regex(@"[a-zA-Z0-9]*", RegexOptions.Compiled);
private string _stageId;
/// <summary>
/// An alphanumeric string indicating the stage of a node.
/// The fact that the nodes share the same stage ID hints that they should be executed together whenever possible.
/// </summary>
public string StageId
{
get { return _stageId; }
set
{
if (!IsStageIdValid(value))
throw _host.Except("Stage ID must be alphanumeric.");
_stageId = value;
}
}
/// <summary>
/// Hints that the output of this node should be checkpointed.
/// </summary>
public bool Checkpoint { get; set; }
private float _cost;
/// <summary>
/// The cost of running this node. NaN indicates unknown.
/// </summary>
public float Cost
{
get { return _cost; }
set
{
if (value < 0)
throw _host.Except("Cost cannot be negative.");
_cost = value;
}
}
private EntryPointNode(IHostEnvironment env, IChannel ch, RunContext context,
string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false,
string stageId = "", float cost = float.NaN, string label = null, string group = null, string weight = null,
string name = null)
{
Contracts.AssertValue(env);
env.AssertNonEmpty(id);
_host = env.Register(id);
_host.AssertValue(context);
_host.AssertNonEmpty(entryPointName);
_host.AssertValueOrNull(inputs);
_host.AssertValueOrNull(outputs);
_context = context;
Id = id;
if (!env.ComponentCatalog.TryFindEntryPoint(entryPointName, out _entryPoint))
throw _host.Except($"Entry point '{entryPointName}' not found");
// Validate inputs.
_inputMap = new Dictionary<ParameterBinding, VariableBinding>();
_inputBindingMap = new Dictionary<string, List<ParameterBinding>>();
_inputBuilder = new InputBuilder(_host, _entryPoint.InputType, env.ComponentCatalog);
// REVIEW: This logic should move out of Node eventually and be delegated to
// a class that can nest to handle Components with variables.
if (inputs != null)
{
foreach (var pair in inputs)
CheckAndSetInputValue(pair);
}
var missing = _inputBuilder.GetMissingValues().Except(_inputBindingMap.Keys).ToArray();
if (missing.Length > 0)
throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");
var inputInstance = _inputBuilder.GetInstance();
SetColumnArgument(ch, inputInstance, "LabelColumnName", label, "label", typeof(CommonInputs.ITrainerInputWithLabel));
SetColumnArgument(ch, inputInstance, "RowGroupColumnName", group, "group Id", typeof(CommonInputs.ITrainerInputWithGroupId));
SetColumnArgument(ch, inputInstance, "ExampleWeightColumnName", weight, "weight", typeof(CommonInputs.ITrainerInputWithWeight), typeof(CommonInputs.IUnsupervisedTrainerWithWeight));
SetColumnArgument(ch, inputInstance, "NameColumn", name, "name");
// Validate outputs.
_outputHelper = new OutputHelper(_host, _entryPoint.OutputType);
_outputMap = new Dictionary<string, string>();
if (outputs != null)
{
foreach (var pair in outputs)
CheckAndMarkOutputValue(pair);
}
Checkpoint = checkpoint;
StageId = stageId;
Cost = cost;
}
private void SetColumnArgument(IChannel ch, object inputInstance, string argName, string colName, string columnRole, params Type[] inputKinds)
{
Contracts.AssertValue(ch);
ch.AssertValue(inputInstance);
ch.AssertNonEmpty(argName);
ch.AssertValueOrNull(colName);
ch.AssertNonEmpty(columnRole);
ch.AssertValueOrNull(inputKinds);
var colField = _inputBuilder.GetFieldNameOrNull(argName);
if (string.IsNullOrEmpty(colField))
return;
const string warning = "Different {0} column specified in trainer and in macro: '{1}', '{2}'." +
" Using column '{2}'. To column use '{1}' instead, please specify this name in" +
"the trainer node arguments.";
if (!string.IsNullOrEmpty(colName) && Utils.Size(_entryPoint.InputKinds) > 0 &&
(Utils.Size(inputKinds) == 0 || _entryPoint.InputKinds.Intersect(inputKinds).Any()))
{
ch.AssertNonEmpty(colField);
var colFieldType = _inputBuilder.GetFieldTypeOrNull(colField);
ch.Assert(colFieldType == typeof(string));
var inputColName = inputInstance.GetType().GetField(colField).GetValue(inputInstance);
ch.Assert(inputColName is string || inputColName is Optional<string>);
var str = inputColName is string ? (string)inputColName : ((Optional<string>)inputColName).Value;
if (colName != str)
ch.Warning(warning, columnRole, colName, inputColName);
else
_inputBuilder.TrySetValue(colField, colName);
}
}
public static EntryPointNode Create(IHostEnvironment env,
string entryPointName,
object arguments,
RunContext context,
Dictionary<string, List<ParameterBinding>> inputBindingMap,
Dictionary<ParameterBinding, VariableBinding> inputMap,
Dictionary<string, string> outputMap,
bool checkpoint = false,
string stageId = "",
float cost = float.NaN)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonEmpty(entryPointName, nameof(entryPointName));
env.CheckValue(arguments, nameof(arguments));
env.CheckValue(context, nameof(context));
env.CheckValue(inputBindingMap, nameof(inputBindingMap));
env.CheckValue(inputMap, nameof(inputMap));
env.CheckValue(outputMap, nameof(outputMap));
ComponentCatalog.EntryPointInfo info;
bool success = env.ComponentCatalog.TryFindEntryPoint(entryPointName, out info);
env.Assert(success);
var inputBuilder = new InputBuilder(env, info.InputType, env.ComponentCatalog);
var outputHelper = new OutputHelper(env, info.OutputType);
using (var ch = env.Start("Create EntryPointNode"))
{
return new EntryPointNode(env, ch, context, context.GenerateId(entryPointName), entryPointName,
inputBuilder.GetJsonObject(arguments, inputBindingMap, inputMap),
outputHelper.GetJsonObject(outputMap), checkpoint, stageId, cost);
}
}
public static EntryPointNode Create(
IHostEnvironment env,
string entryPointName,
object arguments,
ComponentCatalog catalog,
RunContext context,
Dictionary<string, string> inputMap,
Dictionary<string, string> outputMap,
bool checkpoint = false,
string stageId = "",
float cost = float.NaN)
{
ComponentCatalog.EntryPointInfo info;
bool success = catalog.TryFindEntryPoint(entryPointName, out info);
env.Assert(success);
var inputBindingMap = new Dictionary<string, List<ParameterBinding>>();
var inputParamBindingMap = new Dictionary<ParameterBinding, VariableBinding>();
foreach (var kvp in inputMap)
{
var paramBinding = new SimpleParameterBinding(kvp.Key);
inputBindingMap.Add(kvp.Key, new List<ParameterBinding>() { paramBinding });
inputParamBindingMap.Add(paramBinding, new SimpleVariableBinding(kvp.Value));
}
return Create(env, entryPointName, arguments, context, inputBindingMap, inputParamBindingMap,
outputMap, checkpoint, stageId, cost);
}
/// <summary>
/// Checks the given JSON object key-value pair is a valid EntryPoint input and
/// extracts out any variables that need to be populated. These variables will be
/// added to the EntryPoint context. Input parameters that are not set to variables
/// will be immediately set using the input builder instance.
/// </summary>
private void CheckAndSetInputValue(KeyValuePair<string, JToken> pair)
{
var inputName = _inputBuilder.GetFieldNameOrNull(pair.Key);
if (VariableBinding.IsBindingToken(pair.Value))
{
Type valueType = _inputBuilder.GetFieldTypeOrNull(pair.Key);
if (valueType == null)
throw _host.Except($"Unexpected input name: '{pair.Key}'");
if (!EntryPointVariable.IsValidType(valueType))
throw _host.Except($"Unexpected input variable type: {valueType}");
var varBinding = VariableBinding.Create(_host, pair.Value.Value<string>());
_context.AddInputVariable(varBinding, valueType);
if (!_inputBindingMap.ContainsKey(inputName))
_inputBindingMap[inputName] = new List<ParameterBinding>();
var paramBinding = new SimpleParameterBinding(inputName);
_inputBindingMap[inputName].Add(paramBinding);
_inputMap[paramBinding] = varBinding;
}
else if (pair.Value is JArray &&
((JArray)pair.Value).Any(tok => VariableBinding.IsBindingToken(tok)))
{
// REVIEW: EntryPoint arrays and dictionaries containing
// variables must ONLY contain variables right now.
if (!((JArray)pair.Value).All(tok => VariableBinding.IsBindingToken(tok)))
throw _host.Except($"Input {pair.Key} may ONLY contain variables.");
Type valueType = _inputBuilder.GetFieldTypeOrNull(pair.Key);
if (valueType == null || !valueType.HasElementType)
throw _host.Except($"Unexpected input name: '{pair.Key}'");
valueType = valueType.GetElementType();
int i = 0;
foreach (var varName in (JArray)pair.Value)
{
var varBinding = VariableBinding.Create(_host, varName.Value<string>());
_context.AddInputVariable(varBinding, valueType);
if (!_inputBindingMap.ContainsKey(inputName))
_inputBindingMap[inputName] = new List<ParameterBinding>();
var paramBinding = new ArrayIndexParameterBinding(inputName, i++);
_inputBindingMap[inputName].Add(paramBinding);
_inputMap[paramBinding] = varBinding;
}
}
// REVIEW: Implement support for Dictionary of variable values. We need to differentiate
// between a Dictionary and a Component here, and likely need to support nested components
// all of which might have variables. Our current machinery only works at the 'Node' level.
else
{
// This is not a variable.
if (!_inputBuilder.TrySetValueJson(pair.Key, pair.Value))
throw _host.Except($"Unexpected input: '{pair.Key}'");
}
}
/// <summary>
/// Checks the given JSON object key-value pair is a valid EntryPoint output.
/// Extracts out any variables that need to be populated and adds them to the
/// EntryPoint context.
/// </summary>
private void CheckAndMarkOutputValue(KeyValuePair<string, JToken> pair)
{
if (!VariableBinding.IsBindingToken(pair.Value))
throw _host.Except("Only variables allowed as outputs");
// Output variable.
var varBinding = VariableBinding.Create(_host, pair.Value.Value<string>());
if (!(varBinding is SimpleVariableBinding))
throw _host.Except($"Output '{pair.Key}' can only be bound to a variable");
var valueType = _outputHelper.GetFieldType(pair.Key);
if (valueType == null)
throw _host.Except($"Unexpected output name: '{pair.Key}");
if (!EntryPointVariable.IsValidType(valueType))
throw _host.Except($"Output '{pair.Key}' has invalid type");
_context.AddOutputVariable(varBinding.VariableName, valueType);
_outputMap[pair.Key] = varBinding.VariableName;
}
public void RenameInputVariable(string oldName, VariableBinding newBinding)
{
var toModify = new List<ParameterBinding>();
foreach (var kvp in _inputMap)
{
if (kvp.Value.VariableName == oldName)
toModify.Add(kvp.Key);
}
foreach (var parameterBinding in toModify)
_inputMap[parameterBinding] = newBinding;
}
public void RenameOutputVariable(string oldName, string newName, bool cascadeChanges = false)
{
string key = null;
foreach (var kvp in _outputMap)
{
if (kvp.Value == oldName)
{
key = kvp.Key;
break;
}
}
if (key != null)
{
_outputMap[key] = newName;
if (cascadeChanges)
_context.RenameContextVariable(oldName, newName);
}
}
public void RenameAllVariables(Dictionary<string, string> mapping)
{
string newName;
foreach (var kvp in _inputMap)
{
if (!mapping.TryGetValue(kvp.Value.VariableName, out newName))
{
newName = new Var<IDataView>().VarName;
mapping.Add(kvp.Value.VariableName, newName);
}
kvp.Value.Rename(newName);
}
var toModify = new Dictionary<string, string>();
foreach (var kvp in _outputMap)
{
if (!mapping.TryGetValue(kvp.Value, out newName))
{
newName = new Var<IDataView>().VarName;
mapping.Add(kvp.Value, newName);
}
toModify.Add(kvp.Key, newName);
}
foreach (var kvp in toModify)
_outputMap[kvp.Key] = kvp.Value;
}
private static bool IsStageIdValid(string str)
{
return str != null && _stageIdRegex.Match(str).Success;
}
public JObject ToJson()
{
var result = new JObject();
result[FieldNames.Name] = _entryPoint.Name;
result[FieldNames.Inputs] = _inputBuilder.GetJsonObject(_inputBuilder.GetInstance(), _inputBindingMap, _inputMap);
result[FieldNames.Outputs] = _outputHelper.GetJsonObject(_outputMap);
if (!string.IsNullOrEmpty(StageId))
result[FieldNames.StageId] = StageId;
if (Checkpoint)
result[FieldNames.Checkpoint] = Checkpoint;
if (!float.IsNaN(Cost))
result[FieldNames.Cost] = Cost;
return result;
}
/// <summary>
/// Whether the node can run right now.
/// </summary>
public bool CanStart()
{
if (IsFinished)
return false;
return _inputMap.Where(kv => !_inputBuilder.IsInputOptional(kv.Key.ParameterName)).Select(kv => kv.Value).Distinct()
.All(varBinding => _context.TryGetVariable(varBinding.VariableName, out EntryPointVariable v) && v.IsValueSet);
}
public void Run()
{
_host.Assert(CanStart());
Stopwatch stopWatch = new Stopwatch();
stopWatch.Start();
// Set all remaining inputs.
foreach (var pair in _inputBindingMap)
{
bool success = _inputBuilder.TrySetValue(pair.Key, BuildParameterValue(pair.Value));
_host.Assert(success);
}
_host.Assert(_inputBuilder.GetMissingValues().Length == 0);
object output;
if (IsMacro)
{
output = _entryPoint.Method.Invoke(null, new object[] { _host, _inputBuilder.GetInstance(), this });
var macroResult = (CommonOutputs.MacroOutput)output;
_host.AssertValue(macroResult);
_macroNodes = macroResult.Nodes;
}
else
{
output = _entryPoint.Method.Invoke(null, new object[] { _host, _inputBuilder.GetInstance() });
foreach (var pair in _outputHelper.ExtractValues(output))
{
string tgt;
if (_outputMap.TryGetValue(pair.Key, out tgt))
{
EntryPointVariable v;
bool good = _context.TryGetVariable(tgt, out v);
_host.Assert(good);
v.SetValue(pair.Value);
}
}
}
stopWatch.Stop();
RunTime = stopWatch.Elapsed;
IsFinished = true;
}
public bool IsMacro => _entryPoint.OutputType.IsSubclassOf(typeof(CommonOutputs.MacroOutput));
private IEnumerable<EntryPointNode> _macroNodes;
public IEnumerable<EntryPointNode> MacroNodes => _macroNodes;
public ComponentCatalog Catalog => _host.ComponentCatalog;
public RunContext Context => _context;
public Dictionary<string, List<ParameterBinding>> InputBindingMap => _inputBindingMap;
public Dictionary<ParameterBinding, VariableBinding> InputMap => _inputMap;
public Dictionary<string, string> OutputMap => _outputMap;
public override string ToString() => Id;
private object BuildParameterValue(List<ParameterBinding> bindings)
{
_host.AssertNonEmpty(bindings);
var firstBinding = bindings.First();
_host.Assert(bindings.Skip(1).All(binding => binding.GetType().Equals(firstBinding.GetType())));
if (firstBinding is SimpleParameterBinding)
{
_host.Assert(bindings.Count == 1);
return _context.GetValueOrNull(_inputMap[firstBinding]);
}
if (firstBinding is ArrayIndexParameterBinding)
{
var type = _inputBuilder.GetFieldTypeOrNull(firstBinding.ParameterName).GetElementType();
_host.AssertValue(type);
var arr = Array.CreateInstance(type, bindings.Count);
int i = 0;
foreach (var binding in bindings)
arr.SetValue(_context.GetValueOrNull(_inputMap[binding]), i++);
return arr;
}
if (firstBinding is DictionaryKeyParameterBinding)
{
// REVIEW: Implement dictionary support when needed;
throw _host.ExceptNotImpl("Dictionary variable binding is not currently supported");
}
_host.Assert(false);
throw _host.ExceptNotImpl("Unsupported ParameterBinding");
}
public static List<EntryPointNode> ValidateNodes(IHostEnvironment env, RunContext context, JArray nodes,
string label = null, string group = null, string weight = null, string name = null)
{
Contracts.AssertValue(env);
env.AssertValue(context);
env.AssertValue(nodes);
var result = new List<EntryPointNode>(nodes.Count);
using (var ch = env.Start("Validating graph nodes"))
{
for (int i = 0; i < nodes.Count; i++)
{
var node = nodes[i] as JObject;
if (node == null)
throw env.Except("Unexpected node token: '{0}'", nodes[i]);
string nodeName = node[FieldNames.Name].Value<string>();
var inputs = node[FieldNames.Inputs] as JObject;
if (inputs == null && node[FieldNames.Inputs] != null)
throw env.Except("Unexpected {0} token: '{1}'", FieldNames.Inputs, node[FieldNames.Inputs]);
var outputs = node[FieldNames.Outputs] as JObject;
if (outputs == null && node[FieldNames.Outputs] != null)
throw env.Except("Unexpected {0} token: '{1}'", FieldNames.Outputs, node[FieldNames.Outputs]);
var id = context.GenerateId(nodeName);
var unexpectedFields = node.Properties().Where(
x => x.Name != FieldNames.Name && x.Name != FieldNames.Inputs && x.Name != FieldNames.Outputs
&& x.Name != FieldNames.StageId && x.Name != FieldNames.Checkpoint && x.Name != FieldNames.Cost);
var stageId = node[FieldNames.StageId] == null ? "" : node[FieldNames.StageId].Value<string>();
var checkpoint = node[FieldNames.Checkpoint] == null ? false : node[FieldNames.Checkpoint].Value<bool>();
var cost = node[FieldNames.Cost] == null ? float.NaN : node[FieldNames.Cost].Value<float>();
if (unexpectedFields.Any())
{
// REVIEW: consider throwing an exception.
ch.Warning("Node '{0}' has unexpected fields that are ignored: {1}", id, string.Join(", ", unexpectedFields.Select(x => x.Name)));
}
result.Add(new EntryPointNode(env, ch, context, id, nodeName, inputs, outputs, checkpoint, stageId, cost, label, group, weight, name));
}
}
return result;
}
public void SetContext(RunContext context)
{
_host.CheckValue(context, nameof(context));
_context = context;
}
public VariableBinding GetInputVariable(string paramName)
{
List<ParameterBinding> parameterBindings;
bool success = InputBindingMap.TryGetValue(paramName, out parameterBindings);
if (!success)
throw _host.Except($"Invalid parameter '{paramName}': parameter does not exist.");
if (parameterBindings == null || parameterBindings.Count > 1)
throw _host.Except($"Invalid parameter '{paramName}': only simple parameters are supported.");
VariableBinding variableBinding;
success = InputMap.TryGetValue(parameterBindings[0], out variableBinding);
_host.Assert(success && variableBinding != null);
return variableBinding;
}
public string GetOutputVariableName(string paramName)
{
string outputVarName;
bool success = OutputMap.TryGetValue(paramName, out outputVarName);
if (!success)
throw _host.Except($"Invalid parameter '{paramName}': parameter does not exist.");
return outputVarName;
}
public Tuple<Var<T>, VariableBinding> AddNewVariable<T>(string uniqueName, T value)
{
// Make sure name is really unique.
if (InputBindingMap.ContainsKey(uniqueName))
throw _host.Except($"Key {uniqueName} already exists in binding map.");
// Add parameter bindings
var paramBinding = new SimpleParameterBinding(uniqueName);
InputBindingMap.Add(uniqueName, new List<ParameterBinding> { paramBinding });
// Create new variables
var varBinding = new SimpleVariableBinding(uniqueName);
Context.AddInputVariable(varBinding, typeof(T));
InputMap.Add(paramBinding, varBinding);
// Set value
if (value != null && Context.TryGetVariable(varBinding.VariableName, out var variable))
variable.SetValue(value);
// Return Var<> object and variable binding
return new Tuple<Var<T>, VariableBinding>(new Var<T> { VarName = varBinding.VariableName }, varBinding);
}
}
[BestFriend]
internal sealed class EntryPointGraph
{
private const string RegistrationName = "EntryPointGraph";
private readonly IHost _host;
private readonly RunContext _context;
private readonly List<EntryPointNode> _nodes;
public EntryPointGraph(IHostEnvironment env, JArray nodes)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(RegistrationName);
_host.CheckValue(nodes, nameof(nodes));
_context = new RunContext(_host);
_nodes = EntryPointNode.ValidateNodes(_host, _context, nodes);
}
public bool HasRunnableNodes => _nodes.FirstOrDefault(x => x.CanStart()) != null;
public IEnumerable<EntryPointNode> Macros => _nodes.Where(x => x.IsMacro);
public IEnumerable<EntryPointNode> NonMacros => _nodes.Where(x => !x.IsMacro);
public IEnumerable<EntryPointNode> AllNodes => _nodes;
public RunContext Context => _context;
public string[] GetMissingInputs()
{
return _context.GetMissingInputs();
}
public void RunNode(EntryPointNode node)
{
_host.CheckValue(node, nameof(node));
_host.Assert(_nodes.Contains(node));
node.Run();
if (node.IsMacro)
_nodes.AddRange(node.MacroNodes);
}
public bool TryGetVariable(string name, out EntryPointVariable v)
{
return _context.TryGetVariable(name, out v);
}
public EntryPointVariable GetVariableOrNull(string name)
{
EntryPointVariable var;
if (TryGetVariable(name, out var))
return var;
return null;
}
public void AddNode(EntryPointNode node)
{
_host.CheckValue(node, nameof(node));
node.SetContext(_context);
_nodes.Add(node);
}
}
/// <summary>
/// Represents a delayed binding in a JSON graph to an <see cref="EntryPointVariable"/>.
/// The subclasses allow us to express that we either desire the variable itself,
/// or a array-indexed or dictionary-keyed value from the variable, assuming it is
/// of an Array or Dictionary type.
/// </summary>
[BestFriend]
internal abstract class VariableBinding
{
public string VariableName { get; private set; }
protected VariableBinding(string varName)
{
Contracts.AssertNonWhiteSpace(varName);
VariableName = varName;
}
// A regex to validate an EntryPoint variable value accessor string. Valid EntryPoint variable names
// can be any sequence of alphanumeric characters and underscores. They must start with a letter or underscore.
// An EntryPoint variable can be followed with an array or dictionary specifier, which begins
// with '[', contains either an integer or alphanumeric string, optionally wrapped in single-quotes,
// followed with ']'.
private static readonly Regex _variableRegex = new Regex(
@"\$(?<Name>[a-zA-Z_][a-zA-Z0-9_]*)(\[(((?<NumericAccessor>[0-9]*))|(\'?(?<StringAccessor>[a-zA-Z0-9_]*)\'?))\])?",
RegexOptions.Compiled);
public abstract object GetVariableValueOrNull(EntryPointVariable variable);
public static VariableBinding Create(IExceptionContext ectx, string jsonString)
{
Contracts.AssertValue(ectx);
ectx.AssertNonWhiteSpace(jsonString);
var match = _variableRegex.Match(jsonString);
if (!match.Success)
throw ectx.Except($"Unable to parse variable string '{jsonString}'");
if (match.Groups["NumericAccessor"].Success)
{
return new ArrayIndexVariableBinding(
match.Groups["Name"].Value,
int.Parse(match.Groups["NumericAccessor"].Value));
}
if (match.Groups["StringAccessor"].Success)
{
return new DictionaryKeyVariableBinding(
match.Groups["Name"].Value,
match.Groups["StringAccessor"].Value);
}
return new SimpleVariableBinding(match.Groups["Name"].Value);
}
public static bool IsBindingToken(JToken tok)
{
var token = tok as JValue;
return token?.Value != null && _variableRegex.IsMatch(token.Value<string>());
}
/// <summary>
/// Verifies that the name of the graph variable is a valid one
/// </summary>
public static bool IsValidVariableName(IExceptionContext ectx, string variableName)
{
Contracts.AssertValue(ectx);
ectx.AssertNonWhiteSpace(variableName);
return _variableRegex.Match(variableName).Success;
}
public void Rename(string newName)
{
Contracts.CheckNonWhiteSpace(newName, nameof(newName));
VariableName = newName;
}
public abstract string ToJson();
public override string ToString() => VariableName;
}
[BestFriend]
internal sealed class SimpleVariableBinding
: VariableBinding
{
public SimpleVariableBinding(string name)
: base(name)
{ }
public override object GetVariableValueOrNull(EntryPointVariable variable)
{
Contracts.AssertValue(variable);
return variable.Value;
}
public override string ToJson()
{
return $"${VariableName}";
}
}
internal sealed class DictionaryKeyVariableBinding
: VariableBinding
{
public readonly string Key;
public DictionaryKeyVariableBinding(string name, string key)
: base(name)
{
Contracts.AssertNonWhiteSpace(key);
Key = key;
}
public override object GetVariableValueOrNull(EntryPointVariable variable)
{
Contracts.AssertValue(variable);
// REVIEW: Implement dictionary-based value retrieval.
throw Contracts.ExceptNotImpl("Diction-based value retrieval is not supported.");
}
public override string ToJson()
{
return $"${VariableName}['{Key}']";
}
}
[BestFriend]
internal sealed class ArrayIndexVariableBinding
: VariableBinding
{
public readonly int Index;
public ArrayIndexVariableBinding(string name, int index)
: base(name)
{
Contracts.Assert(index >= 0);
Index = index;
}
public override object GetVariableValueOrNull(EntryPointVariable variable)
{
Contracts.AssertValue(variable, nameof(variable));
var arr = variable.Value as Array;
return arr?.GetValue(Index);
}
public override string ToJson()
{
return $"${VariableName}[{Index}]";
}
}
/// <summary>
/// Represents the l-value assignable destination of a <see cref="VariableBinding"/>.
/// Subclasses exist to express the needed bindings for subslots
/// of a yet-to-be-constructed array or dictionary EntryPoint input parameter
/// (for example, "myVar": ["$var1", "$var2"] would yield two <see cref="ArrayIndexParameterBinding"/>: (myVar, 0), (myVar, 1))
/// </summary>
[BestFriend]
internal abstract class ParameterBinding
{
public readonly string ParameterName;
protected ParameterBinding(string name)
{
Contracts.AssertNonWhiteSpace(name);
ParameterName = name;
}
public override string ToString() => ParameterName;
}
[BestFriend]
internal sealed class SimpleParameterBinding
: ParameterBinding
{
public SimpleParameterBinding(string name)
: base(name)
{ }
public override bool Equals(object obj)
{
var asSelf = obj as SimpleParameterBinding;
if (asSelf == null)
return false;
return asSelf.ParameterName.Equals(ParameterName, StringComparison.Ordinal);
}
public override int GetHashCode()
{
return ParameterName.GetHashCode();
}
}
internal sealed class DictionaryKeyParameterBinding
: ParameterBinding
{
public readonly string Key;
public DictionaryKeyParameterBinding(string name, string key)
: base(name)
{
Contracts.AssertNonWhiteSpace(key);
Key = key;
}
public override bool Equals(object obj)
{
var asSelf = obj as DictionaryKeyParameterBinding;
if (asSelf == null)
return false;
return
asSelf.ParameterName.Equals(ParameterName, StringComparison.Ordinal) &&
asSelf.Key.Equals(Key, StringComparison.Ordinal);
}
public override int GetHashCode()
{
return Tuple.Create(ParameterName, Key).GetHashCode();
}
}
[BestFriend]
internal sealed class ArrayIndexParameterBinding
: ParameterBinding
{
public readonly int Index;
public ArrayIndexParameterBinding(string name, int index)
: base(name)
{
Contracts.Check(index >= 0);
Index = index;
}
public override bool Equals(object obj)
{
var asSelf = obj as ArrayIndexParameterBinding;
if (asSelf == null)
return false;
return
asSelf.ParameterName.Equals(ParameterName, StringComparison.Ordinal) &&
asSelf.Index == Index;
}
public override int GetHashCode()
{
return Tuple.Create(ParameterName, Index).GetHashCode();
}
}
}
|