File: DataPipe\TestDataPipeBase.cs
Web Access
Project: src\test\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj (Microsoft.ML.TestFramework)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFramework;
using Microsoft.ML.TestFrameworkCommon;
using Xunit;
 
namespace Microsoft.ML.RunTests
{
    public abstract partial class TestDataPipeBase : TestDataViewBase
    {
        public const string IrisDataPath = "iris.data";
 
        /// <summary>
        /// 'Workout test' for an estimator.
        /// Checks the following traits:
        /// - the estimator is applicable to the validFitInput and validForFitNotValidForTransformInput, and not applicable to validTransformInput and invalidInput;
        /// - the fitted transformer is applicable to validFitInput and validTransformInput, and not applicable to invalidInput and validForFitNotValidForTransformInput;
        /// - fitted transformer can be saved and re-loaded into the transformer with the same behavior.
        /// - schema propagation for fitted transformer conforms to schema propagation of estimator.
        /// </summary>
        protected void TestEstimatorCore(IEstimator<ITransformer> estimator,
            IDataView validFitInput, IDataView validTransformInput = null, IDataView invalidInput = null, IDataView validForFitNotValidForTransformInput = null, bool shouldDispose = false)
        {
            Contracts.AssertValue(estimator);
            Contracts.AssertValue(validFitInput);
            Contracts.AssertValueOrNull(validTransformInput);
            Contracts.AssertValueOrNull(invalidInput);
            Action<Action> mustFail = (Action action) =>
            {
                try
                {
                    action();
                    Assert.False(true);
                }
                catch (ArgumentOutOfRangeException) { }
                catch (InvalidOperationException) { }
                catch (TargetInvocationException ex)
                {
                    Exception e;
                    for (e = ex; e.InnerException != null; e = e.InnerException)
                    {
                    }
                    Assert.True(e is ArgumentOutOfRangeException || e is InvalidOperationException);
                    Assert.True(e.IsMarked());
                }
            };
 
            // Schema propagation tests for estimator.
            var outSchemaShape = estimator.GetOutputSchema(SchemaShape.Create(validFitInput.Schema));
            if (validTransformInput != null)
            {
                mustFail(() => estimator.GetOutputSchema(SchemaShape.Create(validTransformInput.Schema)));
                mustFail(() => estimator.Fit(validTransformInput));
            }
 
            if (invalidInput != null)
            {
                mustFail(() => estimator.GetOutputSchema(SchemaShape.Create(invalidInput.Schema)));
                mustFail(() => estimator.Fit(invalidInput));
            }
 
            if (validForFitNotValidForTransformInput != null)
            {
                estimator.GetOutputSchema(SchemaShape.Create(validForFitNotValidForTransformInput.Schema));
                estimator.Fit(validForFitNotValidForTransformInput);
            }
 
            var transformer = estimator.Fit(validFitInput);
            // Save and reload.
            string modelPath = GetOutputPath(FullTestName + "-model.zip");
            ML.Model.Save(transformer, validFitInput.Schema, modelPath);
 
            ITransformer loadedTransformer;
            DataViewSchema loadedInputSchema;
            using (var fs = File.OpenRead(modelPath))
                loadedTransformer = ML.Model.Load(fs, out loadedInputSchema);
            DeleteOutputPath(modelPath);
 
            // Run on train data.
            Action<IDataView> checkOnData = (IDataView data) =>
            {
                var schema = transformer.GetOutputSchema(data.Schema);
 
                // If it's a row to row mapper, then the output schema should be the same.
                if (transformer.IsRowToRowMapper)
                {
                    var mapper = transformer.GetRowToRowMapper(data.Schema);
                    Check(mapper.InputSchema == data.Schema, "InputSchemas were not identical to actual input schema");
                    TestCommon.CheckSameSchemas(schema, mapper.OutputSchema);
                }
                else
                {
                    mustFail(() => transformer.GetRowToRowMapper(data.Schema));
                }
 
                // Loaded transformer needs to have the same schema propagation.
                TestCommon.CheckSameSchemas(schema, loadedTransformer.GetOutputSchema(data.Schema));
                // Loaded schema needs to have the same schema as data.
                TestCommon.CheckSameSchemas(data.Schema, loadedInputSchema);
 
                var scoredTrain = transformer.Transform(data);
                var scoredTrain2 = loadedTransformer.Transform(data);
 
                // The schema of the transformed data must match the schema provided by schema propagation.
                TestCommon.CheckSameSchemas(schema, scoredTrain.Schema);
 
                // The schema and data of scored dataset must be identical between loaded
                // and original transformer.
                // This in turn means that the schema of loaded transformer matches for 
                // Transform and GetOutputSchema calls.
                TestCommon.CheckSameSchemas(scoredTrain.Schema, scoredTrain2.Schema);
                CheckSameValues(scoredTrain, scoredTrain2, exactDoubles: false);
            };
 
            checkOnData(validFitInput);
 
            if (validTransformInput != null)
                checkOnData(validTransformInput);
 
            if (invalidInput != null)
            {
                mustFail(() => transformer.GetOutputSchema(invalidInput.Schema));
                mustFail(() => transformer.Transform(invalidInput));
                mustFail(() => loadedTransformer.GetOutputSchema(invalidInput.Schema));
                mustFail(() => loadedTransformer.Transform(invalidInput));
            }
            if (validForFitNotValidForTransformInput != null)
            {
                mustFail(() => transformer.GetOutputSchema(validForFitNotValidForTransformInput.Schema));
                mustFail(() => transformer.Transform(validForFitNotValidForTransformInput));
                mustFail(() => loadedTransformer.GetOutputSchema(validForFitNotValidForTransformInput.Schema));
                mustFail(() => loadedTransformer.Transform(validForFitNotValidForTransformInput));
            }
 
            // Schema verification between estimator and transformer.
            var scoredTrainSchemaShape = SchemaShape.Create(transformer.GetOutputSchema(validFitInput.Schema));
            CheckSameSchemaShape(outSchemaShape, scoredTrainSchemaShape);
            (loadedTransformer as IDisposable)?.Dispose();
            if (shouldDispose) (transformer as IDisposable)?.Dispose();
        }
 
