Merge pull request #326 from atamurad/import_hf

Added huggingface model loader/importer to export.py
This commit is contained in:
Andrej 2023-08-20 21:53:17 -07:00 committed by GitHub
commit 801c68f5a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 21 deletions

118
export.py
View File

@ -16,8 +16,9 @@ This script aspires to provide all of these conversions.
"""
import struct
import argparse
import torch
import numpy as np
import torch
from torch import nn
from model import ModelArgs, Transformer
@ -72,6 +73,10 @@ def legacy_export(model, filepath):
# first write out the header
hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
p = model.params
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
# legacy format uses negative/positive vocab size as a shared classifier flag
if not shared_classifier:
p.vocab_size = -p.vocab_size
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, p.vocab_size, p.max_seq_len)
@ -103,11 +108,14 @@ def legacy_export(model, filepath):
serialize_fp32(out_file, layer.feed_forward.w3.weight)
# final rmsnorm
serialize_fp32(out_file, model.norm.weight)
# note: no need to write final classifier weights due to weight sharing
# freqs_cis
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
# final classifier weights
if not shared_classifier:
serialize_fp32(out_file, model.output.weight)
# write to binary file
out_file.close()
print(f"wrote {filepath}")
@ -136,8 +144,8 @@ def version1_export(model, filepath):
n_kv_heads, p.vocab_size, p.max_seq_len)
out_file.write(header)
# 4) write some other flags
shared_classifier = 1 # we do share a classifier, write flag as a byte
out_file.write(struct.pack('B', shared_classifier))
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
out_file.write(struct.pack('B', int(shared_classifier)))
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
assert pad >= 0
out_file.write(b'\0' * pad)
@ -156,6 +164,8 @@ def version1_export(model, filepath):
*[layer.feed_forward.w2.weight for layer in model.layers],
*[layer.feed_forward.w3.weight for layer in model.layers],
]
if not shared_classifier:
weights.append(model.output.weight)
for w in weights:
serialize_fp32(out_file, w)
@ -187,6 +197,9 @@ def version2_export(model, filepath, group_size=64):
*[layer.feed_forward.w2.weight for layer in model.layers],
*[layer.feed_forward.w3.weight for layer in model.layers],
]
shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
if not shared_classifier:
weights.append(model.output.weight)
for w in weights:
assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
@ -205,8 +218,7 @@ def version2_export(model, filepath, group_size=64):
n_kv_heads, p.vocab_size, p.max_seq_len)
out_file.write(header)
# 4) write some other flags
shared_classifier = 1 # we do share a classifier, write flag as a byte
out_file.write(struct.pack('B', shared_classifier))
out_file.write(struct.pack('B', int(shared_classifier)))
out_file.write(struct.pack('i', group_size)) # group size used for quantization
pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
assert pad >= 0
@ -247,6 +259,77 @@ def version2_export(model, filepath, group_size=64):
out_file.close()
print(f"wrote {filepath}")
# -----------------------------------------------------------------------------
# Load / import functions
def load_checkpoint(checkpoint):
# load the provided model checkpoint
checkpoint_dict = torch.load(checkpoint, map_location='cpu')
gptconf = ModelArgs(**checkpoint_dict['model_args'])
model = Transformer(gptconf)
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
model.eval()
return model
def load_hf_model(model_path):
try:
from transformers import AutoModelForCausalLM
except ImportError:
print("Error: transformers package is required to load huggingface models")
print("Please run `pip install transformers` to install it")
return None
# load HF model
hf_model = AutoModelForCausalLM.from_pretrained(model_path)
hf_dict = hf_model.state_dict()
# convert LlamaConfig to ModelArgs
config = ModelArgs()
config.dim = hf_model.config.hidden_size
config.n_layers = hf_model.config.num_hidden_layers
config.n_heads = hf_model.config.num_attention_heads
config.n_kv_heads = hf_model.config.num_attention_heads
config.vocab_size = hf_model.config.vocab_size
config.hidden_dim = hf_model.config.intermediate_size
config.norm_eps = hf_model.config.rms_norm_eps
config.max_seq_len = hf_model.config.max_position_embeddings
# create a new Transformer object and set weights
model = Transformer(config)
model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight'])
model.norm.weight = nn.Parameter(hf_dict['model.norm.weight'])
# huggingface permutes WQ and WK, this function reverses it
def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim):
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
for layer in model.layers:
i = layer.layer_id
layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']))
layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']))
layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight'])
layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight'])
layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight'])
# final classifier
model.output.weight = nn.Parameter(hf_dict['lm_head.weight'])
model.eval()
return model
# -----------------------------------------------------------------------------
# API entrypoint
@ -267,21 +350,20 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("filepath", type=str, help="the output filepath")
parser.add_argument("--checkpoint", default="", type=str, help="model checkpoint, .pt file")
parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
parser.add_argument("--hf", type=str, help="huggingface model")
parser.add_argument("--version", default=0, type=int, help="the version to export with")
args = parser.parse_args()
# load the provided model checkpoint
checkpoint_dict = torch.load(args.checkpoint, map_location='cpu')
gptconf = ModelArgs(**checkpoint_dict['model_args'])
model = Transformer(gptconf)
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
model.eval()
if args.checkpoint:
model = load_checkpoint(args.checkpoint)
elif args.hf:
model = load_hf_model(args.hf)
else:
parser.error("Input model missing: --checkpoint or --hf is required")
if model is None:
parser.error("Can't load input model!")
# export
model_export(model, args.filepath, args.version)

View File

@ -17,6 +17,7 @@ class ModelArgs:
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = 32000
hidden_dim: Optional[int] = None
multiple_of: int = 256 # MLP hidden layer size will be multiple of
norm_eps: float = 1e-5
max_seq_len: int = 2048
@ -166,8 +167,10 @@ class Attention(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
@ -186,7 +189,7 @@ class TransformerBlock(nn.Module):
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
hidden_dim=args.hidden_dim,
multiple_of=args.multiple_of,
dropout=args.dropout,
)