From 36bf904c1869d8c31bab34b93cb17695fa2c2ced Mon Sep 17 00:00:00 2001 From: aidoge <123085890+ai-doge@users.noreply.github.com> Date: Wed, 26 Jul 2023 14:23:25 +0800 Subject: [PATCH 1/7] Refactor freqs_cis into freqs_cos and freqs_sin, and remove complex64 for ONNX export compatibility --- model.py | 53 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/model.py b/model.py index 88d24f6..cafbbd6 100644 --- a/model.py +++ b/model.py @@ -40,9 +40,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return freqs_cos, freqs_sin def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim @@ -51,17 +51,31 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) - def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, - freqs_cis: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + + # reshape xq and xk to match the complex representation + xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1) + xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1) + + # reshape freqs_cos and freqs_sin for broadcasting + freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) + freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) + + # apply rotation using real numbers + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos + + # flatten last two dimensions + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -103,7 +117,8 @@ class Attention(nn.Module): def forward( self, x: torch.Tensor, - freqs_cis: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, ): bsz, seqlen, _ = x.shape @@ -114,7 +129,7 @@ class Attention(nn.Module): xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings - xq, xk = apply_rotary_emb(xq, xk, freqs_cis) + xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) # grouped multiquery attention: expand out keys and values xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -176,8 +191,8 @@ class TransformerBlock(nn.Module): self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - def forward(self, x, freqs_cis): - h = x + self.attention.forward(self.attention_norm(x), freqs_cis) + def forward(self, x, freqs_cos, freqs_sin): + h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin) out = h + self.feed_forward.forward(self.ffn_norm(h)) return out @@ -201,8 +216,9 @@ class Transformer(nn.Module): self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying # some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse - freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2) - self.register_buffer("freqs_cis", freqs_cis, persistent=False) + freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) # init all weights self.apply(self._init_weights) @@ -223,10 +239,11 @@ class Transformer(nn.Module): _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) h = self.dropout(h) - freqs_cis = self.freqs_cis[:seqlen] + freqs_cos = self.freqs_cos[:seqlen] + freqs_sin = self.freqs_sin[:seqlen] for layer in self.layers: - h = layer(h, freqs_cis) + h = layer(h, freqs_cos, freqs_sin) h = self.norm(h) if targets is not None: From 72ba34c39245dc1b54c4725cdea9659e57e98424 Mon Sep 17 00:00:00 2001 From: Murilo Curti Date: Thu, 27 Jul 2023 21:39:09 -0300 Subject: [PATCH 2/7] fix: Use correct compiler for Win64 GCC in Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 360cd2f..ced0d89 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ runomp: run.c .PHONY: win64 win64: - x86_64-w64-mingw32-gcc-win32 -Ofast -D_WIN32 -o run.exe -I. run.c win.c + x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c # compiles with gnu99 standard flags for amazon linux, coreos, etc. compatibility .PHONY: rungnu From 7cbb47cc3671d55c45ac9a99b207fcff22f11616 Mon Sep 17 00:00:00 2001 From: aidoge <123085890+ai-doge@users.noreply.github.com> Date: Fri, 28 Jul 2023 11:07:36 +0800 Subject: [PATCH 3/7] update export_meta_llama_bin, get freqs_cos, freqs_sin independently. --- export_meta_llama_bin.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/export_meta_llama_bin.py b/export_meta_llama_bin.py index 801077b..53ca652 100644 --- a/export_meta_llama_bin.py +++ b/export_meta_llama_bin.py @@ -55,10 +55,10 @@ def export(p, state_dict, filepath='model.bin'): # final rmsnorm serialize('norm.weight') - # freqs_cis - freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2) - state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']] - state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']] + # freqs_cos, freqs_sin + freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2) + state_dict['freqs_cis.real'] = freqs_cos[:p['max_seq_len']] + state_dict['freqs_cis.imag'] = freqs_sin[:p['max_seq_len']] serialize('freqs_cis.real') serialize('freqs_cis.imag') From 3418fed02b9837ee2d96d5809ed0499f3e9daf8e Mon Sep 17 00:00:00 2001 From: Daniil Tcelikin Date: Fri, 28 Jul 2023 14:53:00 +0300 Subject: [PATCH 4/7] added repository in readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 9219633..9c68a77 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg - [llama2.go](https://github.com/haormj/llama2.go) by @haormj: a Go port of this project - [llama2.go](https://github.com/saracen/llama2.go) by @saracen: a Go port of this project - [llama2.c-android](https://github.com/Manuel030/llama2.c-android): by @Manuel030: adds Android binaries of this project +- [llama2.c-android-wrapper](https://github.com/celikin/llama2.c-android-wrapper): by @celikin: added JNI wrapper, PoC - [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @leloykun: a C++ port of this project ## unsorted todos From 3b446baeb3cfecd05027d692b37c6feee18f885e Mon Sep 17 00:00:00 2001 From: Leo Du Date: Mon, 31 Jul 2023 03:34:34 -0400 Subject: [PATCH 5/7] update readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0e801c5..4f10938 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg - [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @leloykun: a C++ port of this project - [llama2.js](https://github.com/epicure/llama2.js) by @epicure: a JavaScript port of this project - [llama2.zig](https://github.com/cgbur/llama2.zig) by @cgbur: A Zig port of this project +- [llama2.rs](https://github.com/leo-du/llama2.rs) by @leo-du: A Rust port of this project ## unsorted todos From 4c0a88249de0e7ec0ae8533f21c2ebbc4430d3b6 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Mon, 31 Jul 2023 14:59:11 +0200 Subject: [PATCH 6/7] add link to scala port --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0e801c5..1259c5e 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg - [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @leloykun: a C++ port of this project - [llama2.js](https://github.com/epicure/llama2.js) by @epicure: a JavaScript port of this project - [llama2.zig](https://github.com/cgbur/llama2.zig) by @cgbur: A Zig port of this project +- [llama2.scala](https://github.com/jrudolph/llama2.scala) by @jrudolph: a Scala port of this project ## unsorted todos From 883cda1a2ca65a07b18608175ce24a4a81b7a33d Mon Sep 17 00:00:00 2001 From: aidoge <123085890+ai-doge@users.noreply.github.com> Date: Tue, 1 Aug 2023 16:31:43 +0800 Subject: [PATCH 7/7] fix freq_cos, freq_sin serialize --- export_meta_llama_bin.py | 8 ++++---- model.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/export_meta_llama_bin.py b/export_meta_llama_bin.py index 53ca652..41c1705 100644 --- a/export_meta_llama_bin.py +++ b/export_meta_llama_bin.py @@ -57,10 +57,10 @@ def export(p, state_dict, filepath='model.bin'): serialize('norm.weight') # freqs_cos, freqs_sin freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2) - state_dict['freqs_cis.real'] = freqs_cos[:p['max_seq_len']] - state_dict['freqs_cis.imag'] = freqs_sin[:p['max_seq_len']] - serialize('freqs_cis.real') - serialize('freqs_cis.imag') + state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']] + state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']] + serialize('freqs_cos') + serialize('freqs_sin') # finally write the output weights serialize('output.weight') diff --git a/model.py b/model.py index cafbbd6..1600f5b 100644 --- a/model.py +++ b/model.py @@ -376,8 +376,8 @@ class Transformer(nn.Module): serialize(self.norm.weight) # note: no need to write final classifier weights due to weight sharing # freqs_cis - serialize(self.freqs_cis.real[:p.max_seq_len]) - serialize(self.freqs_cis.imag[:p.max_seq_len]) + serialize(self.freqs_cos[:p.max_seq_len]) + serialize(self.freqs_sin[:p.max_seq_len]) # write to binary file f.close()