35 references to Parent
Microsoft.ML.TorchSharp (35)
NasBert\NasBertTrainer.cs (28)
189Optimizer = BaseOptimizer.GetOptimizer(Parent.BertOptions, parameters); 193max_lr: Parent.BertOptions.LearningRate[0], 194total_steps: ((TrainingRowCount / Parent.Option.BatchSize) + 1) * Parent.Option.MaxEpoch, 195pct_start: Parent.BertOptions.WarmupRatio, 197div_factor: 1.0 / Parent.Option.StartLearningRateRatio, 198final_div_factor: Parent.Option.StartLearningRateRatio / Parent.Option.FinalLearningRateRatio); 207if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition) 208model = new NerModel(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses); 210model = new ModelForPrediction(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses); 218if (Parent.BertOptions.Sentence2ColumnName != default) 219return input.GetRowCursor(input.Schema[Parent.BertOptions.Sentence1ColumnName], input.Schema[Parent.BertOptions.Sentence2ColumnName], input.Schema[Parent.Option.LabelColumnName]); 221return input.GetRowCursor(input.Schema[Parent.BertOptions.Sentence1ColumnName], input.Schema[Parent.Option.LabelColumnName]); 226Sentence1Getter = cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.BertOptions.Sentence1ColumnName]); 227Sentence2Getter = Parent.BertOptions.Sentence2ColumnName != default ? cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.BertOptions.Sentence2ColumnName]) : default; 271if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition) 294if (Parent.BertOptions.TaskType == BertTaskType.TextClassification) 295loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor); 296else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition) 300loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor); 304loss = torch.nn.MSELoss(reduction: Parent.BertOptions.Reduction).forward(logits.squeeze(), targetsTensor);
NasBert\NerTrainer.cs (3)
216input.Schema[Parent.Option.LabelColumnName].GetKeyValues(ref keys); 217var labelCol = input.GetColumn<VBuffer<uint>>(Parent.Option.LabelColumnName); 225Parent.Option.NumberOfClasses = keys.Length + 1;
NasBert\SentenceSimilarityTrainer.cs (2)
148var labelCol = input.GetColumn<float>(Parent.Option.LabelColumnName); 157Parent.Option.NumberOfClasses = 1;
NasBert\TextClassificationTrainer.cs (2)
150var labelCol = input.GetColumn<uint>(Parent.Option.LabelColumnName); 160Parent.Option.NumberOfClasses = uniqueLabels.Count;