        private void CheckSameSchemaShape(SchemaShape promised, SchemaShape delivered)
        {
            Assert.True(promised.Count == delivered.Count);
            var promisedCols = promised.OrderBy(x => x.Name);
            var deliveredCols = delivered.OrderBy(x => x.Name);
 
            foreach (var (p, d) in promisedCols.Zip(deliveredCols, (p, d) => (p, d)))
            {
                Assert.Equal(p.Name, d.Name);
                // We want the 'promised' metadata to be a superset of 'delivered'.
                Assert.True(d.IsCompatibleWith(p), $"Mismatch on {p.Name}, there was a mismatch, or some unexpected annotations was present.");
                // We also want the 'delivered' to be a superset of 'promised'. Since the above
                // test must have worked if we got this far, I believe the only plausible reason
                // this could happen is if there was something promised but not delivered.
                Assert.True(p.IsCompatibleWith(d), $"Mismatch on {p.Name}, something was promised in the annotations but not delivered.");
            }
        }
 
        // REVIEW: incorporate the testing for re-apply logic here?
        /// <summary>
        /// Create PipeDataLoader from the given args, save it, re-load it, verify that the data of
        /// the loaded pipe matches the original.
        /// * pathData defaults to breast-cancer.txt.
        /// * actLoader is invoked for extra validation (if non-null).
        /// </summary>
        internal ILegacyDataLoader TestCore(string pathData, bool keepHidden, string[] argsPipe,
            Action<ILegacyDataLoader> actLoader = null, string suffix = "", string suffixBase = null, bool checkBaseline = true,
            bool forceDense = false, bool logCurs = false, bool roundTripText = true,
            bool checkTranspose = false, bool checkId = true, bool baselineSchema = true, int digitsOfPrecision = DigitsOfPrecision,
            NumberParseOption parseOption = NumberParseOption.Default)
        {
            Contracts.AssertValue(Env);
 
            MultiFileSource files;
            ILegacyDataLoader compositeLoader;
            var pipe1 = compositeLoader = CreatePipeDataLoader(_env, pathData, argsPipe, out files);
 
            actLoader?.Invoke(compositeLoader);
 
            // Re-apply pipe to the loader and check equality.
            var comp = compositeLoader as LegacyCompositeDataLoader;
            IDataView srcLoader = null;
            if (comp != null)
            {
                srcLoader = comp.View;
                while (srcLoader is IDataTransform)
                    srcLoader = ((IDataTransform)srcLoader).Source;
                var reappliedPipe = ApplyTransformUtils.ApplyAllTransformsToData(_env, comp.View, srcLoader);
                if (!CheckMetadataTypes(reappliedPipe.Schema))
                    Failed();
 
                if (!TestCommon.CheckSameSchemas(pipe1.Schema, reappliedPipe.Schema))
                    Failed();
                else if (!CheckSameValues(pipe1, reappliedPipe, checkId: checkId))
                    Failed();
            }
 
            if (logCurs)
            {
                string name = TestName + suffix + "-CursLog" + ".txt";
                string pathLog = DeleteOutputPath("SavePipe", name);
 
                using (var writer = OpenWriter(pathLog))
                using (_env.RedirectChannelOutput(writer, writer))
                {
                    long count = 0;
                    using (var curs = pipe1.GetRowCursorForAllColumns())
                    {
                        while (curs.MoveNext())
                        {
                            count++;
                        }
                    }
                    writer.WriteLine("Cursored through {0} rows", count);
                }
 
                CheckEqualityNormalized("SavePipe", name, digitsOfPrecision: digitsOfPrecision, parseOption: parseOption);
            }
 
            var pathModel = SavePipe(pipe1, suffix);
            var pipe2 = LoadPipe(pathModel, _env, files);
            if (!CheckMetadataTypes(pipe2.Schema))
                Failed();
 
            if (!TestCommon.CheckSameSchemas(pipe1.Schema, pipe2.Schema))
                Failed();
            else if (!CheckSameValues(pipe1, pipe2, checkId: checkId))
                Failed();
 
            if (pipe1.Schema.Count > 0)
            {
                // The text saver fails if there are no columns, so we cannot check in that case.
                if (!SaveLoadText(pipe1, _env, keepHidden, suffix, suffixBase, checkBaseline, forceDense,
                        roundTripText, digitsOfPrecision: digitsOfPrecision, parseOption: parseOption))
                {
                    Failed();
                }
                // The transpose saver likewise fails for the same reason.
                if (checkTranspose && !SaveLoadTransposed(pipe1, _env, suffix))
                    Failed();
            }
            if (!SaveLoad(pipe1, _env, suffix))
                Failed();
 
            // Check that the pipe doesn't shuffle when it cannot :).
            if (srcLoader != null)
            {
                // First we need to cache the data so it can be shuffled.
                var cachedData = new CacheDataView(_env, srcLoader, null);
                var newPipe = ApplyTransformUtils.ApplyAllTransformsToData(_env, comp.View, cachedData);
                if (!newPipe.CanShuffle)
                {
                    using (var c1 = newPipe.GetRowCursor(newPipe.Schema, new Random(123)))
                    using (var c2 = newPipe.GetRowCursorForAllColumns())
                    {
                        if (!CheckSameValues(c1, c2, true, true, true))
                            Failed();
                    }
                }
 
                // Join all filler threads of CacheDataView prior to the disposal of _wrt. 
                // Otherwise it may writes to a closed stream.
                cachedData.Wait();
            }
 
            // Baseline the schema, including metadata.
            if (baselineSchema)
            {
                string name = TestName + suffix + "-Schema" + ".txt";
                string path = DeleteOutputPath("SavePipe", name);
                using (var writer = OpenWriter(path))
                {
                    ShowSchemaCommand.RunOnData(writer,
                        new ShowSchemaCommand.Arguments() { ShowMetadataValues = true, ShowSteps = true },
                        pipe1);
                }
                if (!CheckEquality("SavePipe", name, digitsOfPrecision: digitsOfPrecision, parseOption: parseOption))
                    Log("*** ShowSchema failed on pipe1");
                else
                {
                    path = DeleteOutputPath("SavePipe", name);
                    using (var writer = OpenWriter(path))
                    {
                        ShowSchemaCommand.RunOnData(writer,
                            new ShowSchemaCommand.Arguments() { ShowMetadataValues = true, ShowSteps = true },
                            pipe2);
                    }
                    if (!CheckEquality("SavePipe", name, digitsOfPrecision: digitsOfPrecision, parseOption: parseOption))
                        Log("*** ShowSchema failed on pipe2");
                }
            }
 
            // REVIEW: What about tests for ensuring that shuffling produces an actual shuffled version?
 
            return pipe1;
        }
 
