mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
Merge pull request #329 from atamurad/import_meta
Moved export_meta_llama_bin.py to new export.py
This commit is contained in:
commit
8a3ea7b433
@ -65,10 +65,10 @@ Quick note on sampling, the recommendation for ~best results is to sample with `
|
||||
## Meta's Llama 2 models
|
||||
|
||||
As the neural net architecture is identical, we can also inference the Llama 2 models released by Meta. Sadly there is a bit of friction here due to licensing (I can't directly upload the checkpoints, I think). So Step 1, get the Llama 2 checkpoints by following the [Meta instructions](https://github.com/facebookresearch/llama). Once we have those checkpoints, we have to convert them into the llama2.c format.
|
||||
For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export_meta_llama_bin.py` file, e.g. for 7B model:
|
||||
For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export.py` file, e.g. for 7B model:
|
||||
|
||||
```bash
|
||||
python export_meta_llama_bin.py path/to/llama/model/7B llama2_7b.bin
|
||||
python export.py llama2_7b.bin --meta-llama path/to/llama/model/7B
|
||||
```
|
||||
|
||||
The export will take ~10 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts. I would not attempt to run anything above 7B right now for two reasons: first, 13B+ currently doesn't work because of integer flow in pointer arithmetic, which is yet to be fixed, and second, even if it were fixed, this repo is doing float32 inference right now, so it would be fairly unusably slow. Once the export is done, we can run it:
|
||||
@ -316,7 +316,6 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
|
||||
## unsorted todos
|
||||
|
||||
- delete the export_meta_llama_bin.py and export_meta_llama_hf_bin.py files. instead, import both of these into a proper model.py Transformer instance, and then export using the export script as usual.
|
||||
- migrate the code to work with the new versions export and deprecate the original .bin files
|
||||
- support Llama 2 7B Chat models and tune run.c to Chat UI/UX
|
||||
- make it easier to add a new dataset with not too much pain
|
||||
|
||||
80
export.py
80
export.py
@ -19,6 +19,9 @@ import gzip
|
||||
import shutil
|
||||
import struct
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -30,7 +33,7 @@ from model import ModelArgs, Transformer
|
||||
|
||||
def serialize_fp32(file, tensor):
|
||||
""" writes one fp32 tensor to file that is open in wb mode """
|
||||
d = tensor.detach().cpu().view(-1).numpy().astype(np.float32)
|
||||
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
|
||||
b = struct.pack(f'{len(d)}f', *d)
|
||||
file.write(b)
|
||||
|
||||
@ -281,6 +284,71 @@ def load_checkpoint(checkpoint):
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def load_meta_model(model_path):
|
||||
params_path = os.path.join(model_path, 'params.json')
|
||||
with open(params_path) as f:
|
||||
params = json.load(f)
|
||||
print(params)
|
||||
|
||||
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||
models = [torch.load(p, map_location='cpu') for p in model_paths]
|
||||
|
||||
def concat_weights(models):
|
||||
state_dict = {}
|
||||
for name in list(models[0]):
|
||||
tensors = [model[name] for model in models]
|
||||
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||
state_dict[name] = tensors[0]
|
||||
continue
|
||||
is_axis_1 = (
|
||||
name.startswith('tok_embeddings.')
|
||||
or name.endswith('.attention.wo.weight')
|
||||
or name.endswith('.feed_forward.w2.weight')
|
||||
)
|
||||
axis = 1 if is_axis_1 else 0
|
||||
state_dict[name] = torch.cat(tensors, dim=axis)
|
||||
for model in models:
|
||||
del model[name]
|
||||
return state_dict
|
||||
|
||||
state_dict = concat_weights(models)
|
||||
del models
|
||||
|
||||
# set ModelArgs
|
||||
config = ModelArgs()
|
||||
config.dim = params["dim"]
|
||||
config.n_layers = params["n_layers"]
|
||||
config.n_heads = params["n_heads"]
|
||||
config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
|
||||
config.multiple_of = params["multiple_of"]
|
||||
config.norm_eps = params["norm_eps"]
|
||||
|
||||
config.vocab_size = 32000
|
||||
config.max_seq_len = 2048
|
||||
|
||||
# create a new Transformer object and set weights
|
||||
model = Transformer(config)
|
||||
|
||||
model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
|
||||
model.norm.weight = nn.Parameter(state_dict['norm.weight'])
|
||||
|
||||
for layer in model.layers:
|
||||
i = layer.layer_id
|
||||
layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
|
||||
layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
|
||||
layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
|
||||
layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
|
||||
layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
|
||||
layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
|
||||
layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
|
||||
layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
|
||||
layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])
|
||||
|
||||
# final classifier
|
||||
model.output.weight = nn.Parameter(state_dict['output.weight'])
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def load_hf_model(model_path):
|
||||
|
||||
try:
|
||||
@ -381,17 +449,19 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("filepath", type=str, help="the output filepath")
|
||||
parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
||||
parser.add_argument("--hf", type=str, help="huggingface model")
|
||||
parser.add_argument("--version", default=0, type=int, help="the version to export with")
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
||||
group.add_argument("--meta-llama", type=str, help="meta llama model path")
|
||||
group.add_argument("--hf", type=str, help="huggingface model path")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.checkpoint:
|
||||
model = load_checkpoint(args.checkpoint)
|
||||
elif args.meta_llama:
|
||||
model = load_meta_model(args.meta_llama)
|
||||
elif args.hf:
|
||||
model = load_hf_model(args.hf)
|
||||
else:
|
||||
parser.error("Input model missing: --checkpoint or --hf is required")
|
||||
|
||||
if model is None:
|
||||
parser.error("Can't load input model!")
|
||||
|
||||
@ -1,112 +0,0 @@
|
||||
"""
|
||||
This script exports the Llama 2 weights in llama2c.bin format.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import struct
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from model import precompute_freqs_cis
|
||||
|
||||
|
||||
def export(p, state_dict, filepath='model.bin'):
|
||||
"""export the model weights in fp32 into .bin file to be read from C"""
|
||||
f = open(filepath, 'wb')
|
||||
|
||||
def serialize(key):
|
||||
print(f"writing {key}...")
|
||||
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
|
||||
f.write(memoryview(t))
|
||||
del state_dict[key]
|
||||
|
||||
# first write out the header
|
||||
hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0]
|
||||
p['vocab_size'] = 32000
|
||||
p['max_seq_len'] = 2048
|
||||
|
||||
n_kv_heads = p.get('n_kv_heads') or p['n_heads']
|
||||
header = struct.pack(
|
||||
'iiiiiii',
|
||||
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
|
||||
n_kv_heads, -p['vocab_size'], p['max_seq_len']
|
||||
)
|
||||
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
|
||||
# in the checkpoint and should be loaded.
|
||||
f.write(header)
|
||||
|
||||
# next write out the embedding weights
|
||||
print("writing tok_embeddings...")
|
||||
serialize('tok_embeddings.weight')
|
||||
|
||||
# now all the layers
|
||||
# attention weights
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight')
|
||||
# ffn weights
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight')
|
||||
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight')
|
||||
|
||||
# final rmsnorm
|
||||
serialize('norm.weight')
|
||||
# freqs_cos, freqs_sin
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
|
||||
state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']]
|
||||
state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']]
|
||||
serialize('freqs_cos')
|
||||
serialize('freqs_sin')
|
||||
|
||||
# finally write the output weights
|
||||
serialize('output.weight')
|
||||
|
||||
f.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
|
||||
def concat_weights(models):
|
||||
state_dict = {}
|
||||
for name in list(models[0]):
|
||||
tensors = [model[name] for model in models]
|
||||
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||
state_dict[name] = tensors[0]
|
||||
continue
|
||||
is_axis_1 = (
|
||||
name.startswith('tok_embeddings.')
|
||||
or name.endswith('.attention.wo.weight')
|
||||
or name.endswith('.feed_forward.w2.weight')
|
||||
)
|
||||
axis = 1 if is_axis_1 else 0
|
||||
state_dict[name] = torch.cat(tensors, dim=axis)
|
||||
for model in models:
|
||||
del model[name]
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_and_export(model_path, output_path):
|
||||
params_path = os.path.join(model_path, 'params.json')
|
||||
with open(params_path) as f:
|
||||
params = json.load(f)
|
||||
print(params)
|
||||
|
||||
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||
models = [torch.load(p, map_location='cpu') for p in model_paths]
|
||||
state_dict = concat_weights(models)
|
||||
del models
|
||||
export(params, state_dict, output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) == 1:
|
||||
print('[Llama model folder path] [output path]')
|
||||
exit()
|
||||
|
||||
model_path = sys.argv[1]
|
||||
output_path = sys.argv[2]
|
||||
load_and_export(model_path, output_path)
|
||||
Loading…
Reference in New Issue
Block a user