1 write to Host
Microsoft.ML.TorchSharp (1)
TorchSharpBaseTrainer.cs (1)
89
Host
= Contracts.CheckRef(env, nameof(env)).Register(nameof(TorchSharpBaseTrainer));
20 references to Host
Microsoft.ML.TorchSharp (20)
NasBert\NasBertTrainer.cs (11)
319
Host
.CheckValue(inputSchema, nameof(inputSchema));
366
throw
Host
.ExceptSchemaMismatch(nameof(inputSchema), "sentence", BertOptions.Sentence1ColumnName);
368
throw
Host
.ExceptSchemaMismatch(nameof(inputSchema), "sentence", BertOptions.Sentence1ColumnName,
372
throw
Host
.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName);
377
throw
Host
.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
384
throw
Host
.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName);
386
throw
Host
.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName,
393
throw
Host
.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
399
throw
Host
.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
403
throw
Host
.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName);
405
throw
Host
.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", BertOptions.Sentence2ColumnName,
TorchSharpBaseTrainer.cs (9)
103
using (var ch =
Host
.Start("TrainModel"))
104
using (var pch =
Host
.StartProgressChannel("Training model"))
112
Host
.CheckAlive();
113
trainer.Train(
Host
, input);
120
transformer = CreateTransformer(
Host
, Option, trainer.Model, new DataViewSchema.DetachedColumn(labelCol.Value));
159
Device = TorchUtils.InitializeDevice(Parent.
Host
);
171
var destDir = Path.Combine(((IHostEnvironmentInternal)Parent.
Host
).TempFilePath, "mlnet");
179
using (var ch = (Parent.
Host
as IHostEnvironment).Start("Ensuring model file is present."))
181
var ensureModel = ResourceManagerUtils.Instance.EnsureResourceAsync(Parent.
Host
, ch, modelUrl, destFileName, destDir, timeout);