Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f0c7453
Making progress.
vbharadwaj-bk Feb 14, 2026
5583c8e
More progress.
vbharadwaj-bk Feb 14, 2026
cb65ed6
Saving temporarily.
vbharadwaj-bk Feb 14, 2026
796bf31
Fixed JIT issue.
vbharadwaj-bk Feb 14, 2026
d147a17
Managed to build stable extension.
vbharadwaj-bk Feb 14, 2026
092932b
Made some further changes.
vbharadwaj-bk Feb 15, 2026
f580cf1
Fixed the dynamic versioning bugs.
vbharadwaj-bk Feb 15, 2026
ad3cb62
Back to a working state.
vbharadwaj-bk Feb 15, 2026
6651469
Ready to begin the import testing process.
vbharadwaj-bk Feb 15, 2026
991f19b
Temporary save.
vbharadwaj-bk Feb 15, 2026
a430a02
Fixed some more details about the C++ backend.
vbharadwaj-bk Feb 15, 2026
254ac47
Even more things working.
vbharadwaj-bk Feb 16, 2026
3ea9aef
Ready to test on HIP.
vbharadwaj-bk Feb 17, 2026
1c030d1
Minor comment fix.
vbharadwaj-bk Feb 17, 2026
e7eeb32
Working on AOTI update.
vbharadwaj-bk Feb 21, 2026
174bd44
AOTI loading works.
vbharadwaj-bk Feb 21, 2026
f5fdd15
Added careful conditions about when to use a precompiled extension.
vbharadwaj-bk Feb 21, 2026
c951dec
Added detailed warning messages.
vbharadwaj-bk Feb 21, 2026
bce13f3
Updated CI.
vbharadwaj-bk Feb 21, 2026
028fe2f
Tried defining symbol.
vbharadwaj-bk Feb 21, 2026
98f9257
Avoided symbol name mangling.
vbharadwaj-bk Feb 21, 2026
7fb9fc4
Ruff.
vbharadwaj-bk Feb 21, 2026
6014ee3
Removed accelerator.h from original library.
vbharadwaj-bk Feb 21, 2026
2df54ca
Updated warning message.
vbharadwaj-bk Feb 21, 2026
0deafb9
Updated upgrade message.
vbharadwaj-bk Feb 21, 2026
e4503d0
Updated PyTorch version in CI.
vbharadwaj-bk Feb 22, 2026
c60d657
Updated CI script.
vbharadwaj-bk Feb 22, 2026
583b659
Updated JAX dependency list.
vbharadwaj-bk Feb 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/requirements_cuda_ci.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy==2.2.5
torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128
torch==2.10.0 --index-url https://download.pytorch.org/whl/cu128
pytest==8.3.5
ninja==1.11.1.4
nanobind==2.10.2
Expand Down
10 changes: 6 additions & 4 deletions .github/workflows/verify_extension_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ jobs:
sudo apt-get update
sudo apt install nvidia-cuda-toolkit
pip install -r .github/workflows/requirements_cuda_ci.txt
pip install -e "./openequivariance"
pip install -e "./openequivariance[jax]"

- name: Test CUDA extension build via import
run: |
pytest \
tests/import_test.py::test_extension_built \
tests/import_test.py::test_torch_extension_built
pytest tests/import_test.py

export OEQ_JIT_EXTENSION=1

pytest tests/import_test.py

