2 types derived from BaseHead
Microsoft.ML.TorchSharp (2)
NasBert\Models\PredictionHead.cs (1)
11internal sealed class PredictionHead : BaseHead, torch.nn.IModule<torch.Tensor, torch.Tensor>
NasBert\Models\SequenceLabelHead.cs (1)
12internal sealed class SequenceLabelHead : BaseHead, torch.nn.IModule<torch.Tensor, torch.Tensor>
4 references to BaseHead
Microsoft.ML.TorchSharp (4)
NasBert\Models\BaseModel.cs (1)
20public abstract BaseHead GetHead();
NasBert\Models\ModelPrediction.cs (1)
15public override BaseHead GetHead() => PredictionHead;
NasBert\Models\NerModel.cs (1)
16public override BaseHead GetHead() => NerHead;
Roberta\Models\RobertaModelForQA.cs (1)
17public override BaseHead GetHead() => QAHead;