1 write to BertOptions
Microsoft.ML.TorchSharp (1)
NasBert\NasBertTrainer.cs (1)
171BertOptions = options as NasBertTrainer.NasBertOptions;
35 references to BertOptions
Microsoft.ML.TorchSharp (35)
NasBert\NasBertTrainer.cs (35)
172Contracts.AssertValue(BertOptions.Sentence1ColumnName); 173Contracts.Assert(BertOptions.TaskType != BertTaskType.None, "BertTaskType must be specified"); 189Optimizer = BaseOptimizer.GetOptimizer(Parent.BertOptions, parameters); 193max_lr: Parent.BertOptions.LearningRate[0], 195pct_start: Parent.BertOptions.WarmupRatio, 207if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition) 208model = new NerModel(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses); 210model = new ModelForPrediction(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses); 218if (Parent.BertOptions.Sentence2ColumnName != default) 219return input.GetRowCursor(input.Schema[Parent.BertOptions.Sentence1ColumnName], input.Schema[Parent.BertOptions.Sentence2ColumnName], input.Schema[Parent.Option.LabelColumnName]); 221return input.GetRowCursor(input.Schema[Parent.BertOptions.Sentence1ColumnName], input.Schema[Parent.Option.LabelColumnName]); 226Sentence1Getter = cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.BertOptions.Sentence1ColumnName]); 227Sentence2Getter = Parent.BertOptions.Sentence2ColumnName != default ? cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.BertOptions.Sentence2ColumnName]) : default; 271if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition) 294if (Parent.BertOptions.TaskType == BertTaskType.TextClassification) 295loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor); 296else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition) 300loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor); 304loss = torch.nn.MSELoss(reduction: Parent.BertOptions.Reduction).forward(logits.squeeze(), targetsTensor); 325if (BertOptions.TaskType == BertTaskType.TextClassification) 341else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition) 365if (!inputSchema.TryFindColumn(BertOptions.Sentence1ColumnName, out var sentenceCol)) 366throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence", BertOptions.Sentence1ColumnName); 368throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence", BertOptions.Sentence1ColumnName, 374if (BertOptions.TaskType == BertTaskType.TextClassification) 381if (BertOptions.Sentence2ColumnName != default) 383if (!inputSchema.TryFindColumn(BertOptions.Sentence2ColumnName, out var sentenceCol2)) 384throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName); 386throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName, 390else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition) 402if (!inputSchema.TryFindColumn(BertOptions.Sentence2ColumnName, out var sentenceCol2)) 403throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName); 405throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName,