File: Expression\Printer.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.CodeDom.Compiler;
using System.IO;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Transforms
{
    using BL = System.Boolean;
    using I4 = System.Int32;
    using I8 = System.Int64;
    using R4 = Single;
    using R8 = Double;
    using TX = ReadOnlyMemory<char>;
 
    // Simple pretty-printing visitor
    internal sealed class NodePrinter : PreVisitor
    {
        private readonly bool _showTypes;
        private readonly bool _showValues;
        private readonly IndentedTextWriter _wrt;
 
        private NodePrinter(IndentedTextWriter wrt, bool showTypes, bool showValues)
        {
            Contracts.AssertValue(wrt);
 
            _showTypes = showTypes;
            _showValues = showValues;
            _wrt = wrt;
        }
 
        // Public entry point for prettyprinting TEXL parse trees
        public static void Print(Node node, TextWriter writer, bool showTypes = false, bool showValues = false)
        {
            Contracts.AssertValue(node);
 
            var wrt = new IndentedTextWriter(writer, "  ");
            NodePrinter printer = new NodePrinter(wrt, showTypes, showValues);
            node.Accept(printer);
        }
 
        private bool NeedParensLeft(Precedence precLeft, Precedence precOp)
        {
            if (precLeft < precOp)
                return true;
            if (precLeft > precOp)
                return false;
 
            // Power is right associative.
            return precOp == Precedence.Power;
        }
 
        private bool NeedParensRight(Precedence precOp, Precedence precRight)
        {
            if (precOp == Precedence.Postfix)
            {
                // Indexing is the only postfix operator, and it never
                // needs parens around the right operand.
                return false;
            }
 
            if (precOp > precRight)
                return true;
            if (precOp < precRight)
                return false;
 
            // Power is right associative.
            return precOp != Precedence.Power;
        }
 
        private Precedence GetPrec(Node node)
        {
            Contracts.Assert(node is ExprNode);
 
            switch (node.Kind)
            {
                case NodeKind.BinaryOp:
                    return GetPrec(node.AsBinaryOp.Op);
 
                case NodeKind.UnaryOp:
                    return Precedence.PrefixUnary;
 
                case NodeKind.Compare:
                    return Precedence.Compare;
 
                case NodeKind.Call:
                case NodeKind.With:
                    return Precedence.Primary;
 
                case NodeKind.Ident:
                case NodeKind.BoolLit:
                case NodeKind.NumLit:
                case NodeKind.StrLit:
                    return Precedence.Atomic;
 
                default:
                    Contracts.Assert(false, "Unexpected node kind in GetPrec - should only see ExprNode kinds");
                    return Precedence.None;
            }
        }
 
        private Precedence GetPrec(BinaryOp op)
        {
            switch (op)
            {
                case BinaryOp.Or:
                    return Precedence.Or;
                case BinaryOp.And:
                    return Precedence.And;
                case BinaryOp.Add:
                case BinaryOp.Sub:
                    return Precedence.Add;
                case BinaryOp.Mul:
                case BinaryOp.Div:
                case BinaryOp.Mod:
                    return Precedence.Mul;
                case BinaryOp.Power:
                    return Precedence.Power;
                case BinaryOp.Error:
                    return Precedence.None;
                default:
                    Contracts.Assert(false);
                    return Precedence.None;
            }
        }
 
        private string GetString(BinaryOp op)
        {
            switch (op)
            {
                case BinaryOp.Or:
                    return " or ";
                case BinaryOp.And:
                    return " and ";
                case BinaryOp.Add:
                    return " + ";
                case BinaryOp.Sub:
                    return " - ";
                case BinaryOp.Mul:
                    return " * ";
                case BinaryOp.Div:
                    return " / ";
                case BinaryOp.Mod:
                    return " % ";
                case BinaryOp.Power:
                    return " ^ ";
                case BinaryOp.Error:
                    return " <err> ";
                default:
                    Contracts.Assert(false);
                    return " <bad> ";
            }
        }
 
        private string GetString(UnaryOp op)
        {
            switch (op)
            {
                case UnaryOp.Not:
                    return "not ";
                case UnaryOp.Minus:
                    return "-";
                default:
                    Contracts.Assert(false);
                    return "<bad> ";
            }
        }
 
        private string GetString(TokKind tidCompare)
        {
            switch (tidCompare)
            {
                case TokKind.Equ:
                    return " = ";
                case TokKind.EquEqu:
                    return " == ";
                case TokKind.LssGrt:
                    return " <> ";
                case TokKind.BngEqu:
                    return " != ";
                case TokKind.Lss:
                    return " < ";
                case TokKind.LssEqu:
                    return " <= ";
                case TokKind.GrtEqu:
                    return " >= ";
                case TokKind.Grt:
                    return " > ";
 
                default:
                    Contracts.Assert(false);
                    return " <bad> ";
            }
        }
 
        private bool TryShowValue(ExprNode node)
        {
            if (!_showValues)
                return false;
            if (node.ExprValue == null)
                return false;
 
            ShowValueCore(node);
            ShowType(node);
 
            return true;
        }
 
        private void ShowValueCore(ExprNode node)
        {
            Contracts.AssertValue(node);
            Contracts.AssertValue(node.ExprValue);
 
            var value = node.ExprValue;
            switch (node.ExprType)
            {
                case ExprTypeKind.I4:
                    Show((I4)value);
                    break;
                case ExprTypeKind.I8:
                    Show((I8)value);
                    break;
                case ExprTypeKind.R4:
                    Show((R4)value);
                    break;
                case ExprTypeKind.R8:
                    Show((R8)value);
                    break;
                case ExprTypeKind.BL:
                    Show((BL)value);
                    break;
                case ExprTypeKind.TX:
                    Show((TX)value);
                    break;
                default:
                    Contracts.Assert(false, "Unknown type");
                    break;
            }
        }
 
        private void Show(I4 x)
        {
            _wrt.Write(x);
        }
 
        private void Show(I8 x)
        {
            _wrt.Write(x);
        }
 
        private void Show(R4 x)
        {
            if (R4.IsNaN(x))
                _wrt.Write("NA");
            else
                _wrt.Write("{0:R}", x);
        }
 
        private void Show(R8 x)
        {
            if (R8.IsNaN(x))
                _wrt.Write("NA");
            else
                _wrt.Write("{0:R}", x);
        }
 
        private void Show(BL x)
        {
            if (!x)
                _wrt.Write("false");
            else
                _wrt.Write("true");
        }
 
        private void Show(TX str)
        {
            int len = str.Length;
            if (len > 100)
                len = 97; // Leave room for ...
 
            _wrt.Write('"');
            foreach (var ch in str.Span)
            {
                // Replace problematic characters with space.
                // REVIEW: Which characters should we replace?
                if (ch < ' ' || ch == '"')
                    _wrt.Write(' ');
                else
                    _wrt.Write(ch);
            }
            if (len < str.Length)
                _wrt.Write("...");
            _wrt.Write('"');
        }
 
        private void ShowType(ExprNode node)
        {
            if (!_showTypes)
                return;
            if (node.IsNone)
                return;
 
            _wrt.Write(':');
            _wrt.Write(node.ExprType.ToString());
        }
 
        private void ShowType(ParamNode node)
        {
            if (!_showTypes)
                return;
            if (node.ExprType == ExprTypeKind.None)
                return;
 
            _wrt.Write(':');
            _wrt.Write(node.ExprType.ToString());
        }
 
        public override void Visit(BoolLitNode node)
        {
            Contracts.AssertValue(node);
            _wrt.Write(node.Value ? "true" : "false");
            ShowType(node);
        }
 
        public override void Visit(StrLitNode node)
        {
            Contracts.AssertValue(node);
            Show(node.Value);
            ShowType(node);
        }
 
        public override void Visit(NumLitNode node)
        {
            Contracts.AssertValue(node);
            _wrt.Write(node.Value.ToString());
            ShowType(node);
        }
 
        public override void Visit(NameNode node)
        {
            Contracts.AssertValue(node);
            _wrt.Write(node.Value);
        }
 
        public override void Visit(IdentNode node)
        {
            Contracts.AssertValue(node);
 
            if (TryShowValue(node))
                return;
            _wrt.Write(node.Value);
            ShowType(node);
        }
 
        public override void Visit(ParamNode node)
        {
            Contracts.AssertValue(node);
            _wrt.Write(node.Name);
            ShowType(node);
        }
 
        public override void Visit(LambdaNode node)
        {
            Contracts.AssertValue(node);
            if (node.Vars.Length == 1)
                node.Vars[0].Accept(this);
            else
            {
                _wrt.Write('(');
                var pre = "";
                foreach (var v in node.Vars)
                {
                    _wrt.Write(pre);
                    v.Accept(this);
                    pre = "";
                }
                _wrt.Write(")");
            }
            _wrt.Write(" => ");
            node.Expr.Accept(this);
        }
 
        public override void Visit(UnaryOpNode node)
        {
            Contracts.AssertValue(node);
 
            if (TryShowValue(node))
                return;
 
            Precedence prec = GetPrec(node.Arg);
            _wrt.Write(GetString(node.Op));
            if (prec < Precedence.PrefixUnary)
                _wrt.Write('(');
            node.Arg.Accept(this);
            if (prec < Precedence.PrefixUnary)
                _wrt.Write(')');
            ShowType(node);
        }
 
        public override void Visit(BinaryOpNode node)
        {
            Contracts.AssertValue(node);
 
            if (TryShowValue(node))
                return;
 
            Precedence prec = GetPrec(node.Op);
            Precedence prec1 = GetPrec(node.Left);
            Precedence prec2 = GetPrec(node.Right);
            bool parens1 = NeedParensLeft(prec1, prec);
            bool parens2 = NeedParensRight(prec, prec2);
 
            if (parens1)
                _wrt.Write('(');
            node.Left.Accept(this);
            if (parens1)
                _wrt.Write(')');
 
            _wrt.Write(GetString(node.Op));
 
            if (parens2)
                _wrt.Write('(');
            node.Right.Accept(this);
            if (parens2)
                _wrt.Write(')');
 
            ShowType(node);
        }
 
        public override void Visit(ConditionalNode node)
        {
            Contracts.AssertValue(node);
 
            if (TryShowValue(node))
                return;
 
            Precedence prec0 = GetPrec(node.Cond);
            Precedence prec1 = GetPrec(node.Left);
            Precedence prec2 = GetPrec(node.Right);
            bool parens0 = NeedParensLeft(prec0, Precedence.Conditional);
 
            if (parens0)
                _wrt.Write('(');
            node.Cond.Accept(this);
            if (parens0)
                _wrt.Write(')');
 
            _wrt.Write(" ? ");
            node.Left.Accept(this);
            _wrt.Write(" : ");
            node.Right.Accept(this);
 
            ShowType(node);
        }
 
        public override void Visit(CompareNode node)
        {
            Contracts.AssertValue(node);
 
            if (TryShowValue(node))
                return;
 
            TokKind tidLax = node.TidLax;
            TokKind tidStrict = node.TidStrict;
            string strLax = GetString(tidLax);
            string strStrict = GetString(tidStrict);
 
            string str = string.Empty;
            string strOp = string.Empty;
            for (int i = 0; ;)
            {
                _wrt.Write(strOp);
                var arg = node.Operands.Items[i];
                var prec = GetPrec(arg);
                if (prec <= Precedence.Compare)
                    _wrt.Write('(');
                arg.Accept(this);
                if (prec <= Precedence.Compare)
                    _wrt.Write(')');
                if (++i >= node.Operands.Items.Length)
                    break;
                var tid = node.Operands.Delimiters[i - 1].Kind;
                Contracts.Assert(tid == tidLax || tid == tidStrict);
                strOp = tid == tidLax ? strLax : strStrict;
            }
 
            ShowType(node);
        }
 
        public override void Visit(CallNode node)
        {
            Contracts.AssertValue(node);
 
            if (TryShowValue(node))
                return;
 
            if (node.NameSpace != null)
            {
                node.NameSpace.Accept(this);
                _wrt.Write('.');
            }
            node.Head.Accept(this);
            _wrt.Write('(');
            node.Args.Accept(this);
            _wrt.Write(')');
            ShowType(node);
        }
 
        public override void Visit(ListNode node)
        {
            Contracts.AssertValue(node);
 
            int count = node.Items.Length;
            if (count == 0)
                return;
 
            if (node.Delimiters == null)
            {
                foreach (var child in node.Items)
                    child.Accept(this);
            }
            else if (count <= 6)
            {
                node.Items[0].Accept(this);
                for (int i = 1; i < count; i++)
                {
                    _wrt.Write(", ");
                    node.Items[i].Accept(this);
                }
            }
            else
            {
                for (int i = 0; i < 5; i++)
                {
                    node.Items[i].Accept(this);
                    _wrt.Write(", ");
                }
                _wrt.Write("..., ");
                node.Items[count - 1].Accept(this);
            }
        }
 
        public override void Visit(WithNode node)
        {
            Contracts.AssertValue(node);
 
            if (TryShowValue(node))
                return;
 
            _wrt.Write("with(");
            node.Local.Accept(this);
            _wrt.Write("; ");
            node.Body.Accept(this);
            _wrt.Write(")");
 
            ShowType(node);
        }
 
        public override void Visit(WithLocalNode node)
        {
            Contracts.AssertValue(node);
 
            _wrt.Write(node.Name);
            _wrt.Write(" = ");
            node.Value.Accept(this);
        }
    }
}