File: TestSparseDataView.cs
Web Access
Project: src\test\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj (Microsoft.ML.TestFramework)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.Data;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.RunTests
{
    public sealed class TestSparseDataView : TestDataViewBase
    {
        private const string Cat = "DataView";
 
        public TestSparseDataView(ITestOutputHelper obj) : base(obj)
        {
        }
 
        private class DenseExample<T>
        {
            [VectorType(2)]
            public T[] X;
        }
 
        private class SparseExample<T>
        {
            [VectorType(5)]
            public VBuffer<T> X;
        }
 
        [Fact]
        [TestCategory(Cat)]
        public void SparseDataView()
        {
            GenericSparseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f });
            GenericSparseDataView(new int[] { 1, 2, 3 }, new int[] { 1, 10, 100 });
            GenericSparseDataView(new bool[] { true, true, true }, new bool[] { false, false, false });
            GenericSparseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 });
            GenericSparseDataView(new ReadOnlyMemory<char>[] { "a".AsMemory(), "b".AsMemory(), "c".AsMemory() },
                                  new ReadOnlyMemory<char>[] { "aa".AsMemory(), "bb".AsMemory(), "cc".AsMemory() });
        }
 
        private void GenericSparseDataView<T>(T[] v1, T[] v2)
        {
            var inputs = new[] {
                new SparseExample<T>() { X = new VBuffer<T> (5, 3, v1, new int[] { 0, 2, 4 }) },
                new SparseExample<T>() { X = new VBuffer<T> (5, 3, v2, new int[] { 0, 1, 3 }) }
            };
            var env = new MLContext(1);
            var data = env.Data.LoadFromEnumerable(inputs);
            var value = new VBuffer<T>();
            int n = 0;
            using (var cur = data.GetRowCursorForAllColumns())
            {
                var getter = cur.GetGetter<VBuffer<T>>(cur.Schema[0]);
                while (cur.MoveNext())
                {
                    getter(ref value);
                    Assert.True(value.GetValues().Length == 3);
                    ++n;
                }
            }
            Assert.True(n == 2);
            var iter = env.Data.CreateEnumerable<SparseExample<T>>(data, false).GetEnumerator();
            n = 0;
            while (iter.MoveNext())
                ++n;
            Assert.True(n == 2);
        }
 
        [Fact]
        [TestCategory(Cat)]
        public void DenseDataView()
        {
            GenericDenseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f });
            GenericDenseDataView(new int[] { 1, 2, 3 }, new int[] { 1, 10, 100 });
            GenericDenseDataView(new bool[] { true, true, true }, new bool[] { false, false, false });
            GenericDenseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 });
            GenericDenseDataView(new ReadOnlyMemory<char>[] { "a".AsMemory(), "b".AsMemory(), "c".AsMemory() },
                                 new ReadOnlyMemory<char>[] { "aa".AsMemory(), "bb".AsMemory(), "cc".AsMemory() });
        }
 
        private void GenericDenseDataView<T>(T[] v1, T[] v2)
        {
            var inputs = new[] {
                new DenseExample<T>() { X = v1 },
                new DenseExample<T>() { X = v2 }
            };
            var env = new MLContext(1);
            var data = env.Data.LoadFromEnumerable(inputs);
            var value = new VBuffer<T>();
            int n = 0;
            using (var cur = data.GetRowCursorForAllColumns())
            {
                var getter = cur.GetGetter<VBuffer<T>>(cur.Schema[0]);
                while (cur.MoveNext())
                {
                    getter(ref value);
                    Assert.True(value.GetValues().Length == 3);
                    ++n;
                }
            }
            Assert.True(n == 2);
            var iter = env.Data.CreateEnumerable<DenseExample<T>>(data, false).GetEnumerator();
            n = 0;
            while (iter.MoveNext())
                ++n;
            Assert.True(n == 2);
        }
    }
}