From f0c745319e002c7fda78305273cae14a0d950beb Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 13 Feb 2026 22:19:34 -0800 Subject: [PATCH 01/28] Making progress. --- openequivariance/CMakeLists.txt | 62 ++ .../extension/libtorch_tp_jit.cpp | 626 ++--------------- .../extension/libtorch_tp_jit_stable.cpp | 104 +++ .../openequivariance/extension/torch_core.hpp | 650 ++++++++++++++++++ 4 files changed, 880 insertions(+), 562 deletions(-) create mode 100644 openequivariance/CMakeLists.txt create mode 100644 openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp create mode 100644 openequivariance/openequivariance/extension/torch_core.hpp diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt new file mode 100644 index 0000000..472bc2d --- /dev/null +++ b/openequivariance/CMakeLists.txt @@ -0,0 +1,62 @@ +cmake_minimum_required(VERSION 3.15...3.30) +project(openequivariance_stable_ext LANGUAGES CXX) + +set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.10.0%2Bcpu.zip") +set(LIBTORCH_ZIP "${CMAKE_BINARY_DIR}/libtorch-shared-with-deps-2.10.0+cpu.zip") +set(LIBTORCH_EXTRACT_DIR "${CMAKE_BINARY_DIR}/libtorch_extract") +set(LIBTORCH_INCLUDE_DIR "${LIBTORCH_EXTRACT_DIR}/libtorch/include") + +file(DOWNLOAD ${LIBTORCH_URL} ${LIBTORCH_ZIP} SHOW_PROGRESS) +file(MAKE_DIRECTORY ${LIBTORCH_EXTRACT_DIR}) +file(ARCHIVE_EXTRACT + INPUT ${LIBTORCH_ZIP} + DESTINATION ${LIBTORCH_EXTRACT_DIR} + PATTERNS "libtorch/include/*" +) + +add_custom_target(libtorch_headers ALL DEPENDS ${LIBTORCH_INCLUDE_DIR}) + +set(EXT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/extension") +set(EXT_UTIL_DIR "${EXT_DIR}/util") +set(EXT_JSON_DIR "${EXT_DIR}/json11") + +set(OEQ_SOURCES + ${EXT_DIR}/libtorch_tp_jit_stable.cpp + ${EXT_JSON_DIR}/json11.cpp +) + +function(add_stable_extension target_name backend_define) + add_library(${target_name} MODULE ${OEQ_SOURCES}) + set_target_properties(${target_name} PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON + ) + target_compile_definitions(${target_name} PRIVATE ${backend_define}=1) + target_include_directories(${target_name} PRIVATE + ${LIBTORCH_INCLUDE_DIR} + ${EXT_DIR} + ${EXT_UTIL_DIR} + ${EXT_JSON_DIR} + ) + add_dependencies(${target_name} libtorch_headers) +endfunction() + +find_package(CUDAToolkit QUIET) +find_package(hip QUIET) + +if(CUDAToolkit_FOUND) + message(STATUS "Building stable extension with CUDA backend.") + add_stable_extension(libtorch_tp_jit_stable_cuda CUDA_BACKEND) + target_link_libraries(libtorch_tp_jit_stable_cuda PRIVATE CUDA::cudart CUDA::cuda_driver) +endif() + +if(hip_FOUND) + message(STATUS "Building stable extension with HIP backend.") + add_stable_extension(libtorch_tp_jit_stable_hip HIP_BACKEND) + target_link_libraries(libtorch_tp_jit_stable_hip PRIVATE hiprtc) +endif() + +if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND) + message(FATAL_ERROR "Neither CUDAToolkit nor HIP was found. Cannot build the stable extension.") +endif() diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index 6216909..bcd2c7b 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -1,603 +1,105 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "json11/json11.hpp" #include #ifdef CUDA_BACKEND #include - #include "backend_cuda.hpp" - #include "group_mm_cuda.hpp" - using JITKernel = CUJITKernel; - using GPU_Allocator = CUDA_Allocator; - - template - using GroupMM = GroupMMCUDA; - - inline Stream get_current_stream() { - return c10::cuda::getCurrentCUDAStream(); - } #endif #ifdef HIP_BACKEND #include - #include "backend_hip.hpp" - #include "group_mm_hip.hpp" - using JITKernel = HIPJITKernel; - using GPU_Allocator = HIP_Allocator; - - template - using GroupMM = GroupMMHIP; - - inline Stream get_current_stream() { - return c10::hip::getCurrentHIPStream(); - } #endif -#include "tensorproducts.hpp" -#include "convolution.hpp" - -using namespace std; -using json = json11::Json; - #include -#include -#include #include #include +#include +#include -torch::Dtype enum_to_torch_dtype(int64_t i){ - switch(i) { - case 1: return torch::kFloat; - case 2: return torch::kDouble; - case 3: return torch::kInt; - case 4: return torch::kLong; - case 5: return torch::kUInt8; - } - throw logic_error("Unsupported tensor datatype!"); -} - -inline void check_tensor(const torch::Tensor &tensor, - std::initializer_list expected_shape, - torch::Dtype expected_dtype, - std::string tensor_name) { - TORCH_CHECK(tensor.sizes() == expected_shape, - "Shape mismatch for tensor '", tensor_name, - "'. Expected: ", torch::IntArrayRef(expected_shape), - ". Got: ", tensor.sizes()); - TORCH_CHECK(tensor.device().is_cuda(), "Tensor '", tensor_name, "' is not on the GPU."); - TORCH_CHECK(tensor.dtype() == expected_dtype, "Dtype mismatch for tensor '", tensor_name, "'. Expected: ", expected_dtype, ". Got: ", tensor.dtype()); -} - -inline void* data_ptr(const torch::Tensor &tensor) { - if(tensor.dtype() == torch::kFloat) - return reinterpret_cast(tensor.data_ptr()); - else if(tensor.dtype() == torch::kDouble) - return reinterpret_cast(tensor.data_ptr()); - else if(tensor.dtype() == torch::kLong) - return reinterpret_cast(tensor.data_ptr()); - else if(tensor.dtype() == torch::kByte) - return reinterpret_cast(tensor.data_ptr()); // Replaces kUInt8 - else if(tensor.dtype() == torch::kInt) - return reinterpret_cast(tensor.data_ptr()); - else - throw logic_error("Unsupported tensor datatype!"); -} - -std::unordered_map parse_json_config(const json &j_obj) { - std::unordered_map result; - for (const auto &kv : j_obj.object_items()) { - result[kv.first] = static_cast(kv.second.number_value()); - } - return result; -} - -struct KernelProp { - int64_t L1_dim, L2_dim, L3_dim, weight_numel; - bool shared_weights; - torch::Dtype irrep_dtype; - torch::Dtype weight_dtype; - - int64_t workspace_size; // Convolution only - bool deterministic; - torch::Dtype idx_dtype; - torch::Dtype workspace_dtype; - - KernelProp() : - L1_dim(0), L2_dim(0), L3_dim(0), weight_numel(0), - shared_weights(false), - irrep_dtype(torch::kFloat), weight_dtype(torch::kFloat), - workspace_size(0), deterministic(false), - idx_dtype(torch::kInt), workspace_dtype(torch::kByte) {} - - KernelProp( - std::unordered_map &kernel_dims, bool is_convolution): - L1_dim(kernel_dims.at("L1_dim")), - L2_dim(kernel_dims.at("L2_dim")), - L3_dim(kernel_dims.at("L3_dim")), - weight_numel(kernel_dims.at("weight_numel")), - shared_weights(kernel_dims.at("shared_weights")), - irrep_dtype(enum_to_torch_dtype(kernel_dims.at("irrep_dtype"))), - weight_dtype(enum_to_torch_dtype(kernel_dims.at("weight_dtype"))), - workspace_dtype(torch::kByte) { - if(is_convolution) { - workspace_size = kernel_dims.at("workspace_size"); - deterministic = kernel_dims.at("deterministic"); - idx_dtype = enum_to_torch_dtype(kernel_dims.at("idx_dtype")); - } - } -}; - -std::unordered_map>, - KernelProp - >> tp_cache; - -std::unordered_map>, - KernelProp - >> conv_cache; - -std::mutex mut; +using Tensor = torch::Tensor; +using Dtype = torch::Dtype; -std::pair*, KernelProp> - compile_tp_with_caching(const torch::Tensor &json_bytes, - int64_t hash) { - { - const std::lock_guard lock(mut); - auto it = tp_cache.find(hash); - if (it == tp_cache.end()) { - torch::Tensor cpu_tensor = json_bytes.to(torch::kCPU).contiguous(); - std::string json_payload( - reinterpret_cast(cpu_tensor.data_ptr()), - cpu_tensor.numel() - ); +constexpr Dtype kFloat = torch::kFloat; +constexpr Dtype kDouble = torch::kDouble; +constexpr Dtype kInt = torch::kInt; +constexpr Dtype kLong = torch::kLong; +constexpr Dtype kByte = torch::kByte; - std::string err; - json root = json::parse(json_payload, err); - if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); +#define CHECK TORCH_CHECK +#define BOX(x) x +#define REGISTER_LIBRARY_IMPL TORCH_LIBRARY_IMPL +#define REGISTER_LIBRARY TORCH_LIBRARY - std::string kernel_src = root["kernel"].string_value(); - auto forward_cfg = parse_json_config(root["forward_config"]); - auto backward_cfg = parse_json_config(root["backward_config"]); - auto dbackward_cfg = parse_json_config(root["double_backward_config"]); - auto kernel_prop_map = parse_json_config(root["kernel_prop"]); +#include "torch_core.hpp" - auto jit_tp_impl = std::make_unique>( - kernel_src, - forward_cfg, - backward_cfg, - dbackward_cfg, - kernel_prop_map); - - tp_cache.insert({hash, - std::make_pair(std::move(jit_tp_impl), - KernelProp(kernel_prop_map, false))}); - it = tp_cache.find(hash); - } - return {it->second.first.get(), it->second.second}; - } +Tensor tensor_to_cpu_contiguous(const Tensor &tensor) { + return tensor.to(torch::kCPU).contiguous(); } -std::pair*, KernelProp> - compile_conv_with_caching(const torch::Tensor &json_bytes, - int64_t hash) { - { - const std::lock_guard lock(mut); - auto it = conv_cache.find(hash); - if (it == conv_cache.end()) { - torch::Tensor cpu_tensor = json_bytes.to(torch::kCPU).contiguous(); - std::string json_payload( - reinterpret_cast(cpu_tensor.data_ptr()), - cpu_tensor.numel() - ); - - std::string err; - json root = json::parse(json_payload, err); - if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); - - std::string kernel_src = root["kernel"].string_value(); - auto forward_cfg = parse_json_config(root["forward_config"]); - auto backward_cfg = parse_json_config(root["backward_config"]); - auto dbackward_cfg = parse_json_config(root["double_backward_config"]); - auto kernel_prop_map = parse_json_config(root["kernel_prop"]); - - auto jit_conv_impl = std::make_unique>( - kernel_src, - forward_cfg, - backward_cfg, - dbackward_cfg, - kernel_prop_map); - - conv_cache.insert({hash, - std::make_pair(std::move(jit_conv_impl), - KernelProp(kernel_prop_map, true))}); - it = conv_cache.find(hash); - } - return {it->second.first.get(), it->second.second}; - } +Tensor tensor_contiguous(const Tensor &tensor) { + return tensor.contiguous(); } -// --------------------- Tensor Products -------------------------- - -torch::Tensor jit_tp_forward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - int64_t L3_dim) { - - auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t num_batch = L1_in.size(0); - - check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); - - if (k.shared_weights) - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); - else - check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); - - torch::Tensor L3_out = torch::empty({num_batch, k.L3_dim}, L1_in.options()); - - at::Tensor L1_contig = L1_in.contiguous(); - at::Tensor L2_contig = L2_in.contiguous(); - at::Tensor W_contig = W.contiguous(); - - jit_kernel->exec_tensor_product( - num_batch, - data_ptr(L1_contig), - data_ptr(L2_contig), - data_ptr(L3_out), - data_ptr(W_contig), - stream - ); - - return L3_out; +Tensor tensor_empty_like(const Tensor &ref, const std::vector &sizes) { + return torch::empty(sizes, ref.options()); } -tuple jit_tp_backward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - torch::Tensor L3_grad) { - - auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t num_batch = L1_in.size(0); - - check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); - check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); - - if (k.shared_weights) - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); - else - check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); - - torch::Tensor L1_grad = torch::empty(L1_in.sizes(), L1_in.options()); - torch::Tensor L2_grad = torch::empty(L2_in.sizes(), L2_in.options()); - torch::Tensor W_grad = torch::empty(W.sizes(), W.options()); - - if(k.shared_weights) - W_grad.zero_(); - - torch::Tensor L1_in_contig = L1_in.contiguous(); - torch::Tensor L2_in_contig = L2_in.contiguous(); - torch::Tensor W_contig = W.contiguous(); - torch::Tensor L3_grad_contig = L3_grad.contiguous(); - - jit_kernel->backward( - num_batch, - data_ptr(L1_in_contig), data_ptr(L1_grad), - data_ptr(L2_in_contig), data_ptr(L2_grad), - data_ptr(W_contig), data_ptr(W_grad), - data_ptr(L3_grad_contig), - stream - ); - - return tuple(L1_grad, L2_grad, W_grad); +Tensor tensor_zeros_like(const Tensor &ref, const std::vector &sizes) { + return torch::zeros(sizes, ref.options()); } -tuple jit_tp_double_backward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - torch::Tensor L3_grad, - torch::Tensor L1_dgrad, - torch::Tensor L2_dgrad, - torch::Tensor W_dgrad) { - - auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t num_batch = L1_in.size(0); - - check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); - check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); - check_tensor(L1_dgrad, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); - check_tensor(L2_dgrad, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); - - if (k.shared_weights){ - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); - check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); - } else { - check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); - check_tensor(W_dgrad, {num_batch, k.weight_numel}, k.weight_dtype, "W_dgrad"); - } - - torch::Tensor L1_grad = torch::empty(L1_in.sizes(), L1_in.options()); - torch::Tensor L2_grad = torch::empty(L2_in.sizes(), L2_in.options()); - torch::Tensor W_grad = torch::empty(W.sizes(), W.options()); - torch::Tensor L3_dgrad = torch::empty(L3_grad.sizes(), L3_grad.options()); - - torch::Tensor L1_in_contig = L1_in.contiguous(); - torch::Tensor L2_in_contig = L2_in.contiguous(); - torch::Tensor W_contig = W.contiguous(); - torch::Tensor L3_grad_contig = L3_grad.contiguous(); - - torch::Tensor L1_dgrad_contig = L1_dgrad.contiguous(); - torch::Tensor L2_dgrad_contig = L2_dgrad.contiguous(); - torch::Tensor W_dgrad_contig = W_dgrad.contiguous(); - - if(k.shared_weights) { - W_grad.zero_(); - TORCH_CHECK(W.dim() == 1); - } - - jit_kernel->double_backward( - num_batch, - data_ptr(L1_in_contig), data_ptr(L2_in_contig), - data_ptr(W_contig), data_ptr(L3_grad_contig), - data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), - data_ptr(W_dgrad_contig), - data_ptr(L1_grad), data_ptr(L2_grad), - data_ptr(W_grad), data_ptr(L3_dgrad), - stream - ); - - return tuple(L1_grad, L2_grad, W_grad, L3_dgrad); +void tensor_zero_(Tensor &tensor) { + tensor.zero_(); } +Dtype tensor_dtype(const Tensor &tensor) { + return tensor.dtype(); +} -// ========================= Convolution ================================== - -torch::Tensor jit_conv_forward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - int64_t L3_dim, - torch::Tensor rows, - torch::Tensor cols, - torch::Tensor workspace, - torch::Tensor transpose_perm) { - - auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t nnz = rows.size(0); - const int64_t node_count = L1_in.size(0); - - check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); - check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); - check_tensor(rows, {nnz}, k.idx_dtype, "rows"); - check_tensor(cols, {nnz}, k.idx_dtype, "cols"); - - if (k.deterministic){ - check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); - } else { - at::globalContext().alertNotDeterministic("OpenEquivariance_conv_atomic_forward"); - } - if (k.shared_weights) - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); - else - check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); - - torch::Tensor L3_out = torch::zeros({node_count, k.L3_dim}, L1_in.options()); - - torch::Tensor L1_contig = L1_in.contiguous(); - torch::Tensor L2_contig = L2_in.contiguous(); - torch::Tensor W_contig = W.contiguous(); - torch::Tensor rows_contig = rows.contiguous(); - torch::Tensor cols_contig = cols.contiguous(); - torch::Tensor workspace_contig = workspace.contiguous(); +bool tensor_is_cuda(const Tensor &tensor) { + return tensor.device().is_cuda(); +} - jit_kernel->exec_conv( - data_ptr(L1_contig), - data_ptr(L2_contig), - data_ptr(W_contig), - data_ptr(L3_out), - data_ptr(rows_contig), - data_ptr(cols_contig), - nnz, node_count, - data_ptr(workspace_contig), - stream); +int64_t tensor_dim(const Tensor &tensor) { + return tensor.dim(); +} - return L3_out; +int64_t tensor_size(const Tensor &tensor, int64_t dim) { + return tensor.size(dim); } -tuple jit_conv_backward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - torch::Tensor L3_grad, - torch::Tensor rows, - torch::Tensor cols, - torch::Tensor workspace, - torch::Tensor transpose_perm) { - - auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); +int64_t tensor_numel(const Tensor &tensor) { + return tensor.numel(); +} - const int64_t nnz = rows.size(0); - const int64_t node_count = L1_in.size(0); +void alert_not_deterministic(const char *name) { + at::globalContext().alertNotDeterministic(name); +} - check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); - check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); - check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); - check_tensor(rows, {nnz}, k.idx_dtype, "rows"); - check_tensor(cols, {nnz}, k.idx_dtype, "cols"); +const uint8_t *tensor_data_ptr_u8(const Tensor &tensor) { + return tensor.data_ptr(); +} - if (k.deterministic){ - check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); - } else { - at::globalContext().alertNotDeterministic("OpenEquivariance_conv_atomic_backward"); - } - - if (k.shared_weights) - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); +void *data_ptr(const Tensor &tensor) { + if (tensor.dtype() == torch::kFloat) + return reinterpret_cast(tensor.data_ptr()); + else if (tensor.dtype() == torch::kDouble) + return reinterpret_cast(tensor.data_ptr()); + else if (tensor.dtype() == torch::kLong) + return reinterpret_cast(tensor.data_ptr()); + else if (tensor.dtype() == torch::kByte) + return reinterpret_cast(tensor.data_ptr()); + else if (tensor.dtype() == torch::kInt) + return reinterpret_cast(tensor.data_ptr()); else - check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); - - torch::Tensor L1_grad = torch::zeros(L1_in.sizes(), L1_in.options()); - torch::Tensor L2_grad = torch::zeros(L2_in.sizes(), L2_in.options()); - torch::Tensor W_grad = torch::empty(W.sizes(), W.options()); - - torch::Tensor L1_in_contig = L1_in.contiguous(); - torch::Tensor L2_in_contig = L2_in.contiguous(); - torch::Tensor W_contig = W.contiguous(); - torch::Tensor L3_grad_contig = L3_grad.contiguous(); - - torch::Tensor rows_contig = rows.contiguous(); - torch::Tensor cols_contig = cols.contiguous(); - torch::Tensor workspace_contig = workspace.contiguous(); - torch::Tensor transpose_perm_contig = transpose_perm.contiguous(); - - if(k.shared_weights) - W_grad.zero_(); - - jit_kernel->backward( - data_ptr(L1_in_contig), data_ptr(L1_grad), - data_ptr(L2_in_contig), data_ptr(L2_grad), - data_ptr(W_contig), data_ptr(W_grad), - data_ptr(L3_grad_contig), - data_ptr(rows_contig), data_ptr(cols_contig), - nnz, node_count, - data_ptr(workspace_contig), - data_ptr(transpose_perm_contig), - stream); - - return tuple(L1_grad, L2_grad, W_grad); + throw std::logic_error("Unsupported tensor datatype!"); } -tuple jit_conv_double_backward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - torch::Tensor L3_grad, - torch::Tensor L1_dgrad, - torch::Tensor L2_dgrad, - torch::Tensor W_dgrad, - torch::Tensor rows, - torch::Tensor cols, - torch::Tensor workspace, - torch::Tensor transpose_perm) { - - auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t nnz = rows.size(0); - const int64_t node_count = L1_in.size(0); - - check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); - check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); - check_tensor(L1_dgrad, {node_count, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); - check_tensor(L2_dgrad, {nnz, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); - check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); - check_tensor(rows, {nnz}, k.idx_dtype, "rows"); - check_tensor(cols, {nnz}, k.idx_dtype, "cols"); - - if (k.deterministic) { - check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); - } else { - at::globalContext().alertNotDeterministic("OpenEquivariance_conv_atomic_double_backward"); - } - - if (k.shared_weights) { - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); - check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); - } - else { - check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); - check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad"); - } - - torch::Tensor L1_grad = torch::zeros(L1_in.sizes(), L1_in.options()); - torch::Tensor L2_grad = torch::zeros(L2_in.sizes(), L2_in.options()); - torch::Tensor W_grad = torch::empty(W.sizes(), W.options()); - torch::Tensor L3_dgrad = torch::zeros(L3_grad.sizes(), L3_grad.options()); - - torch::Tensor L1_in_contig = L1_in.contiguous(); - torch::Tensor L2_in_contig = L2_in.contiguous(); - torch::Tensor W_contig = W.contiguous(); - torch::Tensor L3_grad_contig = L3_grad.contiguous(); - torch::Tensor L1_dgrad_contig = L1_dgrad.contiguous(); - torch::Tensor L2_dgrad_contig = L2_dgrad.contiguous(); - torch::Tensor W_dgrad_contig = W_dgrad.contiguous(); - - torch::Tensor rows_contig = rows.contiguous(); - torch::Tensor cols_contig = cols.contiguous(); - torch::Tensor workspace_contig = workspace.contiguous(); - torch::Tensor transpose_perm_contig = transpose_perm.contiguous(); - - if(k.shared_weights) - W_grad.zero_(); - - jit_kernel->double_backward( - data_ptr(L1_in_contig), data_ptr(L2_in_contig), - data_ptr(W_contig), data_ptr(L3_grad_contig), - data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), - data_ptr(W_dgrad_contig), - data_ptr(L1_grad), data_ptr(L2_grad), - data_ptr(W_grad), data_ptr(L3_dgrad), - data_ptr(rows_contig), data_ptr(cols_contig), - nnz, node_count, - data_ptr(workspace_contig), data_ptr(transpose_perm_contig), - stream - ); - - return tuple(L1_grad, L2_grad, W_grad, L3_dgrad); +Stream get_current_stream() { +#ifdef CUDA_BACKEND + return c10::cuda::getCurrentCUDAStream(); +#endif +#ifdef HIP_BACKEND + return c10::hip::getCurrentHIPStream(); +#endif } -// =========================================================== - -TORCH_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { - m.impl("jit_tp_forward", &jit_tp_forward); - m.impl("jit_tp_backward", &jit_tp_backward); - m.impl("jit_tp_double_backward", &jit_tp_double_backward); - - m.impl("jit_conv_forward", &jit_conv_forward); - m.impl("jit_conv_backward", &jit_conv_backward); - m.impl("jit_conv_double_backward", &jit_conv_double_backward); -}; - -TORCH_LIBRARY(libtorch_tp_jit, m) { - m.def("jit_tp_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim) -> Tensor"); - m.def("jit_tp_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)"); - m.def("jit_tp_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)"); - - m.def("jit_conv_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); - m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); - m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); -}; - -PYBIND11_MODULE(libtorch_tp_jit, m) {} \ No newline at end of file +PYBIND11_MODULE(libtorch_tp_jit, m) {} diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp new file mode 100644 index 0000000..4118f9e --- /dev/null +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -0,0 +1,104 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef HIP_BACKEND + #include +#endif + +using Tensor = torch::stable::Tensor; +using Dtype = torch::headeronly::ScalarType; + +constexpr Dtype kFloat = torch::headeronly::kFloat; +constexpr Dtype kDouble = torch::headeronly::kDouble; +constexpr Dtype kInt = torch::headeronly::kInt; +constexpr Dtype kLong = torch::headeronly::kLong; +constexpr Dtype kByte = torch::headeronly::kUInt8; + +#define CHECK STD_TORCH_CHECK +#define BOX(x) TORCH_BOX(x) +#define REGISTER_LIBRARY_IMPL STABLE_TORCH_LIBRARY_IMPL +#define REGISTER_LIBRARY STABLE_TORCH_LIBRARY + +#include "torch_core.hpp" + +Tensor tensor_to_cpu_contiguous(const Tensor &tensor) { + torch::stable::Device device(torch::headeronly::DeviceType::CPU); + return torch::stable::contiguous(torch::stable::to(tensor, device)); +} + +Tensor tensor_contiguous(const Tensor &tensor) { + return torch::stable::contiguous(tensor); +} + +Tensor tensor_empty_like(const Tensor &ref, const std::vector &sizes) { + auto sizes_ref = torch::headeronly::IntHeaderOnlyArrayRef(sizes.data(), sizes.size()); + return torch::stable::new_empty(ref, sizes_ref); +} + +Tensor tensor_zeros_like(const Tensor &ref, const std::vector &sizes) { + auto sizes_ref = torch::headeronly::IntHeaderOnlyArrayRef(sizes.data(), sizes.size()); + Tensor out = torch::stable::new_empty(ref, sizes_ref); + torch::stable::zero_(out); + return out; +} + +void tensor_zero_(Tensor &tensor) { + torch::stable::zero_(tensor); +} + +Dtype tensor_dtype(const Tensor &tensor) { + return tensor.scalar_type(); +} + +bool tensor_is_cuda(const Tensor &tensor) { + return tensor.is_cuda(); +} + +int64_t tensor_dim(const Tensor &tensor) { + return tensor.dim(); +} + +int64_t tensor_size(const Tensor &tensor, int64_t dim) { + return tensor.size(dim); +} + +int64_t tensor_numel(const Tensor &tensor) { + return tensor.numel(); +} + +void alert_not_deterministic(const char *name) { + (void)name; +} + +const uint8_t *tensor_data_ptr_u8(const Tensor &tensor) { + return static_cast(tensor.data_ptr()); +} + +void *data_ptr(const Tensor &tensor) { + Dtype dtype = tensor.scalar_type(); + if (dtype == kFloat || dtype == kDouble || dtype == kLong || dtype == kByte || dtype == kInt) + return tensor.data_ptr(); + throw std::logic_error("Unsupported tensor datatype!"); +} + +Stream get_current_stream() { +#ifdef CUDA_BACKEND + void *stream_ptr = nullptr; + auto device_index = torch::stable::accelerator::getCurrentDeviceIndex(); + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); + return static_cast(stream_ptr); +#endif +#ifdef HIP_BACKEND + return c10::hip::getCurrentHIPStream(); +#endif +} + diff --git a/openequivariance/openequivariance/extension/torch_core.hpp b/openequivariance/openequivariance/extension/torch_core.hpp new file mode 100644 index 0000000..27df009 --- /dev/null +++ b/openequivariance/openequivariance/extension/torch_core.hpp @@ -0,0 +1,650 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "json11/json11.hpp" + +#ifdef CUDA_BACKEND + #include "backend_cuda.hpp" + #include "group_mm_cuda.hpp" + using JITKernel = CUJITKernel; + using GPU_Allocator = CUDA_Allocator; + + template + using GroupMM = GroupMMCUDA; +#endif + +#ifdef HIP_BACKEND + #include "backend_hip.hpp" + #include "group_mm_hip.hpp" + using JITKernel = HIPJITKernel; + using GPU_Allocator = HIP_Allocator; + + template + using GroupMM = GroupMMHIP; +#endif + +#include "tensorproducts.hpp" +#include "convolution.hpp" + +using namespace std; +using json = json11::Json; + +Dtype enum_to_torch_dtype(int64_t i); + +Tensor tensor_to_cpu_contiguous(const Tensor &tensor); +Tensor tensor_contiguous(const Tensor &tensor); +Tensor tensor_empty_like(const Tensor &ref, const std::vector &sizes); +Tensor tensor_zeros_like(const Tensor &ref, const std::vector &sizes); +void tensor_zero_(Tensor &tensor); + +Dtype tensor_dtype(const Tensor &tensor); +bool tensor_is_cuda(const Tensor &tensor); +int64_t tensor_dim(const Tensor &tensor); +int64_t tensor_size(const Tensor &tensor, int64_t dim); +int64_t tensor_numel(const Tensor &tensor); + +void alert_not_deterministic(const char *name); +Stream get_current_stream(); + +const uint8_t *tensor_data_ptr_u8(const Tensor &tensor); +void *data_ptr(const Tensor &tensor); + +inline std::string shape_to_string(std::initializer_list shape) { + std::ostringstream oss; + oss << "["; + size_t i = 0; + for (int64_t dim : shape) { + if (i > 0) { + oss << ", "; + } + oss << dim; + ++i; + } + oss << "]"; + return oss.str(); +} + +inline std::string tensor_sizes_str(const Tensor &tensor) { + std::ostringstream oss; + oss << "["; + int64_t dims = tensor_dim(tensor); + for (int64_t i = 0; i < dims; ++i) { + if (i > 0) { + oss << ", "; + } + oss << tensor_size(tensor, i); + } + oss << "]"; + return oss.str(); +} + +inline std::vector tensor_sizes_vec(const Tensor &tensor) { + int64_t dims = tensor_dim(tensor); + std::vector sizes; + sizes.reserve(static_cast(dims)); + for (int64_t i = 0; i < dims; ++i) { + sizes.push_back(tensor_size(tensor, i)); + } + return sizes; +} + +inline std::vector make_sizes(std::initializer_list sizes) { + return std::vector(sizes); +} + +inline Dtype enum_to_torch_dtype(int64_t i) { + switch (i) { + case 1: return kFloat; + case 2: return kDouble; + case 3: return kInt; + case 4: return kLong; + case 5: return kByte; + } + throw logic_error("Unsupported tensor datatype!"); +} + +inline void check_tensor(const Tensor &tensor, + std::initializer_list expected_shape, + Dtype expected_dtype, + std::string tensor_name) { + bool shape_ok = (tensor_dim(tensor) == static_cast(expected_shape.size())); + if (shape_ok) { + int64_t i = 0; + for (int64_t dim : expected_shape) { + if (tensor_size(tensor, i) != dim) { + shape_ok = false; + break; + } + ++i; + } + } + + CHECK(shape_ok, + "Shape mismatch for tensor '", tensor_name, + "'. Expected: ", shape_to_string(expected_shape), + ". Got: ", tensor_sizes_str(tensor)); + CHECK(tensor_is_cuda(tensor), "Tensor '", tensor_name, "' is not on the GPU."); + CHECK(tensor_dtype(tensor) == expected_dtype, + "Dtype mismatch for tensor '", tensor_name, + "'. Expected: ", static_cast(expected_dtype), + ". Got: ", static_cast(tensor_dtype(tensor))); +} + +inline std::unordered_map parse_json_config(const json &j_obj) { + std::unordered_map result; + for (const auto &kv : j_obj.object_items()) { + result[kv.first] = static_cast(kv.second.number_value()); + } + return result; +} + +struct KernelProp { + int64_t L1_dim, L2_dim, L3_dim, weight_numel; + bool shared_weights; + Dtype irrep_dtype; + Dtype weight_dtype; + + int64_t workspace_size; // Convolution only + bool deterministic; + Dtype idx_dtype; + Dtype workspace_dtype; + + KernelProp() : + L1_dim(0), L2_dim(0), L3_dim(0), weight_numel(0), + shared_weights(false), + irrep_dtype(kFloat), weight_dtype(kFloat), + workspace_size(0), deterministic(false), + idx_dtype(kInt), workspace_dtype(kByte) {} + + KernelProp( + std::unordered_map &kernel_dims, bool is_convolution) : + L1_dim(kernel_dims.at("L1_dim")), + L2_dim(kernel_dims.at("L2_dim")), + L3_dim(kernel_dims.at("L3_dim")), + weight_numel(kernel_dims.at("weight_numel")), + shared_weights(kernel_dims.at("shared_weights")), + irrep_dtype(enum_to_torch_dtype(kernel_dims.at("irrep_dtype"))), + weight_dtype(enum_to_torch_dtype(kernel_dims.at("weight_dtype"))), + workspace_dtype(kByte) { + if (is_convolution) { + workspace_size = kernel_dims.at("workspace_size"); + deterministic = kernel_dims.at("deterministic"); + idx_dtype = enum_to_torch_dtype(kernel_dims.at("idx_dtype")); + } + } +}; + +inline std::unordered_map>, + KernelProp + >> tp_cache; + +inline std::unordered_map>, + KernelProp + >> conv_cache; + +inline std::mutex mut; + +inline std::pair*, KernelProp> + compile_tp_with_caching(const Tensor &json_bytes, + int64_t hash) { + { + const std::lock_guard lock(mut); + auto it = tp_cache.find(hash); + if (it == tp_cache.end()) { + Tensor cpu_tensor = tensor_to_cpu_contiguous(json_bytes); + std::string json_payload( + reinterpret_cast(tensor_data_ptr_u8(cpu_tensor)), + tensor_numel(cpu_tensor) + ); + + std::string err; + json root = json::parse(json_payload, err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + + auto jit_tp_impl = std::make_unique>( + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, + kernel_prop_map); + + tp_cache.insert({hash, + std::make_pair(std::move(jit_tp_impl), + KernelProp(kernel_prop_map, false))}); + it = tp_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; + } +} + +inline std::pair*, KernelProp> + compile_conv_with_caching(const Tensor &json_bytes, + int64_t hash) { + { + const std::lock_guard lock(mut); + auto it = conv_cache.find(hash); + if (it == conv_cache.end()) { + Tensor cpu_tensor = tensor_to_cpu_contiguous(json_bytes); + std::string json_payload( + reinterpret_cast(tensor_data_ptr_u8(cpu_tensor)), + tensor_numel(cpu_tensor) + ); + + std::string err; + json root = json::parse(json_payload, err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + + auto jit_conv_impl = std::make_unique>( + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, + kernel_prop_map); + + conv_cache.insert({hash, + std::make_pair(std::move(jit_conv_impl), + KernelProp(kernel_prop_map, true))}); + it = conv_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; + } +} + +// --------------------- Tensor Products -------------------------- + +inline Tensor jit_tp_forward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + int64_t L3_dim) { + + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t num_batch = tensor_size(L1_in, 0); + + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + + Tensor L3_out = tensor_empty_like(L1_in, make_sizes({num_batch, k.L3_dim})); + + Tensor L1_contig = tensor_contiguous(L1_in); + Tensor L2_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + + jit_kernel->exec_tensor_product( + num_batch, + data_ptr(L1_contig), + data_ptr(L2_contig), + data_ptr(L3_out), + data_ptr(W_contig), + stream + ); + + return L3_out; +} + +inline tuple jit_tp_backward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + Tensor L3_grad) { + + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t num_batch = tensor_size(L1_in, 0); + + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); + + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + + Tensor L1_grad = tensor_empty_like(L1_in, tensor_sizes_vec(L1_in)); + Tensor L2_grad = tensor_empty_like(L2_in, tensor_sizes_vec(L2_in)); + Tensor W_grad = tensor_empty_like(W, tensor_sizes_vec(W)); + + if (k.shared_weights) + tensor_zero_(W_grad); + + Tensor L1_in_contig = tensor_contiguous(L1_in); + Tensor L2_in_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + Tensor L3_grad_contig = tensor_contiguous(L3_grad); + + jit_kernel->backward( + num_batch, + data_ptr(L1_in_contig), data_ptr(L1_grad), + data_ptr(L2_in_contig), data_ptr(L2_grad), + data_ptr(W_contig), data_ptr(W_grad), + data_ptr(L3_grad_contig), + stream + ); + + return tuple(L1_grad, L2_grad, W_grad); +} + +inline tuple jit_tp_double_backward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + Tensor L3_grad, + Tensor L1_dgrad, + Tensor L2_dgrad, + Tensor W_dgrad) { + + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t num_batch = tensor_size(L1_in, 0); + + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(L1_dgrad, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); + check_tensor(L2_dgrad, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); + + if (k.shared_weights) { + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); + } else { + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {num_batch, k.weight_numel}, k.weight_dtype, "W_dgrad"); + } + + Tensor L1_grad = tensor_empty_like(L1_in, tensor_sizes_vec(L1_in)); + Tensor L2_grad = tensor_empty_like(L2_in, tensor_sizes_vec(L2_in)); + Tensor W_grad = tensor_empty_like(W, tensor_sizes_vec(W)); + Tensor L3_dgrad = tensor_empty_like(L3_grad, tensor_sizes_vec(L3_grad)); + + Tensor L1_in_contig = tensor_contiguous(L1_in); + Tensor L2_in_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + Tensor L3_grad_contig = tensor_contiguous(L3_grad); + + Tensor L1_dgrad_contig = tensor_contiguous(L1_dgrad); + Tensor L2_dgrad_contig = tensor_contiguous(L2_dgrad); + Tensor W_dgrad_contig = tensor_contiguous(W_dgrad); + + if (k.shared_weights) { + tensor_zero_(W_grad); + CHECK(tensor_dim(W) == 1); + } + + jit_kernel->double_backward( + num_batch, + data_ptr(L1_in_contig), data_ptr(L2_in_contig), + data_ptr(W_contig), data_ptr(L3_grad_contig), + data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), + data_ptr(W_dgrad_contig), + data_ptr(L1_grad), data_ptr(L2_grad), + data_ptr(W_grad), data_ptr(L3_dgrad), + stream + ); + + return tuple(L1_grad, L2_grad, W_grad, L3_dgrad); +} + + +// ========================= Convolution ================================== + +inline Tensor jit_conv_forward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + int64_t L3_dim, + Tensor rows, + Tensor cols, + Tensor workspace, + Tensor transpose_perm) { + + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t nnz = tensor_size(rows, 0); + const int64_t node_count = tensor_size(L1_in, 0); + + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic) { + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } else { + alert_not_deterministic("OpenEquivariance_conv_atomic_forward"); + } + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + + Tensor L3_out = tensor_zeros_like(L1_in, make_sizes({node_count, k.L3_dim})); + + Tensor L1_contig = tensor_contiguous(L1_in); + Tensor L2_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + Tensor rows_contig = tensor_contiguous(rows); + Tensor cols_contig = tensor_contiguous(cols); + Tensor workspace_contig = tensor_contiguous(workspace); + + jit_kernel->exec_conv( + data_ptr(L1_contig), + data_ptr(L2_contig), + data_ptr(W_contig), + data_ptr(L3_out), + data_ptr(rows_contig), + data_ptr(cols_contig), + nnz, node_count, + data_ptr(workspace_contig), + stream); + + return L3_out; +} + +inline tuple jit_conv_backward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + Tensor L3_grad, + Tensor rows, + Tensor cols, + Tensor workspace, + Tensor transpose_perm) { + + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t nnz = tensor_size(rows, 0); + const int64_t node_count = tensor_size(L1_in, 0); + + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic) { + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } else { + alert_not_deterministic("OpenEquivariance_conv_atomic_backward"); + } + + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + + Tensor L1_grad = tensor_zeros_like(L1_in, tensor_sizes_vec(L1_in)); + Tensor L2_grad = tensor_zeros_like(L2_in, tensor_sizes_vec(L2_in)); + Tensor W_grad = tensor_empty_like(W, tensor_sizes_vec(W)); + + Tensor L1_in_contig = tensor_contiguous(L1_in); + Tensor L2_in_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + Tensor L3_grad_contig = tensor_contiguous(L3_grad); + + Tensor rows_contig = tensor_contiguous(rows); + Tensor cols_contig = tensor_contiguous(cols); + Tensor workspace_contig = tensor_contiguous(workspace); + Tensor transpose_perm_contig = tensor_contiguous(transpose_perm); + + if (k.shared_weights) + tensor_zero_(W_grad); + + jit_kernel->backward( + data_ptr(L1_in_contig), data_ptr(L1_grad), + data_ptr(L2_in_contig), data_ptr(L2_grad), + data_ptr(W_contig), data_ptr(W_grad), + data_ptr(L3_grad_contig), + data_ptr(rows_contig), data_ptr(cols_contig), + nnz, node_count, + data_ptr(workspace_contig), + data_ptr(transpose_perm_contig), + stream); + + return tuple(L1_grad, L2_grad, W_grad); +} + +inline tuple jit_conv_double_backward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + Tensor L3_grad, + Tensor L1_dgrad, + Tensor L2_dgrad, + Tensor W_dgrad, + Tensor rows, + Tensor cols, + Tensor workspace, + Tensor transpose_perm) { + + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t nnz = tensor_size(rows, 0); + const int64_t node_count = tensor_size(L1_in, 0); + + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(L1_dgrad, {node_count, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); + check_tensor(L2_dgrad, {nnz, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic) { + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } else { + alert_not_deterministic("OpenEquivariance_conv_atomic_double_backward"); + } + + if (k.shared_weights) { + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); + } else { + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad"); + } + + Tensor L1_grad = tensor_zeros_like(L1_in, tensor_sizes_vec(L1_in)); + Tensor L2_grad = tensor_zeros_like(L2_in, tensor_sizes_vec(L2_in)); + Tensor W_grad = tensor_empty_like(W, tensor_sizes_vec(W)); + Tensor L3_dgrad = tensor_zeros_like(L3_grad, tensor_sizes_vec(L3_grad)); + + Tensor L1_in_contig = tensor_contiguous(L1_in); + Tensor L2_in_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + Tensor L3_grad_contig = tensor_contiguous(L3_grad); + Tensor L1_dgrad_contig = tensor_contiguous(L1_dgrad); + Tensor L2_dgrad_contig = tensor_contiguous(L2_dgrad); + Tensor W_dgrad_contig = tensor_contiguous(W_dgrad); + + Tensor rows_contig = tensor_contiguous(rows); + Tensor cols_contig = tensor_contiguous(cols); + Tensor workspace_contig = tensor_contiguous(workspace); + Tensor transpose_perm_contig = tensor_contiguous(transpose_perm); + + if (k.shared_weights) + tensor_zero_(W_grad); + + jit_kernel->double_backward( + data_ptr(L1_in_contig), data_ptr(L2_in_contig), + data_ptr(W_contig), data_ptr(L3_grad_contig), + data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), + data_ptr(W_dgrad_contig), + data_ptr(L1_grad), data_ptr(L2_grad), + data_ptr(W_grad), data_ptr(L3_dgrad), + data_ptr(rows_contig), data_ptr(cols_contig), + nnz, node_count, + data_ptr(workspace_contig), data_ptr(transpose_perm_contig), + stream + ); + + return tuple(L1_grad, L2_grad, W_grad, L3_dgrad); +} + +// =========================================================== + +REGISTER_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { + m.impl("jit_tp_forward", BOX(&jit_tp_forward)); + m.impl("jit_tp_backward", BOX(&jit_tp_backward)); + m.impl("jit_tp_double_backward", BOX(&jit_tp_double_backward)); + + m.impl("jit_conv_forward", BOX(&jit_conv_forward)); + m.impl("jit_conv_backward", BOX(&jit_conv_backward)); + m.impl("jit_conv_double_backward", BOX(&jit_conv_double_backward)); +}; + +REGISTER_LIBRARY(libtorch_tp_jit, m) { + m.def("jit_tp_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim) -> Tensor"); + m.def("jit_tp_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)"); + m.def("jit_tp_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)"); + + m.def("jit_conv_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); + m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); + m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); +}; From 5583c8ed2ecf027b72ed168b7271e8a5f3935605 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 13 Feb 2026 22:51:53 -0800 Subject: [PATCH 02/28] More progress. --- .../extension/libtorch_tp_jit_stable.cpp | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 4118f9e..9a182fd 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include #include @@ -16,11 +15,11 @@ using Tensor = torch::stable::Tensor; using Dtype = torch::headeronly::ScalarType; -constexpr Dtype kFloat = torch::headeronly::kFloat; -constexpr Dtype kDouble = torch::headeronly::kDouble; -constexpr Dtype kInt = torch::headeronly::kInt; -constexpr Dtype kLong = torch::headeronly::kLong; -constexpr Dtype kByte = torch::headeronly::kUInt8; +constexpr Dtype kFloat = torch::headeronly::ScalarType::Float; +constexpr Dtype kDouble = torch::headeronly::ScalarType::Double; +constexpr Dtype kInt = torch::headeronly::ScalarType::Int; +constexpr Dtype kLong = torch::headeronly::ScalarType::Long; +constexpr Dtype kByte = torch::headeronly::ScalarType::Byte; #define CHECK STD_TORCH_CHECK #define BOX(x) TORCH_BOX(x) @@ -90,15 +89,13 @@ void *data_ptr(const Tensor &tensor) { } Stream get_current_stream() { -#ifdef CUDA_BACKEND - void *stream_ptr = nullptr; - auto device_index = torch::stable::accelerator::getCurrentDeviceIndex(); + int32_t device_index; + StreamOpaque* stream_ptr; + TORCH_ERROR_CODE_CHECK( - aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); - return static_cast(stream_ptr); -#endif -#ifdef HIP_BACKEND - return c10::hip::getCurrentHIPStream(); -#endif + aoti_torch_get_current_device_index(&device_index)) + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_stream(device_index, &stream_ptr)) + return (Stream) stream_ptr; } From cb65ed627052c6343559ac41cd6c1111b1b003dd Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 14 Feb 2026 00:30:13 -0800 Subject: [PATCH 03/28] Saving temporarily. --- openequivariance/CMakeLists.txt | 16 ++++++++++++---- .../extension/libtorch_tp_jit.cpp | 2 +- .../extension/libtorch_tp_jit_stable.cpp | 11 ++++------- .../openequivariance/extension/torch_core.hpp | 10 +++++----- openequivariance/pyproject.toml | 4 ++-- 5 files changed, 24 insertions(+), 19 deletions(-) diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index 472bc2d..0f9945e 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -5,6 +5,7 @@ set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with set(LIBTORCH_ZIP "${CMAKE_BINARY_DIR}/libtorch-shared-with-deps-2.10.0+cpu.zip") set(LIBTORCH_EXTRACT_DIR "${CMAKE_BINARY_DIR}/libtorch_extract") set(LIBTORCH_INCLUDE_DIR "${LIBTORCH_EXTRACT_DIR}/libtorch/include") +set(LIBTORCH_LIB_DIR "${LIBTORCH_EXTRACT_DIR}/libtorch/lib") file(DOWNLOAD ${LIBTORCH_URL} ${LIBTORCH_ZIP} SHOW_PROGRESS) file(MAKE_DIRECTORY ${LIBTORCH_EXTRACT_DIR}) @@ -14,6 +15,12 @@ file(ARCHIVE_EXTRACT PATTERNS "libtorch/include/*" ) +file(ARCHIVE_EXTRACT + INPUT ${LIBTORCH_ZIP} + DESTINATION ${LIBTORCH_EXTRACT_DIR} + PATTERNS "libtorch/lib/*" +) + add_custom_target(libtorch_headers ALL DEPENDS ${LIBTORCH_INCLUDE_DIR}) set(EXT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/extension") @@ -39,6 +46,7 @@ function(add_stable_extension target_name backend_define) ${EXT_UTIL_DIR} ${EXT_JSON_DIR} ) + target_link_directories(${target_name} PRIVATE ${LIBTORCH_LIB_DIR}) add_dependencies(${target_name} libtorch_headers) endfunction() @@ -47,14 +55,14 @@ find_package(hip QUIET) if(CUDAToolkit_FOUND) message(STATUS "Building stable extension with CUDA backend.") - add_stable_extension(libtorch_tp_jit_stable_cuda CUDA_BACKEND) - target_link_libraries(libtorch_tp_jit_stable_cuda PRIVATE CUDA::cudart CUDA::cuda_driver) + add_stable_extension(torch_tp_jit_stable_cuda CUDA_BACKEND) + target_link_libraries(torch_tp_jit_stable_cuda PRIVATE CUDA::cudart CUDA::cuda_driver) endif() if(hip_FOUND) message(STATUS "Building stable extension with HIP backend.") - add_stable_extension(libtorch_tp_jit_stable_hip HIP_BACKEND) - target_link_libraries(libtorch_tp_jit_stable_hip PRIVATE hiprtc) + add_stable_extension(torch_tp_jit_stable_hip HIP_BACKEND) + target_link_libraries(torch_tp_jit_stable_hip PRIVATE hiprtc) endif() if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index bcd2c7b..64ecede 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -23,7 +23,7 @@ constexpr Dtype kInt = torch::kInt; constexpr Dtype kLong = torch::kLong; constexpr Dtype kByte = torch::kByte; -#define CHECK TORCH_CHECK +#define TCHECK TORCH_CHECK #define BOX(x) x #define REGISTER_LIBRARY_IMPL TORCH_LIBRARY_IMPL #define REGISTER_LIBRARY TORCH_LIBRARY diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 9a182fd..b6794e0 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -21,7 +21,7 @@ constexpr Dtype kInt = torch::headeronly::ScalarType::Int; constexpr Dtype kLong = torch::headeronly::ScalarType::Long; constexpr Dtype kByte = torch::headeronly::ScalarType::Byte; -#define CHECK STD_TORCH_CHECK +#define TCHECK STD_TORCH_CHECK #define BOX(x) TORCH_BOX(x) #define REGISTER_LIBRARY_IMPL STABLE_TORCH_LIBRARY_IMPL #define REGISTER_LIBRARY STABLE_TORCH_LIBRARY @@ -53,9 +53,9 @@ void tensor_zero_(Tensor &tensor) { torch::stable::zero_(tensor); } -Dtype tensor_dtype(const Tensor &tensor) { +/*Dtype tensor_dtype(const Tensor &tensor) { return tensor.scalar_type(); -} +}*/ bool tensor_is_cuda(const Tensor &tensor) { return tensor.is_cuda(); @@ -82,10 +82,7 @@ const uint8_t *tensor_data_ptr_u8(const Tensor &tensor) { } void *data_ptr(const Tensor &tensor) { - Dtype dtype = tensor.scalar_type(); - if (dtype == kFloat || dtype == kDouble || dtype == kLong || dtype == kByte || dtype == kInt) - return tensor.data_ptr(); - throw std::logic_error("Unsupported tensor datatype!"); + return tensor.data_ptr(); } Stream get_current_stream() { diff --git a/openequivariance/openequivariance/extension/torch_core.hpp b/openequivariance/openequivariance/extension/torch_core.hpp index 27df009..fd59fac 100644 --- a/openequivariance/openequivariance/extension/torch_core.hpp +++ b/openequivariance/openequivariance/extension/torch_core.hpp @@ -46,7 +46,7 @@ Tensor tensor_empty_like(const Tensor &ref, const std::vector &sizes); Tensor tensor_zeros_like(const Tensor &ref, const std::vector &sizes); void tensor_zero_(Tensor &tensor); -Dtype tensor_dtype(const Tensor &tensor); +//Dtype tensor_dtype(const Tensor &tensor); bool tensor_is_cuda(const Tensor &tensor); int64_t tensor_dim(const Tensor &tensor); int64_t tensor_size(const Tensor &tensor, int64_t dim); @@ -128,15 +128,15 @@ inline void check_tensor(const Tensor &tensor, } } - CHECK(shape_ok, + TCHECK(shape_ok, "Shape mismatch for tensor '", tensor_name, "'. Expected: ", shape_to_string(expected_shape), ". Got: ", tensor_sizes_str(tensor)); - CHECK(tensor_is_cuda(tensor), "Tensor '", tensor_name, "' is not on the GPU."); - CHECK(tensor_dtype(tensor) == expected_dtype, + TCHECK(tensor_is_cuda(tensor), "Tensor '", tensor_name, "' is not on the GPU."); + /*TCHECK(tensor_dtype(tensor) == expected_dtype, "Dtype mismatch for tensor '", tensor_name, "'. Expected: ", static_cast(expected_dtype), - ". Got: ", static_cast(tensor_dtype(tensor))); + ". Got: ", static_cast(tensor_dtype(tensor)));*/ } inline std::unordered_map parse_json_config(const json &j_obj) { diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index a0ddd61..789038c 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools", "setuptools-scm"] -build-backend = "setuptools.build_meta" +requires = ["setuptools", "setuptools-scm", "scikit-build-core"] +build-backend = "scikit_build_core.build" [project] name = "openequivariance" From 796bf3174b7279dfbe313556dc675ea98159aa0a Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 14 Feb 2026 00:38:02 -0800 Subject: [PATCH 04/28] Fixed JIT issue. --- openequivariance/openequivariance/extension/libtorch_tp_jit.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index 64ecede..b2cab3d 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -50,7 +50,7 @@ void tensor_zero_(Tensor &tensor) { tensor.zero_(); } -Dtype tensor_dtype(const Tensor &tensor) { +caffe2::TypeMeta tensor_dtype(const Tensor &tensor) { return tensor.dtype(); } From d147a17e7c0fd3005aa73f73c4cb25c3d0298473 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 14 Feb 2026 01:00:06 -0800 Subject: [PATCH 05/28] Managed to build stable extension. --- openequivariance/CMakeLists.txt | 4 ++++ openequivariance/openequivariance/extension/torch_core.hpp | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index 0f9945e..a3947eb 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -57,12 +57,16 @@ if(CUDAToolkit_FOUND) message(STATUS "Building stable extension with CUDA backend.") add_stable_extension(torch_tp_jit_stable_cuda CUDA_BACKEND) target_link_libraries(torch_tp_jit_stable_cuda PRIVATE CUDA::cudart CUDA::cuda_driver) + + install(TARGETS torch_tp_jit_stable_cuda LIBRARY DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib") endif() if(hip_FOUND) message(STATUS "Building stable extension with HIP backend.") add_stable_extension(torch_tp_jit_stable_hip HIP_BACKEND) target_link_libraries(torch_tp_jit_stable_hip PRIVATE hiprtc) + + install(TARGETS torch_tp_jit_stable_hip LIBRARY DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib") endif() if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND) diff --git a/openequivariance/openequivariance/extension/torch_core.hpp b/openequivariance/openequivariance/extension/torch_core.hpp index fd59fac..36a634a 100644 --- a/openequivariance/openequivariance/extension/torch_core.hpp +++ b/openequivariance/openequivariance/extension/torch_core.hpp @@ -405,7 +405,7 @@ inline tuple jit_tp_double_backward( if (k.shared_weights) { tensor_zero_(W_grad); - CHECK(tensor_dim(W) == 1); + TCHECK(tensor_dim(W) == 1); } jit_kernel->double_backward( From 092932bbe688ee9f69f5000fd4149c0a8ba2b13f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 14 Feb 2026 18:40:59 -0800 Subject: [PATCH 06/28] Made some further changes. --- openequivariance/CMakeLists.txt | 63 +++++++++++-------- .../extension/libtorch_tp_jit_stable.cpp | 7 ++- .../openequivariance/extension/torch_core.hpp | 18 ++++++ openequivariance/pyproject.toml | 2 +- 4 files changed, 61 insertions(+), 29 deletions(-) diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index a3947eb..ddaffb5 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -1,53 +1,66 @@ cmake_minimum_required(VERSION 3.15...3.30) project(openequivariance_stable_ext LANGUAGES CXX) -set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.10.0%2Bcpu.zip") -set(LIBTORCH_ZIP "${CMAKE_BINARY_DIR}/libtorch-shared-with-deps-2.10.0+cpu.zip") -set(LIBTORCH_EXTRACT_DIR "${CMAKE_BINARY_DIR}/libtorch_extract") -set(LIBTORCH_INCLUDE_DIR "${LIBTORCH_EXTRACT_DIR}/libtorch/include") -set(LIBTORCH_LIB_DIR "${LIBTORCH_EXTRACT_DIR}/libtorch/lib") - -file(DOWNLOAD ${LIBTORCH_URL} ${LIBTORCH_ZIP} SHOW_PROGRESS) -file(MAKE_DIRECTORY ${LIBTORCH_EXTRACT_DIR}) -file(ARCHIVE_EXTRACT - INPUT ${LIBTORCH_ZIP} - DESTINATION ${LIBTORCH_EXTRACT_DIR} - PATTERNS "libtorch/include/*" +find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) + +# Download LibTorch +include(FetchContent) + +FetchContent_Declare( + libtorch + URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.10.0%2Bcpu.zip" ) -file(ARCHIVE_EXTRACT - INPUT ${LIBTORCH_ZIP} - DESTINATION ${LIBTORCH_EXTRACT_DIR} - PATTERNS "libtorch/lib/*" +message(STATUS "Downloading LibTorch...") +FetchContent_MakeAvailable(libtorch) + +set(LIBTORCH_INCLUDE_DIR "${libtorch_SOURCE_DIR}/include") +set(LIBTORCH_LIB_DIR "${libtorch_SOURCE_DIR}/lib") + +message(STATUS "LibTorch Include: ${LIBTORCH_INCLUDE_DIR}") +message(STATUS "LibTorch Lib: ${LIBTORCH_LIB_DIR}") + +# Setup Nanobind +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT ) +message(STATUS "nanobind cmake directory: ${nanobind_ROOT}") -add_custom_target(libtorch_headers ALL DEPENDS ${LIBTORCH_INCLUDE_DIR}) +find_package(nanobind CONFIG REQUIRED) set(EXT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/extension") set(EXT_UTIL_DIR "${EXT_DIR}/util") set(EXT_JSON_DIR "${EXT_DIR}/json11") +# Source files set(OEQ_SOURCES ${EXT_DIR}/libtorch_tp_jit_stable.cpp ${EXT_JSON_DIR}/json11.cpp ) function(add_stable_extension target_name backend_define) - add_library(${target_name} MODULE ${OEQ_SOURCES}) + nanobind_add_module(${target_name} NB_STATIC ${OEQ_SOURCES}) + set_target_properties(${target_name} PROPERTIES CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON POSITION_INDEPENDENT_CODE ON ) - target_compile_definitions(${target_name} PRIVATE ${backend_define}=1) + + # CRITICAL: Manually enforce CXX11 ABI to match the downloaded LibTorch + target_compile_definitions(${target_name} PRIVATE + ${backend_define}=1 + _GLIBCXX_USE_CXX11_ABI=1 + ) + target_include_directories(${target_name} PRIVATE - ${LIBTORCH_INCLUDE_DIR} ${EXT_DIR} ${EXT_UTIL_DIR} ${EXT_JSON_DIR} - ) - target_link_directories(${target_name} PRIVATE ${LIBTORCH_LIB_DIR}) - add_dependencies(${target_name} libtorch_headers) + ${LIBTORCH_INCLUDE_DIR} + ${LIBTORCH_API_INCLUDE_DIR} + ) endfunction() find_package(CUDAToolkit QUIET) @@ -57,7 +70,7 @@ if(CUDAToolkit_FOUND) message(STATUS "Building stable extension with CUDA backend.") add_stable_extension(torch_tp_jit_stable_cuda CUDA_BACKEND) target_link_libraries(torch_tp_jit_stable_cuda PRIVATE CUDA::cudart CUDA::cuda_driver) - + install(TARGETS torch_tp_jit_stable_cuda LIBRARY DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib") endif() @@ -71,4 +84,4 @@ endif() if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND) message(FATAL_ERROR "Neither CUDAToolkit nor HIP was found. Cannot build the stable extension.") -endif() +endif() \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index b6794e0..8cf28ce 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -7,6 +7,7 @@ #include #include #include +#include "nanobind/nanobind.h" #ifdef HIP_BACKEND #include @@ -93,6 +94,6 @@ Stream get_current_stream() { aoti_torch_get_current_device_index(&device_index)) TORCH_ERROR_CODE_CHECK( aoti_torch_get_current_stream(device_index, &stream_ptr)) - return (Stream) stream_ptr; -} - + //return (Stream) stream_ptr; + return (Stream) 0; +} \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/torch_core.hpp b/openequivariance/openequivariance/extension/torch_core.hpp index 36a634a..f36a280 100644 --- a/openequivariance/openequivariance/extension/torch_core.hpp +++ b/openequivariance/openequivariance/extension/torch_core.hpp @@ -37,6 +37,7 @@ using namespace std; using json = json11::Json; +namespace nb = nanobind; Dtype enum_to_torch_dtype(int64_t i); @@ -648,3 +649,20 @@ REGISTER_LIBRARY(libtorch_tp_jit, m) { m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); }; + +NB_MODULE(litbtorch_tp_jit_stable, m) { + nb::class_(m, "DeviceProp") + .def(nb::init()) + .def_ro("name", &DeviceProp::name) + .def_ro("warpsize", &DeviceProp::warpsize) + .def_ro("major", &DeviceProp::major) + .def_ro("minor", &DeviceProp::minor) + .def_ro("multiprocessorCount", &DeviceProp::multiprocessorCount) + .def_ro("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); + + nb::class_(m, "GPUTimer") + .def(nb::init<>()) + .def("start", &GPUTimer::start) + .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) + .def("clear_L2_cache", &GPUTimer::clear_L2_cache); +} \ No newline at end of file diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 789038c..c1f5899 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "setuptools-scm", "scikit-build-core"] +requires = ["setuptools", "setuptools-scm", "scikit-build-core", "nanobind"] build-backend = "scikit_build_core.build" [project] From f580cf16843c2df3156cc7c92f5059baa5332b08 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 14 Feb 2026 20:14:23 -0800 Subject: [PATCH 07/28] Fixed the dynamic versioning bugs. --- .gitignore | 3 +- openequivariance/openequivariance/__init__.py | 4 +- .../openequivariance/_torch/TensorProduct.py | 10 +- .../_torch/TensorProductConv.py | 10 +- .../_torch/extlib/__init__.py | 170 ++++++++---------- .../extension/generic_module.cpp | 55 ------ .../extension/libtorch_tp_jit.cpp | 2 - openequivariance/pyproject.toml | 17 +- openequivariance_extjax/pyproject.toml | 8 +- tests/import_test.py | 7 - 10 files changed, 102 insertions(+), 184 deletions(-) delete mode 100644 openequivariance/openequivariance/extension/generic_module.cpp diff --git a/.gitignore b/.gitignore index 64fcaa8..07455d1 100644 --- a/.gitignore +++ b/.gitignore @@ -40,4 +40,5 @@ paper_benchmarks_v2 paper_benchmarks_v3 get_node.sh -*.egg-info \ No newline at end of file +*.egg-info +_version.py \ No newline at end of file diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index a842a7c..60ca8b1 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -60,9 +60,7 @@ def _check_package_editable(): LINKED_LIBPYTHON, LINKED_LIBPYTHON_ERROR, BUILT_EXTENSION, - BUILT_EXTENSION_ERROR, - TORCH_COMPILE, - TORCH_COMPILE_ERROR, + BUILT_EXTENSION_ERROR ) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 3885604..8a524ac 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -1,5 +1,5 @@ from openequivariance.core.LoopUnrollTP import LoopUnrollTP -from openequivariance import TPProblem +from openequivariance import TPProblem from openequivariance._torch import extlib import torch from openequivariance.core.utils import torch_to_oeq_dtype @@ -227,7 +227,7 @@ def register_autocast(): "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 ) - -register_torch_fakes() -register_autograd() -register_autocast() +if extlib.BUILT_EXTENSION: + register_torch_fakes() + register_autograd() + register_autocast() diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index 5788f2f..848b7ac 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -6,6 +6,7 @@ from openequivariance._torch.extlib import ( postprocess_kernel, DeviceProp, + BUILT_EXTENSION ) from openequivariance.core.ConvolutionBase import ( @@ -14,7 +15,7 @@ ) from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance._torch.TensorProduct import TensorProduct -from openequivariance import TPProblem +from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance._torch.utils import ( reorder_torch, @@ -403,9 +404,10 @@ def register_autocast(): ) -register_torch_fakes() -register_autograd() -register_autocast() +if BUILT_EXTENSION: + register_torch_fakes() + register_autograd() + register_autocast() # ================================================================== diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index a7b4b86..8cf2741 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -14,15 +14,13 @@ BUILT_EXTENSION = False BUILT_EXTENSION_ERROR = None -TORCH_COMPILE = False -TORCH_COMPILE_ERROR = None - LINKED_LIBPYTHON = False LINKED_LIBPYTHON_ERROR = None -torch_module, generic_module = None, None +extension_module = None postprocess_kernel = lambda kernel: kernel # noqa : E731 +# Locate libpython (required for AOTI) try: python_lib_dir = sysconfig.get_config_var("LIBDIR") major, minor = sys.version_info.major, sys.version_info.minor @@ -39,122 +37,94 @@ except Exception as e: LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" +assert torch.version.cuda or torch.version.hip, "Only CUDA and HIP backends are supported" -if BUILT_EXTENSION: - import openequivariance._torch.extlib.generic_module - - generic_module = openequivariance._torch.extlib.generic_module +try: + from torch.utils.cpp_extension import library_paths, include_paths -elif torch.version.cuda or torch.version.hip: - try: - from torch.utils.cpp_extension import library_paths, include_paths + extra_cflags = ["-O3"] + torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"] - extra_cflags = ["-O3"] - generic_sources = ["generic_module.cpp"] - torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"] + include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) - include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) + if LINKED_LIBPYTHON: + extra_link_args.pop() + extra_link_args.extend( + [ + f"-Wl,--no-as-needed,-rpath,{python_lib_dir}", + f"-L{python_lib_dir}", + f"-l{python_lib_name}", + ], + ) + if torch.version.cuda: + extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) - if LINKED_LIBPYTHON: - extra_link_args.pop() - extra_link_args.extend( - [ - f"-Wl,--no-as-needed,-rpath,{python_lib_dir}", - f"-L{python_lib_dir}", - f"-l{python_lib_name}", - ], - ) - if torch.version.cuda: - extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) - - try: - torch_libs, cuda_libs = library_paths("cuda") - extra_link_args.append("-Wl,-rpath," + torch_libs) - extra_link_args.append("-L" + cuda_libs) - if os.path.exists(cuda_libs + "/stubs"): - extra_link_args.append("-L" + cuda_libs + "/stubs") - except Exception as e: - getLogger().info(str(e)) - - extra_cflags.append("-DCUDA_BACKEND") - elif torch.version.hip: - extra_link_args.extend(["-lhiprtc"]) - torch_libs = library_paths("cuda")[0] + try: + torch_libs, cuda_libs = library_paths("cuda") extra_link_args.append("-Wl,-rpath," + torch_libs) - - def postprocess(kernel): - kernel = kernel.replace("__syncwarp();", "__threadfence_block();") - kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") - kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") - return kernel - - postprocess_kernel = postprocess - - extra_cflags.append("-DHIP_BACKEND") - - generic_sources = [oeq_root + "/extension/" + src for src in generic_sources] - torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] - include_dirs = [ - oeq_root + "/extension/" + d for d in include_dirs - ] + include_paths("cuda") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - try: - torch_module = torch.utils.cpp_extension.load( - "libtorch_tp_jit", - torch_sources, - extra_cflags=extra_cflags, - extra_include_paths=include_dirs, - extra_ldflags=extra_link_args, - ) - torch.ops.load_library(torch_module.__file__) - TORCH_COMPILE = True - except Exception as e: - # If compiling torch fails (e.g. low gcc version), we should fall back to the - # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). - TORCH_COMPILE_ERROR = e - - generic_module = torch.utils.cpp_extension.load( - "generic_module", - generic_sources, + extra_link_args.append("-L" + cuda_libs) + if os.path.exists(cuda_libs + "/stubs"): + extra_link_args.append("-L" + cuda_libs + "/stubs") + except Exception as e: + getLogger().info(str(e)) + + extra_cflags.append("-DCUDA_BACKEND") + elif torch.version.hip: + extra_link_args.extend(["-lhiprtc"]) + torch_libs = library_paths("cuda")[0] + extra_link_args.append("-Wl,-rpath," + torch_libs) + + def postprocess(kernel): + kernel = kernel.replace("__syncwarp();", "__threadfence_block();") + kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") + kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") + return kernel + + postprocess_kernel = postprocess + + extra_cflags.append("-DHIP_BACKEND") + + torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] + include_dirs = [ + oeq_root + "/extension/" + d for d in include_dirs + ] + include_paths("cuda") + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + try: + extension_module = torch.utils.cpp_extension.load( + "libtorch_tp_jit", + torch_sources, extra_cflags=extra_cflags, extra_include_paths=include_dirs, extra_ldflags=extra_link_args, ) - if "generic_module" not in sys.modules: - sys.modules["generic_module"] = generic_module - - if not TORCH_COMPILE: - warnings.warn( - "Could not compile integrated PyTorch wrapper. Falling back to Pybind11" - + f", but JITScript, compile fullgraph, and export will fail.\n {TORCH_COMPILE_ERROR}" - ) - BUILT_EXTENSION = True - except Exception as e: - BUILT_EXTENSION_ERROR = f"Error building OpenEquivariance Extension: {e}" -else: - BUILT_EXTENSION_ERROR = "OpenEquivariance extension build not attempted" - - -def _raise_import_error_helper(import_target: str): - if not BUILT_EXTENSION: - raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}") + torch.ops.load_library(extension_module.__file__) + BUILT_EXTENSION = True + except Exception as e: + # If compiling torch fails (e.g. low gcc version), we should fall back to the + # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). + BUILT_EXTENSION_ERROR = e +except Exception as e: + BUILT_EXTENSION_ERROR = f"Error JIT-compiling OpenEquivariance Extension: {e}" def torch_ext_so_path(): - return torch_module.__file__ + return extension_module.__file__ if BUILT_EXTENSION: - from generic_module import ( - GroupMM_F32, - GroupMM_F64, + from extension_module import ( + #GroupMM_F32, + #GroupMM_F64, DeviceProp, GPUTimer, ) else: + def _raise_import_error_helper(import_target: str): + if not BUILT_EXTENSION: + raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}") def GroupMM_F32(*args, **kwargs): _raise_import_error_helper("GroupMM_F32") diff --git a/openequivariance/openequivariance/extension/generic_module.cpp b/openequivariance/openequivariance/extension/generic_module.cpp deleted file mode 100644 index b099699..0000000 --- a/openequivariance/openequivariance/extension/generic_module.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include -#include -#include -#include -#include - -#ifdef CUDA_BACKEND - #include "backend_cuda.hpp" - #include "group_mm_cuda.hpp" - using JITKernel = CUJITKernel; - using GPU_Allocator = CUDA_Allocator; - - template - using GroupMM = GroupMMCUDA; -#endif - -#ifdef HIP_BACKEND - #include "backend_hip.hpp" - #include "group_mm_hip.hpp" - using JITKernel = HIPJITKernel; - using GPU_Allocator = HIP_Allocator; - - template - using GroupMM = GroupMMHIP; -#endif - -#include "tensorproducts.hpp" -#include "convolution.hpp" - -using namespace std; -namespace py=pybind11; - -PYBIND11_MODULE(generic_module, m) { - py::class_>(m, "GroupMM_F32") - .def(py::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - py::class_>(m, "GroupMM_F64") - .def(py::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - - py::class_(m, "DeviceProp") - .def(py::init()) - .def_readonly("name", &DeviceProp::name) - .def_readonly("warpsize", &DeviceProp::warpsize) - .def_readonly("major", &DeviceProp::major) - .def_readonly("minor", &DeviceProp::minor) - .def_readonly("multiprocessorCount", &DeviceProp::multiprocessorCount) - .def_readonly("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); - - py::class_(m, "GPUTimer") - .def(py::init<>()) - .def("start", &GPUTimer::start) - .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) - .def("clear_L2_cache", &GPUTimer::clear_L2_cache); -} \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index b2cab3d..ce9d8d2 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -101,5 +101,3 @@ Stream get_current_stream() { return c10::hip::getCurrentHIPStream(); #endif } - -PYBIND11_MODULE(libtorch_tp_jit, m) {} diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index c1f5899..060a792 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ "setuptools", "ninja", "jinja2", - "numpy" + "numpy", + "nanobind" ] readme = "README.md" @@ -60,7 +61,6 @@ dev = [ ] jax = [ - "nanobind", "scikit-build-core", "setuptools-scm" ] @@ -68,13 +68,18 @@ jax = [ [tool.setuptools.packages.find] include = ["openequivariance*"] -[tool.setuptools_scm] -root = ".." - [tool.pytest.ini_options] addopts = [ "--import-mode=importlib", ] [tool.ruff] -lint.ignore = ["E741"] \ No newline at end of file +lint.ignore = ["E741"] + +[tool.scikit-build] +metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" +sdist.include = ["openequivariance/_version.py"] + +[tool.setuptools_scm] +write_to = "openequivariance/_version.py" +root = ".." \ No newline at end of file diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index 74f9627..c0e3aec 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -8,7 +8,8 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance_extjax" -version = "0.2.1" +dynamic = ["version"] + authors = [ { name="Austin Glover" }, { name="Vivek Bharadwaj" }, @@ -41,7 +42,12 @@ issues = "https://github.com/PASSIONLab/OpenEquivariance/issues" JAX_HIP = {env="JAX_HIP", default="0"} XLA_DIRECT_DOWNLOAD = {env="XLA_DIRECT_DOWNLOAD", default="0"} +[tool.scikit-build] +metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" +sdist.include = ["openequivariance_extjax/_version.py"] + [tool.setuptools_scm] +write_to = "openequivariance_extjax/_version.py" root = ".." [tool.pytest.ini_options] diff --git a/tests/import_test.py b/tests/import_test.py index a472841..bf26af3 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -14,10 +14,3 @@ def test_extension_built(): assert BUILT_EXTENSION_ERROR is None assert BUILT_EXTENSION - - -def test_torch_extension_built(): - from openequivariance import TORCH_COMPILE, TORCH_COMPILE_ERROR - - assert TORCH_COMPILE_ERROR is None - assert TORCH_COMPILE From ad3cb6208e3d79a87ec5521d9da6674b7d30cd42 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 14 Feb 2026 20:47:34 -0800 Subject: [PATCH 08/28] Back to a working state. --- .../_torch/extlib/__init__.py | 23 ++++++++----------- .../extension/libtorch_tp_jit.cpp | 18 +++++++++++++++ .../extension/libtorch_tp_jit_stable.cpp | 18 +++++++++++++++ .../openequivariance/extension/torch_core.hpp | 18 --------------- 4 files changed, 46 insertions(+), 31 deletions(-) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 8cf2741..cc27f1a 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -18,7 +18,14 @@ LINKED_LIBPYTHON_ERROR = None extension_module = None -postprocess_kernel = lambda kernel: kernel # noqa : E731 + +assert torch.version.cuda or torch.version.hip, "Only CUDA and HIP backends are supported" +def postprocess_kernel(kernel): + if torch.version.hip: + kernel = kernel.replace("__syncwarp();", "__threadfence_block();") + kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") + kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") + return kernel # Locate libpython (required for AOTI) try: @@ -37,8 +44,6 @@ except Exception as e: LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" -assert torch.version.cuda or torch.version.hip, "Only CUDA and HIP backends are supported" - try: from torch.utils.cpp_extension import library_paths, include_paths @@ -73,15 +78,6 @@ extra_link_args.extend(["-lhiprtc"]) torch_libs = library_paths("cuda")[0] extra_link_args.append("-Wl,-rpath," + torch_libs) - - def postprocess(kernel): - kernel = kernel.replace("__syncwarp();", "__threadfence_block();") - kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") - kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") - return kernel - - postprocess_kernel = postprocess - extra_cflags.append("-DHIP_BACKEND") torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] @@ -113,9 +109,10 @@ def postprocess(kernel): def torch_ext_so_path(): return extension_module.__file__ +sys.modules["oeq_utilities"] = extension_module if BUILT_EXTENSION: - from extension_module import ( + from oeq_utilities import ( #GroupMM_F32, #GroupMM_F64, DeviceProp, diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index ce9d8d2..f432ae4 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -101,3 +101,21 @@ Stream get_current_stream() { return c10::hip::getCurrentHIPStream(); #endif } + +namespace py=pybind11; +PYBIND11_MODULE(libtorch_tp_jit, m) { + py::class_(m, "DeviceProp") + .def(py::init()) + .def_readonly("name", &DeviceProp::name) + .def_readonly("warpsize", &DeviceProp::warpsize) + .def_readonly("major", &DeviceProp::major) + .def_readonly("minor", &DeviceProp::minor) + .def_readonly("multiprocessorCount", &DeviceProp::multiprocessorCount) + .def_readonly("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); + + py::class_(m, "GPUTimer") + .def(py::init<>()) + .def("start", &GPUTimer::start) + .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) + .def("clear_L2_cache", &GPUTimer::clear_L2_cache); +} diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 8cf28ce..6d63c8a 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -96,4 +96,22 @@ Stream get_current_stream() { aoti_torch_get_current_stream(device_index, &stream_ptr)) //return (Stream) stream_ptr; return (Stream) 0; +} + +namespace nb = nanobind; +NB_MODULE(litbtorch_tp_jit_stable, m) { + nb::class_(m, "DeviceProp") + .def(nb::init()) + .def_ro("name", &DeviceProp::name) + .def_ro("warpsize", &DeviceProp::warpsize) + .def_ro("major", &DeviceProp::major) + .def_ro("minor", &DeviceProp::minor) + .def_ro("multiprocessorCount", &DeviceProp::multiprocessorCount) + .def_ro("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); + + nb::class_(m, "GPUTimer") + .def(nb::init<>()) + .def("start", &GPUTimer::start) + .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) + .def("clear_L2_cache", &GPUTimer::clear_L2_cache); } \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/torch_core.hpp b/openequivariance/openequivariance/extension/torch_core.hpp index f36a280..36a634a 100644 --- a/openequivariance/openequivariance/extension/torch_core.hpp +++ b/openequivariance/openequivariance/extension/torch_core.hpp @@ -37,7 +37,6 @@ using namespace std; using json = json11::Json; -namespace nb = nanobind; Dtype enum_to_torch_dtype(int64_t i); @@ -649,20 +648,3 @@ REGISTER_LIBRARY(libtorch_tp_jit, m) { m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); }; - -NB_MODULE(litbtorch_tp_jit_stable, m) { - nb::class_(m, "DeviceProp") - .def(nb::init()) - .def_ro("name", &DeviceProp::name) - .def_ro("warpsize", &DeviceProp::warpsize) - .def_ro("major", &DeviceProp::major) - .def_ro("minor", &DeviceProp::minor) - .def_ro("multiprocessorCount", &DeviceProp::multiprocessorCount) - .def_ro("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); - - nb::class_(m, "GPUTimer") - .def(nb::init<>()) - .def("start", &GPUTimer::start) - .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) - .def("clear_L2_cache", &GPUTimer::clear_L2_cache); -} \ No newline at end of file From 66514690585a1057b197e743299e0a09e966e6ec Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 14 Feb 2026 21:35:29 -0800 Subject: [PATCH 09/28] Ready to begin the import testing process. --- .../_torch/extlib/__init__.py | 134 ++++++++++-------- 1 file changed, 78 insertions(+), 56 deletions(-) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index cc27f1a..28ec7a6 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -4,6 +4,7 @@ import warnings import sysconfig from pathlib import Path +from packaging.version import Version import torch @@ -20,6 +21,7 @@ extension_module = None assert torch.version.cuda or torch.version.hip, "Only CUDA and HIP backends are supported" + def postprocess_kernel(kernel): if torch.version.hip: kernel = kernel.replace("__syncwarp();", "__threadfence_block();") @@ -44,66 +46,86 @@ def postprocess_kernel(kernel): except Exception as e: LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" -try: - from torch.utils.cpp_extension import library_paths, include_paths - - extra_cflags = ["-O3"] - torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"] +def jit_compile_extension(): + global BUILT_EXTENSION, BUILT_EXTENSION_ERROR, extension_module + try: + from torch.utils.cpp_extension import library_paths, include_paths - include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) + extra_cflags = ["-O3"] + torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"] - if LINKED_LIBPYTHON: - extra_link_args.pop() - extra_link_args.extend( - [ - f"-Wl,--no-as-needed,-rpath,{python_lib_dir}", - f"-L{python_lib_dir}", - f"-l{python_lib_name}", - ], - ) - if torch.version.cuda: - extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) + include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) - try: - torch_libs, cuda_libs = library_paths("cuda") - extra_link_args.append("-Wl,-rpath," + torch_libs) - extra_link_args.append("-L" + cuda_libs) - if os.path.exists(cuda_libs + "/stubs"): - extra_link_args.append("-L" + cuda_libs + "/stubs") - except Exception as e: - getLogger().info(str(e)) - - extra_cflags.append("-DCUDA_BACKEND") - elif torch.version.hip: - extra_link_args.extend(["-lhiprtc"]) - torch_libs = library_paths("cuda")[0] - extra_link_args.append("-Wl,-rpath," + torch_libs) - extra_cflags.append("-DHIP_BACKEND") - - torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] - include_dirs = [ - oeq_root + "/extension/" + d for d in include_dirs - ] + include_paths("cuda") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - try: - extension_module = torch.utils.cpp_extension.load( - "libtorch_tp_jit", - torch_sources, - extra_cflags=extra_cflags, - extra_include_paths=include_dirs, - extra_ldflags=extra_link_args, + if LINKED_LIBPYTHON: + extra_link_args.pop() + extra_link_args.extend( + [ + f"-Wl,--no-as-needed,-rpath,{python_lib_dir}", + f"-L{python_lib_dir}", + f"-l{python_lib_name}", + ], ) - torch.ops.load_library(extension_module.__file__) - BUILT_EXTENSION = True - except Exception as e: - # If compiling torch fails (e.g. low gcc version), we should fall back to the - # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). - BUILT_EXTENSION_ERROR = e -except Exception as e: - BUILT_EXTENSION_ERROR = f"Error JIT-compiling OpenEquivariance Extension: {e}" + if torch.version.cuda: + extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) + + try: + torch_libs, cuda_libs = library_paths("cuda") + extra_link_args.append("-Wl,-rpath," + torch_libs) + extra_link_args.append("-L" + cuda_libs) + if os.path.exists(cuda_libs + "/stubs"): + extra_link_args.append("-L" + cuda_libs + "/stubs") + except Exception as e: + getLogger().info(str(e)) + + extra_cflags.append("-DCUDA_BACKEND") + elif torch.version.hip: + extra_link_args.extend(["-lhiprtc"]) + torch_libs = library_paths("cuda")[0] + extra_link_args.append("-Wl,-rpath," + torch_libs) + extra_cflags.append("-DHIP_BACKEND") + + torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] + include_dirs = [ + oeq_root + "/extension/" + d for d in include_dirs + ] + include_paths("cuda") + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + try: + extension_module = torch.utils.cpp_extension.load( + "libtorch_tp_jit", + torch_sources, + extra_cflags=extra_cflags, + extra_include_paths=include_dirs, + extra_ldflags=extra_link_args, + ) + torch.ops.load_library(extension_module.__file__) + BUILT_EXTENSION = True + except Exception as e: + # If compiling torch fails (e.g. low gcc version), we should fall back to the + # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). + BUILT_EXTENSION_ERROR = e + except Exception as e: + BUILT_EXTENSION_ERROR = f"Error JIT-compiling OpenEquivariance Extension: {e}" + +def use_precompiled_extension(): + global BUILT_EXTENSION, BUILT_EXTENSION_ERROR, extension_module + try: + if torch.version.cuda: + import openequivariance._torch.extlib.libtorch_tp_jit_stable as extension_module + elif torch.version.hip: + import openequivariance._torch.extlib.libtorch_tp_jit_stable_hip as extension_module + + torch.ops.load_library(extension_module.__file__) + BUILT_EXTENSION = True + except Exception as e: + BUILT_EXTENSION_ERROR = f"Error loading precompiled OpenEquivariance Extension: {e}" + +if Version(torch.__version__) > Version("2.9.0"): + use_precompiled_extension() +else: + jit_compile_extension() def torch_ext_so_path(): From 991f19b16d8208108d3eadd026c19fcb7d494e1c Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 15 Feb 2026 14:13:44 -0800 Subject: [PATCH 10/28] Temporary save. --- .gitignore | 1 + openequivariance/CMakeLists.txt | 13 +++++++++++-- .../openequivariance/_torch/extlib/__init__.py | 6 +++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 07455d1..fe768ba 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ __pycache__ # working folders dist build +cbuild outputs/* visualization/* figures/* diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index ddaffb5..c57425c 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -16,10 +16,15 @@ FetchContent_MakeAvailable(libtorch) set(LIBTORCH_INCLUDE_DIR "${libtorch_SOURCE_DIR}/include") set(LIBTORCH_LIB_DIR "${libtorch_SOURCE_DIR}/lib") +find_library(TORCH_CPU_LIB NAMES torch_cpu PATHS "${LIBTORCH_LIB_DIR}" NO_DEFAULT_PATH) +find_library(C10_LIB NAMES c10 PATHS "${LIBTORCH_LIB_DIR}" NO_DEFAULT_PATH) message(STATUS "LibTorch Include: ${LIBTORCH_INCLUDE_DIR}") message(STATUS "LibTorch Lib: ${LIBTORCH_LIB_DIR}") +message(STATUS "Torch CPU Library: ${TORCH_CPU_LIB}") +message(STATUS "Torch C10 Library: ${C10_LIB}") + # Setup Nanobind execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir @@ -60,7 +65,11 @@ function(add_stable_extension target_name backend_define) ${EXT_JSON_DIR} ${LIBTORCH_INCLUDE_DIR} ${LIBTORCH_API_INCLUDE_DIR} - ) + ) + target_link_libraries(${target_name} PRIVATE + ${TORCH_CPU_LIB} + ${C10_LIB} + ) endfunction() find_package(CUDAToolkit QUIET) @@ -69,7 +78,7 @@ find_package(hip QUIET) if(CUDAToolkit_FOUND) message(STATUS "Building stable extension with CUDA backend.") add_stable_extension(torch_tp_jit_stable_cuda CUDA_BACKEND) - target_link_libraries(torch_tp_jit_stable_cuda PRIVATE CUDA::cudart CUDA::cuda_driver) + target_link_libraries(torch_tp_jit_stable_cuda PRIVATE CUDA::cudart CUDA::cuda_driver CUDA::nvrtc) install(TARGETS torch_tp_jit_stable_cuda LIBRARY DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib") endif() diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 28ec7a6..04f8356 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -113,16 +113,16 @@ def use_precompiled_extension(): global BUILT_EXTENSION, BUILT_EXTENSION_ERROR, extension_module try: if torch.version.cuda: - import openequivariance._torch.extlib.libtorch_tp_jit_stable as extension_module + import openequivariance._torch.extlib.torch_tp_jit_stable_cuda as extension_module elif torch.version.hip: - import openequivariance._torch.extlib.libtorch_tp_jit_stable_hip as extension_module + import openequivariance._torch.extlib.torch_tp_jit_stable_hip as extension_module torch.ops.load_library(extension_module.__file__) BUILT_EXTENSION = True except Exception as e: BUILT_EXTENSION_ERROR = f"Error loading precompiled OpenEquivariance Extension: {e}" -if Version(torch.__version__) > Version("2.9.0"): +if Version(torch.__version__) > Version("2.9.9"): use_precompiled_extension() else: jit_compile_extension() From a430a022c3d86603c2cac9e5ebde775fff50892e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 15 Feb 2026 15:05:12 -0800 Subject: [PATCH 11/28] Fixed some more details about the C++ backend. --- .../extension/libtorch_tp_jit.cpp | 20 -------- .../extension/libtorch_tp_jit_stable.cpp | 30 ++++-------- .../openequivariance/extension/torch_core.hpp | 48 ++++++++----------- 3 files changed, 30 insertions(+), 68 deletions(-) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index f432ae4..698a142 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -50,26 +50,6 @@ void tensor_zero_(Tensor &tensor) { tensor.zero_(); } -caffe2::TypeMeta tensor_dtype(const Tensor &tensor) { - return tensor.dtype(); -} - -bool tensor_is_cuda(const Tensor &tensor) { - return tensor.device().is_cuda(); -} - -int64_t tensor_dim(const Tensor &tensor) { - return tensor.dim(); -} - -int64_t tensor_size(const Tensor &tensor, int64_t dim) { - return tensor.size(dim); -} - -int64_t tensor_numel(const Tensor &tensor) { - return tensor.numel(); -} - void alert_not_deterministic(const char *name) { at::globalContext().alertNotDeterministic(name); } diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 6d63c8a..9607c87 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -54,26 +55,6 @@ void tensor_zero_(Tensor &tensor) { torch::stable::zero_(tensor); } -/*Dtype tensor_dtype(const Tensor &tensor) { - return tensor.scalar_type(); -}*/ - -bool tensor_is_cuda(const Tensor &tensor) { - return tensor.is_cuda(); -} - -int64_t tensor_dim(const Tensor &tensor) { - return tensor.dim(); -} - -int64_t tensor_size(const Tensor &tensor, int64_t dim) { - return tensor.size(dim); -} - -int64_t tensor_numel(const Tensor &tensor) { - return tensor.numel(); -} - void alert_not_deterministic(const char *name) { (void)name; } @@ -98,8 +79,15 @@ Stream get_current_stream() { return (Stream) 0; } +#ifdef CUDA_BACKEND + #define EXTENSION_NAME torch_tp_jit_stable_cuda +#endif +#ifdef HIP_BACKEND + #define EXTENSION_NAME torch_tp_jit_stable_hip +#endif + namespace nb = nanobind; -NB_MODULE(litbtorch_tp_jit_stable, m) { +NB_MODULE(EXTENSION_NAME, m) { nb::class_(m, "DeviceProp") .def(nb::init()) .def_ro("name", &DeviceProp::name) diff --git a/openequivariance/openequivariance/extension/torch_core.hpp b/openequivariance/openequivariance/extension/torch_core.hpp index 36a634a..8e9e43e 100644 --- a/openequivariance/openequivariance/extension/torch_core.hpp +++ b/openequivariance/openequivariance/extension/torch_core.hpp @@ -46,12 +46,6 @@ Tensor tensor_empty_like(const Tensor &ref, const std::vector &sizes); Tensor tensor_zeros_like(const Tensor &ref, const std::vector &sizes); void tensor_zero_(Tensor &tensor); -//Dtype tensor_dtype(const Tensor &tensor); -bool tensor_is_cuda(const Tensor &tensor); -int64_t tensor_dim(const Tensor &tensor); -int64_t tensor_size(const Tensor &tensor, int64_t dim); -int64_t tensor_numel(const Tensor &tensor); - void alert_not_deterministic(const char *name); Stream get_current_stream(); @@ -76,23 +70,23 @@ inline std::string shape_to_string(std::initializer_list shape) { inline std::string tensor_sizes_str(const Tensor &tensor) { std::ostringstream oss; oss << "["; - int64_t dims = tensor_dim(tensor); + int64_t dims = tensor.dim(); for (int64_t i = 0; i < dims; ++i) { if (i > 0) { oss << ", "; } - oss << tensor_size(tensor, i); + oss << tensor.size(i); } oss << "]"; return oss.str(); } inline std::vector tensor_sizes_vec(const Tensor &tensor) { - int64_t dims = tensor_dim(tensor); + int64_t dims = tensor.dim(); std::vector sizes; sizes.reserve(static_cast(dims)); for (int64_t i = 0; i < dims; ++i) { - sizes.push_back(tensor_size(tensor, i)); + sizes.push_back(tensor.size(i)); } return sizes; } @@ -116,11 +110,11 @@ inline void check_tensor(const Tensor &tensor, std::initializer_list expected_shape, Dtype expected_dtype, std::string tensor_name) { - bool shape_ok = (tensor_dim(tensor) == static_cast(expected_shape.size())); + bool shape_ok = (tensor.dim() == static_cast(expected_shape.size())); if (shape_ok) { int64_t i = 0; for (int64_t dim : expected_shape) { - if (tensor_size(tensor, i) != dim) { + if (tensor.size(i) != dim) { shape_ok = false; break; } @@ -132,11 +126,11 @@ inline void check_tensor(const Tensor &tensor, "Shape mismatch for tensor '", tensor_name, "'. Expected: ", shape_to_string(expected_shape), ". Got: ", tensor_sizes_str(tensor)); - TCHECK(tensor_is_cuda(tensor), "Tensor '", tensor_name, "' is not on the GPU."); - /*TCHECK(tensor_dtype(tensor) == expected_dtype, + TCHECK(tensor.is_cuda(), "Tensor '", tensor_name, "' is not on the GPU."); + TCHECK(tensor.scalar_type() == expected_dtype, "Dtype mismatch for tensor '", tensor_name, "'. Expected: ", static_cast(expected_dtype), - ". Got: ", static_cast(tensor_dtype(tensor)));*/ + ". Got: ", static_cast(tensor.scalar_type())); } inline std::unordered_map parse_json_config(const json &j_obj) { @@ -207,7 +201,7 @@ inline std::pair*, KernelProp> Tensor cpu_tensor = tensor_to_cpu_contiguous(json_bytes); std::string json_payload( reinterpret_cast(tensor_data_ptr_u8(cpu_tensor)), - tensor_numel(cpu_tensor) + cpu_tensor.numel() ); std::string err; @@ -246,7 +240,7 @@ inline std::pair*, KernelProp> Tensor cpu_tensor = tensor_to_cpu_contiguous(json_bytes); std::string json_payload( reinterpret_cast(tensor_data_ptr_u8(cpu_tensor)), - tensor_numel(cpu_tensor) + cpu_tensor.numel() ); std::string err; @@ -287,7 +281,7 @@ inline Tensor jit_tp_forward( auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); Stream stream = get_current_stream(); - const int64_t num_batch = tensor_size(L1_in, 0); + const int64_t num_batch = L1_in.size(0); check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -325,7 +319,7 @@ inline tuple jit_tp_backward( auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); Stream stream = get_current_stream(); - const int64_t num_batch = tensor_size(L1_in, 0); + const int64_t num_batch = L1_in.size(0); check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -373,7 +367,7 @@ inline tuple jit_tp_double_backward( auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); Stream stream = get_current_stream(); - const int64_t num_batch = tensor_size(L1_in, 0); + const int64_t num_batch = L1_in.size(0); check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -405,7 +399,7 @@ inline tuple jit_tp_double_backward( if (k.shared_weights) { tensor_zero_(W_grad); - TCHECK(tensor_dim(W) == 1); + TCHECK(W.dim() == 1); } jit_kernel->double_backward( @@ -439,8 +433,8 @@ inline Tensor jit_conv_forward( auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); Stream stream = get_current_stream(); - const int64_t nnz = tensor_size(rows, 0); - const int64_t node_count = tensor_size(L1_in, 0); + const int64_t nnz = rows.size(0); + const int64_t node_count = L1_in.size(0); check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -495,8 +489,8 @@ inline tuple jit_conv_backward( auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); Stream stream = get_current_stream(); - const int64_t nnz = tensor_size(rows, 0); - const int64_t node_count = tensor_size(L1_in, 0); + const int64_t nnz = rows.size(0); + const int64_t node_count = L1_in.size(0); check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -564,8 +558,8 @@ inline tuple jit_conv_double_backward( auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); Stream stream = get_current_stream(); - const int64_t nnz = tensor_size(rows, 0); - const int64_t node_count = tensor_size(L1_in, 0); + const int64_t nnz = rows.size(0); + const int64_t node_count = L1_in.size(0); check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); From 254ac479a59e7bb86a66473716202aacc2e278ff Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 15 Feb 2026 23:46:20 -0800 Subject: [PATCH 12/28] Even more things working. --- openequivariance/CMakeLists.txt | 26 ++++++++++++++++--- .../_torch/extlib/__init__.py | 2 +- .../{util => backend}/backend_cuda.hpp | 0 .../{util => backend}/backend_hip.hpp | 0 .../extension/libtorch_tp_jit.cpp | 1 + .../extension/libtorch_tp_jit_stable.cpp | 20 +++++++------- .../extension/stubs/libtorch_cuda.cpp | 3 +++ openequivariance_extjax/CMakeLists.txt | 4 +-- openequivariance_extjax/src/libjax_tp_jit.cpp | 4 +-- 9 files changed, 42 insertions(+), 18 deletions(-) rename openequivariance/openequivariance/extension/{util => backend}/backend_cuda.hpp (100%) rename openequivariance/openequivariance/extension/{util => backend}/backend_hip.hpp (100%) create mode 100644 openequivariance/openequivariance/extension/stubs/libtorch_cuda.cpp diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index c57425c..cf619db 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -35,7 +35,7 @@ message(STATUS "nanobind cmake directory: ${nanobind_ROOT}") find_package(nanobind CONFIG REQUIRED) set(EXT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/extension") -set(EXT_UTIL_DIR "${EXT_DIR}/util") +set(EXT_BACKEND_DIR "${EXT_DIR}/backend") set(EXT_JSON_DIR "${EXT_DIR}/json11") # Source files @@ -61,10 +61,9 @@ function(add_stable_extension target_name backend_define) target_include_directories(${target_name} PRIVATE ${EXT_DIR} - ${EXT_UTIL_DIR} + ${EXT_BACKEND_DIR} ${EXT_JSON_DIR} ${LIBTORCH_INCLUDE_DIR} - ${LIBTORCH_API_INCLUDE_DIR} ) target_link_libraries(${target_name} PRIVATE ${TORCH_CPU_LIB} @@ -77,8 +76,27 @@ find_package(hip QUIET) if(CUDAToolkit_FOUND) message(STATUS "Building stable extension with CUDA backend.") + + add_library(cuda_stub_lib SHARED ${EXT_DIR}/stubs/libtorch_cuda.cpp) + + target_include_directories(cuda_stub_lib PRIVATE + ${LIBTORCH_INCLUDE_DIR} + ) + + set_target_properties(cuda_stub_lib PROPERTIES + OUTPUT_NAME "torch_cuda" + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) + add_stable_extension(torch_tp_jit_stable_cuda CUDA_BACKEND) - target_link_libraries(torch_tp_jit_stable_cuda PRIVATE CUDA::cudart CUDA::cuda_driver CUDA::nvrtc) + + target_link_libraries(torch_tp_jit_stable_cuda PRIVATE + CUDA::cudart + CUDA::cuda_driver + CUDA::nvrtc + cuda_stub_lib + ) install(TARGETS torch_tp_jit_stable_cuda LIBRARY DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib") endif() diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 04f8356..e1adc99 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -54,7 +54,7 @@ def jit_compile_extension(): extra_cflags = ["-O3"] torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"] - include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) + include_dirs, extra_link_args = (["backend"], ["-Wl,--no-as-needed"]) if LINKED_LIBPYTHON: extra_link_args.pop() diff --git a/openequivariance/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/openequivariance/extension/backend/backend_cuda.hpp similarity index 100% rename from openequivariance/openequivariance/extension/util/backend_cuda.hpp rename to openequivariance/openequivariance/extension/backend/backend_cuda.hpp diff --git a/openequivariance/openequivariance/extension/util/backend_hip.hpp b/openequivariance/openequivariance/extension/backend/backend_hip.hpp similarity index 100% rename from openequivariance/openequivariance/extension/util/backend_hip.hpp rename to openequivariance/openequivariance/extension/backend/backend_hip.hpp diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index 698a142..3eb5c57 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -13,6 +13,7 @@ #include #include #include +#include using Tensor = torch::Tensor; using Dtype = torch::Dtype; diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 9607c87..2f61c26 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -1,3 +1,5 @@ +#define USE_CUDA + #include #include #include @@ -68,15 +70,15 @@ void *data_ptr(const Tensor &tensor) { } Stream get_current_stream() { - int32_t device_index; - StreamOpaque* stream_ptr; - - TORCH_ERROR_CODE_CHECK( - aoti_torch_get_current_device_index(&device_index)) - TORCH_ERROR_CODE_CHECK( - aoti_torch_get_current_stream(device_index, &stream_ptr)) - //return (Stream) stream_ptr; - return (Stream) 0; + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + + #ifdef CUDA_BACKEND + return static_cast(stream_ptr); + #elif defined(HIP_BACKEND) + return static_cast(stream_ptr); + #endif } #ifdef CUDA_BACKEND diff --git a/openequivariance/openequivariance/extension/stubs/libtorch_cuda.cpp b/openequivariance/openequivariance/extension/stubs/libtorch_cuda.cpp new file mode 100644 index 0000000..9a69eb8 --- /dev/null +++ b/openequivariance/openequivariance/extension/stubs/libtorch_cuda.cpp @@ -0,0 +1,3 @@ +#include + +extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); \ No newline at end of file diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 90eafe6..c815c39 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -58,8 +58,8 @@ set(OEQ_JAX_SOURCES set(OEQ_JAX_HEADERS ${HEADER_DIR}/convolution.hpp ${HEADER_DIR}/tensorproducts.hpp - ${HEADER_DIR}/util/backend_cuda.hpp - ${HEADER_DIR}/util/backend_hip.hpp + ${HEADER_DIR}/backend/backend_cuda.hpp + ${HEADER_DIR}/backend/backend_hip.hpp ${HEADER_DIR}/json11/json11.hpp ) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index bf3fd69..87e5e78 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -18,7 +18,7 @@ using json = json11::Json; #include #include - #include "util/backend_cuda.hpp" + #include "backend/backend_cuda.hpp" #include "group_mm_cuda.hpp" using JITKernel = CUJITKernel; using GPU_Allocator = CUDA_Allocator; @@ -29,7 +29,7 @@ using json = json11::Json; #endif #ifdef HIP_BACKEND - #include "util/backend_hip.hpp" + #include "backend/backend_hip.hpp" #include "group_mm_hip.hpp" using JITKernel = HIPJITKernel; using GPU_Allocator = HIP_Allocator; From 3ea9aefb6281cc131890ce8b528492ff33eeee7b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 16 Feb 2026 21:02:13 -0800 Subject: [PATCH 13/28] Ready to test on HIP. --- openequivariance/CMakeLists.txt | 17 +++++++++++++++-- .../stubs/{libtorch_cuda.cpp => stream.cpp} | 0 2 files changed, 15 insertions(+), 2 deletions(-) rename openequivariance/openequivariance/extension/stubs/{libtorch_cuda.cpp => stream.cpp} (100%) diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index cf619db..b26353b 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -77,7 +77,7 @@ find_package(hip QUIET) if(CUDAToolkit_FOUND) message(STATUS "Building stable extension with CUDA backend.") - add_library(cuda_stub_lib SHARED ${EXT_DIR}/stubs/libtorch_cuda.cpp) + add_library(cuda_stub_lib SHARED ${EXT_DIR}/stubs/stream.cpp) target_include_directories(cuda_stub_lib PRIVATE ${LIBTORCH_INCLUDE_DIR} @@ -103,8 +103,21 @@ endif() if(hip_FOUND) message(STATUS "Building stable extension with HIP backend.") + + add_library(hip_stub_lib SHARED ${EXT_DIR}/stubs/stream.cpp) + + target_include_directories(hip_stub_lib PRIVATE + ${LIBTORCH_INCLUDE_DIR} + ) + + set_target_properties(hip_stub_lib PROPERTIES + OUTPUT_NAME "torch_hip" + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) + add_stable_extension(torch_tp_jit_stable_hip HIP_BACKEND) - target_link_libraries(torch_tp_jit_stable_hip PRIVATE hiprtc) + target_link_libraries(torch_tp_jit_stable_hip PRIVATE hiprtc hip_stub_lib) install(TARGETS torch_tp_jit_stable_hip LIBRARY DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib") endif() diff --git a/openequivariance/openequivariance/extension/stubs/libtorch_cuda.cpp b/openequivariance/openequivariance/extension/stubs/stream.cpp similarity index 100% rename from openequivariance/openequivariance/extension/stubs/libtorch_cuda.cpp rename to openequivariance/openequivariance/extension/stubs/stream.cpp From 1c030d12b99c83a6caa8506518fe7293c8fcb992 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 16 Feb 2026 21:03:41 -0800 Subject: [PATCH 14/28] Minor comment fix. --- openequivariance/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index b26353b..cfe1eee 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -53,7 +53,7 @@ function(add_stable_extension target_name backend_define) POSITION_INDEPENDENT_CODE ON ) - # CRITICAL: Manually enforce CXX11 ABI to match the downloaded LibTorch + # Enforce CXX11 ABI to match LibTorch target_compile_definitions(${target_name} PRIVATE ${backend_define}=1 _GLIBCXX_USE_CXX11_ABI=1 From e7eeb3225266da9fe3b1ba0396572821190c4c31 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 20 Feb 2026 20:56:11 -0800 Subject: [PATCH 15/28] Working on AOTI update. --- .../extension/test/CMakeLists.txt | 10 +- .../extension/test/load_aoti.cpp | 63 +++++++++++++ .../extension/test/load_jitscript.cpp | 62 ------------- tests/export_test.py | 93 +++++++++++++++++++ 4 files changed, 161 insertions(+), 67 deletions(-) create mode 100644 openequivariance/openequivariance/extension/test/load_aoti.cpp delete mode 100644 openequivariance/openequivariance/extension/test/load_jitscript.cpp diff --git a/openequivariance/openequivariance/extension/test/CMakeLists.txt b/openequivariance/openequivariance/extension/test/CMakeLists.txt index 3209d41..d869bec 100644 --- a/openequivariance/openequivariance/extension/test/CMakeLists.txt +++ b/openequivariance/openequivariance/extension/test/CMakeLists.txt @@ -1,9 +1,9 @@ cmake_minimum_required(VERSION 3.5 FATAL_ERROR) -project(test_oeq_jitscript_load) +project(test_oeq_aoti_load) find_package(Torch REQUIRED) -add_executable(load_jitscript load_jitscript.cpp) -target_link_libraries(load_jitscript "${TORCH_LIBRARIES}") -target_link_libraries(load_jitscript -Wl,--no-as-needed "${OEQ_EXTLIB}") -set_property(TARGET load_jitscript PROPERTY CXX_STANDARD 17) \ No newline at end of file +add_executable(load_aoti load_aoti.cpp) +target_link_libraries(load_aoti "${TORCH_LIBRARIES}") +target_link_libraries(load_aoti -Wl,--no-as-needed "${OEQ_EXTLIB}") +set_property(TARGET load_aoti PROPERTY CXX_STANDARD 17) diff --git a/openequivariance/openequivariance/extension/test/load_aoti.cpp b/openequivariance/openequivariance/extension/test/load_aoti.cpp new file mode 100644 index 0000000..0ffc967 --- /dev/null +++ b/openequivariance/openequivariance/extension/test/load_aoti.cpp @@ -0,0 +1,63 @@ +#include +#include +#include + +#include +#include + +/* +* This program takes in two JITScript modules that execute +* a tensor product in FP32 precision. +* The first module is compiled from e3nn, the second is +* OEQ's compiled module. The program checks that the +* two outputs are comparable. +*/ + +int main(int argc, const char* argv[]) { + if (argc != 7) { + std::cerr << "usage: load_aoti " + << " " + << " " + << " " + << " " + << " " + << " " + << std::endl; + + return 1; + } + + c10::InferenceMode mode; + + int64_t L1_dim = std::stoi(argv[3]); + int64_t L2_dim = std::stoi(argv[4]); + int64_t weight_numel = std::stoi(argv[5]); + int64_t batch_size = std::stoi(argv[6]); + + std::vector inputs; + inputs.push_back(torch::randn({batch_size, L1_dim}, at::kCUDA)); + inputs.push_back(torch::randn({batch_size, L2_dim}, at::kCUDA)); + inputs.push_back(torch::randn({batch_size, weight_numel}, at::kCUDA)); + + try { + torch::inductor::AOTIModelPackageLoader module_e3nn(argv[1]); + torch::inductor::AOTIModelPackageLoader module_oeq(argv[2]); + + std::vector output_e3nn = module_e3nn.run(inputs); + std::vector output_oeq = module_oeq.run(inputs); + + for (size_t i = 0; i < output_e3nn.size(); i++) { + if(at::allclose(output_e3nn[i], output_oeq[i], 1e-5, 1e-5)) { + return 0; + } + else { + std::cerr << "torch.allclose returned FALSE comparing model outputs." << std::endl; + return 1; + } + } + } + catch (const c10::Error& e) { + std::cerr << "error loading script module" << std::endl; + return 1; + } +} \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/test/load_jitscript.cpp b/openequivariance/openequivariance/extension/test/load_jitscript.cpp deleted file mode 100644 index 6ea7982..0000000 --- a/openequivariance/openequivariance/extension/test/load_jitscript.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include - -#include -#include - -/* -* This program takes in two JITScript modules that execute -* a tensor product in FP32 precision. -* The first module is compiled from e3nn, the second is -* OEQ's compiled module. The program checks that the -* two outputs are comparable. -*/ - -int main(int argc, const char* argv[]) { - if (argc != 7) { - std::cerr << "usage: load_jitscript " - << " " - << " " - << " " - << " " - << " " - << " " - << std::endl; - - return 1; - } - - int64_t L1_dim = std::stoi(argv[3]); - int64_t L2_dim = std::stoi(argv[4]); - int64_t weight_numel = std::stoi(argv[5]); - int64_t batch_size = std::stoi(argv[6]); - - torch::Device device(torch::kCUDA); - std::vector inputs; - inputs.push_back(torch::randn({batch_size, L1_dim}, device)); - inputs.push_back(torch::randn({batch_size, L2_dim}, device)); - inputs.push_back(torch::randn({batch_size, weight_numel}, device)); - - torch::jit::script::Module module_e3nn, module_oeq; - try { - module_e3nn = torch::jit::load(argv[1]); - module_oeq = torch::jit::load(argv[2]); - } - catch (const c10::Error& e) { - std::cerr << "error loading script module" << std::endl; - return 1; - } - - module_e3nn.to(device); - module_oeq.to(device); - - at::Tensor output_e3nn = module_e3nn.forward(inputs).toTensor(); - at::Tensor output_oeq = module_oeq.forward(inputs).toTensor(); - - if(at::allclose(output_e3nn, output_oeq, 1e-5, 1e-5)) { - return 0; - } - else { - std::cerr << "torch.allclose returned FALSE comparing model outputs." << std::endl; - return 1; - } -} \ No newline at end of file diff --git a/tests/export_test.py b/tests/export_test.py index efdaf86..5fd0145 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1,11 +1,17 @@ import torch import pytest import tempfile +import subprocess +import shutil +import os +import sys +import importlib.resources import numpy as np import openequivariance as oeq from torch_geometric import EdgeIndex +from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct @pytest.fixture(scope="session") def problem_and_irreps(): @@ -128,3 +134,90 @@ def test_aoti(tp_and_inputs): aoti_result = aoti_model(*inputs) assert torch.allclose(uncompiled_result, aoti_result, atol=1e-5) + + +def test_aoti_cpp_inference(problem_and_irreps): + assert oeq.LINKED_LIBPYTHON, oeq.LINKED_LIBPYTHON_ERROR + problem, X_ir, Y_ir, _ = problem_and_irreps + cmake_prefix_path = torch.utils.cmake_prefix_path + torch_ext_so_path = oeq.torch_ext_so_path() + + gen = torch.Generator(device="cuda") + gen.manual_seed(0) + batch_size = 1000 + + # Create models + oeq_tp = oeq.TensorProduct(problem).to("cuda") + e3nn_tp = E3NNTensorProduct(problem).e3nn_tp.to("cuda") + + # Prepare inputs for export + X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) + Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) + W = torch.rand(batch_size, problem.weight_numel, device="cuda", generator=gen) + inputs = (X, Y, W) + + with ( + tempfile.TemporaryDirectory() as tmpdir, + tempfile.NamedTemporaryFile(suffix=".pt2") as oeq_file, + tempfile.NamedTemporaryFile(suffix=".pt2") as e3nn_file, + ): + # Export and compile with AOTI + exported_oeq = torch.export.export(oeq_tp, args=inputs, strict=False) + torch._inductor.aoti_compile_and_package( + exported_oeq, package_path=oeq_file.name + ) + + exported_e3nn = torch.export.export(e3nn_tp, args=inputs, strict=False) + torch._inductor.aoti_compile_and_package( + exported_e3nn, package_path=e3nn_file.name + ) + + test_path = importlib.resources.files("openequivariance") / "extension" / "test" + build_dir = os.path.join(tmpdir, "build") + os.makedirs(build_dir, exist_ok=True) + + for item in test_path.iterdir(): + shutil.copy(item, tmpdir) + + try: + subprocess.run( + [ + "cmake", + "..", + "-DCMAKE_BUILD_TYPE=Release", + "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, + "-DOEQ_EXTLIB=" + torch_ext_so_path, + ], + cwd=build_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + subprocess.run( + ["make"], + cwd=build_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + subprocess.run( + [ + "./load_aoti", + e3nn_file.name, + oeq_file.name, + str(X_ir.dim), + str(Y_ir.dim), + str(problem.weight_numel), + str(batch_size), + ], + cwd=build_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + except subprocess.CalledProcessError as e: + print(e.stdout.decode(), file=sys.stderr) + print(e.stderr.decode(), file=sys.stderr) + assert False From 174bd44126a563d98e1e24fc94f6c51066b944da Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 20 Feb 2026 22:19:15 -0800 Subject: [PATCH 16/28] AOTI loading works. --- openequivariance/CMakeLists.txt | 64 +++++++++++++---- .../_torch/extlib/__init__.py | 71 ++++++++++++------- .../extension/libtorch_tp_jit_stable.cpp | 42 +++++------ tests/export_test.py | 1 + 4 files changed, 117 insertions(+), 61 deletions(-) diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index cfe1eee..ec56c80 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -44,7 +44,10 @@ set(OEQ_SOURCES ${EXT_JSON_DIR}/json11.cpp ) -function(add_stable_extension target_name backend_define) +set(OEQ_INSTALL_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib") + +function(add_stable_extension target_name backend_define link_libraries) + # Create nanobind extension nanobind_add_module(${target_name} NB_STATIC ${OEQ_SOURCES}) set_target_properties(${target_name} PROPERTIES @@ -56,7 +59,8 @@ function(add_stable_extension target_name backend_define) # Enforce CXX11 ABI to match LibTorch target_compile_definitions(${target_name} PRIVATE ${backend_define}=1 - _GLIBCXX_USE_CXX11_ABI=1 + _GLIBCXX_USE_CXX11_ABI=1 + INCLUDE_NB_EXTENSION ) target_include_directories(${target_name} PRIVATE @@ -68,7 +72,39 @@ function(add_stable_extension target_name backend_define) target_link_libraries(${target_name} PRIVATE ${TORCH_CPU_LIB} ${C10_LIB} - ) + ${link_libraries} + ) + + install(TARGETS ${target_name} LIBRARY DESTINATION "${OEQ_INSTALL_DIR}") + + # AOTI C++ library (identical except without nanobind and without INCLUDE_NB_EXTENSION) + set(aoti_target_name ${target_name}_aoti) + add_library(${aoti_target_name} SHARED ${OEQ_SOURCES}) + + set_target_properties(${aoti_target_name} PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON + ) + + target_compile_definitions(${aoti_target_name} PRIVATE + ${backend_define}=1 + _GLIBCXX_USE_CXX11_ABI=1 + ) + + target_include_directories(${aoti_target_name} PRIVATE + ${EXT_DIR} + ${EXT_BACKEND_DIR} + ${EXT_JSON_DIR} + ${LIBTORCH_INCLUDE_DIR} + ) + target_link_libraries(${aoti_target_name} PRIVATE + ${TORCH_CPU_LIB} + ${C10_LIB} + ${link_libraries} + ) + + install(TARGETS ${aoti_target_name} LIBRARY DESTINATION "${OEQ_INSTALL_DIR}") endfunction() find_package(CUDAToolkit QUIET) @@ -89,16 +125,13 @@ if(CUDAToolkit_FOUND) CXX_STANDARD 17 ) - add_stable_extension(torch_tp_jit_stable_cuda CUDA_BACKEND) - - target_link_libraries(torch_tp_jit_stable_cuda PRIVATE - CUDA::cudart - CUDA::cuda_driver + set(CUDA_LINK_LIBS + CUDA::cudart + CUDA::cuda_driver CUDA::nvrtc - cuda_stub_lib + cuda_stub_lib ) - - install(TARGETS torch_tp_jit_stable_cuda LIBRARY DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib") + add_stable_extension(oeq_stable_cuda CUDA_BACKEND "${CUDA_LINK_LIBS}") endif() if(hip_FOUND) @@ -116,10 +149,11 @@ if(hip_FOUND) CXX_STANDARD 17 ) - add_stable_extension(torch_tp_jit_stable_hip HIP_BACKEND) - target_link_libraries(torch_tp_jit_stable_hip PRIVATE hiprtc hip_stub_lib) - - install(TARGETS torch_tp_jit_stable_hip LIBRARY DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib") + set(HIP_LINK_LIBS + hiprtc + hip_stub_lib + ) + add_stable_extension(torch_stable_hip HIP_BACKEND "${HIP_LINK_LIBS}") endif() if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index e1adc99..56ac0d0 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -29,25 +29,28 @@ def postprocess_kernel(kernel): kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") return kernel -# Locate libpython (required for AOTI) -try: - python_lib_dir = sysconfig.get_config_var("LIBDIR") - major, minor = sys.version_info.major, sys.version_info.minor - python_lib_name = f"python{major}.{minor}" - - libpython_so = os.path.join(python_lib_dir, f"lib{python_lib_name}.so") - libpython_a = os.path.join(python_lib_dir, f"lib{python_lib_name}.a") - if not (os.path.exists(libpython_so) or os.path.exists(libpython_a)): - raise FileNotFoundError( - f"libpython not found, tried {libpython_so} and {libpython_a}" - ) - - LINKED_LIBPYTHON = True -except Exception as e: - LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" - -def jit_compile_extension(): - global BUILT_EXTENSION, BUILT_EXTENSION_ERROR, extension_module + +def load_jit_extension(): + global BUILT_EXTENSION, BUILT_EXTENSION_ERROR, LINKED_LIBPYTHON, LINKED_LIBPYTHON_ERROR, extension_module + + # Locate libpython (required for AOTI) + try: + python_lib_dir = sysconfig.get_config_var("LIBDIR") + major, minor = sys.version_info.major, sys.version_info.minor + python_lib_name = f"python{major}.{minor}" + + libpython_so = os.path.join(python_lib_dir, f"lib{python_lib_name}.so") + libpython_a = os.path.join(python_lib_dir, f"lib{python_lib_name}.a") + if not (os.path.exists(libpython_so) or os.path.exists(libpython_a)): + raise FileNotFoundError( + f"libpython not found, tried {libpython_so} and {libpython_a}" + ) + + LINKED_LIBPYTHON = True + except Exception as e: + LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" + + try: from torch.utils.cpp_extension import library_paths, include_paths @@ -109,27 +112,43 @@ def jit_compile_extension(): except Exception as e: BUILT_EXTENSION_ERROR = f"Error JIT-compiling OpenEquivariance Extension: {e}" -def use_precompiled_extension(): - global BUILT_EXTENSION, BUILT_EXTENSION_ERROR, extension_module +def load_precompiled_extension(): + global BUILT_EXTENSION, BUILT_EXTENSION_ERROR, LINKED_LIBPYTHON, extension_module + LINKED_LIBPYTHON = True # Doesn't actually use libpython, just set this as true anyway try: if torch.version.cuda: - import openequivariance._torch.extlib.torch_tp_jit_stable_cuda as extension_module + import openequivariance._torch.extlib.oeq_stable_cuda as extension_module elif torch.version.hip: - import openequivariance._torch.extlib.torch_tp_jit_stable_hip as extension_module + import openequivariance._torch.extlib.oeq_stable_hip as extension_module torch.ops.load_library(extension_module.__file__) BUILT_EXTENSION = True except Exception as e: BUILT_EXTENSION_ERROR = f"Error loading precompiled OpenEquivariance Extension: {e}" + +USE_PRECOMPILED_EXTENSION = False +if os.getenv("OEQ_JIT_EXTENSION", "0") != "1" \ + and Version(torch.__version__) > Version("2.9.9") \ + and torch.cuda.is_available() and torch.version.cuda: + USE_PRECOMPILED_EXTENSION = True + if Version(torch.__version__) > Version("2.9.9"): - use_precompiled_extension() + load_precompiled_extension() else: - jit_compile_extension() + load_jit_extension() def torch_ext_so_path(): - return extension_module.__file__ + if not USE_PRECOMPILED_EXTENSION: + return extension_module.__file__ + else: + dirname = os.path.dirname(extension_module.__file__) + if torch.version.cuda: + return os.path.join(dirname, "liboeq_stable_cuda_aoti.so") + elif torch.version.hip: + return os.path.join(dirname, "liboeq_stable_hip_aoti.so") + sys.modules["oeq_utilities"] = extension_module diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 2f61c26..5ddcb7f 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -10,7 +10,6 @@ #include #include #include -#include "nanobind/nanobind.h" #ifdef HIP_BACKEND #include @@ -82,26 +81,29 @@ Stream get_current_stream() { } #ifdef CUDA_BACKEND - #define EXTENSION_NAME torch_tp_jit_stable_cuda + #define EXTENSION_NAME oeq_stable_cuda #endif #ifdef HIP_BACKEND - #define EXTENSION_NAME torch_tp_jit_stable_hip + #define EXTENSION_NAME oeq_stable_hip #endif -namespace nb = nanobind; -NB_MODULE(EXTENSION_NAME, m) { - nb::class_(m, "DeviceProp") - .def(nb::init()) - .def_ro("name", &DeviceProp::name) - .def_ro("warpsize", &DeviceProp::warpsize) - .def_ro("major", &DeviceProp::major) - .def_ro("minor", &DeviceProp::minor) - .def_ro("multiprocessorCount", &DeviceProp::multiprocessorCount) - .def_ro("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); - - nb::class_(m, "GPUTimer") - .def(nb::init<>()) - .def("start", &GPUTimer::start) - .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) - .def("clear_L2_cache", &GPUTimer::clear_L2_cache); -} \ No newline at end of file +#ifdef INCLUDE_NB_EXTENSION + #include "nanobind/nanobind.h" + namespace nb = nanobind; + NB_MODULE(EXTENSION_NAME, m) { + nb::class_(m, "DeviceProp") + .def(nb::init()) + .def_ro("name", &DeviceProp::name) + .def_ro("warpsize", &DeviceProp::warpsize) + .def_ro("major", &DeviceProp::major) + .def_ro("minor", &DeviceProp::minor) + .def_ro("multiprocessorCount", &DeviceProp::multiprocessorCount) + .def_ro("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); + + nb::class_(m, "GPUTimer") + .def(nb::init<>()) + .def("start", &GPUTimer::start) + .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) + .def("clear_L2_cache", &GPUTimer::clear_L2_cache); + } +#endif \ No newline at end of file diff --git a/tests/export_test.py b/tests/export_test.py index 5fd0145..c2d1ffa 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -187,6 +187,7 @@ def test_aoti_cpp_inference(problem_and_irreps): "-DCMAKE_BUILD_TYPE=Release", "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, "-DOEQ_EXTLIB=" + torch_ext_so_path, + "-DCMAKE_CXX_COMPILER=g++" ], cwd=build_dir, check=True, From f5fdd1596eb105245489f18a7e56447d68b8675f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 20 Feb 2026 22:35:35 -0800 Subject: [PATCH 17/28] Added careful conditions about when to use a precompiled extension. --- openequivariance/CMakeLists.txt | 2 +- openequivariance/openequivariance/__init__.py | 3 +- .../_torch/extlib/__init__.py | 28 +++++++++++++++---- openequivariance/pyproject.toml | 3 +- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index ec56c80..e7d42bd 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -157,5 +157,5 @@ if(hip_FOUND) endif() if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND) - message(FATAL_ERROR "Neither CUDAToolkit nor HIP was found. Cannot build the stable extension.") + message(WARNING "Neither CUDAToolkit nor HIP was found. The stable extension will not be built.") endif() \ No newline at end of file diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 60ca8b1..7fc0b0f 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -60,7 +60,8 @@ def _check_package_editable(): LINKED_LIBPYTHON, LINKED_LIBPYTHON_ERROR, BUILT_EXTENSION, - BUILT_EXTENSION_ERROR + BUILT_EXTENSION_ERROR, + USE_PRECOMPILED_EXTENSION, ) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 56ac0d0..4c4b8a2 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -127,15 +127,31 @@ def load_precompiled_extension(): BUILT_EXTENSION_ERROR = f"Error loading precompiled OpenEquivariance Extension: {e}" -USE_PRECOMPILED_EXTENSION = False -if os.getenv("OEQ_JIT_EXTENSION", "0") != "1" \ - and Version(torch.__version__) > Version("2.9.9") \ - and torch.cuda.is_available() and torch.version.cuda: - USE_PRECOMPILED_EXTENSION = True +USE_PRECOMPILED_EXTENSION = True +WARNING_MESSAGE = "" -if Version(torch.__version__) > Version("2.9.9"): +if os.getenv("OEQ_JIT_EXTENSION", "0") == "1": + WARNING_MESSAGE += "Environment variable OEQ_JIT_EXTENSION=1 is set.\n" + USE_PRECOMPILED_EXTENSION = False + +if Version(torch.__version__) <= Version("2.9.9"): + WARNING_MESSAGE += f"PyTorch version {torch.__version__} is < 2.10, minimum required for precompiled extension. Please upgrade.\n" + USE_PRECOMPILED_EXTENSION = False + +if torch.version.hip: + WARNING_MESSAGE += "HIP does not support precompiled extension yet.\n" + USE_PRECOMPILED_EXTENSION = False + +if not os.path.exists(os.path.join(os.path.dirname(__file__), "liboeq_stable_cuda_aoti.so")): + WARNING_MESSAGE += "Precompiled extension shared object not found.\n" + USE_PRECOMPILED_EXTENSION = False + + +if USE_PRECOMPILED_EXTENSION: load_precompiled_extension() else: + WARNING_MESSAGE += "Falling back to JIT compilation of OpenEquivariance extension, which may hang. If this happens, clear ./cache/torch_extensions and try again.\n" + warnings.warn(WARNING_MESSAGE) load_jit_extension() diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 060a792..3b18272 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -17,8 +17,7 @@ dependencies = [ "setuptools", "ninja", "jinja2", - "numpy", - "nanobind" + "numpy" ] readme = "README.md" From c951dec0da0f8dbcaf5ed294efc19c907d1c6144 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 20 Feb 2026 22:42:32 -0800 Subject: [PATCH 18/28] Added detailed warning messages. --- openequivariance/openequivariance/_torch/extlib/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 4c4b8a2..4bf9277 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -150,8 +150,8 @@ def load_precompiled_extension(): if USE_PRECOMPILED_EXTENSION: load_precompiled_extension() else: - WARNING_MESSAGE += "Falling back to JIT compilation of OpenEquivariance extension, which may hang. If this happens, clear ./cache/torch_extensions and try again.\n" - warnings.warn(WARNING_MESSAGE) + WARNING_MESSAGE += "For these reasons, falling back to JIT compilation of OpenEquivariance extension, which may hang. If this happens, clear ~/.cache/torch_extensions or address the conditions above.\n" + warnings.warn(WARNING_MESSAGE, stacklevel=3) load_jit_extension() From bce13f3447d0e3972d00295cfd7d2654e45ec067 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 20 Feb 2026 22:45:03 -0800 Subject: [PATCH 19/28] Updated CI. --- .github/workflows/verify_extension_build.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 3988849..935b0e8 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -37,6 +37,12 @@ jobs: tests/import_test.py::test_extension_built \ tests/import_test.py::test_torch_extension_built + export OEQ_JIT_EXTENSION=1 + + pytest \ + tests/import_test.py::test_extension_built \ + tests/import_test.py::test_torch_extension_built + - name: Test JAX extension build run: | XLA_DIRECT_DOWNLOAD=1 pip install -e "./openequivariance_extjax" --no-build-isolation \ No newline at end of file From 028fe2ffcd01f3cfcf74d5e4997813cedc40bfe5 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Feb 2026 00:09:11 -0800 Subject: [PATCH 20/28] Tried defining symbol. --- .../openequivariance/extension/libtorch_tp_jit_stable.cpp | 3 ++- openequivariance/openequivariance/extension/stubs/stream.cpp | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 5ddcb7f..d75028d 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #ifdef HIP_BACKEND #include @@ -31,6 +30,8 @@ constexpr Dtype kByte = torch::headeronly::ScalarType::Byte; #include "torch_core.hpp" +AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); + Tensor tensor_to_cpu_contiguous(const Tensor &tensor) { torch::stable::Device device(torch::headeronly::DeviceType::CPU); return torch::stable::contiguous(torch::stable::to(tensor, device)); diff --git a/openequivariance/openequivariance/extension/stubs/stream.cpp b/openequivariance/openequivariance/extension/stubs/stream.cpp index 9a69eb8..0bcad99 100644 --- a/openequivariance/openequivariance/extension/stubs/stream.cpp +++ b/openequivariance/openequivariance/extension/stubs/stream.cpp @@ -1,3 +1,5 @@ #include -extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); \ No newline at end of file +AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream) { + return 0; +} \ No newline at end of file From 98f925720d7e3e197529545daf434e13b33d56a9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Feb 2026 15:26:18 -0800 Subject: [PATCH 21/28] Avoided symbol name mangling. --- openequivariance/CMakeLists.txt | 2 +- .../openequivariance/extension/libtorch_tp_jit_stable.cpp | 3 +-- .../openequivariance/extension/stubs/stream.cpp | 6 ++++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt index e7d42bd..43d6a4f 100644 --- a/openequivariance/CMakeLists.txt +++ b/openequivariance/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.15...3.30) -project(openequivariance_stable_ext LANGUAGES CXX) +project(openequivariance_stable_ext) find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index d75028d..5ddcb7f 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #ifdef HIP_BACKEND #include @@ -30,8 +31,6 @@ constexpr Dtype kByte = torch::headeronly::ScalarType::Byte; #include "torch_core.hpp" -AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); - Tensor tensor_to_cpu_contiguous(const Tensor &tensor) { torch::stable::Device device(torch::headeronly::DeviceType::CPU); return torch::stable::contiguous(torch::stable::to(tensor, device)); diff --git a/openequivariance/openequivariance/extension/stubs/stream.cpp b/openequivariance/openequivariance/extension/stubs/stream.cpp index 0bcad99..5e39739 100644 --- a/openequivariance/openequivariance/extension/stubs/stream.cpp +++ b/openequivariance/openequivariance/extension/stubs/stream.cpp @@ -1,5 +1,7 @@ #include -AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream) { - return 0; +extern "C" { + AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream) { + return 0; + } } \ No newline at end of file From 7fb9fc44445eef228ac75d77aeebb300f1e37d13 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Feb 2026 15:45:50 -0800 Subject: [PATCH 22/28] Ruff. --- .../openequivariance/_torch/TensorProduct.py | 3 +- .../_torch/TensorProductConv.py | 4 +- .../_torch/extlib/__init__.py | 47 +++++++++++++------ tests/export_test.py | 3 +- 4 files changed, 38 insertions(+), 19 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 8a524ac..35640b2 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -1,5 +1,5 @@ from openequivariance.core.LoopUnrollTP import LoopUnrollTP -from openequivariance import TPProblem +from openequivariance import TPProblem from openequivariance._torch import extlib import torch from openequivariance.core.utils import torch_to_oeq_dtype @@ -227,6 +227,7 @@ def register_autocast(): "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 ) + if extlib.BUILT_EXTENSION: register_torch_fakes() register_autograd() diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index 848b7ac..e1c0e74 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -6,7 +6,7 @@ from openequivariance._torch.extlib import ( postprocess_kernel, DeviceProp, - BUILT_EXTENSION + BUILT_EXTENSION, ) from openequivariance.core.ConvolutionBase import ( @@ -15,7 +15,7 @@ ) from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance._torch.TensorProduct import TensorProduct -from openequivariance import TPProblem +from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance._torch.utils import ( reorder_torch, diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 4bf9277..17da355 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -20,7 +20,10 @@ extension_module = None -assert torch.version.cuda or torch.version.hip, "Only CUDA and HIP backends are supported" +assert torch.version.cuda or torch.version.hip, ( + "Only CUDA and HIP backends are supported" +) + def postprocess_kernel(kernel): if torch.version.hip: @@ -31,7 +34,12 @@ def postprocess_kernel(kernel): def load_jit_extension(): - global BUILT_EXTENSION, BUILT_EXTENSION_ERROR, LINKED_LIBPYTHON, LINKED_LIBPYTHON_ERROR, extension_module + global \ + BUILT_EXTENSION, \ + BUILT_EXTENSION_ERROR, \ + LINKED_LIBPYTHON, \ + LINKED_LIBPYTHON_ERROR, \ + extension_module # Locate libpython (required for AOTI) try: @@ -50,7 +58,6 @@ def load_jit_extension(): except Exception as e: LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" - try: from torch.utils.cpp_extension import library_paths, include_paths @@ -112,9 +119,12 @@ def load_jit_extension(): except Exception as e: BUILT_EXTENSION_ERROR = f"Error JIT-compiling OpenEquivariance Extension: {e}" + def load_precompiled_extension(): global BUILT_EXTENSION, BUILT_EXTENSION_ERROR, LINKED_LIBPYTHON, extension_module - LINKED_LIBPYTHON = True # Doesn't actually use libpython, just set this as true anyway + LINKED_LIBPYTHON = ( + True # Doesn't actually use libpython, just set this as true anyway + ) try: if torch.version.cuda: import openequivariance._torch.extlib.oeq_stable_cuda as extension_module @@ -124,30 +134,34 @@ def load_precompiled_extension(): torch.ops.load_library(extension_module.__file__) BUILT_EXTENSION = True except Exception as e: - BUILT_EXTENSION_ERROR = f"Error loading precompiled OpenEquivariance Extension: {e}" + BUILT_EXTENSION_ERROR = ( + f"Error loading precompiled OpenEquivariance Extension: {e}" + ) USE_PRECOMPILED_EXTENSION = True -WARNING_MESSAGE = "" +WARNING_MESSAGE = "" if os.getenv("OEQ_JIT_EXTENSION", "0") == "1": - WARNING_MESSAGE += "Environment variable OEQ_JIT_EXTENSION=1 is set.\n" + WARNING_MESSAGE += "Environment variable OEQ_JIT_EXTENSION=1 is set.\n" USE_PRECOMPILED_EXTENSION = False if Version(torch.__version__) <= Version("2.9.9"): - WARNING_MESSAGE += f"PyTorch version {torch.__version__} is < 2.10, minimum required for precompiled extension. Please upgrade.\n" + WARNING_MESSAGE += f"PyTorch version {torch.__version__} is < 2.10, minimum required for precompiled extension. Please upgrade.\n" USE_PRECOMPILED_EXTENSION = False -if torch.version.hip: - WARNING_MESSAGE += "HIP does not support precompiled extension yet.\n" +if torch.version.hip: + WARNING_MESSAGE += "HIP does not support precompiled extension yet.\n" USE_PRECOMPILED_EXTENSION = False -if not os.path.exists(os.path.join(os.path.dirname(__file__), "liboeq_stable_cuda_aoti.so")): +if not os.path.exists( + os.path.join(os.path.dirname(__file__), "liboeq_stable_cuda_aoti.so") +): WARNING_MESSAGE += "Precompiled extension shared object not found.\n" USE_PRECOMPILED_EXTENSION = False -if USE_PRECOMPILED_EXTENSION: +if USE_PRECOMPILED_EXTENSION: load_precompiled_extension() else: WARNING_MESSAGE += "For these reasons, falling back to JIT compilation of OpenEquivariance extension, which may hang. If this happens, clear ~/.cache/torch_extensions or address the conditions above.\n" @@ -170,15 +184,18 @@ def torch_ext_so_path(): if BUILT_EXTENSION: from oeq_utilities import ( - #GroupMM_F32, - #GroupMM_F64, + # GroupMM_F32, + # GroupMM_F64, DeviceProp, GPUTimer, ) else: + def _raise_import_error_helper(import_target: str): if not BUILT_EXTENSION: - raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}") + raise ImportError( + f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}" + ) def GroupMM_F32(*args, **kwargs): _raise_import_error_helper("GroupMM_F32") diff --git a/tests/export_test.py b/tests/export_test.py index c2d1ffa..6aba969 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -13,6 +13,7 @@ from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct + @pytest.fixture(scope="session") def problem_and_irreps(): X_ir, Y_ir, Z_ir = oeq.Irreps("32x5e"), oeq.Irreps("1x3e"), oeq.Irreps("32x5e") @@ -187,7 +188,7 @@ def test_aoti_cpp_inference(problem_and_irreps): "-DCMAKE_BUILD_TYPE=Release", "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, "-DOEQ_EXTLIB=" + torch_ext_so_path, - "-DCMAKE_CXX_COMPILER=g++" + "-DCMAKE_CXX_COMPILER=g++", ], cwd=build_dir, check=True, From 6014ee308d0ecb74a3bbfd208c23fce19c9a5f59 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Feb 2026 15:55:32 -0800 Subject: [PATCH 23/28] Removed accelerator.h from original library. --- openequivariance/openequivariance/extension/libtorch_tp_jit.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index 3eb5c57..698a142 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -13,7 +13,6 @@ #include #include #include -#include using Tensor = torch::Tensor; using Dtype = torch::Dtype; From 2df54cae55b2282ee413ec4c01174c310d5c35ce Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Feb 2026 15:58:23 -0800 Subject: [PATCH 24/28] Updated warning message. --- openequivariance/openequivariance/_torch/extlib/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 17da355..c765af9 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -164,7 +164,7 @@ def load_precompiled_extension(): if USE_PRECOMPILED_EXTENSION: load_precompiled_extension() else: - WARNING_MESSAGE += "For these reasons, falling back to JIT compilation of OpenEquivariance extension, which may hang. If this happens, clear ~/.cache/torch_extensions or address the conditions above.\n" + WARNING_MESSAGE += "For these reasons, falling back to JIT compilation of OpenEquivariance extension, which may hang. If waiting for >5 minutes, clear ~/.cache/torch_extensions or address the conditions above.\n" warnings.warn(WARNING_MESSAGE, stacklevel=3) load_jit_extension() From 0deafb9683b538f9ffa4c088ee9cb7d51de15318 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Feb 2026 15:59:47 -0800 Subject: [PATCH 25/28] Updated upgrade message. --- openequivariance/openequivariance/_torch/extlib/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index c765af9..be4113e 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -147,7 +147,7 @@ def load_precompiled_extension(): USE_PRECOMPILED_EXTENSION = False if Version(torch.__version__) <= Version("2.9.9"): - WARNING_MESSAGE += f"PyTorch version {torch.__version__} is < 2.10, minimum required for precompiled extension. Please upgrade.\n" + WARNING_MESSAGE += f"PyTorch version {torch.__version__} is < 2.10, minimum required for precompiled extension. Please upgrade to 2.10.\n" USE_PRECOMPILED_EXTENSION = False if torch.version.hip: From e4503d008615e0dcbc608b8c910022fa4bf7865d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Feb 2026 16:02:58 -0800 Subject: [PATCH 26/28] Updated PyTorch version in CI. --- .github/workflows/requirements_cuda_ci.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/requirements_cuda_ci.txt b/.github/workflows/requirements_cuda_ci.txt index 0e04348..c7121b8 100644 --- a/.github/workflows/requirements_cuda_ci.txt +++ b/.github/workflows/requirements_cuda_ci.txt @@ -1,5 +1,5 @@ numpy==2.2.5 -torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128 +torch==2.10.0 --index-url https://download.pytorch.org/whl/cu128 pytest==8.3.5 ninja==1.11.1.4 nanobind==2.10.2 From c60d6573eae4814feeb06960c488dbdb94777dca Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Feb 2026 16:15:28 -0800 Subject: [PATCH 27/28] Updated CI script. --- .github/workflows/verify_extension_build.yml | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 935b0e8..9053c4c 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -33,15 +33,11 @@ jobs: - name: Test CUDA extension build via import run: | - pytest \ - tests/import_test.py::test_extension_built \ - tests/import_test.py::test_torch_extension_built + pytest tests/import_test.py export OEQ_JIT_EXTENSION=1 - pytest \ - tests/import_test.py::test_extension_built \ - tests/import_test.py::test_torch_extension_built + pytest tests/import_test.py - name: Test JAX extension build run: | From 583b6596227fc86ca66fcbd2b409dce27f6264a0 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 21 Feb 2026 16:26:12 -0800 Subject: [PATCH 28/28] Updated JAX dependency list. --- .github/workflows/verify_extension_build.yml | 2 +- openequivariance/pyproject.toml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 9053c4c..6c90320 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -29,7 +29,7 @@ jobs: sudo apt-get update sudo apt install nvidia-cuda-toolkit pip install -r .github/workflows/requirements_cuda_ci.txt - pip install -e "./openequivariance" + pip install -e "./openequivariance[jax]" - name: Test CUDA extension build via import run: | diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 3b18272..a8fb6c2 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -61,7 +61,8 @@ dev = [ jax = [ "scikit-build-core", - "setuptools-scm" + "setuptools-scm", + "nanobind" ] [tool.setuptools.packages.find]