2 types derived from BaseHead
Microsoft.ML.TorchSharp (2)
NasBert\Models\PredictionHead.cs (1)
11
internal sealed class PredictionHead :
BaseHead
, torch.nn.IModule<torch.Tensor, torch.Tensor>
NasBert\Models\SequenceLabelHead.cs (1)
12
internal sealed class SequenceLabelHead :
BaseHead
, torch.nn.IModule<torch.Tensor, torch.Tensor>
4 references to BaseHead
Microsoft.ML.TorchSharp (4)
NasBert\Models\BaseModel.cs (1)
20
public abstract
BaseHead
GetHead();
NasBert\Models\ModelPrediction.cs (1)
15
public override
BaseHead
GetHead() => PredictionHead;
NasBert\Models\NerModel.cs (1)
16
public override
BaseHead
GetHead() => NerHead;
Roberta\Models\RobertaModelForQA.cs (1)
17
public override
BaseHead
GetHead() => QAHead;