- name: Test JAX extension build
run: |
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __pycache__
# working folders
dist
build
cbuild
outputs/*
visualization/*
figures/*
Expand Down Expand Up @@ -40,4 +41,5 @@ paper_benchmarks_v2
paper_benchmarks_v3

get_node.sh
*.egg-info
*.egg-info
_version.py
161 changes: 161 additions & 0 deletions openequivariance/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
cmake_minimum_required(VERSION 3.15...3.30)
project(openequivariance_stable_ext)

find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module)

# Download LibTorch
include(FetchContent)

FetchContent_Declare(
libtorch
URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.10.0%2Bcpu.zip"
)

message(STATUS "Downloading LibTorch...")
FetchContent_MakeAvailable(libtorch)

set(LIBTORCH_INCLUDE_DIR "${libtorch_SOURCE_DIR}/include")
set(LIBTORCH_LIB_DIR "${libtorch_SOURCE_DIR}/lib")
find_library(TORCH_CPU_LIB NAMES torch_cpu PATHS "${LIBTORCH_LIB_DIR}" NO_DEFAULT_PATH)
find_library(C10_LIB NAMES c10 PATHS "${LIBTORCH_LIB_DIR}" NO_DEFAULT_PATH)

message(STATUS "LibTorch Include: ${LIBTORCH_INCLUDE_DIR}")
message(STATUS "LibTorch Lib: ${LIBTORCH_LIB_DIR}")

message(STATUS "Torch CPU Library: ${TORCH_CPU_LIB}")
message(STATUS "Torch C10 Library: ${C10_LIB}")

# Setup Nanobind
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT
)
message(STATUS "nanobind cmake directory: ${nanobind_ROOT}")

find_package(nanobind CONFIG REQUIRED)

set(EXT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/extension")
set(EXT_BACKEND_DIR "${EXT_DIR}/backend")
set(EXT_JSON_DIR "${EXT_DIR}/json11")

# Source files
set(OEQ_SOURCES
${EXT_DIR}/libtorch_tp_jit_stable.cpp
${EXT_JSON_DIR}/json11.cpp
)

set(OEQ_INSTALL_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib")

function(add_stable_extension target_name backend_define link_libraries)
# Create nanobind extension
nanobind_add_module(${target_name} NB_STATIC ${OEQ_SOURCES})

set_target_properties(${target_name} PROPERTIES
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON
)

# Enforce CXX11 ABI to match LibTorch
target_compile_definitions(${target_name} PRIVATE
${backend_define}=1
_GLIBCXX_USE_CXX11_ABI=1
INCLUDE_NB_EXTENSION
)

target_include_directories(${target_name} PRIVATE
${EXT_DIR}
${EXT_BACKEND_DIR}
${EXT_JSON_DIR}
${LIBTORCH_INCLUDE_DIR}
)
target_link_libraries(${target_name} PRIVATE
${TORCH_CPU_LIB}
${C10_LIB}
${link_libraries}
)

install(TARGETS ${target_name} LIBRARY DESTINATION "${OEQ_INSTALL_DIR}")

# AOTI C++ library (identical except without nanobind and without INCLUDE_NB_EXTENSION)
set(aoti_target_name ${target_name}_aoti)
add_library(${aoti_target_name} SHARED ${OEQ_SOURCES})

set_target_properties(${aoti_target_name} PROPERTIES
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON
)

target_compile_definitions(${aoti_target_name} PRIVATE
${backend_define}=1
_GLIBCXX_USE_CXX11_ABI=1
)

target_include_directories(${aoti_target_name} PRIVATE
${EXT_DIR}
${EXT_BACKEND_DIR}
${EXT_JSON_DIR}
${LIBTORCH_INCLUDE_DIR}
)
target_link_libraries(${aoti_target_name} PRIVATE
${TORCH_CPU_LIB}
${C10_LIB}
${link_libraries}
)

install(TARGETS ${aoti_target_name} LIBRARY DESTINATION "${OEQ_INSTALL_DIR}")
endfunction()

find_package(CUDAToolkit QUIET)
find_package(hip QUIET)

if(CUDAToolkit_FOUND)
message(STATUS "Building stable extension with CUDA backend.")

add_library(cuda_stub_lib SHARED ${EXT_DIR}/stubs/stream.cpp)

target_include_directories(cuda_stub_lib PRIVATE
${LIBTORCH_INCLUDE_DIR}
)

set_target_properties(cuda_stub_lib PROPERTIES
OUTPUT_NAME "torch_cuda"
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD 17
)

set(CUDA_LINK_LIBS
CUDA::cudart
CUDA::cuda_driver
CUDA::nvrtc
cuda_stub_lib
)
add_stable_extension(oeq_stable_cuda CUDA_BACKEND "${CUDA_LINK_LIBS}")
endif()

if(hip_FOUND)
message(STATUS "Building stable extension with HIP backend.")

add_library(hip_stub_lib SHARED ${EXT_DIR}/stubs/stream.cpp)

target_include_directories(hip_stub_lib PRIVATE
${LIBTORCH_INCLUDE_DIR}
)

set_target_properties(hip_stub_lib PROPERTIES
OUTPUT_NAME "torch_hip"
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD 17
)

set(HIP_LINK_LIBS
hiprtc
hip_stub_lib
)
add_stable_extension(torch_stable_hip HIP_BACKEND "${HIP_LINK_LIBS}")
endif()

if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND)
message(WARNING "Neither CUDAToolkit nor HIP was found. The stable extension will not be built.")
endif()
3 changes: 1 addition & 2 deletions openequivariance/openequivariance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def _check_package_editable():
LINKED_LIBPYTHON_ERROR,
BUILT_EXTENSION,
BUILT_EXTENSION_ERROR,
TORCH_COMPILE,
TORCH_COMPILE_ERROR,
USE_PRECOMPILED_EXTENSION,
)


Expand Down
7 changes: 4 additions & 3 deletions openequivariance/openequivariance/_torch/TensorProduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def register_autocast():
)


register_torch_fakes()
register_autograd()
register_autocast()
if extlib.BUILT_EXTENSION:
register_torch_fakes()
register_autograd()
register_autocast()
8 changes: 5 additions & 3 deletions openequivariance/openequivariance/_torch/TensorProductConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from openequivariance._torch.extlib import (
postprocess_kernel,
DeviceProp,
BUILT_EXTENSION,
)

from openequivariance.core.ConvolutionBase import (
Expand Down Expand Up @@ -403,9 +404,10 @@ def register_autocast():
)


register_torch_fakes()
register_autograd()
register_autocast()
if BUILT_EXTENSION:
register_torch_fakes()
register_autograd()
register_autocast()


# ==================================================================
Expand Down
Loading