From de005474d37d0cde1356739b8c79ebe7b42b5973 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Mon, 21 Aug 2023 14:13:47 +0300 Subject: [PATCH 1/3] Added load_meta_model() to export.py --- export.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 75 insertions(+), 5 deletions(-) diff --git a/export.py b/export.py index e486a81..a60d7cf 100644 --- a/export.py +++ b/export.py @@ -19,6 +19,9 @@ import gzip import shutil import struct import argparse +import json +from pathlib import Path + import numpy as np import torch from torch import nn @@ -30,7 +33,7 @@ from model import ModelArgs, Transformer def serialize_fp32(file, tensor): """ writes one fp32 tensor to file that is open in wb mode """ - d = tensor.detach().cpu().view(-1).numpy().astype(np.float32) + d = tensor.detach().cpu().view(-1).to(torch.float32).numpy() b = struct.pack(f'{len(d)}f', *d) file.write(b) @@ -281,6 +284,71 @@ def load_checkpoint(checkpoint): model.eval() return model +def load_meta_model(model_path): + params_path = os.path.join(model_path, 'params.json') + with open(params_path) as f: + params = json.load(f) + print(params) + + model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth'))) + models = [torch.load(p, map_location='cpu') for p in model_paths] + + def concat_weights(models): + state_dict = {} + for name in list(models[0]): + tensors = [model[name] for model in models] + if len(tensors) == 1 or len(tensors[0].shape) == 1: + state_dict[name] = tensors[0] + continue + is_axis_1 = ( + name.startswith('tok_embeddings.') + or name.endswith('.attention.wo.weight') + or name.endswith('.feed_forward.w2.weight') + ) + axis = 1 if is_axis_1 else 0 + state_dict[name] = torch.cat(tensors, dim=axis) + for model in models: + del model[name] + return state_dict + + state_dict = concat_weights(models) + del models + + # set ModelArgs + config = ModelArgs() + config.dim = params["dim"] + config.n_layers = params["n_layers"] + config.n_heads = params["n_heads"] + config.n_kv_heads = params.get('n_kv_heads') or params['n_heads'] + config.multiple_of = params["multiple_of"] + config.norm_eps = params["norm_eps"] + + config.vocab_size = 32000 + config.max_seq_len = 2048 + + # create a new Transformer object and set weights + model = Transformer(config) + + model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight']) + model.norm.weight = nn.Parameter(state_dict['norm.weight']) + + for layer in model.layers: + i = layer.layer_id + layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight']) + layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight']) + layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight']) + layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight']) + layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight']) + layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight']) + layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight']) + layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight']) + layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight']) + + # final classifier + model.output.weight = nn.Parameter(state_dict['output.weight']) + model.eval() + return model + def load_hf_model(model_path): try: @@ -381,17 +449,19 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("filepath", type=str, help="the output filepath") - parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file") - parser.add_argument("--hf", type=str, help="huggingface model") parser.add_argument("--version", default=0, type=int, help="the version to export with") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file") + group.add_argument("--meta-llama", type=str, help="meta llama model path") + group.add_argument("--hf", type=str, help="huggingface model path") args = parser.parse_args() if args.checkpoint: model = load_checkpoint(args.checkpoint) + elif args.meta_llama: + model = load_meta_model(args.meta_llama) elif args.hf: model = load_hf_model(args.hf) - else: - parser.error("Input model missing: --checkpoint or --hf is required") if model is None: parser.error("Can't load input model!") From 36a78af5e16e68117d4e1235199938f024e2226c Mon Sep 17 00:00:00 2001 From: atamyrat Date: Mon, 21 Aug 2023 14:19:56 +0300 Subject: [PATCH 2/3] tested load_meta_model() in export.py, deleting old export_meta_llama_bin.py file --- export_meta_llama_bin.py | 112 --------------------------------------- 1 file changed, 112 deletions(-) delete mode 100644 export_meta_llama_bin.py diff --git a/export_meta_llama_bin.py b/export_meta_llama_bin.py deleted file mode 100644 index 4e42197..0000000 --- a/export_meta_llama_bin.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -This script exports the Llama 2 weights in llama2c.bin format. -""" -import os -import sys -import struct -from pathlib import Path -import json - -import torch - -from model import precompute_freqs_cis - - -def export(p, state_dict, filepath='model.bin'): - """export the model weights in fp32 into .bin file to be read from C""" - f = open(filepath, 'wb') - - def serialize(key): - print(f"writing {key}...") - t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy() - f.write(memoryview(t)) - del state_dict[key] - - # first write out the header - hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0] - p['vocab_size'] = 32000 - p['max_seq_len'] = 2048 - - n_kv_heads = p.get('n_kv_heads') or p['n_heads'] - header = struct.pack( - 'iiiiiii', - p['dim'], hidden_dim, p['n_layers'], p['n_heads'], - n_kv_heads, -p['vocab_size'], p['max_seq_len'] - ) - # NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present - # in the checkpoint and should be loaded. - f.write(header) - - # next write out the embedding weights - print("writing tok_embeddings...") - serialize('tok_embeddings.weight') - - # now all the layers - # attention weights - for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight') - for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight') - for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight') - for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight') - for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight') - # ffn weights - for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight') - for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight') - for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight') - for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight') - - # final rmsnorm - serialize('norm.weight') - # freqs_cos, freqs_sin - freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2) - state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']] - state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']] - serialize('freqs_cos') - serialize('freqs_sin') - - # finally write the output weights - serialize('output.weight') - - f.close() - print(f"wrote {filepath}") - - -def concat_weights(models): - state_dict = {} - for name in list(models[0]): - tensors = [model[name] for model in models] - if len(tensors) == 1 or len(tensors[0].shape) == 1: - state_dict[name] = tensors[0] - continue - is_axis_1 = ( - name.startswith('tok_embeddings.') - or name.endswith('.attention.wo.weight') - or name.endswith('.feed_forward.w2.weight') - ) - axis = 1 if is_axis_1 else 0 - state_dict[name] = torch.cat(tensors, dim=axis) - for model in models: - del model[name] - return state_dict - - -def load_and_export(model_path, output_path): - params_path = os.path.join(model_path, 'params.json') - with open(params_path) as f: - params = json.load(f) - print(params) - - model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth'))) - models = [torch.load(p, map_location='cpu') for p in model_paths] - state_dict = concat_weights(models) - del models - export(params, state_dict, output_path) - - -if __name__ == '__main__': - if len(sys.argv) == 1: - print('[Llama model folder path] [output path]') - exit() - - model_path = sys.argv[1] - output_path = sys.argv[2] - load_and_export(model_path, output_path) From 61c26d5392010495196a1525e104e4e7cdc7aadc Mon Sep 17 00:00:00 2001 From: atamyrat Date: Mon, 21 Aug 2023 14:24:01 +0300 Subject: [PATCH 3/3] Updated README to replace export_meta_llama_bin.py script with export.py --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ff15005..24edca7 100644 --- a/README.md +++ b/README.md @@ -65,10 +65,10 @@ Quick note on sampling, the recommendation for ~best results is to sample with ` ## Meta's Llama 2 models As the neural net architecture is identical, we can also inference the Llama 2 models released by Meta. Sadly there is a bit of friction here due to licensing (I can't directly upload the checkpoints, I think). So Step 1, get the Llama 2 checkpoints by following the [Meta instructions](https://github.com/facebookresearch/llama). Once we have those checkpoints, we have to convert them into the llama2.c format. -For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export_meta_llama_bin.py` file, e.g. for 7B model: +For this we need to install the python dependencies (`pip install -r requirements.txt`) and then use the `export.py` file, e.g. for 7B model: ```bash -python export_meta_llama_bin.py path/to/llama/model/7B llama2_7b.bin +python export.py llama2_7b.bin --meta-llama path/to/llama/model/7B ``` The export will take ~10 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts. I would not attempt to run anything above 7B right now for two reasons: first, 13B+ currently doesn't work because of integer flow in pointer arithmetic, which is yet to be fixed, and second, even if it were fixed, this repo is doing float32 inference right now, so it would be fairly unusably slow. Once the export is done, we can run it: @@ -316,7 +316,6 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg ## unsorted todos -- delete the export_meta_llama_bin.py and export_meta_llama_hf_bin.py files. instead, import both of these into a proper model.py Transformer instance, and then export using the export script as usual. - migrate the code to work with the new versions export and deprecate the original .bin files - support Llama 2 7B Chat models and tune run.c to Chat UI/UX - make it easier to add a new dataset with not too much pain