File: TestCommon.cs
Web Access
Project: src\test\Microsoft.ML.TestFrameworkCommon\Microsoft.ML.TestFrameworkCommon.csproj (Microsoft.ML.TestFrameworkCommon)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using Microsoft.ML.Data;
using Xunit;
 
namespace Microsoft.ML.TestFrameworkCommon
{
    public static class TestCommon
    {
        public static string GetOutputPath(string outDir, string name)
        {
            if (string.IsNullOrWhiteSpace(name))
                return null;
            return Path.Combine(outDir, name);
        }
        public static string GetOutputPath(string outDir, string subDir, string name)
        {
            if (string.IsNullOrWhiteSpace(subDir))
                return GetOutputPath(outDir, name);
            EnsureOutputDir(subDir, outDir);
            if (string.IsNullOrWhiteSpace(name))
                return null;
            return Path.Combine(outDir, subDir, name); // REVIEW: put the path in in braces in case the path has spaces
        }
 
        public static string GetDataPath(string dataDir, string name)
        {
            if (string.IsNullOrWhiteSpace(name))
                return null;
            return Path.GetFullPath(Path.Combine(dataDir, name));
        }
        public static string GetDataPath(string dataDir, string subDir, string name)
        {
            if (string.IsNullOrWhiteSpace(name))
                return null;
            return Path.GetFullPath(Path.Combine(dataDir, subDir, name));
        }
 
        public static string DeleteOutputPath(string outDir, string subDir, string name)
        {
            string path = GetOutputPath(outDir, subDir, name);
            if (!string.IsNullOrWhiteSpace(path))
                File.Delete(path);
            return path;
        }
        public static string DeleteOutputPath(string outDir, string name)
        {
            string path = GetOutputPath(outDir, name);
            if (!string.IsNullOrWhiteSpace(path))
                File.Delete(path);
            return path;
        }
 
        /// <summary>
        /// Environment variable containing path to the test data and BaseLineOutput folders.
        /// </summary>
        public const string TestDataDirEnvVariable = "ML_TEST_DATADIR";
 
        public static string GetRepoRoot()
        {
            string directory = Environment.GetEnvironmentVariable(TestDataDirEnvVariable);
            if (directory != null)
            {
                return directory;
            }
#if NETFRAMEWORK
            directory = AppDomain.CurrentDomain.BaseDirectory;
#else
            directory = AppContext.BaseDirectory;
#endif
 
            while (!Directory.Exists(Path.Combine(directory, ".git")) && directory != null)
            {
                directory = Directory.GetParent(directory).FullName;
            }
 
            if (directory == null)
            {
                return null;
            }
            return directory;
        }
 
        public static bool CheckSameSchemas(DataViewSchema sch1, DataViewSchema sch2, bool exactTypes = true, bool keyNames = true)
        {
            Assert.True(sch1.Count == sch2.Count, $"column count mismatch: {sch1.Count} vs {sch2.Count}");
 
            for (int col = 0; col < sch1.Count; col++)
            {
                string name1 = sch1[col].Name;
                string name2 = sch2[col].Name;
                Assert.True(name1 == name2, $"column name mismatch at index {col}: {name1} vs {name2}");
 
                var type1 = sch1[col].Type;
                var type2 = sch2[col].Type;
                Assert.True(EqualTypes(type1, type2, exactTypes), $"column type mismatch at index {col}");
 
                // This ensures that the two schemas map names to the same column indices.
                int col1;
                int col2;
                bool f1 = sch1.TryGetColumnIndex(name1, out col1);
                bool f2 = sch2.TryGetColumnIndex(name2, out col2);
 
                Assert.True(f1, "TryGetColumnIndex unexpectedly failed");
                Assert.True(f2, "TryGetColumnIndex unexpectedly failed");
                Assert.True(col1 == col2, $"TryGetColumnIndex on '{name1}' produced different results: '{col1}' vs '{col2}'");
 
                // This checks that an unknown metadata kind does the right thing.
                if (!CheckMetadataNames("PurpleDragonScales", 0, sch1, sch2, col, exactTypes, true))
                    return false;
 
                ulong vsize = type1 is VectorDataViewType vectorType ? (ulong)vectorType.Size : 0;
                if (!CheckMetadataNames("SlotNames", vsize, sch1, sch2, col, exactTypes, true))
                    return false;
 
                if (!keyNames)
                    continue;
 
                ulong ksize = type1.GetItemType() is KeyDataViewType keyType ? keyType.Count : 0;
                if (!CheckMetadataNames("KeyValues", ksize, sch1, sch2, col, exactTypes, false))
                    return false;
            }
 
            return true;
        }
 
        public static bool CompareVec<T>(in VBuffer<T> v1, in VBuffer<T> v2, int size, Func<T, T, bool> fn)
        {
            return CompareVec(in v1, in v2, size, (i, x, y) => fn(x, y));
        }
 
