mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
Llama3 Support (WIP)
use -l 3 option
This commit is contained in:
parent
ed2253b306
commit
1be98e214d
@ -33,9 +33,9 @@ Learn more about the Llama2 models & architecture at Meta: [Llama 2 @ Meta](http
|
||||
|
||||
# Features & Milestones
|
||||
|
||||
#### Llama 3 Support
|
||||
#### Llama 3 Support WIP
|
||||
|
||||
Almost done - Coming Soonish (TM)...
|
||||
Should support inference, WIP, use -l 3 option...
|
||||
|
||||
#### L2E OS (Linux Kernel)
|
||||
|
||||
|
||||
140
run.c
140
run.c
@ -9,6 +9,12 @@
|
||||
|
||||
int buffertokens = 1; // output token buffer size
|
||||
int stats = 1; // extended status info
|
||||
int llamaver = 2; // llama version (default is 2, valid 2 & 3)
|
||||
float rope_sf = 10000.0; // Rope scaling factor, 10000.0 => llama2, 500000.0 > llama3
|
||||
int BOS = 1; // Beginning of Sentence token value, llama2 = 1 , llama3 = 128000
|
||||
int EOS = 2; // End of Sentence token value, llama2 = 2 , llama3 = 128009 (end of text)
|
||||
char system_template[1024]="";
|
||||
char user_template[1024]="";
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// L2E Humanoid : Linux Kernel Support Directives
|
||||
@ -550,7 +556,9 @@ float* forward(Transformer* transformer, int token, int pos) {
|
||||
// RoPE relative positional encoding: complex-valued rotate q and k in each head
|
||||
for (int i = 0; i < dim; i+=2) {
|
||||
int head_dim = i % head_size;
|
||||
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
|
||||
// L2E Addition
|
||||
float freq = 1.0f / powf(rope_sf, head_dim / (float)head_size);
|
||||
// END L2E Addition
|
||||
float val = pos * freq;
|
||||
float fcr = cosf(val);
|
||||
float fci = sinf(val);
|
||||
@ -738,8 +746,10 @@ void free_tokenizer(Tokenizer* t) {
|
||||
|
||||
char* decode(Tokenizer* t, int prev_token, int token) {
|
||||
char *piece = t->vocab[token];
|
||||
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
if (prev_token == 1 && piece[0] == ' ') { piece++; }
|
||||
// L2E Addition
|
||||
// following BOS (1) or (2) token, sentencepiece decoder strips any leading whitespace (see PR #89)
|
||||
if (prev_token == BOS && piece[0] == ' ') { piece++; }
|
||||
// END L2E Addition
|
||||
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
|
||||
// parse this and convert and return the actual byte
|
||||
unsigned char byte_val;
|
||||
@ -772,7 +782,7 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
|
||||
|
||||
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
|
||||
// encode the string text (input) into an upper-bound preallocated tokens[] array
|
||||
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
|
||||
// bos != 0 means prepend the BOS token, eos != 0 means append the EOS token
|
||||
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
|
||||
|
||||
if (t->sorted_vocab == NULL) {
|
||||
@ -793,17 +803,25 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
|
||||
// start at 0 tokens
|
||||
*n_tokens = 0;
|
||||
|
||||
// add optional BOS (=1) token, if desired
|
||||
if (bos) tokens[(*n_tokens)++] = 1;
|
||||
// L2E Addition
|
||||
// add optional BOS token, if desired
|
||||
if (bos) tokens[(*n_tokens)++] = BOS;
|
||||
// END L2E Addition
|
||||
|
||||
|
||||
// add_dummy_prefix is true by default
|
||||
// so prepend a dummy prefix token to the input string, but only if text != ""
|
||||
// TODO: pretty sure this isn't correct in the general case but I don't have the
|
||||
// energy to read more of the sentencepiece code to figure out what it's doing
|
||||
if (text[0] != '\0') {
|
||||
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
|
||||
tokens[(*n_tokens)++] = dummy_prefix;
|
||||
|
||||
// L2E Addition
|
||||
if (llamaver == 2) {
|
||||
if (text[0] != '\0') {
|
||||
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
|
||||
tokens[(*n_tokens)++] = dummy_prefix;
|
||||
}
|
||||
}
|
||||
// END L2E Addition
|
||||
|
||||
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
|
||||
// Code point ↔ UTF-8 conversion
|
||||
@ -854,13 +872,16 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
|
||||
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
|
||||
}
|
||||
|
||||
// merge the best consecutive pair each iteration, according the scores in vocab_scores
|
||||
// L2E Addition
|
||||
// merge the best consecutive pair or triple each iteration, according to the scores in vocab_scores
|
||||
while (1) {
|
||||
float best_score = -1e10;
|
||||
int best_id = -1;
|
||||
int best_idx = -1;
|
||||
int best_merge = 0; // length of the best merge sequence (2 for pair, 3 for triple)
|
||||
|
||||
for (int i=0; i < (*n_tokens-1); i++) {
|
||||
// try to find the best pair or triple to merge
|
||||
for (int i = 0; i < (*n_tokens - 1); i++) {
|
||||
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
|
||||
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
|
||||
@ -869,28 +890,45 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
|
||||
best_score = t->vocab_scores[id];
|
||||
best_id = id;
|
||||
best_idx = i;
|
||||
best_merge = 2;
|
||||
}
|
||||
|
||||
// check if we can merge the triple (tokens[i], tokens[i+1], tokens[i+2])
|
||||
if (i < (*n_tokens - 2)) {
|
||||
sprintf(str_buffer, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]], t->vocab[tokens[i+2]]);
|
||||
id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
|
||||
if (id != -1 && t->vocab_scores[id] > best_score) {
|
||||
// this merge triple exists in vocab! record its score and position
|
||||
best_score = t->vocab_scores[id];
|
||||
best_id = id;
|
||||
best_idx = i;
|
||||
best_merge = 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (best_idx == -1) {
|
||||
break; // we couldn't find any more pairs to merge, so we're done
|
||||
break; // we couldn't find any more pairs or triples to merge, so we're done
|
||||
}
|
||||
|
||||
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
|
||||
// merge the consecutive pair or triple (best_idx, best_idx+1[, best_idx+2]) into new token best_id
|
||||
tokens[best_idx] = best_id;
|
||||
// delete token at position best_idx+1, shift the entire sequence back 1
|
||||
for (int i = best_idx+1; i < (*n_tokens-1); i++) {
|
||||
tokens[i] = tokens[i+1];
|
||||
// delete token(s) at position best_idx+1 (and optionally best_idx+2), shift the entire sequence back
|
||||
for (int i = best_idx + 1; i < (*n_tokens - best_merge + 1); i++) {
|
||||
tokens[i] = tokens[i + best_merge - 1];
|
||||
}
|
||||
(*n_tokens)--; // token length decreased
|
||||
(*n_tokens) -= (best_merge - 1); // token length decreased by the number of merged tokens minus one
|
||||
}
|
||||
|
||||
// add optional EOS (=2) token, if desired
|
||||
if (eos) tokens[(*n_tokens)++] = 2;
|
||||
// add optional EOS token, if desired
|
||||
if (eos) tokens[(*n_tokens)++] = EOS;
|
||||
|
||||
free(str_buffer);
|
||||
|
||||
}
|
||||
|
||||
// END L2E Addition
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// The Sampler, which takes logits and returns a sampled token
|
||||
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
|
||||
@ -1089,9 +1127,11 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
next = sample(sampler, logits);
|
||||
}
|
||||
pos++;
|
||||
|
||||
// data-dependent terminating condition: the BOS (=1) token delimits sequences
|
||||
if (next == 1) { break; }
|
||||
|
||||
// L2E Addition
|
||||
// data-dependent terminating condition: the BOS token delimits sequences
|
||||
if (next == BOS) { break; }
|
||||
// END L2E Addition
|
||||
|
||||
// print the token as string, decode it with the Tokenizer object
|
||||
char* piece = decode(tokenizer, token, next);
|
||||
@ -1141,18 +1181,46 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
|
||||
// buffers for reading the system prompt and user prompt from stdin
|
||||
// you'll notice they are soomewhat haphazardly and unsafely set atm
|
||||
// L2E Addition
|
||||
char system_prompt[512];
|
||||
char user_prompt[512];
|
||||
char rendered_prompt[1152];
|
||||
char rendered_prompt[2048];
|
||||
int num_prompt_tokens = 0;
|
||||
int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
|
||||
int* prompt_tokens = (int*)malloc(2048 * sizeof(int));
|
||||
// END L2E Addition
|
||||
int user_idx;
|
||||
|
||||
// start the main loop
|
||||
int8_t user_turn = 1; // user starts
|
||||
int next; // will store the next token in the sequence
|
||||
int token; // stores the current token to feed into the transformer
|
||||
int prev_token;
|
||||
// L2E Addition
|
||||
/* System and user prompt templates for llama 2 and llama 3
|
||||
Llama 2:
|
||||
System:
|
||||
[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]
|
||||
User:
|
||||
[INST] %s [/INST]
|
||||
|
||||
Llama 3:
|
||||
System:
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|>
|
||||
User:
|
||||
<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|>\n
|
||||
Assistant: (Starts Generating)
|
||||
<|start_header_id|>assistant<|end_header_id|>\n\n
|
||||
*/
|
||||
if (llamaver == 3) {
|
||||
BOS = 128000; // 128000 = <|begin_of_text|>
|
||||
EOS = 128009; // 128009 = <|eot_id|> , 128001 = <|end_of_text|>
|
||||
strcpy(system_template, "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|>");
|
||||
strcpy(user_template, "<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n");
|
||||
} else {
|
||||
int prev_token;
|
||||
strcpy(system_template,"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]");
|
||||
strcpy(user_template, "[INST] %s [/INST]");
|
||||
}
|
||||
// END L2E Addition
|
||||
int pos = 0; // position in the sequence
|
||||
while (pos < steps) {
|
||||
|
||||
@ -1177,14 +1245,14 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
// otherwise get user prompt from stdin
|
||||
read_stdin("User: ", user_prompt, sizeof(user_prompt));
|
||||
}
|
||||
// L2E Addition
|
||||
// render user/system prompts into the Llama 2 Chat schema
|
||||
if (pos == 0 && system_prompt[0] != '\0') {
|
||||
char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
|
||||
sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
|
||||
} else {
|
||||
char user_template[] = "[INST] %s [/INST]";
|
||||
sprintf(rendered_prompt, user_template, user_prompt);
|
||||
}
|
||||
// END L2E Addition
|
||||
// encode the rendered prompt into tokens
|
||||
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
|
||||
user_idx = 0; // reset the user index
|
||||
@ -1200,22 +1268,25 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
|
||||
// otherwise use the next token sampled from previous turn
|
||||
token = next;
|
||||
}
|
||||
// EOS (=2) token ends the Assistant turn
|
||||
if (token == 2) { user_turn = 1; }
|
||||
// L2E Addition
|
||||
// EOS token ends the Assistant turn
|
||||
if (token == EOS) { user_turn = 1; }
|
||||
// End L2E Addition
|
||||
|
||||
// forward the transformer to get logits for the next token
|
||||
float* logits = forward(transformer, token, pos);
|
||||
next = sample(sampler, logits);
|
||||
pos++;
|
||||
|
||||
if (user_idx >= num_prompt_tokens && next != 2) {
|
||||
// L2E Addition
|
||||
if (user_idx >= num_prompt_tokens && next != EOS) {
|
||||
// the Assistant is responding, so print its output
|
||||
char* piece = decode(tokenizer, token, next);
|
||||
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
|
||||
fflush(stdout);
|
||||
}
|
||||
if (next == 2) { printf("\n"); }
|
||||
if (next == EOS) { printf("\n"); }
|
||||
}
|
||||
// End L2E Addition
|
||||
printf("\n");
|
||||
free(prompt_tokens);
|
||||
}
|
||||
@ -1254,7 +1325,8 @@ void error_usage() {
|
||||
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
|
||||
// L2E Addition
|
||||
fprintf(stderr, " -b <int> number of tokens to buffer, default 1. 0 = max_seq_len\n");
|
||||
fprintf(stderr, " -x <int> extended info / stats, default 1 = on. 0 = off\n");
|
||||
fprintf(stderr, " -x <int> extended info / stats, default 1 = on. 0 = off\n");
|
||||
fprintf(stderr, " -l <int> llama version / default 2 = llama2. 3 = llama3\n");
|
||||
// END L2E Addition
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
@ -1323,9 +1395,11 @@ int main(int argc, char *argv[]) {
|
||||
// L2E Addition
|
||||
else if (argv[i][1] == 'b') { buffertokens = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'x') { stats = atoi(argv[i + 1]); }
|
||||
else if (argv[i][1] == 'l') { llamaver = atoi(argv[i + 1]); }
|
||||
// END L2E Addition
|
||||
else { error_usage(); }
|
||||
}
|
||||
if (llamaver == 3){ rope_sf = 500000.0; }
|
||||
// L2E Addition
|
||||
#endif
|
||||
// END L2E Addition
|
||||
|
||||
Loading…
Reference in New Issue
Block a user