File: NerTests.cs
Web Access
Project: src\test\Microsoft.ML.TorchSharp.Tests\Microsoft.ML.TorchSharp.Tests.csproj (Microsoft.ML.TorchSharp.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
using Microsoft.ML.TorchSharp;
using Microsoft.ML.TorchSharp.NasBert;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.TorchSharp.Tests
{
    public class NerTests : TestDataPipeBase
    {
        public NerTests(ITestOutputHelper output) : base(output)
        {
        }
 
        private class TestSingleSentenceData
        {
            public string Sentence;
            public string[] Label;
        }
 
        private class Label
        {
            public string Key { get; set; }
        }
 
        [Fact]
        public void TestSimpleNer()
        {
            var labels = ML.Data.LoadFromEnumerable(
                new[] {
                new Label { Key = "PERSON" },
                new Label { Key = "CITY" },
                new Label { Key = "COUNTRY"  },
                new Label { Key = "B_WORK_OF_ART"  },
                new Label { Key = "WORK_OF_ART"  },
                new Label { Key = "B_NORP"  },
                });
 
            var dataView = ML.Data.LoadFromEnumerable(
                new List<TestSingleSentenceData>(new TestSingleSentenceData[] {
                    new()
                    {
                        Sentence = "Alice and Bob live in the liechtenstein",
                        Label = new string[]{"PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY" }
                    },
                     new()
                     {
                        Sentence = "Alice and Bob live in the USA",
                        Label = new string[]{"PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY"}
                     },
                     new()
                     {
                         Sentence = "WW II Landmarks on the Great Earth of China : Eternal Memories of Taihang Mountain",
                         Label = new string[]{"B_WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART" }
                     },
                     new()
                     {
                         Sentence = "This campaign broke through the Japanese army 's blockade to reach base areas behind enemy lines , stirring up anti-Japanese spirit throughout the nation and influencing the situation of the anti-fascist war of the people worldwide .",
                         Label = new string[]{"0", "0", "0", "0", "0", "B_NORP", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "B_NORP", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0" }
                     }
                }));
            var chain = new EstimatorChain<ITransformer>();
            var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
               .Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "outputColumn"))
               .Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));
 
            var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
            Assert.Equal(3, estimatorSchema.Count);
            Assert.Equal("outputColumn", estimatorSchema[2].Name);
            Assert.Equal(TextDataViewType.Instance, estimatorSchema[2].ItemType);
 
            var transformer = estimator.Fit(dataView);
            var transformerSchema = transformer.GetOutputSchema(dataView.Schema);
 
            Assert.Equal(5, transformerSchema.Count);
            Assert.Equal("outputColumn", transformerSchema[4].Name);
 
            var output = transformer.Transform(dataView);
            using (var cursor = output.GetRowCursorForAllColumns())
            {
 
                var labelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[2]);
                var predictedLabelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[3]);
 
                VBuffer<uint> labelData = default;
                VBuffer<uint> predictedLabelData = default;
 
                while (cursor.MoveNext())
                {
                    labelGetter(ref labelData);
                    predictedLabelGetter(ref predictedLabelData);
 
                    // Make sure that the expected label and the predicted label have same length
                    Assert.Equal(labelData.Length, predictedLabelData.Length);
                }
            }
 
            TestEstimatorCore(estimator, dataView, shouldDispose: true);
            transformer.Dispose();
        }
 
        [Fact]
        public void TestSimpleNerOptions()
        {
            var labels = ML.Data.LoadFromEnumerable(
                new[] {
                new Label { Key = "PERSON" },
                new Label { Key = "CITY" },
                new Label { Key = "COUNTRY"  },
                new Label { Key = "B_WORK_OF_ART"  },
                new Label { Key = "WORK_OF_ART"  },
                new Label { Key = "B_NORP"  },
                });
 
            var options = new NerTrainer.NerOptions();
            options.PredictionColumnName = "outputColumn";
 
            var dataView = ML.Data.LoadFromEnumerable(
                new List<TestSingleSentenceData>(new TestSingleSentenceData[] {
                    new()
                    {
                        Sentence = "Alice and Bob live in the liechtenstein",
                        Label = new string[]{"PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY" }
                    },
                     new()
                     {
                        Sentence = "Alice and Bob live in the USA",
                        Label = new string[]{"PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY"}
                     },
                     new()
                     {
                         Sentence = "WW II Landmarks on the Great Earth of China : Eternal Memories of Taihang Mountain",
                         Label = new string[]{"B_WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART" }
                     },
                     new()
                     {
                         Sentence = "This campaign broke through the Japanese army 's blockade to reach base areas behind enemy lines , stirring up anti-Japanese spirit throughout the nation and influencing the situation of the anti-fascist war of the people worldwide .",
                         Label = new string[]{"0", "0", "0", "0", "0", "B_NORP", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "B_NORP", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0" }
                     }
                }));
            var chain = new EstimatorChain<ITransformer>();
            var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
               .Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(options))
               .Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));
 
            var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
            Assert.Equal(3, estimatorSchema.Count);
            Assert.Equal("outputColumn", estimatorSchema[2].Name);
            Assert.Equal(TextDataViewType.Instance, estimatorSchema[2].ItemType);
 
            var transformer = estimator.Fit(dataView);
            var transformerSchema = transformer.GetOutputSchema(dataView.Schema);
 
            Assert.Equal(5, transformerSchema.Count);
            Assert.Equal("outputColumn", transformerSchema[4].Name);
 
            var output = transformer.Transform(dataView);
            using (var cursor = output.GetRowCursorForAllColumns())
            {
 
                var labelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[2]);
                var predictedLabelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[3]);
 
                VBuffer<uint> labelData = default;
                VBuffer<uint> predictedLabelData = default;
 
                while (cursor.MoveNext())
                {
                    labelGetter(ref labelData);
                    predictedLabelGetter(ref predictedLabelData);
 
                    // Make sure that the expected label and the predicted label have same length
                    Assert.Equal(labelData.Length, predictedLabelData.Length);
                }
            }
 
            TestEstimatorCore(estimator, dataView, shouldDispose: true);
            transformer.Dispose();
        }
 
        [Fact(Skip = "Needs to be on a comp with GPU or will take a LONG time.")]
        public void TestNERLargeFileGpu()
        {
            ML.FallbackToCpu = false;
            ML.GpuDeviceId = 0;
 
            var labelFilePath = GetDataPath("ner-key-info.txt");
            var labels = ML.Data.LoadFromTextFile(labelFilePath, new TextLoader.Column[]
                {
                    new TextLoader.Column("Key", DataKind.String, 0)
                }
            );
 
            var dataFilePath = GetDataPath("ner-conll2012_english_v4_train.txt");
            var dataView = TextLoader.Create(ML, new TextLoader.Options()
            {
                Columns = new[]
                {
                    new TextLoader.Column("Sentence", DataKind.String, 0),
                    new TextLoader.Column("Label", DataKind.String, new TextLoader.Range[]
                    {
                        new TextLoader.Range(1, null) { VariableEnd = true, AutoEnd = false }
                    })
                },
                HasHeader = false,
                Separators = new char[] { '\t' },
                MaxRows = 75187 // Dataset has 75187 rows. Only load 1k for quicker training,
            }, new MultiFileSource(dataFilePath));
 
            var trainTest = ML.Data.TrainTestSplit(dataView);
 
            var options = new NerTrainer.NerOptions();
            options.PredictionColumnName = "outputColumn";
 
            var chain = new EstimatorChain<ITransformer>();
            var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
               .Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(options))
               .Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));
 
            var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
            Assert.Equal(3, estimatorSchema.Count);
            Assert.Equal("outputColumn", estimatorSchema[2].Name);
            Assert.Equal(TextDataViewType.Instance, estimatorSchema[2].ItemType);
 
            var transformer = estimator.Fit(trainTest.TrainSet);
            var transformerSchema = transformer.GetOutputSchema(dataView.Schema);
 
            var output = transformer.Transform(trainTest.TestSet);
            using var cursor = output.GetRowCursorForAllColumns();
 
            var labelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[2]);
            var predictedLabelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[3]);
 
            VBuffer<uint> labelData = default;
            VBuffer<uint> predictedLabelData = default;
 
            double correct = 0;
            double total = 0;
 
            while (cursor.MoveNext())
            {
                labelGetter(ref labelData);
                predictedLabelGetter(ref predictedLabelData);
 
                Assert.Equal(labelData.Length, predictedLabelData.Length);
 
                for (var i = 0; i < labelData.Length; i++)
                {
                    if (labelData.GetItemOrDefault(i) == predictedLabelData.GetItemOrDefault(i) || (labelData.GetItemOrDefault(i) == default && predictedLabelData.GetItemOrDefault(i) == 0))
                        correct++;
                    total++;
                }
            }
            Assert.True(correct / total > .80);
            Assert.Equal(5, transformerSchema.Count);
            Assert.Equal("outputColumn", transformerSchema[4].Name);
 
            transformer.Dispose();
        }
    }
}