mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
166 lines
6.1 KiB
Python
166 lines
6.1 KiB
Python
"""
|
|
Download, preprocess and serve the TinyStories dataset as a DataLoader.
|
|
"""
|
|
|
|
import argparse
|
|
import glob
|
|
import json
|
|
import os
|
|
import random
|
|
from typing import List
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
import torch.distributed as dist
|
|
from tqdm import tqdm
|
|
|
|
from tokenizer import Tokenizer
|
|
|
|
DATA_CACHE_DIR = "data"
|
|
|
|
def download_file(url: str, fname: str, chunk_size=1024):
|
|
"""Helper function to download a file from a given url"""
|
|
resp = requests.get(url, stream=True)
|
|
total = int(resp.headers.get("content-length", 0))
|
|
with open(fname, "wb") as file, tqdm(
|
|
desc=fname,
|
|
total=total,
|
|
unit="iB",
|
|
unit_scale=True,
|
|
unit_divisor=1024,
|
|
) as bar:
|
|
for data in resp.iter_content(chunk_size=chunk_size):
|
|
size = file.write(data)
|
|
bar.update(size)
|
|
|
|
|
|
def download():
|
|
"""Downloads the dataset to disk."""
|
|
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
|
|
|
# download the TinyStories dataset, unless it's already downloaded
|
|
data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
|
|
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
|
|
if not os.path.exists(data_filename):
|
|
print(f"Downloading {data_url} to {data_filename}...")
|
|
download_file(data_url, data_filename)
|
|
else:
|
|
print(f"{data_filename} already exists, skipping download...")
|
|
|
|
# unpack the tar.gz file into all the data shards (json files)
|
|
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
|
if not os.path.exists(data_dir):
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
print(f"Unpacking {data_filename}...")
|
|
os.system(f"tar -xzf {data_filename} -C {data_dir}")
|
|
else:
|
|
print(f"{data_dir} already exists, skipping unpacking...")
|
|
|
|
# print a single example just for debugging and such
|
|
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
|
|
with open(shard_filenames[0], "r") as f:
|
|
data = json.load(f)
|
|
print("Download done.")
|
|
print(f"Number of shards: {len(shard_filenames)}")
|
|
print(f"Example story:\n{data[0]}")
|
|
|
|
def pretokenize():
|
|
enc = Tokenizer()
|
|
|
|
def process_shard(shard):
|
|
with open(shard, "r") as f:
|
|
data = json.load(f)
|
|
all_tokens = []
|
|
for example in tqdm(data):
|
|
text = example["story"]
|
|
text = text.strip() # get rid of leading/trailing whitespace
|
|
tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
|
|
all_tokens.extend(tokens)
|
|
# convert to uint16 nparray
|
|
all_tokens = np.array(all_tokens, dtype=np.uint16)
|
|
# write to disk
|
|
tokenized_filename = shard.replace(".json", ".bin")
|
|
with open(tokenized_filename, "wb") as f:
|
|
f.write(all_tokens.tobytes())
|
|
print(f"Saved {tokenized_filename}")
|
|
|
|
# iterate the shards and tokenize all of them one by one
|
|
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
|
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
|
|
|
|
# process all the shards in a threadpool
|
|
with ThreadPoolExecutor(max_workers=8) as executor:
|
|
executor.map(process_shard, shard_filenames)
|
|
|
|
print("Done.")
|
|
|
|
|
|
class PretokDataset(torch.utils.data.IterableDataset):
|
|
"""Loads pretokenized examples from disk and yields them as PyTorch tensors."""
|
|
|
|
def __init__(self, split, max_seq_len):
|
|
super().__init__()
|
|
self.split = split
|
|
self.max_seq_len = max_seq_len
|
|
|
|
def __iter__(self):
|
|
# get worker info within a DataLoader
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
worker_id = worker_info.id if worker_info else 0
|
|
# get DDP rank info
|
|
rank = dist.get_rank() if dist.is_initialized() else 0
|
|
# combine the worker_id and worker_rank to create a unique seed for rng
|
|
seed = 42 + worker_id + 1337 * rank
|
|
rng = random.Random(seed)
|
|
print(f"Created a PretokDataset with rng seed {seed}")
|
|
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
|
|
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin")))
|
|
# train/test split. let's use only shard 0 for test split, rest train
|
|
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
|
|
while True:
|
|
rng.shuffle(shard_filenames)
|
|
for shard in shard_filenames:
|
|
# open the dataset for reading but keep it on disk with memmap
|
|
m = np.memmap(shard, dtype=np.uint16, mode="r")
|
|
num_batches = len(m) // self.max_seq_len
|
|
num_batches -= 1 # drop the last partial batch
|
|
assert num_batches > 0, "this shard is way too small? investigate."
|
|
ixs = list(range(num_batches))
|
|
rng.shuffle(ixs)
|
|
for ix in ixs:
|
|
start = ix * self.max_seq_len
|
|
end = start + self.max_seq_len + 1
|
|
# calling .astype will copy the data into a new numpy array, now in RAM
|
|
chunk = torch.from_numpy((m[start:end]).astype(np.int64))
|
|
x = chunk[:-1]
|
|
y = chunk[1:]
|
|
yield x, y
|
|
|
|
|
|
class Task:
|
|
|
|
@staticmethod
|
|
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
|
|
ds = PretokDataset(split, max_seq_len)
|
|
dl = torch.utils.data.DataLoader(
|
|
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
|
)
|
|
for x, y in dl:
|
|
x = x.to(device, non_blocking=True)
|
|
y = y.to(device, non_blocking=True)
|
|
yield x, y
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"])
|
|
args = parser.parse_args()
|
|
|
|
# depending on the stage call the appropriate function
|
|
fun = {
|
|
"download": download,
|
|
"pretokenize": pretokenize,
|
|
}
|
|
fun[args.stage]() |