File: Roberta\QATrainer.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using TorchSharp;
using System.Runtime.CompilerServices;
 
using static TorchSharp.torch;
using static TorchSharp.TensorExtensionMethods;
using Microsoft.ML.TorchSharp.Utils;
using Microsoft.ML;
using System.IO;
using static Microsoft.ML.TorchSharp.NasBert.NasBertTrainer;
using static Microsoft.ML.Data.AnnotationUtils;
using Microsoft.ML.TorchSharp.NasBert.Optimizers;
using Microsoft.ML.TorchSharp.Roberta;
using Microsoft.ML.TorchSharp.NasBert;
using Microsoft.ML.Tokenizers;
using Microsoft.ML.TorchSharp.Extensions;
 
[assembly: LoadableClass(typeof(QATransformer), null, typeof(SignatureLoadModel),
    QATransformer.UserName, QATransformer.LoaderSignature)]
 
[assembly: LoadableClass(typeof(IRowMapper), typeof(QATransformer), null, typeof(SignatureLoadRowMapper),
    QATransformer.UserName, QATransformer.LoaderSignature)]
 
namespace Microsoft.ML.TorchSharp.Roberta
{
    public class QATrainer : IEstimator<QATransformer>
    {
        public sealed class Options : NasBertOptions
        {
            /// <summary>
            /// Context Column Name
            /// </summary>
            public string ContextColumnName = DefaultColumnNames.Context;
 
            /// <summary>
            /// Question Column Name
            /// </summary>
            public string QuestionColumnName = DefaultColumnNames.Question;
 
            /// <summary>
            /// Answer Column Name for the training data
            /// </summary>
            public string TrainingAnswerColumnName = DefaultColumnNames.TrainingAnswer;
 
            /// <summary>
            /// Answer Column Name for the predicted answers
            /// </summary>
            public string PredictedAnswerColumnName = DefaultColumnNames.Answer;
 
            /// <summary>
            /// Answer Index Start Column Name
            /// </summary>
            public string AnswerIndexStartColumnName = DefaultColumnNames.AnswerIndex;
 
            /// <summary>
            /// Number of top predicted answers in question answering task.
            /// </summary>
            public int TopKAnswers = DefaultColumnNames.TopKAnswers;
 
            /// <summary>
            /// How often to log the loss.
            /// </summary>
            public int LogEveryNStep = 50;
 
            public Options()
            {
                EncoderOutputDim = 768;
                EmbeddingDim = 768;
                PoolerDropout = 0;
                ModelType = BertModelType.Roberta;
                TaskType = BertTaskType.QuestionAnswering;
                LearningRate = new List<double>() { .000001 };
                WeightDecay = 0.01;
            }
        }
 
        internal static class DefaultColumnNames
        {
            public const string Context = "Context";
            public const string Question = "Question";
            public const string Answer = "Answer";
            public const string TrainingAnswer = "TrainingAnswer";
            public const string AnswerIndex = "AnswerStart";
            public const int TopKAnswers = 3;
            public const string Score = "Score";
        }
 
        private protected readonly IHost Host;
        internal readonly Options Option;
        private const string ModelUrl = "models/pretrained_Roberta_encoder.tsm";
 
        internal QATrainer(IHostEnvironment env, Options options)
        {
            Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(QATrainer));
            Contracts.Assert(options.MaxEpoch > 0);
            Contracts.AssertValue(options.ContextColumnName);
            Contracts.AssertValue(options.QuestionColumnName);
            Contracts.AssertValue(options.TrainingAnswerColumnName);
            Contracts.AssertValue(options.AnswerIndexStartColumnName);
            Contracts.AssertValue(options.ScoreColumnName);
            Contracts.AssertValue(options.PredictedAnswerColumnName);
 
            Option = options;
        }
 
        internal QATrainer(IHostEnvironment env,
            string contextColumnName = DefaultColumnNames.Context,
            string questionColumnName = DefaultColumnNames.Question,
            string trainingAnswerColumnName = DefaultColumnNames.TrainingAnswer,
            string answerIndexColumnName = DefaultColumnNames.AnswerIndex,
            string predictedAnswerColumnName = DefaultColumnNames.Answer,
            string scoreColumnName = DefaultColumnNames.Score,
            int topk = 3,
            int batchSize = 4,
            int maxEpochs = 10,
            IDataView validationSet = null,
            BertArchitecture architecture = BertArchitecture.Roberta) :
            this(env, new Options
            {
                ContextColumnName = contextColumnName,
                QuestionColumnName = questionColumnName,
                TrainingAnswerColumnName = trainingAnswerColumnName,
                AnswerIndexStartColumnName = answerIndexColumnName,
                PredictedAnswerColumnName = predictedAnswerColumnName,
                ScoreColumnName = scoreColumnName,
                TopKAnswers = topk,
                BatchSize = batchSize,
                MaxEpoch = maxEpochs,
                ValidationSet = validationSet
            })
        {
        }
 
