2 writes to lm_head
Microsoft.ML.GenAI.LLaMA (2)
LlamaForCausalLM.cs (2)
41
lm_head
= nn.Linear(config.HiddenSize, config.VocabSize, hasBias: false, dtype: config.DType);
45
lm_head
= nn.Linear(config.HiddenSize, config.VocabSize, hasBias: false, dtype: config.DType);
3 references to lm_head
Microsoft.ML.GenAI.LLaMA (3)
LlamaForCausalLM.cs (3)
54
this.
lm_head
.load_state_dict(embeddingWeight);
56
this.
lm_head
.to(device: model.Embedding.weight!.device);
64
var logits = this.
lm_head
.forward(outputs.LastHiddenState);