From 98758c4fa2f5bab704d7331c9b4d97132ca6f145 Mon Sep 17 00:00:00 2001 From: Ben van Werkhoven Date: Fri, 5 Jun 2026 22:17:21 +0200 Subject: [PATCH 1/2] issue #356, replace pycuda with cuda-python as default --- kernel_tuner/core.py | 94 ++++++++++++++++---------------- test/test_core.py | 105 +++++++++++++----------------------- test/test_cuda_functions.py | 33 ++++++++++++ test/test_energy.py | 4 +- test/test_observers.py | 8 +-- test/test_runners.py | 14 ++--- test/test_util_functions.py | 2 +- 7 files changed, 129 insertions(+), 131 deletions(-) diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 69d952541..0c2e4e3d0 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -18,6 +18,7 @@ def _get_cupy(): from kernel_tuner.accuracy import Tunable from kernel_tuner.observers.observer import BenchmarkObserver, ContinuousObserver, OutputObserver, PrologueObserver from kernel_tuner.observers.tegra import TegraObserver +from kernel_tuner.backends.backend import GPUBackend try: import torch @@ -230,6 +231,24 @@ def instantiate_observer(observer, args): raise TypeError(f"Invalid observer: {observer!r} does not extend BenchmarkObserver") +def _select_default_cuda_backend(lang): + """ Select default CUDA backend, looks for which backends are installed. """ + # First try cuda-python (nvcuda) + from kernel_tuner.backends.nvcuda import CudaFunctions, driver + if driver: + return CudaFunctions + # Then try Cupy + if _get_cupy(): + from kernel_tuner.backends.cupy import CupyFunctions + return CupyFunctions + # Then try PyCUDA + from kernel_tuner.backends.pycuda import PyCudaFunctions, pycuda_available + if pycuda_available: + return PyCudaFunctions + # Ran out of options + raise RuntimeError("Error: CUDA selected/detected, but missing CUDA dependencies, please run 'pip install cuda-python', or install cupy or pycuda.") + + class DeviceInterface(object): """Class that offers a High-Level Device Interface to the rest of the Kernel Tuner.""" @@ -286,67 +305,46 @@ def __init__( observer_args = dict(device=device, platform=platform, compiler=compiler, lang=lang) observers = [instantiate_observer(ob, observer_args) for ob in observers] - if lang.upper() == "CUDA": + backend_options = dict(compiler_options=compiler_options, iterations=iterations) + + # first check for explicitly selected backends + if lang.upper() == "PYCUDA": from kernel_tuner.backends.pycuda import PyCudaFunctions - dev = PyCudaFunctions( - device, - compiler_options=compiler_options, - iterations=iterations, - observers=observers, - ) + backend = PyCudaFunctions elif lang.upper() == "CUPY": from kernel_tuner.backends.cupy import CupyFunctions - dev = CupyFunctions( - device, - compiler_options=compiler_options, - iterations=iterations, - observers=observers, - ) + backend = CupyFunctions elif lang.upper() == "NVCUDA": from kernel_tuner.backends.nvcuda import CudaFunctions - dev = CudaFunctions( - device, - compiler_options=compiler_options, - iterations=iterations, - observers=observers, - ) + backend = CudaFunctions + elif lang.upper() == "CUDA": + # Select default CUDA backend, based on availability + backend = _select_default_cuda_backend(lang) elif lang.upper() == "OPENCL": from kernel_tuner.backends.opencl import OpenCLFunctions - dev = OpenCLFunctions( - device, - platform, - compiler_options=compiler_options, - iterations=iterations, - observers=observers, - ) - elif lang.upper() in ["C", "FORTRAN"]: - from kernel_tuner.backends.compiler import CompilerFunctions - dev = CompilerFunctions( - compiler=compiler, - compiler_options=compiler_options, - iterations=iterations, - observers=observers, - ) + backend = OpenCLFunctions + backend_options["platform"] = platform elif lang.upper() == "HIP": from kernel_tuner.backends.hip import HipFunctions - dev = HipFunctions( - device, - compiler_options=compiler_options, - iterations=iterations, - observers=observers, - ) + backend = HipFunctions + elif lang.upper() in ["C", "FORTRAN"]: + from kernel_tuner.backends.compiler import CompilerFunctions + backend = CompilerFunctions + backend_options["compiler"] = compiler + backend_options["observers"] = observers elif lang.upper() == "HYPERTUNER": from kernel_tuner.backends.hypertuner import HypertunerFunctions - dev = HypertunerFunctions( - iterations=iterations, - compiler_options=compiler_options - ) + backend = HypertunerFunctions self.requires_warmup = False else: raise NotImplementedError( "Sorry, support for languages other than CUDA, OpenCL, HIP, C, and Fortran is not implemented yet" ) - self.dev = dev + + if issubclass(backend, GPUBackend): + backend_options["device"] = device + backend_options["observers"] = observers + self.dev = backend(**backend_options) # look for NVMLObserver and TegraObserver in observers, if present, enable special tunable parameters through nvml/tegra self.use_nvml = False @@ -381,9 +379,9 @@ def __init__( self.iterations = iterations self.lang = lang - self.units = dev.units - self.name = dev.name - self.max_threads = dev.max_threads + self.units = self.dev.units + self.name = self.dev.name + self.max_threads = self.dev.max_threads if not quiet: print("Using: " + self.dev.name) diff --git a/test/test_core.py b/test/test_core.py index aefe70cf8..67c8b7c6a 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -11,7 +11,7 @@ from kernel_tuner import core from kernel_tuner.interface import Options -from .context import skip_if_no_pycuda +from .context import skip_if_no_cuda mock_config = {"return_value.compile.return_value": "compile", @@ -52,84 +52,51 @@ def env(): yield dev, instance -@skip_if_no_pycuda -def test_default_verify_function(env): - - # gpu_args = dev.ready_argument_list(args) - # func = dev.compile_kernel(instance, verbose) - - dev, instance = env - args = instance.arguments - verbose = True - - # 1st case, correct answer but not enough items in the list - answer = [args[1] + args[2]] - try: - core._default_verify_function(instance, answer, args, 1e-6, verbose) - print("Expected a TypeError to be raised") - assert False - except TypeError as expected_error: - print(str(expected_error)) - assert "The length of argument list and provided results do not match." == str(expected_error) - except Exception: - print("Expected a TypeError to be raised") - assert False - - # 2nd case, answer is of wrong type - answer = [np.ubyte([12]), None, None, None] - try: - core._default_verify_function(instance, answer, args, 1e-6, verbose) - # dev.check_kernel_output(func, gpu_args, instance, answer, 1e-6, None, verbose) - print("Expected a TypeError to be raised") - assert False - except TypeError as expected_error: - print(str(expected_error)) - assert "Element 0" in str(expected_error) - except Exception: - print("Expected a TypeError to be raised") - assert False - - instance.delete_temp_files() - assert True +@skip_if_no_cuda +def test_check_kernel_output(): + kernel_string = """ + __global__ void copy(float *out, float *in, int n) { + int i = threadIdx.x + blockDim.x * blockIdx.x; + if (i < n) + out[i] = in[i]; + } + """ + + kernel_name = "copy" + lang = "CUDA" -@patch('kernel_tuner.backends.pycuda.PyCudaFunctions') -def test_check_kernel_output(dev_func_interface): - dev_func_interface.configure_mock(**mock_config) + # Create the object under test (DeviceInterface) + kernel_source = core.KernelSource(kernel_name, kernel_string, lang, None) + dev = core.DeviceInterface(kernel_source) - dev = core.DeviceInterface(core.KernelSource("name", "", lang="CUDA")) - dfi = dev.dev + # Setup GPU args + n = np.int32(2000) + input_data = np.random.random(n).astype(np.float32) + output_data = np.zeros_like(input_data) + args = [output_data, input_data, n] + gpu_args = dev.dev.ready_argument_list(args) - answer = [np.zeros(4).astype(np.float32)] - instance = core.KernelInstance("name", None, "kernel_string", "temp_files", (256, 1, 1), (1, 1, 1), {}, answer) - wrong = [np.array([1, 2, 3, 4]).astype(np.float32)] - atol = 1e-6 - dev.check_kernel_output('func', answer, instance, answer, atol, None, True) + # Create kernel instance and compile GPU kernel + class FakeOptions(dict): + def __getattr__(self, name): + if not name in self: + return None + return self[name] - dfi.refresh_memory.assert_called() - dfi.run_kernel.assert_called_once_with('func', answer, (256, 1, 1), (1, 1, 1)) + kernel_options = FakeOptions(dict(kernel_name=kernel_name, arguments=args, problem_size=n)) + instance = dev.create_kernel_instance(kernel_source, kernel_options, {}, True) + func = dev.compile_kernel(instance, True) - print(dfi.mock_calls) + # Run check_kernel_output + # As the kernel only copies the data this should complete without throwing + # an exception + answer = [input_data, None, None] + dev.check_kernel_output(func, gpu_args, instance, answer, 1e-6, None, True) - assert dfi.refresh_memory.called == 1 - assert dfi.memcpy_dtoh.called == 1 - for name, args, _ in dfi.mock_calls: - if name == 'memcpy_dtoh': - assert all(args[0] == answer[0]) - assert all(args[1] == answer[0]) - # the following call to check_kernel_output is expected to fail because - # the answer is non-zero, while the memcpy_dtoh function on the Mocked object - # obviously does not result in the result_host array containing anything - # non-zero - try: - dev.check_kernel_output('func', wrong, instance, wrong, atol, None, True) - print("check_kernel_output failed to throw an exception") - assert False - except Exception: - assert True def test_default_verify_function_arrays(): diff --git a/test/test_cuda_functions.py b/test/test_cuda_functions.py index 1fe509c12..d142c5028 100644 --- a/test/test_cuda_functions.py +++ b/test/test_cuda_functions.py @@ -102,3 +102,36 @@ def test_compile_include(): def test_tune_kernel(env): result, _ = tune_kernel(*env, lang="nvcuda", verbose=True) assert len(result) > 0 + +@skip_if_no_cuda +def test_copy_constant_memory_args(): + kernel_string = """ + __constant__ float my_constant_data[100]; + __global__ void copy_data_kernel(float* output) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < 100) { + output[idx] = my_constant_data[idx]; + } + } + """ + + kernel_name = "copy_data_kernel" + kernel_sources = KernelSource(kernel_name, kernel_string, "NVCUDA") + kernel_instance = KernelInstance(kernel_name, kernel_sources, kernel_string, [], None, None, dict(), []) + dev = nvcuda.CudaFunctions(0) + kernel = dev.compile(kernel_instance) + + my_constant_data = np.full(100, 23).astype(np.float32) + cmem_args = {'my_constant_data': my_constant_data} + dev.copy_constant_memory_args(cmem_args) + + output = np.full(100, 0).astype(np.float32) + gpu_args = dev.ready_argument_list([output]) + + threads = (100, 1, 1) + grid = (1, 1, 1) + dev.run_kernel(kernel, gpu_args, threads, grid) + + dev.memcpy_dtoh(output, gpu_args[0]) + + assert (my_constant_data == output).all() diff --git a/test/test_energy.py b/test/test_energy.py index 187ac1cdc..82d885169 100644 --- a/test/test_energy.py +++ b/test/test_energy.py @@ -2,11 +2,11 @@ from kernel_tuner.energy import energy -from .context import skip_if_no_pycuda, skip_if_no_pynvml +from .context import skip_if_no_cuda, skip_if_no_pynvml cache_filename = os.path.dirname(os.path.realpath(__file__)) + "/synthetic_fp32_cache_NVIDIA_RTX_A4000.json" -@skip_if_no_pycuda +@skip_if_no_cuda @skip_if_no_pynvml def test_create_power_frequency_model(): diff --git a/test/test_observers.py b/test/test_observers.py index 69a94a9bb..1ad8f8454 100644 --- a/test/test_observers.py +++ b/test/test_observers.py @@ -21,7 +21,7 @@ from .test_runners import env # noqa: F401 -@skip_if_no_pycuda +@skip_if_no_cuda @skip_if_no_pynvml def test_nvml_observer(env): nvmlobserver = NVMLObserver(["nvml_energy", "temperature"]) @@ -33,7 +33,7 @@ def test_nvml_observer(env): assert "temperature" in result[0] assert result[0]["temperature"] > 0 -@skip_if_no_pycuda +@skip_if_no_cuda def test_custom_observer(env): env[-1]["block_size_x"] = [128] @@ -55,7 +55,7 @@ def __init__(self, args): def get_results(self): return {"observer_args": self.observer_args} - + result, _ = kernel_tuner.tune_kernel(*env_compiler, observers=[lambda args: MyObserver(args)], compiler_options=["-fopenmp"]) # Check if the observer has correctly received the lang option @@ -63,7 +63,7 @@ def get_results(self): @skip_if_no_pycuda def test_register_observer_pycuda(env): - result, _ = kernel_tuner.tune_kernel(*env, observers=[RegisterObserver()], lang='CUDA') + result, _ = kernel_tuner.tune_kernel(*env, observers=[RegisterObserver()], lang='PYCUDA') assert "num_regs" in result[0] assert result[0]["num_regs"] > 0 diff --git a/test/test_runners.py b/test/test_runners.py index 37a049d6d..a55a16df6 100644 --- a/test/test_runners.py +++ b/test/test_runners.py @@ -8,7 +8,7 @@ from kernel_tuner.interface import Options, _device_options, _kernel_options, _tuning_options from kernel_tuner.runners.sequential import SequentialRunner -from .context import skip_if_no_pycuda +from .context import skip_if_no_cuda cache_filename = os.path.dirname( os.path.realpath(__file__)) + "/test_cache_file.json" @@ -38,7 +38,7 @@ def env(): return ["vector_add", kernel_string, size, args, tune_params] -@skip_if_no_pycuda +@skip_if_no_cuda def test_sequential_runner_alt_block_size_names(env): kernel_string = """__global__ void vector_add(float *c, float *a, float *b, int n) { @@ -71,7 +71,7 @@ def test_sequential_runner_alt_block_size_names(env): assert len(result) == len(tune_params["block_dim_x"]) -@skip_if_no_pycuda +@skip_if_no_cuda def test_smem_args(env): result, _ = tune_kernel(*env, smem_args=dict(size="block_size_x*4"), @@ -86,7 +86,7 @@ def test_smem_args(env): assert len(result) == len(tune_params["block_size_x"]) -@skip_if_no_pycuda +@skip_if_no_cuda def test_build_cache(env): if not os.path.isfile(cache_filename): result, _ = tune_kernel(*env, @@ -157,7 +157,7 @@ def test_restrictions(env): assert len(result) == 6 -@skip_if_no_pycuda +@skip_if_no_cuda def test_time_keeping(env): kernel_name, kernel_string, size, args, tune_params = env answer = [args[1] + args[2], None, None, None] @@ -223,7 +223,7 @@ def test_random_sample(env): assert v['time'] > 0.0 and v['time'] < 1.0 -@skip_if_no_pycuda +@skip_if_no_cuda def test_interface_handles_compile_failures(env): kernel_name, kernel_string, size, args, tune_params = env @@ -260,7 +260,7 @@ def test_interface_handles_compile_failures(env): assert isinstance(failed_config["time"], util.CompilationFailedConfig) -@skip_if_no_pycuda +@skip_if_no_cuda def test_runner(env): kernel_name, kernel_source, problem_size, arguments, tune_params = env diff --git a/test/test_util_functions.py b/test/test_util_functions.py index ba8727765..c0fe6d270 100644 --- a/test/test_util_functions.py +++ b/test/test_util_functions.py @@ -276,7 +276,7 @@ def test_detect_language3(): @skip_if_no_pycuda def test_get_device_interface1(): - lang = "CUDA" + lang = "PYCUDA" dev = core.DeviceInterface(core.KernelSource("", "", lang=lang)) assert isinstance(dev, core.DeviceInterface) assert isinstance(dev.dev, pycuda.PyCudaFunctions) From 547998b9974ca0ef7ff4ac1530a5c7ee5687bbae Mon Sep 17 00:00:00 2001 From: Ben van Werkhoven Date: Fri, 5 Jun 2026 22:23:03 +0200 Subject: [PATCH 2/2] update pyproject toml file --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e67fc3f7f..17986d671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,9 +94,9 @@ priority = "explicit" hip-python = { version = "^6.3.3.540.31", source = "testpypi", optional = true } # Note: when released, switch this package to pypi and remove tool.poetry.source and move this to [project.optional-dependencies] [project.optional-dependencies] -cuda = ["pycuda>=2025.1", "nvidia-ml-py>=12.535.108", "pynvml>=11.4.1"] # Attention: if pycuda is changed here, also change `session.install("pycuda")` in the Noxfile +cuda = ["cuda-python>=12.6.0", "nvidia-ml-py>=12.535.108", "pynvml>=11.4.1"] # Attention: if pycuda is changed here, also change `session.install("pycuda")` in the Noxfile opencl = ["pyopencl"] # Attention: if pyopencl is changed here, also change `session.install("pyopencl")` in the Noxfile -cuda_opencl = ["pycuda>=2024.1", "pyopencl"] # Attention: if pycuda is changed here, also change `session.install("pycuda")` in the Noxfile +cuda_opencl = ["cuda-python>=12.6.0", "pyopencl"] # Attention: if pycuda is changed here, also change `session.install("pycuda")` in the Noxfile hip = ["hip-python"] tutorial = ["jupyter>=1.0.0", "matplotlib>=3.5.0", "nvidia-ml-py>=12.535.108"]