        public QATransformer Fit(IDataView input)
        {
            CheckInputSchema(SchemaShape.Create(input.Schema));
 
            QATransformer transformer = default;
 
            using (var ch = Host.Start("TrainModel"))
            using (var pch = Host.StartProgressChannel("Training model"))
            {
                var header = new ProgressHeader(new[] { "Loss" }, new[] { "Total Rows" });
 
                var trainer = new Trainer(this, ch, input);
                pch.SetHeader(header,
                    e =>
                    {
                        e.SetProgress(0, trainer.Updates, trainer.RowCount);
                        e.SetMetric(0, trainer.LossValue);
                    });
 
                for (int i = 0; i < Option.MaxEpoch; i++)
                {
                    ch.Trace($"Starting epoch {i}");
                    Host.CheckAlive();
                    trainer.Train(Host, input, pch);
                    ch.Trace($"Finished epoch {i}");
                }
 
                trainer.Optimizer.Optimizer.Dispose();
 
                transformer = new QATransformer(Host, Option, trainer.Model);
                transformer.GetOutputSchema(input.Schema);
            }
            return transformer;
        }
 
        internal class Trainer
        {
            public RobertaModelForQA Model;
            public torch.Device Device;
            public BaseOptimizer Optimizer;
            public optim.lr_scheduler.LRScheduler LearningRateScheduler;
            protected readonly QATrainer Parent;
            public int Updates;
            public float LossValue;
            public readonly int RowCount;
            private readonly IChannel _channel;
            public Tokenizer Tokenizer;
 
            public Trainer(QATrainer parent, IChannel ch, IDataView input)
            {
                Parent = parent;
                Updates = 0;
                LossValue = 0;
                _channel = ch;
 
                // Get row count
                RowCount = GetRowCount(input);
                Device = TorchUtils.InitializeDevice(Parent.Host);
 
                // Initialize the model and load pre-trained weights
                Model = new RobertaModelForQA(Parent.Option);
 
                Model.GetEncoder().load(GetModelPath());
 
                // Figure out if we are running on GPU or CPU
                Device = TorchUtils.InitializeDevice(Parent.Host);
 
                // Move to GPU if we are running there
                if (Device.type == DeviceType.CUDA)
                    Model.cuda();
 
                Tokenizer = TokenizerExtensions.GetInstance(ch);
 
                // Get the parameters that need optimization and set up the optimizer
                var parameters = Model.parameters().Where(p => p.requires_grad);
                Optimizer = BaseOptimizer.GetOptimizer(Parent.Option, parameters);
                LearningRateScheduler = torch.optim.lr_scheduler.OneCycleLR(
                   Optimizer.Optimizer,
                   max_lr: Parent.Option.LearningRate[0],
                   total_steps: ((RowCount / Parent.Option.BatchSize) + 1) * Parent.Option.MaxEpoch,
                   pct_start: Parent.Option.WarmupRatio,
                   anneal_strategy: torch.optim.lr_scheduler.impl.OneCycleLR.AnnealStrategy.Linear,
                   div_factor: 1.0 / Parent.Option.StartLearningRateRatio,
                   final_div_factor: Parent.Option.StartLearningRateRatio / Parent.Option.FinalLearningRateRatio);
            }
 
            private protected int GetRowCount(IDataView input)
            {
                var labelCol = input.GetColumn<int>(Parent.Option.AnswerIndexStartColumnName);
                var rowCount = 0;
 
                foreach (var label in labelCol)
                {
                    rowCount++;
                }
 
                return rowCount;
            }
 
            private string GetModelPath()
            {
                var destDir = Path.Combine(((IHostEnvironmentInternal)Parent.Host).TempFilePath, "mlnet");
                var destFileName = ModelUrl.Split('/').Last();
 
                Directory.CreateDirectory(destDir);
 
                string relativeFilePath = Path.Combine(destDir, destFileName);
 
                int timeout = 10 * 60 * 1000;
                using (var ch = (Parent.Host as IHostEnvironment).Start("Ensuring model file is present."))
                {
                    var ensureModel = ResourceManagerUtils.Instance.EnsureResourceAsync(Parent.Host, ch, ModelUrl, destFileName, destDir, timeout);
                    ensureModel.Wait();
                    var errorResult = ResourceManagerUtils.GetErrorMessage(out var errorMessage, ensureModel.Result);
                    if (errorResult != null)
                    {
                        var directory = Path.GetDirectoryName(errorResult.FileName);
                        var name = Path.GetFileName(errorResult.FileName);
                        throw ch.Except($"{errorMessage}\nmodel file could not be downloaded!");
                    }
                }
 
                return relativeFilePath;
            }
 
