mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
Merge branch 'master' into feature/int8_try2
This commit is contained in:
commit
1f8af82130
17
README.md
17
README.md
@ -4,6 +4,8 @@
|
||||
<img src="assets/llama_cute.jpg" width="300" height="300" alt="Cute Llama">
|
||||
</p>
|
||||
|
||||
Have you ever wanted to inference a baby [Llama 2](https://ai.meta.com/llama/) model in pure C? No? Well, now you can!
|
||||
|
||||
Train the Llama 2 LLM architecture in PyTorch then inference it with one simple 700-line C file ([run.c](run.c)). You might think that you need many billion parameter LLMs to do anything useful, but in fact very small LLMs can have surprisingly strong performance if you make the domain narrow enough (ref: [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) paper). This repo is a "fullstack" train + inference solution for Llama 2 LLM, with focus on minimalism and simplicity.
|
||||
|
||||
As the architecture is identical, you can also load and inference Meta's Llama 2 models. However, the current code only inferences models in fp32, so you will most likely not be able to productively load models larger than 7B. Work on model quantization is currently ongoing.
|
||||
@ -14,7 +16,7 @@ Please note that this repo started recently as a fun weekend project: I took my
|
||||
|
||||
[](https://colab.research.google.com/github/karpathy/llama2.c/blob/master/run.ipynb)
|
||||
|
||||
First, navigate to the folder when you keep your projects and clone this repository to this folder:
|
||||
First, navigate to the folder where you keep your projects and clone this repository to this folder:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/karpathy/llama2.c.git
|
||||
@ -109,8 +111,9 @@ Chat with Code Llama Instruct:
|
||||
python export.py codellama2_7b_instruct.bin --meta-llama /path/to/CodeLlama-7b-Instruct
|
||||
python tokenizer.py --tokenizer-model=/path/to/CodeLlama-7b-Instruct/tokenizer.model
|
||||
./run codellama2_7b_instruct.bin -m chat -z /path/to/CodeLlama-7b-Instruct/tokenizer.bin
|
||||
```
|
||||
|
||||
## hugginface models
|
||||
## huggingface models
|
||||
|
||||
We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file.
|
||||
|
||||
@ -311,6 +314,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
- [llama2-rs](https://github.com/danielgrittner/llama2-rs) by @[danielgrittner](https://github.com/danielgrittner): a Rust port of this project
|
||||
- [llama2.rs](https://github.com/lintian06/llama2.rs) by @[lintian06](https://github.com/lintian06): A Rust port of this project
|
||||
- [pecca.rs](https://github.com/rahoua/pecca-rs) by @[rahoua](https://github.com/rahoua): A Rust port leveraging [ndarray](https://github.com/rust-ndarray/ndarray), supports BLAS.
|
||||
- [llama2.rs](https://github.com/flaneur2020/llama2.rs) by @[flaneur2020](https://github.com/flaneur2020): A Rust port of this project.
|
||||
- Go
|
||||
- [go-llama2](https://github.com/tmc/go-llama2) by @[tmc](https://github.com/tmc): a Go port of this project
|
||||
- [llama2.go](https://github.com/nikolaydubina/llama2.go) by @[nikolaydubina](https://github.com/nikolaydubina): a Go port of this project
|
||||
@ -323,6 +327,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
- [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @[leloykun](https://github.com/leloykun): a C++ port of this project
|
||||
- JavaScript
|
||||
- [llama2.js](https://github.com/epicure/llama2.js) by @[epicure](https://github.com/epicure): a JavaScript port of this project
|
||||
- [llamajs](https://github.com/agershun/llamajs) by @[agershun](https://github.com/agershun): a JavaScript port of this project
|
||||
- [llama2.ts](https://github.com/wizzard0/llama2.ts) by @[oleksandr_now](https://twitter.com/oleksandr_now): a TypeScript port of this project. Full Llama2-7B capable.
|
||||
- [llama2.c-emscripten](https://github.com/gohai/llama2.c-emscripten) by @[gohai](https://github.com/gohai): Emscripten (JavaScript) port, based on @ggerganov's initial prototype
|
||||
- Zig
|
||||
@ -343,8 +348,16 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
|
||||
- [llama2.cs](https://github.com/trrahul/llama2.cs) by @[trrahul](https://github.com/trrahul): a C# port of this project
|
||||
- Dart
|
||||
- [llama2.dart](https://github.com/yiminghan/llama2.dart) by @[yiminghan](https://github.com/yiminghan/llama2.dart): one-file dart port of this project, works with Flutter!
|
||||
- Web
|
||||
- [llama2c-web](https://github.com/dmarcos/llama2.c-web) by @[dmarcos](https://github.com/dmarcos): Super simple way to build unmodified llama2.c to WASM and run it in the browser. [Demo](https://diegomarcos.com/llama2.c-web/)
|
||||
- WebAssembly
|
||||
- [icpp-llm](https://github.com/icppWorld/icpp-llm): LLMs for the Internet Computer
|
||||
- Fortran
|
||||
- [llama2.f90](https://github.com/rbitr/llama2.f90): a Fortran port of this project
|
||||
- Mojo
|
||||
- [llama2.🔥](https://github.com/tairov/llama2.mojo) by @[tairov](https://github.com/tairov): pure Mojo port of this project
|
||||
- OCaml
|
||||
- [llama2.ml](https://github.com/jackpeck/llama2.ml) by @[jackpeck](https://github.com/jackpeck): an OCaml port of this project
|
||||
- [llama2.c - Llama 2 Everywhere](https://github.com/trholding/llama2.c) by @[trholding](https://github.com/trholding): Standalone, Bootable & Portable Binary Llama 2
|
||||
- [llama2.c-zh - Bilingual Chinese and English](https://github.com/chenyangMl/llama2.c-zh) by @[chenyangMl](https://github.com/chenyangMl): Expand tokenizer to support training and inference in both Chinese and English
|
||||
|
||||
|
||||
100
export.py
100
export.py
@ -259,6 +259,96 @@ def version2_export(model, filepath, group_size=64):
|
||||
out_file.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32):
|
||||
""" Generate the pytorch_model.bin state_dict and config.json for HuggingFace """
|
||||
|
||||
try:
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
except ImportError:
|
||||
print("Error: transformers package is required to load huggingface models")
|
||||
print("Please run `pip install transformers` to install it")
|
||||
return None
|
||||
|
||||
# Generate LlamaModel state_dict
|
||||
hf_state_dict = {}
|
||||
|
||||
# Sometimes we have repeated key values for the heads
|
||||
dim = llama_model.params.dim
|
||||
num_key_value_heads = llama_model.params.n_kv_heads
|
||||
n_rep = llama_model.params.n_heads // num_key_value_heads
|
||||
key_value_dim = dim // n_rep
|
||||
|
||||
# HuggingFace needs the weights permuted.
|
||||
# See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122
|
||||
def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim):
|
||||
return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
# Transfer weights from llama model to the HF state dictionary format
|
||||
hf_state_dict['model.embed_tokens.weight'] = llama_model.tok_embeddings.weight.clone().to(dtype)
|
||||
hf_state_dict['model.norm.weight'] = llama_model.norm.weight.clone().to(dtype)
|
||||
|
||||
# Add each layer's weights to the HF state dictionary
|
||||
for i, layer in enumerate(llama_model.layers):
|
||||
layer_id = layer.layer_id
|
||||
hf_state_dict[f'model.layers.{i}.input_layernorm.weight'] = llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype)
|
||||
hf_state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wq.weight.clone()).to(dtype)
|
||||
hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone(), num_key_value_heads, key_value_dim, dim).to(dtype)
|
||||
hf_state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype)
|
||||
hf_state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype)
|
||||
hf_state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype)
|
||||
hf_state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = llama_model.layers[layer_id].feed_forward.w1.weight.clone().to(dtype)
|
||||
hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype)
|
||||
hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype)
|
||||
|
||||
# llama2.c usually uses tied weights -> reference the embed_tokens.weights instead
|
||||
hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight']
|
||||
|
||||
# We check that the embeddings are tied, else use manual output weights
|
||||
_embeddings_are_tied: bool = torch.equal(llama_model.tok_embeddings.weight, llama_model.output.weight)
|
||||
if not _embeddings_are_tied:
|
||||
hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)
|
||||
|
||||
|
||||
# Generate LlamaConfig (seen in transformers.models.llama.configuration_llama)
|
||||
|
||||
# Extract necessary attributes from llama.c model
|
||||
vocab_size = llama_model.params.vocab_size
|
||||
hidden_size = llama_model.params.dim
|
||||
intermediate_size = llama_model.layers[0].feed_forward.w1.weight.shape[0]
|
||||
num_hidden_layers = llama_model.params.n_layers
|
||||
num_attention_heads = llama_model.params.n_heads
|
||||
num_key_value_heads = llama_model.params.n_kv_heads
|
||||
max_position_embeddings = llama_model.params.max_seq_len
|
||||
rms_norm_eps = llama_model.params.norm_eps
|
||||
|
||||
# TODO check values for:
|
||||
# pretraining_tp, initializer_range, use_cache,
|
||||
# rope_theta, and rope_scaling.
|
||||
|
||||
config = LlamaConfig(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
tie_word_embeddings=_embeddings_are_tied,
|
||||
# Manual
|
||||
architectures=["LlamaForCausalLM"],
|
||||
hidden_act="silu",
|
||||
)
|
||||
|
||||
|
||||
# Save files in directory filepath
|
||||
# First make the directory if it doesn't exist
|
||||
os.makedirs(filepath, exist_ok=True)
|
||||
|
||||
# Save the state dictionary in .bin format, and config as .json
|
||||
torch.save(hf_state_dict, os.path.join(filepath, "pytorch_model.bin"))
|
||||
config.save_pretrained(filepath)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Load / import functions
|
||||
@ -399,12 +489,14 @@ def load_hf_model(model_path):
|
||||
# -----------------------------------------------------------------------------
|
||||
# API entrypoint
|
||||
|
||||
def model_export(model, filepath, version):
|
||||
def model_export(model, filepath, version, dtype=torch.float32):
|
||||
"""
|
||||
Versions docs:
|
||||
v-1:huggingface export, i.e. intended for use outside of this repo, in HF
|
||||
v0: legacy llama2.c float format, DEPRECATED
|
||||
v1: float32 export
|
||||
v2: int8 quantized Q8_0 export, similar to llama.cpp, in groups
|
||||
# TODO: add dtype export support for other versions (?)
|
||||
"""
|
||||
if version == 0:
|
||||
legacy_export(model, filepath)
|
||||
@ -412,6 +504,8 @@ def model_export(model, filepath, version):
|
||||
version1_export(model, filepath)
|
||||
elif version == 2:
|
||||
version2_export(model, filepath)
|
||||
elif version == -1:
|
||||
hf_export(model, filepath, dtype)
|
||||
else:
|
||||
raise ValueError(f"unknown version {version}")
|
||||
|
||||
@ -451,11 +545,13 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("filepath", type=str, help="the output filepath")
|
||||
parser.add_argument("--version", default=0, type=int, help="the version to export with")
|
||||
parser.add_argument("--dtype", type=str, help="dtype of the model (fp16, fp32)", default="fp32")
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
|
||||
group.add_argument("--meta-llama", type=str, help="meta llama model path")
|
||||
group.add_argument("--hf", type=str, help="huggingface model path")
|
||||
args = parser.parse_args()
|
||||
dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype]
|
||||
|
||||
if args.checkpoint:
|
||||
model = load_checkpoint(args.checkpoint)
|
||||
@ -468,4 +564,4 @@ if __name__ == "__main__":
|
||||
parser.error("Can't load input model!")
|
||||
|
||||
# export
|
||||
model_export(model, args.filepath, args.version)
|
||||
model_export(model, args.filepath, args.version, args.dtype)
|
||||
|
||||
47
run.c
47
run.c
@ -83,16 +83,13 @@ void malloc_run_state(RunState* s, Config* p) {
|
||||
s->hb = calloc(p->hidden_dim, sizeof(float));
|
||||
s->hb2 = calloc(p->hidden_dim, sizeof(float));
|
||||
s->q = calloc(p->dim, sizeof(float));
|
||||
s->k = calloc(kv_dim, sizeof(float));
|
||||
s->v = calloc(kv_dim, sizeof(float));
|
||||
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
||||
s->logits = calloc(p->vocab_size, sizeof(float));
|
||||
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
||||
s->logits = calloc(p->vocab_size, sizeof(float));
|
||||
// ensure all mallocs went fine
|
||||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
||||
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|
||||
|| !s->value_cache) {
|
||||
|| !s->key_cache || !s->value_cache || !s->att || !s->logits) {
|
||||
fprintf(stderr, "malloc failed!\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
@ -105,8 +102,6 @@ void free_run_state(RunState* s) {
|
||||
free(s->hb);
|
||||
free(s->hb2);
|
||||
free(s->q);
|
||||
free(s->k);
|
||||
free(s->v);
|
||||
free(s->att);
|
||||
free(s->logits);
|
||||
free(s->key_cache);
|
||||
@ -115,26 +110,28 @@ void free_run_state(RunState* s) {
|
||||
|
||||
void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
|
||||
int head_size = p->dim / p->n_heads;
|
||||
// make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
|
||||
unsigned long long n_layers = p->n_layers;
|
||||
w->token_embedding_table = ptr;
|
||||
ptr += p->vocab_size * p->dim;
|
||||
w->rms_att_weight = ptr;
|
||||
ptr += p->n_layers * p->dim;
|
||||
ptr += n_layers * p->dim;
|
||||
w->wq = ptr;
|
||||
ptr += p->n_layers * p->dim * (p->n_heads * head_size);
|
||||
ptr += n_layers * p->dim * (p->n_heads * head_size);
|
||||
w->wk = ptr;
|
||||
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
|
||||
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
|
||||
w->wv = ptr;
|
||||
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
|
||||
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
|
||||
w->wo = ptr;
|
||||
ptr += p->n_layers * (p->n_heads * head_size) * p->dim;
|
||||
ptr += n_layers * (p->n_heads * head_size) * p->dim;
|
||||
w->rms_ffn_weight = ptr;
|
||||
ptr += p->n_layers * p->dim;
|
||||
ptr += n_layers * p->dim;
|
||||
w->w1 = ptr;
|
||||
ptr += p->n_layers * p->dim * p->hidden_dim;
|
||||
ptr += n_layers * p->dim * p->hidden_dim;
|
||||
w->w2 = ptr;
|
||||
ptr += p->n_layers * p->hidden_dim * p->dim;
|
||||
ptr += n_layers * p->hidden_dim * p->dim;
|
||||
w->w3 = ptr;
|
||||
ptr += p->n_layers * p->dim * p->hidden_dim;
|
||||
ptr += n_layers * p->dim * p->hidden_dim;
|
||||
w->rms_final_weight = ptr;
|
||||
ptr += p->dim;
|
||||
ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
|
||||
@ -249,11 +246,16 @@ float* forward(Transformer* transformer, int token, int pos) {
|
||||
memcpy(x, content_row, dim*sizeof(*x));
|
||||
|
||||
// forward all the layers
|
||||
for(int l = 0; l < p->n_layers; l++) {
|
||||
for(unsigned long long l = 0; l < p->n_layers; l++) {
|
||||
|
||||
// attention rmsnorm
|
||||
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
|
||||
|
||||
// key and value point to the kv cache
|
||||
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
|
||||
s->k = s->key_cache + loff + pos * kv_dim;
|
||||
s->v = s->value_cache + loff + pos * kv_dim;
|
||||
|
||||
// qkv matmuls for this position
|
||||
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
|
||||
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
|
||||
@ -276,13 +278,6 @@ float* forward(Transformer* transformer, int token, int pos) {
|
||||
}
|
||||
}
|
||||
|
||||
// save key,value at this time step (pos) to our kv cache
|
||||
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
|
||||
float* key_cache_row = s->key_cache + loff + pos * kv_dim;
|
||||
float* value_cache_row = s->value_cache + loff + pos * kv_dim;
|
||||
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
|
||||
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
|
||||
|
||||
// multihead attention. iterate over all heads
|
||||
int h;
|
||||
#pragma omp parallel for private(h)
|
||||
@ -754,7 +749,7 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
// forward the transformer to get logits for the next token
|
||||
float* logits = forward(transformer, token, pos);
|
||||
|
||||
// advance the state state machine
|
||||
// advance the state machine
|
||||
if (pos < num_prompt_tokens - 1) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
next = prompt_tokens[pos + 1];
|
||||
|
||||
@ -88,7 +88,7 @@ def train_vocab(vocab_size):
|
||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
|
||||
|
||||
print(f"Writing temporary file {tiny_file} with {num_shards} shards...")
|
||||
with open(tiny_file, "w") as of:
|
||||
with open(tiny_file, "w", encoding="utf-8") as of:
|
||||
for shard in tqdm(shard_filenames[:num_shards]):
|
||||
with open(shard, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user