File: NasBert\NasBertTrainer.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Tokenizers;
using Microsoft.ML.TorchSharp.NasBert.Models;
using Microsoft.ML.Transforms;
using TorchSharp;
 
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using Microsoft.ML.TorchSharp.Utils;
using Microsoft.ML.TorchSharp.NasBert.Optimizers;
using Microsoft.ML.TorchSharp.Extensions;
using System.IO;
using System.CodeDom;
using System.Runtime.CompilerServices;
using TorchSharp.Modules;
using System.Diagnostics;
 
namespace Microsoft.ML.TorchSharp.NasBert
{
 
    public class NasBertTrainer
    {
        public class NasBertOptions : TorchSharpBaseTrainer.Options
        {
            /// <summary>
            /// The first sentence column.
            /// </summary>
            public string Sentence1ColumnName = "Sentence";
 
            /// <summary>
            /// The second sentence column.
            /// </summary>
            public string Sentence2ColumnName = default;
 
            /// <summary>
            /// Whether to freeze encoder parameters.
            /// </summary>
            public bool FreezeEncoder = false;
 
            /// <summary>
            /// Whether to freeze transfer module parameters.
            /// </summary>
            public bool FreezeTransfer = false;
 
            /// <summary>
            /// Whether to train layer norm parameters.
            /// </summary>
            public bool LayerNormTraining = false;
 
            /// <summary>
            /// Whether to apply layer normalization before each encoder block.
            /// </summary>
            public bool EncoderNormalizeBefore = true;
 
            /// <summary>
            /// Dropout rate for general situations. Should be within [0, 1).
            /// </summary>
            public double Dropout = .1;
 
            /// <summary>
            /// Dropout rate for attention weights. Should be within [0, 1).
            /// </summary>
            public double AttentionDropout = .1;
 
            /// <summary>
            /// Dropout rate after activation functions in FFN layers. Should be within [0, 1).
            /// </summary>
            public double ActivationDropout = 0;
 
            /// <summary>
            /// Whether to use dynamic dropout.
            /// </summary>
            public bool DynamicDropout = false;
 
            /// <summary>
            /// Dropout rate in the masked language model pooler layers. Should be within [0, 1).
            /// </summary>
            public double PoolerDropout = 0;
 
            /// <summary>
            /// Betas for Adam optimizer.
            /// </summary>
            public IReadOnlyList<double> AdamBetas = new List<double> { .9, .999 };
 
            /// <summary>
            /// Epsilon for Adam optimizer.
            /// </summary>
            public double AdamEps = 1e-8;
 
            /// <summary>
            /// The clipping threshold of gradients. Should be within [0, +Inf). 0 means not to clip norm.
            /// </summary>
            public double ClipNorm = 5.0;
 
            /// <summary>
            /// Proportion of warmup steps for polynomial decay scheduler.
            /// </summary>
            public double WarmupRatio = .06;
 
            /// <summary>
            /// Learning rate for the first N epochs; all epochs >N using LR_N.
            /// Note: this may be interpreted differently depending on the scheduler.
            /// </summary>
            public List<double> LearningRate = new List<double> { 1e-4 };
 
            /// <summary>
            /// Task type, which is related to the model head.
            /// </summary>
            public BertTaskType TaskType = BertTaskType.None;
 
            /// <summary>
            /// The index numbers of model architecture. Fixed by the TorchSharp model.
            /// </summary>
            internal IReadOnlyList<int> Arches = new int[] { 9, 11, 7, 0, 0, 0, 11, 11, 7, 0, 0, 0, 9, 7, 11, 0, 0, 0, 10, 7, 9, 0, 0, 0 };
 
            /// <summary>
            /// Maximum length of a sample. Set by the TorchSharp model.
            /// </summary>
            internal int MaxSequenceLength = 512;
 
            /// <summary>
            /// Number of the embedding dimensions. Should be positive. Set by the TorchSharp model.
            /// </summary>
            internal int EmbeddingDim = 64;
 
            /// <summary>
            /// Number of encoder layers. Set by the TorchSharp model.
            /// </summary>
            internal int EncoderLayers = 24;
 
            /// <summary>
            ///  Number of the output dimensions of encoder. Should be positive. Set by the TorchSharp model. 3 * EmbeddingDim
            /// </summary>
            internal int EncoderOutputDim = 192;
 
            /// <summary>
            /// The activation function to use for general situations. Set by the TorchSharp model.
            /// </summary>
            internal string ActivationFunction = "gelu";
 
