Changed code so that lm_head and token_embed are tied

This commit is contained in:
Nicky Pochinkov 2023-09-16 18:10:36 +01:00
parent f38055dfb6
commit fc11cc387b

View File

@ -297,7 +297,9 @@ def hf_export(llama_model, filepath, group_size=64, dtype=torch.float16):
hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype)
hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)
# llama2.c uses tied weights, so we reference the embed_tokens.weights instead
#hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)
hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight']
# Generate LlamaConfig (seen in transformers.models.llama.configuration_llama)