        private ILegacyDataLoader CreatePipeDataLoader(IHostEnvironment env, string pathData, string[] argsPipe, out MultiFileSource files)
        {
            VerifyArgParsing(env, argsPipe);
 
            // Default to breast-cancer.txt.
            if (string.IsNullOrEmpty(pathData))
                pathData = GetDataPath(TestDatasets.breastCancer.trainFilename);
 
            files = new MultiFileSource(pathData == "<none>" ? null : pathData);
            var sub = new SubComponent<ILegacyDataLoader, SignatureDataLoader>("Pipe", argsPipe);
            var pipe = sub.CreateInstance(env, files);
            if (!CheckMetadataTypes(pipe.Schema))
                Failed();
 
            return pipe;
        }
 
        protected void VerifyArgParsing(IHostEnvironment env, string[] strs)
        {
            string str = CmdParser.CombineSettings(strs);
            var args = new LegacyCompositeDataLoader.Arguments();
            if (!CmdParser.ParseArguments(Env, str, args))
            {
                Fail("Parsing arguments failed!");
                return;
            }
 
            // For the loader and each transform, verify that custom unparsing is correct.
            VerifyCustArgs(env, args.Loader);
            foreach (var kvp in args.Transforms)
                VerifyCustArgs(env, kvp.Value);
        }
 
        protected void VerifyCustArgs<TArg, TRes>(IHostEnvironment env, IComponentFactory<TArg, TRes> factory)
            where TRes : class
        {
            if (factory is ICommandLineComponentFactory commandLineFactory)
            {
                var str = commandLineFactory.GetSettingsString();
                var info = env.ComponentCatalog.GetLoadableClassInfo(commandLineFactory.Name, commandLineFactory.SignatureType);
                Assert.NotNull(info);
                var def = info.CreateArguments();
 
                var a1 = info.CreateArguments();
                CmdParser.ParseArguments(Env, str, a1);
 
                // Get both the expanded and custom forms.
                string exp1 = CmdParser.GetSettings(Env, a1, def, SettingsFlags.Default | SettingsFlags.NoUnparse);
                string cust = CmdParser.GetSettings(Env, a1, def);
 
                // Map cust back to an object, then get its full form.
                var a2 = info.CreateArguments();
                CmdParser.ParseArguments(Env, cust, a2);
                string exp2 = CmdParser.GetSettings(Env, a2, def, SettingsFlags.Default | SettingsFlags.NoUnparse);
 
                if (exp1 != exp2)
                    Fail("Custom unparse failed on '{0}' starting with '{1}': '{2}' vs '{3}'", commandLineFactory.Name, str, exp1, exp2);
            }
            else
            {
                Fail($"TestDataPipeBase was called with a non command line loader or transform '{factory}'");
            }
        }
 