            /// <summary>
            /// The activation function to use for pooler layers. Set by the TorchSharp model.
            /// </summary>
            internal string PoolerActivationFunction = "tanh";
 
            /// <summary>
            /// Reduction of criterion loss function. Set by the TorchSharp model.
            /// </summary>
            internal torch.nn.Reduction Reduction = Reduction.Sum;
 
            internal BertModelType ModelType = BertModelType.NasBert;
        }
    }
 
    public abstract class NasBertTrainer<TLabelCol, TTargetsCol> : TorchSharpBaseTrainer<TLabelCol, TTargetsCol>
    {
 
        internal readonly NasBertTrainer.NasBertOptions BertOptions;
 
        internal NasBertTrainer(IHostEnvironment env, Options options) : base(env, options)
        {
            BertOptions = options as NasBertTrainer.NasBertOptions;
            Contracts.AssertValue(BertOptions.Sentence1ColumnName);
            Contracts.Assert(BertOptions.TaskType != BertTaskType.None, "BertTaskType must be specified");
        }
 
        private protected abstract class NasBertTrainerBase : TrainerBase
        {
            public Tokenizer Tokenizer;
            public new BaseOptimizer Optimizer;
            public new NasBertTrainer<TLabelCol, TTargetsCol> Parent => base.Parent as NasBertTrainer<TLabelCol, TTargetsCol>;
            public new NasBertModel Model;
            private protected ValueGetter<ReadOnlyMemory<char>> Sentence1Getter;
            private protected ValueGetter<ReadOnlyMemory<char>> Sentence2Getter;
 
            public NasBertTrainerBase(TorchSharpBaseTrainer<TLabelCol, TTargetsCol> parent, IChannel ch, IDataView input, string modelUrl) : base(parent, ch, input, modelUrl)
            {
                // Get the parameters that need optimization and set up the optimizer
                var parameters = Model.parameters().Where(p => p.requires_grad);
                Optimizer = BaseOptimizer.GetOptimizer(Parent.BertOptions, parameters);
                base.Optimizer = Optimizer.Optimizer;
                LearningRateScheduler = torch.optim.lr_scheduler.OneCycleLR(
                   Optimizer.Optimizer,
                   max_lr: Parent.BertOptions.LearningRate[0],
                   total_steps: ((TrainingRowCount / Parent.Option.BatchSize) + 1) * Parent.Option.MaxEpoch,
                   pct_start: Parent.BertOptions.WarmupRatio,
                   anneal_strategy: torch.optim.lr_scheduler.impl.OneCycleLR.AnnealStrategy.Linear,
                   div_factor: 1.0 / Parent.Option.StartLearningRateRatio,
                   final_div_factor: Parent.Option.StartLearningRateRatio / Parent.Option.FinalLearningRateRatio);
            }
 
            private protected override Module CreateModule(IChannel ch, IDataView input)
            {
                Tokenizer = TokenizerExtensions.GetInstance(ch);
                EnglishRobertaTokenizer tokenizerModel = Tokenizer.RobertaModel();
 
                NasBertModel model;
                if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
                    model = new NerModel(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
                else
                    model = new ModelForPrediction(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
                model.GetEncoder().load(GetModelPath(ModelUrl));
                Model = model;
                return model;
            }
 
            private protected override DataViewRowCursor GetRowCursor(IDataView input)
            {
                if (Parent.BertOptions.Sentence2ColumnName != default)
                    return input.GetRowCursor(input.Schema[Parent.BertOptions.Sentence1ColumnName], input.Schema[Parent.BertOptions.Sentence2ColumnName], input.Schema[Parent.Option.LabelColumnName]);
                else
                    return input.GetRowCursor(input.Schema[Parent.BertOptions.Sentence1ColumnName], input.Schema[Parent.Option.LabelColumnName]);
            }
 
            private protected override void InitializeDataGetters(IDataView input, DataViewRowCursor cursor)
            {
                Sentence1Getter = cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.BertOptions.Sentence1ColumnName]);
                Sentence2Getter = Parent.BertOptions.Sentence2ColumnName != default ? cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.BertOptions.Sentence2ColumnName]) : default;
            }
 
            private protected override void RunModelAndUpdateValidationStats(ref Tensor inputTensor, ref Tensor targetsTensor, ref int numCorrect)
            {
                var logits = Model.forward(inputTensor);
                var predictions = GetPredictions(logits);
                var targetss = GetTargets(targetsTensor);
                numCorrect = GetNumCorrect(predictions, targetss);
            }
 