            public void Train(IHost host, IDataView input, IProgressChannel pch)
            {
                // Get the cursor and the correct columns based on the inputs
                DataViewRowCursor cursor = input.GetRowCursor(input.Schema[Parent.Option.ContextColumnName], input.Schema[Parent.Option.QuestionColumnName], input.Schema[Parent.Option.TrainingAnswerColumnName], input.Schema[Parent.Option.AnswerIndexStartColumnName]);
 
                var contextGetter = cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.Option.ContextColumnName]);
                var questionGetter = cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.Option.QuestionColumnName]);
                var answerGetter = cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.Option.TrainingAnswerColumnName]);
                var answerIndexGetter = cursor.GetGetter<int>(input.Schema[Parent.Option.AnswerIndexStartColumnName]);
 
                var cursorValid = true;
                Updates = 0;
 
                List<Tensor> inputTensors = new List<Tensor>(Parent.Option.BatchSize);
                List<Tensor> targetTensors = new List<Tensor>(Parent.Option.BatchSize);
 
                while (cursorValid)
                {
 
                    if (host is IHostEnvironmentInternal hostInternal)
                    {
                        torch.random.manual_seed(hostInternal.Seed + Updates ?? 1);
                        torch.cuda.manual_seed(hostInternal.Seed + Updates ?? 1);
                    }
                    else
                    {
                        torch.random.manual_seed(1);
                        torch.cuda.manual_seed(1);
                    }
                    cursorValid = TrainStep(host, cursor, contextGetter, questionGetter, answerGetter, answerIndexGetter, ref inputTensors, ref targetTensors, pch);
                }
            }
 
            private bool TrainStep(IHost host,
                DataViewRowCursor cursor,
                ValueGetter<ReadOnlyMemory<char>> contextGetter,
                ValueGetter<ReadOnlyMemory<char>> questionGetter,
                ValueGetter<ReadOnlyMemory<char>> answerGetter,
                ValueGetter<int> answerIndexGetter,
                ref List<Tensor> inputTensors,
                ref List<Tensor> targetTensors,
                IProgressChannel pch)
            {
                // Make sure list is clear before use
                inputTensors.Clear();
                targetTensors.Clear();
 
                using var disposeScope = torch.NewDisposeScope();
                var cursorValid = true;
                Tensor srcTensor = default;
                Tensor targetTensor = default;
 
                host.CheckAlive();
 
                for (int i = 0; i < Parent.Option.BatchSize && cursorValid; i++)
                {
                    host.CheckAlive();
                    cursorValid = cursor.MoveNext();
                    if (cursorValid)
                    {
                        (srcTensor, targetTensor, bool valid) = PrepareData(contextGetter, questionGetter, answerGetter, answerIndexGetter);
                        if (valid)
                        {
                            inputTensors.Add(srcTensor);
                            targetTensors.Add(targetTensor);
                        }
                        else
                            i--;
                    }
                    else
                    {
                        inputTensors.TrimExcess();
                        targetTensors.TrimExcess();
                        if (inputTensors.Count() == 0)
                            return cursorValid;
                    }
                }
 
                Updates++;
                host.CheckAlive();
                Model.train();
                Optimizer.zero_grad();
 
                srcTensor = PrepareBatchTensor(ref inputTensors, device: Device, Tokenizer.RobertaModel().PadIndex);
                targetTensor = PrepareBatchTensor(ref targetTensors, device: Device, 0);
                var logits = Model.forward(srcTensor);  //[batchsize, maxseqlen, 2]
                var splitLogits = logits.split(1, dim: -1);
                var startLogits = splitLogits[0].squeeze(-1).contiguous();  //[batchsize, maxseqlen]
                var endLogits = splitLogits[1].squeeze(-1).contiguous();  //[batchsize, maxseqlen]
 
                var targetsLong = targetTensor.@long();
                var splitTargets = targetsLong.split(1, dim: -1);
                var startTargets = splitTargets[0].squeeze(-1).contiguous();  //[batchsize]
                var endTargets = splitTargets[1].squeeze(-1).contiguous();  //[batchsize]
 
                torch.Tensor lossStart = torch.nn.CrossEntropyLoss(reduction: Parent.Option.Reduction).forward(startLogits, startTargets);
                torch.Tensor lossEnd = torch.nn.CrossEntropyLoss(reduction: Parent.Option.Reduction).forward(endLogits, endTargets);
 
                var loss = ((lossStart + lossEnd) / 2);
 
                loss.backward();
 
                Optimizer.Step();
                LearningRateScheduler.step();
                host.CheckAlive();
 
                if (Updates % Parent.Option.LogEveryNStep == 0)
                {
                    pch.Checkpoint(loss.ToDouble(), Updates);
                    _channel.Info($"Row: {Updates}, Loss: {loss.ToDouble()}");
                }
 
                return cursorValid;
            }
 
            private torch.Tensor PrepareBatchTensor(ref List<Tensor> inputTensors, Device device, int padIndex)
            {
                return DataUtils.CollateTokens(inputTensors, padIndex, device: Device);
            }
 
            private (Tensor image, Tensor Label, bool hasMapping) PrepareData(ValueGetter<ReadOnlyMemory<char>> contextGetter, ValueGetter<ReadOnlyMemory<char>> questionGetter, ValueGetter<ReadOnlyMemory<char>> answerGetter, ValueGetter<int> answerIndexGetter)
            {
                using (var _ = torch.NewDisposeScope())
                {
                    ReadOnlyMemory<char> context = default;
                    ReadOnlyMemory<char> question = default;
                    ReadOnlyMemory<char> answer = default;
                    int answerIndex = default;
 
                    contextGetter(ref context);
                    questionGetter(ref question);
                    answerGetter(ref answer);
                    answerIndexGetter(ref answerIndex);
 
                    var contextString = context.ToString();
                    var contextTokens = Tokenizer.EncodeToTokens(contextString, out string normalized);
                    var contextToken = contextTokens.Select(t => t.Value).ToArray();
                    var contextTokenId = Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(contextTokens.Select(t => t.Id).ToArray());
 
                    var mapping = AlignAnswerPosition(contextToken, contextString);
                    if (mapping == null)
                    {
                        return (null, null, false);
                    }
                    var questionTokenId = Tokenizer.EncodeToConverted(question.ToString());
 
                    var answerEnd = answerIndex + answer.Length - 1;
                    if (!mapping.ContainsKey(answerIndex) || !mapping.ContainsKey(answerEnd))
                    {
                        return (null, null, false);
                    }
                    var targetList = new List<int> { mapping[answerIndex] + 2 + questionTokenId.Count, mapping[answerEnd] + 2 + questionTokenId.Count };
 
                    var srcTensor = torch.tensor((new[] { 0 /* InitToken */ }).Concat(questionTokenId).Concat(new[] { 2 /* SeparatorToken */ }).Concat(contextTokenId).ToList(), device: Device);
                    // If the end of the answer goes beyond the 512 tokens then set answer start/end index to 0
                    if (targetList[1] > 511)
                    {
                        targetList[0] = 0;
                        targetList[1] = 0;
                    }
                    var labelTensor = torch.tensor(targetList, device: Device);
 
                    if (srcTensor.NumberOfElements > 512)
                        srcTensor = srcTensor.slice(0, 0, 512, 1);
 
                    return (srcTensor.MoveToOuterDisposeScope(), labelTensor.MoveToOuterDisposeScope(), true);
                }
            }
 
            private Dictionary<int, int> AlignAnswerPosition(IReadOnlyList<string> tokens, string text)
            {
                EnglishRobertaTokenizer robertaModel = Tokenizer as EnglishRobertaTokenizer;
                Debug.Assert(robertaModel is not null);
 
                var mapping = new Dictionary<int, int>();
                int surrogateDeduce = 0;
                for (var (i, j, tid) = (0, 0, 0); i < text.Length && tid < tokens.Count;)
                {
                    // Move to a new token
                    if (j >= tokens[tid].Length)
                    {
                        ++tid;
                        j = 0;
                    }
                    // There are a few UTF-32 chars in corpus, which is considered one char in position
                    else if (i + 1 < text.Length && char.IsSurrogatePair(text[i], text[i + 1]))
                    {
                        i += 2;
                        ++surrogateDeduce;
                    }
                    // White spaces are not included in tokens
                    else if (char.IsWhiteSpace(text[i]))
                    {
                        ++i;
                    }
                    // Chars not included in tokenizer will not appear in tokens
                    else if (!robertaModel.IsSupportedChar(text[i]))
                    {
                        mapping[i - surrogateDeduce] = tid;
                        ++i;
                    }
                    // "\\\"", "``" and "''" converted to "\"" in normalizer
                    else if (i + 1 < text.Length && tokens[tid][j] == '"'
                        && ((text[i] == '`' && text[i + 1] == '`')
                         || (text[i] == '\'' && text[i + 1] == '\'')
                         || (text[i] == '\\' && text[i + 1] == '"')))
                    {
                        mapping[i - surrogateDeduce] = mapping[i + 1 - surrogateDeduce] = tid;
                        i += 2;
                        j += 1;
                    }
                    // Normal match
                    else if (text[i] == tokens[tid][j])
                    {
                        mapping[i - surrogateDeduce] = tid;
                        ++i;
                        ++j;
                    }
                    // There are a few real \u0120 chars in the corpus, so this rule has to be later than text[i] == tokens[tid][j].
                    else if (tokens[tid][j] == '\u0120' && j == 0)
                    {
                        ++j;
                    }
                    else
                    {
                        throw new DataMisalignedException("unmatched!");
                    }
                }
 
                return mapping;
            }
        }
 
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
 
            CheckInputSchema(inputSchema);
 
            var outColumns = inputSchema.ToDictionary(x => x.Name);
 
            var scoreMetadata = new List<SchemaShape.Column>();
 
            scoreMetadata.Add(new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar,
                TextDataViewType.Instance, false));
            scoreMetadata.Add(new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar,
                TextDataViewType.Instance, false));
            scoreMetadata.Add(new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar,
                NumberDataViewType.UInt32, true));
 
            outColumns[Option.PredictedAnswerColumnName] = new SchemaShape.Column(Option.PredictedAnswerColumnName, SchemaShape.Column.VectorKind.VariableVector,
                    TextDataViewType.Instance, false);
 
            outColumns[Option.ScoreColumnName] = new SchemaShape.Column(Option.ScoreColumnName, SchemaShape.Column.VectorKind.VariableVector,
                NumberDataViewType.Single, false, new SchemaShape(scoreMetadata.ToArray()));
 
            return new SchemaShape(outColumns.Values);
        }
 
        private void CheckInputSchema(SchemaShape inputSchema)
        {
            // Verify that all required input columns are present, and are of the same type.
            if (!inputSchema.TryFindColumn(Option.ContextColumnName, out var contextCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Context", Option.ContextColumnName);
            if (contextCol.Kind != SchemaShape.Column.VectorKind.Scalar || contextCol.ItemType.RawType != typeof(ReadOnlyMemory<char>))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Context", Option.ContextColumnName,
                    TextDataViewType.Instance.ToString(), contextCol.GetTypeString());
 
            if (!inputSchema.TryFindColumn(Option.QuestionColumnName, out var questionCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Question", Option.QuestionColumnName);
            if (questionCol.Kind != SchemaShape.Column.VectorKind.Scalar || questionCol.ItemType.RawType != typeof(ReadOnlyMemory<char>))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Question", Option.QuestionColumnName,
                    TextDataViewType.Instance.ToString(), questionCol.GetTypeString());
 
            if (!inputSchema.TryFindColumn(Option.TrainingAnswerColumnName, out var answerCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "TrainingAnswer", Option.TrainingAnswerColumnName);
            if (answerCol.Kind != SchemaShape.Column.VectorKind.Scalar || answerCol.ItemType.RawType != typeof(ReadOnlyMemory<char>))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "TrainingAnswer", Option.TrainingAnswerColumnName,
                    TextDataViewType.Instance.ToString(), answerCol.GetTypeString());
 
            if (!inputSchema.TryFindColumn(Option.AnswerIndexStartColumnName, out var answerIndexCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "AnswerIndex", Option.AnswerIndexStartColumnName);
            if (answerIndexCol.Kind != SchemaShape.Column.VectorKind.Scalar || answerIndexCol.ItemType.RawType != typeof(int))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "AnswerIndex", Option.AnswerIndexStartColumnName,
                    NumberDataViewType.Int32.ToString(), answerIndexCol.GetTypeString());
        }
    }
 
    public class QATransformer : RowToRowTransformerBase, IDisposable
    {
        private protected readonly Device Device;
        private protected RobertaModelForQA Model;
        internal readonly QATrainer.Options Options;
 
        internal const string LoadName = "QATrainer";
        internal const string UserName = "QA Trainer";
        internal const string ShortName = "QA";
        internal const string Summary = "Question and Answer";
        internal const string LoaderSignature = "QATRAIN";
 
        public Tokenizer Tokenizer;
        private bool _disposedValue;
 
        internal QATransformer(IHostEnvironment env, QATrainer.Options options, RobertaModelForQA model)
           : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(QATransformer)))
        {
            Device = TorchUtils.InitializeDevice(env);
 
            Options = options;
 
            Model = model;
            Model.eval();
 
            if (Device.type == DeviceType.CUDA)
                Model.cuda();
            using (var ch = Host.Start("Initialize Tokenizer"))
                Tokenizer = TokenizerExtensions.GetInstance(ch);
        }
 
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
 
            CheckInputSchema(inputSchema);
 
            var outColumns = inputSchema.ToDictionary(x => x.Name);
 
            var scoreMetadata = new List<SchemaShape.Column>
            {
                new SchemaShape.Column(Kinds.ScoreColumnKind, SchemaShape.Column.VectorKind.Scalar,
                TextDataViewType.Instance, false),
                new SchemaShape.Column(Kinds.ScoreValueKind, SchemaShape.Column.VectorKind.Scalar,
                TextDataViewType.Instance, false),
                new SchemaShape.Column(Kinds.ScoreColumnSetId, SchemaShape.Column.VectorKind.Scalar,
                NumberDataViewType.UInt32, true)
            };
 
            outColumns[Options.PredictedAnswerColumnName] = new SchemaShape.Column(Options.PredictedAnswerColumnName, SchemaShape.Column.VectorKind.VariableVector,
                TextDataViewType.Instance, false);
 
            outColumns[Options.ScoreColumnName] = new SchemaShape.Column(Options.ScoreColumnName, SchemaShape.Column.VectorKind.VariableVector,
                NumberDataViewType.Single, false, new SchemaShape(scoreMetadata.ToArray()));
 
            return new SchemaShape(outColumns.Values);
        }
 
        private void CheckInputSchema(SchemaShape inputSchema)
        {
            if (!inputSchema.TryFindColumn(Options.ContextColumnName, out var contextCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Context", Options.ContextColumnName);
            if (contextCol.ItemType != TextDataViewType.Instance)
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Context", Options.ContextColumnName,
                    TextDataViewType.Instance.ToString(), contextCol.GetTypeString());
 
            if (!inputSchema.TryFindColumn(Options.QuestionColumnName, out var questionCol))
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Question", Options.QuestionColumnName);
            if (questionCol.ItemType != TextDataViewType.Instance)
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "Question", Options.QuestionColumnName,
                    TextDataViewType.Instance.ToString(), questionCol.GetTypeString());
        }
 
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "QA-ANSWR",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(QATransformer).Assembly.FullName);
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.AssertValue(ctx);
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format ***
            // int: id of the context column name
            // int: id of the question column name
            // int: id of the PredictedAnswer column name
            // int: id of the Score name
            // int: topk
            // BinaryStream: TS Model
 
            ctx.SaveNonEmptyString(Options.ContextColumnName);
            ctx.SaveNonEmptyString(Options.QuestionColumnName);
            ctx.SaveNonEmptyString(Options.PredictedAnswerColumnName);
            ctx.SaveNonEmptyString(Options.ScoreColumnName);
            ctx.Writer.Write(Options.TopKAnswers);
 
            ctx.SaveBinaryStream("TSModel", w =>
            {
                Model.save(w);
            });
        }
 
        private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new QAMapper(this, schema);
 
        //Factory method for SignatureLoadRowMapper.
        private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
            => Create(env, ctx).MakeRowMapper(inputSchema);
 
        // Factory method for SignatureLoadModel.
        private static QATransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            // *** Binary format ***
            // int: id of the context column name
            // int: id of the question column name
            // int: id of the PredictedAnswer column name
            // int: id of the Score name
            // int: topk
            // BinaryStream: TS Model
 
            var options = new QATrainer.Options()
            {
                ContextColumnName = ctx.LoadString(),
                QuestionColumnName = ctx.LoadString(),
                PredictedAnswerColumnName = ctx.LoadString(),
                ScoreColumnName = ctx.LoadString(),
                TopKAnswers = ctx.Reader.ReadInt32()
            };
 
            var ch = env.Start("Load Model");
 
            var model = new RobertaModelForQA(options);
 
            if (!ctx.TryLoadBinaryStream("TSModel", r => model.load(r)))
                throw env.ExceptDecode();
 
            return new QATransformer(env, options, model);
        }
 
        private class QAMapper : MapperBase
        {
            private readonly QATransformer _parent;
            private readonly HashSet<int> _inputColIndices;
 
            public QAMapper(QATransformer parent, DataViewSchema inputSchema) :
                base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(QAMapper)), inputSchema, parent)
            {
                _parent = parent;
                _inputColIndices = new HashSet<int>();
 
                if (inputSchema.TryGetColumnIndex(parent.Options.ContextColumnName, out var col))
                    _inputColIndices.Add(col);
                if (inputSchema.TryGetColumnIndex(parent.Options.QuestionColumnName, out col))
                    _inputColIndices.Add(col);
 
                if (Host is IHostEnvironmentInternal hostInternal)
                {
                    torch.random.manual_seed(hostInternal.Seed ?? 1);
                    torch.cuda.manual_seed(hostInternal.Seed ?? 1);
                }
                else
                {
                    torch.random.manual_seed(1);
                    torch.cuda.manual_seed(1);
                }
            }
 
            protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
            {
 
                var info = new DataViewSchema.DetachedColumn[2];
 
                var meta = new DataViewSchema.Annotations.Builder();
                meta.Add(AnnotationUtils.Kinds.ScoreColumnKind, TextDataViewType.Instance, (ref ReadOnlyMemory<char> value) => { value = AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification.AsMemory(); });
                meta.Add(AnnotationUtils.Kinds.ScoreColumnSetId, AnnotationUtils.ScoreColumnSetIdType, GetScoreColumnSetId(InputSchema));
                meta.Add(AnnotationUtils.Kinds.ScoreValueKind, TextDataViewType.Instance, (ref ReadOnlyMemory<char> value) => { value = AnnotationUtils.Const.ScoreValueKind.Score.AsMemory(); });
 
 
                info[0] = new DataViewSchema.DetachedColumn(_parent.Options.PredictedAnswerColumnName, new VectorDataViewType(TextDataViewType.Instance));
 
                info[1] = new DataViewSchema.DetachedColumn(_parent.Options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single), meta.ToAnnotations());
                return info;
            }
 
            private ValueGetter<uint> GetScoreColumnSetId(DataViewSchema schema)
            {
                int c;
                var max = schema.GetMaxAnnotationKind(out c, AnnotationUtils.Kinds.ScoreColumnSetId);
                uint id = checked(max + 1);
                return
                    (ref uint dst) => dst = id;
            }
 
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
                => throw new NotImplementedException("This should never be called!");
 
            private Delegate CreateGetter(DataViewRow input, int iinfo, TensorCacher outputCacher)
            {
                var ch = Host.Start("Make Getter");
                if (iinfo == 0)
                    return MakePredictedAnswerGetter(input, ch, outputCacher);
                else
                    return MakeScoreGetter(input, ch, outputCacher);
            }
 
            private Delegate MakeScoreGetter(DataViewRow input, IChannel ch, TensorCacher outputCacher)
            {
                ValueGetter<ReadOnlyMemory<char>> getContext = default;
                ValueGetter<ReadOnlyMemory<char>> getQuestion = default;
 
                ReadOnlyMemory<char> context = default;
                ReadOnlyMemory<char> question = default;
 
                getContext = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[_parent.Options.ContextColumnName]);
                getQuestion = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[_parent.Options.QuestionColumnName]);
 
                ValueGetter<VBuffer<float>> score = (ref VBuffer<float> dst) =>
                {
                    using var disposeScope = torch.NewDisposeScope();
                    UpdateCacheIfNeeded(input.Position, outputCacher, ref context, ref question, ref getContext, ref getQuestion);
                    var editor = VBufferEditor.Create(ref dst, outputCacher.ScoresBuffer.Length);
 
                    for (var i = 0; i < outputCacher.ScoresBuffer.Length; i++)
                    {
                        editor.Values[i] = outputCacher.ScoresBuffer[i];
                    }
                    dst = editor.Commit();
                };
 
                return score;
            }
 
            private Delegate MakePredictedAnswerGetter(DataViewRow input, IChannel ch, TensorCacher outputCacher)
            {
                ValueGetter<ReadOnlyMemory<char>> getContext = default;
                ValueGetter<ReadOnlyMemory<char>> getQuestion = default;
 
                ReadOnlyMemory<char> context = default;
                ReadOnlyMemory<char> question = default;
 
                getContext = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[_parent.Options.ContextColumnName]);
                getQuestion = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[_parent.Options.QuestionColumnName]);
 
                ValueGetter<VBuffer<ReadOnlyMemory<char>>> predictedAnswer = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                {
                    using var disposeScope = torch.NewDisposeScope();
                    UpdateCacheIfNeeded(input.Position, outputCacher, ref context, ref question, ref getContext, ref getQuestion);
                    var editor = VBufferEditor.Create(ref dst, outputCacher.PredictedAnswersBuffer.Length);
 
                    for (var i = 0; i < outputCacher.PredictedAnswersBuffer.Length; i++)
                    {
                        editor.Values[i] = outputCacher.PredictedAnswersBuffer[i];
                    }
                    dst = editor.Commit();
                };
 
                return predictedAnswer;
            }
 
            public override Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
            {
                Host.AssertValue(input);
                Contracts.Assert(input.Schema == base.InputSchema);
 
                TensorCacher outputCacher = new TensorCacher(_parent.Options.TopKAnswers);
                var ch = Host.Start("Make Getters");
                _parent.Model.eval();
 
                int n = OutputColumns.Value.Length;
                var result = new Delegate[n];
                for (int i = 0; i < n; i++)
                {
                    if (!activeOutput(i))
                        continue;
                    result[i] = CreateGetter(input, i, outputCacher);
                }
                disposer = () =>
                {
                    outputCacher.Dispose();
                };
                return result;
            }
 
            private Tensor PrepInputTensors(ref ReadOnlyMemory<char> context, ref ReadOnlyMemory<char> question, ValueGetter<ReadOnlyMemory<char>> contextGetter, ValueGetter<ReadOnlyMemory<char>> questionGetter, out int contextLength, out int questionLength, out int[] contextIds)
            {
 
                contextGetter(ref context);
                questionGetter(ref question);
 
                var contextTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.EncodeToIds(context.ToString()));
 
                var questionTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.EncodeToIds(question.ToString()));
 
                var srcTensor = torch.tensor((new[] { 0 /* InitToken */ }).Concat(questionTokenId).Concat(new[] { 2 /* SeparatorToken */ }).Concat(contextTokenId).ToList(), device: _parent.Device);
 
                if (srcTensor.NumberOfElements > 512)
                    srcTensor = srcTensor.slice(0, 0, 512, 1);
 
                contextLength = contextTokenId.Count;
                questionLength = questionTokenId.Count;
                contextIds = contextTokenId.ToArray();
 
                return srcTensor.reshape(1, srcTensor.NumberOfElements);
            }
 
            private Tensor PrepAndRunModel(Tensor inputTensor)
            {
                using (torch.NewDisposeScope())
                {
                    return _parent.Model.forward(inputTensor).MoveToOuterDisposeScope();
                }
            }
 
            private protected class TensorCacher : IDisposable
            {
                public long Position;
 
                public int MaxLength;
                public ReadOnlyMemory<char>[] PredictedAnswersBuffer;
                public Single[] ScoresBuffer;
 
                public TensorCacher(int maxLength)
                {
                    Position = -1;
                    MaxLength = maxLength;
 
                    PredictedAnswersBuffer = new ReadOnlyMemory<char>[maxLength];
                    ScoresBuffer = new float[maxLength];
                }
 
                private bool _isDisposed;
 
                public void Dispose()
                {
                    if (_isDisposed)
                        return;
 
                    _isDisposed = true;
                }
            }
 
            private protected void UpdateCacheIfNeeded(long position, TensorCacher outputCache, ref ReadOnlyMemory<char> context, ref ReadOnlyMemory<char> question, ref ValueGetter<ReadOnlyMemory<char>> getContext, ref ValueGetter<ReadOnlyMemory<char>> getQuestion)
            {
                if (outputCache.Position != position)
                {
 
                    var inputTensor = PrepInputTensors(ref context, ref question, getContext, getQuestion, out int contextLength, out int questionLength, out int[] contextIds);
                    _parent.Model.eval();
                    using (torch.no_grad())
                    {
                        var logits = PrepAndRunModel(inputTensor);
 
                        var topKSpans = MetricUtils.ComputeTopKSpansWithScore(logits, _parent.Options.TopKAnswers, questionLength, contextLength);
                        int index = 0;
                        foreach (var topKSpan in topKSpans)
                        {
                            var predictStart = topKSpan.start;
                            var predictEnd = topKSpan.end;
                            var score = topKSpan.score;
                            outputCache.PredictedAnswersBuffer[index] = new ReadOnlyMemory<char>(_parent.Tokenizer.Decode(_parent.Tokenizer.RobertaModel().ConvertOccurrenceRanksToIds(contextIds).ToArray().AsSpan(predictStart - questionLength - 2, predictEnd - predictStart).ToArray()).Trim().ToCharArray());
                            outputCache.ScoresBuffer[index++] = score;
                        }
 
                        logits.Dispose();
                    }
                    outputCache.Position = position;
                }
            }
 
            private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
 
            private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
            {
                return col => (activeOutput(0) || activeOutput(1)) && _inputColIndices.Any(i => i == col);
            }
        }
 
        protected virtual void Dispose(bool disposing)
        {
            if (!_disposedValue)
            {
                if (disposing)
                {
                    Model.Dispose();
                    Model = null;
                    _disposedValue = true;
                }
            }
        }
 
        ~QATransformer()
        {
            // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
            Dispose(disposing: false);
        }
 
        public void Dispose()
        {
            // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
            Dispose(disposing: true);
            GC.SuppressFinalize(this);
        }
    }
}