File: ExpressionLanguageTests\ExpressionLanguageTests.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
#pragma warning disable 420 // volatile with Interlocked.CompareExchange
 
using System;
using System.CodeDom.Compiler;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Data.Conversion;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.RunTests;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFramework.Attributes;
using Microsoft.ML.Tests;
using Microsoft.ML.Transforms;
using Xunit;
using Xunit.Abstractions;
 
[assembly: LoadableClass(typeof(TestFuncs1), null, typeof(SignatureFunctionProvider), "Test Functions 1", "__test1")]
 
[assembly: LoadableClass(typeof(TestFuncs2), null, typeof(SignatureFunctionProvider), "Test Functions 2", "__test2")]
 
namespace Microsoft.ML.Tests
{
    using BL = System.Boolean;
    using I4 = System.Int32;
    using I8 = System.Int64;
    using R4 = Single;
    using R8 = Double;
    using TX = ReadOnlyMemory<char>;
 
    public sealed partial class ExprLanguageTests : BaseTestBaseline
    {
        private const string ResourcePrefix = "Microsoft.ML.Tests.ExpressionLanguageTests.TestData.";
        private object _sync = new object();
 
        public ExprLanguageTests(ITestOutputHelper output)
            : base(output)
        {
            Env.ComponentCatalog.RegisterAssembly(typeof(TestFuncs1).Assembly);
        }
 
        [Fact, TestCategory("Expr Language")]
        public void ExprParse()
        {
            // Code coverage test for the parser.
            Run("ExprParse");
        }
 
#if !NETFRAMEWORK
        // Bug in sin(x) in .Net Framework: sin(1e+30) returns 1e+30.
        [X64Fact("sin(1e+30) gives different value on x86."), TestCategory("Expr Language")]
        public void ExprBind()
        {
            // Code coverage test for the binder.
            Run("ExprBind");
        }
#endif
 
        [Fact, TestCategory("Expr Language")]
        public void ExprBindEx()
        {
            // Addition code coverage test for the binder.
            Run("ExprBindEx");
        }
 
        [Fact, TestCategory("Expr Language")]
        public void ExprCodeGen()
        {
            // Code coverage test for code gen.
            Run("ExprCodeGen");
        }
 
        [Fact, TestCategory("Expr Language")]
        public void ExprEval()
        {
            // Code coverage test evaluation. Note that VS can't help us measure this one :-(.
            Run("ExprEval");
        }
 
        private string InResName(string name)
        {
            return ResourcePrefix + name + "Input.txt";
        }
 
        private string GetResText(string resName)
        {
            var stream = typeof(ExprLanguageTests).Assembly.GetManifestResourceStream(resName);
            if (stream == null)
                return string.Format("<couldn't read {0}>", resName);
 
            using (var reader = new StreamReader(stream))
            {
                return reader.ReadToEnd();
            }
        }
 
