big change: adding prompting. many LOC, but critical. ty @atamurad for the first draft, i ended up tuning it quite a bit.

This commit is contained in:
Andrej Karpathy 2023-07-28 04:12:54 +00:00
parent 568a651c45
commit b4bb47bb7b
3 changed files with 130 additions and 30 deletions

136
run.c
View File

@ -337,6 +337,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
}
// ----------------------------------------------------------------------------
// functions to sample the next token from the transformer's predicted distribution
int sample(float* probabilities, int n) {
// sample index from probabilities, they must sum to 1
float r = (float)rand() / (float)RAND_MAX;
@ -362,14 +365,76 @@ int argmax(float* v, int n) {
}
return max_i;
}
// ----------------------------------------------------------------------------
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
int str_lookup(char *str, char **vocab, int vocab_size) {
// find the first perfect match for str in vocab, return its index or -1 if not found
for (int i = 0; i < vocab_size; i++) {
if (strcmp(str, vocab[i]) == 0) {
return i;
}
}
return -1;
}
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
// a temporary buffer to merge two consecutive tokens
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
// first encode every individual byte in the input string
*n_tokens = 0; // the number of tokens
for (char *c = text; *c != '\0'; c++) {
sprintf(str_buffer, "%c", *c);
int id = str_lookup(str_buffer, vocab, vocab_size);
if (id == -1) { printf("not good\n"); exit(1);}
tokens[*n_tokens] = id;
(*n_tokens)++;
}
// merge the best consecutive pair each iteration, according the scores in vocab_scores
while (1) {
float best_score = -1e10;
int best_id = -1;
int best_idx = -1;
for (int i=0; i < (*n_tokens-1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, vocab, vocab_size);
if (id != -1 && vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = vocab_scores[id];
best_id = id;
best_idx = i;
}
}
if (best_idx == -1) {
break; // we couldn't find any more pairs to merge, so we're done
}
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
tokens[best_idx] = best_id;
// delete token at position best_idx+1, shift the entire sequence back 1
for (int i = best_idx+1; i < (*n_tokens-1); i++) {
tokens[i] = tokens[i+1];
}
(*n_tokens)--; // token length decreased
}
free(str_buffer);
}
// ----------------------------------------------------------------------------
// utilities
long time_in_ms() {
struct timespec time;
clock_gettime(CLOCK_REALTIME, &time);
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
}
// ----------------------------------------------------------------------------
int main(int argc, char *argv[]) {
@ -377,9 +442,11 @@ int main(int argc, char *argv[]) {
char *checkpoint = NULL; // e.g. out/model.bin
float temperature = 0.9f; // e.g. 1.0, or 0.0
int steps = 256; // max number of steps to run for, 0: use seq_len
char *prompt = NULL; // prompt string
// 'checkpoint' is necessary arg
if (argc < 2) {
printf("Usage: %s <checkpoint_file> [temperature] [steps]\n", argv[0]);
printf("Usage: %s <checkpoint_file> [temperature] [steps] [prompt]\n", argv[0]);
return 1;
}
if (argc >= 2) {
@ -392,6 +459,9 @@ int main(int argc, char *argv[]) {
if (argc >= 4) {
steps = atoi(argv[3]);
}
if (argc >= 5) {
prompt = argv[4];
}
// seed rng with time. if you want deterministic behavior use temperature 0.0
srand((unsigned int)time(NULL));
@ -406,7 +476,7 @@ int main(int argc, char *argv[]) {
FILE *file = fopen(checkpoint, "rb");
if (!file) { printf("Couldn't open file %s\n", checkpoint); return 1; }
// read in the config header
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
int shared_weights = config.vocab_size > 0 ? 1 : 0;
config.vocab_size = abs(config.vocab_size);
@ -427,14 +497,18 @@ int main(int argc, char *argv[]) {
// read in the tokenizer.bin file
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
unsigned int max_token_length;
{
FILE *file = fopen("tokenizer.bin", "rb");
if (!file) { printf("Couldn't load tokenizer.bin\n"); return 1; }
if (!file) { printf("couldn't load tokenizer.bin\n"); return 1; }
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
int len;
for (int i = 0; i < config.vocab_size; i++) {
if(fread(&len, sizeof(int), 1, file) != 1) { return 1; }
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { printf("failed read\n"); return 1;}
if (fread(&len, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
vocab[i] = (char *)malloc(len + 1);
if(fread(vocab[i], len, 1, file) != 1) { return 1; }
if (fread(vocab[i], len, 1, file) != 1) { printf("failed read\n"); return 1; }
vocab[i][len] = '\0'; // add the string terminating token
}
fclose(file);
@ -443,30 +517,44 @@ int main(int argc, char *argv[]) {
// create and init the application RunState
RunState state;
malloc_run_state(&state, &config);
// the current position we are in
long start = 0; // used to time our code, only initialized after first iteration
int next;
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
int pos = 0;
printf("<s>\n"); // explicit print the initial BOS token (=1), stylistically symmetric
// process the prompt, if any
int *prompt_tokens = NULL;
int num_prompt_tokens = 0;
if (prompt != NULL) {
prompt_tokens = (int*)malloc(config.seq_len * sizeof(int));
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
}
// start the main loop
long start = 0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
while (pos < steps) {
// forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights);
// sample the next token
if(temperature == 0.0f) {
// greedy argmax sampling
next = argmax(state.logits, config.vocab_size);
if(pos < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
} else {
// apply the temperature to the logits
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(state.logits, config.vocab_size);
// we now want to sample from this distribution to get the next token
next = sample(state.logits, config.vocab_size);
// sample the next token
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = argmax(state.logits, config.vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(state.logits, config.vocab_size);
// we sample from this distribution to get the next token
next = sample(state.logits, config.vocab_size);
}
}
// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
printf("%s", token_str);
@ -487,6 +575,8 @@ int main(int argc, char *argv[]) {
free_run_state(&state);
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
free(vocab);
free(vocab_scores);
if (prompt_tokens != NULL) free(prompt_tokens);
if (data != MAP_FAILED) munmap(data, file_size);
if (fd != -1) close(fd);
return 0;

Binary file not shown.

View File

@ -3,6 +3,7 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import struct
from logging import getLogger
from typing import List
@ -39,26 +40,35 @@ class Tokenizer:
return self.sp_model.decode(t)
def export(self):
tokens = []
# get all the tokens (postprocessed) and their scores as floats
tokens, scores = [], []
for i in range(self.n_words):
# decode the token and light postprocessing
t = self.sp_model.id_to_piece(i)
s = self.sp_model.get_score(i)
if i == self.bos_id:
t = '\n<s>\n'
elif i == self.eos_id:
t = '\n</s>\n'
elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01'
t = t.replace('', ' ') # sentencepiece uses this as the whitespace
t = t.replace('', ' ') # sentencepiece uses this character as whitespace
b = t.encode('utf-8') # bytes of this token, utf-8 encoded
tokens.append(t)
tokens.append(b)
scores.append(s)
# record the max token length
max_token_length = max(len(t) for t in tokens)
# write to a binary file
with open(TOKENIZER_BIN, 'wb') as f:
for token in tokens:
bytes = token.encode('utf-8')
f.write((len(bytes)).to_bytes(4, 'little')) # write length of bytes
f.write(bytes) # write token bytes
f.write(struct.pack("I", max_token_length))
for bytes, score in zip(tokens, scores):
f.write(struct.pack("fI", score, len(bytes)))
f.write(bytes)
if __name__ == "__main__":
t = Tokenizer()