mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
big change: adding prompting. many LOC, but critical. ty @atamurad for the first draft, i ended up tuning it quite a bit.
This commit is contained in:
parent
568a651c45
commit
b4bb47bb7b
136
run.c
136
run.c
@ -337,6 +337,9 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
|
||||
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// functions to sample the next token from the transformer's predicted distribution
|
||||
|
||||
int sample(float* probabilities, int n) {
|
||||
// sample index from probabilities, they must sum to 1
|
||||
float r = (float)rand() / (float)RAND_MAX;
|
||||
@ -362,14 +365,76 @@ int argmax(float* v, int n) {
|
||||
}
|
||||
return max_i;
|
||||
}
|
||||
// ----------------------------------------------------------------------------
|
||||
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
|
||||
|
||||
int str_lookup(char *str, char **vocab, int vocab_size) {
|
||||
// find the first perfect match for str in vocab, return its index or -1 if not found
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
if (strcmp(str, vocab[i]) == 0) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
|
||||
|
||||
// a temporary buffer to merge two consecutive tokens
|
||||
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator
|
||||
|
||||
// first encode every individual byte in the input string
|
||||
*n_tokens = 0; // the number of tokens
|
||||
for (char *c = text; *c != '\0'; c++) {
|
||||
sprintf(str_buffer, "%c", *c);
|
||||
int id = str_lookup(str_buffer, vocab, vocab_size);
|
||||
if (id == -1) { printf("not good\n"); exit(1);}
|
||||
tokens[*n_tokens] = id;
|
||||
(*n_tokens)++;
|
||||
}
|
||||
|
||||
// merge the best consecutive pair each iteration, according the scores in vocab_scores
|
||||
while (1) {
|
||||
float best_score = -1e10;
|
||||
int best_id = -1;
|
||||
int best_idx = -1;
|
||||
|
||||
for (int i=0; i < (*n_tokens-1); i++) {
|
||||
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
|
||||
int id = str_lookup(str_buffer, vocab, vocab_size);
|
||||
if (id != -1 && vocab_scores[id] > best_score) {
|
||||
// this merge pair exists in vocab! record its score and position
|
||||
best_score = vocab_scores[id];
|
||||
best_id = id;
|
||||
best_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_idx == -1) {
|
||||
break; // we couldn't find any more pairs to merge, so we're done
|
||||
}
|
||||
|
||||
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
|
||||
tokens[best_idx] = best_id;
|
||||
// delete token at position best_idx+1, shift the entire sequence back 1
|
||||
for (int i = best_idx+1; i < (*n_tokens-1); i++) {
|
||||
tokens[i] = tokens[i+1];
|
||||
}
|
||||
(*n_tokens)--; // token length decreased
|
||||
}
|
||||
|
||||
free(str_buffer);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// utilities
|
||||
long time_in_ms() {
|
||||
struct timespec time;
|
||||
clock_gettime(CLOCK_REALTIME, &time);
|
||||
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
|
||||
}
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
|
||||
@ -377,9 +442,11 @@ int main(int argc, char *argv[]) {
|
||||
char *checkpoint = NULL; // e.g. out/model.bin
|
||||
float temperature = 0.9f; // e.g. 1.0, or 0.0
|
||||
int steps = 256; // max number of steps to run for, 0: use seq_len
|
||||
char *prompt = NULL; // prompt string
|
||||
|
||||
// 'checkpoint' is necessary arg
|
||||
if (argc < 2) {
|
||||
printf("Usage: %s <checkpoint_file> [temperature] [steps]\n", argv[0]);
|
||||
printf("Usage: %s <checkpoint_file> [temperature] [steps] [prompt]\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
if (argc >= 2) {
|
||||
@ -392,6 +459,9 @@ int main(int argc, char *argv[]) {
|
||||
if (argc >= 4) {
|
||||
steps = atoi(argv[3]);
|
||||
}
|
||||
if (argc >= 5) {
|
||||
prompt = argv[4];
|
||||
}
|
||||
|
||||
// seed rng with time. if you want deterministic behavior use temperature 0.0
|
||||
srand((unsigned int)time(NULL));
|
||||
@ -406,7 +476,7 @@ int main(int argc, char *argv[]) {
|
||||
FILE *file = fopen(checkpoint, "rb");
|
||||
if (!file) { printf("Couldn't open file %s\n", checkpoint); return 1; }
|
||||
// read in the config header
|
||||
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
||||
if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
||||
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
|
||||
int shared_weights = config.vocab_size > 0 ? 1 : 0;
|
||||
config.vocab_size = abs(config.vocab_size);
|
||||
@ -427,14 +497,18 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
// read in the tokenizer.bin file
|
||||
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
|
||||
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
|
||||
unsigned int max_token_length;
|
||||
{
|
||||
FILE *file = fopen("tokenizer.bin", "rb");
|
||||
if (!file) { printf("Couldn't load tokenizer.bin\n"); return 1; }
|
||||
if (!file) { printf("couldn't load tokenizer.bin\n"); return 1; }
|
||||
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
|
||||
int len;
|
||||
for (int i = 0; i < config.vocab_size; i++) {
|
||||
if(fread(&len, sizeof(int), 1, file) != 1) { return 1; }
|
||||
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { printf("failed read\n"); return 1;}
|
||||
if (fread(&len, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
|
||||
vocab[i] = (char *)malloc(len + 1);
|
||||
if(fread(vocab[i], len, 1, file) != 1) { return 1; }
|
||||
if (fread(vocab[i], len, 1, file) != 1) { printf("failed read\n"); return 1; }
|
||||
vocab[i][len] = '\0'; // add the string terminating token
|
||||
}
|
||||
fclose(file);
|
||||
@ -443,30 +517,44 @@ int main(int argc, char *argv[]) {
|
||||
// create and init the application RunState
|
||||
RunState state;
|
||||
malloc_run_state(&state, &config);
|
||||
|
||||
// the current position we are in
|
||||
long start = 0; // used to time our code, only initialized after first iteration
|
||||
int next;
|
||||
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
|
||||
int pos = 0;
|
||||
printf("<s>\n"); // explicit print the initial BOS token (=1), stylistically symmetric
|
||||
|
||||
// process the prompt, if any
|
||||
int *prompt_tokens = NULL;
|
||||
int num_prompt_tokens = 0;
|
||||
if (prompt != NULL) {
|
||||
prompt_tokens = (int*)malloc(config.seq_len * sizeof(int));
|
||||
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
|
||||
}
|
||||
|
||||
// start the main loop
|
||||
long start = 0; // used to time our code, only initialized after first iteration
|
||||
int next; // will store the next token in the sequence
|
||||
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
|
||||
int pos = 0; // position in the sequence
|
||||
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
|
||||
while (pos < steps) {
|
||||
|
||||
// forward the transformer to get logits for the next token
|
||||
transformer(token, pos, &config, &state, &weights);
|
||||
|
||||
// sample the next token
|
||||
if(temperature == 0.0f) {
|
||||
// greedy argmax sampling
|
||||
next = argmax(state.logits, config.vocab_size);
|
||||
if(pos < num_prompt_tokens) {
|
||||
// if we are still processing the input prompt, force the next prompt token
|
||||
next = prompt_tokens[pos];
|
||||
} else {
|
||||
// apply the temperature to the logits
|
||||
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
|
||||
// apply softmax to the logits to get the probabilities for next token
|
||||
softmax(state.logits, config.vocab_size);
|
||||
// we now want to sample from this distribution to get the next token
|
||||
next = sample(state.logits, config.vocab_size);
|
||||
// sample the next token
|
||||
if (temperature == 0.0f) {
|
||||
// greedy argmax sampling: take the token with the highest probability
|
||||
next = argmax(state.logits, config.vocab_size);
|
||||
} else {
|
||||
// apply the temperature to the logits
|
||||
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
|
||||
// apply softmax to the logits to get the probabilities for next token
|
||||
softmax(state.logits, config.vocab_size);
|
||||
// we sample from this distribution to get the next token
|
||||
next = sample(state.logits, config.vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
|
||||
printf("%s", token_str);
|
||||
@ -487,6 +575,8 @@ int main(int argc, char *argv[]) {
|
||||
free_run_state(&state);
|
||||
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
|
||||
free(vocab);
|
||||
free(vocab_scores);
|
||||
if (prompt_tokens != NULL) free(prompt_tokens);
|
||||
if (data != MAP_FAILED) munmap(data, file_size);
|
||||
if (fd != -1) close(fd);
|
||||
return 0;
|
||||
|
||||
BIN
tokenizer.bin
BIN
tokenizer.bin
Binary file not shown.
24
tokenizer.py
24
tokenizer.py
@ -3,6 +3,7 @@
|
||||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||
|
||||
import os
|
||||
import struct
|
||||
from logging import getLogger
|
||||
from typing import List
|
||||
|
||||
@ -39,26 +40,35 @@ class Tokenizer:
|
||||
return self.sp_model.decode(t)
|
||||
|
||||
def export(self):
|
||||
tokens = []
|
||||
|
||||
# get all the tokens (postprocessed) and their scores as floats
|
||||
tokens, scores = [], []
|
||||
for i in range(self.n_words):
|
||||
|
||||
# decode the token and light postprocessing
|
||||
t = self.sp_model.id_to_piece(i)
|
||||
s = self.sp_model.get_score(i)
|
||||
if i == self.bos_id:
|
||||
t = '\n<s>\n'
|
||||
elif i == self.eos_id:
|
||||
t = '\n</s>\n'
|
||||
elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
|
||||
t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01'
|
||||
t = t.replace('▁', ' ') # sentencepiece uses this as the whitespace
|
||||
t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
|
||||
b = t.encode('utf-8') # bytes of this token, utf-8 encoded
|
||||
|
||||
tokens.append(t)
|
||||
tokens.append(b)
|
||||
scores.append(s)
|
||||
|
||||
# record the max token length
|
||||
max_token_length = max(len(t) for t in tokens)
|
||||
|
||||
# write to a binary file
|
||||
with open(TOKENIZER_BIN, 'wb') as f:
|
||||
for token in tokens:
|
||||
bytes = token.encode('utf-8')
|
||||
f.write((len(bytes)).to_bytes(4, 'little')) # write length of bytes
|
||||
f.write(bytes) # write token bytes
|
||||
f.write(struct.pack("I", max_token_length))
|
||||
for bytes, score in zip(tokens, scores):
|
||||
f.write(struct.pack("fI", score, len(bytes)))
|
||||
f.write(bytes)
|
||||
|
||||
if __name__ == "__main__":
|
||||
t = Tokenizer()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user