        private void Run(string name)
        {
            string outDir = "ExprParser";
 
            string text = GetResText(InResName(name));
            string inName = name + "Input.txt";
            string outName = name + "Output.txt";
            string outNameAssem = name + "Output.Assem.txt";
            string outPath = DeleteOutputPath(outDir, outName);
            string outPathAssem = DeleteOutputPath(outDir, outNameAssem);
 
            using (var wr = OpenWriter(outPath))
            {
                var wrt = new IndentedTextWriter(wr, "  ");
 
                // Individual scripts are separated by $.
                // Inputs start after #.
                int count = 0;
                int ichLim = 0;
                int lineLim = 1;
                for (; ichLim < text.Length; ichLim++)
                {
                    int ichMin = ichLim;
                    int lineMin = lineLim;
 
                    while (ichLim < text.Length && text[ichLim] != '$')
                    {
                        if (text[ichLim] == '\n')
                            lineLim++;
                        ichLim++;
                    }
 
                    while (ichMin < ichLim && char.IsWhiteSpace(text[ichMin]))
                    {
                        if (text[ichMin] == '\n')
                            lineMin++;
                        ichMin++;
                    }
 
                    if (ichMin >= ichLim)
                        continue;
 
                    // Process the script.
                    count++;
                    string scriptName = string.Format("Script {0}, lines {1} to {2}", count, lineMin, lineLim);
                    wrt.WriteLine("===== Start {0} =====", scriptName);
                    var types = ParseTypes(text, ref ichMin, ichLim);
                    int ichLimChars = text.IndexOf('#', ichMin);
                    if (ichLimChars < 0 || ichLimChars >= ichLim)
                        ichLimChars = ichLim;
                    else
                        Contracts.Assert(ichMin < ichLimChars && ichLimChars < ichLim);
                    CharCursor chars = new CharCursor(text, ichMin, ichLimChars);
                    Delegate del = null;
                    lock (_sync)
                    {
                        try
                        {
                            LambdaNode node;
                            List<Error> errors;
                            List<int> lineMap;
                            var perm = Utils.GetIdentityPermutation(types.Length);
                            using (wrt.Nest())
                            {
                                node = LambdaParser.Parse(out errors, out lineMap, chars, perm, types);
                            }
                            Check(node != null, "Null LambdaNode");
                            if (Utils.Size(errors) > 0)
                            {
                                DumpErrors(wrt, lineMap, lineMin, inName, "Parsing", errors);
                                goto LDone;
                            }
 
                            LambdaBinder.Run(Env, ref errors, node, msg => wr.WriteLine(msg));
                            if (Utils.Size(errors) > 0)
                            {
                                DumpErrors(wrt, lineMap, lineMin, inName, "Binding", errors);
                                goto LDone;
                            }
                            wrt.WriteLine("Binding succeeded. Output type: {0}", node.ResultType);
 
                            del = LambdaCompiler.Compile(out errors, node);
                            Contracts.Assert(TestFuncs1.Writer == null);
                            TestFuncs1.Writer = wr;
                            if (Utils.Size(errors) > 0)
                            {
                                DumpErrors(wrt, lineMap, lineMin, inName, "Compiling", errors);
                                goto LDone;
                            }
 
                            wrt.WriteLine("Compiling succeeded.");
                            if (ichLimChars < ichLim)
                                Evaluate(wrt, del, node.ResultType, types, text, ichLimChars + 1, ichLim);
                        }
                        catch (Exception ex)
                        {
                            if (!ex.IsMarked())
                                wrt.WriteLine("Unknown Exception: {0}!", del != null ? del.GetMethodInfo().DeclaringType : (object)"<null>");
                            wrt.WriteLine("Exception: {0}", ex.Message);
                        }
                        finally
                        {
                            TestFuncs1.Writer = null;
                        }
 
LDone:
                        wrt.WriteLine("===== End {0} =====", scriptName);
                    }
                }
            }
 
            CheckEquality(outDir, outName, digitsOfPrecision: 6);
 
            Done();
        }
 
        private DataViewType[] ParseTypes(string text, ref int ichMin, int ichLim)
        {
            int ichCol = text.IndexOf(':', ichMin);
            Contracts.Assert(ichMin < ichCol && ichCol < ichLim);
            string[] toks = text.Substring(ichMin, ichCol - ichMin).Split(',');
            var res = new DataViewType[toks.Length];
            for (int i = 0; i < toks.Length; i++)
            {
                InternalDataKind kind;
                bool tmp = Enum.TryParse(toks[i], out kind);
                Contracts.Assert(tmp);
                res[i] = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
            }
            ichMin = ichCol + 1;
            return res;
        }
 
        private void Evaluate(IndentedTextWriter wrt, Delegate del, DataViewType typeRes, DataViewType[] types,
            string text, int ichMin, int ichLim)
        {
            Contracts.AssertValue(del);
            Contracts.AssertNonEmpty(types);
            var args = new object[types.Length];
            var getters = new Func<ReadOnlyMemory<char>, bool>[types.Length];
            for (int i = 0; i < getters.Length; i++)
                getters[i] = GetGetter(i, types[i], args);
 
            StringBuilder sb = new StringBuilder();
            Action<object> printer = GetPrinter(typeRes, sb);
 
            ReadOnlyMemory<char> chars = text.AsMemory().Slice(ichMin, ichLim - ichMin);
            for (bool more = true; more;)
            {
                ReadOnlyMemory<char> line;
                if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
                    more = ReadOnlyMemoryUtils.SplitOne(chars, '\x0D', out line, out chars);
                else
                    more = ReadOnlyMemoryUtils.SplitOne(chars, '\x0A', out line, out chars);
                line = ReadOnlyMemoryUtils.TrimWhiteSpace(line);
                if (line.IsEmpty)
                    continue;
 
                // Note this "hack" to map _ to empty. It's easier than fully handling quoting and is sufficient
                // for these tests.
                var vals = ReadOnlyMemoryUtils.Split(line, new char[] { ',' })
                        .Select(x => ReadOnlyMemoryUtils.TrimWhiteSpace(x))
                        .Select(x => ReadOnlyMemoryUtils.EqualsStr("_", x) ? ReadOnlyMemory<char>.Empty : x)
                        .ToArray();
 
                Contracts.Assert(vals.Length == getters.Length);
                for (int i = 0; i < getters.Length; i++)
                {
                    if (!getters[i](vals[i]))
                        wrt.Write("*** Parsing {0} Failed *** ", vals[i]);
                }
                var res = del.DynamicInvoke(args);
                printer(res);
                wrt.WriteLine(sb);
            }
        }
 
        private Func<ReadOnlyMemory<char>, bool> GetGetter(int i, DataViewType dst, object[] args)
        {
            switch (dst.GetRawKind())
            {
                case InternalDataKind.BL:
                    return
                        src =>
                        {
                            bool v;
                            bool tmp = Conversions.DefaultInstance.TryParse(in src, out v);
                            args[i] = v;
                            return tmp;
                        };
                case InternalDataKind.I4:
                    return
                        src =>
                        {
                            int v;
                            bool tmp = Conversions.DefaultInstance.TryParse(in src, out v);
                            args[i] = v;
                            return tmp;
                        };
                case InternalDataKind.I8:
                    return
                        src =>
                        {
                            long v;
                            bool tmp = Conversions.DefaultInstance.TryParse(in src, out v);
                            args[i] = v;
                            return tmp;
                        };
                case InternalDataKind.R4:
                    return
                        src =>
                        {
                            float v;
                            bool tmp = Conversions.DefaultInstance.TryParse(in src, out v);
                            args[i] = v;
                            return tmp;
                        };
                case InternalDataKind.R8:
                    return
                        src =>
                        {
                            double v;
                            bool tmp = Conversions.DefaultInstance.TryParse(in src, out v);
                            args[i] = v;
                            return tmp;
                        };
                case InternalDataKind.TX:
                    return
                        src =>
                        {
                            args[i] = src;
                            return true;
                        };
            }
 
            Contracts.Assert(false);
            return null;
        }
 
        private Action<object> GetPrinter(DataViewType dst, StringBuilder sb)
        {
            switch (dst.GetRawKind())
            {
                case InternalDataKind.BL:
                    return
                        src =>
                        {
                            var v = (bool)src;
                            Conversions.DefaultInstance.Convert(in v, ref sb);
                        };
                case InternalDataKind.I4:
                    return
                        src =>
                        {
                            var v = (int)src;
                            Conversions.DefaultInstance.Convert(in v, ref sb);
                        };
                case InternalDataKind.I8:
                    return
                        src =>
                        {
                            var v = (long)src;
                            Conversions.DefaultInstance.Convert(in v, ref sb);
                        };
                case InternalDataKind.R4:
                    return
                        src =>
                        {
                            var v = (Single)src;
                            Conversions.DefaultInstance.Convert(in v, ref sb);
                        };
                case InternalDataKind.R8:
                    return
                        src =>
                        {
                            var v = (Double)src;
                            Conversions.DefaultInstance.Convert(in v, ref sb);
                        };
                case InternalDataKind.TX:
                    return
                        src =>
                        {
                            var v = (ReadOnlyMemory<char>)src;
                            TextSaverUtils.MapText(v.Span, ref sb, '\t');
                        };
            }
 
            Contracts.Assert(false);
            return null;
        }
 
