Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/support/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def callback_libdevice_path(arch):
return ""


@tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version")
@tvm_ffi.register_global_func("tvm.support.nvcc.get_compute_version")
def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target.

Expand Down Expand Up @@ -1060,7 +1060,7 @@ def have_cudagraph():
return False


@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16")
@tvm_ffi.register_global_func("tvm.support.nvcc.supports_bf16")
def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not

Expand All @@ -1076,7 +1076,7 @@ def have_bf16(compute_version):
return False


@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8")
@tvm_ffi.register_global_func("tvm.support.nvcc.supports_fp8")
def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not

Expand All @@ -1094,7 +1094,7 @@ def have_fp8(compute_version):
return False


@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp4")
@tvm_ffi.register_global_func("tvm.support.nvcc.supports_fp4")
def have_fp4(compute_version):
"""Whether fp4 support is provided in the specified compute capability or not

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class CUDAModuleNode : public ffi::ModuleObj {
auto fcompile = ffi::Function::GetGlobal("tvm_callback_cuda_compile");
TVM_FFI_CHECK(fcompile.has_value(), RuntimeError)
<< "fmt=='cuda' requires tvm_callback_cuda_compile to be registered. "
<< "Import tvm.contrib.nvcc.";
<< "Import tvm.support.nvcc.";
return (*fcompile)(source).cast<ffi::Bytes>();
}

Expand Down
2 changes: 1 addition & 1 deletion src/target/rocm/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ ffi::Module BuildAMDGPU(IRModule mod, Target target) {

auto flink = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_link");
TVM_FFI_ICHECK(flink.has_value())
<< "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm";
<< "Require tvm_callback_rocm_link to exist, do import tvm.support.rocm";

TVMFFIByteArray arr;
arr.data = &obj[0];
Expand Down
10 changes: 5 additions & 5 deletions src/tirx/transform/unsupported_dtype_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ namespace transform {
bool CheckDataTypeSupport(const Target& target, const std::string& support_func_name) {
bool has_native_support = false;
if (target->kind->name == "cuda") {
if (auto get_cv = tvm::ffi::Function::GetGlobal("tvm.contrib.nvcc.get_compute_version")) {
if (auto get_cv = tvm::ffi::Function::GetGlobal("tvm.support.nvcc.get_compute_version")) {
std::string compute_version = (*get_cv)(target).cast<std::string>();
if (auto check_support = tvm::ffi::Function::GetGlobal(support_func_name)) {
has_native_support = (*check_support)(compute_version).cast<bool>();
Expand All @@ -750,7 +750,7 @@ Pass BF16ComputeLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
if (opt_target.defined() &&
CheckDataTypeSupport(opt_target.value(), "tvm.contrib.nvcc.supports_bf16")) {
CheckDataTypeSupport(opt_target.value(), "tvm.support.nvcc.supports_bf16")) {
return f;
}
return BF16ComputeLegalizer().Legalize(f);
Expand All @@ -767,7 +767,7 @@ Pass BF16StorageLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
if (opt_target.defined() &&
CheckDataTypeSupport(opt_target.value(), "tvm.contrib.nvcc.supports_bf16")) {
CheckDataTypeSupport(opt_target.value(), "tvm.support.nvcc.supports_bf16")) {
return f;
}
return BF16StorageLegalizer().Legalize(f);
Expand All @@ -784,7 +784,7 @@ Pass FP8ComputeLegalize(ffi::String promote_dtype) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
if (opt_target.defined() &&
CheckDataTypeSupport(opt_target.value(), "tvm.contrib.nvcc.supports_fp8")) {
CheckDataTypeSupport(opt_target.value(), "tvm.support.nvcc.supports_fp8")) {
return f;
}
return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f);
Expand All @@ -801,7 +801,7 @@ Pass FP8StorageLegalize() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
if (opt_target.defined() &&
CheckDataTypeSupport(opt_target.value(), "tvm.contrib.nvcc.supports_fp8")) {
CheckDataTypeSupport(opt_target.value(), "tvm.support.nvcc.supports_fp8")) {
return f;
}
return FP8StorageLegalizer().Legalize(f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test contrib.cc with ccache"""
"""Test support.cc with ccache"""
Comment thread
cbalint13 marked this conversation as resolved.

import os
import shutil
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion web/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ make
```

This command will create the follow files:
- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.contrib.emcc` will link into.
- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.support.emcc` will link into.
- `dist/wasm/tvmjs_runtime.wasm` a standalone wasm runtime for testing purposes.
- `dist/wasm/tvmjs_runtime.wasi.js` a WASI compatible library generated by emscripten that can be fed into runtime.

Expand Down
Loading