mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
f7a7ed94c8
83
.github/workflows/build.yml
vendored
83
.github/workflows/build.yml
vendored
@ -4,10 +4,12 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: ['.github/workflows/**', '**/Makefile', '**/*.c', '**/*.h']
|
||||
paths: ['.github/workflows/**', '**/Makefile', '**/*.c', '**/*.h', '**/*.py']
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths: ['**/Makefile', '**/*.c', '**/*.h']
|
||||
paths: ['**/Makefile', '**/*.c', '**/*.h', '**/*.py']
|
||||
# for manual triggering
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
||||
@ -15,7 +17,7 @@ env:
|
||||
jobs:
|
||||
# check basic builds to avoid breaking changes
|
||||
ubuntu-focal-make:
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@ -28,6 +30,16 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential -y
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Pip setup
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
|
||||
- name: Build
|
||||
id: make_build
|
||||
run: |
|
||||
@ -38,6 +50,10 @@ jobs:
|
||||
run: |
|
||||
make runfast
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest
|
||||
|
||||
macOS-latest-make:
|
||||
runs-on: macos-latest
|
||||
|
||||
@ -52,6 +68,21 @@ jobs:
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Pip setup
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
|
||||
- name: Build clang
|
||||
id: make_build_clang
|
||||
run: |
|
||||
make run CC=clang
|
||||
|
||||
- name: Build
|
||||
id: make_build
|
||||
run: |
|
||||
@ -62,15 +93,17 @@ jobs:
|
||||
run: |
|
||||
make runfast
|
||||
|
||||
- name: Build clang
|
||||
id: make_build_clang
|
||||
run: |
|
||||
make run CC=clang
|
||||
- name: Test with pytest
|
||||
run: pytest
|
||||
|
||||
|
||||
|
||||
|
||||
windows-latest-make:
|
||||
runs-on: windows-latest
|
||||
|
||||
strategy:
|
||||
fail-fast: false #necessary, otherwise the matrix breaks
|
||||
matrix:
|
||||
arch:
|
||||
- amd64
|
||||
@ -90,11 +123,30 @@ jobs:
|
||||
with:
|
||||
arch: ${{ matrix.arch }}
|
||||
|
||||
- name: Set up Python 3.10
|
||||
if: matrix.arch != 'amd64_arm64'
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Pip setup
|
||||
if: matrix.arch != 'amd64_arm64'
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
if (Test-Path requirements.txt) {
|
||||
pip install -r requirements.txt
|
||||
}
|
||||
|
||||
- name: Build ${{ matrix.arch }}
|
||||
id: build_msvc
|
||||
run: |
|
||||
.\build_msvc.bat
|
||||
|
||||
#cross-comiled, cannot be run on host
|
||||
- name: Test with pytest
|
||||
if: matrix.arch != 'amd64_arm64'
|
||||
run: pytest
|
||||
|
||||
windows-latest-mingw:
|
||||
runs-on: windows-latest
|
||||
|
||||
@ -122,3 +174,20 @@ jobs:
|
||||
id: build_mingw
|
||||
run: |
|
||||
make win64
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Pip setup
|
||||
shell: powershell
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
if (Test-Path requirements.txt) {
|
||||
pip install -r requirements.txt
|
||||
}
|
||||
|
||||
- name: Test with pytest
|
||||
shell: powershell
|
||||
run: pytest
|
||||
|
||||
10
Makefile
10
Makefile
@ -93,6 +93,16 @@ cosmorun:
|
||||
zip run.com out/model.bin
|
||||
zip run.com tokenizer.bin
|
||||
|
||||
# run all tests
|
||||
.PHONY: test
|
||||
test:
|
||||
pytest
|
||||
|
||||
# run only tests for run.c C implementation (is a bit faster if only C code changed)
|
||||
.PHONY: testc
|
||||
testc:
|
||||
pytest -k runc
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
rm -f run
|
||||
|
||||
113
export_meta_llama_hf_bin.py
Normal file
113
export_meta_llama_hf_bin.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""
|
||||
This script exports the Llama 2 weights in llama2c.bin format.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import struct
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from model import precompute_freqs_cis
|
||||
|
||||
|
||||
def export(p, state_dict, filepath='model.bin'):
|
||||
"""export the model weights in fp32 into .bin file to be read from C"""
|
||||
f = open(filepath, 'wb')
|
||||
|
||||
def serialize(key):
|
||||
print(f"writing {key}...")
|
||||
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
|
||||
f.write(memoryview(t))
|
||||
del state_dict[key]
|
||||
|
||||
# first write out the header
|
||||
hidden_dim = state_dict['model.layers.0.mlp.gate_proj.weight'].shape[0]
|
||||
p['vocab_size'] = 32000
|
||||
p['max_seq_len'] = 2048
|
||||
|
||||
n_kv_heads = p.get('n_kv_heads') or p['n_heads']
|
||||
header = struct.pack(
|
||||
'iiiiiii',
|
||||
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
|
||||
n_kv_heads, -p['vocab_size'], p['max_seq_len']
|
||||
)
|
||||
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
|
||||
# in the checkpoint and should be loaded.
|
||||
f.write(header)
|
||||
|
||||
# next write out the embedding weights
|
||||
print("writing tok_embeddings...")
|
||||
serialize('model.embed_tokens.weight')
|
||||
|
||||
# now all the layers
|
||||
# attention weights
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.input_layernorm.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.q_proj.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.k_proj.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.v_proj.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.o_proj.weight')
|
||||
# ffn weights
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.post_attention_layernorm.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.mlp.gate_proj.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.mlp.down_proj.weight')
|
||||
for i in range(p['n_layers']): serialize(f'model.layers.{i}.mlp.up_proj.weight')
|
||||
|
||||
# final rmsnorm
|
||||
serialize('model.norm.weight')
|
||||
# freqs_cos, freqs_sin
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
|
||||
state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']]
|
||||
state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']]
|
||||
# check if this requires addtional conversion
|
||||
serialize('freqs_cos')
|
||||
serialize('freqs_sin')
|
||||
|
||||
# finally write the output weights
|
||||
serialize('lm_head.weight')
|
||||
|
||||
f.close()
|
||||
print(f"wrote {filepath}")
|
||||
|
||||
|
||||
def concat_weights(models):
|
||||
state_dict = {}
|
||||
for name in list(models[0]):
|
||||
tensors = [model[name] for model in models]
|
||||
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||
state_dict[name] = tensors[0]
|
||||
continue
|
||||
is_axis_1 = (
|
||||
name.startswith('model.embed_tokens.weight')
|
||||
or name.endswith('.self_attn.o_proj.weight')
|
||||
or name.endswith('.mlp.down_proj.weight')
|
||||
)
|
||||
axis = 1 if is_axis_1 else 0
|
||||
state_dict[name] = torch.cat(tensors, dim=axis)
|
||||
for model in models:
|
||||
del model[name]
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_and_export(model_path, output_path):
|
||||
params_path = os.path.join(model_path, 'params.json')
|
||||
with open(params_path) as f:
|
||||
params = json.load(f)
|
||||
print(params)
|
||||
|
||||
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
|
||||
models = [torch.load(p, map_location='cpu') for p in model_paths]
|
||||
state_dict = concat_weights(models)
|
||||
del models
|
||||
export(params, state_dict, output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) == 1:
|
||||
print('[Llama model folder path] [output path]')
|
||||
exit()
|
||||
|
||||
model_path = sys.argv[1]
|
||||
output_path = sys.argv[2]
|
||||
load_and_export(model_path, output_path)
|
||||
8
model.py
8
model.py
@ -11,12 +11,13 @@ from torch import nn
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
# default hyperparameters for the Llama 7B model
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
vocab_size: int = -1 # defined later by tokenizer
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
vocab_size: int = 32000
|
||||
multiple_of: int = 256 # MLP hidden layer size will be multiple of
|
||||
norm_eps: float = 1e-5
|
||||
max_seq_len: int = 2048
|
||||
dropout: float = 0.0
|
||||
@ -93,6 +94,7 @@ class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
model_parallel_size = 1
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
@ -317,7 +319,7 @@ class Transformer(nn.Module):
|
||||
# if the sequence context is growing too long we must crop it at block_size
|
||||
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
logits, _ = self(idx_cond)
|
||||
logits = self(idx_cond)
|
||||
logits = logits[:, -1, :] # crop to just the final time step
|
||||
if temperature == 0.0:
|
||||
# "sample" the single most likely index
|
||||
|
||||
@ -2,7 +2,6 @@ numpy==1.23.5
|
||||
pytest==7.4.0
|
||||
Requests==2.31.0
|
||||
sentencepiece==0.1.99
|
||||
tiktoken==0.3.3
|
||||
torch==2.0.1
|
||||
tqdm==4.64.1
|
||||
wandb==0.15.5
|
||||
|
||||
362
run.c
362
run.c
@ -1,23 +1,5 @@
|
||||
/*
|
||||
Inference for Llama-2 Transformer model in pure C.
|
||||
|
||||
Example compile: (see README for more details)
|
||||
$ gcc -O3 -o run run.c -lm
|
||||
|
||||
Then run with:
|
||||
$ ./run
|
||||
|
||||
Example compile for a portable executable with embedded model:
|
||||
$ cosmocc -O3 -Ofast -funsafe-math-optimizations -ffast-math -D COSMO_BLINK \
|
||||
-D COSMO_METAL -D COSMO_ZIP -o run.com run.c -lm
|
||||
|
||||
Add checkpoint and tokenizer model to executable:
|
||||
$ zip run.com out/model.bin
|
||||
$ zip run.com tokenizer.bin
|
||||
|
||||
Then copy to any system (Linux,Win,Mac),(x86_64,ARM64) and run with:
|
||||
$ ./run.com
|
||||
*/
|
||||
/* Inference for Llama-2 Transformer model in pure C
|
||||
The Llama 2 Everywhere fork */
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Actually Portable Executable Format Preprocessor Directives
|
||||
@ -76,6 +58,7 @@ __static_yoink("zipos");
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <ctype.h>
|
||||
#include <time.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
@ -105,18 +88,18 @@ typedef struct {
|
||||
// weights for rmsnorms
|
||||
float* rms_att_weight; // (layer, dim) rmsnorm weights
|
||||
float* rms_ffn_weight; // (layer, dim)
|
||||
// weights for matmuls
|
||||
float* wq; // (layer, dim, dim)
|
||||
float* wk; // (layer, dim, dim)
|
||||
float* wv; // (layer, dim, dim)
|
||||
float* wo; // (layer, dim, dim)
|
||||
// weights for matmuls. note dim == n_heads * head_size
|
||||
float* wq; // (layer, dim, n_heads * head_size)
|
||||
float* wk; // (layer, dim, n_kv_heads * head_size)
|
||||
float* wv; // (layer, dim, n_kv_heads * head_size)
|
||||
float* wo; // (layer, n_heads * head_size, dim)
|
||||
// weights for ffn
|
||||
float* w1; // (layer, hidden_dim, dim)
|
||||
float* w2; // (layer, dim, hidden_dim)
|
||||
float* w3; // (layer, hidden_dim, dim)
|
||||
// final rmsnorm
|
||||
float* rms_final_weight; // (dim,)
|
||||
// freq_cis for RoPE relatively positional embeddings
|
||||
// freq_cis for RoPE relatively positional embeddings (not used anymore)
|
||||
float* freq_cis_real; // (seq_len, head_size/2)
|
||||
float* freq_cis_imag; // (seq_len, head_size/2)
|
||||
// (optional) classifier weights for the logits, on the last layer
|
||||
@ -148,24 +131,25 @@ typedef struct {
|
||||
|
||||
void malloc_run_state(RunState* s, Config* p) {
|
||||
// we calloc instead of malloc to keep valgrind happy
|
||||
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
||||
s->x = calloc(p->dim, sizeof(float));
|
||||
s->xb = calloc(p->dim, sizeof(float));
|
||||
s->xb2 = calloc(p->dim, sizeof(float));
|
||||
s->hb = calloc(p->hidden_dim, sizeof(float));
|
||||
s->hb2 = calloc(p->hidden_dim, sizeof(float));
|
||||
s->q = calloc(p->dim, sizeof(float));
|
||||
s->k = calloc(p->dim, sizeof(float));
|
||||
s->v = calloc(p->dim, sizeof(float));
|
||||
s->k = calloc(kv_dim, sizeof(float));
|
||||
s->v = calloc(kv_dim, sizeof(float));
|
||||
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
||||
s->logits = calloc(p->vocab_size, sizeof(float));
|
||||
s->probindex = calloc(p->vocab_size, sizeof(ProbIndex));
|
||||
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
|
||||
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
|
||||
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||
// ensure all mallocs went fine
|
||||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
||||
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|
||||
|| !s->value_cache || !s->probindex) {
|
||||
printf("malloc failed!\n");
|
||||
fprintf(stderr, "malloc failed!\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
@ -189,20 +173,20 @@ void free_run_state(RunState* s) {
|
||||
// ----------------------------------------------------------------------------
|
||||
// initialization: read from checkpoint
|
||||
|
||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) {
|
||||
float* ptr = f;
|
||||
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
|
||||
int head_size = p->dim / p->n_heads;
|
||||
w->token_embedding_table = ptr;
|
||||
ptr += p->vocab_size * p->dim;
|
||||
w->rms_att_weight = ptr;
|
||||
ptr += p->n_layers * p->dim;
|
||||
w->wq = ptr;
|
||||
ptr += p->n_layers * p->dim * p->dim;
|
||||
ptr += p->n_layers * p->dim * (p->n_heads * head_size);
|
||||
w->wk = ptr;
|
||||
ptr += p->n_layers * p->dim * p->dim;
|
||||
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
|
||||
w->wv = ptr;
|
||||
ptr += p->n_layers * p->dim * p->dim;
|
||||
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
|
||||
w->wo = ptr;
|
||||
ptr += p->n_layers * p->dim * p->dim;
|
||||
ptr += p->n_layers * (p->n_heads * head_size) * p->dim;
|
||||
w->rms_ffn_weight = ptr;
|
||||
ptr += p->n_layers * p->dim;
|
||||
w->w1 = ptr;
|
||||
@ -214,7 +198,6 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int sha
|
||||
w->rms_final_weight = ptr;
|
||||
ptr += p->dim;
|
||||
w->freq_cis_real = ptr;
|
||||
int head_size = p->dim / p->n_heads;
|
||||
ptr += p->seq_len * head_size / 2;
|
||||
w->freq_cis_imag = ptr;
|
||||
ptr += p->seq_len * head_size / 2;
|
||||
@ -298,6 +281,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
// a few convenience variables
|
||||
float *x = s->x;
|
||||
int dim = p->dim;
|
||||
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
|
||||
int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
|
||||
int hidden_dim = p->hidden_dim;
|
||||
int head_size = dim / p->n_heads;
|
||||
|
||||
@ -305,10 +290,6 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
float* content_row = &(w->token_embedding_table[token * dim]);
|
||||
memcpy(x, content_row, dim*sizeof(*x));
|
||||
|
||||
// pluck out the "pos" row of freq_cis_real and freq_cis_imag
|
||||
float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2;
|
||||
float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2;
|
||||
|
||||
// forward all the layers
|
||||
for(int l = 0; l < p->n_layers; l++) {
|
||||
|
||||
@ -317,29 +298,32 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
|
||||
// qkv matmuls for this position
|
||||
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
|
||||
matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim);
|
||||
matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim);
|
||||
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
|
||||
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
|
||||
|
||||
// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
|
||||
// RoPE relative positional encoding: complex-valued rotate q and k in each head
|
||||
for (int i = 0; i < dim; i+=2) {
|
||||
float q0 = s->q[i];
|
||||
float q1 = s->q[i+1];
|
||||
float k0 = s->k[i];
|
||||
float k1 = s->k[i+1];
|
||||
float fcr = freq_cis_real_row[(i % head_size) / 2];
|
||||
float fci = freq_cis_imag_row[(i % head_size) / 2];
|
||||
s->q[i] = q0 * fcr - q1 * fci;
|
||||
s->q[i+1] = q0 * fci + q1 * fcr;
|
||||
s->k[i] = k0 * fcr - k1 * fci;
|
||||
s->k[i+1] = k0 * fci + k1 * fcr;
|
||||
int head_dim = i % head_size;
|
||||
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
|
||||
float val = pos * freq;
|
||||
float fcr = cosf(val);
|
||||
float fci = sinf(val);
|
||||
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
|
||||
for (int v = 0; v < rotn; v++) {
|
||||
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
|
||||
float v0 = vec[i];
|
||||
float v1 = vec[i+1];
|
||||
vec[i] = v0 * fcr - v1 * fci;
|
||||
vec[i+1] = v0 * fci + v1 * fcr;
|
||||
}
|
||||
}
|
||||
|
||||
// save key,value at this time step (pos) to our kv cache
|
||||
int loff = l * p->seq_len * dim; // kv cache layer offset for convenience
|
||||
float* key_cache_row = s->key_cache + loff + pos * dim;
|
||||
float* value_cache_row = s->value_cache + loff + pos * dim;
|
||||
memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row));
|
||||
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));
|
||||
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
|
||||
float* key_cache_row = s->key_cache + loff + pos * kv_dim;
|
||||
float* value_cache_row = s->value_cache + loff + pos * kv_dim;
|
||||
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
|
||||
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
|
||||
|
||||
// multihead attention. iterate over all heads
|
||||
int h;
|
||||
@ -354,7 +338,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
// iterate over all timesteps, including the current one
|
||||
for (int t = 0; t <= pos; t++) {
|
||||
// get the key vector for this head and at this timestep
|
||||
float* k = s->key_cache + loff + t * dim + h * head_size;
|
||||
float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
|
||||
// calculate the attention score as the dot product of q and k
|
||||
float score = 0.0f;
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
@ -373,7 +357,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
memset(xb, 0, head_size * sizeof(float));
|
||||
for (int t = 0; t <= pos; t++) {
|
||||
// get the value vector for this head and at this timestep
|
||||
float* v = s->value_cache + loff + t * dim + h * head_size;
|
||||
float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
|
||||
// get the attention weight for this timestep
|
||||
float a = att[t];
|
||||
// accumulate the weighted value into xb
|
||||
@ -387,7 +371,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
|
||||
|
||||
// residual connection back into x
|
||||
accum(x, s->xb2, dim);
|
||||
for (int i = 0; i < dim; i++) {
|
||||
x[i] += s->xb2[i];
|
||||
}
|
||||
|
||||
// ffn rmsnorm
|
||||
rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
|
||||
@ -411,7 +397,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
|
||||
|
||||
// residual connection
|
||||
accum(x, s->xb, dim);
|
||||
for (int i = 0; i < dim; i++) {
|
||||
x[i] += s->xb[i];
|
||||
}
|
||||
}
|
||||
|
||||
// final rmsnorm
|
||||
@ -424,29 +412,87 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
// ----------------------------------------------------------------------------
|
||||
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
||||
|
||||
int str_lookup(char *str, char **vocab, int vocab_size) {
|
||||
// find the first perfect match for str in vocab, return its index or -1 if not found
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
if (strcmp(str, vocab[i]) == 0) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
typedef struct {
|
||||
char *str;
|
||||
int id;
|
||||
} TokenIndex;
|
||||
|
||||
int compare_tokens(const void *a, const void *b) {
|
||||
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
|
||||
}
|
||||
|
||||
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
|
||||
TokenIndex tok = { .str = str }; // acts as the key to search for
|
||||
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||
return res != NULL ? res->id : -1;
|
||||
}
|
||||
|
||||
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
|
||||
|
||||
// a temporary buffer to merge two consecutive tokens
|
||||
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
|
||||
// sort vocabulary
|
||||
TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex));
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
sorted_vocab[i].str = vocab[i];
|
||||
sorted_vocab[i].id = i;
|
||||
}
|
||||
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
|
||||
|
||||
// first encode every individual byte in the input string
|
||||
*n_tokens = 0; // the number of tokens
|
||||
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||
char* str_buffer = malloc((max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||
size_t str_len = 0;
|
||||
|
||||
// add_dummy_prefix is true by default
|
||||
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
|
||||
*n_tokens = 1; // the number of tokens
|
||||
|
||||
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
||||
// Code point ↔ UTF-8 conversion
|
||||
// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
|
||||
// U+0000 U+007F 0xxxxxxx
|
||||
// U+0080 U+07FF 110xxxxx 10xxxxxx
|
||||
// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
|
||||
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||
|
||||
// process the raw (UTF-8) byte sequence of the input string
|
||||
for (char *c = text; *c != '\0'; c++) {
|
||||
sprintf(str_buffer, "%c", *c);
|
||||
int id = str_lookup(str_buffer, vocab, vocab_size);
|
||||
if (id == -1) { printf("not good\n"); exit(EXIT_FAILURE); }
|
||||
tokens[*n_tokens] = id;
|
||||
(*n_tokens)++;
|
||||
|
||||
// reset buffer if the current byte is ASCII or a leading byte
|
||||
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
|
||||
// 0x80 is 10000000
|
||||
// in UTF-8, all continuation bytes start with "10" in first two bits
|
||||
// so in English this is: "if this byte is not a continuation byte"
|
||||
if ((*c & 0xC0) != 0x80) {
|
||||
// this byte must be either a leading byte (11...) or an ASCII char (0x...)
|
||||
// => reset our location, as we're starting a new UTF-8 codepoint
|
||||
str_len = 0;
|
||||
}
|
||||
|
||||
// append the current byte to the buffer
|
||||
str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
|
||||
str_buffer[str_len] = '\0';
|
||||
|
||||
// while the next character is a continuation byte, continue appending
|
||||
// but if there are too many of them, just stop to avoid overruning str_buffer size.
|
||||
if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// ok c+1 is not a continuation byte, so we've read in a full codepoint
|
||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||
|
||||
if (id != -1) {
|
||||
// we found this codepoint in vocab, add it as a token
|
||||
tokens[(*n_tokens)++] = id;
|
||||
} else {
|
||||
// byte_fallback encoding: just encode each byte as a token
|
||||
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
|
||||
// so the individual bytes only start at index 3
|
||||
for (int i=0; i < str_len; i++) {
|
||||
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
|
||||
}
|
||||
}
|
||||
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
|
||||
}
|
||||
|
||||
// merge the best consecutive pair each iteration, according the scores in vocab_scores
|
||||
@ -458,7 +504,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
for (int i=0; i < (*n_tokens-1); i++) {
|
||||
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
|
||||
int id = str_lookup(str_buffer, vocab, vocab_size);
|
||||
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
|
||||
if (id != -1 && vocab_scores[id] > best_score) {
|
||||
// this merge pair exists in vocab! record its score and position
|
||||
best_score = vocab_scores[id];
|
||||
@ -481,6 +527,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
|
||||
}
|
||||
|
||||
free(str_buffer);
|
||||
free(sorted_vocab);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@ -547,17 +594,24 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
// tokens that exceed probability topp. This way we never sample tokens that
|
||||
// have very low probabilities and are less likely to go "off the rails".
|
||||
|
||||
int n0 = 0;
|
||||
// quicksort indices in descending order of probabilities
|
||||
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
|
||||
// so for efficiency we crop these out as candidates before sorting
|
||||
const float cutoff = (1.0f - topp) / (n - 1);
|
||||
for (int i = 0; i < n; i++) {
|
||||
probindex[i].index = i;
|
||||
probindex[i].prob = probabilities[i];
|
||||
if (probabilities[i] >= cutoff) {
|
||||
probindex[n0].index = i;
|
||||
probindex[n0].prob = probabilities[i];
|
||||
n0++;
|
||||
}
|
||||
}
|
||||
qsort(probindex, n, sizeof(ProbIndex), compare);
|
||||
qsort(probindex, n0, sizeof(ProbIndex), compare);
|
||||
|
||||
// truncate the list where cumulative probability exceeds topp
|
||||
float cumulative_prob = 0.0f;
|
||||
int last_idx = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
int last_idx = n0 - 1; // in case of rounding errors consider all elements
|
||||
for (int i = 0; i < n0; i++) {
|
||||
cumulative_prob += probindex[i].prob;
|
||||
if (cumulative_prob > topp) {
|
||||
last_idx = i;
|
||||
@ -581,13 +635,28 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
// ----------------------------------------------------------------------------
|
||||
// int main
|
||||
|
||||
void error_usage() {
|
||||
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
|
||||
fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
|
||||
fprintf(stderr, "Options:\n");
|
||||
fprintf(stderr, " -t <float> temperature, default 1.0\n");
|
||||
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling. default 0.9\n");
|
||||
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
|
||||
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
|
||||
fprintf(stderr, " -b <int> number of tokens to buffer, default 1. 0 = max_seq_len\n");
|
||||
fprintf(stderr, " -i <string> input prompt\n");
|
||||
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
|
||||
// default inits
|
||||
char *checkpoint = NULL; // e.g. out/model.bin
|
||||
char *tokenizer = "tokenizer.bin";
|
||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||
float topp = 0.9f; // top-p in nucleus sampling
|
||||
rng_seed = (unsigned int)time(NULL); // seed rng with time by default
|
||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||
rng_seed = 0; // seed rng with time by default
|
||||
int steps = 256; // number of steps to run for
|
||||
char *prompt = NULL; // prompt string
|
||||
int buffertokens = 1; // output token buffer size
|
||||
@ -604,32 +673,24 @@ int main(int argc, char *argv[]) {
|
||||
prompt=promptbuffer; // Set prompt
|
||||
#else
|
||||
// poor man's C argparse so we can override the defaults above from the command line
|
||||
if (argc < 2) {
|
||||
printf("Usage: %s <checkpoint_file> \n", argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
if (argc >= 2) { checkpoint = argv[1]; }
|
||||
for (int i = 2; i < argc; i++) {
|
||||
// do some basic validation - add rng_seed and other checks
|
||||
switch (argv[i][0]) {
|
||||
case '-':
|
||||
switch (argv[i][1]) {
|
||||
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
|
||||
case 't': if (i + 1 < argc) { temperature = atof(argv[++i]); } break;
|
||||
case 'p': if (i + 1 < argc) { topp = atof(argv[++i]); } break;
|
||||
case 's': if (i + 1 < argc) { rng_seed = atoi(argv[++i]); } break;
|
||||
case 'n': if (i + 1 < argc) { steps = atoi(argv[++i]); } break;
|
||||
case 'b': if (i + 1 < argc) { buffertokens = atoi(argv[++i]); } break;
|
||||
case 'i': if (i + 1 < argc) { prompt = argv[++i]; } break;
|
||||
default: printf("Invalid option: %s\n", argv[i]);
|
||||
exit(EXIT_FAILURE);
|
||||
} break;
|
||||
default:
|
||||
printf("Usage: %s <checkpoint_file> -t [temperature] -p [top-p] -s [seed] -n [steps] -b [buffertokens] -p [prompt] \n", argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); }
|
||||
for (int i = 2; i < argc; i+=2) {
|
||||
// do some basic validation
|
||||
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
|
||||
if (argv[i][0] != '-') { error_usage(); } // must start with dash
|
||||
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
|
||||
// read in the args
|
||||
if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
|
||||
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'b') { buffertokens = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
|
||||
else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; }
|
||||
else { error_usage(); }
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
|
||||
|
||||
// read in the model.bin file
|
||||
Config config;
|
||||
@ -639,7 +700,7 @@ int main(int argc, char *argv[]) {
|
||||
ssize_t file_size; // size of the checkpoint file in bytes
|
||||
{
|
||||
FILE *file = fopen(checkpoint, "rb");
|
||||
if (!file) { printf("Couldn't open file %s\n", checkpoint); return 1; }
|
||||
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); return 1; }
|
||||
// read in the config header
|
||||
if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
||||
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
||||
@ -651,16 +712,16 @@ int main(int argc, char *argv[]) {
|
||||
fclose(file);
|
||||
// memory map the Transformer weights into the data pointer
|
||||
fd = open(checkpoint, O_RDONLY); // open in read only mode
|
||||
if (fd == -1) { printf("open failed!\n"); return 1; }
|
||||
if (fd == -1) { fprintf(stderr, "open failed!\n"); return 1; }
|
||||
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
|
||||
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; }
|
||||
if (data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); return 1; }
|
||||
float* weights_ptr = data + sizeof(Config)/sizeof(float);
|
||||
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
|
||||
}
|
||||
// right now we cannot run for more than config.seq_len steps
|
||||
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
|
||||
|
||||
// read in the tokenizer.bin file
|
||||
// read in the tokenizer .bin file
|
||||
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
|
||||
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
|
||||
unsigned int max_token_length;
|
||||
@ -669,16 +730,16 @@ int main(int argc, char *argv[]) {
|
||||
// we read the embedded tokenizer.bin from within the executable
|
||||
FILE *file = fopen("/zip/tokenizer.bin", "rb");
|
||||
#else
|
||||
FILE *file = fopen("tokenizer.bin", "rb");
|
||||
FILE *file = fopen(tokenizer, "rb");
|
||||
#endif
|
||||
if (!file) { printf("couldn't load tokenizer.bin\n"); return 1; }
|
||||
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
|
||||
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; }
|
||||
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
||||
int len;
|
||||
for (int i = 0; i < config.vocab_size; i++) {
|
||||
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { printf("failed read\n"); return 1;}
|
||||
if (fread(&len, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
|
||||
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1;}
|
||||
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
||||
vocab[i] = (char *)malloc(len + 1);
|
||||
if (fread(vocab[i], len, 1, file) != 1) { printf("failed read\n"); return 1; }
|
||||
if (fread(vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
|
||||
vocab[i][len] = '\0'; // add the string terminating token
|
||||
}
|
||||
fclose(file);
|
||||
@ -692,7 +753,7 @@ int main(int argc, char *argv[]) {
|
||||
int *prompt_tokens = NULL;
|
||||
int num_prompt_tokens = 0;
|
||||
if (prompt != NULL) {
|
||||
prompt_tokens = (int*)malloc(strlen(prompt) * sizeof(int));
|
||||
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
|
||||
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
|
||||
}
|
||||
|
||||
@ -703,8 +764,8 @@ int main(int argc, char *argv[]) {
|
||||
int pos = 0; // position in the sequence
|
||||
int bufferflush = 1; // token counter for flushing buffer
|
||||
static char outbuff[4096 * (6 + 2)] ; // buffersize is context length * average size of subwords + margin
|
||||
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
|
||||
|
||||
// Todo: we can do buffering without setvbuff, implement that
|
||||
// setvbuf is used to buffer output into outbuff instead of flushing to screen directly
|
||||
if (setvbuf(stdout, outbuff, _IOFBF, sizeof(outbuff)) != 0) {
|
||||
puts("Error: Buffer allocation!"); exit(EXIT_FAILURE);
|
||||
@ -715,6 +776,7 @@ int main(int argc, char *argv[]) {
|
||||
// forward the transformer to get logits for the next token
|
||||
transformer(token, pos, &config, &state, &weights);
|
||||
|
||||
// advance the state state machine
|
||||
if(pos < num_prompt_tokens) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
next = prompt_tokens[pos];
|
||||
@ -729,7 +791,7 @@ int main(int argc, char *argv[]) {
|
||||
// apply softmax to the logits to get the probabilities for next token
|
||||
softmax(state.logits, config.vocab_size);
|
||||
// we sample from this distribution to get the next token
|
||||
if (topp <= 0) {
|
||||
if (topp <= 0 || topp >= 1) {
|
||||
// simply sample from the predicted probability distribution
|
||||
next = sample(state.logits, config.vocab_size);
|
||||
} else {
|
||||
@ -738,24 +800,40 @@ int main(int argc, char *argv[]) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
||||
|
||||
printf("%s", token_str);
|
||||
// flush output to screen after the defined number of buffertokens have accumulated
|
||||
if (bufferflush==pos) { fflush(stdout); bufferflush+=buffertokens; }
|
||||
|
||||
// advance forward
|
||||
token = next;
|
||||
pos++;
|
||||
// init our timer here because the first iteration is slow due to memmap
|
||||
|
||||
// data-dependent terminating condition: the BOS (1) token delimits sequences
|
||||
if (next == 1) { break; }
|
||||
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
||||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||||
unsigned char byte_val;
|
||||
if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) {
|
||||
// ok this token is a raw byte token, carefuly to only print printable chars or whitespace
|
||||
// some of the other bytes can be various control codes, backspace, etc. => skip
|
||||
if (isprint(byte_val) || isspace(byte_val)) {
|
||||
char byte_piece[2];
|
||||
byte_piece[0] = byte_val;
|
||||
byte_piece[1] = '\0';
|
||||
printf("%s", byte_piece);
|
||||
}
|
||||
} else {
|
||||
printf("%s", token_str);
|
||||
}
|
||||
if (bufferflush==pos) { fflush(stdout); bufferflush+=buffertokens; }
|
||||
token = next;
|
||||
|
||||
// init the timer here because the first iteration can be slower
|
||||
if (start == 0) { start = time_in_ms(); }
|
||||
}
|
||||
|
||||
// report achieved tok/s
|
||||
long end = time_in_ms();
|
||||
printf("\nachieved tok/s: %f\n", (steps-1) / (double)(end-start)*1000);
|
||||
printf("\n");
|
||||
fflush(stdout); // This could be in the if next break, and the print new line prepended to achieved tok/s
|
||||
// report achieved tok/s (pos-1 because the timer starts after first iteration)
|
||||
if (pos > 1) {
|
||||
long end = time_in_ms();
|
||||
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
|
||||
}
|
||||
|
||||
// memory and file handles cleanup
|
||||
free_run_state(&state);
|
||||
|
||||
130
run.ipynb
Normal file
130
run.ipynb
Normal file
@ -0,0 +1,130 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "HLdoj4cz-xal"
|
||||
},
|
||||
"source": [
|
||||
"# Run.c\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/karpathy/llama2.c/blob/master/run.ipynb)\n",
|
||||
"\n",
|
||||
"More details can be found in the [README.md](README.md) ."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Une3Ozlnu1B7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Clone Project\n",
|
||||
"\n",
|
||||
"!git clone https://github.com/karpathy/llama2.c.git\n",
|
||||
"%cd llama2.c"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Build\n",
|
||||
"\n",
|
||||
"!make runfast"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "thm0ZBrtSgoC"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Pick Your Model\n",
|
||||
"\n",
|
||||
"#@markdown Choose model\n",
|
||||
"model = \"stories15M\" #@param [\"stories15M\", \"stories42M\", \"stories110M\"]\n",
|
||||
"\n",
|
||||
"download_url = \"\"\n",
|
||||
"\n",
|
||||
"if(model == \"stories15M\"):\n",
|
||||
" download_url = \"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin\"\n",
|
||||
"if(model == \"stories42M\"):\n",
|
||||
" download_url = \"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin\"\n",
|
||||
"if(model == \"stories110M\"):\n",
|
||||
" download_url = \"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin\"\n",
|
||||
"\n",
|
||||
"print(f\"download_url: {download_url}\")\n",
|
||||
"\n",
|
||||
"!wget $download_url\n",
|
||||
"\n",
|
||||
"model_file = model + \".bin\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "OgAc3KjuT-NM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Generate Stories\n",
|
||||
"\n",
|
||||
"# Generate args\n",
|
||||
"max_token = 256 #@param {type:\"slider\", min:32, max:1024, step:32}\n",
|
||||
"temperature = 0.8 #@param {type:\"slider\", min:0.0, max:1, step:0.05}\n",
|
||||
"top_p = 0.9 #@param {type:\"slider\", min:0.0, max:1.0, step:0.05}\n",
|
||||
"prompt = \"One day, Lily met a Shoggoth\" #@param {type:\"string\"}\n",
|
||||
"\n",
|
||||
"print(f\"model: {model_file}, max_token: {max_token}, temperature: {temperature}, top_p: {top_p}, prompt: {prompt}\")\n",
|
||||
"print(f\"----------------------------\\n\")\n",
|
||||
"\n",
|
||||
"cmd = f'./run {model_file} -t {temperature} -p {top_p} -n {max_token} -i \"{prompt}\"'\n",
|
||||
"!{cmd}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Run Meta's Llama 2 models\n",
|
||||
"\n",
|
||||
"#@markdown input your huggingface [access token](https://huggingface.co/settings/tokens) to download Meta's Llama 2 models.\n",
|
||||
"\n",
|
||||
"from huggingface_hub import snapshot_download\n",
|
||||
"\n",
|
||||
"token = \"replace your huggingface access token\" #@param {type:\"string\"}\n",
|
||||
"path = snapshot_download(repo_id=\"meta-llama/Llama-2-7b\",cache_dir=\"Llama-2-7b\", use_auth_token=token)\n",
|
||||
"\n",
|
||||
"!python export_meta_llama_bin.py $path llama2_7b.bin\n",
|
||||
"\n",
|
||||
"print(\"./run llama2_7b.bin\\n\")\n",
|
||||
"!./run llama2_7b.bin"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"private_outputs": true,
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
24
sample.py
24
sample.py
@ -5,17 +5,19 @@ import os
|
||||
import pickle
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
import tiktoken
|
||||
from model import ModelArgs, Transformer
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
from tinystories import get_tokenizer_model_path
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
out_dir = 'out' # ignored if init_from is not 'resume'
|
||||
checkpoint = 'out/ckpt.pt'
|
||||
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
|
||||
num_samples = 1 # number of samples to draw
|
||||
max_new_tokens = 100 # number of tokens generated in each sample
|
||||
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
|
||||
top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability
|
||||
tokenizer = "" # override the tokenizer model path
|
||||
seed = 1337
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
|
||||
#dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
|
||||
@ -33,11 +35,10 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc
|
||||
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
||||
|
||||
# init from a model saved in a specific directory
|
||||
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
gptconf = ModelArgs(**checkpoint['model_args'])
|
||||
checkpoint_dict = torch.load(checkpoint, map_location=device)
|
||||
gptconf = ModelArgs(**checkpoint_dict['model_args'])
|
||||
model = Transformer(gptconf)
|
||||
state_dict = checkpoint['model']
|
||||
state_dict = checkpoint_dict['model']
|
||||
unwanted_prefix = '_orig_mod.'
|
||||
for k,v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
@ -51,7 +52,16 @@ if compile:
|
||||
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
||||
|
||||
# load the tokenizer
|
||||
enc = Tokenizer()
|
||||
vocab_source = checkpoint_dict.get("vocab_source", "llama2")
|
||||
vocab_size = gptconf.vocab_size
|
||||
if tokenizer:
|
||||
# a specific tokenizer is provided, use it
|
||||
tokenizer_model = tokenizer
|
||||
else:
|
||||
# let's try to find the tokenizer model automatically. bit gross here...
|
||||
query_vocab_size = 0 if vocab_source == "llama2" else vocab_size
|
||||
tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size)
|
||||
enc = Tokenizer(tokenizer_model=tokenizer_model)
|
||||
|
||||
# encode the beginning of the prompt
|
||||
if start.startswith('FILE:'):
|
||||
|
||||
92
test_all.py
92
test_all.py
@ -4,37 +4,71 @@ $ pytest
|
||||
"""
|
||||
import os
|
||||
import pytest # pip install pytest
|
||||
import requests
|
||||
import subprocess
|
||||
|
||||
|
||||
import torch
|
||||
from model import ModelArgs, Transformer
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
def test_argmax_inference():
|
||||
"""
|
||||
Only the simplest test for now: run inference with temperature 0
|
||||
(for determinism) in both C and PyTorch, and see that the sampled tokens
|
||||
are the same.
|
||||
"""
|
||||
test_ckpt_dir = "out" # TODO create a dummy test checkpoint for this?
|
||||
# -----------------------------------------------------------------------------
|
||||
# test utilities
|
||||
|
||||
# run C version
|
||||
model_path = os.path.join(test_ckpt_dir, "model.bin")
|
||||
command = ["./run", model_path, "0.0"]
|
||||
proc = subprocess.Popen(command, stdout=subprocess.PIPE)
|
||||
c_tokens = []
|
||||
for line in proc.stdout:
|
||||
token = int(line.decode('utf-8').strip())
|
||||
c_tokens.append(token)
|
||||
proc.wait()
|
||||
#print(c_tokens)
|
||||
test_ckpt_dir = "test"
|
||||
|
||||
# run PyTorch version
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
ckpt_path = os.path.join(test_ckpt_dir, "ckpt.pt")
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
gptconf = ModelArgs(**checkpoint['model_args'])
|
||||
def download_file(url, filename):
|
||||
print(f"Downloading {url} to {filename}")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status() # Raise an HTTPError on bad status code
|
||||
with open(filename, 'wb') as file:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
file.write(chunk)
|
||||
|
||||
def attempt_download_files():
|
||||
os.makedirs(test_ckpt_dir, exist_ok=True)
|
||||
root_url = "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K"
|
||||
need = ["stories260K.bin", "stories260K.pt", "tok512.bin", "tok512.model"]
|
||||
for file in need:
|
||||
url = root_url + '/' + file #os.path.join inserts \\ on windows
|
||||
filename = os.path.join(test_ckpt_dir, file)
|
||||
if not os.path.exists(filename):
|
||||
download_file(url, filename)
|
||||
|
||||
expected_stdout = b'Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big, red ball. She wanted to play with it, but it was too high.\nLily\'s mom said, "Lily, let\'s go to the park." Lily was sad and didn\'t know what to do. She said, "I want to play with your ball, but I can\'t find it."\nLily was sad and didn\'t know what to do. She said, "I\'m sorry, Lily. I didn\'t know what to do."\nLily didn\'t want to help her mom, so she'
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# actual tests
|
||||
|
||||
def test_runc():
|
||||
""" Forwards a model against a known-good desired outcome in run.c for 200 steps"""
|
||||
attempt_download_files()
|
||||
|
||||
model_path = os.path.join(test_ckpt_dir, "stories260K.bin")
|
||||
tokenizer_path = os.path.join(test_ckpt_dir, "tok512.bin")
|
||||
command = ["./run", model_path, "-z", tokenizer_path, "-t", "0.0", "-n", "200"]
|
||||
with open('err.txt', mode='wb') as fe:
|
||||
with open('stdout.txt', mode='wb') as fo:
|
||||
proc = subprocess.Popen(command, stdout=fo, stderr=fe) #pipe in windows terminal does funny things like replacing \n with \r\n
|
||||
proc.wait()
|
||||
|
||||
with open('stdout.txt', mode='r') as f:
|
||||
stdout = f.read()
|
||||
# strip the very last \n that is added by run.c for aesthetic reasons
|
||||
stdout = stdout[:-1].encode('ascii')
|
||||
|
||||
assert stdout == expected_stdout
|
||||
|
||||
def test_python():
|
||||
""" Forwards a model against a known-good desired outcome in sample.py for 200 steps"""
|
||||
attempt_download_files()
|
||||
|
||||
device = "cpu" # stories260K is small enough to just breeze through it on CPU
|
||||
checkpoint = os.path.join(test_ckpt_dir, "stories260K.pt")
|
||||
checkpoint_dict = torch.load(checkpoint, map_location=device)
|
||||
gptconf = ModelArgs(**checkpoint_dict['model_args'])
|
||||
model = Transformer(gptconf)
|
||||
state_dict = checkpoint['model']
|
||||
state_dict = checkpoint_dict['model']
|
||||
unwanted_prefix = '_orig_mod.'
|
||||
for k,v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
@ -44,10 +78,12 @@ def test_argmax_inference():
|
||||
model.to(device)
|
||||
x = torch.tensor([[1]], dtype=torch.long, device=device) # 1 is BOS
|
||||
with torch.inference_mode():
|
||||
y = model.generate(x, max_new_tokens=gptconf.max_seq_len, temperature=0.0)
|
||||
y = model.generate(x, max_new_tokens=200, temperature=0.0)
|
||||
pt_tokens = y[0].tolist()
|
||||
pt_tokens = pt_tokens[1:] # remove BOS
|
||||
#print(pt_tokens)
|
||||
|
||||
# compare
|
||||
assert c_tokens == pt_tokens
|
||||
tokenizer_model = os.path.join(test_ckpt_dir, "tok512.model")
|
||||
enc = Tokenizer(tokenizer_model=tokenizer_model)
|
||||
text = enc.decode(pt_tokens)
|
||||
text = text.encode('ascii') # turn into bytes
|
||||
|
||||
assert text == expected_stdout
|
||||
|
||||
@ -1,140 +0,0 @@
|
||||
"""
|
||||
Download, preprocess and serve the TinyShakespeare dataset as a DataLoader.
|
||||
|
||||
Follows the same interface as the TinyStories dataset.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
DATA_CACHE_DIR = "data"
|
||||
|
||||
def download_file(url: str, fname: str, chunk_size=1024):
|
||||
"""Helper function to download a file from a given url"""
|
||||
resp = requests.get(url, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
with open(fname, "wb") as file, tqdm(
|
||||
desc=fname,
|
||||
total=total,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=chunk_size):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
|
||||
|
||||
def download():
|
||||
"""Downloads the dataset to disk."""
|
||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
||||
|
||||
# download the TinyShakespeare dataset, unless it's already downloaded
|
||||
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
|
||||
data_filename = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt")
|
||||
if not os.path.exists(data_filename):
|
||||
print(f"Downloading {data_url} to {data_filename}...")
|
||||
download_file(data_url, data_filename)
|
||||
else:
|
||||
print(f"{data_filename} already exists, skipping download...")
|
||||
|
||||
print("Download done.")
|
||||
|
||||
def pretokenize():
|
||||
enc = Tokenizer()
|
||||
|
||||
data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.txt")
|
||||
|
||||
all_tokens = []
|
||||
with open(data_file, "r") as f:
|
||||
for line in f:
|
||||
text = line.strip()
|
||||
tokens = enc.encode(text, bos=True, eos=False)
|
||||
all_tokens.extend(tokens)
|
||||
all_tokens = np.array(all_tokens, dtype=np.uint16)
|
||||
print(f"Total tokens: {len(all_tokens)}")
|
||||
with open(data_file.replace(".txt", ".bin"), "wb") as f:
|
||||
f.write(all_tokens.tobytes())
|
||||
print(f"Saved {data_file.replace('.txt', '.bin')}")
|
||||
print("Done.")
|
||||
|
||||
|
||||
class PretokDataset(torch.utils.data.IterableDataset):
|
||||
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
|
||||
|
||||
def __init__(self, split, max_seq_len):
|
||||
super().__init__()
|
||||
self.split = split
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def __iter__(self):
|
||||
# get worker info within a DataLoader
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
worker_id = worker_info.id if worker_info else 0
|
||||
# get DDP rank info
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
# combine the worker_id and worker_rank to create a unique seed for rng
|
||||
seed = 42 + worker_id + 1337 * rank
|
||||
rng = random.Random(seed)
|
||||
print(f"Created a PretokDataset with rng seed {seed}")
|
||||
data_file = os.path.join(DATA_CACHE_DIR, "tinyshakespeare.bin")
|
||||
m_all = np.memmap(data_file, dtype=np.uint16, mode="r")
|
||||
|
||||
# split out 10% of the data for validation
|
||||
split_ix = int(len(m_all) * 0.9)
|
||||
if self.split == "train":
|
||||
m = m_all[:split_ix]
|
||||
else:
|
||||
m = m_all[split_ix:]
|
||||
|
||||
num_batches = len(m) // self.max_seq_len
|
||||
num_batches -= 1 # drop the last partial batch
|
||||
assert num_batches > 0, "this split is way too small? investigate."
|
||||
|
||||
while True:
|
||||
ixs = list(range(num_batches))
|
||||
rng.shuffle(ixs)
|
||||
for ix in ixs:
|
||||
start = ix * self.max_seq_len
|
||||
end = start + self.max_seq_len + 1
|
||||
# calling .astype will copy the data into a new numpy array, now in RAM
|
||||
chunk = torch.from_numpy((m[start:end]).astype(np.int64))
|
||||
x = chunk[:-1]
|
||||
y = chunk[1:]
|
||||
yield x, y
|
||||
|
||||
|
||||
class ShakespeareTask:
|
||||
|
||||
@staticmethod
|
||||
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
|
||||
ds = PretokDataset(split, max_seq_len)
|
||||
dl = torch.utils.data.DataLoader(
|
||||
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
||||
)
|
||||
for x, y in dl:
|
||||
x = x.to(device, non_blocking=True)
|
||||
y = y.to(device, non_blocking=True)
|
||||
yield x, y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"])
|
||||
args = parser.parse_args()
|
||||
|
||||
# depending on the stage call the appropriate function
|
||||
fun = {
|
||||
"download": download,
|
||||
"pretokenize": pretokenize,
|
||||
}
|
||||
fun[args.stage]()
|
||||
147
tinystories.py
147
tinystories.py
@ -9,6 +9,7 @@ import os
|
||||
import random
|
||||
from typing import List
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -37,7 +38,7 @@ def download_file(url: str, fname: str, chunk_size=1024):
|
||||
|
||||
|
||||
def download():
|
||||
"""Downloads the dataset to disk."""
|
||||
"""Downloads the TinyStories dataset to DATA_CACHE_DIR"""
|
||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
||||
|
||||
# download the TinyStories dataset, unless it's already downloaded
|
||||
@ -66,10 +67,61 @@ def download():
|
||||
print(f"Number of shards: {len(shard_filenames)}")
|
||||
print(f"Example story:\n{data[0]}")
|
||||
|
||||
def train_vocab(vocab_size):
|
||||
"""
|
||||
Trains a custom sentencepiece tokenizer on the TinyStories dataset.
|
||||
The custom tokenizer files will be saved in DATA_CACHE_DIR/tok{N} directories,
|
||||
where N is the vocab size. This is also where the pretok .bin files will go.
|
||||
"""
|
||||
assert vocab_size > 0, "Vocab size must be positive"
|
||||
|
||||
def process_shard(args):
|
||||
# output file prefix path for sentencepiece
|
||||
prefix = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
|
||||
|
||||
# how many shards we'll use for vocab training, kept low for efficiency
|
||||
num_shards = 10
|
||||
|
||||
# 1) export a large chunk of text as a single text file tiny.txt
|
||||
tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt")
|
||||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
|
||||
|
||||
print(f"Writing temporary file {tiny_file} with {num_shards} shards...")
|
||||
with open(tiny_file, "w") as of:
|
||||
for shard in tqdm(shard_filenames[:num_shards]):
|
||||
with open(shard, "r") as f:
|
||||
data = json.load(f)
|
||||
for example in data:
|
||||
text = example["story"]
|
||||
text = text.strip()
|
||||
of.write(text + "\n")
|
||||
print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB")
|
||||
|
||||
# 2) run the train_vocab.sh script that trains the sentencepiece model
|
||||
print("Will now train the vocab with:")
|
||||
cmd = f"bash train_vocab.sh {tiny_file} {prefix} {vocab_size}"
|
||||
print(cmd)
|
||||
print("OK? [y/N] ")
|
||||
dec = input()
|
||||
if dec.lower() != "y":
|
||||
print("Exiting...")
|
||||
return
|
||||
os.system(cmd)
|
||||
|
||||
# 3) optional cleanup, ask the user if they'd like to delete tiny.txt
|
||||
dec = input(f"Delete the temporary file {tiny_file}? [y/N] ")
|
||||
if dec.lower() == "y":
|
||||
os.remove(tiny_file)
|
||||
print(f"Deleted {tiny_file}")
|
||||
|
||||
print(f"Trained tokenizer is in {prefix}.model")
|
||||
print("Done.")
|
||||
|
||||
|
||||
def process_shard(args, vocab_size):
|
||||
shard_id, shard = args
|
||||
enc = Tokenizer()
|
||||
tokenizer_model = get_tokenizer_model_path(vocab_size)
|
||||
enc = Tokenizer(tokenizer_model)
|
||||
with open(shard, "r") as f:
|
||||
data = json.load(f)
|
||||
all_tokens = []
|
||||
@ -80,31 +132,49 @@ def process_shard(args):
|
||||
all_tokens.extend(tokens)
|
||||
# convert to uint16 nparray
|
||||
all_tokens = np.array(all_tokens, dtype=np.uint16)
|
||||
# write to disk
|
||||
tokenized_filename = shard.replace(".json", ".bin")
|
||||
# calculate the output filename
|
||||
if vocab_size == 0:
|
||||
# if we're using Llama 2, just save the tokenized file in the same dir
|
||||
tokenized_filename = shard.replace(".json", ".bin")
|
||||
else:
|
||||
# save .bin files into a new tok{N} directory
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
|
||||
shard_basename = os.path.basename(shard)
|
||||
bin_basename = shard_basename.replace(".json", ".bin")
|
||||
tokenized_filename = os.path.join(bin_dir, bin_basename)
|
||||
# write the bytes
|
||||
with open(tokenized_filename, "wb") as f:
|
||||
f.write(all_tokens.tobytes())
|
||||
print(f"Saved {tokenized_filename}")
|
||||
# calculate the average sequence length (they are separated by BOS=1)
|
||||
avg_seq_len = all_tokens.size / ((all_tokens == 1).sum())
|
||||
print(f"Saved {tokenized_filename}, average seqlen: {avg_seq_len:.2f}")
|
||||
|
||||
|
||||
def pretokenize():
|
||||
def pretokenize(vocab_size):
|
||||
# iterate the shards and tokenize all of them one by one
|
||||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
|
||||
if vocab_size > 0:
|
||||
# .bin files will be saved into tok{N} directory, create it once here
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
|
||||
os.makedirs(bin_dir, exist_ok=True)
|
||||
|
||||
# process all the shards in a process pool
|
||||
fun = partial(process_shard, vocab_size=vocab_size)
|
||||
with ProcessPoolExecutor() as executor:
|
||||
executor.map(process_shard, enumerate(shard_filenames))
|
||||
executor.map(fun, enumerate(shard_filenames))
|
||||
print("Done.")
|
||||
|
||||
|
||||
class PretokDataset(torch.utils.data.IterableDataset):
|
||||
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
|
||||
|
||||
def __init__(self, split, max_seq_len):
|
||||
def __init__(self, split, max_seq_len, vocab_size, vocab_source):
|
||||
super().__init__()
|
||||
self.split = split
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_source = vocab_source
|
||||
|
||||
def __iter__(self):
|
||||
# get worker info within a DataLoader
|
||||
@ -116,10 +186,17 @@ class PretokDataset(torch.utils.data.IterableDataset):
|
||||
seed = 42 + worker_id + 1337 * rank
|
||||
rng = random.Random(seed)
|
||||
print(f"Created a PretokDataset with rng seed {seed}")
|
||||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin")))
|
||||
if self.vocab_source == "llama2":
|
||||
# the .bin files are right along the .json files
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||
elif self.vocab_source == "custom":
|
||||
# the .bin files are in tok{N} directory
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||
# train/test split. let's use only shard 0 for test split, rest train
|
||||
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
|
||||
assert len(shard_filenames)>0, f"No bin files found in {bin_dir}"
|
||||
while True:
|
||||
rng.shuffle(shard_filenames)
|
||||
for shard in shard_filenames:
|
||||
@ -139,12 +216,25 @@ class PretokDataset(torch.utils.data.IterableDataset):
|
||||
y = chunk[1:]
|
||||
yield x, y
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# public interface functions
|
||||
|
||||
def get_tokenizer_model_path(vocab_size):
|
||||
"""
|
||||
Returns path to the sentencepiece tokenizer model for a given vocab size
|
||||
vocab_size = 0 designates the default Llama 2 tokenizer, in that case
|
||||
None is returned.
|
||||
"""
|
||||
if vocab_size == 0:
|
||||
return None
|
||||
else:
|
||||
return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
|
||||
|
||||
class Task:
|
||||
|
||||
@staticmethod
|
||||
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
|
||||
ds = PretokDataset(split, max_seq_len)
|
||||
def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs):
|
||||
ds = PretokDataset(**dataset_kwargs)
|
||||
dl = torch.utils.data.DataLoader(
|
||||
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
||||
)
|
||||
@ -153,16 +243,33 @@ class Task:
|
||||
y = y.to(device, non_blocking=True)
|
||||
yield x, y
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI for constructing the dataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
These stages are designed to be run in order.
|
||||
|
||||
To tokenize data with the Llama 2 tokenizer:
|
||||
python tinystories.py download
|
||||
python tinystories.py pretokenize
|
||||
|
||||
To tokenize data with a custom tokenizer we train ourselves with sentencepiece, e.g.:
|
||||
python tinystories.py download
|
||||
python tinystories.py train_vocab --vocab_size=2048
|
||||
python tinystories.py pretokenize --vocab_size=2048
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"])
|
||||
parser.add_argument("stage", type=str, choices=["download", "pretokenize", "train_vocab"])
|
||||
parser.add_argument("--vocab_size", type=int, default=0, help="pretokenization vocab size. 0 = use Llama 2 tokenizer.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# depending on the stage call the appropriate function
|
||||
fun = {
|
||||
"download": download,
|
||||
"pretokenize": pretokenize,
|
||||
}
|
||||
fun[args.stage]()
|
||||
|
||||
if args.stage == "download":
|
||||
download()
|
||||
elif args.stage == "train_vocab":
|
||||
train_vocab(vocab_size=args.vocab_size)
|
||||
elif args.stage == "pretokenize":
|
||||
pretokenize(vocab_size=args.vocab_size)
|
||||
else:
|
||||
raise ValueError(f"Unknown stage {args.stage}")
|
||||
|
||||
BIN
tokenizer.bin
BIN
tokenizer.bin
Binary file not shown.
23
tokenizer.py
23
tokenizer.py
@ -4,20 +4,19 @@
|
||||
|
||||
import os
|
||||
import struct
|
||||
from logging import getLogger
|
||||
import argparse
|
||||
from typing import List
|
||||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model
|
||||
TOKENIZER_BIN = "tokenizer.bin" # binary version of the tokenizer for inference in C
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self):
|
||||
model_path = TOKENIZER_MODEL
|
||||
def __init__(self, tokenizer_model=None):
|
||||
model_path = tokenizer_model if tokenizer_model else TOKENIZER_MODEL
|
||||
assert os.path.isfile(model_path), model_path
|
||||
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
||||
#print(f"Loaded SentencePiece model from {model_path}")
|
||||
self.model_path = model_path
|
||||
|
||||
# BOS / EOS token IDs
|
||||
self.n_words: int = self.sp_model.vocab_size()
|
||||
@ -52,24 +51,28 @@ class Tokenizer:
|
||||
t = '\n<s>\n'
|
||||
elif i == self.eos_id:
|
||||
t = '\n</s>\n'
|
||||
elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
|
||||
t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01'
|
||||
t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
|
||||
b = t.encode('utf-8') # bytes of this token, utf-8 encoded
|
||||
|
||||
tokens.append(b)
|
||||
scores.append(s)
|
||||
|
||||
|
||||
# record the max token length
|
||||
max_token_length = max(len(t) for t in tokens)
|
||||
|
||||
# write to a binary file
|
||||
with open(TOKENIZER_BIN, 'wb') as f:
|
||||
# the tokenizer.bin file is the same as .model file, but .bin
|
||||
tokenizer_bin = self.model_path.replace('.model', '.bin')
|
||||
with open(tokenizer_bin, 'wb') as f:
|
||||
f.write(struct.pack("I", max_token_length))
|
||||
for bytes, score in zip(tokens, scores):
|
||||
f.write(struct.pack("fI", score, len(bytes)))
|
||||
f.write(bytes)
|
||||
|
||||
if __name__ == "__main__":
|
||||
t = Tokenizer()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer ")
|
||||
args = parser.parse_args()
|
||||
|
||||
t = Tokenizer(args.tokenizer_model)
|
||||
t.export()
|
||||
|
||||
22
train.py
22
train.py
@ -29,7 +29,6 @@ from torch.distributed import destroy_process_group, init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from tinystories import Task
|
||||
from tinyshakespeare import ShakespeareTask
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# I/O
|
||||
@ -47,11 +46,13 @@ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
# data
|
||||
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
||||
max_seq_len = 256
|
||||
dataset = "tinystories" # tinystories|tinyshakespeare
|
||||
vocab_source = "llama2" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
|
||||
vocab_size = 32000 # the Llama 2 tokenizer has 32K tokens
|
||||
# model
|
||||
dim = 288
|
||||
n_layers = 6
|
||||
n_heads = 6
|
||||
n_kv_heads = 6
|
||||
multiple_of = 32
|
||||
dropout = 0.0
|
||||
# adamw optimizer
|
||||
@ -83,6 +84,10 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
|
||||
min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
||||
|
||||
# validating checks
|
||||
assert vocab_source in ["llama2", "custom"]
|
||||
assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens"
|
||||
|
||||
# various inits, derived attributes, I/O setup
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
if ddp:
|
||||
@ -123,11 +128,12 @@ ctx = (
|
||||
)
|
||||
|
||||
# task-specific setup
|
||||
task = {'tinystories': Task, 'tinyshakespeare': ShakespeareTask}[dataset]
|
||||
iter_batches = partial(
|
||||
task.iter_batches,
|
||||
Task.iter_batches,
|
||||
batch_size=batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
vocab_source=vocab_source,
|
||||
device=device,
|
||||
num_workers=0,
|
||||
)
|
||||
@ -141,8 +147,8 @@ model_args = dict(
|
||||
dim=dim,
|
||||
n_layers=n_layers,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_heads,
|
||||
vocab_size=32000,
|
||||
n_kv_heads=n_kv_heads,
|
||||
vocab_size=vocab_size,
|
||||
multiple_of=multiple_of,
|
||||
max_seq_len=max_seq_len,
|
||||
dropout=dropout,
|
||||
@ -206,7 +212,7 @@ def estimate_loss():
|
||||
out = {}
|
||||
model.eval()
|
||||
for split in ["train", "val"]:
|
||||
batch_iter = iter_batches(split)
|
||||
batch_iter = iter_batches(split=split)
|
||||
losses = torch.zeros(eval_iters) # keep on CPU
|
||||
for k in range(eval_iters):
|
||||
X, Y = next(batch_iter)
|
||||
@ -238,7 +244,7 @@ if wandb_log and master_process:
|
||||
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
||||
|
||||
# training loop
|
||||
train_batch_iter = iter_batches("train")
|
||||
train_batch_iter = iter_batches(split="train")
|
||||
X, Y = next(train_batch_iter) # fetch the very first batch
|
||||
t0 = time.time()
|
||||
local_iter_num = 0 # number of iterations in the lifetime of this process
|
||||
|
||||
126
train_vocab.sh
Executable file
126
train_vocab.sh
Executable file
@ -0,0 +1,126 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Trains a sentencepiece tokenizer model on a bunch of given data, my best
|
||||
# effort attempt to replicate how Meta trained their Llama 2 tokenizer.
|
||||
|
||||
# usage: $ train_vocab.sh <input> <model_prefix> <vocab_size>
|
||||
# example:
|
||||
# ./train_vocab.sh tiny.txt tokenizer_tiny 1024
|
||||
# requirements:
|
||||
# install https://github.com/google/sentencepiece
|
||||
|
||||
# check if the correct number of arguments are provided
|
||||
if [ $# -ne 3 ]; then
|
||||
echo "Usage: $0 <input> <model_prefix> <vocab_size>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# assign command-line arguments to variables
|
||||
input=$1
|
||||
model_prefix=$2
|
||||
vocab_size=$3
|
||||
|
||||
# check if input file exists
|
||||
if [ ! -f "$input" ]; then
|
||||
echo "Usage: $0 <input> <model_prefix> <vocab_size>"
|
||||
echo "input '$input' not found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# check if vocab_size is a positive integer
|
||||
if ! [[ "$vocab_size" =~ ^[0-9]+$ ]] || [ "$vocab_size" -lt 1 ]; then
|
||||
echo "Usage: $0 <input> <model_prefix> <vocab_size>"
|
||||
echo "vocab_size size must be a positive integer."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Print the processed inputs
|
||||
echo "Input: $input"
|
||||
echo "Model Prefix: $model_prefix"
|
||||
echo "Vocabulary Size: $vocab_size"
|
||||
|
||||
# train a sentencepiece tokenizer model
|
||||
# Llama 2 config can be printed as follows:
|
||||
|
||||
# import sentencepiece.sentencepiece_model_pb2
|
||||
# mp = sentencepiece.sentencepiece_model_pb2.ModelProto()
|
||||
# mp.ParseFromString(open("tokenizer.model", "rb").read())
|
||||
# print(mp.trainer_spec)
|
||||
# print(mp.normalizer_spec)
|
||||
|
||||
# this gives:
|
||||
|
||||
# trainer_spec {
|
||||
# input: "/large_experiments/theorem/datasets/MERGED/all.test1.merged"
|
||||
# model_prefix: "spm_model_32k_200M_charcov099995_allowWSO__v2"
|
||||
# model_type: BPE
|
||||
# vocab_size: 32000
|
||||
# self_test_sample_size: 0
|
||||
# input_format: "text"
|
||||
# character_coverage: 0.9999499917030334
|
||||
# input_sentence_size: 200000000
|
||||
# seed_sentencepiece_size: 1000000
|
||||
# shrinking_factor: 0.75
|
||||
# num_threads: 80
|
||||
# num_sub_iterations: 2
|
||||
# max_sentence_length: 4192
|
||||
# shuffle_input_sentence: true
|
||||
# max_sentencepiece_length: 16
|
||||
# split_by_unicode_script: true
|
||||
# split_by_whitespace: true
|
||||
# split_by_number: true
|
||||
# treat_whitespace_as_suffix: false
|
||||
# split_digits: true
|
||||
# allow_whitespace_only_pieces: true
|
||||
# vocabulary_output_piece_score: true
|
||||
# hard_vocab_limit: true
|
||||
# use_all_vocab: false
|
||||
# byte_fallback: true
|
||||
# required_chars: ""
|
||||
# unk_id: 0
|
||||
# bos_id: 1
|
||||
# eos_id: 2
|
||||
# pad_id: -1
|
||||
# unk_surface: " \342\201\207 "
|
||||
# unk_piece: "<unk>"
|
||||
# bos_piece: "<s>"
|
||||
# eos_piece: "</s>"
|
||||
# pad_piece: "<pad>"
|
||||
# train_extremely_large_corpus: false
|
||||
# enable_differential_privacy: false
|
||||
# differential_privacy_noise_level: 0.0
|
||||
# differential_privacy_clipping_threshold: 0
|
||||
# }
|
||||
# normalizer_spec {
|
||||
# name: "identity"
|
||||
# precompiled_charsmap: ""
|
||||
# add_dummy_prefix: true
|
||||
# remove_extra_whitespaces: false
|
||||
# normalization_rule_tsv: ""
|
||||
# }
|
||||
|
||||
# let's now use spm_train to train this exact model
|
||||
# options docs: https://github.com/google/sentencepiece/blob/master/doc/options.md
|
||||
|
||||
# we'll depart on a few settings:
|
||||
# character_coverage -> 1.0
|
||||
|
||||
# other important notes:
|
||||
# --split-digits = true, per the paper
|
||||
# --allow_whitespace_only_pieces is true, default in spm is false
|
||||
# --byte_fallback is true, default in spm is false
|
||||
# --normalization_rule_name is identity, default in spm is nmt_nfkc
|
||||
|
||||
spm_train --input="$input" \
|
||||
--model_prefix="$model_prefix" \
|
||||
--model_type=bpe \
|
||||
--vocab_size="$vocab_size" \
|
||||
--self_test_sample_size=0 \
|
||||
--input_format="text" \
|
||||
--character_coverage=1.0 \
|
||||
--num_threads="$(nproc)" \
|
||||
--split_digits=true \
|
||||
--allow_whitespace_only_pieces=true \
|
||||
--byte_fallback=true \
|
||||
--unk_surface=" \342\201\207 " \
|
||||
--normalization_rule_name=identity \
|
||||
Loading…
Reference in New Issue
Block a user