diff --git a/run.c b/run.c index 912c3cb..477011c 100644 --- a/run.c +++ b/run.c @@ -337,6 +337,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* matmul(s->logits, x, w->wcls, p->dim, p->vocab_size); } +// ---------------------------------------------------------------------------- +// functions to sample the next token from the transformer's predicted distribution + int sample(float* probabilities, int n) { // sample index from probabilities, they must sum to 1 float r = (float)rand() / (float)RAND_MAX; @@ -362,14 +365,76 @@ int argmax(float* v, int n) { } return max_i; } +// ---------------------------------------------------------------------------- +// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt + +int str_lookup(char *str, char **vocab, int vocab_size) { + // find the first perfect match for str in vocab, return its index or -1 if not found + for (int i = 0; i < vocab_size; i++) { + if (strcmp(str, vocab[i]) == 0) { + return i; + } + } + return -1; +} + +void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) { + + // a temporary buffer to merge two consecutive tokens + char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator + + // first encode every individual byte in the input string + *n_tokens = 0; // the number of tokens + for (char *c = text; *c != '\0'; c++) { + sprintf(str_buffer, "%c", *c); + int id = str_lookup(str_buffer, vocab, vocab_size); + if (id == -1) { printf("not good\n"); exit(1);} + tokens[*n_tokens] = id; + (*n_tokens)++; + } + + // merge the best consecutive pair each iteration, according the scores in vocab_scores + while (1) { + float best_score = -1e10; + int best_id = -1; + int best_idx = -1; + + for (int i=0; i < (*n_tokens-1); i++) { + // check if we can merge the pair (tokens[i], tokens[i+1]) + sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]); + int id = str_lookup(str_buffer, vocab, vocab_size); + if (id != -1 && vocab_scores[id] > best_score) { + // this merge pair exists in vocab! record its score and position + best_score = vocab_scores[id]; + best_id = id; + best_idx = i; + } + } + + if (best_idx == -1) { + break; // we couldn't find any more pairs to merge, so we're done + } + + // merge the consecutive pair (best_idx, best_idx+1) into new token best_id + tokens[best_idx] = best_id; + // delete token at position best_idx+1, shift the entire sequence back 1 + for (int i = best_idx+1; i < (*n_tokens-1); i++) { + tokens[i] = tokens[i+1]; + } + (*n_tokens)--; // token length decreased + } + + free(str_buffer); +} // ---------------------------------------------------------------------------- - +// utilities long time_in_ms() { struct timespec time; clock_gettime(CLOCK_REALTIME, &time); return time.tv_sec * 1000 + time.tv_nsec / 1000000; } +// ---------------------------------------------------------------------------- int main(int argc, char *argv[]) { @@ -377,9 +442,11 @@ int main(int argc, char *argv[]) { char *checkpoint = NULL; // e.g. out/model.bin float temperature = 0.9f; // e.g. 1.0, or 0.0 int steps = 256; // max number of steps to run for, 0: use seq_len + char *prompt = NULL; // prompt string + // 'checkpoint' is necessary arg if (argc < 2) { - printf("Usage: %s [temperature] [steps]\n", argv[0]); + printf("Usage: %s [temperature] [steps] [prompt]\n", argv[0]); return 1; } if (argc >= 2) { @@ -392,6 +459,9 @@ int main(int argc, char *argv[]) { if (argc >= 4) { steps = atoi(argv[3]); } + if (argc >= 5) { + prompt = argv[4]; + } // seed rng with time. if you want deterministic behavior use temperature 0.0 srand((unsigned int)time(NULL)); @@ -406,7 +476,7 @@ int main(int argc, char *argv[]) { FILE *file = fopen(checkpoint, "rb"); if (!file) { printf("Couldn't open file %s\n", checkpoint); return 1; } // read in the config header - if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; } + if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; } // negative vocab size is hacky way of signaling unshared weights. bit yikes. int shared_weights = config.vocab_size > 0 ? 1 : 0; config.vocab_size = abs(config.vocab_size); @@ -427,14 +497,18 @@ int main(int argc, char *argv[]) { // read in the tokenizer.bin file char** vocab = (char**)malloc(config.vocab_size * sizeof(char*)); + float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float)); + unsigned int max_token_length; { FILE *file = fopen("tokenizer.bin", "rb"); - if (!file) { printf("Couldn't load tokenizer.bin\n"); return 1; } + if (!file) { printf("couldn't load tokenizer.bin\n"); return 1; } + if (fread(&max_token_length, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; } int len; for (int i = 0; i < config.vocab_size; i++) { - if(fread(&len, sizeof(int), 1, file) != 1) { return 1; } + if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { printf("failed read\n"); return 1;} + if (fread(&len, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; } vocab[i] = (char *)malloc(len + 1); - if(fread(vocab[i], len, 1, file) != 1) { return 1; } + if (fread(vocab[i], len, 1, file) != 1) { printf("failed read\n"); return 1; } vocab[i][len] = '\0'; // add the string terminating token } fclose(file); @@ -443,30 +517,44 @@ int main(int argc, char *argv[]) { // create and init the application RunState RunState state; malloc_run_state(&state, &config); - - // the current position we are in - long start = 0; // used to time our code, only initialized after first iteration - int next; - int token = 1; // 1 = BOS token in Llama-2 sentencepiece - int pos = 0; - printf("\n"); // explicit print the initial BOS token (=1), stylistically symmetric + + // process the prompt, if any + int *prompt_tokens = NULL; + int num_prompt_tokens = 0; + if (prompt != NULL) { + prompt_tokens = (int*)malloc(config.seq_len * sizeof(int)); + bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens); + } + + // start the main loop + long start = 0; // used to time our code, only initialized after first iteration + int next; // will store the next token in the sequence + int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer + int pos = 0; // position in the sequence + printf("\n"); // explicit print the initial BOS token for stylistic symmetry reasons while (pos < steps) { // forward the transformer to get logits for the next token transformer(token, pos, &config, &state, &weights); - // sample the next token - if(temperature == 0.0f) { - // greedy argmax sampling - next = argmax(state.logits, config.vocab_size); + if(pos < num_prompt_tokens) { + // if we are still processing the input prompt, force the next prompt token + next = prompt_tokens[pos]; } else { - // apply the temperature to the logits - for (int q=0; q'): t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01' - t = t.replace('▁', ' ') # sentencepiece uses this as the whitespace + t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace + b = t.encode('utf-8') # bytes of this token, utf-8 encoded - tokens.append(t) + tokens.append(b) + scores.append(s) + # record the max token length + max_token_length = max(len(t) for t in tokens) + + # write to a binary file with open(TOKENIZER_BIN, 'wb') as f: - for token in tokens: - bytes = token.encode('utf-8') - f.write((len(bytes)).to_bytes(4, 'little')) # write length of bytes - f.write(bytes) # write token bytes + f.write(struct.pack("I", max_token_length)) + for bytes, score in zip(tokens, scores): + f.write(struct.pack("fI", score, len(bytes))) + f.write(bytes) if __name__ == "__main__": t = Tokenizer()