mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
absorb our rng state into the Sampler. I feel that this is correct because it makes our use of entropy very explicit and localized, and the sampler is now well-contained without any global state. Code is increasingly more beautiful.
This commit is contained in:
parent
ac6cf8d6e8
commit
d26a499207
53
run.c
53
run.c
@ -455,7 +455,8 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
|
||||
}
|
||||
|
||||
// create a temporary buffer that will store merge candidates of always two consecutive tokens
|
||||
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
|
||||
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
|
||||
size_t str_len = 0;
|
||||
|
||||
// add_dummy_prefix is true by default
|
||||
@ -559,22 +560,9 @@ typedef struct {
|
||||
ProbIndex* probindex; // buffer used in top-p sampling
|
||||
float temperature;
|
||||
float topp;
|
||||
unsigned long long rng_state;
|
||||
} Sampler;
|
||||
|
||||
// rng should technically be a state variable of the Sampler
|
||||
// leaving it global here for now for convenience, maybe move later
|
||||
unsigned long long rng_seed;
|
||||
unsigned int random_u32() {
|
||||
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
||||
rng_seed ^= rng_seed >> 12;
|
||||
rng_seed ^= rng_seed << 25;
|
||||
rng_seed ^= rng_seed >> 27;
|
||||
return (rng_seed * 0x2545F4914F6CDD1Dull) >> 32;
|
||||
}
|
||||
float random_f32() { // random float32 in [0,1)
|
||||
return (random_u32() >> 8) / 16777216.0f;
|
||||
}
|
||||
|
||||
int sample_argmax(float* probabilities, int n) {
|
||||
// return the index that has the highest probability
|
||||
int max_i = 0;
|
||||
@ -588,13 +576,13 @@ int sample_argmax(float* probabilities, int n) {
|
||||
return max_i;
|
||||
}
|
||||
|
||||
int sample_mult(float* probabilities, int n) {
|
||||
int sample_mult(float* probabilities, int n, float coin) {
|
||||
// sample index from probabilities (they must sum to 1!)
|
||||
float r = random_f32();
|
||||
// coin is a random number in [0, 1), usually from random_f32()
|
||||
float cdf = 0.0f;
|
||||
for (int i = 0; i < n; i++) {
|
||||
cdf += probabilities[i];
|
||||
if (r < cdf) {
|
||||
if (coin < cdf) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
@ -609,10 +597,11 @@ int compare(const void* a, const void* b) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
|
||||
// top-p sampling (or "nucleus sampling") samples from the smallest set of
|
||||
// tokens that exceed probability topp. This way we never sample tokens that
|
||||
// have very low probabilities and are less likely to go "off the rails".
|
||||
// coin is a random number in [0, 1), usually from random_f32()
|
||||
|
||||
int n0 = 0;
|
||||
// quicksort indices in descending order of probabilities
|
||||
@ -640,7 +629,7 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
}
|
||||
|
||||
// sample from the truncated list
|
||||
float r = random_f32() * cumulative_prob;
|
||||
float r = coin * cumulative_prob;
|
||||
float cdf = 0.0f;
|
||||
for (int i = 0; i <= last_idx; i++) {
|
||||
cdf += probindex[i].prob;
|
||||
@ -651,10 +640,11 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
|
||||
return probindex[last_idx].index; // in case of rounding errors
|
||||
}
|
||||
|
||||
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp) {
|
||||
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
|
||||
sampler->vocab_size = vocab_size;
|
||||
sampler->temperature = temperature;
|
||||
sampler->topp = topp;
|
||||
sampler->rng_state = rng_seed;
|
||||
// buffer only used with nucleus sampling; may not need but it's ~small
|
||||
sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
|
||||
}
|
||||
@ -663,6 +653,17 @@ void free_sampler(Sampler* sampler) {
|
||||
free(sampler->probindex);
|
||||
}
|
||||
|
||||
unsigned int random_u32(unsigned long long *state) {
|
||||
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
||||
*state ^= *state >> 12;
|
||||
*state ^= *state << 25;
|
||||
*state ^= *state >> 27;
|
||||
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
|
||||
}
|
||||
float random_f32(unsigned long long *state) { // random float32 in [0,1)
|
||||
return (random_u32(state) >> 8) / 16777216.0f;
|
||||
}
|
||||
|
||||
int sample(Sampler* sampler, float* logits) {
|
||||
// sample the token given the logits and some hyperparameters
|
||||
int next;
|
||||
@ -674,13 +675,15 @@ int sample(Sampler* sampler, float* logits) {
|
||||
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
|
||||
// apply softmax to the logits to get the probabilities for next token
|
||||
softmax(logits, sampler->vocab_size);
|
||||
// flip a (float) coin (this is our source of entropy for sampling)
|
||||
float coin = random_f32(&sampler->rng_state);
|
||||
// we sample from this distribution to get the next token
|
||||
if (sampler->topp <= 0 || sampler->topp >= 1) {
|
||||
// simply sample from the predicted probability distribution
|
||||
next = sample_mult(logits, sampler->vocab_size);
|
||||
next = sample_mult(logits, sampler->vocab_size, coin);
|
||||
} else {
|
||||
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex);
|
||||
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
|
||||
}
|
||||
}
|
||||
return next;
|
||||
@ -775,9 +778,9 @@ int main(int argc, char *argv[]) {
|
||||
char *tokenizer_path = "tokenizer.bin";
|
||||
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
|
||||
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
|
||||
rng_seed = 0; // seed rng with time by default
|
||||
int steps = 256; // number of steps to run for
|
||||
char *prompt = NULL; // prompt string
|
||||
unsigned long long rng_seed = 0; // seed rng with time by default
|
||||
|
||||
// poor man's C argparse so we can override the defaults above from the command line
|
||||
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
|
||||
@ -813,7 +816,7 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
// build the Sampler
|
||||
Sampler sampler;
|
||||
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp);
|
||||
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
|
||||
|
||||
// run!
|
||||
generate(&transformer, &tokenizer, &sampler, prompt, steps);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user