        protected bool SaveLoadText(IDataView view, IHostEnvironment env,
            bool hidden = true, string suffix = "", string suffixBase = null,
            bool checkBaseline = true, bool forceDense = false, bool roundTrip = true,
            bool outputSchema = true, bool outputHeader = true, int digitsOfPrecision = DigitsOfPrecision,
            NumberParseOption parseOption = NumberParseOption.Default)
        {
            TextSaver saver = new TextSaver(env, new TextSaver.Arguments() { Dense = forceDense, OutputSchema = outputSchema, OutputHeader = outputHeader });
            var schema = view.Schema;
            List<int> savable = new List<int>();
            for (int c = 0; c < schema.Count; ++c)
            {
                DataViewType type = schema[c].Type;
                if (saver.IsColumnSavable(type) && (hidden || !schema[c].IsHidden))
                    savable.Add(c);
            }
 
            string name = TestName + suffix + "-Data" + ".txt";
            string pathData = DeleteOutputPath("SavePipe", name);
 
            string argsLoader;
            using (var stream = File.Create(pathData))
                saver.SaveData(out argsLoader, stream, view, savable.ToArray());
 
            if (checkBaseline)
            {
                string nameBase = suffixBase != null ? TestName + suffixBase + "-Data" + ".txt" : name;
                CheckEquality("SavePipe", name, nameBase, digitsOfPrecision: digitsOfPrecision, parseOption: parseOption);
            }
 
            if (!roundTrip)
                return true;
 
            if (savable.Count < view.Schema.Count)
            {
                // Restrict the comparison to the subset of columns we were able to save.
                var chooseargs = new ChooseColumnsByIndexTransform.Options();
                chooseargs.Indices = savable.ToArray();
                view = new ChooseColumnsByIndexTransform(env, chooseargs, view);
            }
 
            var args = new TextLoader.Options() { AllowSparse = true, AllowQuoting = true };
            if (!CmdParser.ParseArguments(Env, argsLoader, args))
            {
                Fail("Couldn't parse the args '{0}' in '{1}'", argsLoader, pathData);
                return Failed();
            }
 
            // Note that we don't pass in "args", but pass in a default args so we test
            // the auto-schema parsing.
            var loadedData = ML.Data.LoadFromTextFile(pathData, options: args);
            if (!CheckMetadataTypes(loadedData.Schema))
                Failed();
 
            if (!TestCommon.CheckSameSchemas(view.Schema, loadedData.Schema, exactTypes: false, keyNames: false))
                return Failed();
            if (!CheckSameValues(view, loadedData, exactTypes: false, exactDoubles: false, checkId: false))
                return Failed();
            return true;
        }
 
        private protected string SavePipe(ILegacyDataLoader pipe, string suffix = "", string dir = "Pipeline")
        {
            string name = TestName + suffix + ".zip";
            string pathModel = DeleteOutputPath("SavePipe", name);
 
            using (var file = Env.CreateOutputFile(pathModel))
            using (var strm = file.CreateWriteStream())
            using (var rep = RepositoryWriter.CreateNew(strm, Env))
            {
                ModelSaveContext.SaveModel(rep, pipe, dir);
                rep.Commit();
            }
            return pathModel;
        }
 
        protected string[] Concat(string[] args1, string[] args2)
        {
            string[] res = new string[args1.Length + args2.Length];
            Array.Copy(args1, res, args1.Length);
            Array.Copy(args2, 0, res, args1.Length, args2.Length);
            return res;
        }
 
        private ILegacyDataLoader LoadPipe(string pathModel, IHostEnvironment env, IMultiStreamSource files)
        {
            using (var file = Env.OpenInputFile(pathModel))
            using (var strm = file.OpenReadStream())
            using (var rep = RepositoryReader.Open(strm, env))
            {
                ILegacyDataLoader pipe;
                ModelLoadContext.LoadModel<ILegacyDataLoader, SignatureLoadDataLoader>(env,
                    out pipe, rep, "Pipeline", files);
                return pipe;
            }
        }
 
