diff --git a/model.py b/model.py index 09e6aa5..9e4ce22 100644 --- a/model.py +++ b/model.py @@ -17,7 +17,7 @@ class ModelArgs: n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = 32000 - hidden_dim: int = (4 * 4096) + hidden_dim: Optional[int] = None multiple_of: int = 256 # MLP hidden layer size will be multiple of norm_eps: float = 1e-5 max_seq_len: int = 2048 @@ -167,8 +167,10 @@ class Attention(nn.Module): class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float): super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + if hidden_dim is None: + hidden_dim = 4 * dim + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False)