mark ModelArgs.hidden_dim as optional and calculate as previously if not provided

This commit is contained in:
atamyrat 2023-08-21 03:40:34 +03:00
parent 09db52c69e
commit d7704bdeaa

View File

@ -17,7 +17,7 @@ class ModelArgs:
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = 32000
hidden_dim: int = (4 * 4096)
hidden_dim: Optional[int] = None
multiple_of: int = 256 # MLP hidden layer size will be multiple of
norm_eps: float = 1e-5
max_seq_len: int = 2048
@ -167,8 +167,10 @@ class Attention(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)