mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
Merge pull request #400 from kroggen/use-kv-cache
calculate key and value inside the kv cache
This commit is contained in:
commit
1fcdf04fbb
23
run.c
23
run.c
@ -83,16 +83,13 @@ void malloc_run_state(RunState* s, Config* p) {
|
||||
s->hb = calloc(p->hidden_dim, sizeof(float));
|
||||
s->hb2 = calloc(p->hidden_dim, sizeof(float));
|
||||
s->q = calloc(p->dim, sizeof(float));
|
||||
s->k = calloc(kv_dim, sizeof(float));
|
||||
s->v = calloc(kv_dim, sizeof(float));
|
||||
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
||||
s->logits = calloc(p->vocab_size, sizeof(float));
|
||||
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
|
||||
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
|
||||
s->logits = calloc(p->vocab_size, sizeof(float));
|
||||
// ensure all mallocs went fine
|
||||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|
||||
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|
||||
|| !s->value_cache) {
|
||||
|| !s->key_cache || !s->value_cache || !s->att || !s->logits) {
|
||||
fprintf(stderr, "malloc failed!\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
@ -105,8 +102,6 @@ void free_run_state(RunState* s) {
|
||||
free(s->hb);
|
||||
free(s->hb2);
|
||||
free(s->q);
|
||||
free(s->k);
|
||||
free(s->v);
|
||||
free(s->att);
|
||||
free(s->logits);
|
||||
free(s->key_cache);
|
||||
@ -256,6 +251,11 @@ float* forward(Transformer* transformer, int token, int pos) {
|
||||
// attention rmsnorm
|
||||
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
|
||||
|
||||
// key and value point to the kv cache
|
||||
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
|
||||
s->k = s->key_cache + loff + pos * kv_dim;
|
||||
s->v = s->value_cache + loff + pos * kv_dim;
|
||||
|
||||
// qkv matmuls for this position
|
||||
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
|
||||
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
|
||||
@ -278,13 +278,6 @@ float* forward(Transformer* transformer, int token, int pos) {
|
||||
}
|
||||
}
|
||||
|
||||
// save key,value at this time step (pos) to our kv cache
|
||||
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
|
||||
float* key_cache_row = s->key_cache + loff + pos * kv_dim;
|
||||
float* value_cache_row = s->value_cache + loff + pos * kv_dim;
|
||||
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
|
||||
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
|
||||
|
||||
// multihead attention. iterate over all heads
|
||||
int h;
|
||||
#pragma omp parallel for private(h)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user