absorb our rng state into the Sampler. I feel that this is correct because it makes our use of entropy very explicit and localized, and the sampler is now well-contained without any global state. Code is increasingly more beautiful.

This commit is contained in:
Andrej Karpathy 2023-08-22 03:22:56 +00:00
parent ac6cf8d6e8
commit d26a499207

53
run.c
View File

@ -455,7 +455,8 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
}
// create a temporary buffer that will store merge candidates of always two consecutive tokens
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
size_t str_len = 0;
// add_dummy_prefix is true by default
@ -559,22 +560,9 @@ typedef struct {
ProbIndex* probindex; // buffer used in top-p sampling
float temperature;
float topp;
unsigned long long rng_state;
} Sampler;
// rng should technically be a state variable of the Sampler
// leaving it global here for now for convenience, maybe move later
unsigned long long rng_seed;
unsigned int random_u32() {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
rng_seed ^= rng_seed >> 12;
rng_seed ^= rng_seed << 25;
rng_seed ^= rng_seed >> 27;
return (rng_seed * 0x2545F4914F6CDD1Dull) >> 32;
}
float random_f32() { // random float32 in [0,1)
return (random_u32() >> 8) / 16777216.0f;
}
int sample_argmax(float* probabilities, int n) {
// return the index that has the highest probability
int max_i = 0;
@ -588,13 +576,13 @@ int sample_argmax(float* probabilities, int n) {
return max_i;
}
int sample_mult(float* probabilities, int n) {
int sample_mult(float* probabilities, int n, float coin) {
// sample index from probabilities (they must sum to 1!)
float r = random_f32();
// coin is a random number in [0, 1), usually from random_f32()
float cdf = 0.0f;
for (int i = 0; i < n; i++) {
cdf += probabilities[i];
if (r < cdf) {
if (coin < cdf) {
return i;
}
}
@ -609,10 +597,11 @@ int compare(const void* a, const void* b) {
return 0;
}
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
// coin is a random number in [0, 1), usually from random_f32()
int n0 = 0;
// quicksort indices in descending order of probabilities
@ -640,7 +629,7 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
}
// sample from the truncated list
float r = random_f32() * cumulative_prob;
float r = coin * cumulative_prob;
float cdf = 0.0f;
for (int i = 0; i <= last_idx; i++) {
cdf += probindex[i].prob;
@ -651,10 +640,11 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
return probindex[last_idx].index; // in case of rounding errors
}
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp) {
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
sampler->vocab_size = vocab_size;
sampler->temperature = temperature;
sampler->topp = topp;
sampler->rng_state = rng_seed;
// buffer only used with nucleus sampling; may not need but it's ~small
sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
}
@ -663,6 +653,17 @@ void free_sampler(Sampler* sampler) {
free(sampler->probindex);
}
unsigned int random_u32(unsigned long long *state) {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
*state ^= *state >> 12;
*state ^= *state << 25;
*state ^= *state >> 27;
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
}
float random_f32(unsigned long long *state) { // random float32 in [0,1)
return (random_u32(state) >> 8) / 16777216.0f;
}
int sample(Sampler* sampler, float* logits) {
// sample the token given the logits and some hyperparameters
int next;
@ -674,13 +675,15 @@ int sample(Sampler* sampler, float* logits) {
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(logits, sampler->vocab_size);
// flip a (float) coin (this is our source of entropy for sampling)
float coin = random_f32(&sampler->rng_state);
// we sample from this distribution to get the next token
if (sampler->topp <= 0 || sampler->topp >= 1) {
// simply sample from the predicted probability distribution
next = sample_mult(logits, sampler->vocab_size);
next = sample_mult(logits, sampler->vocab_size, coin);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex);
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
}
}
return next;
@ -775,9 +778,9 @@ int main(int argc, char *argv[]) {
char *tokenizer_path = "tokenizer.bin";
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
rng_seed = 0; // seed rng with time by default
int steps = 256; // number of steps to run for
char *prompt = NULL; // prompt string
unsigned long long rng_seed = 0; // seed rng with time by default
// poor man's C argparse so we can override the defaults above from the command line
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
@ -813,7 +816,7 @@ int main(int argc, char *argv[]) {
// build the Sampler
Sampler sampler;
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp);
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
// run!
generate(&transformer, &tokenizer, &sampler, prompt, steps);