mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
Changed code so that lm_head and token_embed are tied
This commit is contained in:
parent
f38055dfb6
commit
fc11cc387b
@ -297,7 +297,9 @@ def hf_export(llama_model, filepath, group_size=64, dtype=torch.float16):
|
||||
hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype)
|
||||
hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype)
|
||||
|
||||
hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)
|
||||
# llama2.c uses tied weights, so we reference the embed_tokens.weights instead
|
||||
#hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)
|
||||
hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight']
|
||||
|
||||
|
||||
# Generate LlamaConfig (seen in transformers.models.llama.configuration_llama)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user