mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
ok i can train and sample a model with a custom tokenizer
This commit is contained in:
parent
4c6f0af9ff
commit
b0cfa2458d
5
model.py
5
model.py
@ -11,12 +11,13 @@ from torch import nn
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
# default hyperparameters for the Llama 7B model
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
vocab_size: int = -1 # defined later by tokenizer
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
vocab_size: int = 32000
|
||||
multiple_of: int = 256 # MLP hidden layer size will be multiple of
|
||||
norm_eps: float = 1e-5
|
||||
max_seq_len: int = 2048
|
||||
dropout: float = 0.0
|
||||
|
||||
@ -9,6 +9,8 @@ import tiktoken
|
||||
from model import ModelArgs, Transformer
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
from tinystories import get_tokenizer_model_path
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
out_dir = 'out' # ignored if init_from is not 'resume'
|
||||
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
|
||||
@ -51,7 +53,9 @@ if compile:
|
||||
model = torch.compile(model) # requires PyTorch 2.0 (optional)
|
||||
|
||||
# load the tokenizer
|
||||
enc = Tokenizer()
|
||||
assert checkpoint["config"]["dataset"] == "tinystories" # TODO: generalize
|
||||
tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size)
|
||||
enc = Tokenizer(tokenizer_model=tokenizer_model)
|
||||
|
||||
# encode the beginning of the prompt
|
||||
if start.startswith('FILE:'):
|
||||
|
||||
@ -120,9 +120,7 @@ def train_vocab(vocab_size):
|
||||
|
||||
def process_shard(args, vocab_size):
|
||||
shard_id, shard = args
|
||||
tokenizer_model = None
|
||||
if vocab_size > 0:
|
||||
tokenizer_model = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
|
||||
tokenizer_model = get_tokenizer_model_path()
|
||||
enc = Tokenizer(tokenizer_model)
|
||||
with open(shard, "r") as f:
|
||||
data = json.load(f)
|
||||
@ -171,10 +169,12 @@ def pretokenize(vocab_size):
|
||||
class PretokDataset(torch.utils.data.IterableDataset):
|
||||
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
|
||||
|
||||
def __init__(self, split, max_seq_len):
|
||||
def __init__(self, split, max_seq_len, vocab_size, vocab_source):
|
||||
super().__init__()
|
||||
self.split = split
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_source = vocab_source
|
||||
|
||||
def __iter__(self):
|
||||
# get worker info within a DataLoader
|
||||
@ -186,8 +186,14 @@ class PretokDataset(torch.utils.data.IterableDataset):
|
||||
seed = 42 + worker_id + 1337 * rank
|
||||
rng = random.Random(seed)
|
||||
print(f"Created a PretokDataset with rng seed {seed}")
|
||||
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin")))
|
||||
if self.vocab_source == "llama2":
|
||||
# the .bin files are right along the .json files
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||
elif self.vocab_source == "custom":
|
||||
# the .bin files are in tok{N} directory
|
||||
bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}")
|
||||
shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
|
||||
# train/test split. let's use only shard 0 for test split, rest train
|
||||
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
|
||||
while True:
|
||||
@ -209,12 +215,25 @@ class PretokDataset(torch.utils.data.IterableDataset):
|
||||
y = chunk[1:]
|
||||
yield x, y
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# public interface functions
|
||||
|
||||
def get_tokenizer_model_path(vocab_size):
|
||||
"""
|
||||
Returns path to the sentencepiece tokenizer model for a given vocab size
|
||||
vocab_size = 0 designates the default Llama 2 tokenizer, in that case
|
||||
None is returned.
|
||||
"""
|
||||
if vocab_size == 0:
|
||||
return None
|
||||
else:
|
||||
return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
|
||||
|
||||
class Task:
|
||||
|
||||
@staticmethod
|
||||
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
|
||||
ds = PretokDataset(split, max_seq_len)
|
||||
def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs):
|
||||
ds = PretokDataset(**dataset_kwargs)
|
||||
dl = torch.utils.data.DataLoader(
|
||||
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
||||
)
|
||||
@ -223,6 +242,8 @@ class Task:
|
||||
y = y.to(device, non_blocking=True)
|
||||
yield x, y
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI for constructing the dataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
|
||||
14
train.py
14
train.py
@ -47,6 +47,8 @@ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
# data
|
||||
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
||||
max_seq_len = 256
|
||||
vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
|
||||
vocab_size = 512
|
||||
dataset = "tinystories" # tinystories|tinyshakespeare
|
||||
# model
|
||||
dim = 288
|
||||
@ -83,6 +85,10 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
|
||||
min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
||||
|
||||
# validating checks
|
||||
assert vocab_source in ["llama2", "custom"]
|
||||
assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens"
|
||||
|
||||
# various inits, derived attributes, I/O setup
|
||||
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
||||
if ddp:
|
||||
@ -128,6 +134,8 @@ iter_batches = partial(
|
||||
task.iter_batches,
|
||||
batch_size=batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
vocab_source=vocab_source,
|
||||
device=device,
|
||||
num_workers=0,
|
||||
)
|
||||
@ -142,7 +150,7 @@ model_args = dict(
|
||||
n_layers=n_layers,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_heads,
|
||||
vocab_size=32000,
|
||||
vocab_size=vocab_size,
|
||||
multiple_of=multiple_of,
|
||||
max_seq_len=max_seq_len,
|
||||
dropout=dropout,
|
||||
@ -206,7 +214,7 @@ def estimate_loss():
|
||||
out = {}
|
||||
model.eval()
|
||||
for split in ["train", "val"]:
|
||||
batch_iter = iter_batches(split)
|
||||
batch_iter = iter_batches(split=split)
|
||||
losses = torch.zeros(eval_iters) # keep on CPU
|
||||
for k in range(eval_iters):
|
||||
X, Y = next(batch_iter)
|
||||
@ -238,7 +246,7 @@ if wandb_log and master_process:
|
||||
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
||||
|
||||
# training loop
|
||||
train_batch_iter = iter_batches("train")
|
||||
train_batch_iter = iter_batches(split="train")
|
||||
X, Y = next(train_batch_iter) # fetch the very first batch
|
||||
t0 = time.time()
|
||||
local_iter_num = 0 # number of iterations in the lifetime of this process
|
||||
|
||||
Loading…
Reference in New Issue
Block a user