        private void DumpErrors(IndentedTextWriter wrt, List<int> lineMap, int lineMin,
            string fileName, string phase, List<Error> errors)
        {
            Contracts.AssertValue(wrt);
            Contracts.AssertValue(lineMap);
            Contracts.AssertNonEmpty(phase);
            Contracts.AssertNonEmpty(errors);
 
            using (wrt.Nest())
            {
                foreach (var err in errors)
                {
                    var tok = err.Token;
                    Contracts.AssertValue(tok);
                    var pos = new LambdaParser.SourcePos(lineMap, tok.Span, lineMin);
                    wrt.Write("{0}({1},{2})-({3},{4}): ",
                        fileName, pos.LineMin, pos.ColumnMin, pos.LineLim, pos.ColumnLim);
                    wrt.Write("error: ");
                    wrt.WriteLine(err.GetMessage());
                }
            }
        }
    }
 
    public sealed class TestFuncs1 : IFunctionProvider
    {
        // REVIEW: This is a temporary hack to baseline test the _dump functions. 
        // Should probably figure out a proper way to do this.
        internal static TextWriter Writer;
 
        private static volatile TestFuncs1 _instance;
 
        public static TestFuncs1 Instance
        {
            get
            {
                if (_instance == null)
                    Interlocked.CompareExchange(ref _instance, new TestFuncs1(), null);
                return _instance;
            }
        }
 
        private static TextWriter OutWriter { get { return Writer ?? Console.Out; } }
 
        public string NameSpace { get { return "__test1"; } }
 
