File: UnitTests\ScoreSchemaTest.cs
Web Access
Project: src\test\Microsoft.ML.Core.Tests\Microsoft.ML.Core.Tests.csproj (Microsoft.ML.Core.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using Microsoft.ML.Data;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.RunTests
{
    public class ScoreSchemaTest : TestDataViewBase
    {
        public ScoreSchemaTest(ITestOutputHelper helper)
           : base(helper)
        {
        }
 
        private VBuffer<ReadOnlyMemory<char>> GenerateKeyNames(int keyCount)
        {
            // Initialize an empty array of strings.
            VBuffer<ReadOnlyMemory<char>> buffer = default;
 
            // Add strings to the empty buffer using a buffer editor.
            var bufferEditor = VBufferEditor.Create(ref buffer, keyCount);
            for (int i = 0; i < keyCount; ++i)
                bufferEditor.Values[i] = string.Format($"Key-{i}").AsMemory();
 
            // The input buffer is a string array containing {"Key-{0}", ..., "Key-{keyCount-1}"} now.
            return bufferEditor.Commit();
        }
 
        [Fact]
        public void SequencePredictorSchemaTest()
        {
            int keyCount = 10;
            var expectedScoreColumnType = new KeyDataViewType(typeof(uint), keyCount);
            VBuffer<ReadOnlyMemory<char>> keyNames = GenerateKeyNames(keyCount);
 
            var sequenceSchema = ScoreSchemaFactory.CreateSequencePredictionSchema(expectedScoreColumnType,
                AnnotationUtils.Const.ScoreColumnKind.SequenceClassification, keyNames);
 
            // Output schema should only contain one column, which is the predicted label.
            Assert.Single(sequenceSchema);
            var scoreColumn = sequenceSchema[0];
 
            // Check score column name.
            Assert.Equal(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, scoreColumn.Name);
 
            // Check score column type.
            var actualScoreColumnType = scoreColumn.Type as KeyDataViewType;
            Assert.NotNull(actualScoreColumnType);
            Assert.Equal(expectedScoreColumnType.Count, actualScoreColumnType.Count);
            Assert.Equal(expectedScoreColumnType.RawType, actualScoreColumnType.RawType);
 
            // Check metadata. Because keyNames is not empty, there should be three metadata fields.
            var scoreMetadata = scoreColumn.Annotations;
            Assert.Equal(3, scoreMetadata.Schema.Count);
 
            // Check metadata columns' names.
            Assert.Equal(AnnotationUtils.Kinds.KeyValues, scoreMetadata.Schema[0].Name);
            Assert.Equal(AnnotationUtils.Kinds.ScoreColumnKind, scoreMetadata.Schema[1].Name);
            Assert.Equal(AnnotationUtils.Kinds.ScoreValueKind, scoreMetadata.Schema[2].Name);
 
            // Check metadata columns' types.
            Assert.True(scoreMetadata.Schema[0].Type is VectorDataViewType);
            Assert.Equal(keyNames.Length, (scoreMetadata.Schema[0].Type as VectorDataViewType).Size);
            Assert.Equal(TextDataViewType.Instance, (scoreMetadata.Schema[0].Type as VectorDataViewType).ItemType);
            Assert.Equal(TextDataViewType.Instance, scoreColumn.Annotations.Schema[1].Type);
            Assert.Equal(TextDataViewType.Instance, scoreColumn.Annotations.Schema[2].Type);
 
            // Check metadata columns' values.
            var keyNamesGetter = scoreMetadata.GetGetter<VBuffer<ReadOnlyMemory<char>>>(scoreMetadata.Schema[0]);
            var actualKeyNames = new VBuffer<ReadOnlyMemory<char>>();
            keyNamesGetter(ref actualKeyNames);
            Assert.Equal(keyNames.Length, actualKeyNames.Length);
            Assert.Equal(keyNames.DenseValues(), actualKeyNames.DenseValues());
 
            var scoreColumnKindGetter = scoreMetadata.GetGetter<ReadOnlyMemory<char>>(scoreMetadata.Schema[1]);
            ReadOnlyMemory<char> scoreColumnKindValue = null;
            scoreColumnKindGetter(ref scoreColumnKindValue);
            Assert.Equal(AnnotationUtils.Const.ScoreColumnKind.SequenceClassification, scoreColumnKindValue.ToString());
 
            var scoreValueKindGetter = scoreMetadata.GetGetter<ReadOnlyMemory<char>>(scoreMetadata.Schema[2]);
            ReadOnlyMemory<char> scoreValueKindValue = null;
            scoreValueKindGetter(ref scoreValueKindValue);
            Assert.Equal(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, scoreValueKindValue.ToString());
        }
 
        [Fact]
        public void SequencePredictorSchemaWithoutKeyNamesMetadataTest()
        {
            int keyCount = 10;
            var expectedScoreColumnType = new KeyDataViewType(typeof(uint), keyCount);
            VBuffer<ReadOnlyMemory<char>> keyNames = GenerateKeyNames(0);
 
            var sequenceSchema = ScoreSchemaFactory.CreateSequencePredictionSchema(expectedScoreColumnType,
                AnnotationUtils.Const.ScoreColumnKind.SequenceClassification, keyNames);
 
            // Output schema should only contain one column, which is the predicted label.
            Assert.Single(sequenceSchema);
            var scoreColumn = sequenceSchema[0];
 
            // Check score column name.
            Assert.Equal(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, scoreColumn.Name);
 
            // Check score column type.
            var actualScoreColumnType = scoreColumn.Type as KeyDataViewType;
            Assert.NotNull(actualScoreColumnType);
            Assert.Equal(expectedScoreColumnType.Count, actualScoreColumnType.Count);
            Assert.Equal(expectedScoreColumnType.RawType, actualScoreColumnType.RawType);
 
            // Check metadata. Because keyNames is not empty, there should be three metadata fields.
            var scoreMetadata = scoreColumn.Annotations;
            Assert.Equal(2, scoreMetadata.Schema.Count);
 
            // Check metadata columns' names.
            Assert.Equal(AnnotationUtils.Kinds.ScoreColumnKind, scoreMetadata.Schema[0].Name);
            Assert.Equal(AnnotationUtils.Kinds.ScoreValueKind, scoreMetadata.Schema[1].Name);
 
            // Check metadata columns' types.
            Assert.Equal(TextDataViewType.Instance, scoreColumn.Annotations.Schema[0].Type);
            Assert.Equal(TextDataViewType.Instance, scoreColumn.Annotations.Schema[1].Type);
 
            // Check metadata columns' values.
            var scoreColumnKindGetter = scoreMetadata.GetGetter<ReadOnlyMemory<char>>(scoreMetadata.Schema[0]);
            ReadOnlyMemory<char> scoreColumnKindValue = null;
            scoreColumnKindGetter(ref scoreColumnKindValue);
            Assert.Equal(AnnotationUtils.Const.ScoreColumnKind.SequenceClassification, scoreColumnKindValue.ToString());
 
            var scoreValueKindGetter = scoreMetadata.GetGetter<ReadOnlyMemory<char>>(scoreMetadata.Schema[1]);
            ReadOnlyMemory<char> scoreValueKindValue = null;
            scoreValueKindGetter(ref scoreValueKindValue);
            Assert.Equal(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, scoreValueKindValue.ToString());
        }
    }
}