diff --git a/export.py b/export.py index e0f7a9b..d87a0d5 100644 --- a/export.py +++ b/export.py @@ -297,7 +297,9 @@ def hf_export(llama_model, filepath, group_size=64, dtype=torch.float16): hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype) hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype) - hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype) + # llama2.c uses tied weights, so we reference the embed_tokens.weights instead + #hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype) + hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight'] # Generate LlamaConfig (seen in transformers.models.llama.configuration_llama)