            private protected override torch.Tensor PrepareBatchTensor(ref List<Tensor> inputTensors, Device device)
            {
                return DataUtils.CollateTokens(inputTensors, Tokenizer.RobertaModel().PadIndex, device: Device);
            }
 
            private protected override torch.Tensor PrepareRowTensor(ref TLabelCol target)
            {
                ReadOnlyMemory<char> sentence1 = default;
                Sentence1Getter(ref sentence1);
                Tensor t;
                if (Sentence2Getter == default)
                {
                    t = torch.tensor((new[] { 0 /* InitToken */ }).Concat(Tokenizer.EncodeToConverted(sentence1.ToString())).ToList(), device: Device);
                }
                else
                {
 
                    ReadOnlyMemory<char> sentence2 = default;
                    Sentence2Getter(ref sentence2);
 
                    t = torch.tensor((new[] { 0 /* InitToken */ }).Concat(Tokenizer.EncodeToConverted(sentence1.ToString()))
                        .Concat(new[] { 2 /* SeparatorToken */ }).Concat(Tokenizer.EncodeToConverted(sentence2.ToString())).ToList(), device: Device);
                }
 
                if (t.NumberOfElements > 512)
                    t = t.slice(0, 0, 512, 1);
 
                return t;
            }
 
            private protected override void RunModelAndBackPropagate(ref List<Tensor> inputTensors, ref Tensor targetsTensor)
            {
                Tensor logits = default;
                if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
                {
                    int[,] lengthArray = new int[inputTensors.Count, 1];
                    for (int i = 0; i < inputTensors.Count; i++)
                    {
                        lengthArray[i, 0] = (int)inputTensors[i].shape[0];
                    }
                    Tensor lengths = torch.tensor(lengthArray, device: Device);
 
                    var inputTensor = PrepareBatchTensor(ref inputTensors, device: Device);
                    var tokenMask = torch.arange(512).expand(lengths.numel(), 512).to(lengths.device) < lengths;
 
                    logits = Model.forward(inputTensor, tokenMask: tokenMask);
 
                }
                else
                {
                    var inputTensor = PrepareBatchTensor(ref inputTensors, device: Device);
 
                    logits = Model.forward(inputTensor);
                }
 
                torch.Tensor loss;
                if (Parent.BertOptions.TaskType == BertTaskType.TextClassification)
                    loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor);
                else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
                {
                    targetsTensor = targetsTensor.@long().view(-1);
                    logits = logits.view(-1, logits.size(-1));
                    loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor);
                }
                else
                {
                    loss = torch.nn.MSELoss(reduction: Parent.BertOptions.Reduction).forward(logits.squeeze(), targetsTensor);
                }
 
                loss.backward();
            }
 
            private protected override void OptimizeStep()
            {
                Optimizer.Step();
                LearningRateScheduler.step();
            }
        }
 
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
 
            CheckInputSchema(inputSchema);
 
            var outColumns = inputSchema.ToDictionary(x => x.Name);
 
            if (BertOptions.TaskType == BertTaskType.TextClassification)
            {
 
                var metadata = new List<SchemaShape.Column>();
                metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
                    TextDataViewType.Instance, false));
 
                // Get label column for score column annotations. Already verified it exists.
                inputSchema.TryFindColumn(Option.LabelColumnName, out var labelCol);
 
                outColumns[Option.PredictionColumnName] = new SchemaShape.Column(Option.PredictionColumnName, SchemaShape.Column.VectorKind.Scalar,
                        NumberDataViewType.UInt32, true, new SchemaShape(metadata.ToArray()));
 
                outColumns[Option.ScoreColumnName] = new SchemaShape.Column(Option.ScoreColumnName, SchemaShape.Column.VectorKind.Vector,
                    NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol)));
            }
            else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
            {
                var metadata = new List<SchemaShape.Column>();
                metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
                    TextDataViewType.Instance, false));
 
                // Get label column for score column annotations. Already verified it exists.
                inputSchema.TryFindColumn(Option.LabelColumnName, out var labelCol);
 
                outColumns[Option.PredictionColumnName] = new SchemaShape.Column(Option.PredictionColumnName, SchemaShape.Column.VectorKind.VariableVector,
                        NumberDataViewType.UInt32, true, new SchemaShape(metadata.ToArray()));
            }
            else
            {
                outColumns[Option.ScoreColumnName] = new SchemaShape.Column(Option.ScoreColumnName, SchemaShape.Column.VectorKind.Scalar,
                    NumberDataViewType.Single, false);
            }
 
            return new SchemaShape(outColumns.Values);
        }
 
        private protected override void CheckInputSchema(SchemaShape inputSchema)
        {
            // Verify that all required input columns are present, and are of the same type.
            if (!inputSchema.TryFindColumn(BertOptions.Sentence1ColumnName, out var sentenceCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence", BertOptions.Sentence1ColumnName);
            if (sentenceCol.ItemType != TextDataViewType.Instance)
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence", BertOptions.Sentence1ColumnName,
                    TextDataViewType.Instance.ToString(), sentenceCol.GetTypeString());
 
            if (!inputSchema.TryFindColumn(Option.LabelColumnName, out var labelCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName);
 
            if (BertOptions.TaskType == BertTaskType.TextClassification)
            {
                if (labelCol.ItemType != NumberDataViewType.UInt32)
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
                        NumberDataViewType.UInt32.ToString(), labelCol.GetTypeString());
 
 
                if (BertOptions.Sentence2ColumnName != default)
                {
                    if (!inputSchema.TryFindColumn(BertOptions.Sentence2ColumnName, out var sentenceCol2))
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName);
                    if (sentenceCol2.ItemType != TextDataViewType.Instance)
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName,
                            TextDataViewType.Instance.ToString(), sentenceCol2.GetTypeString());
                }
            }
            else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
            {
                if (labelCol.ItemType != NumberDataViewType.UInt32)
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
                        NumberDataViewType.UInt32.ToString(), labelCol.GetTypeString());
            }
            else
            {
                if (labelCol.ItemType != NumberDataViewType.Single)
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
                        NumberDataViewType.Single.ToString(), labelCol.GetTypeString());
 
                if (!inputSchema.TryFindColumn(BertOptions.Sentence2ColumnName, out var sentenceCol2))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName);
                if (sentenceCol2.ItemType != TextDataViewType.Instance)
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName,
                        TextDataViewType.Instance.ToString(), sentenceCol2.GetTypeString());
            }
        }
    }
 
    public abstract class NasBertTransformer<TLabelCol, TTargetsCol> : TorchSharpBaseTransformer<TLabelCol, TTargetsCol>
    {
        internal readonly NasBertTrainer.NasBertOptions BertOptions;
 
 
        public readonly SchemaShape.Column SentenceColumn;
        public readonly SchemaShape.Column SentenceColumn2;
 
        internal NasBertTransformer(IHostEnvironment env, NasBertTrainer.NasBertOptions options, NasBertModel model, DataViewSchema.DetachedColumn labelColumn)
           : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NasBertTransformer<TLabelCol, TTargetsCol>)), options, model, labelColumn)
        {
            BertOptions = options;
            SentenceColumn = new SchemaShape.Column(options.Sentence1ColumnName, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false);
            SentenceColumn2 = options.Sentence2ColumnName == default ? default : new SchemaShape.Column(options.Sentence2ColumnName, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false);
        }
 
        private protected override SchemaShape GetOutputSchemaCore(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
 
            CheckInputSchema(inputSchema);
 
            var outColumns = inputSchema.ToDictionary(x => x.Name);
            if (BertOptions.TaskType == BertTaskType.TextClassification)
            {
                var labelAnnotationsColumn = new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, LabelColumn.Annotations.Schema[AnnotationUtils.Kinds.SlotNames].Type, false);
                var predLabelMetadata = new SchemaShape(new SchemaShape.Column[] { labelAnnotationsColumn }
                    .Concat(AnnotationUtils.GetTrainerOutputAnnotation()));
 
                outColumns[Options.PredictionColumnName] = new SchemaShape.Column(Options.PredictionColumnName, SchemaShape.Column.VectorKind.Scalar,
                    NumberDataViewType.UInt32, true, predLabelMetadata);
 
                outColumns[ScoreColumnName] = new SchemaShape.Column(ScoreColumnName, SchemaShape.Column.VectorKind.Vector,
                       NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelAnnotationsColumn)));
            }
            else
            {
                outColumns[ScoreColumnName] = new SchemaShape.Column(ScoreColumnName, SchemaShape.Column.VectorKind.Scalar,
                       NumberDataViewType.Single, false);
            }
 
            return new SchemaShape(outColumns.Values);
        }
 
        private protected override void CheckInputSchema(SchemaShape inputSchema)
        {
            // Verify that all required input columns are present, and are of the same type.
            if (!inputSchema.TryFindColumn(SentenceColumn.Name, out var sentenceCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence", SentenceColumn.Name);
            if (!SentenceColumn.IsCompatibleWith(sentenceCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence", SentenceColumn.Name,
                    SentenceColumn.GetTypeString(), sentenceCol.GetTypeString());
 
            if (BertOptions.Sentence2ColumnName != default || BertOptions.TaskType == BertTaskType.SentenceRegression)
            {
                if (!inputSchema.TryFindColumn(SentenceColumn2.Name, out var sentenceCol2))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", SentenceColumn2.Name);
                if (!SentenceColumn2.IsCompatibleWith(sentenceCol2))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", SentenceColumn2.Name,
                        SentenceColumn2.GetTypeString(), sentenceCol2.GetTypeString());
            }
        }
 
        private protected abstract override void SaveModel(ModelSaveContext ctx);
 
 
        private protected new void SaveBaseModel(ModelSaveContext ctx, VersionInfo versionInfo)
        {
            base.SaveBaseModel(ctx, versionInfo);
 
            // *** Binary format ***
            // int: id of sentence 1 column name
            // int: id of sentence 2 column name
 
            ctx.SaveNonEmptyString(BertOptions.Sentence1ColumnName);
            ctx.SaveStringOrNull(BertOptions.Sentence2ColumnName);
        }
 
        private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => GetRowMapper(this, schema);
 
        private protected abstract class NasBertMapper : TorchSharpBaseMapper
        {
            private protected new NasBertTransformer<TLabelCol, TTargetsCol> Parent => base.Parent as NasBertTransformer<TLabelCol, TTargetsCol>;
 
            private static readonly FuncInstanceMethodInfo1<NasBertMapper, DataViewSchema.DetachedColumn, Delegate> _makeLabelAnnotationGetter
                = FuncInstanceMethodInfo1<NasBertMapper, DataViewSchema.DetachedColumn, Delegate>.Create(target => target.GetLabelAnnotations<int>);
            internal static readonly int[] InitTokenArray = new[] { 0 /* InitToken */ };
            internal static readonly int[] SeperatorTokenArray = new[] { 2 /* SeperatorToken */ };
 
            public NasBertMapper(TorchSharpBaseTransformer<TLabelCol, TTargetsCol> parent, DataViewSchema inputSchema) :
                base(parent, inputSchema)
            {
            }
 
            private protected override void AddInputColumnIndices(DataViewSchema inputSchema)
            {
                if (inputSchema.TryGetColumnIndex(Parent.BertOptions.Sentence1ColumnName, out var col))
                    InputColIndices.Add(col);
 
                if (Parent.BertOptions.Sentence2ColumnName != default || Parent.BertOptions.TaskType == BertTaskType.SentenceRegression)
                    if (inputSchema.TryGetColumnIndex(Parent.BertOptions.Sentence2ColumnName, out col))
                        InputColIndices.Add(col);
            }
 
            protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            {
                if (Parent.BertOptions.TaskType == BertTaskType.TextClassification)
                {
                    var info = new DataViewSchema.DetachedColumn[2];
                    var keyType = Parent.LabelColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
                    var getter = Microsoft.ML.Internal.Utilities.Utils.MarshalInvoke(_makeLabelAnnotationGetter, this, keyType.ItemType.RawType, Parent.LabelColumn);
 
 
                    var meta = new DataViewSchema.Annotations.Builder();
                    meta.Add(AnnotationUtils.Kinds.ScoreColumnKind, TextDataViewType.Instance, (ref ReadOnlyMemory<char> value) => { value = AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification.AsMemory(); });
                    meta.Add(AnnotationUtils.Kinds.ScoreColumnSetId, AnnotationUtils.ScoreColumnSetIdType, GetScoreColumnSetId(InputSchema));
                    meta.Add(AnnotationUtils.Kinds.ScoreValueKind, TextDataViewType.Instance, (ref ReadOnlyMemory<char> value) => { value = AnnotationUtils.Const.ScoreValueKind.Score.AsMemory(); });
                    meta.Add(AnnotationUtils.Kinds.TrainingLabelValues, keyType, getter);
                    meta.Add(AnnotationUtils.Kinds.SlotNames, keyType, getter);
 
                    var labelBuilder = new DataViewSchema.Annotations.Builder();
                    labelBuilder.Add(AnnotationUtils.Kinds.KeyValues, keyType, getter);
 
                    info[0] = new DataViewSchema.DetachedColumn(Parent.Options.PredictionColumnName, new KeyDataViewType(typeof(uint), Parent.Options.NumberOfClasses), labelBuilder.ToAnnotations());
 
                    info[1] = new DataViewSchema.DetachedColumn(Parent.Options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single, Parent.Options.NumberOfClasses), meta.ToAnnotations());
                    return info;
                }
                else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
                {
                    var info = new DataViewSchema.DetachedColumn[1];
                    var keyType = Parent.LabelColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
                    var getter = Microsoft.ML.Internal.Utilities.Utils.MarshalInvoke(_makeLabelAnnotationGetter, this, keyType.ItemType.RawType, Parent.LabelColumn);
 
                    var labelBuilder = new DataViewSchema.Annotations.Builder();
                    labelBuilder.Add(AnnotationUtils.Kinds.KeyValues, keyType, getter);
 
                    info[0] = new DataViewSchema.DetachedColumn(Parent.Options.PredictionColumnName, new VectorDataViewType(new KeyDataViewType(typeof(uint), Parent.Options.NumberOfClasses - 1)), labelBuilder.ToAnnotations());
 
                    return info;
                }
                else
                {
                    var info = new DataViewSchema.DetachedColumn[1];
 
                    info[0] = new DataViewSchema.DetachedColumn(Parent.Options.ScoreColumnName, NumberDataViewType.Single);
                    return info;
                }
            }
 
            private Delegate GetLabelAnnotations<T>(DataViewSchema.DetachedColumn labelCol)
            {
                return labelCol.Annotations.GetGetter<VBuffer<T>>(labelCol.Annotations.Schema[AnnotationUtils.Kinds.KeyValues]);
            }
 
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
                => throw new NotImplementedException("This should never be called!");
 
            private protected class BertTensorCacher : TorchSharpBaseMapper.TensorCacher<Tensor>
            {
                public override void DisposeCore()
                {
                    Result?.Dispose();
                }
            }
 
            private protected override TensorCacher GetTensorCacher()
            {
                return new BertTensorCacher();
            }
 
            private IList<int> PrepInputTokens(ref ReadOnlyMemory<char> sentence1, ref ReadOnlyMemory<char> sentence2, ref ValueGetter<ReadOnlyMemory<char>> getSentence1, ref ValueGetter<ReadOnlyMemory<char>> getSentence2, Tokenizer tokenizer)
            {
                getSentence1(ref sentence1);
                if (getSentence2 == default)
                {
                    List<int> newList = new List<int>(tokenizer.EncodeToConverted(sentence1.ToString()));
                    // 0 Is the init token and must be at the beginning.
                    newList.Insert(0, 0);
                    return newList;
                }
                else
                {
                    getSentence2(ref sentence2);
                    return InitTokenArray.Concat(tokenizer.EncodeToConverted(sentence1.ToString()))
                                              .Concat(SeperatorTokenArray).Concat(tokenizer.EncodeToConverted(sentence2.ToString())).ToList();
                }
            }
 
            private Tensor PrepAndRunModel(IList<int> tokens)
            {
                using (torch.no_grad())
                {
                    var inputTensor = torch.tensor(tokens, device: Parent.Device);
                    if (inputTensor.NumberOfElements > 512)
                        inputTensor = inputTensor.slice(0, 0, 512, 1);
                    inputTensor = inputTensor.reshape(1, inputTensor.NumberOfElements);
                    return (Parent.Model as NasBertModel).forward(inputTensor);
                }
            }
 
            private protected void UpdateCacheIfNeeded(long position, TensorCacher outputCache, ref ReadOnlyMemory<char> sentence1, ref ReadOnlyMemory<char> sentence2, ref ValueGetter<ReadOnlyMemory<char>> getSentence1, ref ValueGetter<ReadOnlyMemory<char>> getSentence2, Tokenizer tokenizer)
            {
                var cache = outputCache as BertTensorCacher;
                if (outputCache.Position != position)
                {
                    cache.Result?.Dispose();
                    cache.Result = PrepAndRunModel(PrepInputTokens(ref sentence1, ref sentence2, ref getSentence1, ref getSentence2, tokenizer));
                    cache.Result.MoveToOuterDisposeScope();
                    cache.Position = position;
                }
            }
 
            private protected override void SaveModel(ModelSaveContext ctx) => Parent.SaveModel(ctx);
        }
    }
}