diff --git a/export.py b/export.py index ffcb506..d909c9f 100644 --- a/export.py +++ b/export.py @@ -16,8 +16,9 @@ This script aspires to provide all of these conversions. """ import struct import argparse -import torch import numpy as np +import torch +from torch import nn from model import ModelArgs, Transformer @@ -72,6 +73,10 @@ def legacy_export(model, filepath): # first write out the header hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] p = model.params + shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight) + # legacy format uses negative/positive vocab size as a shared classifier flag + if not shared_classifier: + p.vocab_size = -p.vocab_size n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, n_kv_heads, p.vocab_size, p.max_seq_len) @@ -103,11 +108,14 @@ def legacy_export(model, filepath): serialize_fp32(out_file, layer.feed_forward.w3.weight) # final rmsnorm serialize_fp32(out_file, model.norm.weight) - # note: no need to write final classifier weights due to weight sharing # freqs_cis serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len]) serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len]) + # final classifier weights + if not shared_classifier: + serialize_fp32(out_file, model.output.weight) + # write to binary file out_file.close() print(f"wrote {filepath}") @@ -136,8 +144,8 @@ def version1_export(model, filepath): n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) # 4) write some other flags - shared_classifier = 1 # we do share a classifier, write flag as a byte - out_file.write(struct.pack('B', shared_classifier)) + shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight) + out_file.write(struct.pack('B', int(shared_classifier))) pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos assert pad >= 0 out_file.write(b'\0' * pad) @@ -156,6 +164,8 @@ def version1_export(model, filepath): *[layer.feed_forward.w2.weight for layer in model.layers], *[layer.feed_forward.w3.weight for layer in model.layers], ] + if not shared_classifier: + weights.append(model.output.weight) for w in weights: serialize_fp32(out_file, w) @@ -187,6 +197,9 @@ def version2_export(model, filepath, group_size=64): *[layer.feed_forward.w2.weight for layer in model.layers], *[layer.feed_forward.w3.weight for layer in model.layers], ] + shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight) + if not shared_classifier: + weights.append(model.output.weight) for w in weights: assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}" @@ -205,8 +218,7 @@ def version2_export(model, filepath, group_size=64): n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) # 4) write some other flags - shared_classifier = 1 # we do share a classifier, write flag as a byte - out_file.write(struct.pack('B', shared_classifier)) + out_file.write(struct.pack('B', int(shared_classifier))) out_file.write(struct.pack('i', group_size)) # group size used for quantization pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos assert pad >= 0 @@ -247,6 +259,77 @@ def version2_export(model, filepath, group_size=64): out_file.close() print(f"wrote {filepath}") + +# ----------------------------------------------------------------------------- +# Load / import functions + +def load_checkpoint(checkpoint): + + # load the provided model checkpoint + checkpoint_dict = torch.load(checkpoint, map_location='cpu') + gptconf = ModelArgs(**checkpoint_dict['model_args']) + model = Transformer(gptconf) + state_dict = checkpoint_dict['model'] + unwanted_prefix = '_orig_mod.' + for k,v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) + model.load_state_dict(state_dict, strict=False) + model.eval() + return model + +def load_hf_model(model_path): + + try: + from transformers import AutoModelForCausalLM + except ImportError: + print("Error: transformers package is required to load huggingface models") + print("Please run `pip install transformers` to install it") + return None + + # load HF model + hf_model = AutoModelForCausalLM.from_pretrained(model_path) + hf_dict = hf_model.state_dict() + + # convert LlamaConfig to ModelArgs + config = ModelArgs() + config.dim = hf_model.config.hidden_size + config.n_layers = hf_model.config.num_hidden_layers + config.n_heads = hf_model.config.num_attention_heads + config.n_kv_heads = hf_model.config.num_attention_heads + config.vocab_size = hf_model.config.vocab_size + config.hidden_dim = hf_model.config.intermediate_size + config.norm_eps = hf_model.config.rms_norm_eps + config.max_seq_len = hf_model.config.max_position_embeddings + + # create a new Transformer object and set weights + model = Transformer(config) + + model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight']) + model.norm.weight = nn.Parameter(hf_dict['model.norm.weight']) + + # huggingface permutes WQ and WK, this function reverses it + def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim): + return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + for layer in model.layers: + i = layer.layer_id + layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight']) + layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight'])) + layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight'])) + layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight']) + layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight']) + layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight']) + layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight']) + layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight']) + layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight']) + + # final classifier + model.output.weight = nn.Parameter(hf_dict['lm_head.weight']) + model.eval() + return model + + # ----------------------------------------------------------------------------- # API entrypoint @@ -267,21 +350,20 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("filepath", type=str, help="the output filepath") - parser.add_argument("--checkpoint", default="", type=str, help="model checkpoint, .pt file") + parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file") + parser.add_argument("--hf", type=str, help="huggingface model") parser.add_argument("--version", default=0, type=int, help="the version to export with") args = parser.parse_args() - # load the provided model checkpoint - checkpoint_dict = torch.load(args.checkpoint, map_location='cpu') - gptconf = ModelArgs(**checkpoint_dict['model_args']) - model = Transformer(gptconf) - state_dict = checkpoint_dict['model'] - unwanted_prefix = '_orig_mod.' - for k,v in list(state_dict.items()): - if k.startswith(unwanted_prefix): - state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) - model.load_state_dict(state_dict, strict=False) - model.eval() + if args.checkpoint: + model = load_checkpoint(args.checkpoint) + elif args.hf: + model = load_hf_model(args.hf) + else: + parser.error("Input model missing: --checkpoint or --hf is required") + + if model is None: + parser.error("Can't load input model!") # export model_export(model, args.filepath, args.version) diff --git a/model.py b/model.py index 044712f..9e4ce22 100644 --- a/model.py +++ b/model.py @@ -17,6 +17,7 @@ class ModelArgs: n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = 32000 + hidden_dim: Optional[int] = None multiple_of: int = 256 # MLP hidden layer size will be multiple of norm_eps: float = 1e-5 max_seq_len: int = 2048 @@ -166,8 +167,10 @@ class Attention(nn.Module): class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float): super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + if hidden_dim is None: + hidden_dim = 4 * dim + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) @@ -186,7 +189,7 @@ class TransformerBlock(nn.Module): self.attention = Attention(args) self.feed_forward = FeedForward( dim=args.dim, - hidden_dim=4 * args.dim, + hidden_dim=args.hidden_dim, multiple_of=args.multiple_of, dropout=args.dropout, )