        public static bool CompareVec<T>(in VBuffer<T> v1, in VBuffer<T> v2, int size, Func<int, T, T, bool> fn)
        {
            Assert.True(size == 0 || v1.Length == size);
            Assert.True(size == 0 || v2.Length == size);
            Assert.True(v1.Length == v2.Length);
 
            var v1Values = v1.GetValues();
            var v2Values = v2.GetValues();
 
            if (v1.IsDense && v2.IsDense)
            {
                for (int i = 0; i < v1.Length; i++)
                {
                    var x1 = v1Values[i];
                    var x2 = v2Values[i];
                    if (!fn(i, x1, x2))
                        return false;
                }
                return true;
            }
 
            Assert.True(!v1.IsDense || !v2.IsDense);
            int iiv1 = 0;
            int iiv2 = 0;
            var v1Indices = v1.GetIndices();
            var v2Indices = v2.GetIndices();
            for (; ; )
            {
                int iv1 = v1.IsDense ? iiv1 : iiv1 < v1Indices.Length ? v1Indices[iiv1] : v1.Length;
                int iv2 = v2.IsDense ? iiv2 : iiv2 < v2Indices.Length ? v2Indices[iiv2] : v2.Length;
                T x1;
                T x2;
                int iv;
                if (iv1 == iv2)
                {
                    if (iv1 == v1.Length)
                        return true;
                    x1 = v1Values[iiv1];
                    x2 = v2Values[iiv2];
                    iv = iv1;
                    iiv1++;
                    iiv2++;
                }
                else if (iv1 < iv2)
                {
                    x1 = v1Values[iiv1];
                    x2 = default(T);
                    iv = iv1;
                    iiv1++;
                }
                else
                {
                    x1 = default(T);
                    x2 = v2Values[iiv2];
                    iv = iv2;
                    iiv2++;
                }
                if (!fn(iv, x1, x2))
                    return false;
            }
        }
 
        public static bool EqualTypes(DataViewType type1, DataViewType type2, bool exactTypes)
        {
            Assert.NotNull(type1);
            Assert.NotNull(type2);
 
            return exactTypes ? type1.Equals(type2) : type1.SameSizeAndItemType(type2);
        }
 
        /// <summary>
        /// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type,
        /// returns true if current and other vector types have the same size and item type.
        /// </summary>
        private static bool SameSizeAndItemType(this DataViewType columnType, DataViewType other)
        {
            if (other == null)
                return false;
 
            if (columnType.Equals(other))
                return true;
 
            // For vector types, we don't care about the factoring of the dimensions.
            if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType))
                return false;
            if (!vectorType.ItemType.Equals(otherVectorType.ItemType))
                return false;
            return vectorType.Size == otherVectorType.Size;
        }
 
        private static bool TryGetColumnIndex(this DataViewSchema schema, string name, out int col)
        {
            col = schema.GetColumnOrNull(name)?.Index ?? -1;
            return col >= 0;
        }
 
        private static bool CheckMetadataNames(string kind, ulong size, DataViewSchema sch1, DataViewSchema sch2, int col, bool exactTypes, bool mustBeText)
        {
            var names1 = default(VBuffer<ReadOnlyMemory<char>>);
            var names2 = default(VBuffer<ReadOnlyMemory<char>>);
 
            var t1 = sch1[col].Annotations.Schema.GetColumnOrNull(kind)?.Type;
            var t2 = sch2[col].Annotations.Schema.GetColumnOrNull(kind)?.Type;
            Assert.False((t1 == null) != (t2 == null), $"Different null-ness of {kind} metadata types");
 
            if (t1 == null)
            {
                Assert.True(CheckMetadataCallFailure(kind, sch1, col, ref names1));
                Assert.True(CheckMetadataCallFailure(kind, sch2, col, ref names2));
 
                return true;
            }
 
            Assert.False(size > int.MaxValue, $"{nameof(KeyDataViewType)}.{nameof(KeyDataViewType.Count)} is larger than int.MaxValue");
            Assert.True(EqualTypes(t1, t2, exactTypes), $"Different {kind} metadata types: {t1} vs {t2}");
 
            if (!(t1.GetItemType() is TextDataViewType))
            {
                if (!mustBeText)
                    return true;
 
                Assert.False(mustBeText, $"Unexpected {kind} metadata type");
            }
 
            Assert.True((int)size == t1.GetVectorSize(), $"{kind} metadata type wrong size: {t1.GetVectorSize()} vs {size}");
 
            sch1[col].Annotations.GetValue(kind, ref names1);
            sch2[col].Annotations.GetValue(kind, ref names2);
            Assert.True(CompareVec(in names1, in names2, (int)size, (a, b) => a.Span.SequenceEqual(b.Span)), $"Different {kind} metadata values");
 
            return true;
        }
 
        private static bool CheckMetadataCallFailure(string kind, DataViewSchema sch, int col, ref VBuffer<ReadOnlyMemory<char>> names)
        {
            try
            {
                sch[col].Annotations.GetValue(kind, ref names);
 
                return false;
            }
            catch (InvalidOperationException ex)
            {
                if (ex.Message != "Invalid call to 'GetValue'")
                {
                    return false;
                }
            }
            return true;
        }
 
        private static DataViewType GetItemType(this DataViewType columnType) => (columnType as VectorDataViewType)?.ItemType ?? columnType;
 
        private static int GetVectorSize(this DataViewType columnType) => (columnType as VectorDataViewType)?.Size ?? 0;
 
        private static void EnsureOutputDir(string subDir, string outDir)
        {
            Directory.CreateDirectory(Path.Combine(outDir, subDir));
        }
    }
}