diff --git a/README.md b/README.md index 4889816..f0eb9ec 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,32 @@ python tokenizer.py --tokenizer-model=/path/to/CodeLlama-7b-Instruct/tokenizer.m ./run codellama2_7b_instruct.bin -m chat -z /path/to/CodeLlama-7b-Instruct/tokenizer.bin ``` +## int8 quantization + +The (default) script [run.c](run.c), above, uses a float32 forward pass, where the entire calculation of the forward pass is kept in fp32. This is very easy to understand as far as reference code goes, but it has the following downsides: the model checkpoint files are very large (it takes 4 bytes per every individual weight), and the forward pass is relatively slow. The (very) common inference optimization employed in practice is to quantize the model parameters to lower precision, giving up a little bit of correctness in return for smaller checkpoint sizes and faster forward passes (as most of the inference uses integer arithmetic). Empirically, LLMs can tolerate precisions as low as 4-bit (or even lower), but we use int8 here because it is a "safe" setting that gets us the benefits but doesn't sacrifice too much of the model accuracy. Only the weights that participate in matmuls are quantized. All the other parameters (e.g. especially the scale and bias in RMSNorm) are kept in float32, because these layers are very sensitive. Now, if all you're after is reduction in checkpoint sizes, you could quantize the weights, save the checkpoint, and then dequantize them in run.c, and do float32 inference as normal and call it a day. This is totally fine. But here, we go one step further (as is standard practice) and additionally quantize the activations in the forward pass. This requires us to dynamically quantize and dequantize between float32 and int8 at runtime, which adds overhead. But the benefit is that now the majority of the calculations (the matmuls especially!) are using pure integer arithmetic, where both weights and activations enter as int8. This is where the speedups can fundamentally come from. The version we use is the "Q8_0" quantization (llama.cpp terminology), where the 0 means that the weight quantization is symmetric around 0, quantizing to the range [-127, 127]. + +The quantized forward pass is implemented in [runq.c](runq.c). To use it, we have to export the model in the quantized format. For example, the float32 version of Llama 2 7B was exported as: + +``` +python export.py llama2_7b.bin --meta-llama path/to/llama/model/7B +``` + +This creates a 26GB file, because each one of 7B parameters is 4 bytes (fp32). To export it quantized, we instead use version 2 export: + +``` +python export.py llama2_7b_q80.bin --version 2 --meta-llama path/to/llama/model/7B +``` + +This runs for a few minutes, but now creates only a 6.7GB file. For exporting non-meta checkpoints you would use the --checkpoint arg instead of --meta-llama arg (more docs on this later, below). Now let's inference them. I like to use OMP here because these are big models, so e.g. on my Linux box: + +``` +make runomp +OMP_NUM_THREADS=64 ./run llama2_7b.bin -n 40 +OMP_NUM_THREADS=64 ./runq llama2_7b_q80.bin -n 40 +``` + +This runs 40 steps just to get a timing. The float32 version for me runs at 4.6 tok/s, and the int8 version at 14 tok/s. So we achieved a 3X speedup while reducing the checkpoint size by 4X. However, the forward pass is quantized to int8, and therefore silently very slightly lower quality. + ## huggingface models We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file. @@ -364,7 +390,6 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg ## unsorted todos - add support in run.c of reading version 1+ files from export, later deprecate "version 0" -- runq.c (int8 quantization) add - run.cu (CUDA) investigate and merge - add more tests inside [test.c](test.c) - add Engine class for use in sample.py that does efficient inference in PyTorch, e.g. KV cache keeping diff --git a/runq.c b/runq.c index 53a7c85..8c95a91 100644 --- a/runq.c +++ b/runq.c @@ -1,4 +1,4 @@ -/* Inference for Llama-2 Transformer model in pure C */ +/* Inference for Llama-2 Transformer model in pure C, int8 quantized forward pass. */ #include #include @@ -176,7 +176,7 @@ QuantizedTensor *init_quantized_tensors(void **ptr, int n, int size_each) { QuantizedTensor *res = malloc(n * sizeof(QuantizedTensor)); for(int i=0; i