From 593d846bc3d9460c66925d7d3281e67c1b2df5d1 Mon Sep 17 00:00:00 2001 From: Bernardo Ramos Date: Thu, 14 Sep 2023 01:13:08 +0000 Subject: [PATCH 1/2] use key and value from kv cache --- run.c | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/run.c b/run.c index efb254f..615ef38 100644 --- a/run.c +++ b/run.c @@ -83,16 +83,13 @@ void malloc_run_state(RunState* s, Config* p) { s->hb = calloc(p->hidden_dim, sizeof(float)); s->hb2 = calloc(p->hidden_dim, sizeof(float)); s->q = calloc(p->dim, sizeof(float)); - s->k = calloc(kv_dim, sizeof(float)); - s->v = calloc(kv_dim, sizeof(float)); s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); s->logits = calloc(p->vocab_size, sizeof(float)); s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); // ensure all mallocs went fine if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q - || !s->k || !s->v || !s->att || !s->logits || !s->key_cache - || !s->value_cache) { + || !s->att || !s->logits || !s->key_cache || !s->value_cache) { fprintf(stderr, "malloc failed!\n"); exit(EXIT_FAILURE); } @@ -105,8 +102,6 @@ void free_run_state(RunState* s) { free(s->hb); free(s->hb2); free(s->q); - free(s->k); - free(s->v); free(s->att); free(s->logits); free(s->key_cache); @@ -256,6 +251,11 @@ float* forward(Transformer* transformer, int token, int pos) { // attention rmsnorm rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); + // key and value point to the kv cache + int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience + s->k = s->key_cache + loff + pos * kv_dim; + s->v = s->value_cache + loff + pos * kv_dim; + // qkv matmuls for this position matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim); matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim); @@ -278,13 +278,6 @@ float* forward(Transformer* transformer, int token, int pos) { } } - // save key,value at this time step (pos) to our kv cache - int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience - float* key_cache_row = s->key_cache + loff + pos * kv_dim; - float* value_cache_row = s->value_cache + loff + pos * kv_dim; - memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row)); - memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row)); - // multihead attention. iterate over all heads int h; #pragma omp parallel for private(h) From 411c5bd2db9a87e94e1bd1a6c7b7ca117adc4b01 Mon Sep 17 00:00:00 2001 From: Bernardo Ramos Date: Thu, 14 Sep 2023 07:14:45 +0000 Subject: [PATCH 2/2] reorganize variables --- run.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/run.c b/run.c index 615ef38..e1a4ec2 100644 --- a/run.c +++ b/run.c @@ -83,13 +83,13 @@ void malloc_run_state(RunState* s, Config* p) { s->hb = calloc(p->hidden_dim, sizeof(float)); s->hb2 = calloc(p->hidden_dim, sizeof(float)); s->q = calloc(p->dim, sizeof(float)); - s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); - s->logits = calloc(p->vocab_size, sizeof(float)); s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); + s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); + s->logits = calloc(p->vocab_size, sizeof(float)); // ensure all mallocs went fine if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q - || !s->att || !s->logits || !s->key_cache || !s->value_cache) { + || !s->key_cache || !s->value_cache || !s->att || !s->logits) { fprintf(stderr, "malloc failed!\n"); exit(EXIT_FAILURE); }