2 writes to lm_head
Microsoft.ML.GenAI.LLaMA (2)
LlamaForCausalLM.cs (2)
41lm_head = nn.Linear(config.HiddenSize, config.VocabSize, hasBias: false, dtype: config.DType); 45lm_head = nn.Linear(config.HiddenSize, config.VocabSize, hasBias: false, dtype: config.DType);
3 references to lm_head
Microsoft.ML.GenAI.LLaMA (3)
LlamaForCausalLM.cs (3)
54this.lm_head.load_state_dict(embeddingWeight); 56this.lm_head.to(device: model.Embedding.weight!.device); 64var logits = this.lm_head.forward(outputs.LastHiddenState);