From 2a3011e3f95724ae8adba282655307e8a348f2ce Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Fri, 16 Jan 2026 16:25:12 -0800 Subject: [PATCH 1/5] initial cutile changes --- requirements.txt | 1 + scripts/generate_and_eval_single_sample.py | 2 +- .../generate_and_eval_single_sample_modal.py | 2 +- scripts/generate_samples.py | 2 +- src/kernelbench/eval.py | 8 +-- src/kernelbench/prompt_constructor_toml.py | 6 +- .../prompts/model_ex_add_cutile.py | 61 +++++++++++++++++++ src/kernelbench/prompts/prompts.toml | 6 ++ 8 files changed, 78 insertions(+), 10 deletions(-) create mode 100644 src/kernelbench/prompts/model_ex_add_cutile.py diff --git a/requirements.txt b/requirements.txt index 07603a86..69f31e66 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ # we use latest PyTorch stable release torch==2.9.* triton==3.5.* +cuda-tile # we shall upgrade torch for blackwell when it is stable transformers>=4.57.3 diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index fce1b16f..9cf280b5 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -174,7 +174,7 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} + supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens", "cutile"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 7308d228..282bb424 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -207,7 +207,7 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} + supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens", "cutile"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 2c01ee8d..d80920fc 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -234,7 +234,7 @@ def main(config: GenerationConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "cute", "tilelang", "thunderkittens"} + supported_backends = {"cuda", "triton", "cute", "tilelang", "thunderkittens", "cutile"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( diff --git a/src/kernelbench/eval.py b/src/kernelbench/eval.py index 47f59793..7da49a49 100644 --- a/src/kernelbench/eval.py +++ b/src/kernelbench/eval.py @@ -404,7 +404,7 @@ def eval_kernel_against_ref( device: Union[torch.device, int] = ( torch.cuda.current_device() if torch.cuda.is_available() else None ), # have to run on GPU - backend: str = "cuda", # can be 'cuda', 'triton', 'tilelang', or 'cute' + backend: str = "cuda", # can be 'cuda', 'triton', 'tilelang', 'cute', or 'cutile' precision: torch.dtype = torch.float32, # Guard against potential reward hacking [optional but ongoing enhancement] @@ -420,7 +420,7 @@ def eval_kernel_against_ref( num_correct_trials: number of trials to initialize different random inputs; correctness pass only if all trials pass num_perf_trials: run the evalutation many times to take the average device: GPU (cuda) device to run the evalutation on - backend: str, one of 'cuda', 'triton', 'tilelang', or 'cute' + backend: str, one of 'cuda', 'triton', 'tilelang', 'cute', or 'cutile' precision: torch.dtype for computation (note: tilelang only supports fp16) timing_method: str, method to time kernel, see timing.py for more details @@ -444,7 +444,7 @@ def eval_kernel_against_ref( # Backends that use tempfile approach and need CUDA_VISIBLE_DEVICES # TileLang, Triton, and CuTe all use tempfile for proper module loading - uses_tempfile = backend.lower() in ["triton", "tilelang", "cute"] + uses_tempfile = backend.lower() in ["triton", "tilelang", "cute", "cutile"] metadata = {} # for storing result metadata metadata["hardware"] = torch.cuda.get_device_name(device=device) @@ -496,7 +496,7 @@ def eval_kernel_against_ref( # add hash for later to distinguish between multi-turn kernels backend_lower = backend.lower() - if backend_lower in ["triton", "tilelang", "cute"]: + if backend_lower in ["triton", "tilelang", "cute", "cutile"]: # Use tempfile approach for triton, tilelang, and cute # These DSLs require proper module import for JIT decorators to work ModelNew, tempfile = load_custom_model_with_tempfile( diff --git a/src/kernelbench/prompt_constructor_toml.py b/src/kernelbench/prompt_constructor_toml.py index 4349a74d..0a24d0d8 100644 --- a/src/kernelbench/prompt_constructor_toml.py +++ b/src/kernelbench/prompt_constructor_toml.py @@ -141,7 +141,7 @@ def render_prompt_by_option( Args: prompts_toml: Path to the prompts.toml file - backend: The kernel backend (triton, cuda, cute, tilelang) + backend: The kernel backend (triton, cuda, cute, tilelang, cutile) option: The prompt option (zero_shot, one_shot, few_shot) - zero_shot: No examples (model learns from description only) - one_shot: Single example @@ -196,7 +196,7 @@ def render_prompt_by_option( # Add backend-specific content to context context = { **context, - "backend": backend.upper() if backend in ["cuda", "cute"] else backend.capitalize(), + "backend": backend.upper() if backend in ["cuda", "cute"] else ("cuTile" if backend == "cutile" else backend.capitalize()), "backend_display": backend_display, "problem_statement": problem_statement, "instruction": instruction, @@ -332,7 +332,7 @@ def get_prompt_for_backend( Args: ref_arch_src: The reference architecture source code - backend: The kernel backend (triton, cuda, cute, tilelang) + backend: The kernel backend (triton, cuda, cute, tilelang, cutile) option: The prompt option (zero_shot, one_shot, few_shot) precision: Optional precision (fp32, fp16, bf16) - defaults to fp32 if not provided include_hardware: When True, append hardware guidance blocks (requires gpu_name) diff --git a/src/kernelbench/prompts/model_ex_add_cutile.py b/src/kernelbench/prompts/model_ex_add_cutile.py new file mode 100644 index 00000000..ff85aba6 --- /dev/null +++ b/src/kernelbench/prompts/model_ex_add_cutile.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +import cuda.tile as ct + +TILE_SIZE = 256 + + +@ct.kernel +def add_kernel(a, b, result): + """ + cuTile kernel for adding two dense tensors element-wise. + Each block processes TILE_SIZE elements. + """ + block_id = ct.bid(0) + a_tile = ct.load(a, index=(block_id,), shape=(TILE_SIZE,)) + b_tile = ct.load(b, index=(block_id,), shape=(TILE_SIZE,)) + result_tile = a_tile + b_tile + ct.store(result, index=(block_id,), tile=result_tile) + + +class ModelNew(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Forward pass using cuTile kernel for elementwise addition. + + Args: + a: First input tensor on CUDA + b: Second input tensor on CUDA (same shape as a) + + Returns: + Result tensor of a + b + """ + assert a.is_cuda and b.is_cuda, "Tensors must be on CUDA." + a = a.contiguous() + b = b.contiguous() + + # Store original shape for reshaping back + original_shape = a.shape + + # Flatten tensors for 1D processing + a_flat = a.view(-1) + b_flat = b.view(-1) + + # Allocate output tensor + result = torch.empty_like(a_flat) + + # Calculate grid dimensions + n_elements = a_flat.shape[0] + grid = (ct.cdiv(n_elements, TILE_SIZE), 1, 1) + + # Get current CUDA stream + stream = torch.cuda.current_stream()._as_parameter_ + + # Launch the kernel + ct.launch(stream, grid, add_kernel, (a_flat, b_flat, result)) + + # Reshape back to original shape + return result.view(original_shape) \ No newline at end of file diff --git a/src/kernelbench/prompts/prompts.toml b/src/kernelbench/prompts/prompts.toml index 2768aa11..54edb2d4 100644 --- a/src/kernelbench/prompts/prompts.toml +++ b/src/kernelbench/prompts/prompts.toml @@ -54,6 +54,12 @@ backend_display = "ThunderKittens kernels" one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_thunderkittens.py" # No few_shot_examples - will use one-shot when few_shot option is selected +[backends.cutile] +backend_display = "cuTile kernels" +one_shot_new_arch = "src/prompts/model_new_ex_add_cutile.py" +# No few_shot_examples - will use one-shot when few_shot option is selected +# Note: cuTile requires CUDA Toolkit 13.1+ + # ------------------------------------------------------------------------- # Precision: Precision-specific configuration # ------------------------------------------------------------------------- From a268f1e0ad7739e6359d91f67b7185c501f22881 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Fri, 16 Jan 2026 16:33:41 -0800 Subject: [PATCH 2/5] typo in add example path --- src/kernelbench/prompts/prompts.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernelbench/prompts/prompts.toml b/src/kernelbench/prompts/prompts.toml index 54edb2d4..2ad4f5b5 100644 --- a/src/kernelbench/prompts/prompts.toml +++ b/src/kernelbench/prompts/prompts.toml @@ -56,7 +56,7 @@ one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_thunderkittens.py" [backends.cutile] backend_display = "cuTile kernels" -one_shot_new_arch = "src/prompts/model_new_ex_add_cutile.py" +one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_cutile.py" # No few_shot_examples - will use one-shot when few_shot option is selected # Note: cuTile requires CUDA Toolkit 13.1+ From 73a18c4f4818f307246ee1be70b56a5b6c952388 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Fri, 16 Jan 2026 16:35:41 -0800 Subject: [PATCH 3/5] fixed it for sure --- src/kernelbench/prompts/prompts.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernelbench/prompts/prompts.toml b/src/kernelbench/prompts/prompts.toml index 2ad4f5b5..5f289380 100644 --- a/src/kernelbench/prompts/prompts.toml +++ b/src/kernelbench/prompts/prompts.toml @@ -56,7 +56,7 @@ one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_thunderkittens.py" [backends.cutile] backend_display = "cuTile kernels" -one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_cutile.py" +one_shot_new_arch = "src/kernelbench/prompts/model_ex_add_cutile.py" # No few_shot_examples - will use one-shot when few_shot option is selected # Note: cuTile requires CUDA Toolkit 13.1+ From ca5d9707e87ca3789d3ccc3c49b9330973a7caee Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Fri, 16 Jan 2026 16:46:28 -0800 Subject: [PATCH 4/5] pip installed cuda-tile --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bed37150..4defd7d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "torch==2.9.0", "transformers", - "datasets", + "datasets>=2.20.0", "modal", # helper @@ -41,6 +41,7 @@ gpu = [ "triton", "nvidia-cutlass-dsl", "tilelang", + "cuda-tile", "cupy-cuda12x", "nsight-python", ] From ae91877964ed8bcaf70495d65b315ab86cf84308 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Fri, 16 Jan 2026 17:07:48 -0800 Subject: [PATCH 5/5] removed tk stuff and added cutile static checker --- scripts/eval_from_generations.py | 2 +- .../generate_and_eval_single_sample_modal.py | 2 +- scripts/run_and_check.py | 2 +- src/kernelbench/kernel_static_checker.py | 19 +++++++++++++++++++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 247410f3..66c9083e 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -71,7 +71,7 @@ ) .uv_sync(uv_project_dir=REPO_TOP_DIR) - .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", "PYTHONPATH": "/root/src:/root" diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 282bb424..a3e38e62 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -105,7 +105,7 @@ def __repr__(self): ) .uv_sync(uv_project_dir=REPO_TOP_DIR, extras=["gpu"]) - .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", "PYTHONPATH": "/root:/root/src" diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index d253dd45..87f9f256 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -39,7 +39,7 @@ modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") .apt_install("git", "gcc-10", "g++-10", "clang") .uv_sync(uv_project_dir=REPO_TOP_PATH) - .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", "PYTHONPATH": "/root:/root/src:/root/scripts" diff --git a/src/kernelbench/kernel_static_checker.py b/src/kernelbench/kernel_static_checker.py index c8832a1a..1a2d87bc 100644 --- a/src/kernelbench/kernel_static_checker.py +++ b/src/kernelbench/kernel_static_checker.py @@ -269,6 +269,23 @@ def check_tilelang_impl(code: str) -> Tuple[bool, str]: return (False, "") +# <========= CUTILE PYTHON CHECKS =========> +# CuTile Python uses @ct.kernel decorator +CUTILE_KERNEL_PATTERN = r"@ct\.kernel" + +def check_cutile_impl(code: str) -> Tuple[bool, str]: + """ + Check for valid CuTile Python kernel implementation. + + Requirements: + - Must have @ct.kernel decorator + """ + code = _strip_comments(code) + if not re.search(CUTILE_KERNEL_PATTERN, code): + return (True, "Missing @ct.kernel decorator") + return (False, "") + + # ============================================================================= # TIMING MANIPULATION CHECKS - Reward Hacking Patterns # From adversarial hack PR and DeepReinforce blog @@ -559,6 +576,7 @@ def check_precision_downgrade(code: str, precision: str = "fp32") -> Tuple[bool, "tk_impl": check_tk_impl, "cute_impl": check_cute_impl, "tilelang_impl": check_tilelang_impl, + "cutile_impl": check_cutile_impl, } # Checks that require additional parameters beyond just code @@ -583,6 +601,7 @@ def check_precision_downgrade(code: str, precision: str = "fp32") -> Tuple[bool, "cute": "cute_impl", "cutlass": "cute_impl", # alias "tilelang": "tilelang_impl", + "cutile": "cutile_impl", } # These are optional checks (by user's decision) - flagged as warnings