        public MethodInfo[] Lookup(string name)
        {
            switch (name)
            {
                // This one should be ambigous when invoked on an I4.
                case "_aa":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<I8, I8>(A),
                        FunctionProviderUtils.Fn<R4, R4>(A));
                case "_ab":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<I8, I8>(A));
                case "_ac":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<I8, I8>(A));
                case "_ad":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<I8, I8>(A));
 
                case "_var":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<I4, BL, R4[], R4>(Var));
 
                case "_ba":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<I4>(B),
                        FunctionProviderUtils.Fn<I4, I4>(B),
                        FunctionProviderUtils.Fn<I4, I4, I4>(B));
 
                case "_bad":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<object>(X),
                        FunctionProviderUtils.Fn<string, I4>(X),
                        FunctionProviderUtils.Fn<I4, I4>(X),
                        ((Func<I8, I8>)(X)).GetMethodInfo(),
                        FunctionProviderUtils.Fn<R4, R4>(X));
 
                case "_fa":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<BL, BL>(F));
 
                case "_dump":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<I4, I4>(Dump),
                        FunctionProviderUtils.Fn<I8, I8>(Dump),
                        FunctionProviderUtils.Fn<R4, R4>(Dump),
                        FunctionProviderUtils.Fn<R8, R8>(Dump),
                        FunctionProviderUtils.Fn<BL, BL>(Dump),
                        FunctionProviderUtils.Fn<TX, TX>(Dump),
                        FunctionProviderUtils.Fn<TX, I4, I4>(Dump),
                        FunctionProviderUtils.Fn<TX, I8, I8>(Dump),
                        FunctionProviderUtils.Fn<TX, R4, R4>(Dump),
                        FunctionProviderUtils.Fn<TX, R8, R8>(Dump),
                        FunctionProviderUtils.Fn<TX, BL, BL>(Dump),
                        FunctionProviderUtils.Fn<TX, TX, TX>(Dump));
 
                case "_chars":
                    return FunctionProviderUtils.Ret(FunctionProviderUtils.Fn<TX, TX>(DumpChars));
            }
 
            return null;
        }
 
        public object ResolveToConstant(string name, MethodInfo meth, object[] values)
        {
            switch (name)
            {
                case "_bad":
                    // Note this is intentionally wrong (should return an I4, not int), to test
                    // handling of buggy implementations of IExprFunctions.
                    return 3;
            }
 
            return null;
        }
 
        public static I8 A(I8 a)
        {
            return a * 2;
        }
 
        public static R4 A(R4 a)
        {
            return -a;
        }
 
        /// <summary>
        /// For testing variable-arg functions. This selects the element in c indicated by a.
        /// If b is true, it negates the result.
        /// </summary>
        public static R4 Var(I4 a, BL b, R4[] c)
        {
            if (a < 0 || a >= c.Length)
                return R4.NaN;
            R4 res = c[a];
            if (b)
                res = -res;
            return res;
        }
 
        public static I4 B()
        {
            return 1;
        }
 
        public static I4 B(I4 a)
        {
            return 2;
        }
 
        public static I4 B(I4 a, I4 b)
        {
            return 3;
        }
 
        public static object X()
        {
            return null;
        }
 
        public static I4 X(string a)
        {
            return 41;
        }
 
        public static I4 X(I4 a)
        {
            return a;
        }
 
        public I8 X(I8 a)
        {
            return a;
        }
 
        internal static R4 X(R4 a)
        {
            return a;
        }
 
        public static BL F(BL a)
        {
            return a;
        }
 
        public static T Dump<T>(T a)
        {
            OutWriter.WriteLine("ExprDump: {0}", a);
            return a;
        }
 
        public static T Dump<T>(TX fmt, T a)
        {
            OutWriter.WriteLine(fmt.ToString(), a);
            return a;
        }
 
        public static TX DumpChars(TX a)
        {
            var sb = new StringBuilder();
            for (int ich = 0; ich < a.Length; ich++)
                sb.AppendFormat("{0:X4} ", (short)a.Span[ich]);
            OutWriter.WriteLine("ExprDumpChars: {0}", sb);
            return a;
        }
    }
 
    public sealed class TestFuncs2 : IFunctionProvider
    {
        private static volatile TestFuncs2 _instance;
        public static TestFuncs2 Instance
        {
            get
            {
                if (_instance == null)
                    Interlocked.CompareExchange(ref _instance, new TestFuncs2(), null);
                return _instance;
            }
        }
 
        public string NameSpace { get { return "__test2"; } }
 
        private TestFuncs2()
        {
        }
 
        private MethodInfo[] R(params Delegate[] funcs)
        {
            Contracts.AssertValue(funcs);
            var meths = new MethodInfo[funcs.Length];
            for (int i = 0; i < funcs.Length; i++)
            {
                Contracts.Assert(funcs[i] != null);
                Contracts.Assert(funcs[i].Target == null);
                Contracts.Assert(funcs[i].GetMethodInfo() != null);
                meths[i] = funcs[i].GetMethodInfo();
            }
            return meths;
        }
 
        public MethodInfo[] Lookup(string name)
        {
            switch (name)
            {
                case "_ab":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<R4, R4>(A));
                case "_ac":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<I8, I8>(A));
                case "_ad":
                    return FunctionProviderUtils.Ret(
                        FunctionProviderUtils.Fn<I4, I4>(A));
            }
 
            return null;
        }
 
        public object ResolveToConstant(string name, MethodInfo meth, object[] values)
        {
            return null;
        }
 
        public static I4 A(I4 a)
        {
            return a * 3 * 10;
        }
 
        public static I8 A(I8 a)
        {
            return a * 2 * 10;
        }
 
        public static R4 A(R4 a)
        {
            return -a * 10;
        }
    }
}