File: Model\Pfa\PfaUtils.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Newtonsoft.Json.Linq;
 
namespace Microsoft.ML.Model.Pfa
{
    [BestFriend]
    internal static class PfaUtils
    {
        public static JObject AddReturn(this JObject toEdit, string name, JToken value)
        {
            Contracts.CheckValueOrNull(toEdit);
            Contracts.CheckValue(name, nameof(name));
            Contracts.CheckValue(value, nameof(value));
 
            if (toEdit == null)
                toEdit = new JObject();
            toEdit.Add(name, value);
            return toEdit;
        }
 
        /// <summary>
        /// Generic facilities for calling a function.
        /// </summary>
        /// <param name="func">The function to call</param>
        /// <param name="prms">The parameters for the function</param>
        /// <returns></returns>
        public static JObject Call(string func, params JToken[] prms)
        {
            Contracts.CheckNonWhiteSpace(func, nameof(func));
            Contracts.CheckValue(prms, nameof(prms));
            var retval = new JObject();
            retval[func] = new JArray(prms);
            return retval;
        }
 
        public static JObject FuncRef(string func)
        {
            Contracts.CheckNonWhiteSpace(func, nameof(func));
            return ((JObject)null).AddReturn("fcn", func);
        }
 
        public static JObject Param(string name, JToken type)
        {
            var retval = new JObject();
            retval[name] = type;
            return retval;
        }
 
        public static JObject Index(JToken arrayOrMap, JToken key)
        {
            var retval = new JObject();
            retval["attr"] = arrayOrMap;
            var path = new JArray();
            path.Add(key);
            retval["path"] = path;
            return retval;
        }
 
        public static JObject String(string str)
        {
            Contracts.CheckValue(str, nameof(str));
            return ((JObject)null).AddReturn("type", Type.String).AddReturn("value", str);
        }
 
        public static JObject For(JObject initBlock, JObject whileBlock, JObject stepBlock, JObject doBlock)
        {
            var retval = new JObject();
            retval["for"] = initBlock;
            retval["while"] = whileBlock;
            retval["step"] = stepBlock;
            retval["do"] = doBlock;
            return retval;
        }
 
        public static JObject If(JToken condition, JToken thenBlock, JToken elseBlock)
        {
            var retval = new JObject();
            retval["if"] = condition;
            retval["then"] = thenBlock;
            if (elseBlock != null)
                retval["else"] = elseBlock;
            return retval;
        }
 
        /// <summary>
        /// Builds a "cast" statement to the two vector types.
        /// </summary>
        /// <param name="itemType">The type of the item in the vector</param>
        /// <param name="src">The token we are casting</param>
        /// <param name="asMapName">The name for the token as it will appear in the <paramref name="mapDo"/></param>
        /// <param name="mapDo">The map case expression</param>
        /// <param name="asArrName">The name for the token as it will appear in the <paramref name="arrDo"/></param>
        /// <param name="arrDo">The array case expression</param>
        /// <returns>The cast/case expression</returns>
        public static JObject VectorCase(JToken itemType, JToken src, string asMapName, JToken mapDo, string asArrName, JToken arrDo)
        {
            JObject jobj = null;
            var cases = new JArray();
            cases.Add(jobj.AddReturn("as", Type.Map(itemType)).AddReturn(
                "named", asMapName).AddReturn("do", mapDo));
            cases.Add(jobj.AddReturn("as", Type.Array(itemType)).AddReturn(
                "named", asArrName).AddReturn("do", arrDo));
            return jobj.AddReturn("cast", src).AddReturn("cases", cases);
        }
 
        public static JObject Cell(string name)
        {
            Contracts.CheckNonWhiteSpace(name, nameof(name));
            return ((JObject)null).AddReturn("cell", name);
        }
 
        public static class Type
        {
            public static readonly JToken Int = "int";
            public static readonly JToken Long = "long";
            public static readonly JToken Float = "float";
            public static readonly JToken Double = "double";
            public static readonly JToken Bool = "boolean";
            public static readonly JToken String = "string";
            public static readonly JToken Null = "null";
 
            public static JToken Map(JToken valueType)
            {
                Contracts.CheckValue(valueType, nameof(valueType));
                var retval = new JObject();
                retval["type"] = "map";
                retval["values"] = valueType;
                return retval;
            }
 
            public static JToken Array(JToken itemType)
            {
                Contracts.CheckValue(itemType, nameof(itemType));
                var retval = new JObject();
                retval["type"] = "array";
                retval["items"] = itemType;
                return retval;
            }
 
            public static JToken Union(params JToken[] types)
            {
                Contracts.CheckParam(Utils.Size(types) >= 2, nameof(types), "Union must have at least two types");
                return new JArray(types);
            }
 
            public static JToken Vector(JToken itemType)
            {
                Contracts.CheckValue(itemType, nameof(itemType));
                return Union(Map(itemType), Array(itemType));
            }
 
            public static JToken PfaTypeOrNullForColumnType(DataViewType type)
            {
                Contracts.CheckValue(type, nameof(type));
                if (type is VectorDataViewType vectorType)
                {
                    // We represent vectors as the union of array (for dense) and map (for sparse),
                    // of the appropriate item type.
                    var itemType = PfaTypeOrNullCore(vectorType.ItemType);
                    if (itemType == null)
                        return null;
                    return Array(itemType);
                }
                return PfaTypeOrNullCore(type);
            }
 
