mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
mark ModelArgs.hidden_dim as optional and calculate as previously if not provided
This commit is contained in:
parent
09db52c69e
commit
d7704bdeaa
8
model.py
8
model.py
@ -17,7 +17,7 @@ class ModelArgs:
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
vocab_size: int = 32000
|
||||
hidden_dim: int = (4 * 4096)
|
||||
hidden_dim: Optional[int] = None
|
||||
multiple_of: int = 256 # MLP hidden layer size will be multiple of
|
||||
norm_eps: float = 1e-5
|
||||
max_seq_len: int = 2048
|
||||
@ -167,8 +167,10 @@ class Attention(nn.Module):
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
if hidden_dim is None:
|
||||
hidden_dim = 4 * dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user