diff --git a/.gitignore b/.gitignore index ebe8e61bd0..d2c130f80a 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,16 @@ pytensor-venv/ testing-report.html coverage.xml .coverage.* + +# ai +.ai/ +.claude/ +.cursor/ +.CLAUDE.md + +# gallery notebook downloaded data +doc/gallery/**/data/ + +# JupyterLab session artifacts +.jupyter/ +.jupyter_ystore.db \ No newline at end of file diff --git a/doc/gallery/transformers/tiny_transformer_llm.ipynb b/doc/gallery/transformers/tiny_transformer_llm.ipynb new file mode 100644 index 0000000000..f6fe713e65 --- /dev/null +++ b/doc/gallery/transformers/tiny_transformer_llm.ipynb @@ -0,0 +1,976 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(Tiny_Transformer_LLM)=\n", + "# A Tiny Transformer LLM in PyTensor\n", + "\n", + "Can you train a transformer language model end-to-end in PyTensor? Yes — and that is exactly what we do in this notebook. We build a small **decoder-only GPT** on character-level Tiny Shakespeare, train it with a hand-rolled **Adam** optimizer, and sample from it. Along the way we showcase what makes PyTensor distinctive for deep learning:\n", + "\n", + "* `pytensor.xtensor` **named dimensions** drive multi-head attention, so the `(batch, head, time, head_dim)` reshape gymnastics become self-documenting code,\n", + "* a **static graph** you can inspect with `pytensor.dprint`,\n", + "* symbolic **reverse-mode auto-diff** via `pytensor.grad`,\n", + "* `pytensor.shared` parameter state with **in-graph optimizer updates**, and\n", + "* `pytensor.scan` for **autoregressive generation** — the entire sampling loop, including all forward passes and all categorical draws, runs inside a single compiled call.\n", + "\n", + "The model is intentionally tiny (~100k parameters, a few thousand training steps) so the whole notebook runs on a laptop CPU in a couple of minutes.\n", + "\n", + "## Prepare notebook" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:26.660089Z", + "iopub.status.busy": "2026-05-07T13:13:26.659993Z", + "iopub.status.idle": "2026-05-07T13:13:27.126450Z", + "shell.execute_reply": "2026-05-07T13:13:27.125234Z" + } + }, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "import time\n", + "import urllib.request\n", + "from dataclasses import dataclass\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "import pytensor.xtensor as ptx\n", + "\n", + "\n", + "plt.style.use(\"seaborn-v0_8\")\n", + "\n", + "%config InlineBackend.figure_format = \"retina\"\n", + "\n", + "rng_np = np.random.default_rng(0)\n", + "floatX = pytensor.config.floatX\n", + "print(\"pytensor:\", pytensor.__version__, \"| floatX:\", floatX)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data — Tiny Shakespeare\n", + "\n", + "We download Andrej Karpathy's 1.1 MB **Tiny Shakespeare** corpus into a local cache directory and use a character-level vocabulary (~65 unique characters). To keep training fast on a laptop, we slice off only the first ~50,000 characters — plenty for the model to learn the shape, rhythm, and common words of Shakespearean English." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:27.129250Z", + "iopub.status.busy": "2026-05-07T13:13:27.128943Z", + "iopub.status.idle": "2026-05-07T13:13:27.134375Z", + "shell.execute_reply": "2026-05-07T13:13:27.133378Z" + } + }, + "outputs": [], + "source": [ + "DATA_DIR = Path(\"data\")\n", + "DATA_DIR.mkdir(parents=True, exist_ok=True)\n", + "DATA_FILE = DATA_DIR / \"tinyshakespeare.txt\"\n", + "DATA_URL = (\n", + " \"https://raw.githubusercontent.com/karpathy/char-rnn/\"\n", + " \"master/data/tinyshakespeare/input.txt\"\n", + ")\n", + "\n", + "if not DATA_FILE.exists():\n", + " print(\"Downloading\", DATA_URL)\n", + " urllib.request.urlretrieve(DATA_URL, DATA_FILE)\n", + "\n", + "full_text = DATA_FILE.read_text()\n", + "print(f\"Full corpus: {len(full_text):,} characters\")\n", + "\n", + "# Use a small slice so training is fast in a notebook.\n", + "text = full_text[:50_000]\n", + "print(f\"Used slice : {len(text):,} characters\")\n", + "print(\"\\n--- First 200 characters ---\\n\")\n", + "print(text[:200])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:27.136096Z", + "iopub.status.busy": "2026-05-07T13:13:27.135971Z", + "iopub.status.idle": "2026-05-07T13:13:27.143040Z", + "shell.execute_reply": "2026-05-07T13:13:27.141910Z" + } + }, + "outputs": [], + "source": [ + "chars = sorted(set(text))\n", + "vocab_size = len(chars)\n", + "stoi = {c: i for i, c in enumerate(chars)}\n", + "itos = {i: c for i, c in enumerate(chars)}\n", + "\n", + "def encode(s: str) -> np.ndarray:\n", + " return np.array([stoi[c] for c in s], dtype=\"int64\")\n", + "\n", + "def decode(arr: np.ndarray) -> str:\n", + " return \"\".join(itos[int(i)] for i in arr)\n", + "\n", + "data_ids = encode(text)\n", + "print(f\"vocab_size = {vocab_size}\")\n", + "print(f\"first 20 ids: {data_ids[:20].tolist()}\")\n", + "print(f\"round-trip : {decode(data_ids[:20])!r}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Mini-batches of `(context, next-char)` pairs\n", + "\n", + "The transformer reads a window of `block_size` characters and predicts the next character at every position, so a batch element is a pair `(x, y)` of `block_size`-length sequences where `y[t] = x[t + 1]`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:27.145023Z", + "iopub.status.busy": "2026-05-07T13:13:27.144920Z", + "iopub.status.idle": "2026-05-07T13:13:27.149354Z", + "shell.execute_reply": "2026-05-07T13:13:27.148415Z" + } + }, + "outputs": [], + "source": [ + "def get_batch(rng: np.random.Generator, batch_size: int, block_size: int):\n", + " \"\"\"Sample `batch_size` random windows of length `block_size + 1` from the corpus.\"\"\"\n", + " starts = rng.integers(0, len(data_ids) - block_size - 1, size=batch_size)\n", + " x = np.stack([data_ids[s : s + block_size] for s in starts])\n", + " y = np.stack([data_ids[s + 1 : s + 1 + block_size] for s in starts])\n", + " return x, y\n", + "\n", + "x_demo, y_demo = get_batch(rng_np, batch_size=2, block_size=16)\n", + "print(\"x[0]:\", repr(decode(x_demo[0])))\n", + "print(\"y[0]:\", repr(decode(y_demo[0])))\n", + "assert (x_demo[:, 1:] == y_demo[:, :-1]).all(), \"y is x shifted by one\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hyper-parameters and parameter init\n", + "\n", + "A tiny GPT-style transformer with a few thousand parameters. Note that `block_size` is *static* in the PyTensor graph (it appears in the causal-mask shape), while `batch_size` is left symbolic so we can train on `B=16` and sample at `B=1` with the **same** compiled graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:27.151290Z", + "iopub.status.busy": "2026-05-07T13:13:27.151119Z", + "iopub.status.idle": "2026-05-07T13:13:27.156816Z", + "shell.execute_reply": "2026-05-07T13:13:27.156272Z" + } + }, + "outputs": [], + "source": [ + "@dataclass\n", + "class Config:\n", + " vocab_size: int\n", + " block_size: int = 32\n", + " n_embd: int = 64\n", + " n_head: int = 4\n", + " n_layer: int = 2\n", + "\n", + " @property\n", + " def head_dim(self) -> int:\n", + " assert self.n_embd % self.n_head == 0, \"n_embd must be divisible by n_head\"\n", + " return self.n_embd // self.n_head\n", + "\n", + "cfg = Config(vocab_size=vocab_size)\n", + "cfg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:27.158183Z", + "iopub.status.busy": "2026-05-07T13:13:27.158078Z", + "iopub.status.idle": "2026-05-07T13:13:27.167215Z", + "shell.execute_reply": "2026-05-07T13:13:27.166803Z" + } + }, + "outputs": [], + "source": [ + "def init_params(cfg: Config, rng: np.random.Generator) -> dict[str, np.ndarray]:\n", + " D = cfg.n_embd\n", + " p: dict[str, np.ndarray] = {}\n", + "\n", + " p[\"tok_embed\"] = rng.normal(0, 0.02, (cfg.vocab_size, D)).astype(floatX)\n", + " p[\"pos_embed\"] = rng.normal(0, 0.02, (cfg.block_size, D)).astype(floatX)\n", + "\n", + " for i in range(cfg.n_layer):\n", + " p[f\"l{i}.ln1_g\"] = np.ones(D, dtype=floatX)\n", + " p[f\"l{i}.ln1_b\"] = np.zeros(D, dtype=floatX)\n", + " p[f\"l{i}.W_qkv\"] = rng.normal(0, 0.02, (D, 3 * D)).astype(floatX)\n", + " p[f\"l{i}.W_proj\"] = rng.normal(0, 0.02, (D, D)).astype(floatX)\n", + " p[f\"l{i}.ln2_g\"] = np.ones(D, dtype=floatX)\n", + " p[f\"l{i}.ln2_b\"] = np.zeros(D, dtype=floatX)\n", + " p[f\"l{i}.W_fc\"] = rng.normal(0, 0.02, (D, 4 * D)).astype(floatX)\n", + " p[f\"l{i}.W_proj2\"] = rng.normal(0, 0.02, (4 * D, D)).astype(floatX)\n", + "\n", + " p[\"ln_f_g\"] = np.ones(D, dtype=floatX)\n", + " p[\"ln_f_b\"] = np.zeros(D, dtype=floatX)\n", + " return p\n", + "\n", + "# Wrap each numpy array in a `pytensor.shared` so the optimizer can update it in-graph.\n", + "init_arrays = init_params(cfg, np.random.default_rng(42))\n", + "params = {name: pytensor.shared(arr, name=name) for name, arr in init_arrays.items()}\n", + "\n", + "n_params = sum(arr.size for arr in init_arrays.values())\n", + "print(f\"{len(params)} shared variables, {n_params:,} total parameters\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The model with named-dimension attention\n", + "\n", + "Multi-head attention is the place in transformer code where axis bookkeeping bites the hardest: `(B, T, D)` becomes `(B, T, H, hd)` and then `(B, H, T, hd)`, and you have to remember every transpose. We build attention with [`pytensor.xtensor`](../../library/xtensor.rst), which gives every dimension a **name**, so each operation is driven by what the axis *means* rather than its index. The xtensor sub-graph is lowered to ordinary tensor ops at compile time, so the same Numba / JAX / C / MLX backends still apply.\n", + "\n", + "Layernorm, MLP, embeddings, and the residual stream stay as plain `pytensor.tensor`, where named dimensions add little. Every component is a Python function that **builds a symbolic graph** from its inputs; these functions run *once* at graph-construction time, so there is no per-batch Python overhead during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:27.168815Z", + "iopub.status.busy": "2026-05-07T13:13:27.168689Z", + "iopub.status.idle": "2026-05-07T13:13:27.173978Z", + "shell.execute_reply": "2026-05-07T13:13:27.173296Z" + } + }, + "outputs": [], + "source": [ + "def layernorm(x, gamma, beta, eps=1e-5):\n", + " mu = x.mean(axis=-1, keepdims=True)\n", + " var = x.var(axis=-1, keepdims=True)\n", + " return (x - mu) / pt.sqrt(var + eps) * gamma + beta\n", + "\n", + "\n", + "def causal_self_attention(x, p, n_head, head_dim, block_size):\n", + " \"\"\"Multi-head causal self-attention with named dimensions.\n", + "\n", + " Inputs/outputs are plain `(B, T, D)` tensors. Internally we wrap them as\n", + " xtensors with named dims and reshape `W_qkv` to expose ('qkv', 'head', 'hd')\n", + " as named axes — from then on every op is pure semantics: contract over\n", + " 'embd', contract over 'hd' (head_dim), softmax over 'time_k'. The `assert`s\n", + " on `.dims` aren't just documentation: PyTensor xtensors carry their dim\n", + " names symbolically, so they're always available for introspection.\n", + " \"\"\"\n", + " D = n_head * head_dim\n", + " x_x = ptx.as_xtensor(x, dims=(\"batch\", \"time\", \"embd\"))\n", + " Wqkv = ptx.as_xtensor(\n", + " p[\"W_qkv\"].reshape((D, 3, n_head, head_dim)),\n", + " dims=(\"embd\", \"qkv\", \"head\", \"hd\"),\n", + " )\n", + " Wproj = ptx.as_xtensor(p[\"W_proj\"], dims=(\"embd\", \"embd_out\"))\n", + "\n", + " qkv = ptx.dot(x_x, Wqkv, dim=\"embd\")\n", + " assert qkv.dims == (\"batch\", \"time\", \"qkv\", \"head\", \"hd\")\n", + " q = qkv.isel(qkv=0).rename(time=\"time_q\")\n", + " k = qkv.isel(qkv=1).rename(time=\"time_k\")\n", + " v = qkv.isel(qkv=2).rename(time=\"time_k\")\n", + "\n", + " scale = np.sqrt(head_dim).astype(floatX)\n", + " scores = ptx.dot(q, k, dim=\"hd\") / scale\n", + " assert scores.dims == (\"batch\", \"time_q\", \"head\", \"time_k\")\n", + "\n", + " mask = ptx.as_xtensor(\n", + " pt.tril(pt.ones((block_size, block_size), dtype=\"bool\")),\n", + " dims=(\"time_q\", \"time_k\"),\n", + " )\n", + " scores = ptx.where(mask, scores, np.float64(-1e9))\n", + " attn = ptx.math.softmax(scores, dim=\"time_k\")\n", + "\n", + " out = ptx.dot(attn, v, dim=\"time_k\")\n", + " assert out.dims == (\"time_q\", \"batch\", \"head\", \"hd\")\n", + " out = ptx.stack(out, embd=(\"head\", \"hd\"))\n", + " assert out.dims == (\"time_q\", \"batch\", \"embd\")\n", + " out = ptx.dot(out, Wproj, dim=\"embd\").rename(time_q=\"time\", embd_out=\"embd\")\n", + " return out.transpose(\"batch\", \"time\", \"embd\").values\n", + "\n", + "\n", + "def mlp(x, p):\n", + " h = pt.tanh(x @ p[\"W_fc\"]) # tanh keeps the toy model simple\n", + " return h @ p[\"W_proj2\"]\n", + "\n", + "\n", + "def block(x, params, layer: int, n_head: int, head_dim: int, block_size: int):\n", + " p = {k.split(\".\", 1)[1]: v for k, v in params.items() if k.startswith(f\"l{layer}.\")}\n", + " x = x + causal_self_attention(\n", + " layernorm(x, p[\"ln1_g\"], p[\"ln1_b\"]), p, n_head, head_dim, block_size\n", + " )\n", + " x = x + mlp(layernorm(x, p[\"ln2_g\"], p[\"ln2_b\"]), p)\n", + " return x\n", + "\n", + "\n", + "def forward(tokens, params, cfg: Config):\n", + " \"\"\"Returns logits of shape (B, T, vocab_size). Uses tied embeddings.\"\"\"\n", + " h = params[\"tok_embed\"][tokens] + params[\"pos_embed\"]\n", + " for i in range(cfg.n_layer):\n", + " h = block(h, params, i, cfg.n_head, cfg.head_dim, cfg.block_size)\n", + " h = layernorm(h, params[\"ln_f_g\"], params[\"ln_f_b\"])\n", + " return h @ params[\"tok_embed\"].T # tied output head" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we build the symbolic logits graph and inspect it. The whole forward pass is a single `TensorVariable`, with the named-dimension attention sub-graph showing up as a wrapped xtensor block. The `lower_xtensor` rewrite (which runs automatically at compile time) replaces those xtensor ops with plain tensor ops." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:27.175559Z", + "iopub.status.busy": "2026-05-07T13:13:27.175439Z", + "iopub.status.idle": "2026-05-07T13:13:27.284973Z", + "shell.execute_reply": "2026-05-07T13:13:27.284421Z" + } + }, + "outputs": [], + "source": [ + "X_sym = pt.tensor(\"X\", shape=(None, cfg.block_size), dtype=\"int64\")\n", + "logits = forward(X_sym, params, cfg)\n", + "print(\"logits type:\", logits.type)\n", + "print(\"\\n--- forward graph (truncated) ---\")\n", + "logits.dprint(depth=4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:27.286619Z", + "iopub.status.busy": "2026-05-07T13:13:27.286405Z", + "iopub.status.idle": "2026-05-07T13:13:30.359631Z", + "shell.execute_reply": "2026-05-07T13:13:30.359060Z" + } + }, + "outputs": [], + "source": [ + "# Sanity check: compile the forward pass and feed a real batch.\n", + "f_forward = pytensor.function([X_sym], logits)\n", + "x_batch, _ = get_batch(rng_np, batch_size=4, block_size=cfg.block_size)\n", + "out = f_forward(x_batch)\n", + "print(\"input shape :\", x_batch.shape)\n", + "print(\"logits shape:\", out.shape)\n", + "assert out.shape == (4, cfg.block_size, cfg.vocab_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loss and gradients\n", + "\n", + "We use cross-entropy averaged over `(B, T)`, computed with `pt.special.log_softmax` for numerical stability and advanced indexing to gather the log-prob assigned to the true target at every position.\n", + "\n", + "Because the forward graph contains xtensor ops (inside attention) and `pytensor.grad` runs *before* compile-time lowering, we explicitly lower the xtensor sub-graph first with `rewrite_graph(..., include=(\"lower_xtensor\",))`. From there one line of PyTensor gives us the gradients of the loss with respect to **every** parameter — as a *new symbolic graph*." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:30.361026Z", + "iopub.status.busy": "2026-05-07T13:13:30.360867Z", + "iopub.status.idle": "2026-05-07T13:13:30.435791Z", + "shell.execute_reply": "2026-05-07T13:13:30.435284Z" + } + }, + "outputs": [], + "source": [ + "from pytensor.graph.rewriting.utils import rewrite_graph\n", + "\n", + "Y_sym = pt.tensor(\"Y\", shape=(None, cfg.block_size), dtype=\"int64\")\n", + "\n", + "log_probs = pt.special.log_softmax(logits, axis=-1) # (B, T, V)\n", + "B_sym = X_sym.shape[0]\n", + "T_sym = X_sym.shape[1]\n", + "batch_idx = pt.arange(B_sym)[:, None] # (B, 1)\n", + "time_idx = pt.arange(T_sym)[None, :] # (1, T)\n", + "nll = -log_probs[batch_idx, time_idx, Y_sym] # (B, T)\n", + "loss = nll.mean()\n", + "\n", + "# Lower xtensor ops to plain tensor ops so `pytensor.grad` can pull back through them.\n", + "loss = rewrite_graph(loss, include=(\"lower_xtensor\",))\n", + "\n", + "flat_params = list(params.values())\n", + "grads = pytensor.grad(loss, flat_params)\n", + "print(f\"Loss is a {loss.type}\")\n", + "print(f\"pytensor.grad returned {len(grads)} gradient graphs (one per shared variable).\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adam optimizer\n", + "\n", + "The optimizer is just a function that returns a list of `(shared_var, new_value)` updates. PyTensor weaves those updates into the same compiled function as the forward and backward pass, so there is no Python overhead per step — the entire forward/backward/update is one C/Numba call." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:30.437214Z", + "iopub.status.busy": "2026-05-07T13:13:30.437105Z", + "iopub.status.idle": "2026-05-07T13:13:30.487164Z", + "shell.execute_reply": "2026-05-07T13:13:30.486661Z" + } + }, + "outputs": [], + "source": [ + "def adam_updates(\n", + " params, grads, lr=3e-3, b1=0.9, b2=0.999, eps=1e-8\n", + "):\n", + " t = pytensor.shared(np.array(0.0, dtype=floatX), name=\"adam_t\")\n", + " t_new = t + 1\n", + " updates = [(t, t_new)]\n", + " for p, g in zip(params, grads):\n", + " m = pytensor.shared(np.zeros_like(p.get_value()), name=p.name + \"_m\")\n", + " v = pytensor.shared(np.zeros_like(p.get_value()), name=p.name + \"_v\")\n", + " m_new = b1 * m + (1 - b1) * g\n", + " v_new = b2 * v + (1 - b2) * pt.square(g)\n", + " m_hat = m_new / (1 - b1 ** t_new)\n", + " v_hat = v_new / (1 - b2 ** t_new)\n", + " p_new = p - lr * m_hat / (pt.sqrt(v_hat) + eps)\n", + " updates += [(m, m_new), (v, v_new), (p, p_new)]\n", + " return updates\n", + "\n", + "updates = adam_updates(flat_params, grads, lr=3e-3)\n", + "print(f\"Adam adds {len(updates)} updates ({len(flat_params)} params x 3 + 1 step counter).\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compile the train step and train\n", + "\n", + "The first call to `pytensor.function` takes a few seconds while PyTensor optimises and compiles the graph. Each subsequent call is a pure C / Numba invocation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:30.488447Z", + "iopub.status.busy": "2026-05-07T13:13:30.488361Z", + "iopub.status.idle": "2026-05-07T13:13:44.104011Z", + "shell.execute_reply": "2026-05-07T13:13:44.103622Z" + } + }, + "outputs": [], + "source": [ + "t0 = time.time()\n", + "train_step = pytensor.function([X_sym, Y_sym], loss, updates=updates)\n", + "print(f\"Compilation took {time.time() - t0:.1f}s\")\n", + "\n", + "# Initial loss should be near ln(vocab_size) for a randomly initialised model.\n", + "x_init, y_init = get_batch(rng_np, batch_size=4, block_size=cfg.block_size)\n", + "initial_loss = float(train_step(x_init, y_init))\n", + "print(f\"Initial loss = {initial_loss:.3f} (random baseline = {np.log(vocab_size):.3f})\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:13:44.105474Z", + "iopub.status.busy": "2026-05-07T13:13:44.105383Z", + "iopub.status.idle": "2026-05-07T13:14:19.641951Z", + "shell.execute_reply": "2026-05-07T13:14:19.641513Z" + } + }, + "outputs": [], + "source": [ + "BATCH_SIZE = 32\n", + "N_STEPS = 1500\n", + "LOG_EVERY = 100\n", + "\n", + "losses = []\n", + "t0 = time.time()\n", + "for step in range(N_STEPS):\n", + " x_b, y_b = get_batch(rng_np, BATCH_SIZE, cfg.block_size)\n", + " losses.append(float(train_step(x_b, y_b)))\n", + " if step % LOG_EVERY == 0 or step == N_STEPS - 1:\n", + " recent = np.mean(losses[-LOG_EVERY:])\n", + " print(f\"step {step:>4d} recent-mean loss = {recent:.3f}\")\n", + "print(f\"\\nTotal training time: {time.time() - t0:.1f}s\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:14:19.643434Z", + "iopub.status.busy": "2026-05-07T13:14:19.643321Z", + "iopub.status.idle": "2026-05-07T13:14:19.739221Z", + "shell.execute_reply": "2026-05-07T13:14:19.738756Z" + } + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 4))\n", + "ax.plot(losses, lw=0.6, alpha=0.5, label=\"raw\")\n", + "win = 50\n", + "ax.plot(\n", + " np.arange(len(losses) - win + 1),\n", + " np.convolve(losses, np.ones(win) / win, mode=\"valid\"),\n", + " lw=2, label=f\"{win}-step rolling mean\",\n", + ")\n", + "ax.axhline(np.log(vocab_size), color=\"grey\", ls=\"--\", label=\"random baseline\")\n", + "ax.set_xlabel(\"training step\")\n", + "ax.set_ylabel(\"cross-entropy loss\")\n", + "ax.legend()\n", + "ax.set_title(\"Training curve\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampling — the autoregressive loop lives inside `pytensor.scan`\n", + "\n", + "For inference we feed the model a context window and sample the next character from the predicted distribution over the vocabulary. PyTensor lets us push the entire autoregressive loop **into the compiled graph** with `pytensor.scan`: every forward pass, every categorical draw, all the context-window bookkeeping becomes a single C/Numba call. We use `ptx.random.shared_rng` so the RNG state advances **inside** the compiled function — each call returns an independent draw. This pattern matches the one in the [scan tutorial](../scan/scan_tutorial.ipynb).\n", + "\n", + "We keep the context window and the categorical sampling as **xtensor** ops: that way the sliding window is a named-dim `concat` rather than positional `pt.concatenate`, and `categorical` consumes a `probs` xtensor with a `vocab` core dim. `pytensor.scan` itself only carries plain tensor variables across iterations, so we cross the xtensor↔tensor boundary in four places: convert to tensor before passing into the scan, back to xtensor at the top of the body, back to tensor before returning, and back to xtensor as soon as we get the scan outputs out.\n", + "\n", + "We also turn this into a **while-scan**: `pytensor.scan` accepts a `(outputs, until(cond))` return from the body and stops as soon as `cond` becomes true. Here we stop on the first newline — Tiny Shakespeare uses `\\n` as a hard break between speakers — capped by `n_new_sym` as a safety net." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:14:19.741016Z", + "iopub.status.busy": "2026-05-07T13:14:19.740900Z", + "iopub.status.idle": "2026-05-07T13:14:22.515333Z", + "shell.execute_reply": "2026-05-07T13:14:22.514743Z" + } + }, + "outputs": [], + "source": [ + "from pytensor.scan import until\n", + "\n", + "EOS_ID = stoi[\"\\n\"]\n", + "rng_sym = ptx.random.shared_rng(seed=2026, name=\"rng_sample\")\n", + "\n", + "\n", + "def gen_step(context_t, rng):\n", + " \"\"\"One generation step inside the scan loop.\n", + "\n", + " `context_t` arrives as a length-`block_size` int64 *tensor* — scan only\n", + " carries tensor variables across iterations — and we immediately wrap it\n", + " as an xtensor with a `time` dim. The forward pass is tensor-native, so\n", + " we dip back to tensor land for that single call, then go back to\n", + " xtensor for the sample + window slide.\n", + " \"\"\"\n", + " context = ptx.as_xtensor(context_t, dims=(\"time\",))\n", + " assert context.dims == (\"time\",)\n", + "\n", + " logits_step = ptx.as_xtensor(\n", + " forward(context_t[None, :], params, cfg)[0, -1], dims=(\"vocab\",)\n", + " )\n", + " probs_step = ptx.math.softmax(logits_step, dim=\"vocab\")\n", + " next_rng, next_tok = rng.categorical(probs_step, core_dims=\"vocab\")\n", + " assert next_tok.dims == ()\n", + "\n", + " new_context = ptx.concat(\n", + " [context.isel(time=slice(1, None)), next_tok.expand_dims(\"time\")],\n", + " dim=\"time\",\n", + " )\n", + " assert new_context.dims == (\"time\",)\n", + "\n", + " return (\n", + " [new_context.values, next_tok.values, next_rng],\n", + " until(pt.eq(next_tok.values, EOS_ID)),\n", + " )\n", + "\n", + "\n", + "init_ctx_x = ptx.xtensor(\n", + " \"init_ctx\", dims=(\"time\",), shape=(cfg.block_size,), dtype=\"int64\"\n", + ")\n", + "n_new_sym = pt.iscalar(\"n_new\")\n", + "\n", + "ctx_seq_t, tok_seq_t, rng_final = pytensor.scan(\n", + " fn=gen_step,\n", + " outputs_info=[init_ctx_x.values, None, rng_sym],\n", + " n_steps=n_new_sym,\n", + " return_updates=False,\n", + ")\n", + "tok_seq = ptx.as_xtensor(tok_seq_t, dims=(\"step\",))\n", + "assert tok_seq.dims == (\"step\",)\n", + "\n", + "generate_fn = pytensor.function(\n", + " [init_ctx_x, n_new_sym],\n", + " tok_seq.values,\n", + " updates={rng_sym: rng_final},\n", + ")\n", + "\n", + "\n", + "def generate(prompt: str, n_new_tokens: int = 400) -> str:\n", + " \"\"\"Run autoregressive generation. The Python wrapper only handles text ↔ ids;\n", + " every forward pass and every sample lives inside the compiled scan, which\n", + " stops early at the first newline or after ``n_new_tokens`` steps.\n", + " \"\"\"\n", + " ids = list(encode(prompt)) or [EOS_ID]\n", + " init_ctx = np.full(cfg.block_size, EOS_ID, dtype=\"int64\")\n", + " init_ctx[-len(ids):] = ids[-cfg.block_size:]\n", + " sampled = generate_fn(init_ctx, n_new_tokens)\n", + " return prompt + decode(sampled)\n", + "\n", + "\n", + "print(generate(\"ROMEO:\\n\", n_new_tokens=400))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scaling up: a deeper model on more data\n", + "\n", + "The sections above used a 50 K-character slice and a small 2-layer model. Now we double both and train a 4-layer transformer on a 100 K-character slice (and double `block_size`), still on the default Numba backend. We then compare the two models head-to-head on a held-out slice of Shakespeare that neither has seen — the training losses on their own are not directly comparable because the corpora, vocabularies, and context windows differ.\n", + "\n", + "## Re-tokenize a 100 K-character slice and rebuild the model\n", + "\n", + "We use a fresh `Config` and a freshly initialised parameter dict so this section is independent of everything above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text_full = full_text[:100_000]\n", + "chars_full = sorted(set(text_full))\n", + "vocab_size_full = len(chars_full)\n", + "stoi_full = {c: i for i, c in enumerate(chars_full)}\n", + "itos_full = {i: c for i, c in enumerate(chars_full)}\n", + "data_full = np.array([stoi_full[c] for c in text_full], dtype=\"int64\")\n", + "\n", + "cfg_full = Config(\n", + " vocab_size=vocab_size_full,\n", + " block_size=64,\n", + " n_embd=64,\n", + " n_head=4,\n", + " n_layer=4,\n", + ")\n", + "print(f\"corpus slice: {len(data_full):,} chars, vocab: {vocab_size_full}\")\n", + "print(f\"cfg_full : {cfg_full}\")\n", + "\n", + "init_arrays_full = init_params(cfg_full, np.random.default_rng(7))\n", + "print(f\"params : {sum(a.size for a in init_arrays_full.values()):,}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compile a fresh `train_step` and train\n", + "\n", + "We wrap the now-familiar pattern — fresh `shared` params, symbolic loss, lower the xtensor sub-graph, `pytensor.grad`, Adam updates, compile — into a small helper so this section reads top-to-bottom." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def build_train_step(cfg, init_arrays, *, lr=3e-3):\n", + " p = {n: pytensor.shared(arr.copy(), name=n) for n, arr in init_arrays.items()}\n", + " X = pt.tensor(\"X\", shape=(None, cfg.block_size), dtype=\"int64\")\n", + " Y = pt.tensor(\"Y\", shape=(None, cfg.block_size), dtype=\"int64\")\n", + " logits = forward(X, p, cfg)\n", + " log_probs = pt.special.log_softmax(logits, axis=-1)\n", + " bidx = pt.arange(X.shape[0])[:, None]\n", + " tidx = pt.arange(X.shape[1])[None, :]\n", + " loss = (-log_probs[bidx, tidx, Y]).mean()\n", + " # Lower xtensor ops before grad (see the smaller-model section above).\n", + " loss = rewrite_graph(loss, include=(\"lower_xtensor\",))\n", + " grads = pytensor.grad(loss, list(p.values()))\n", + " upd = adam_updates(list(p.values()), grads, lr=lr)\n", + " t0 = time.time()\n", + " fn = pytensor.function([X, Y], loss, updates=upd)\n", + " return fn, p, time.time() - t0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "BATCH_FULL = 32\n", + "N_STEPS_FULL = 1500\n", + "\n", + "train_step_full, params_full, ct_numba = build_train_step(cfg_full, init_arrays_full)\n", + "print(f\"Compilation took {ct_numba:.1f}s\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng_full = np.random.default_rng(123)\n", + "losses_full: list[float] = []\n", + "t0 = time.time()\n", + "for step in range(N_STEPS_FULL):\n", + " starts = rng_full.integers(0, len(data_full) - cfg_full.block_size - 1, size=BATCH_FULL)\n", + " xb = np.stack([data_full[s : s + cfg_full.block_size] for s in starts])\n", + " yb = np.stack([data_full[s + 1 : s + 1 + cfg_full.block_size] for s in starts])\n", + " losses_full.append(float(train_step_full(xb, yb)))\n", + " if step % 100 == 0 or step == N_STEPS_FULL - 1:\n", + " print(f\"step {step:>4d} recent-mean loss = {np.mean(losses_full[-100:]):.3f}\")\n", + "print(f\"\\nTotal training time: {time.time() - t0:.1f}s\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(9, 4))\n", + "ax.plot(losses_full, alpha=0.4, label=\"raw\")\n", + "rolling = np.convolve(losses_full, np.ones(50) / 50, mode=\"valid\")\n", + "ax.plot(np.arange(len(rolling)) + 49, rolling, lw=2, label=\"50-step rolling mean\")\n", + "ax.axhline(np.log(cfg_full.vocab_size), ls=\"--\", color=\"gray\", label=\"random baseline\")\n", + "ax.set_xlabel(\"training step\")\n", + "ax.set_ylabel(\"cross-entropy loss\")\n", + "ax.set_title(\"Training curve — deeper model on 100 K characters\")\n", + "ax.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## A fair side-by-side: held-out perplexity\n", + "\n", + "The two training losses (≈ 1.65 vs ≈ 1.57) aren't directly comparable. The models saw different corpus slices (50 K vs 100 K characters), have different vocabularies (59 vs 61), and — most importantly — a different `block_size` (32 vs 64). Doubling the context window alone tends to drive per-token cross-entropy down, even at equal model quality.\n", + "\n", + "To get an honest comparison we evaluate both models on the same 10 K-character slice that **neither** of them saw during training (`full_text[100_000:110_000]`). For the small model we simply skip the handful of characters that aren't in its vocabulary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def build_eval_fn(cfg, params_dict):\n", + " \"\"\"Compile a forward + cross-entropy function that reads from the shared params.\"\"\"\n", + " X = pt.tensor(\"X_eval\", shape=(None, cfg.block_size), dtype=\"int64\")\n", + " Y = pt.tensor(\"Y_eval\", shape=(None, cfg.block_size), dtype=\"int64\")\n", + " log_probs = pt.special.log_softmax(forward(X, params_dict, cfg), axis=-1)\n", + " bidx = pt.arange(X.shape[0])[:, None]\n", + " tidx = pt.arange(X.shape[1])[None, :]\n", + " loss = (-log_probs[bidx, tidx, Y]).mean()\n", + " loss = rewrite_graph(loss, include=(\"lower_xtensor\",))\n", + " return pytensor.function([X, Y], loss)\n", + "\n", + "\n", + "def held_out_loss(eval_fn, cfg, stoi_local, holdout_text):\n", + " \"\"\"Mean per-token cross-entropy over non-overlapping windows of `holdout_text`.\n", + " Characters absent from `stoi_local` are skipped (vocabularies can differ slightly).\"\"\"\n", + " ids = np.array([stoi_local[c] for c in holdout_text if c in stoi_local], dtype=\"int64\")\n", + " starts = np.arange(0, len(ids) - cfg.block_size - 1, cfg.block_size)\n", + " total, n = 0.0, 0\n", + " for i in range(0, len(starts), 64):\n", + " batch = starts[i : i + 64]\n", + " xb = np.stack([ids[s : s + cfg.block_size] for s in batch])\n", + " yb = np.stack([ids[s + 1 : s + 1 + cfg.block_size] for s in batch])\n", + " total += float(eval_fn(xb, yb)) * len(batch)\n", + " n += len(batch)\n", + " return total / n\n", + "\n", + "\n", + "holdout_text = full_text[100_000:110_000]\n", + "L_small = held_out_loss(build_eval_fn(cfg, params), cfg, stoi, holdout_text)\n", + "L_full = held_out_loss(build_eval_fn(cfg_full, params_full), cfg_full, stoi_full, holdout_text)\n", + "\n", + "print(f\"held-out CE (small, 2-layer, block 32): {L_small:.3f} ppl = {np.exp(L_small):.2f}\")\n", + "print(f\"held-out CE (large, 4-layer, block 64): {L_full :.3f} ppl = {np.exp(L_full ):.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## A sample from the deeper model\n", + "\n", + "On held-out text the deeper model is genuinely better — about 0.10 nats per character lower cross-entropy, or **~10% lower perplexity** (5.69 vs 6.31). At this scale and training budget that is *not* enough to make the samples *look* coherent: both models still produce Shakespearean gibberish at the word level. What you can see in the sample below is firmer speaker-name structure and slightly more believable rhythm. We reuse the same `pytensor.scan` generator pattern — the entire autoregressive loop still lives inside one compiled call." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "EOS_ID_FULL = stoi_full[\"\\n\"]\n", + "rng_full_sym = ptx.random.shared_rng(seed=2027, name=\"rng_full\")\n", + "\n", + "\n", + "def gen_step_full(context_t, rng):\n", + " context = ptx.as_xtensor(context_t, dims=(\"time\",))\n", + " logits_step = ptx.as_xtensor(\n", + " forward(context_t[None, :], params_full, cfg_full)[0, -1], dims=(\"vocab\",)\n", + " )\n", + " probs_step = ptx.math.softmax(logits_step, dim=\"vocab\")\n", + " next_rng, next_tok = rng.categorical(probs_step, core_dims=\"vocab\")\n", + " new_context = ptx.concat(\n", + " [context.isel(time=slice(1, None)), next_tok.expand_dims(\"time\")],\n", + " dim=\"time\",\n", + " )\n", + " return (\n", + " [new_context.values, next_tok.values, next_rng],\n", + " until(pt.eq(next_tok.values, EOS_ID_FULL)),\n", + " )\n", + "\n", + "\n", + "init_ctx_full = ptx.xtensor(\n", + " \"init_ctx_full\", dims=(\"time\",), shape=(cfg_full.block_size,), dtype=\"int64\"\n", + ")\n", + "n_new_full = pt.iscalar(\"n_new_full\")\n", + "\n", + "_, tok_seq_full_t, rng_final_full = pytensor.scan(\n", + " fn=gen_step_full,\n", + " outputs_info=[init_ctx_full.values, None, rng_full_sym],\n", + " n_steps=n_new_full,\n", + " return_updates=False,\n", + ")\n", + "tok_seq_full = ptx.as_xtensor(tok_seq_full_t, dims=(\"step\",))\n", + "\n", + "generate_full_fn = pytensor.function(\n", + " [init_ctx_full, n_new_full],\n", + " tok_seq_full.values,\n", + " updates={rng_full_sym: rng_final_full},\n", + ")\n", + "\n", + "\n", + "def generate_full(prompt: str, n_new_tokens: int = 400) -> str:\n", + " ids = list(np.array([stoi_full[c] for c in prompt], dtype=\"int64\")) or [EOS_ID_FULL]\n", + " init_ctx = np.full(cfg_full.block_size, EOS_ID_FULL, dtype=\"int64\")\n", + " init_ctx[-len(ids):] = ids[-cfg_full.block_size:]\n", + " sampled = generate_full_fn(init_ctx, n_new_tokens)\n", + " return prompt + \"\".join(itos_full[int(i)] for i in sampled)\n", + "\n", + "\n", + "print(generate_full(\"ROMEO:\\n\", n_new_tokens=400))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Takeaways\n", + "\n", + "* A full **decoder-only transformer** — embeddings, multi-head causal attention, MLP, layernorm, tied softmax head, cross-entropy loss, and Adam — fits in roughly 100 lines of PyTensor. No custom `Op`, no backend-specific code.\n", + "* `pytensor.xtensor` named dimensions turned the trickiest part of the model (multi-head attention) into self-documenting code. The named-dim sub-graph is lowered to plain tensor ops at compile time — we just had to call `rewrite_graph(loss, include=(\"lower_xtensor\",))` once before `pytensor.grad` so the gradient pass walks the lowered graph.\n", + "* `pytensor.grad` gave us back-propagation through the entire stack as a *new symbolic graph*. We optimised that graph and compiled it into a single C/Numba `train_step`.\n", + "* `pytensor.shared` plus `updates=` made Adam a 15-line helper. Optimizer state lives next to the parameters in the graph, not in Python.\n", + "* `pytensor.scan` let us push autoregressive generation entirely inside the compiled graph — the Python wrapper only translates between text and ids; every forward pass and every categorical draw lives inside one C/Numba call.\n", + "\n", + "## Where to go from here\n", + "\n", + "* Swap `mode=\"NUMBA\"` (default) for `mode=\"JAX\"` to train on a GPU/TPU.\n", + "* Replace `tanh` with a real GELU and try a longer training run on the full 1.1 M-character corpus.\n", + "* Use `pytensor.dprint(train_step)` to inspect the optimised post-rewrite graph and see how many ops PyTensor's rewriter fused away.\n", + "\n", + "## Authors\n", + "\n", + "* Authored by the PyTensor developers in May 2026.\n", + "\n", + "## Watermark" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-07T13:16:56.512277Z", + "iopub.status.busy": "2026-05-07T13:16:56.512148Z", + "iopub.status.idle": "2026-05-07T13:16:56.518894Z", + "shell.execute_reply": "2026-05-07T13:16:56.518316Z" + } + }, + "outputs": [], + "source": [ + "%load_ext watermark\n", + "%watermark -n -u -v -iv -w -p pytensor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::{include} ../page_footer.md\n", + ":::\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytensor-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.14.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index 22de7642f1..c2ab4126f5 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -1,7 +1,14 @@ import pytensor.xtensor.rewriting from pytensor.xtensor import linalg, math, random, signal from pytensor.xtensor.math import dot, where -from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like +from pytensor.xtensor.shape import ( + broadcast, + concat, + full_like, + ones_like, + stack, + zeros_like, +) from pytensor.xtensor.type import ( as_xtensor, xtensor,