mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
Merge pull request #326 from atamurad/import_hf
Added huggingface model loader/importer to export.py
This commit is contained in:
commit
801c68f5a1
118
export.py
118
export.py
@ -16,8 +16,9 @@ This script aspires to provide all of these conversions.
|
||||
"""
|
||||
import struct
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from model import ModelArgs, Transformer
|
||||
|
||||
@ -72,6 +73,10 @@ def legacy_export(model, filepath):
|
||||
# first write out the header
|
||||
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
|
||||
p = model.params
|
||||
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
||||
# legacy format uses negative/positive vocab size as a shared classifier flag
|
||||
if not shared_classifier:
|
||||
p.vocab_size = -p.vocab_size
|
||||
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
|
||||
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
|
||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||
@ -103,11 +108,14 @@ def legacy_export(model, filepath):
|
||||
serialize_fp32(out_file, layer.feed_forward.w3.weight)
|
||||
# final rmsnorm
|
||||
serialize_fp32(out_file, model.norm.weight)
|
||||
# note: no need to write final classifier weights due to weight sharing
|
||||
# freqs_cis
|
||||
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
|
||||
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
|
||||
|
||||
# final classifier weights
|
||||
if not shared_classifier:
|
||||
serialize_fp32(out_file, model.output.weight)
|
||||
|
||||
# write to binary file
|
||||
out_file.close()
|
||||
print(f"wrote {filepath}")
|
||||
@ -136,8 +144,8 @@ def version1_export(model, filepath):
|
||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||
out_file.write(header)
|
||||
# 4) write some other flags
|
||||
shared_classifier = 1 # we do share a classifier, write flag as a byte
|
||||
out_file.write(struct.pack('B', shared_classifier))
|
||||
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
||||
out_file.write(struct.pack('B', int(shared_classifier)))
|
||||
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
|
||||
assert pad >= 0
|
||||
out_file.write(b'\0' * pad)
|
||||
@ -156,6 +164,8 @@ def version1_export(model, filepath):
|
||||
*[layer.feed_forward.w2.weight for layer in model.layers],
|
||||
*[layer.feed_forward.w3.weight for layer in model.layers],
|
||||
]
|
||||
if not shared_classifier:
|
||||
weights.append(model.output.weight)
|
||||
for w in weights:
|
||||
serialize_fp32(out_file, w)
|
||||
|
||||
@ -187,6 +197,9 @@ def version2_export(model, filepath, group_size=64):
|
||||
*[layer.feed_forward.w2.weight for layer in model.layers],
|
||||
*[layer.feed_forward.w3.weight for layer in model.layers],
|
||||
]
|
||||
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
|
||||
if not shared_classifier:
|
||||
weights.append(model.output.weight)
|
||||
for w in weights:
|
||||
assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
|
||||
|
||||
@ -205,8 +218,7 @@ def version2_export(model, filepath, group_size=64):
|
||||
n_kv_heads, p.vocab_size, p.max_seq_len)
|
||||
out_file.write(header)
|
||||
# 4) write some other flags
|
||||
shared_classifier = 1 # we do share a classifier, write flag as a byte
|
||||
out_file.write(struct.pack('B', shared_classifier))
|
||||
out_file.write(struct.pack('B', int(shared_classifier)))
|
||||
out_file.write(struct.pack('i', group_size)) # group size used for quantization
|
||||
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
|
||||
assert pad >= 0
|
||||
@ -247,6 +259,77 @@ def version2_export(model, filepath, group_size=64):
|
||||
out_file.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Load / import functions
|
||||
|
||||
def load_checkpoint(checkpoint):
|
||||
|
||||
# load the provided model checkpoint
|
||||
checkpoint_dict = torch.load(checkpoint, map_location='cpu')
|
||||
gptconf = ModelArgs(**checkpoint_dict['model_args'])
|
||||
model = Transformer(gptconf)
|
||||
state_dict = checkpoint_dict['model']
|
||||
unwanted_prefix = '_orig_mod.'
|
||||
for k,v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def load_hf_model(model_path):
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM
|
||||
except ImportError:
|
||||
print("Error: transformers package is required to load huggingface models")
|
||||
print("Please run `pip install transformers` to install it")
|
||||
return None
|
||||
|
||||
# load HF model
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||
hf_dict = hf_model.state_dict()
|
||||
|
||||
# convert LlamaConfig to ModelArgs
|
||||
config = ModelArgs()
|
||||
config.dim = hf_model.config.hidden_size
|
||||
config.n_layers = hf_model.config.num_hidden_layers
|
||||
config.n_heads = hf_model.config.num_attention_heads
|
||||
config.n_kv_heads = hf_model.config.num_attention_heads
|
||||
config.vocab_size = hf_model.config.vocab_size
|
||||
config.hidden_dim = hf_model.config.intermediate_size
|
||||
config.norm_eps = hf_model.config.rms_norm_eps
|
||||
config.max_seq_len = hf_model.config.max_position_embeddings
|
||||
|
||||
# create a new Transformer object and set weights
|
||||
model = Transformer(config)
|
||||
|
||||
model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight'])
|
||||
model.norm.weight = nn.Parameter(hf_dict['model.norm.weight'])
|
||||
|
||||
# huggingface permutes WQ and WK, this function reverses it
|
||||
def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim):
|
||||
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
for layer in model.layers:
|
||||
i = layer.layer_id
|
||||
layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
|
||||
layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']))
|
||||
layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']))
|
||||
layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
|
||||
layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
|
||||
layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
|
||||
layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight'])
|
||||
layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight'])
|
||||
layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight'])
|
||||
|
||||
# final classifier
|
||||
model.output.weight = nn.Parameter(hf_dict['lm_head.weight'])
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# API entrypoint
|
||||
|
||||
@ -267,21 +350,20 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("filepath", type=str, help="the output filepath")
|
||||
parser.add_argument("--checkpoint", default="", type=str, help="model checkpoint, .pt file")
|
||||
parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
||||
parser.add_argument("--hf", type=str, help="huggingface model")
|
||||
parser.add_argument("--version", default=0, type=int, help="the version to export with")
|
||||
args = parser.parse_args()
|
||||
|
||||
# load the provided model checkpoint
|
||||
checkpoint_dict = torch.load(args.checkpoint, map_location='cpu')
|
||||
gptconf = ModelArgs(**checkpoint_dict['model_args'])
|
||||
model = Transformer(gptconf)
|
||||
state_dict = checkpoint_dict['model']
|
||||
unwanted_prefix = '_orig_mod.'
|
||||
for k,v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.eval()
|
||||
if args.checkpoint:
|
||||
model = load_checkpoint(args.checkpoint)
|
||||
elif args.hf:
|
||||
model = load_hf_model(args.hf)
|
||||
else:
|
||||
parser.error("Input model missing: --checkpoint or --hf is required")
|
||||
|
||||
if model is None:
|
||||
parser.error("Can't load input model!")
|
||||
|
||||
# export
|
||||
model_export(model, args.filepath, args.version)
|
||||
|
||||
9
model.py
9
model.py
@ -17,6 +17,7 @@ class ModelArgs:
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
vocab_size: int = 32000
|
||||
hidden_dim: Optional[int] = None
|
||||
multiple_of: int = 256 # MLP hidden layer size will be multiple of
|
||||
norm_eps: float = 1e-5
|
||||
max_seq_len: int = 2048
|
||||
@ -166,8 +167,10 @@ class Attention(nn.Module):
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
if hidden_dim is None:
|
||||
hidden_dim = 4 * dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
@ -186,7 +189,7 @@ class TransformerBlock(nn.Module):
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=4 * args.dim,
|
||||
hidden_dim=args.hidden_dim,
|
||||
multiple_of=args.multiple_of,
|
||||
dropout=args.dropout,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user