|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Linq;
using System.Reflection;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.EntryPoints;
[BestFriend]
internal static class EntryPointUtils
{
private static readonly FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool> _isValueWithinRangeMethodInfo
= new FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool>(IsValueWithinRange<int>);
private static bool IsValueWithinRange<T>(TlcModule.RangeAttribute range, object obj)
{
T val;
if (obj is Optional<T> asOptional)
val = asOptional.Value;
else
val = (T)obj;
return
(range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) &&
(range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) &&
(range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) &&
(range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0);
}
public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val)
{
Contracts.AssertValue(range);
Contracts.AssertValue(val);
// Avoid trying to cast double as float. If range
// was specified using floats, but value being checked
// is double, change range to be of type double
if (range.Type == typeof(float) && val is double)
range.CastToDouble();
return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val);
}
/// <summary>
/// Performs checks on an EntryPoint input class equivalent to the checks that are done
/// when parsing a JSON EntryPoint graph.
///
/// Call this method from EntryPoint methods to ensure that range and required checks are performed
/// in a consistent manner when EntryPoints are created directly from code.
/// </summary>
public static void CheckInputArgs(IExceptionContext ectx, object args)
{
foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance))
{
var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault()
as ArgumentAttribute;
if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly)
continue;
var fieldVal = fieldInfo.GetValue(args);
var fieldType = fieldInfo.FieldType;
// Optionals are either left in their Implicit constructed state or
// a new Explicit optional is constructed. They should never be set
// to null.
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null)
throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name);
if (attr.IsRequired)
{
bool equalToDefault;
if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>))
equalToDefault = !((Optional)fieldVal).IsExplicit;
else
equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null;
if (equalToDefault)
throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name);
}
var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault()
as TlcModule.RangeAttribute;
if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal))
throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name);
}
}
public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(hostName);
host.CheckValue(input, nameof(input));
CheckInputArgs(host, input);
return host;
}
/// <summary>
/// Searches for the given column name in the schema. This method applies a
/// common policy that throws an exception if the column is not found
/// and the column name was explicitly specified. If the column is not found
/// and the column name was not explicitly specified, it returns null.
/// </summary>
public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional<string> value)
{
Contracts.CheckValueOrNull(ectx);
ectx.CheckValue(schema, nameof(schema));
ectx.CheckValue(value, nameof(value));
if (value == "")
return null;
if (schema.GetColumnOrNull(value) == null)
{
if (value.IsExplicit)
throw ectx.Except("Column '{0}' not found", value);
return null;
}
return value;
}
/// <summary>
/// Converts EntryPoint Optional{T} types into nullable types, with the
/// implicit value being converted to the null value.
/// </summary>
public static T? AsNullable<T>(this Optional<T> opt) where T : struct
{
if (opt.IsExplicit)
return opt.Value;
else
return null;
}
}
|