From 0aec0bbded10bbbfdf9cec3c209251965b17a0e8 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Mon, 1 Jun 2026 12:09:48 +0300 Subject: [PATCH] [REFACTOR][PYTHON] Revisit lifted support modules from tvm.contrib --- python/tvm/support/nvcc.py | 8 ++++---- src/runtime/cuda/cuda_module.cc | 2 +- src/target/rocm/llvm/codegen_amdgpu.cc | 2 +- src/tirx/transform/unsupported_dtype_legalize.cc | 10 +++++----- tests/python/{contrib => support}/test_ccache.py | 2 +- tests/python/{contrib => support}/test_popen_pool.py | 0 tests/python/{contrib => support}/test_util.py | 0 web/README.md | 2 +- 8 files changed, 13 insertions(+), 13 deletions(-) rename tests/python/{contrib => support}/test_ccache.py (98%) rename tests/python/{contrib => support}/test_popen_pool.py (100%) rename tests/python/{contrib => support}/test_util.py (100%) diff --git a/python/tvm/support/nvcc.py b/python/tvm/support/nvcc.py index b985e74778f7..94dbd59ff66b 100644 --- a/python/tvm/support/nvcc.py +++ b/python/tvm/support/nvcc.py @@ -906,7 +906,7 @@ def callback_libdevice_path(arch): return "" -@tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version") +@tvm_ffi.register_global_func("tvm.support.nvcc.get_compute_version") def get_target_compute_version(target=None): """Utility function to get compute capability of compilation target. @@ -1060,7 +1060,7 @@ def have_cudagraph(): return False -@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16") +@tvm_ffi.register_global_func("tvm.support.nvcc.supports_bf16") def have_bf16(compute_version): """Either bf16 support is provided in the compute capability or not @@ -1076,7 +1076,7 @@ def have_bf16(compute_version): return False -@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8") +@tvm_ffi.register_global_func("tvm.support.nvcc.supports_fp8") def have_fp8(compute_version): """Whether fp8 support is provided in the specified compute capability or not @@ -1094,7 +1094,7 @@ def have_fp8(compute_version): return False -@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp4") +@tvm_ffi.register_global_func("tvm.support.nvcc.supports_fp4") def have_fp4(compute_version): """Whether fp4 support is provided in the specified compute capability or not diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 9492f943a869..3f182afb8245 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -162,7 +162,7 @@ class CUDAModuleNode : public ffi::ModuleObj { auto fcompile = ffi::Function::GetGlobal("tvm_callback_cuda_compile"); TVM_FFI_CHECK(fcompile.has_value(), RuntimeError) << "fmt=='cuda' requires tvm_callback_cuda_compile to be registered. " - << "Import tvm.contrib.nvcc."; + << "Import tvm.support.nvcc."; return (*fcompile)(source).cast(); } diff --git a/src/target/rocm/llvm/codegen_amdgpu.cc b/src/target/rocm/llvm/codegen_amdgpu.cc index 2da399231e31..12a8aed79bd8 100644 --- a/src/target/rocm/llvm/codegen_amdgpu.cc +++ b/src/target/rocm/llvm/codegen_amdgpu.cc @@ -306,7 +306,7 @@ ffi::Module BuildAMDGPU(IRModule mod, Target target) { auto flink = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_link"); TVM_FFI_ICHECK(flink.has_value()) - << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm"; + << "Require tvm_callback_rocm_link to exist, do import tvm.support.rocm"; TVMFFIByteArray arr; arr.data = &obj[0]; diff --git a/src/tirx/transform/unsupported_dtype_legalize.cc b/src/tirx/transform/unsupported_dtype_legalize.cc index 558a3ca43788..0bd703358cd0 100644 --- a/src/tirx/transform/unsupported_dtype_legalize.cc +++ b/src/tirx/transform/unsupported_dtype_legalize.cc @@ -736,7 +736,7 @@ namespace transform { bool CheckDataTypeSupport(const Target& target, const std::string& support_func_name) { bool has_native_support = false; if (target->kind->name == "cuda") { - if (auto get_cv = tvm::ffi::Function::GetGlobal("tvm.contrib.nvcc.get_compute_version")) { + if (auto get_cv = tvm::ffi::Function::GetGlobal("tvm.support.nvcc.get_compute_version")) { std::string compute_version = (*get_cv)(target).cast(); if (auto check_support = tvm::ffi::Function::GetGlobal(support_func_name)) { has_native_support = (*check_support)(compute_version).cast(); @@ -750,7 +750,7 @@ Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto opt_target = f->GetAttr(tvm::attr::kTarget); if (opt_target.defined() && - CheckDataTypeSupport(opt_target.value(), "tvm.contrib.nvcc.supports_bf16")) { + CheckDataTypeSupport(opt_target.value(), "tvm.support.nvcc.supports_bf16")) { return f; } return BF16ComputeLegalizer().Legalize(f); @@ -767,7 +767,7 @@ Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto opt_target = f->GetAttr(tvm::attr::kTarget); if (opt_target.defined() && - CheckDataTypeSupport(opt_target.value(), "tvm.contrib.nvcc.supports_bf16")) { + CheckDataTypeSupport(opt_target.value(), "tvm.support.nvcc.supports_bf16")) { return f; } return BF16StorageLegalizer().Legalize(f); @@ -784,7 +784,7 @@ Pass FP8ComputeLegalize(ffi::String promote_dtype) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto opt_target = f->GetAttr(tvm::attr::kTarget); if (opt_target.defined() && - CheckDataTypeSupport(opt_target.value(), "tvm.contrib.nvcc.supports_fp8")) { + CheckDataTypeSupport(opt_target.value(), "tvm.support.nvcc.supports_fp8")) { return f; } return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f); @@ -801,7 +801,7 @@ Pass FP8StorageLegalize() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto opt_target = f->GetAttr(tvm::attr::kTarget); if (opt_target.defined() && - CheckDataTypeSupport(opt_target.value(), "tvm.contrib.nvcc.supports_fp8")) { + CheckDataTypeSupport(opt_target.value(), "tvm.support.nvcc.supports_fp8")) { return f; } return FP8StorageLegalizer().Legalize(f); diff --git a/tests/python/contrib/test_ccache.py b/tests/python/support/test_ccache.py similarity index 98% rename from tests/python/contrib/test_ccache.py rename to tests/python/support/test_ccache.py index 013b6896cbb0..f1f182562c82 100644 --- a/tests/python/contrib/test_ccache.py +++ b/tests/python/support/test_ccache.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Test contrib.cc with ccache""" +"""Test support.cc with ccache""" import os import shutil diff --git a/tests/python/contrib/test_popen_pool.py b/tests/python/support/test_popen_pool.py similarity index 100% rename from tests/python/contrib/test_popen_pool.py rename to tests/python/support/test_popen_pool.py diff --git a/tests/python/contrib/test_util.py b/tests/python/support/test_util.py similarity index 100% rename from tests/python/contrib/test_util.py rename to tests/python/support/test_util.py diff --git a/web/README.md b/web/README.md index 9b3cda1fb76c..9488389e9b17 100644 --- a/web/README.md +++ b/web/README.md @@ -43,7 +43,7 @@ make ``` This command will create the follow files: -- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.contrib.emcc` will link into. +- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.support.emcc` will link into. - `dist/wasm/tvmjs_runtime.wasm` a standalone wasm runtime for testing purposes. - `dist/wasm/tvmjs_runtime.wasi.js` a WASI compatible library generated by emscripten that can be fed into runtime.