1 write to BertOptions
Microsoft.ML.TorchSharp (1)
NasBert\NasBertTrainer.cs (1)
171
BertOptions
= options as NasBertTrainer.NasBertOptions;
35 references to BertOptions
Microsoft.ML.TorchSharp (35)
NasBert\NasBertTrainer.cs (35)
172
Contracts.AssertValue(
BertOptions
.Sentence1ColumnName);
173
Contracts.Assert(
BertOptions
.TaskType != BertTaskType.None, "BertTaskType must be specified");
189
Optimizer = BaseOptimizer.GetOptimizer(Parent.
BertOptions
, parameters);
193
max_lr: Parent.
BertOptions
.LearningRate[0],
195
pct_start: Parent.
BertOptions
.WarmupRatio,
207
if (Parent.
BertOptions
.TaskType == BertTaskType.NamedEntityRecognition)
208
model = new NerModel(Parent.
BertOptions
, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
210
model = new ModelForPrediction(Parent.
BertOptions
, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
218
if (Parent.
BertOptions
.Sentence2ColumnName != default)
219
return input.GetRowCursor(input.Schema[Parent.
BertOptions
.Sentence1ColumnName], input.Schema[Parent.
BertOptions
.Sentence2ColumnName], input.Schema[Parent.Option.LabelColumnName]);
221
return input.GetRowCursor(input.Schema[Parent.
BertOptions
.Sentence1ColumnName], input.Schema[Parent.Option.LabelColumnName]);
226
Sentence1Getter = cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.
BertOptions
.Sentence1ColumnName]);
227
Sentence2Getter = Parent.
BertOptions
.Sentence2ColumnName != default ? cursor.GetGetter<ReadOnlyMemory<char>>(input.Schema[Parent.
BertOptions
.Sentence2ColumnName]) : default;
271
if (Parent.
BertOptions
.TaskType == BertTaskType.NamedEntityRecognition)
294
if (Parent.
BertOptions
.TaskType == BertTaskType.TextClassification)
295
loss = torch.nn.CrossEntropyLoss(reduction: Parent.
BertOptions
.Reduction).forward(logits, targetsTensor);
296
else if (Parent.
BertOptions
.TaskType == BertTaskType.NamedEntityRecognition)
300
loss = torch.nn.CrossEntropyLoss(reduction: Parent.
BertOptions
.Reduction).forward(logits, targetsTensor);
304
loss = torch.nn.MSELoss(reduction: Parent.
BertOptions
.Reduction).forward(logits.squeeze(), targetsTensor);
325
if (
BertOptions
.TaskType == BertTaskType.TextClassification)
341
else if (
BertOptions
.TaskType == BertTaskType.NamedEntityRecognition)
365
if (!inputSchema.TryFindColumn(
BertOptions
.Sentence1ColumnName, out var sentenceCol))
366
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence",
BertOptions
.Sentence1ColumnName);
368
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence",
BertOptions
.Sentence1ColumnName,
374
if (
BertOptions
.TaskType == BertTaskType.TextClassification)
381
if (
BertOptions
.Sentence2ColumnName != default)
383
if (!inputSchema.TryFindColumn(
BertOptions
.Sentence2ColumnName, out var sentenceCol2))
384
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2",
BertOptions
.Sentence2ColumnName);
386
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2",
BertOptions
.Sentence2ColumnName,
390
else if (
BertOptions
.TaskType == BertTaskType.NamedEntityRecognition)
402
if (!inputSchema.TryFindColumn(
BertOptions
.Sentence2ColumnName, out var sentenceCol2))
403
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2",
BertOptions
.Sentence2ColumnName);
405
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2",
BertOptions
.Sentence2ColumnName,