Merge pull request #329 from atamurad/import_meta

Moved export_meta_llama_bin.py to new export.py
This commit is contained in:
Andrej 2023-08-21 07:34:32 -07:00 committed by GitHub
commit 8a3ea7b433
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 120 deletions

View File

@ -65,10 +65,10 @@ Quick note on sampling, the recommendation for ~best results is to sample with `
## Meta's Llama 2 models
As the neural net architecture is identical, we can also inference the Llama 2 models released by Meta. Sadly there is a bit of friction here due to licensing (I can't directly upload the checkpoints, I think). So Step 1, get the Llama 2 checkpoints by following the [Meta instructions](https://github.com/facebookresearch/llama). Once we have those checkpoints, we have to convert them into the llama2.c format.
For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export_meta_llama_bin.py` file, e.g. for 7B model:
For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export.py` file, e.g. for 7B model:
```bash
python export_meta_llama_bin.py path/to/llama/model/7B llama2_7b.bin
python export.py llama2_7b.bin --meta-llama path/to/llama/model/7B
```
The export will take ~10 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts. I would not attempt to run anything above 7B right now for two reasons: first, 13B+ currently doesn't work because of integer flow in pointer arithmetic, which is yet to be fixed, and second, even if it were fixed, this repo is doing float32 inference right now, so it would be fairly unusably slow. Once the export is done, we can run it:
@ -316,7 +316,6 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
## unsorted todos
- delete the export_meta_llama_bin.py and export_meta_llama_hf_bin.py files. instead, import both of these into a proper model.py Transformer instance, and then export using the export script as usual.
- migrate the code to work with the new versions export and deprecate the original .bin files
- support Llama 2 7B Chat models and tune run.c to Chat UI/UX
- make it easier to add a new dataset with not too much pain

View File

@ -19,6 +19,9 @@ import gzip
import shutil
import struct
import argparse
import json
from pathlib import Path
import numpy as np
import torch
from torch import nn
@ -30,7 +33,7 @@ from model import ModelArgs, Transformer
def serialize_fp32(file, tensor):
""" writes one fp32 tensor to file that is open in wb mode """
d = tensor.detach().cpu().view(-1).numpy().astype(np.float32)
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
b = struct.pack(f'{len(d)}f', *d)
file.write(b)
@ -281,6 +284,71 @@ def load_checkpoint(checkpoint):
model.eval()
return model
def load_meta_model(model_path):
params_path = os.path.join(model_path, 'params.json')
with open(params_path) as f:
params = json.load(f)
print(params)
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
models = [torch.load(p, map_location='cpu') for p in model_paths]
def concat_weights(models):
state_dict = {}
for name in list(models[0]):
tensors = [model[name] for model in models]
if len(tensors) == 1 or len(tensors[0].shape) == 1:
state_dict[name] = tensors[0]
continue
is_axis_1 = (
name.startswith('tok_embeddings.')
or name.endswith('.attention.wo.weight')
or name.endswith('.feed_forward.w2.weight')
)
axis = 1 if is_axis_1 else 0
state_dict[name] = torch.cat(tensors, dim=axis)
for model in models:
del model[name]
return state_dict
state_dict = concat_weights(models)
del models
# set ModelArgs
config = ModelArgs()
config.dim = params["dim"]
config.n_layers = params["n_layers"]
config.n_heads = params["n_heads"]
config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
config.multiple_of = params["multiple_of"]
config.norm_eps = params["norm_eps"]
config.vocab_size = 32000
config.max_seq_len = 2048
# create a new Transformer object and set weights
model = Transformer(config)
model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
model.norm.weight = nn.Parameter(state_dict['norm.weight'])
for layer in model.layers:
i = layer.layer_id
layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])
# final classifier
model.output.weight = nn.Parameter(state_dict['output.weight'])
model.eval()
return model
def load_hf_model(model_path):
try:
@ -381,17 +449,19 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("filepath", type=str, help="the output filepath")
parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
parser.add_argument("--hf", type=str, help="huggingface model")
parser.add_argument("--version", default=0, type=int, help="the version to export with")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
group.add_argument("--meta-llama", type=str, help="meta llama model path")
group.add_argument("--hf", type=str, help="huggingface model path")
args = parser.parse_args()
if args.checkpoint:
model = load_checkpoint(args.checkpoint)
elif args.meta_llama:
model = load_meta_model(args.meta_llama)
elif args.hf:
model = load_hf_model(args.hf)
else:
parser.error("Input model missing: --checkpoint or --hf is required")
if model is None:
parser.error("Can't load input model!")

View File

@ -1,112 +0,0 @@
"""
This script exports the Llama 2 weights in llama2c.bin format.
"""
import os
import sys
import struct
from pathlib import Path
import json
import torch
from model import precompute_freqs_cis
def export(p, state_dict, filepath='model.bin'):
"""export the model weights in fp32 into .bin file to be read from C"""
f = open(filepath, 'wb')
def serialize(key):
print(f"writing {key}...")
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
f.write(memoryview(t))
del state_dict[key]
# first write out the header
hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0]
p['vocab_size'] = 32000
p['max_seq_len'] = 2048
n_kv_heads = p.get('n_kv_heads') or p['n_heads']
header = struct.pack(
'iiiiiii',
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
n_kv_heads, -p['vocab_size'], p['max_seq_len']
)
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
# in the checkpoint and should be loaded.
f.write(header)
# next write out the embedding weights
print("writing tok_embeddings...")
serialize('tok_embeddings.weight')
# now all the layers
# attention weights
for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight')
# ffn weights
for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight')
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight')
# final rmsnorm
serialize('norm.weight')
# freqs_cos, freqs_sin
freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']]
state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']]
serialize('freqs_cos')
serialize('freqs_sin')
# finally write the output weights
serialize('output.weight')
f.close()
print(f"wrote {filepath}")
def concat_weights(models):
state_dict = {}
for name in list(models[0]):
tensors = [model[name] for model in models]
if len(tensors) == 1 or len(tensors[0].shape) == 1:
state_dict[name] = tensors[0]
continue
is_axis_1 = (
name.startswith('tok_embeddings.')
or name.endswith('.attention.wo.weight')
or name.endswith('.feed_forward.w2.weight')
)
axis = 1 if is_axis_1 else 0
state_dict[name] = torch.cat(tensors, dim=axis)
for model in models:
del model[name]
return state_dict
def load_and_export(model_path, output_path):
params_path = os.path.join(model_path, 'params.json')
with open(params_path) as f:
params = json.load(f)
print(params)
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
models = [torch.load(p, map_location='cpu') for p in model_paths]
state_dict = concat_weights(models)
del models
export(params, state_dict, output_path)
if __name__ == '__main__':
if len(sys.argv) == 1:
print('[Llama model folder path] [output path]')
exit()
model_path = sys.argv[1]
output_path = sys.argv[2]
load_and_export(model_path, output_path)