        protected bool CheckMetadataTypes(DataViewSchema sch)
        {
            var hs = new HashSet<string>();
            for (int col = 0; col < sch.Count; col++)
            {
                var typeSlot = sch[col].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.SlotNames)?.Type;
                var typeKeys = sch[col].Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type;
 
                hs.Clear();
                foreach (var metaColumn in sch[col].Annotations.Schema)
                {
                    if (metaColumn.Name == null || metaColumn.Type == null)
                    {
                        Fail("Null returned from GetMetadataTypes");
                        return Failed();
                    }
                    if (!hs.Add(metaColumn.Name))
                    {
                        Fail("Duplicate metadata type: {0}", metaColumn.Name);
                        return Failed();
                    }
                    if (metaColumn.Name == AnnotationUtils.Kinds.SlotNames)
                    {
                        if (typeSlot == null || !typeSlot.Equals(metaColumn.Type))
                        {
                            Fail("SlotNames types don't match");
                            return Failed();
                        }
                        typeSlot = null;
                        continue;
                    }
                    if (metaColumn.Name == AnnotationUtils.Kinds.KeyValues)
                    {
                        if (typeKeys == null || !typeKeys.Equals(metaColumn.Type))
                        {
                            Fail("KeyValues types don't match");
                            return Failed();
                        }
                        typeKeys = null;
                    }
                }
 
                if (!Check(typeSlot == null, "SlotNames not in GetMetadataTypes"))
                    return Failed();
                if (!Check(typeKeys == null, "KeyValues not in GetMetadataTypes"))
                    return Failed();
            }
 
            return true;
        }
 
        protected bool SaveLoad(IDataView view, IHostEnvironment env, string suffix = "")
        {
            var saverArgs = new BinarySaver.Arguments();
            saverArgs.MaxBytesPerBlock = null;
            saverArgs.MaxRowsPerBlock = 100;
            BinarySaver saver = new BinarySaver(env, saverArgs);
 
            var schema = view.Schema;
            List<int> savable = new List<int>();
            for (int c = 0; c < schema.Count; ++c)
            {
                DataViewType type = schema[c].Type;
                if (saver.IsColumnSavable(type))
                    savable.Add(c);
            }
 
            string name = TestName + suffix + "-Data" + ".idv";
            string pathData = DeleteOutputPath("SavePipe", name);
 
            using (var stream = File.Create(pathData))
            {
                saver.SaveData(stream, view, savable.ToArray());
                Log("View saved in {0} bytes", stream.Length);
            }
 
            if (savable.Count < view.Schema.Count)
            {
                // Restrict the comparison to the subset of columns we were able to save.
                var chooseargs = new ChooseColumnsByIndexTransform.Options();
                chooseargs.Indices = savable.ToArray();
                view = new ChooseColumnsByIndexTransform(env, chooseargs, view);
            }
 
            var args = new BinaryLoader.Arguments();
            using (BinaryLoader loader = new BinaryLoader(env, args, pathData))
            {
                if (!CheckMetadataTypes(loader.Schema))
                    return Failed();
 
                if (!TestCommon.CheckSameSchemas(view.Schema, loader.Schema))
                    return Failed();
                if (!CheckSameValues(view, loader, checkId: false))
                    return Failed();
            }
            return true;
        }
 
