Merge pull request #364 from karpathy/feature/int8_try2

int8 quantization attempt #2
This commit is contained in:
Andrej 2023-10-09 12:54:35 -07:00 committed by GitHub
commit d9862069e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 1135 additions and 9 deletions

View File

@ -6,11 +6,13 @@ CC = gcc
.PHONY: run
run: run.c
$(CC) -O3 -o run run.c -lm
$(CC) -O3 -o runq runq.c -lm
# useful for a debug build, can then e.g. analyze with valgrind, example:
# $ valgrind --leak-check=full ./run out/model.bin -n 3
rundebug: run.c
$(CC) -g -o run run.c -lm
$(CC) -g -o runq runq.c -lm
# https://gcc.gnu.org/onlinedocs/gcc/Optimize-Options.html
# https://simonbyrne.github.io/notes/fastmath/
@ -24,6 +26,7 @@ rundebug: run.c
.PHONY: runfast
runfast: run.c
$(CC) -Ofast -o run run.c -lm
$(CC) -Ofast -o runq runq.c -lm
# additionally compiles with OpenMP, allowing multithreaded runs
# make sure to also enable multiple threads when running, e.g.:
@ -31,19 +34,23 @@ runfast: run.c
.PHONY: runomp
runomp: run.c
$(CC) -Ofast -fopenmp -march=native run.c -lm -o run
$(CC) -Ofast -fopenmp -march=native runq.c -lm -o runq
.PHONY: win64
win64:
x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c
x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o runq.exe -I. runq.c win.c
# compiles with gnu99 standard flags for amazon linux, coreos, etc. compatibility
.PHONY: rungnu
rungnu:
$(CC) -Ofast -std=gnu11 -o run run.c -lm
$(CC) -Ofast -std=gnu11 -o runq runq.c -lm
.PHONY: runompgnu
runompgnu:
$(CC) -Ofast -fopenmp -std=gnu11 run.c -lm -o run
$(CC) -Ofast -fopenmp -std=gnu11 runq.c -lm -o runq
# run all tests
.PHONY: test
@ -66,3 +73,4 @@ testcc:
.PHONY: clean
clean:
rm -f run
rm -f runq

View File

@ -113,6 +113,32 @@ python tokenizer.py --tokenizer-model=/path/to/CodeLlama-7b-Instruct/tokenizer.m
./run codellama2_7b_instruct.bin -m chat -z /path/to/CodeLlama-7b-Instruct/tokenizer.bin
```
## int8 quantization
The (default) script [run.c](run.c), above, uses a float32 forward pass, where the entire calculation of the forward pass is kept in fp32. This is very easy to understand as far as reference code goes, but it has the following downsides: the model checkpoint files are very large (it takes 4 bytes per every individual weight), and the forward pass is relatively slow. The (very) common inference optimization employed in practice is to quantize the model parameters to lower precision, giving up a little bit of correctness in return for smaller checkpoint sizes and faster forward passes (as most of the inference uses integer arithmetic). Empirically, LLMs can tolerate precisions as low as 4-bit (or even lower), but we use int8 here because it is a "safe" setting that gets us the benefits but doesn't sacrifice too much of the model accuracy. Only the weights that participate in matmuls are quantized. All the other parameters (e.g. especially the scale and bias in RMSNorm) are kept in float32, because these layers are very sensitive. Now, if all you're after is reduction in checkpoint sizes, you could quantize the weights, save the checkpoint, and then dequantize them in run.c, and do float32 inference as normal and call it a day. This is totally fine. But here, we go one step further (as is standard practice) and additionally quantize the activations in the forward pass. This requires us to dynamically quantize and dequantize between float32 and int8 at runtime, which adds overhead. But the benefit is that now the majority of the calculations (the matmuls especially!) are using pure integer arithmetic, where both weights and activations enter as int8. This is where the speedups can fundamentally come from. The version we use is the "Q8_0" quantization (llama.cpp terminology), where the 0 means that the weight quantization is symmetric around 0, quantizing to the range [-127, 127].
The quantized forward pass is implemented in [runq.c](runq.c). To use it, we have to export the model in the quantized format. For example, the float32 version of Llama 2 7B was exported as:
```
python export.py llama2_7b.bin --meta-llama path/to/llama/model/7B
```
This creates a 26GB file, because each one of 7B parameters is 4 bytes (fp32). To export it quantized, we instead use version 2 export:
```
python export.py llama2_7b_q80.bin --version 2 --meta-llama path/to/llama/model/7B
```
This runs for a few minutes, but now creates only a 6.7GB file. For exporting non-meta checkpoints you would use the --checkpoint arg instead of --meta-llama arg (more docs on this later, below). Now let's inference them. I like to use OMP here because these are big models, so e.g. on my Linux box:
```
make runomp
OMP_NUM_THREADS=64 ./run llama2_7b.bin -n 40
OMP_NUM_THREADS=64 ./runq llama2_7b_q80.bin -n 40
```
This runs 40 steps just to get a timing. The float32 version for me runs at 4.6 tok/s, and the int8 version at 14 tok/s. So we achieved a 3X speedup while reducing the checkpoint size by 4X. However, the forward pass is quantized to int8, and therefore silently very slightly lower quality.
## huggingface models
We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file.
@ -364,7 +390,6 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
## unsorted todos
- add support in run.c of reading version 1+ files from export, later deprecate "version 0"
- runq.c (int8 quantization) add
- run.cu (CUDA) investigate and merge
- add more tests inside [test.c](test.c)
- add Engine class for use in sample.py that does efficient inference in PyTorch, e.g. KV cache keeping

View File

@ -241,22 +241,16 @@ def version2_export(model, filepath, group_size=64):
# now let's write out all the params that we are quantizing to Q8_0
# note we skip classifier weights, which are shared with the embedding
ew = []
scales = []
for i, w in enumerate(weights):
# quantize this weight
q, s, err = quantize_q80(w, group_size)
# save the int8 weights to file
serialize_int8(out_file, q) # save the tensor in int8
scales.append(s) # we'll do all the scales after all the qs
serialize_fp32(out_file, s) # save scale factors
# logging
ew.append((err, w.shape))
print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")
# save the scaling factors in fp32 here
# this is done to keep all the weights contiquous, making pointer arithmetic easier in C
for s in scales:
serialize_fp32(out_file, s)
# print the highest error across all weights, should be very small, e.g. O(~0.001)
ew.sort(reverse=True)
print(f"max quantization group error across all weights: {ew[0][0]}")
@ -496,7 +490,14 @@ def load_hf_model(model_path):
# API entrypoint
def model_export(model, filepath, version, dtype=torch.float32):
# TODO: add dtype export support for other versions
"""
Versions docs:
v-1:huggingface export, i.e. intended for use outside of this repo, in HF
v0: legacy llama2.c float format, DEPRECATED
v1: float32 export
v2: int8 quantized Q8_0 export, similar to llama.cpp, in groups
# TODO: add dtype export support for other versions (?)
"""
if version == 0:
legacy_export(model, filepath)
elif version == 1:

1092
runq.c Normal file

File diff suppressed because it is too large Load Diff