File: Model\Pfa\PfaContext.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Runtime;
using Newtonsoft.Json.Linq;
 
namespace Microsoft.ML.Model.Pfa
{
    /// <summary>
    /// A context for defining a restricted sort of PFA output.
    /// </summary>
    [BestFriend]
    internal sealed class PfaContext
    {
        public JToken InputType { get; set; }
        public JToken OutputType { get; set; }
        public JToken Final { get; set; }
 
        private readonly HashSet<string> _variables;
        private readonly List<CellBlock> _cellBlocks;
        private readonly List<VariableBlock> _letSetBlocks;
        private readonly List<FuncBlock> _funcBlocks;
        private readonly HashSet<string> _types;
        private readonly IHost _host;
 
        private readonly struct VariableBlock
        {
            public readonly string Type;
            public readonly KeyValuePair<string, JToken>[] Locals;
 
            public VariableBlock(string type, KeyValuePair<string, JToken>[] locals)
            {
                Type = type;
                Locals = locals;
            }
 
            public JToken ToToken()
            {
                var vars = new JObject();
                foreach (var v in Locals)
                    vars[v.Key] = v.Value;
                var blockJson = new JObject();
                blockJson[Type] = vars;
                return blockJson;
            }
        }
 
        private readonly struct CellBlock
        {
            public readonly string Name;
            public readonly JToken Type;
            public readonly JToken Init;
 
            public CellBlock(string name, JToken type, JToken init)
            {
                Name = name;
                Type = type;
                Init = init;
            }
 
            public JObject ToToken()
            {
                var vars = new JObject();
                vars["type"] = Type;
                vars["init"] = Init;
                return vars;
            }
        }
 
        private readonly struct FuncBlock
        {
            public readonly string Name;
            public readonly JArray Params;
            public readonly JToken ReturnType;
            public readonly JToken Do;
 
            public FuncBlock(string name, JArray prms, JToken returnType, JToken doBlock)
            {
                Name = name;
                Params = prms;
                ReturnType = returnType;
                Do = doBlock;
            }
 
            public JObject ToToken()
            {
                var func = new JObject();
                func["params"] = Params;
                func["ret"] = ReturnType;
                func["do"] = Do;
                return func;
            }
        }
 
        public PfaContext(IHostEnvironment env)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(PfaContext));
            _variables = new HashSet<string>();
            _cellBlocks = new List<CellBlock>();
            _letSetBlocks = new List<VariableBlock>();
            _funcBlocks = new List<FuncBlock>();
            _types = new HashSet<string>();
        }
 
        public JObject Finalize()
        {
            var val = new JObject();
            val["input"] = InputType;
            val["output"] = OutputType;
 
            // Add the cells.
            if (_cellBlocks.Count > 0)
            {
                var cells = new JObject();
                foreach (var cell in _cellBlocks)
                    cells[cell.Name] = cell.ToToken();
                val["cells"] = cells;
            }
 
            // Add the actions.
            if (_letSetBlocks.Count > 0)
            {
                var actions = new JArray();
                foreach (var block in _letSetBlocks)
                    actions.Add(block.ToToken());
                actions.Add(Final);
                val["action"] = actions;
            }
            else
                val["action"] = Final;
 
            // Add the functions at the end.
            if (_funcBlocks.Count > 0)
            {
                var funcs = new JObject();
                foreach (var block in _funcBlocks)
                    funcs[block.Name] = block.ToToken();
                val["fcns"] = funcs;
            }
 
            return val;
        }
 
        public KeyValuePair<string, JToken> CreatePair(string varName, string token)
        {
            return new KeyValuePair<string, JToken>(varName, JToken.Parse(token));
        }
 
        public void AddVariables(params KeyValuePair<string, JToken>[] locals)
        {
            // Add as lets, then sets.
            if (locals.Length == 0)
                return;
            var sets = new List<KeyValuePair<string, JToken>>();
            foreach (var l in locals)
            {
                if (_variables.Contains(l.Key))
                    sets.Add(l);
            }
            // If either all or none of the inputs are sets, we can simplify the logic slightly.
            if (sets.Count == 0 || locals.Length == sets.Count)
            {
                _letSetBlocks.Add(new VariableBlock(sets.Count == 0 ? "let" : "set", locals));
                _variables.UnionWith(locals.Select(v => v.Key));
                return;
            }
            var lets = new List<KeyValuePair<string, JToken>>(locals.Length - sets.Count);
 
            foreach (var l in locals)
            {
                if (!_variables.Contains(l.Key))
                    lets.Add(l);
            }
            _variables.UnionWith(locals.Select(v => v.Key));
            // We must do the lets first.
            _letSetBlocks.Add(new VariableBlock("let", lets.ToArray()));
            _letSetBlocks.Add(new VariableBlock("set", sets.ToArray()));
        }
 
        public void AddCell(string name, JToken type, JToken init)
        {
            Contracts.CheckValue(name, nameof(name));
            if (ContainsCell(name))
                throw Contracts.ExceptParam(nameof(name), $"Cell {name} already exists");
            Contracts.CheckValue(type, nameof(type));
            Contracts.CheckValue(init, nameof(init));
            _cellBlocks.Add(new CellBlock(name, type, init));
        }
 
        public void AddFunc(string name, JArray prms, JToken returnType, JToken doBlock)
        {
            _funcBlocks.Add(new FuncBlock(name, prms, returnType, doBlock));
        }
 
        /// <summary>
        /// For creating an anonymous function block. This in itself will not modify the context.
        /// </summary>
        public static JObject CreateFuncBlock(JArray prms, JToken returnType, JToken doBlock)
        {
            return new FuncBlock("foo", prms, returnType, doBlock).ToToken();
        }
 
        public bool ContainsCell(string name) => _cellBlocks.Any(c => c.Name == name);
 
        public bool ContainsVar(string name) => _variables.Contains(name);
 
        public bool ContainsFunc(string name) => _funcBlocks.Any(b => b.Name == name);
 
        public bool ContainsType(string name) => _types.Contains(name);
 
        /// <summary>
        /// PFA is weird in that you do not declare types separately, you declare them as part of a variable
        /// declaration. So, if you use a record type three times, that means one of the three usages must be
        /// accompanied by a full type declaration, whereas the other two can just then identify it by name.
        /// This is extremely silly, but there you go.
        ///
        /// Anyway: this will attempt to add a type to the list of registered types. If it returns <c>true</c>
        /// then the caller is responsible, then, for ensuring that their PFA code they are generating contains
        /// not only a reference of the type, but a declaration of the type. If however this returns <c>false</c>
        /// then it can just refer to the type by name, since it has already been declared.
        /// </summary>
        /// <param name="name">The type to register</param>
        /// <returns>If this name was not already registered</returns>
        /// <seealso cref="ContainsType(string)"/>
        public bool RegisterType(string name)
        {
            return _types.Add(name);
        }
    }
}