File: CmdLine\CmdLineReverseTest.cs
Web Access
Project: src\test\Microsoft.ML.Predictor.Tests\Microsoft.ML.Predictor.Tests.csproj (Microsoft.ML.Predictor.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Reflection;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.RunTests
{
    public class CmdLineReverseTests : BaseTestClass
    {
        /// <summary>
        /// This tests CmdParser.GetSettings
        /// </summary>
        [Fact]
        [TestCategory("Cmd Parsing")]
        public void ArgumentParseTest()
        {
            var env = new MLContext(seed: 42);
            var innerArg1 = new SimpleArg()
            {
                required = -2,
                text1 = "}",
                text2 = "{",
                text3 = "  ",
                text4 = "\n",
            };
 
            var innerArg2 = new SimpleArg()
            {
                required = -2,
                text1 = "{!@# $%^}&*{()}",
                text2 = "=",
                text3 = "\t",
                text4 = @"\\",
            };
 
            var innerArg3 = new SimpleArg()
            {
                required = -2,
                text1 = "\r\n",
                text2 = "\"",
                text3 = "\" \" ",
                text4 = "{/a=2 /b=3 /c=4}",
                sub1 = CreateComponentFactory("S1", innerArg1.ToString(env)),
                sub2 = CreateComponentFactory("S2", innerArg2.ToString(env)),
            };
 
            var outerArg1 = new SimpleArg()
            {
                required = -2,
                once = 2,
                text2 = "Testing",
                text3 = "a=7",
                sub1 = CreateComponentFactory("S1", innerArg1.ToString(env)),
                sub2 = CreateComponentFactory("S2", innerArg2.ToString(env)),
                sub3 = CreateComponentFactory("S3", innerArg3.ToString(env)),
            };
 
            var testArg = new SimpleArg();
            CmdParser.ParseArguments(env, outerArg1.ToString(env), testArg);
            Assert.Equal(outerArg1, testArg);
 
            CmdParser.ParseArguments(env, ((ICommandLineComponentFactory)outerArg1.sub1).GetSettingsString(), testArg = new SimpleArg());
            Assert.Equal(innerArg1, testArg);
 
            CmdParser.ParseArguments(env, ((ICommandLineComponentFactory)outerArg1.sub2).GetSettingsString(), testArg = new SimpleArg());
            Assert.Equal(innerArg2, testArg);
 
            CmdParser.ParseArguments(env, ((ICommandLineComponentFactory)outerArg1.sub3).GetSettingsString(), testArg = new SimpleArg());
            Assert.Equal(innerArg3, testArg);
        }
 
        [Fact]
        [TestCategory("Cmd Parsing")]
        public void NewTest()
        {
            var ml = new MLContext(1);
            ml.AddStandardComponents();
            var findLoadableClassesMethodInfo = new FuncInstanceMethodInfo1<ComponentCatalog, ComponentCatalog.LoadableClassInfo[]>(ml.ComponentCatalog.FindLoadableClasses<int>);
            var classes = Utils.MarshalInvoke(findLoadableClassesMethodInfo, ml.ComponentCatalog, typeof(SignatureCalibrator));
            foreach (var cls in classes)
            {
                var factory = CmdParser.CreateComponentFactory(typeof(IComponentFactory<ICalibratorTrainer>), typeof(SignatureCalibrator), cls.LoadNames[0]);
                var calibrator = ((IComponentFactory<ICalibratorTrainer>)factory).CreateComponent(ml);
            }
            var components = ml.ComponentCatalog.GetAllComponents(typeof(ICalibratorTrainerFactory));
            foreach (var component in components)
            {
                var factory = CmdParser.CreateComponentFactory(typeof(IComponentFactory<ICalibratorTrainer>), typeof(SignatureCalibrator), component.Aliases[0]);
                var calibrator = ((IComponentFactory<ICalibratorTrainer>)factory).CreateComponent(ml);
            }
        }
 
        private delegate void SignatureSimpleComponent();
 
        private class SimpleArg
        {
            [Argument(ArgumentType.Required, ShortName = "r")]
            public int required = -1;
 
            [Argument(ArgumentType.AtMostOnce)]
            public int once = 1;
 
            [Argument(ArgumentType.LastOccurrenceWins)]
            public string text1 = "";
 
            [Argument(ArgumentType.AtMostOnce)]
            public string text2 = "";
 
            [Argument(ArgumentType.AtMostOnce)]
            public string text3 = "";
 
            [Argument(ArgumentType.AtMostOnce)]
            public string text4 = "";
 
            [Argument(ArgumentType.Multiple, SignatureType = typeof(SignatureSimpleComponent))]
            public IComponentFactory<SimpleArg> sub1 = CreateComponentFactory("sub1", "settings1");
 
            [Argument(ArgumentType.Multiple, SignatureType = typeof(SignatureSimpleComponent))]
            public IComponentFactory<SimpleArg> sub2 = CreateComponentFactory("sub2", "settings2");
 
            [Argument(ArgumentType.Multiple, SignatureType = typeof(SignatureSimpleComponent))]
            public IComponentFactory<SimpleArg> sub3 = CreateComponentFactory("sub3", "settings3");
 
            // REVIEW: include Subcomponent array for testing once it is supported
            //[Argument(ArgumentType.Multiple)]
            //public SubComponent[] sub4 = new SubComponent[] { new SubComponent("sub4", "settings4"), new SubComponent("sub5", "settings5") };
 
            /// <summary>
            /// ToString is overridden by CmdParser.GetSettings which is of primary for this test
            /// </summary>
            /// <returns></returns>
            public string ToString(IHostEnvironment env)
            {
                return CmdParser.GetSettings(env, this, new SimpleArg(), SettingsFlags.None);
            }
 
            public override bool Equals(object obj)
            {
                var arg = (SimpleArg)obj;
                if (arg.required != this.required)
                    return false;
                if (arg.once != this.once)
                    return false;
                if (arg.text1 != this.text1)
                    return false;
                if (arg.text2 != this.text2)
                    return false;
                if (arg.text3 != this.text3)
                    return false;
                if (arg.text4 != this.text4)
                    return false;
                if (!ComponentFactoryEquals(arg.sub1, this.sub1))
                    return false;
                if (!ComponentFactoryEquals(arg.sub2, this.sub2))
                    return false;
                if (!ComponentFactoryEquals(arg.sub3, this.sub3))
                    return false;
 
                return true;
            }
 
            public override int GetHashCode()
            {
                return base.GetHashCode();
            }
        }
 
        private static bool ComponentFactoryEquals(IComponentFactory left, IComponentFactory right)
        {
            if (!(left is ICommandLineComponentFactory commandLineLeft &&
                right is ICommandLineComponentFactory commandLineRight))
            {
                // can't compare non-command-line component factories
                return false;
            }
 
            return commandLineLeft.Name == commandLineRight.Name &&
                commandLineLeft.SignatureType == commandLineRight.SignatureType &&
                commandLineLeft.GetSettingsString() == commandLineRight.GetSettingsString();
        }
 
        private static readonly MethodInfo _createComponentFactoryMethod = typeof(CmdParser)
            .GetNestedType("ComponentFactoryFactory", BindingFlags.NonPublic)
            .GetMethod("CreateComponentFactory");
 
        public CmdLineReverseTests(ITestOutputHelper output) : base(output)
        {
        }
 
        private static IComponentFactory<SimpleArg> CreateComponentFactory(string name, string settings)
        {
            return (IComponentFactory<SimpleArg>)_createComponentFactoryMethod.Invoke(null,
                new object[]
                {
                    typeof(IComponentFactory<SimpleArg>),
                    typeof(SignatureSimpleComponent),
                    name,
                    new string[] { settings }
                });
        }
    }
}