            private static JToken PfaTypeOrNullCore(DataViewType itemType)
            {
                Contracts.AssertValue(itemType);
 
                if (!(itemType is PrimitiveDataViewType))
                    return null;
 
                if (itemType is KeyDataViewType keyType)
                {
                    // Keys will retain the property that they are just numbers,
                    // with 0 representing missing.
                    if (keyType.Count > 0 || keyType.RawType != typeof(ulong))
                        return Int;
                    return Long;
                }
 
                System.Type rawType = itemType.RawType;
                if (rawType == typeof(sbyte)
                    || rawType == typeof(byte)
                    || rawType == typeof(short)
                    || rawType == typeof(ushort)
                    || rawType == typeof(int))
                {
                    return Int;
                }
                else if (rawType == typeof(uint)
                    || rawType == typeof(long)
                    || rawType == typeof(ulong))
                {
                    return Long;
                }
                else if (rawType == typeof(float)
                    // REVIEW: The above should really be float. But, for the
                    // sake of the POC, we use double since all the PFA convenience
                    // libraries operate over doubles.
                    || rawType == typeof(double))
                {
                    return Double;
                }
                else if (rawType == typeof(bool))
                {
                    return Bool;
                }
                else if (rawType == typeof(System.ReadOnlyMemory<char>)
                    || rawType == typeof(string))
                {
                    return String;
                }
 
                return null;
            }
 
            public static JToken DefaultTokenOrNull(PrimitiveDataViewType itemType)
            {
                Contracts.CheckValue(itemType, nameof(itemType));
 
                if (itemType is KeyDataViewType)
                    return 0;
 
                System.Type rawType = itemType.RawType;
                if (rawType == typeof(sbyte)
                    || rawType == typeof(byte)
                    || rawType == typeof(short)
                    || rawType == typeof(ushort)
                    || rawType == typeof(int)
                    || rawType == typeof(uint)
                    || rawType == typeof(long)
                    || rawType == typeof(ulong))
                {
                    return 0;
                }
                else if (rawType == typeof(float)
                    // REVIEW: The above should really be float. But, for the
                    // sake of the POC, we use double since all the PFA convenience
                    // libraries operate over doubles.
                    || rawType == typeof(double))
                {
                    return 0.0;
                }
                else if (rawType == typeof(bool))
                {
                    return false;
                }
                else if (rawType == typeof(System.ReadOnlyMemory<char>)
                    || rawType == typeof(string))
                {
                    return String("");
                }
 
                return null;
            }
        }
 
        /// <summary>
        /// This ensures that there is a function formatted as "count_type" (for example, "count_double"),
        /// that takes either a map or array and returns the number of items in that map or array.
        /// </summary>
        /// <param name="ctx">The context to check for the existence of this</param>
        /// <param name="itemType">The item type this will operate on</param>
        public static string EnsureCount(this PfaContext ctx, JToken itemType)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            Contracts.CheckValue(itemType, nameof(itemType));
            var name = "count_" + itemType.ToString();
            if (ctx.ContainsFunc(name))
                return "u." + name;
            ctx.AddFunc(name, new JArray(Param("a", Type.Vector(itemType))), Type.Int,
                VectorCase(itemType, "a", "ma", Call("map.len", "ma"), "aa", Call("a.len", "aa")));
            return "u." + name;
        }
 
        /// <summary>
        /// A string -> bool function for determining whether a string has content.
        /// </summary>
        public static string EnsureHasChars(this PfaContext ctx)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            const string name = "hasChars";
            if (ctx.ContainsFunc(name))
                return "u." + name;
            ctx.AddFunc(name, new JArray(Param("str", Type.String)), Type.Bool,
                Call(">", Call("s.len", "str"), 0));
            return "u." + name;
        }
 
        public static string EnsureNewArray(this PfaContext ctx)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            const string name = "hasChars";
            const string refname = "u." + name;
            if (ctx.ContainsFunc(name))
                return refname;
            var arrType = Type.Array(Type.Double);
 
            JObject jobj = null;
 
            JArray elseBlock = new JArray();
            elseBlock.Add(jobj.AddReturn("let", jobj.AddReturn("halfsize",
                Call(refname, Call("//", "size", 2)))));
            elseBlock.Add(jobj.AddReturn("let", jobj.AddReturn("fullsize",
                Call("a.concat", "halfsize", "halfsize"))));
            elseBlock.Add(If(
                Call("==", Call("&", "size", 1), 1),
                Call("a.append", "fullsize", 0.0), "fullsize"));
 
            ctx.AddFunc(name, new JArray(Param("size", Type.Int)), arrType,
                If(Call("==", "size", 0), jobj.AddReturn("type", arrType).AddReturn("value", new JArray()),
                elseBlock));
            return refname;
        }
 
        public static string EnsureAdd(this PfaContext ctx, JToken itemType)
            => EnsureOpCore(ctx, "add", "+", itemType);
        public static string EnsureSub(this PfaContext ctx, JToken itemType)
            => EnsureOpCore(ctx, "sub", "-", itemType);
        public static string EnsureMul(this PfaContext ctx, JToken itemType)
            => EnsureOpCore(ctx, "mul", "*", itemType);
        public static string EnsureDiv(this PfaContext ctx, JToken itemType)
            => EnsureOpCore(ctx, "div", "/", itemType);
 
        private static string EnsureOpCore(PfaContext ctx, string funcPrefix, string binOp, JToken itemType, JToken returnType = null)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            Contracts.AssertNonEmpty(funcPrefix);
            Contracts.AssertNonEmpty(binOp);
            Contracts.CheckValue(itemType, nameof(itemType));
            Contracts.CheckValueOrNull(returnType);
            returnType = returnType ?? itemType;
 
            var name = funcPrefix + "_" + itemType.ToString();
            if (ctx.ContainsFunc(name))
                return "u." + name;
            ctx.AddFunc(name, new JArray(Param("a", itemType), Param("b", itemType)), returnType, Call(binOp, "a", "b"));
            return "u." + name;
        }
    }
}