        protected bool SaveLoadTransposed(IDataView view, IHostEnvironment env, string suffix = "")
        {
            var saverArgs = new TransposeSaver.Arguments();
            saverArgs.WriteRowData = false; // Force it to use this the re-transposition logic.
            TransposeSaver saver = new TransposeSaver(env, saverArgs);
 
            var schema = view.Schema;
            List<int> savable = new List<int>();
            for (int c = 0; c < schema.Count; ++c)
            {
                DataViewType type = schema[c].Type;
                if (saver.IsColumnSavable(type))
                    savable.Add(c);
            }
            if (savable.Count == 0)
            {
                Log("No columns were savable in transposed saver, skipping");
                return true;
            }
 
            string name = TestName + suffix + "-Data" + ".tdv";
            string pathData = DeleteOutputPath("SavePipe", name);
 
            using (var stream = File.Create(pathData))
            {
                saver.SaveData(stream, view, savable.ToArray());
                Log("View saved in {0} bytes", stream.Length);
            }
 
            if (savable.Count < view.Schema.Count)
            {
                // Restrict the comparison to the subset of columns we were able to save.
                var chooseargs = new ChooseColumnsByIndexTransform.Options();
                chooseargs.Indices = savable.ToArray();
                view = new ChooseColumnsByIndexTransform(env, chooseargs, view);
            }
 
            var args = new TransposeLoader.Arguments();
            MultiFileSource src = new MultiFileSource(pathData);
            TransposeLoader loader = new TransposeLoader(env, args, src);
            if (!CheckMetadataTypes(loader.Schema))
                return Failed();
 
            if (!TestCommon.CheckSameSchemas(view.Schema, loader.Schema))
                return Failed();
            if (!CheckSameValues(view, loader, checkId: false))
                return Failed();
            return true;
        }
    }
 
    public abstract partial class TestDataViewBase : BaseTestBaseline
    {
 
        public class SentimentData
        {
            [ColumnName("Label")]
            public bool Sentiment;
            public string SentimentText;
        }
 
        public class SentimentPrediction
        {
            [ColumnName("PredictedLabel")]
            public bool Sentiment;
 
            public float Score;
        }
 
        protected bool Failed()
        {
            Contracts.Assert(!IsPassing);
            return false;
        }
 
        protected bool EqualTypes(DataViewType type1, DataViewType type2, bool exactTypes)
        {
            Contracts.AssertValue(type1);
            Contracts.AssertValue(type2);
 
            return exactTypes ? type1.Equals(type2) : type1.SameSizeAndItemType(type2);
        }
 
        protected bool CheckSameValues(IDataView view1, IDataView view2, bool exactTypes = true, bool exactDoubles = true, bool checkId = true)
        {
            Contracts.Assert(view1.Schema.Count == view2.Schema.Count);
 
            bool all = true;
            bool tmp;
 
            using (var curs1 = view1.GetRowCursorForAllColumns())
            using (var curs2 = view2.GetRowCursorForAllColumns())
            {
                Check(curs1.Schema == view1.Schema, "Schema of view 1 and its cursor differed");
                Check(curs2.Schema == view2.Schema, "Schema of view 2 and its cursor differed");
                tmp = CheckSameValues(curs1, curs2, exactTypes, exactDoubles, checkId, true);
            }
            Check(tmp, "All same failed");
            all &= tmp;
 
            var view2EvenCols = view2.Schema.Where(col => (col.Index & 1) == 0);
            using (var curs1 = view1.GetRowCursorForAllColumns())
            using (var curs2 = view2.GetRowCursor(view2EvenCols))
            {
                Check(curs1.Schema == view1.Schema, "Schema of view 1 and its cursor differed");
                Check(curs2.Schema == view2.Schema, "Schema of view 2 and its cursor differed");
                tmp = CheckSameValues(curs1, curs2, exactTypes, exactDoubles, checkId, false);
            }
            Check(tmp, "Even same failed");
            all &= tmp;
 
            var view2OddCols = view2.Schema.Where(col => (col.Index & 1) != 0);
            using (var curs1 = view1.GetRowCursorForAllColumns())
            using (var curs2 = view2.GetRowCursor(view2OddCols))
            {
                Check(curs1.Schema == view1.Schema, "Schema of view 1 and its cursor differed");
                Check(curs2.Schema == view2.Schema, "Schema of view 2 and its cursor differed");
                tmp = CheckSameValues(curs1, curs2, exactTypes, exactDoubles, checkId, false);
            }
            Check(tmp, "Odd same failed");
 
            using (var curs1 = view1.GetRowCursorForAllColumns())
            {
                Check(curs1.Schema == view1.Schema, "Schema of view 1 and its cursor differed");
                tmp = CheckSameValues(curs1, view2, exactTypes, exactDoubles, checkId);
            }
            Check(tmp, "Single value same failed");
 
            all &= tmp;
            return all;
        }
 
        protected bool CheckSameValues(DataViewRowCursor curs1, DataViewRowCursor curs2, bool exactTypes, bool exactDoubles, bool checkId, bool checkIdCollisions = true)
        {
            Contracts.Assert(curs1.Schema.Count == curs2.Schema.Count);
 
            // Get the comparison delegates for each column.
            int colLim = curs1.Schema.Count;
            Func<bool>[] comps = new Func<bool>[colLim];
            for (int col = 0; col < colLim; col++)
            {
                var f1 = curs1.IsColumnActive(curs1.Schema[col]);
                var f2 = curs2.IsColumnActive(curs2.Schema[col]);
 
                if (f1 && f2)
                {
                    var type1 = curs1.Schema[col].Type;
                    var type2 = curs2.Schema[col].Type;
                    if (!EqualTypes(type1, type2, exactTypes))
                    {
                        Fail($"Different types {type1} and {type2}");
                        return Failed();
                    }
                    comps[col] = GetColumnComparer(curs1, curs2, col, type1, exactDoubles);
                }
            }
            ValueGetter<DataViewRowId> idGetter = null;
            Func<bool> idComp = checkId ? GetIdComparer(curs1, curs2, out idGetter) : null;
            HashSet<DataViewRowId> idsSeen = null;
            if (checkIdCollisions && idGetter == null)
                idGetter = curs1.GetIdGetter();
            long idCollisions = 0;
            DataViewRowId id = default(DataViewRowId);
 
            for (; ; )
            {
                bool f1 = curs1.MoveNext();
                bool f2 = curs2.MoveNext();
                if (f1 != f2)
                {
                    if (f1)
                        Fail("Left has more rows at position: {0}", curs1.Position);
                    else
                        Fail("Right has more rows at position: {0}", curs2.Position);
                    return Failed();
                }
 
                if (!f1)
                {
                    if (idCollisions > 0)
                        Fail("{0} id collisions among {1} items", idCollisions, Utils.Size(idsSeen) + idCollisions);
                    return idCollisions == 0;
                }
                else if (checkIdCollisions)
                {
                    idGetter(ref id);
                    if (!Utils.Add(ref idsSeen, id))
                    {
                        if (idCollisions == 0)
                            Log("Id collision {0} at {1}, further collisions will not be logged", id, curs1.Position);
                        idCollisions++;
                    }
                }
 
                Contracts.Assert(curs1.Position == curs2.Position);
 
                for (int col = 0; col < colLim; col++)
                {
                    var comp = comps[col];
                    if (comp != null && !comp())
                    {
                        Fail("Different values in column {0} of row {1}", col, curs1.Position);
                        return Failed();
                    }
                    if (idComp != null && !idComp())
                    {
                        Fail("Different values in ID of row {0}", curs1.Position);
                        return Failed();
                    }
                }
            }
        }
 
        protected bool CheckSameValues(DataViewRowCursor curs1, IDataView view2, bool exactTypes = true, bool exactDoubles = true, bool checkId = true)
        {
            Contracts.Assert(curs1.Schema.Count == view2.Schema.Count);
 
            // Get a cursor for each column.
            int colLim = curs1.Schema.Count;
            var cursors = new DataViewRowCursor[colLim];
            try
            {
                for (int col = 0; col < colLim; col++)
                {
                    // curs1 should have all columns active (for simplicity of the code here).
                    Contracts.Assert(curs1.IsColumnActive(curs1.Schema[col]));
                    cursors[col] = view2.GetRowCursorForAllColumns();
                }
 
                // Get the comparison delegates for each column.
                Func<bool>[] comps = new Func<bool>[colLim];
                // We have also one ID comparison delegate for each cursor.
                Func<bool>[] idComps = new Func<bool>[cursors.Length];
                for (int col = 0; col < colLim; col++)
                {
                    Contracts.Assert(cursors[col] != null);
                    var type1 = curs1.Schema[col].Type;
                    var type2 = cursors[col].Schema[col].Type;
                    if (!EqualTypes(type1, type2, exactTypes))
                    {
                        Fail("Different types");
                        return Failed();
                    }
                    comps[col] = GetColumnComparer(curs1, cursors[col], col, type1, exactDoubles);
                    ValueGetter<DataViewRowId> idGetter;
                    idComps[col] = checkId ? GetIdComparer(curs1, cursors[col], out idGetter) : null;
                }
 
                for (; ; )
                {
                    bool f1 = curs1.MoveNext();
                    for (int col = 0; col < colLim; col++)
                    {
                        bool f2 = cursors[col].MoveNext();
                        if (f1 != f2)
                        {
                            if (f1)
                                Fail("Left has more rows at position: {0}", curs1.Position);
                            else
                                Fail("Right {0} has more rows at position: {1}", col, cursors[2].Position);
                            return Failed();
                        }
                    }
 
                    if (!f1)
                        return true;
 
                    for (int col = 0; col < colLim; col++)
                    {
                        Contracts.Assert(curs1.Position == cursors[col].Position);
                        var comp = comps[col];
                        if (comp != null && !comp())
                        {
                            Fail("Different values in column {0} of row {1}", col, curs1.Position);
                            return Failed();
                        }
                        comp = idComps[col];
                        if (comp != null && !comp())
                        {
                            Fail("Different values in ID values for column {0} cursor of row {1}", col, curs1.Position);
                            return Failed();
                        }
                    }
                }
            }
            finally
            {
                for (int col = 0; col < colLim; col++)
                {
                    var c = cursors[col];
                    if (c != null)
                        c.Dispose();
                }
            }
        }
 
        protected Func<bool> GetIdComparer(DataViewRow r1, DataViewRow r2, out ValueGetter<DataViewRowId> idGetter)
        {
            var g1 = r1.GetIdGetter();
            idGetter = g1;
            var g2 = r2.GetIdGetter();
            DataViewRowId v1 = default(DataViewRowId);
            DataViewRowId v2 = default(DataViewRowId);
            return
                () =>
                {
                    g1(ref v1);
                    g2(ref v2);
                    return v1.Equals(v2);
                };
        }
 
        protected Func<bool> GetColumnComparer(DataViewRow r1, DataViewRow r2, int col, DataViewType type, bool exactDoubles)
        {
            if (!(type is VectorDataViewType vectorType))
            {
                Type rawType = type.RawType;
                if (rawType == typeof(sbyte))
                    return GetComparerOne<sbyte>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(byte))
                    return GetComparerOne<byte>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(short))
                    return GetComparerOne<short>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(ushort))
                    return GetComparerOne<ushort>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(int))
                    return GetComparerOne<int>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(uint))
                    return GetComparerOne<uint>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(long))
                    return GetComparerOne<long>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(ulong))
                    return GetComparerOne<ulong>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(float))
                {
                    if (exactDoubles)
                        return GetComparerOne<float>(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
                    else
                        return GetComparerOne<float>(r1, r2, col, EqualWithEpsSingle);
                }
                else if (rawType == typeof(double))
                {
                    if (exactDoubles)
                        return GetComparerOne<double>(r1, r2, col, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
                    else
                        return GetComparerOne<double>(r1, r2, col, EqualWithEpsDouble);
                }
                else if (rawType == typeof(ReadOnlyMemory<char>))
                    return GetComparerOne<ReadOnlyMemory<char>>(r1, r2, col, (a, b) => a.Span.SequenceEqual(b.Span));
                else if (rawType == typeof(bool))
                    return GetComparerOne<bool>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(TimeSpan))
                    return GetComparerOne<TimeSpan>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(DateTime))
                    return GetComparerOne<DateTime>(r1, r2, col, (x, y) => x == y);
                else if (rawType == typeof(DateTimeOffset))
                    return GetComparerOne<DateTimeOffset>(r1, r2, col, (x, y) => x.Equals(y));
                else if (rawType == typeof(DataViewRowId))
                    return GetComparerOne<DataViewRowId>(r1, r2, col, (x, y) => x.Equals(y));
                else
                    return () => true;
            }
            else
            {
                int size = vectorType.Size;
                Contracts.Assert(size >= 0);
                Type itemType = vectorType.ItemType.RawType;
 
                if (itemType == typeof(sbyte))
                    return GetComparerVec<sbyte>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(byte))
                    return GetComparerVec<byte>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(short))
                    return GetComparerVec<short>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(ushort))
                    return GetComparerVec<ushort>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(int))
                    return GetComparerVec<int>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(uint))
                    return GetComparerVec<uint>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(long))
                    return GetComparerVec<long>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(ulong))
                    return GetComparerVec<ulong>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(float))
                {
                    if (exactDoubles)
                        return GetComparerVec<float>(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
                    else
                        return GetComparerVec<float>(r1, r2, col, size, EqualWithEpsSingle);
                }
                else if (itemType == typeof(double))
                {
                    if (exactDoubles)
                        return GetComparerVec<double>(r1, r2, col, size, (x, y) => FloatUtils.GetBits(x) == FloatUtils.GetBits(y));
                    else
                        return GetComparerVec<double>(r1, r2, col, size, EqualWithEpsDouble);
                }
                else if (itemType == typeof(ReadOnlyMemory<char>))
                    return GetComparerVec<ReadOnlyMemory<char>>(r1, r2, col, size, (a, b) => a.Span.SequenceEqual(b.Span));
                else if (itemType == typeof(bool))
                    return GetComparerVec<bool>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(TimeSpan))
                    return GetComparerVec<TimeSpan>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(DateTime))
                    return GetComparerVec<DateTime>(r1, r2, col, size, (x, y) => x == y);
                else if (itemType == typeof(DateTimeOffset))
                    return GetComparerVec<DateTimeOffset>(r1, r2, col, size, (x, y) => x.Equals(y));
                else if (itemType == typeof(DataViewRowId))
                    return GetComparerVec<DataViewRowId>(r1, r2, col, size, (x, y) => x.Equals(y));
            }
 
            throw Contracts.Except("Unknown type in GetColumnComparer: '{0}'", type);
        }
 
        private const Double DoubleEps = 1e-9;
 
        private static bool EqualWithEpsDouble(Double x, Double y)
        {
            // bitwise comparison is needed because Abs(Inf-Inf) and Abs(NaN-NaN) are not 0s.
            return FloatUtils.GetBits(x) == FloatUtils.GetBits(y) || Math.Abs(x - y) < DoubleEps;
        }
 
        private const float SingleEps = 1e-6f;
 
        private static bool EqualWithEpsSingle(float x, float y)
        {
            // bitwise comparison is needed because Abs(Inf-Inf) and Abs(NaN-NaN) are not 0s.
            return FloatUtils.GetBits(x) == FloatUtils.GetBits(y) || Math.Abs(x - y) < SingleEps;
        }
 
        protected Func<bool> GetComparerOne<T>(DataViewRow r1, DataViewRow r2, int col, Func<T, T, bool> fn)
        {
            var g1 = r1.GetGetter<T>(r1.Schema[col]);
            var g2 = r2.GetGetter<T>(r2.Schema[col]);
            T v1 = default(T);
            T v2 = default(T);
            return
                () =>
                {
                    g1(ref v1);
                    g2(ref v2);
                    if (!fn(v1, v2))
                        return false;
                    return true;
                };
        }
 
        protected Func<bool> GetComparerVec<T>(DataViewRow r1, DataViewRow r2, int col, int size, Func<T, T, bool> fn)
        {
            var g1 = r1.GetGetter<VBuffer<T>>(r1.Schema[col]);
            var g2 = r2.GetGetter<VBuffer<T>>(r2.Schema[col]);
            var v1 = default(VBuffer<T>);
            var v2 = default(VBuffer<T>);
            return
                () =>
                {
                    g1(ref v1);
                    g2(ref v2);
                    return TestCommon.CompareVec<T>(in v1, in v2, size, fn);
                };
        }
 
        // Verifies the equality of the values returned by the single valued getters passed in as parameters.
        protected void VerifyOneEquality<T>(ValueGetter<T> oneGetter, ValueGetter<T> oneNGetter)
        {
            T f1 = default(T);
 
            T f1n = default(T);
            oneGetter(ref f1);
            oneNGetter(ref f1n);
            Assert.Equal(f1, f1n);
        }
 
        // Verifies the equality of the values returned by the vector valued getters passed in as parameters using the provided compare function.
        protected void VerifyVecEquality<T>(ValueGetter<VBuffer<T>> vecGetter, ValueGetter<VBuffer<T>> vecNGetter, Func<int, T, T, bool> compare, int size)
        {
            VBuffer<T> fv = default(VBuffer<T>);
            VBuffer<T> fvn = default(VBuffer<T>);
            vecGetter(ref fv);
            vecNGetter(ref fvn);
            Assert.True(TestCommon.CompareVec(in fv, in fvn, size, compare));
        }
    }
}