1 write to _parent
Microsoft.ML.TorchSharp (1)
Roberta\QATrainer.cs (1)
715_parent = parent;
19 references to _parent
Microsoft.ML.TorchSharp (19)
Roberta\QATrainer.cs (19)
746info[0] = new DataViewSchema.DetachedColumn(_parent.Options.PredictedAnswerColumnName, new VectorDataViewType(TextDataViewType.Instance)); 748info[1] = new DataViewSchema.DetachedColumn(_parent.Options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single), meta.ToAnnotations()); 781getContext = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[_parent.Options.ContextColumnName]); 782getQuestion = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[_parent.Options.QuestionColumnName]); 808getContext = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[_parent.Options.ContextColumnName]); 809getQuestion = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[_parent.Options.QuestionColumnName]); 832TensorCacher outputCacher = new TensorCacher(_parent.Options.TopKAnswers); 834_parent.Model.eval(); 857var contextTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.EncodeToIds(context.ToString())); 859var questionTokenId = _parent.Tokenizer.RobertaModel().ConvertIdsToOccurrenceRanks(_parent.Tokenizer.EncodeToIds(question.ToString())); 861var srcTensor = torch.tensor((new[] { 0 /* InitToken */ }).Concat(questionTokenId).Concat(new[] { 2 /* SeparatorToken */ }).Concat(contextTokenId).ToList(), device: _parent.Device); 877return _parent.Model.forward(inputTensor).MoveToOuterDisposeScope(); 915_parent.Model.eval(); 920var topKSpans = MetricUtils.ComputeTopKSpansWithScore(logits, _parent.Options.TopKAnswers, questionLength, contextLength); 927outputCache.PredictedAnswersBuffer[index] = new ReadOnlyMemory<char>(_parent.Tokenizer.Decode(_parent.Tokenizer.RobertaModel().ConvertOccurrenceRanksToIds(contextIds).ToArray().AsSpan(predictStart - questionLength - 2, predictEnd - predictStart).ToArray()).Trim().ToCharArray()); 937private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);