diff --git a/model.py b/model.py index f7edbb6..7329d6c 100644 --- a/model.py +++ b/model.py @@ -11,12 +11,13 @@ from torch import nn @dataclass class ModelArgs: + # default hyperparameters for the Llama 7B model dim: int = 4096 n_layers: int = 32 n_heads: int = 32 n_kv_heads: Optional[int] = None - vocab_size: int = -1 # defined later by tokenizer - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + vocab_size: int = 32000 + multiple_of: int = 256 # MLP hidden layer size will be multiple of norm_eps: float = 1e-5 max_seq_len: int = 2048 dropout: float = 0.0 diff --git a/sample.py b/sample.py index 040bc14..93c9407 100644 --- a/sample.py +++ b/sample.py @@ -9,6 +9,8 @@ import tiktoken from model import ModelArgs, Transformer from tokenizer import Tokenizer +from tinystories import get_tokenizer_model_path + # ----------------------------------------------------------------------------- out_dir = 'out' # ignored if init_from is not 'resume' start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" @@ -51,7 +53,9 @@ if compile: model = torch.compile(model) # requires PyTorch 2.0 (optional) # load the tokenizer -enc = Tokenizer() +assert checkpoint["config"]["dataset"] == "tinystories" # TODO: generalize +tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size) +enc = Tokenizer(tokenizer_model=tokenizer_model) # encode the beginning of the prompt if start.startswith('FILE:'): diff --git a/tinystories.py b/tinystories.py index d41f8fc..278c817 100644 --- a/tinystories.py +++ b/tinystories.py @@ -120,9 +120,7 @@ def train_vocab(vocab_size): def process_shard(args, vocab_size): shard_id, shard = args - tokenizer_model = None - if vocab_size > 0: - tokenizer_model = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model") + tokenizer_model = get_tokenizer_model_path() enc = Tokenizer(tokenizer_model) with open(shard, "r") as f: data = json.load(f) @@ -171,10 +169,12 @@ def pretokenize(vocab_size): class PretokDataset(torch.utils.data.IterableDataset): """Loads pretokenized examples from disk and yields them as PyTorch tensors.""" - def __init__(self, split, max_seq_len): + def __init__(self, split, max_seq_len, vocab_size, vocab_source): super().__init__() self.split = split self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.vocab_source = vocab_source def __iter__(self): # get worker info within a DataLoader @@ -186,8 +186,14 @@ class PretokDataset(torch.utils.data.IterableDataset): seed = 42 + worker_id + 1337 * rank rng = random.Random(seed) print(f"Created a PretokDataset with rng seed {seed}") - data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") - shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin"))) + if self.vocab_source == "llama2": + # the .bin files are right along the .json files + bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") + shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin"))) + elif self.vocab_source == "custom": + # the .bin files are in tok{N} directory + bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}") + shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin"))) # train/test split. let's use only shard 0 for test split, rest train shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1] while True: @@ -209,12 +215,25 @@ class PretokDataset(torch.utils.data.IterableDataset): y = chunk[1:] yield x, y +# ----------------------------------------------------------------------------- +# public interface functions + +def get_tokenizer_model_path(vocab_size): + """ + Returns path to the sentencepiece tokenizer model for a given vocab size + vocab_size = 0 designates the default Llama 2 tokenizer, in that case + None is returned. + """ + if vocab_size == 0: + return None + else: + return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model") class Task: @staticmethod - def iter_batches(split, batch_size, max_seq_len, device, num_workers=0): - ds = PretokDataset(split, max_seq_len) + def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs): + ds = PretokDataset(**dataset_kwargs) dl = torch.utils.data.DataLoader( ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers ) @@ -223,6 +242,8 @@ class Task: y = y.to(device, non_blocking=True) yield x, y +# ----------------------------------------------------------------------------- +# CLI for constructing the dataset if __name__ == "__main__": """ diff --git a/train.py b/train.py index dbf0b24..662afcf 100644 --- a/train.py +++ b/train.py @@ -47,6 +47,8 @@ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S") # data batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size max_seq_len = 256 +vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained +vocab_size = 512 dataset = "tinystories" # tinystories|tinyshakespeare # model dim = 288 @@ -83,6 +85,10 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla +# validating checks +assert vocab_source in ["llama2", "custom"] +assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens" + # various inits, derived attributes, I/O setup ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? if ddp: @@ -128,6 +134,8 @@ iter_batches = partial( task.iter_batches, batch_size=batch_size, max_seq_len=max_seq_len, + vocab_size=vocab_size, + vocab_source=vocab_source, device=device, num_workers=0, ) @@ -142,7 +150,7 @@ model_args = dict( n_layers=n_layers, n_heads=n_heads, n_kv_heads=n_heads, - vocab_size=32000, + vocab_size=vocab_size, multiple_of=multiple_of, max_seq_len=max_seq_len, dropout=dropout, @@ -206,7 +214,7 @@ def estimate_loss(): out = {} model.eval() for split in ["train", "val"]: - batch_iter = iter_batches(split) + batch_iter = iter_batches(split=split) losses = torch.zeros(eval_iters) # keep on CPU for k in range(eval_iters): X, Y = next(batch_iter) @@ -238,7 +246,7 @@ if wandb_log and master_process: wandb.init(project=wandb_project, name=wandb_run_name, config=config) # training loop -train_batch_iter = iter_batches("train") +train_batch_iter = iter_batches(split="train") X, Y = next(train_batch_iter) # fetch the very first batch t0 = time.time() local_iter_num = 0 # number of iterations in the lifetime of this process