In [ ]:
Copied!
"""
Train nanochat on Modal: Build Your Own ChatGPT from Scratch
Complete nanochat pipeline - from tokenizer training to a functional ChatGPT clone.
Pipeline stages:
1. Tokenizer Training - Custom BPE tokenizer (65K vocab)
2. Base Model Pretraining - GPT on FineWeb dataset
3. Midtraining - Conversation tokens and tool use
4. Supervised Fine-tuning - Task-specific training
5. Reinforcement Learning - Optional RL on GSM8K
6. Comprehensive Evaluation - CORE, ARC, GSM8K, HumanEval, MMLU
7. Inference - Chat CLI and Web UI
GPU Requirements:
- Recommended: 4-8x A100 80GB (~4 hours, ~$96)
- Minimum: 1x A100 80GB (8x longer)
- Testing: 1x A100 40GB (reduced batch sizes)
Setup:
1. Clone nanochat: git clone https://github.com/karpathy/nanochat.git
2. Optional - Set up secrets for WandB:
modal secret create nanochat-secrets WANDB_API_KEY=your_key HUGGINGFACE_TOKEN=your_token
Usage:
modal run TrainNanochatModal.py # Full pipeline
modal run TrainNanochatModal.py::main --num-data-shards=8 --depth=12 # Quick test
modal run TrainNanochatModal.py::chat_cli --source=sft --prompt="Why is the sky blue?"
modal run TrainNanochatModal.py::chat_web --source=sft # Web UI
"""
""" Train nanochat on Modal: Build Your Own ChatGPT from Scratch Complete nanochat pipeline - from tokenizer training to a functional ChatGPT clone. Pipeline stages: 1. Tokenizer Training - Custom BPE tokenizer (65K vocab) 2. Base Model Pretraining - GPT on FineWeb dataset 3. Midtraining - Conversation tokens and tool use 4. Supervised Fine-tuning - Task-specific training 5. Reinforcement Learning - Optional RL on GSM8K 6. Comprehensive Evaluation - CORE, ARC, GSM8K, HumanEval, MMLU 7. Inference - Chat CLI and Web UI GPU Requirements: - Recommended: 4-8x A100 80GB (~4 hours, ~$96) - Minimum: 1x A100 80GB (8x longer) - Testing: 1x A100 40GB (reduced batch sizes) Setup: 1. Clone nanochat: git clone https://github.com/karpathy/nanochat.git 2. Optional - Set up secrets for WandB: modal secret create nanochat-secrets WANDB_API_KEY=your_key HUGGINGFACE_TOKEN=your_token Usage: modal run TrainNanochatModal.py # Full pipeline modal run TrainNanochatModal.py::main --num-data-shards=8 --depth=12 # Quick test modal run TrainNanochatModal.py::chat_cli --source=sft --prompt="Why is the sky blue?" modal run TrainNanochatModal.py::chat_web --source=sft # Web UI """
In [ ]:
Copied!
from modal import App, Image as ModalImage, Volume, Secret
from modal import App, Image as ModalImage, Volume, Secret
============================================================================= CONFIGURATION¶
In [ ]:
Copied!
MINUTES = 60
HOURS = 60 * 60
MINUTES = 60 HOURS = 60 * 60
In [ ]:
Copied!
GPU_TYPE = "a100-80gb"
GPU_TYPE = "a100-80gb"
In [ ]:
Copied!
# Multi-GPU configuration (nanochat supports 1-8 GPUs)
NUM_GPUS_BASE = 4
NUM_GPUS_MID = 4
NUM_GPUS_SFT = 4
NUM_GPUS_RL = 4
NUM_GPUS_EVAL = 4
NUM_GPUS_TOKENIZER = 1
NUM_GPUS_INFERENCE = 1
# Multi-GPU configuration (nanochat supports 1-8 GPUs) NUM_GPUS_BASE = 4 NUM_GPUS_MID = 4 NUM_GPUS_SFT = 4 NUM_GPUS_RL = 4 NUM_GPUS_EVAL = 4 NUM_GPUS_TOKENIZER = 1 NUM_GPUS_INFERENCE = 1
In [ ]:
Copied!
WANDB_PROJECT_DEFAULT = "nanochat-modal"
BASE_DIR = "/data/.cache/nanochat"
WANDB_PROJECT_DEFAULT = "nanochat-modal" BASE_DIR = "/data/.cache/nanochat"
============================================================================= MODAL APP AND VOLUMES¶
In [ ]:
Copied!
app = App("nanochat-training")
app = App("nanochat-training")
In [ ]:
Copied!
data_volume = Volume.from_name("nanochat-data", create_if_missing=True)
checkpoint_volume = Volume.from_name("nanochat-checkpoints", create_if_missing=True)
data_volume = Volume.from_name("nanochat-data", create_if_missing=True) checkpoint_volume = Volume.from_name("nanochat-checkpoints", create_if_missing=True)
In [ ]:
Copied!
VOLUME_CONFIG = {
"/data": data_volume,
"/checkpoints": checkpoint_volume,
}
VOLUME_CONFIG = { "/data": data_volume, "/checkpoints": checkpoint_volume, }
============================================================================= SECRETS SETUP¶
In [ ]:
Copied!
try:
nanochat_secret = Secret.from_dotenv()
print("Loaded secrets from .env file")
except Exception:
try:
nanochat_secret = Secret.from_name("nanochat-secrets")
print("Loaded secrets from Modal")
except Exception:
nanochat_secret = None
print("No secrets found - WandB logging disabled")
try: nanochat_secret = Secret.from_dotenv() print("Loaded secrets from .env file") except Exception: try: nanochat_secret = Secret.from_name("nanochat-secrets") print("Loaded secrets from Modal") except Exception: nanochat_secret = None print("No secrets found - WandB logging disabled")
============================================================================= CONTAINER IMAGE¶
In [ ]:
Copied!
NANOCHAT_IMAGE = (
ModalImage.from_registry("nvidia/cuda:12.8.1-devel-ubuntu24.04", add_python="3.11")
.apt_install("git", "build-essential", "curl", "wget", "unzip")
.run_commands(
"curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y",
"echo 'source $HOME/.cargo/env' >> $HOME/.bashrc",
)
.run_commands(
"curl -LsSf https://astral.sh/uv/install.sh | sh",
"echo 'export PATH=\"$HOME/.cargo/bin:$PATH\"' >> $HOME/.bashrc",
)
.add_local_dir(local_path="nanochat", remote_path="/root/nanochat", copy=True)
.workdir("/root/nanochat")
.run_commands(
"bash -c 'source $HOME/.cargo/env && uv sync && uv run maturin develop --release --manifest-path rustbpe/Cargo.toml'"
)
.env(
{
"OMP_NUM_THREADS": "1",
"NANOCHAT_BASE_DIR": "/data/.cache/nanochat",
"HF_HOME": "/data/.cache/huggingface",
}
)
)
NANOCHAT_IMAGE = ( ModalImage.from_registry("nvidia/cuda:12.8.1-devel-ubuntu24.04", add_python="3.11") .apt_install("git", "build-essential", "curl", "wget", "unzip") .run_commands( "curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y", "echo 'source $HOME/.cargo/env' >> $HOME/.bashrc", ) .run_commands( "curl -LsSf https://astral.sh/uv/install.sh | sh", "echo 'export PATH=\"$HOME/.cargo/bin:$PATH\"' >> $HOME/.bashrc", ) .add_local_dir(local_path="nanochat", remote_path="/root/nanochat", copy=True) .workdir("/root/nanochat") .run_commands( "bash -c 'source $HOME/.cargo/env && uv sync && uv run maturin develop --release --manifest-path rustbpe/Cargo.toml'" ) .env( { "OMP_NUM_THREADS": "1", "NANOCHAT_BASE_DIR": "/data/.cache/nanochat", "HF_HOME": "/data/.cache/huggingface", } ) )
============================================================================= HELPER FUNCTIONS¶
In [ ]:
Copied!
def setup_base_dir():
"""Create base directory structure."""
import os
os.makedirs(BASE_DIR, exist_ok=True)
os.makedirs(f"{BASE_DIR}/base_data", exist_ok=True)
os.makedirs(f"{BASE_DIR}/tokenizer", exist_ok=True)
os.makedirs(f"{BASE_DIR}/checkpoints", exist_ok=True)
os.makedirs(f"{BASE_DIR}/eval_bundle", exist_ok=True)
os.makedirs(f"{BASE_DIR}/report", exist_ok=True)
def setup_base_dir(): """Create base directory structure.""" import os os.makedirs(BASE_DIR, exist_ok=True) os.makedirs(f"{BASE_DIR}/base_data", exist_ok=True) os.makedirs(f"{BASE_DIR}/tokenizer", exist_ok=True) os.makedirs(f"{BASE_DIR}/checkpoints", exist_ok=True) os.makedirs(f"{BASE_DIR}/eval_bundle", exist_ok=True) os.makedirs(f"{BASE_DIR}/report", exist_ok=True)
In [ ]:
Copied!
def setup_secrets():
"""Set up environment variables from secrets."""
import os
if "WANDB_API_KEY" in os.environ:
print("WandB API key found")
else:
print("WandB API key not found - logging disabled")
if "HUGGINGFACE_TOKEN" in os.environ:
os.environ["HF_TOKEN"] = os.environ["HUGGINGFACE_TOKEN"]
print("HuggingFace token found")
else:
print("HuggingFace token not found")
def setup_secrets(): """Set up environment variables from secrets.""" import os if "WANDB_API_KEY" in os.environ: print("WandB API key found") else: print("WandB API key not found - logging disabled") if "HUGGINGFACE_TOKEN" in os.environ: os.environ["HF_TOKEN"] = os.environ["HUGGINGFACE_TOKEN"] print("HuggingFace token found") else: print("HuggingFace token not found")
In [ ]:
Copied!
def run_torchrun_command(script: str, num_gpus: int, extra_args: list = None):
"""Run nanochat script with torchrun for multi-GPU training."""
import subprocess
if extra_args is None:
extra_args = []
extra_args_str = " ".join(extra_args) if extra_args else ""
cmd = f"cd /root/nanochat && uv run torchrun --standalone --nproc_per_node={num_gpus} -m {script}"
if extra_args:
cmd += f" -- {extra_args_str}"
print(f"Running: {cmd}")
result = subprocess.run(["bash", "-c", cmd], capture_output=False, text=True)
if result.returncode != 0:
raise RuntimeError(f"Command failed with code {result.returncode}")
return result
def run_torchrun_command(script: str, num_gpus: int, extra_args: list = None): """Run nanochat script with torchrun for multi-GPU training.""" import subprocess if extra_args is None: extra_args = [] extra_args_str = " ".join(extra_args) if extra_args else "" cmd = f"cd /root/nanochat && uv run torchrun --standalone --nproc_per_node={num_gpus} -m {script}" if extra_args: cmd += f" -- {extra_args_str}" print(f"Running: {cmd}") result = subprocess.run(["bash", "-c", cmd], capture_output=False, text=True) if result.returncode != 0: raise RuntimeError(f"Command failed with code {result.returncode}") return result
============================================================================= STAGE 1: DATASET DOWNLOAD¶
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
volumes=VOLUME_CONFIG,
timeout=2 * HOURS,
)
def download_dataset(num_shards: int = 240):
"""
Download FineWeb dataset shards from HuggingFace.
Each shard is ~250M characters (~100MB compressed).
- Full speedrun: 240 shards (~60B characters, ~24GB)
- Testing: 8 shards (~2B characters, ~800MB)
"""
import subprocess
setup_base_dir()
print("=" * 80)
print(f"DOWNLOADING FINEWEB DATASET - {num_shards} SHARDS")
print("=" * 80)
print(f"Total data: ~{num_shards * 250 / 1000:.1f}B characters (~{num_shards * 100 / 1024:.1f}GB)")
print()
result = subprocess.run(
["bash", "-c", f"cd /root/nanochat && uv run python -m nanochat.dataset -n {num_shards}"],
capture_output=True,
text=True,
)
print(result.stdout)
if result.stderr:
print("STDERR:", result.stderr)
if result.returncode != 0:
raise RuntimeError(f"Dataset download failed with code {result.returncode}")
data_volume.commit()
print("\n" + "=" * 80)
print(f"Downloaded {num_shards} shards successfully")
print("=" * 80)
return {
"status": "completed",
"num_shards": num_shards,
"data_dir": f"{BASE_DIR}/base_data",
}
@app.function( image=NANOCHAT_IMAGE, volumes=VOLUME_CONFIG, timeout=2 * HOURS, ) def download_dataset(num_shards: int = 240): """ Download FineWeb dataset shards from HuggingFace. Each shard is ~250M characters (~100MB compressed). - Full speedrun: 240 shards (~60B characters, ~24GB) - Testing: 8 shards (~2B characters, ~800MB) """ import subprocess setup_base_dir() print("=" * 80) print(f"DOWNLOADING FINEWEB DATASET - {num_shards} SHARDS") print("=" * 80) print(f"Total data: ~{num_shards * 250 / 1000:.1f}B characters (~{num_shards * 100 / 1024:.1f}GB)") print() result = subprocess.run( ["bash", "-c", f"cd /root/nanochat && uv run python -m nanochat.dataset -n {num_shards}"], capture_output=True, text=True, ) print(result.stdout) if result.stderr: print("STDERR:", result.stderr) if result.returncode != 0: raise RuntimeError(f"Dataset download failed with code {result.returncode}") data_volume.commit() print("\n" + "=" * 80) print(f"Downloaded {num_shards} shards successfully") print("=" * 80) return { "status": "completed", "num_shards": num_shards, "data_dir": f"{BASE_DIR}/base_data", }
============================================================================= STAGE 2: TOKENIZER TRAINING¶
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_TOKENIZER}",
volumes=VOLUME_CONFIG,
timeout=2 * HOURS,
)
def train_tokenizer(
max_chars: int = 2_000_000_000,
vocab_size: int = 65536,
doc_cap: int = 10000,
):
"""
Train a custom BPE tokenizer on FineWeb data.
Training takes 30-60 minutes on a single GPU.
"""
import subprocess
setup_base_dir()
print("=" * 80)
print("TRAINING CUSTOM BPE TOKENIZER")
print("=" * 80)
print(f"Max characters: {max_chars:,}")
print(f"Vocabulary size: {vocab_size:,}")
print(f"Document cap: {doc_cap:,}")
print()
cmd = f"cd /root/nanochat && uv run python -m scripts.tok_train --max_chars={max_chars} --vocab_size={vocab_size} --doc_cap={doc_cap}"
print(f"Running: {cmd}")
result = subprocess.run(["bash", "-c", cmd], capture_output=False, text=True)
if result.returncode != 0:
raise RuntimeError(f"Tokenizer training failed with code {result.returncode}")
data_volume.commit()
checkpoint_volume.commit()
print("\n" + "=" * 80)
print("Tokenizer training completed")
print(f"Tokenizer saved to {BASE_DIR}/tokenizer/")
print("=" * 80)
return {
"status": "completed",
"max_chars": max_chars,
"vocab_size": vocab_size,
"tokenizer_dir": f"{BASE_DIR}/tokenizer",
}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_TOKENIZER}", volumes=VOLUME_CONFIG, timeout=2 * HOURS, ) def train_tokenizer( max_chars: int = 2_000_000_000, vocab_size: int = 65536, doc_cap: int = 10000, ): """ Train a custom BPE tokenizer on FineWeb data. Training takes 30-60 minutes on a single GPU. """ import subprocess setup_base_dir() print("=" * 80) print("TRAINING CUSTOM BPE TOKENIZER") print("=" * 80) print(f"Max characters: {max_chars:,}") print(f"Vocabulary size: {vocab_size:,}") print(f"Document cap: {doc_cap:,}") print() cmd = f"cd /root/nanochat && uv run python -m scripts.tok_train --max_chars={max_chars} --vocab_size={vocab_size} --doc_cap={doc_cap}" print(f"Running: {cmd}") result = subprocess.run(["bash", "-c", cmd], capture_output=False, text=True) if result.returncode != 0: raise RuntimeError(f"Tokenizer training failed with code {result.returncode}") data_volume.commit() checkpoint_volume.commit() print("\n" + "=" * 80) print("Tokenizer training completed") print(f"Tokenizer saved to {BASE_DIR}/tokenizer/") print("=" * 80) return { "status": "completed", "max_chars": max_chars, "vocab_size": vocab_size, "tokenizer_dir": f"{BASE_DIR}/tokenizer", }
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_TOKENIZER}",
volumes=VOLUME_CONFIG,
timeout=30 * MINUTES,
)
def evaluate_tokenizer():
"""Evaluate the trained tokenizer."""
import subprocess
print("=" * 80)
print("EVALUATING TOKENIZER")
print("=" * 80)
result = subprocess.run(
["bash", "-c", "cd /root/nanochat && uv run python -m scripts.tok_eval"],
capture_output=False,
text=True
)
if result.returncode != 0:
raise RuntimeError(f"Tokenizer evaluation failed with code {result.returncode}")
print("\n" + "=" * 80)
print("Tokenizer evaluation completed")
print("=" * 80)
return {"status": "completed"}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_TOKENIZER}", volumes=VOLUME_CONFIG, timeout=30 * MINUTES, ) def evaluate_tokenizer(): """Evaluate the trained tokenizer.""" import subprocess print("=" * 80) print("EVALUATING TOKENIZER") print("=" * 80) result = subprocess.run( ["bash", "-c", "cd /root/nanochat && uv run python -m scripts.tok_eval"], capture_output=False, text=True ) if result.returncode != 0: raise RuntimeError(f"Tokenizer evaluation failed with code {result.returncode}") print("\n" + "=" * 80) print("Tokenizer evaluation completed") print("=" * 80) return {"status": "completed"}
============================================================================= STAGE 3: BASE MODEL PRETRAINING¶
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_BASE}",
volumes=VOLUME_CONFIG,
secrets=[nanochat_secret] if nanochat_secret else [],
timeout=8 * HOURS,
)
def train_base_model(
depth: int = 20,
device_batch_size: int = 32,
max_iterations: int = -1,
wandb_run: str = "dummy",
):
"""
Pretrain the base GPT model on FineWeb.
Model sizes: depth=20 (561M params), depth=26 (1B params)
Training duration: ~2-3 hours on 8 GPUs, ~16-24 hours on 1 GPU
"""
import subprocess
import os
setup_base_dir()
setup_secrets()
eval_bundle_path = f"{BASE_DIR}/eval_bundle"
if not os.path.exists(eval_bundle_path):
print("Downloading eval bundle...")
subprocess.run(
[
"curl",
"-L",
"-o",
"eval_bundle.zip",
"https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip",
],
check=True,
)
subprocess.run(["unzip", "-q", "eval_bundle.zip"], check=True)
subprocess.run(["mv", "eval_bundle", eval_bundle_path], check=True)
subprocess.run(["rm", "eval_bundle.zip"], check=True)
print("=" * 80)
print("PRETRAINING BASE MODEL ON FINEWEB")
print("=" * 80)
print(f"Model depth: {depth}")
print(f"Estimated parameters: {depth * depth * 64 * 12 // 1_000_000}M")
print(f"Device batch size: {device_batch_size}")
print(f"Number of GPUs: {NUM_GPUS_BASE}")
print(f"WandB run: {wandb_run}")
print()
extra_args = [
f"--depth={depth}",
f"--device_batch_size={device_batch_size}",
f"--run={wandb_run}",
]
if max_iterations > 0:
extra_args.append(f"--num_iterations={max_iterations}")
run_torchrun_command("scripts.base_train", NUM_GPUS_BASE, extra_args)
checkpoint_volume.commit()
print("\n" + "=" * 80)
print("Base model training completed")
print(f"Checkpoints saved to {BASE_DIR}/checkpoints/base/")
print("=" * 80)
return {
"status": "completed",
"depth": depth,
"device_batch_size": device_batch_size,
"num_gpus": NUM_GPUS_BASE,
"checkpoint_dir": f"{BASE_DIR}/checkpoints/base",
}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_BASE}", volumes=VOLUME_CONFIG, secrets=[nanochat_secret] if nanochat_secret else [], timeout=8 * HOURS, ) def train_base_model( depth: int = 20, device_batch_size: int = 32, max_iterations: int = -1, wandb_run: str = "dummy", ): """ Pretrain the base GPT model on FineWeb. Model sizes: depth=20 (561M params), depth=26 (1B params) Training duration: ~2-3 hours on 8 GPUs, ~16-24 hours on 1 GPU """ import subprocess import os setup_base_dir() setup_secrets() eval_bundle_path = f"{BASE_DIR}/eval_bundle" if not os.path.exists(eval_bundle_path): print("Downloading eval bundle...") subprocess.run( [ "curl", "-L", "-o", "eval_bundle.zip", "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip", ], check=True, ) subprocess.run(["unzip", "-q", "eval_bundle.zip"], check=True) subprocess.run(["mv", "eval_bundle", eval_bundle_path], check=True) subprocess.run(["rm", "eval_bundle.zip"], check=True) print("=" * 80) print("PRETRAINING BASE MODEL ON FINEWEB") print("=" * 80) print(f"Model depth: {depth}") print(f"Estimated parameters: {depth * depth * 64 * 12 // 1_000_000}M") print(f"Device batch size: {device_batch_size}") print(f"Number of GPUs: {NUM_GPUS_BASE}") print(f"WandB run: {wandb_run}") print() extra_args = [ f"--depth={depth}", f"--device_batch_size={device_batch_size}", f"--run={wandb_run}", ] if max_iterations > 0: extra_args.append(f"--num_iterations={max_iterations}") run_torchrun_command("scripts.base_train", NUM_GPUS_BASE, extra_args) checkpoint_volume.commit() print("\n" + "=" * 80) print("Base model training completed") print(f"Checkpoints saved to {BASE_DIR}/checkpoints/base/") print("=" * 80) return { "status": "completed", "depth": depth, "device_batch_size": device_batch_size, "num_gpus": NUM_GPUS_BASE, "checkpoint_dir": f"{BASE_DIR}/checkpoints/base", }
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_EVAL}",
volumes=VOLUME_CONFIG,
timeout=1 * HOURS,
)
def evaluate_base_model(max_per_task: int = 500):
"""Evaluate base model on CORE benchmark."""
print("=" * 80)
print("EVALUATING BASE MODEL - CORE METRIC")
print("=" * 80)
extra_args = []
if max_per_task > 0:
extra_args.append(f"--core_metric_max_per_task={max_per_task}")
run_torchrun_command("scripts.base_eval", NUM_GPUS_EVAL, extra_args)
checkpoint_volume.commit()
print("\n" + "=" * 80)
print("Base model evaluation completed")
print("=" * 80)
return {"status": "completed"}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_EVAL}", volumes=VOLUME_CONFIG, timeout=1 * HOURS, ) def evaluate_base_model(max_per_task: int = 500): """Evaluate base model on CORE benchmark.""" print("=" * 80) print("EVALUATING BASE MODEL - CORE METRIC") print("=" * 80) extra_args = [] if max_per_task > 0: extra_args.append(f"--core_metric_max_per_task={max_per_task}") run_torchrun_command("scripts.base_eval", NUM_GPUS_EVAL, extra_args) checkpoint_volume.commit() print("\n" + "=" * 80) print("Base model evaluation completed") print("=" * 80) return {"status": "completed"}
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_EVAL}",
volumes=VOLUME_CONFIG,
timeout=1 * HOURS,
)
def evaluate_base_loss():
"""Evaluate base model validation loss (bits per byte)."""
print("=" * 80)
print("EVALUATING BASE MODEL - VALIDATION LOSS")
print("=" * 80)
run_torchrun_command("scripts.base_loss", NUM_GPUS_EVAL)
checkpoint_volume.commit()
print("\n" + "=" * 80)
print("Base loss evaluation completed")
print("=" * 80)
return {"status": "completed"}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_EVAL}", volumes=VOLUME_CONFIG, timeout=1 * HOURS, ) def evaluate_base_loss(): """Evaluate base model validation loss (bits per byte).""" print("=" * 80) print("EVALUATING BASE MODEL - VALIDATION LOSS") print("=" * 80) run_torchrun_command("scripts.base_loss", NUM_GPUS_EVAL) checkpoint_volume.commit() print("\n" + "=" * 80) print("Base loss evaluation completed") print("=" * 80) return {"status": "completed"}
============================================================================= STAGE 4: MIDTRAINING¶
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_MID}",
volumes=VOLUME_CONFIG,
secrets=[nanochat_secret] if nanochat_secret else [],
timeout=2 * HOURS,
)
def train_mid_model(
device_batch_size: int = 32,
wandb_run: str = "dummy",
):
"""
Midtrain the model on conversation data.
Teaches conversation tokens, tool use, and multiple choice format.
Duration: ~30-45 minutes on 8 GPUs
"""
setup_secrets()
print("=" * 80)
print("MIDTRAINING - TEACHING CONVERSATION TOKENS")
print("=" * 80)
print(f"Device batch size: {device_batch_size}")
print(f"Number of GPUs: {NUM_GPUS_MID}")
print()
extra_args = [
f"--device_batch_size={device_batch_size}",
f"--run={wandb_run}",
]
run_torchrun_command("scripts.mid_train", NUM_GPUS_MID, extra_args)
checkpoint_volume.commit()
print("\n" + "=" * 80)
print("Midtraining completed")
print(f"Checkpoints saved to {BASE_DIR}/checkpoints/mid/")
print("=" * 80)
return {
"status": "completed",
"checkpoint_dir": f"{BASE_DIR}/checkpoints/mid",
}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_MID}", volumes=VOLUME_CONFIG, secrets=[nanochat_secret] if nanochat_secret else [], timeout=2 * HOURS, ) def train_mid_model( device_batch_size: int = 32, wandb_run: str = "dummy", ): """ Midtrain the model on conversation data. Teaches conversation tokens, tool use, and multiple choice format. Duration: ~30-45 minutes on 8 GPUs """ setup_secrets() print("=" * 80) print("MIDTRAINING - TEACHING CONVERSATION TOKENS") print("=" * 80) print(f"Device batch size: {device_batch_size}") print(f"Number of GPUs: {NUM_GPUS_MID}") print() extra_args = [ f"--device_batch_size={device_batch_size}", f"--run={wandb_run}", ] run_torchrun_command("scripts.mid_train", NUM_GPUS_MID, extra_args) checkpoint_volume.commit() print("\n" + "=" * 80) print("Midtraining completed") print(f"Checkpoints saved to {BASE_DIR}/checkpoints/mid/") print("=" * 80) return { "status": "completed", "checkpoint_dir": f"{BASE_DIR}/checkpoints/mid", }
============================================================================= STAGE 5: SUPERVISED FINE-TUNING (SFT)¶
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_SFT}",
volumes=VOLUME_CONFIG,
secrets=[nanochat_secret] if nanochat_secret else [],
timeout=2 * HOURS,
)
def train_sft_model(
device_batch_size: int = 4,
num_epochs: int = 1,
wandb_run: str = "dummy",
source: str = "mid",
):
"""
Supervised fine-tuning on task-specific data.
Trains on MMLU, ARC, GSM8K, HumanEval, and SmolTalk.
Duration: ~30-45 minutes on 8 GPUs
"""
setup_secrets()
print("=" * 80)
print("SUPERVISED FINE-TUNING")
print("=" * 80)
print(f"Source: {source}")
print(f"Device batch size: {device_batch_size}")
print(f"Number of GPUs: {NUM_GPUS_SFT}")
print(f"Epochs: {num_epochs}")
print()
extra_args = [
f"--device_batch_size={device_batch_size}",
f"--num_epochs={num_epochs}",
f"--run={wandb_run}",
f"--source={source}",
]
run_torchrun_command("scripts.chat_sft", NUM_GPUS_SFT, extra_args)
checkpoint_volume.commit()
print("\n" + "=" * 80)
print("SFT completed")
print(f"Checkpoints saved to {BASE_DIR}/checkpoints/sft/")
print("=" * 80)
return {
"status": "completed",
"checkpoint_dir": f"{BASE_DIR}/checkpoints/sft",
}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_SFT}", volumes=VOLUME_CONFIG, secrets=[nanochat_secret] if nanochat_secret else [], timeout=2 * HOURS, ) def train_sft_model( device_batch_size: int = 4, num_epochs: int = 1, wandb_run: str = "dummy", source: str = "mid", ): """ Supervised fine-tuning on task-specific data. Trains on MMLU, ARC, GSM8K, HumanEval, and SmolTalk. Duration: ~30-45 minutes on 8 GPUs """ setup_secrets() print("=" * 80) print("SUPERVISED FINE-TUNING") print("=" * 80) print(f"Source: {source}") print(f"Device batch size: {device_batch_size}") print(f"Number of GPUs: {NUM_GPUS_SFT}") print(f"Epochs: {num_epochs}") print() extra_args = [ f"--device_batch_size={device_batch_size}", f"--num_epochs={num_epochs}", f"--run={wandb_run}", f"--source={source}", ] run_torchrun_command("scripts.chat_sft", NUM_GPUS_SFT, extra_args) checkpoint_volume.commit() print("\n" + "=" * 80) print("SFT completed") print(f"Checkpoints saved to {BASE_DIR}/checkpoints/sft/") print("=" * 80) return { "status": "completed", "checkpoint_dir": f"{BASE_DIR}/checkpoints/sft", }
============================================================================= STAGE 6: REINFORCEMENT LEARNING (OPTIONAL)¶
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_RL}",
volumes=VOLUME_CONFIG,
secrets=[nanochat_secret] if nanochat_secret else [],
timeout=2 * HOURS,
)
def train_rl_model(
device_batch_size: int = 8,
num_epochs: int = 1,
wandb_run: str = "dummy",
source: str = "sft",
):
"""
Reinforcement learning on GSM8K (optional).
Uses GRPO/REINFORCE to improve math reasoning.
Duration: ~30-45 minutes on 8 GPUs
"""
setup_secrets()
print("=" * 80)
print("REINFORCEMENT LEARNING ON GSM8K")
print("=" * 80)
print(f"Source: {source}")
print(f"Device batch size: {device_batch_size}")
print(f"Number of GPUs: {NUM_GPUS_RL}")
print(f"Epochs: {num_epochs}")
print()
extra_args = [
f"--device_batch_size={device_batch_size}",
f"--num_epochs={num_epochs}",
f"--run={wandb_run}",
f"--source={source}",
]
run_torchrun_command("scripts.chat_rl", NUM_GPUS_RL, extra_args)
checkpoint_volume.commit()
print("\n" + "=" * 80)
print("RL training completed")
print(f"Checkpoints saved to {BASE_DIR}/checkpoints/rl/")
print("=" * 80)
return {
"status": "completed",
"checkpoint_dir": f"{BASE_DIR}/checkpoints/rl",
}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_RL}", volumes=VOLUME_CONFIG, secrets=[nanochat_secret] if nanochat_secret else [], timeout=2 * HOURS, ) def train_rl_model( device_batch_size: int = 8, num_epochs: int = 1, wandb_run: str = "dummy", source: str = "sft", ): """ Reinforcement learning on GSM8K (optional). Uses GRPO/REINFORCE to improve math reasoning. Duration: ~30-45 minutes on 8 GPUs """ setup_secrets() print("=" * 80) print("REINFORCEMENT LEARNING ON GSM8K") print("=" * 80) print(f"Source: {source}") print(f"Device batch size: {device_batch_size}") print(f"Number of GPUs: {NUM_GPUS_RL}") print(f"Epochs: {num_epochs}") print() extra_args = [ f"--device_batch_size={device_batch_size}", f"--num_epochs={num_epochs}", f"--run={wandb_run}", f"--source={source}", ] run_torchrun_command("scripts.chat_rl", NUM_GPUS_RL, extra_args) checkpoint_volume.commit() print("\n" + "=" * 80) print("RL training completed") print(f"Checkpoints saved to {BASE_DIR}/checkpoints/rl/") print("=" * 80) return { "status": "completed", "checkpoint_dir": f"{BASE_DIR}/checkpoints/rl", }
============================================================================= STAGE 7: EVALUATION¶
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_EVAL}",
volumes=VOLUME_CONFIG,
timeout=2 * HOURS,
)
def evaluate_chat_model(
source: str = "sft",
tasks: str = "all",
):
"""
Evaluate the chat model on benchmark tasks.
Available tasks: ARC-Easy, ARC-Challenge, GSM8K, HumanEval, MMLU, ChatCORE
"""
print("=" * 80)
print(f"EVALUATING CHAT MODEL - {source.upper()}")
print("=" * 80)
print(f"Tasks: {tasks}")
print()
extra_args = ["-i", source]
if tasks != "all":
extra_args.extend(["-a", tasks])
run_torchrun_command("scripts.chat_eval", NUM_GPUS_EVAL, extra_args)
checkpoint_volume.commit()
print("\n" + "=" * 80)
print(f"Evaluation of {source} model completed")
print("=" * 80)
return {
"status": "completed",
"source": source,
"tasks": tasks,
}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_EVAL}", volumes=VOLUME_CONFIG, timeout=2 * HOURS, ) def evaluate_chat_model( source: str = "sft", tasks: str = "all", ): """ Evaluate the chat model on benchmark tasks. Available tasks: ARC-Easy, ARC-Challenge, GSM8K, HumanEval, MMLU, ChatCORE """ print("=" * 80) print(f"EVALUATING CHAT MODEL - {source.upper()}") print("=" * 80) print(f"Tasks: {tasks}") print() extra_args = ["-i", source] if tasks != "all": extra_args.extend(["-a", tasks]) run_torchrun_command("scripts.chat_eval", NUM_GPUS_EVAL, extra_args) checkpoint_volume.commit() print("\n" + "=" * 80) print(f"Evaluation of {source} model completed") print("=" * 80) return { "status": "completed", "source": source, "tasks": tasks, }
============================================================================= STAGE 8: INFERENCE¶
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_INFERENCE}",
volumes=VOLUME_CONFIG,
timeout=1 * HOURS,
)
def chat_cli(
source: str = "sft",
prompt: str = "",
temperature: float = 0.6,
top_k: int = 50,
):
"""Chat with the model via command line interface."""
import subprocess
print("=" * 80)
print(f"CHAT CLI - {source.upper()} MODEL")
print("=" * 80)
cmd = f"cd /root/nanochat && uv run python -m scripts.chat_cli -i {source} -t {temperature} -k {top_k}"
if prompt:
escaped_prompt = prompt.replace('"', '\\"')
cmd += f' -p "{escaped_prompt}"'
print(f"Running: {cmd}")
result = subprocess.run(["bash", "-c", cmd], capture_output=False, text=True)
if result.returncode != 0:
raise RuntimeError(f"Chat CLI failed with code {result.returncode}")
return {
"status": "completed",
"source": source,
"prompt": prompt,
}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_INFERENCE}", volumes=VOLUME_CONFIG, timeout=1 * HOURS, ) def chat_cli( source: str = "sft", prompt: str = "", temperature: float = 0.6, top_k: int = 50, ): """Chat with the model via command line interface.""" import subprocess print("=" * 80) print(f"CHAT CLI - {source.upper()} MODEL") print("=" * 80) cmd = f"cd /root/nanochat && uv run python -m scripts.chat_cli -i {source} -t {temperature} -k {top_k}" if prompt: escaped_prompt = prompt.replace('"', '\\"') cmd += f' -p "{escaped_prompt}"' print(f"Running: {cmd}") result = subprocess.run(["bash", "-c", cmd], capture_output=False, text=True) if result.returncode != 0: raise RuntimeError(f"Chat CLI failed with code {result.returncode}") return { "status": "completed", "source": source, "prompt": prompt, }
In [ ]:
Copied!
@app.function(
image=NANOCHAT_IMAGE,
gpu=f"{GPU_TYPE}:{NUM_GPUS_INFERENCE}",
volumes=VOLUME_CONFIG,
timeout=4 * HOURS,
max_containers=2,
)
def chat_web(
source: str = "sft",
port: int = 8000,
temperature: float = 0.8,
top_k: int = 50,
max_tokens: int = 512,
):
"""Serve the chat model via a web UI."""
import subprocess
print("=" * 80)
print(f"STARTING WEB UI - {source.upper()} MODEL")
print("=" * 80)
print(f"Port: {port}")
print(f"Temperature: {temperature}")
print(f"Top-k: {top_k}")
print(f"Max tokens: {max_tokens}")
print()
cmd = f"cd /root/nanochat && uv run python -m scripts.chat_web -i {source} -p {port} -t {temperature} -k {top_k} -m {max_tokens} --host 0.0.0.0"
print(f"Running: {cmd}")
print("\n" + "=" * 80)
print(f"Web UI will be available at: http://localhost:{port}")
print("=" * 80)
print()
result = subprocess.run(["bash", "-c", cmd], capture_output=False, text=True)
if result.returncode != 0:
raise RuntimeError(f"Web server failed with code {result.returncode}")
return {
"status": "completed",
"source": source,
"port": port,
}
@app.function( image=NANOCHAT_IMAGE, gpu=f"{GPU_TYPE}:{NUM_GPUS_INFERENCE}", volumes=VOLUME_CONFIG, timeout=4 * HOURS, max_containers=2, ) def chat_web( source: str = "sft", port: int = 8000, temperature: float = 0.8, top_k: int = 50, max_tokens: int = 512, ): """Serve the chat model via a web UI.""" import subprocess print("=" * 80) print(f"STARTING WEB UI - {source.upper()} MODEL") print("=" * 80) print(f"Port: {port}") print(f"Temperature: {temperature}") print(f"Top-k: {top_k}") print(f"Max tokens: {max_tokens}") print() cmd = f"cd /root/nanochat && uv run python -m scripts.chat_web -i {source} -p {port} -t {temperature} -k {top_k} -m {max_tokens} --host 0.0.0.0" print(f"Running: {cmd}") print("\n" + "=" * 80) print(f"Web UI will be available at: http://localhost:{port}") print("=" * 80) print() result = subprocess.run(["bash", "-c", cmd], capture_output=False, text=True) if result.returncode != 0: raise RuntimeError(f"Web server failed with code {result.returncode}") return { "status": "completed", "source": source, "port": port, }
============================================================================= MAIN PIPELINE¶
In [ ]:
Copied!
@app.local_entrypoint()
def main(
run_download: bool = True,
run_tokenizer: bool = True,
run_base: bool = True,
run_mid: bool = True,
run_sft: bool = True,
run_rl: bool = False,
run_eval: bool = True,
run_inference: bool = True,
num_data_shards: int = 240,
depth: int = 20,
device_batch_size_base: int = 32,
device_batch_size_sft: int = 4,
wandb_run: str = "dummy",
):
"""
Run the complete nanochat pipeline from scratch.
Pipeline stages:
1. Download FineWeb dataset
2. Train + evaluate tokenizer
3. Train + evaluate base model
4. Train + evaluate mid model
5. Train + evaluate SFT model
6. (Optional) Train + evaluate RL model
7. Run final inference test
Configuration modes:
- Full Speedrun (4h, $96): num_data_shards=240, depth=20
- Quick Test (1h, $24): num_data_shards=8, depth=12
- GPT-2 Grade (12h, $288): num_data_shards=450, depth=26
"""
print("=" * 80)
print("NANOCHAT TRAINING PIPELINE")
print("=" * 80)
print(f"Mode: {'Speedrun' if num_data_shards >= 240 else 'Quick Test'}")
print(f"Data shards: {num_data_shards}")
print(f"Model depth: {depth}")
print(f"WandB run: {wandb_run}")
print("=" * 80)
print()
if run_download:
print("Stage 1/8: Downloading dataset...")
download_dataset.remote(num_shards=num_data_shards)
if run_tokenizer:
print("\nStage 2/8: Training tokenizer...")
train_tokenizer.remote()
print("Evaluating tokenizer...")
evaluate_tokenizer.remote()
if run_base:
print("\nStage 3/8: Training base model...")
train_base_model.remote(
depth=depth, device_batch_size=device_batch_size_base, wandb_run=wandb_run
)
if run_eval:
print("Evaluating base model (CORE)...")
evaluate_base_model.remote()
print("Evaluating base model (loss)...")
evaluate_base_loss.remote()
if run_mid:
print("\nStage 4/8: Midtraining (conversation tokens)...")
train_mid_model.remote(
device_batch_size=device_batch_size_base, wandb_run=wandb_run
)
if run_eval:
print("Evaluating mid model...")
evaluate_chat_model.remote(source="mid")
if run_sft:
print("\nStage 5/8: Supervised fine-tuning...")
train_sft_model.remote(
device_batch_size=device_batch_size_sft, wandb_run=wandb_run, source="mid"
)
if run_eval:
print("Evaluating SFT model...")
evaluate_chat_model.remote(source="sft")
if run_rl:
print("\nStage 6/8: Reinforcement learning...")
train_rl_model.remote(wandb_run=wandb_run)
if run_eval:
print("Evaluating RL model...")
evaluate_chat_model.remote(source="rl", tasks="GSM8K")
if run_inference:
print("\nStage 7/8: Testing inference...")
final_source = "rl" if run_rl else "sft"
chat_cli.remote(source=final_source, prompt="Why is the sky blue?")
print("\n" + "=" * 80)
print("PIPELINE COMPLETED")
print("=" * 80)
print()
print("Next steps:")
print("1. Chat via CLI: modal run TrainNanochatModal.py::chat_cli --source=sft")
print("2. Launch Web UI: modal run TrainNanochatModal.py::chat_web --source=sft")
print("3. Run more evals: modal run TrainNanochatModal.py::evaluate_chat_model --source=sft")
print()
@app.local_entrypoint() def main( run_download: bool = True, run_tokenizer: bool = True, run_base: bool = True, run_mid: bool = True, run_sft: bool = True, run_rl: bool = False, run_eval: bool = True, run_inference: bool = True, num_data_shards: int = 240, depth: int = 20, device_batch_size_base: int = 32, device_batch_size_sft: int = 4, wandb_run: str = "dummy", ): """ Run the complete nanochat pipeline from scratch. Pipeline stages: 1. Download FineWeb dataset 2. Train + evaluate tokenizer 3. Train + evaluate base model 4. Train + evaluate mid model 5. Train + evaluate SFT model 6. (Optional) Train + evaluate RL model 7. Run final inference test Configuration modes: - Full Speedrun (4h, $96): num_data_shards=240, depth=20 - Quick Test (1h, $24): num_data_shards=8, depth=12 - GPT-2 Grade (12h, $288): num_data_shards=450, depth=26 """ print("=" * 80) print("NANOCHAT TRAINING PIPELINE") print("=" * 80) print(f"Mode: {'Speedrun' if num_data_shards >= 240 else 'Quick Test'}") print(f"Data shards: {num_data_shards}") print(f"Model depth: {depth}") print(f"WandB run: {wandb_run}") print("=" * 80) print() if run_download: print("Stage 1/8: Downloading dataset...") download_dataset.remote(num_shards=num_data_shards) if run_tokenizer: print("\nStage 2/8: Training tokenizer...") train_tokenizer.remote() print("Evaluating tokenizer...") evaluate_tokenizer.remote() if run_base: print("\nStage 3/8: Training base model...") train_base_model.remote( depth=depth, device_batch_size=device_batch_size_base, wandb_run=wandb_run ) if run_eval: print("Evaluating base model (CORE)...") evaluate_base_model.remote() print("Evaluating base model (loss)...") evaluate_base_loss.remote() if run_mid: print("\nStage 4/8: Midtraining (conversation tokens)...") train_mid_model.remote( device_batch_size=device_batch_size_base, wandb_run=wandb_run ) if run_eval: print("Evaluating mid model...") evaluate_chat_model.remote(source="mid") if run_sft: print("\nStage 5/8: Supervised fine-tuning...") train_sft_model.remote( device_batch_size=device_batch_size_sft, wandb_run=wandb_run, source="mid" ) if run_eval: print("Evaluating SFT model...") evaluate_chat_model.remote(source="sft") if run_rl: print("\nStage 6/8: Reinforcement learning...") train_rl_model.remote(wandb_run=wandb_run) if run_eval: print("Evaluating RL model...") evaluate_chat_model.remote(source="rl", tasks="GSM8K") if run_inference: print("\nStage 7/8: Testing inference...") final_source = "rl" if run_rl else "sft" chat_cli.remote(source=final_source, prompt="Why is the sky blue?") print("\n" + "=" * 80) print("PIPELINE COMPLETED") print("=" * 80) print() print("Next steps:") print("1. Chat via CLI: modal run TrainNanochatModal.py::chat_cli --source=sft") print("2. Launch Web UI: modal run TrainNanochatModal.py::chat_web --source=sft") print("3. Run more evals: modal run TrainNanochatModal.py::evaluate_chat_model --source=sft") print()