diff --git a/.github/workflows/requirements_cuda_ci.txt b/.github/workflows/requirements_cuda_ci.txt index 0e04348c..c7121b8d 100644 --- a/.github/workflows/requirements_cuda_ci.txt +++ b/.github/workflows/requirements_cuda_ci.txt @@ -1,5 +1,5 @@ numpy==2.2.5 -torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128 +torch==2.10.0 --index-url https://download.pytorch.org/whl/cu128 pytest==8.3.5 ninja==1.11.1.4 nanobind==2.10.2 diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 39888491..6c903205 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -29,13 +29,15 @@ jobs: sudo apt-get update sudo apt install nvidia-cuda-toolkit pip install -r .github/workflows/requirements_cuda_ci.txt - pip install -e "./openequivariance" + pip install -e "./openequivariance[jax]" - name: Test CUDA extension build via import run: | - pytest \ - tests/import_test.py::test_extension_built \ - tests/import_test.py::test_torch_extension_built + pytest tests/import_test.py + + export OEQ_JIT_EXTENSION=1 + + pytest tests/import_test.py - name: Test JAX extension build run: | diff --git a/.gitignore b/.gitignore index 64fcaa8d..fe768ba9 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ __pycache__ # working folders dist build +cbuild outputs/* visualization/* figures/* @@ -40,4 +41,5 @@ paper_benchmarks_v2 paper_benchmarks_v3 get_node.sh -*.egg-info \ No newline at end of file +*.egg-info +_version.py \ No newline at end of file diff --git a/openequivariance/CMakeLists.txt b/openequivariance/CMakeLists.txt new file mode 100644 index 00000000..43d6a4f6 --- /dev/null +++ b/openequivariance/CMakeLists.txt @@ -0,0 +1,161 @@ +cmake_minimum_required(VERSION 3.15...3.30) +project(openequivariance_stable_ext) + +find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) + +# Download LibTorch +include(FetchContent) + +FetchContent_Declare( + libtorch + URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.10.0%2Bcpu.zip" +) + +message(STATUS "Downloading LibTorch...") +FetchContent_MakeAvailable(libtorch) + +set(LIBTORCH_INCLUDE_DIR "${libtorch_SOURCE_DIR}/include") +set(LIBTORCH_LIB_DIR "${libtorch_SOURCE_DIR}/lib") +find_library(TORCH_CPU_LIB NAMES torch_cpu PATHS "${LIBTORCH_LIB_DIR}" NO_DEFAULT_PATH) +find_library(C10_LIB NAMES c10 PATHS "${LIBTORCH_LIB_DIR}" NO_DEFAULT_PATH) + +message(STATUS "LibTorch Include: ${LIBTORCH_INCLUDE_DIR}") +message(STATUS "LibTorch Lib: ${LIBTORCH_LIB_DIR}") + +message(STATUS "Torch CPU Library: ${TORCH_CPU_LIB}") +message(STATUS "Torch C10 Library: ${C10_LIB}") + +# Setup Nanobind +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT +) +message(STATUS "nanobind cmake directory: ${nanobind_ROOT}") + +find_package(nanobind CONFIG REQUIRED) + +set(EXT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/extension") +set(EXT_BACKEND_DIR "${EXT_DIR}/backend") +set(EXT_JSON_DIR "${EXT_DIR}/json11") + +# Source files +set(OEQ_SOURCES + ${EXT_DIR}/libtorch_tp_jit_stable.cpp + ${EXT_JSON_DIR}/json11.cpp +) + +set(OEQ_INSTALL_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib") + +function(add_stable_extension target_name backend_define link_libraries) + # Create nanobind extension + nanobind_add_module(${target_name} NB_STATIC ${OEQ_SOURCES}) + + set_target_properties(${target_name} PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON + ) + + # Enforce CXX11 ABI to match LibTorch + target_compile_definitions(${target_name} PRIVATE + ${backend_define}=1 + _GLIBCXX_USE_CXX11_ABI=1 + INCLUDE_NB_EXTENSION + ) + + target_include_directories(${target_name} PRIVATE + ${EXT_DIR} + ${EXT_BACKEND_DIR} + ${EXT_JSON_DIR} + ${LIBTORCH_INCLUDE_DIR} + ) + target_link_libraries(${target_name} PRIVATE + ${TORCH_CPU_LIB} + ${C10_LIB} + ${link_libraries} + ) + + install(TARGETS ${target_name} LIBRARY DESTINATION "${OEQ_INSTALL_DIR}") + + # AOTI C++ library (identical except without nanobind and without INCLUDE_NB_EXTENSION) + set(aoti_target_name ${target_name}_aoti) + add_library(${aoti_target_name} SHARED ${OEQ_SOURCES}) + + set_target_properties(${aoti_target_name} PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON + ) + + target_compile_definitions(${aoti_target_name} PRIVATE + ${backend_define}=1 + _GLIBCXX_USE_CXX11_ABI=1 + ) + + target_include_directories(${aoti_target_name} PRIVATE + ${EXT_DIR} + ${EXT_BACKEND_DIR} + ${EXT_JSON_DIR} + ${LIBTORCH_INCLUDE_DIR} + ) + target_link_libraries(${aoti_target_name} PRIVATE + ${TORCH_CPU_LIB} + ${C10_LIB} + ${link_libraries} + ) + + install(TARGETS ${aoti_target_name} LIBRARY DESTINATION "${OEQ_INSTALL_DIR}") +endfunction() + +find_package(CUDAToolkit QUIET) +find_package(hip QUIET) + +if(CUDAToolkit_FOUND) + message(STATUS "Building stable extension with CUDA backend.") + + add_library(cuda_stub_lib SHARED ${EXT_DIR}/stubs/stream.cpp) + + target_include_directories(cuda_stub_lib PRIVATE + ${LIBTORCH_INCLUDE_DIR} + ) + + set_target_properties(cuda_stub_lib PROPERTIES + OUTPUT_NAME "torch_cuda" + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) + + set(CUDA_LINK_LIBS + CUDA::cudart + CUDA::cuda_driver + CUDA::nvrtc + cuda_stub_lib + ) + add_stable_extension(oeq_stable_cuda CUDA_BACKEND "${CUDA_LINK_LIBS}") +endif() + +if(hip_FOUND) + message(STATUS "Building stable extension with HIP backend.") + + add_library(hip_stub_lib SHARED ${EXT_DIR}/stubs/stream.cpp) + + target_include_directories(hip_stub_lib PRIVATE + ${LIBTORCH_INCLUDE_DIR} + ) + + set_target_properties(hip_stub_lib PROPERTIES + OUTPUT_NAME "torch_hip" + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) + + set(HIP_LINK_LIBS + hiprtc + hip_stub_lib + ) + add_stable_extension(torch_stable_hip HIP_BACKEND "${HIP_LINK_LIBS}") +endif() + +if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND) + message(WARNING "Neither CUDAToolkit nor HIP was found. The stable extension will not be built.") +endif() \ No newline at end of file diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index a842a7c9..7fc0b0f9 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -61,8 +61,7 @@ def _check_package_editable(): LINKED_LIBPYTHON_ERROR, BUILT_EXTENSION, BUILT_EXTENSION_ERROR, - TORCH_COMPILE, - TORCH_COMPILE_ERROR, + USE_PRECOMPILED_EXTENSION, ) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 3885604f..35640b25 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -228,6 +228,7 @@ def register_autocast(): ) -register_torch_fakes() -register_autograd() -register_autocast() +if extlib.BUILT_EXTENSION: + register_torch_fakes() + register_autograd() + register_autocast() diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index 5788f2f0..e1c0e742 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -6,6 +6,7 @@ from openequivariance._torch.extlib import ( postprocess_kernel, DeviceProp, + BUILT_EXTENSION, ) from openequivariance.core.ConvolutionBase import ( @@ -403,9 +404,10 @@ def register_autocast(): ) -register_torch_fakes() -register_autograd() -register_autocast() +if BUILT_EXTENSION: + register_torch_fakes() + register_autograd() + register_autocast() # ================================================================== diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index a7b4b865..be4113ec 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -4,6 +4,7 @@ import warnings import sysconfig from pathlib import Path +from packaging.version import Version import torch @@ -14,46 +15,56 @@ BUILT_EXTENSION = False BUILT_EXTENSION_ERROR = None -TORCH_COMPILE = False -TORCH_COMPILE_ERROR = None - LINKED_LIBPYTHON = False LINKED_LIBPYTHON_ERROR = None -torch_module, generic_module = None, None -postprocess_kernel = lambda kernel: kernel # noqa : E731 +extension_module = None -try: - python_lib_dir = sysconfig.get_config_var("LIBDIR") - major, minor = sys.version_info.major, sys.version_info.minor - python_lib_name = f"python{major}.{minor}" +assert torch.version.cuda or torch.version.hip, ( + "Only CUDA and HIP backends are supported" +) - libpython_so = os.path.join(python_lib_dir, f"lib{python_lib_name}.so") - libpython_a = os.path.join(python_lib_dir, f"lib{python_lib_name}.a") - if not (os.path.exists(libpython_so) or os.path.exists(libpython_a)): - raise FileNotFoundError( - f"libpython not found, tried {libpython_so} and {libpython_a}" - ) - LINKED_LIBPYTHON = True -except Exception as e: - LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" +def postprocess_kernel(kernel): + if torch.version.hip: + kernel = kernel.replace("__syncwarp();", "__threadfence_block();") + kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") + kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") + return kernel -if BUILT_EXTENSION: - import openequivariance._torch.extlib.generic_module +def load_jit_extension(): + global \ + BUILT_EXTENSION, \ + BUILT_EXTENSION_ERROR, \ + LINKED_LIBPYTHON, \ + LINKED_LIBPYTHON_ERROR, \ + extension_module - generic_module = openequivariance._torch.extlib.generic_module + # Locate libpython (required for AOTI) + try: + python_lib_dir = sysconfig.get_config_var("LIBDIR") + major, minor = sys.version_info.major, sys.version_info.minor + python_lib_name = f"python{major}.{minor}" + + libpython_so = os.path.join(python_lib_dir, f"lib{python_lib_name}.so") + libpython_a = os.path.join(python_lib_dir, f"lib{python_lib_name}.a") + if not (os.path.exists(libpython_so) or os.path.exists(libpython_a)): + raise FileNotFoundError( + f"libpython not found, tried {libpython_so} and {libpython_a}" + ) + + LINKED_LIBPYTHON = True + except Exception as e: + LINKED_LIBPYTHON_ERROR = f"Error linking libpython:\n{e}\nSysconfig variables:\n{sysconfig.get_config_vars()}" -elif torch.version.cuda or torch.version.hip: try: from torch.utils.cpp_extension import library_paths, include_paths extra_cflags = ["-O3"] - generic_sources = ["generic_module.cpp"] torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"] - include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) + include_dirs, extra_link_args = (["backend"], ["-Wl,--no-as-needed"]) if LINKED_LIBPYTHON: extra_link_args.pop() @@ -81,18 +92,8 @@ extra_link_args.extend(["-lhiprtc"]) torch_libs = library_paths("cuda")[0] extra_link_args.append("-Wl,-rpath," + torch_libs) - - def postprocess(kernel): - kernel = kernel.replace("__syncwarp();", "__threadfence_block();") - kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") - kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") - return kernel - - postprocess_kernel = postprocess - extra_cflags.append("-DHIP_BACKEND") - generic_sources = [oeq_root + "/extension/" + src for src in generic_sources] torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] include_dirs = [ oeq_root + "/extension/" + d for d in include_dirs @@ -102,60 +103,100 @@ def postprocess(kernel): warnings.simplefilter("ignore") try: - torch_module = torch.utils.cpp_extension.load( + extension_module = torch.utils.cpp_extension.load( "libtorch_tp_jit", torch_sources, extra_cflags=extra_cflags, extra_include_paths=include_dirs, extra_ldflags=extra_link_args, ) - torch.ops.load_library(torch_module.__file__) - TORCH_COMPILE = True + torch.ops.load_library(extension_module.__file__) + BUILT_EXTENSION = True except Exception as e: # If compiling torch fails (e.g. low gcc version), we should fall back to the # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). - TORCH_COMPILE_ERROR = e - - generic_module = torch.utils.cpp_extension.load( - "generic_module", - generic_sources, - extra_cflags=extra_cflags, - extra_include_paths=include_dirs, - extra_ldflags=extra_link_args, - ) - if "generic_module" not in sys.modules: - sys.modules["generic_module"] = generic_module + BUILT_EXTENSION_ERROR = e + except Exception as e: + BUILT_EXTENSION_ERROR = f"Error JIT-compiling OpenEquivariance Extension: {e}" - if not TORCH_COMPILE: - warnings.warn( - "Could not compile integrated PyTorch wrapper. Falling back to Pybind11" - + f", but JITScript, compile fullgraph, and export will fail.\n {TORCH_COMPILE_ERROR}" - ) + +def load_precompiled_extension(): + global BUILT_EXTENSION, BUILT_EXTENSION_ERROR, LINKED_LIBPYTHON, extension_module + LINKED_LIBPYTHON = ( + True # Doesn't actually use libpython, just set this as true anyway + ) + try: + if torch.version.cuda: + import openequivariance._torch.extlib.oeq_stable_cuda as extension_module + elif torch.version.hip: + import openequivariance._torch.extlib.oeq_stable_hip as extension_module + + torch.ops.load_library(extension_module.__file__) BUILT_EXTENSION = True except Exception as e: - BUILT_EXTENSION_ERROR = f"Error building OpenEquivariance Extension: {e}" -else: - BUILT_EXTENSION_ERROR = "OpenEquivariance extension build not attempted" + BUILT_EXTENSION_ERROR = ( + f"Error loading precompiled OpenEquivariance Extension: {e}" + ) + + +USE_PRECOMPILED_EXTENSION = True +WARNING_MESSAGE = "" + +if os.getenv("OEQ_JIT_EXTENSION", "0") == "1": + WARNING_MESSAGE += "Environment variable OEQ_JIT_EXTENSION=1 is set.\n" + USE_PRECOMPILED_EXTENSION = False + +if Version(torch.__version__) <= Version("2.9.9"): + WARNING_MESSAGE += f"PyTorch version {torch.__version__} is < 2.10, minimum required for precompiled extension. Please upgrade to 2.10.\n" + USE_PRECOMPILED_EXTENSION = False +if torch.version.hip: + WARNING_MESSAGE += "HIP does not support precompiled extension yet.\n" + USE_PRECOMPILED_EXTENSION = False -def _raise_import_error_helper(import_target: str): - if not BUILT_EXTENSION: - raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}") +if not os.path.exists( + os.path.join(os.path.dirname(__file__), "liboeq_stable_cuda_aoti.so") +): + WARNING_MESSAGE += "Precompiled extension shared object not found.\n" + USE_PRECOMPILED_EXTENSION = False + + +if USE_PRECOMPILED_EXTENSION: + load_precompiled_extension() +else: + WARNING_MESSAGE += "For these reasons, falling back to JIT compilation of OpenEquivariance extension, which may hang. If waiting for >5 minutes, clear ~/.cache/torch_extensions or address the conditions above.\n" + warnings.warn(WARNING_MESSAGE, stacklevel=3) + load_jit_extension() def torch_ext_so_path(): - return torch_module.__file__ + if not USE_PRECOMPILED_EXTENSION: + return extension_module.__file__ + else: + dirname = os.path.dirname(extension_module.__file__) + if torch.version.cuda: + return os.path.join(dirname, "liboeq_stable_cuda_aoti.so") + elif torch.version.hip: + return os.path.join(dirname, "liboeq_stable_hip_aoti.so") + +sys.modules["oeq_utilities"] = extension_module if BUILT_EXTENSION: - from generic_module import ( - GroupMM_F32, - GroupMM_F64, + from oeq_utilities import ( + # GroupMM_F32, + # GroupMM_F64, DeviceProp, GPUTimer, ) else: + def _raise_import_error_helper(import_target: str): + if not BUILT_EXTENSION: + raise ImportError( + f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}" + ) + def GroupMM_F32(*args, **kwargs): _raise_import_error_helper("GroupMM_F32") diff --git a/openequivariance/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/openequivariance/extension/backend/backend_cuda.hpp similarity index 100% rename from openequivariance/openequivariance/extension/util/backend_cuda.hpp rename to openequivariance/openequivariance/extension/backend/backend_cuda.hpp diff --git a/openequivariance/openequivariance/extension/util/backend_hip.hpp b/openequivariance/openequivariance/extension/backend/backend_hip.hpp similarity index 100% rename from openequivariance/openequivariance/extension/util/backend_hip.hpp rename to openequivariance/openequivariance/extension/backend/backend_hip.hpp diff --git a/openequivariance/openequivariance/extension/generic_module.cpp b/openequivariance/openequivariance/extension/generic_module.cpp deleted file mode 100644 index b0996991..00000000 --- a/openequivariance/openequivariance/extension/generic_module.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include -#include -#include -#include -#include - -#ifdef CUDA_BACKEND - #include "backend_cuda.hpp" - #include "group_mm_cuda.hpp" - using JITKernel = CUJITKernel; - using GPU_Allocator = CUDA_Allocator; - - template - using GroupMM = GroupMMCUDA; -#endif - -#ifdef HIP_BACKEND - #include "backend_hip.hpp" - #include "group_mm_hip.hpp" - using JITKernel = HIPJITKernel; - using GPU_Allocator = HIP_Allocator; - - template - using GroupMM = GroupMMHIP; -#endif - -#include "tensorproducts.hpp" -#include "convolution.hpp" - -using namespace std; -namespace py=pybind11; - -PYBIND11_MODULE(generic_module, m) { - py::class_>(m, "GroupMM_F32") - .def(py::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - py::class_>(m, "GroupMM_F64") - .def(py::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - - py::class_(m, "DeviceProp") - .def(py::init()) - .def_readonly("name", &DeviceProp::name) - .def_readonly("warpsize", &DeviceProp::warpsize) - .def_readonly("major", &DeviceProp::major) - .def_readonly("minor", &DeviceProp::minor) - .def_readonly("multiprocessorCount", &DeviceProp::multiprocessorCount) - .def_readonly("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); - - py::class_(m, "GPUTimer") - .def(py::init<>()) - .def("start", &GPUTimer::start) - .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) - .def("clear_L2_cache", &GPUTimer::clear_L2_cache); -} \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index 6216909f..698a142f 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -1,603 +1,101 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "json11/json11.hpp" #include #ifdef CUDA_BACKEND #include - #include "backend_cuda.hpp" - #include "group_mm_cuda.hpp" - using JITKernel = CUJITKernel; - using GPU_Allocator = CUDA_Allocator; - - template - using GroupMM = GroupMMCUDA; - - inline Stream get_current_stream() { - return c10::cuda::getCurrentCUDAStream(); - } #endif #ifdef HIP_BACKEND #include - #include "backend_hip.hpp" - #include "group_mm_hip.hpp" - using JITKernel = HIPJITKernel; - using GPU_Allocator = HIP_Allocator; - - template - using GroupMM = GroupMMHIP; - - inline Stream get_current_stream() { - return c10::hip::getCurrentHIPStream(); - } #endif -#include "tensorproducts.hpp" -#include "convolution.hpp" - -using namespace std; -using json = json11::Json; - #include -#include -#include #include #include +#include +#include -torch::Dtype enum_to_torch_dtype(int64_t i){ - switch(i) { - case 1: return torch::kFloat; - case 2: return torch::kDouble; - case 3: return torch::kInt; - case 4: return torch::kLong; - case 5: return torch::kUInt8; - } - throw logic_error("Unsupported tensor datatype!"); -} - -inline void check_tensor(const torch::Tensor &tensor, - std::initializer_list expected_shape, - torch::Dtype expected_dtype, - std::string tensor_name) { - TORCH_CHECK(tensor.sizes() == expected_shape, - "Shape mismatch for tensor '", tensor_name, - "'. Expected: ", torch::IntArrayRef(expected_shape), - ". Got: ", tensor.sizes()); - TORCH_CHECK(tensor.device().is_cuda(), "Tensor '", tensor_name, "' is not on the GPU."); - TORCH_CHECK(tensor.dtype() == expected_dtype, "Dtype mismatch for tensor '", tensor_name, "'. Expected: ", expected_dtype, ". Got: ", tensor.dtype()); -} - -inline void* data_ptr(const torch::Tensor &tensor) { - if(tensor.dtype() == torch::kFloat) - return reinterpret_cast(tensor.data_ptr()); - else if(tensor.dtype() == torch::kDouble) - return reinterpret_cast(tensor.data_ptr()); - else if(tensor.dtype() == torch::kLong) - return reinterpret_cast(tensor.data_ptr()); - else if(tensor.dtype() == torch::kByte) - return reinterpret_cast(tensor.data_ptr()); // Replaces kUInt8 - else if(tensor.dtype() == torch::kInt) - return reinterpret_cast(tensor.data_ptr()); - else - throw logic_error("Unsupported tensor datatype!"); -} - -std::unordered_map parse_json_config(const json &j_obj) { - std::unordered_map result; - for (const auto &kv : j_obj.object_items()) { - result[kv.first] = static_cast(kv.second.number_value()); - } - return result; -} - -struct KernelProp { - int64_t L1_dim, L2_dim, L3_dim, weight_numel; - bool shared_weights; - torch::Dtype irrep_dtype; - torch::Dtype weight_dtype; - - int64_t workspace_size; // Convolution only - bool deterministic; - torch::Dtype idx_dtype; - torch::Dtype workspace_dtype; - - KernelProp() : - L1_dim(0), L2_dim(0), L3_dim(0), weight_numel(0), - shared_weights(false), - irrep_dtype(torch::kFloat), weight_dtype(torch::kFloat), - workspace_size(0), deterministic(false), - idx_dtype(torch::kInt), workspace_dtype(torch::kByte) {} - - KernelProp( - std::unordered_map &kernel_dims, bool is_convolution): - L1_dim(kernel_dims.at("L1_dim")), - L2_dim(kernel_dims.at("L2_dim")), - L3_dim(kernel_dims.at("L3_dim")), - weight_numel(kernel_dims.at("weight_numel")), - shared_weights(kernel_dims.at("shared_weights")), - irrep_dtype(enum_to_torch_dtype(kernel_dims.at("irrep_dtype"))), - weight_dtype(enum_to_torch_dtype(kernel_dims.at("weight_dtype"))), - workspace_dtype(torch::kByte) { - if(is_convolution) { - workspace_size = kernel_dims.at("workspace_size"); - deterministic = kernel_dims.at("deterministic"); - idx_dtype = enum_to_torch_dtype(kernel_dims.at("idx_dtype")); - } - } -}; - -std::unordered_map>, - KernelProp - >> tp_cache; - -std::unordered_map>, - KernelProp - >> conv_cache; - -std::mutex mut; +using Tensor = torch::Tensor; +using Dtype = torch::Dtype; -std::pair*, KernelProp> - compile_tp_with_caching(const torch::Tensor &json_bytes, - int64_t hash) { - { - const std::lock_guard lock(mut); - auto it = tp_cache.find(hash); - if (it == tp_cache.end()) { - torch::Tensor cpu_tensor = json_bytes.to(torch::kCPU).contiguous(); - std::string json_payload( - reinterpret_cast(cpu_tensor.data_ptr()), - cpu_tensor.numel() - ); +constexpr Dtype kFloat = torch::kFloat; +constexpr Dtype kDouble = torch::kDouble; +constexpr Dtype kInt = torch::kInt; +constexpr Dtype kLong = torch::kLong; +constexpr Dtype kByte = torch::kByte; - std::string err; - json root = json::parse(json_payload, err); - if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); +#define TCHECK TORCH_CHECK +#define BOX(x) x +#define REGISTER_LIBRARY_IMPL TORCH_LIBRARY_IMPL +#define REGISTER_LIBRARY TORCH_LIBRARY - std::string kernel_src = root["kernel"].string_value(); - auto forward_cfg = parse_json_config(root["forward_config"]); - auto backward_cfg = parse_json_config(root["backward_config"]); - auto dbackward_cfg = parse_json_config(root["double_backward_config"]); - auto kernel_prop_map = parse_json_config(root["kernel_prop"]); +#include "torch_core.hpp" - auto jit_tp_impl = std::make_unique>( - kernel_src, - forward_cfg, - backward_cfg, - dbackward_cfg, - kernel_prop_map); - - tp_cache.insert({hash, - std::make_pair(std::move(jit_tp_impl), - KernelProp(kernel_prop_map, false))}); - it = tp_cache.find(hash); - } - return {it->second.first.get(), it->second.second}; - } +Tensor tensor_to_cpu_contiguous(const Tensor &tensor) { + return tensor.to(torch::kCPU).contiguous(); } -std::pair*, KernelProp> - compile_conv_with_caching(const torch::Tensor &json_bytes, - int64_t hash) { - { - const std::lock_guard lock(mut); - auto it = conv_cache.find(hash); - if (it == conv_cache.end()) { - torch::Tensor cpu_tensor = json_bytes.to(torch::kCPU).contiguous(); - std::string json_payload( - reinterpret_cast(cpu_tensor.data_ptr()), - cpu_tensor.numel() - ); - - std::string err; - json root = json::parse(json_payload, err); - if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); - - std::string kernel_src = root["kernel"].string_value(); - auto forward_cfg = parse_json_config(root["forward_config"]); - auto backward_cfg = parse_json_config(root["backward_config"]); - auto dbackward_cfg = parse_json_config(root["double_backward_config"]); - auto kernel_prop_map = parse_json_config(root["kernel_prop"]); - - auto jit_conv_impl = std::make_unique>( - kernel_src, - forward_cfg, - backward_cfg, - dbackward_cfg, - kernel_prop_map); - - conv_cache.insert({hash, - std::make_pair(std::move(jit_conv_impl), - KernelProp(kernel_prop_map, true))}); - it = conv_cache.find(hash); - } - return {it->second.first.get(), it->second.second}; - } +Tensor tensor_contiguous(const Tensor &tensor) { + return tensor.contiguous(); } -// --------------------- Tensor Products -------------------------- - -torch::Tensor jit_tp_forward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - int64_t L3_dim) { - - auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t num_batch = L1_in.size(0); - - check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); - - if (k.shared_weights) - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); - else - check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); - - torch::Tensor L3_out = torch::empty({num_batch, k.L3_dim}, L1_in.options()); - - at::Tensor L1_contig = L1_in.contiguous(); - at::Tensor L2_contig = L2_in.contiguous(); - at::Tensor W_contig = W.contiguous(); - - jit_kernel->exec_tensor_product( - num_batch, - data_ptr(L1_contig), - data_ptr(L2_contig), - data_ptr(L3_out), - data_ptr(W_contig), - stream - ); - - return L3_out; +Tensor tensor_empty_like(const Tensor &ref, const std::vector &sizes) { + return torch::empty(sizes, ref.options()); } -tuple jit_tp_backward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - torch::Tensor L3_grad) { - - auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t num_batch = L1_in.size(0); - - check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); - check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); - - if (k.shared_weights) - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); - else - check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); - - torch::Tensor L1_grad = torch::empty(L1_in.sizes(), L1_in.options()); - torch::Tensor L2_grad = torch::empty(L2_in.sizes(), L2_in.options()); - torch::Tensor W_grad = torch::empty(W.sizes(), W.options()); - - if(k.shared_weights) - W_grad.zero_(); - - torch::Tensor L1_in_contig = L1_in.contiguous(); - torch::Tensor L2_in_contig = L2_in.contiguous(); - torch::Tensor W_contig = W.contiguous(); - torch::Tensor L3_grad_contig = L3_grad.contiguous(); - - jit_kernel->backward( - num_batch, - data_ptr(L1_in_contig), data_ptr(L1_grad), - data_ptr(L2_in_contig), data_ptr(L2_grad), - data_ptr(W_contig), data_ptr(W_grad), - data_ptr(L3_grad_contig), - stream - ); - - return tuple(L1_grad, L2_grad, W_grad); +Tensor tensor_zeros_like(const Tensor &ref, const std::vector &sizes) { + return torch::zeros(sizes, ref.options()); } -tuple jit_tp_double_backward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - torch::Tensor L3_grad, - torch::Tensor L1_dgrad, - torch::Tensor L2_dgrad, - torch::Tensor W_dgrad) { - - auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t num_batch = L1_in.size(0); - - check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); - check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); - check_tensor(L1_dgrad, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); - check_tensor(L2_dgrad, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); - - if (k.shared_weights){ - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); - check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); - } else { - check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); - check_tensor(W_dgrad, {num_batch, k.weight_numel}, k.weight_dtype, "W_dgrad"); - } - - torch::Tensor L1_grad = torch::empty(L1_in.sizes(), L1_in.options()); - torch::Tensor L2_grad = torch::empty(L2_in.sizes(), L2_in.options()); - torch::Tensor W_grad = torch::empty(W.sizes(), W.options()); - torch::Tensor L3_dgrad = torch::empty(L3_grad.sizes(), L3_grad.options()); - - torch::Tensor L1_in_contig = L1_in.contiguous(); - torch::Tensor L2_in_contig = L2_in.contiguous(); - torch::Tensor W_contig = W.contiguous(); - torch::Tensor L3_grad_contig = L3_grad.contiguous(); - - torch::Tensor L1_dgrad_contig = L1_dgrad.contiguous(); - torch::Tensor L2_dgrad_contig = L2_dgrad.contiguous(); - torch::Tensor W_dgrad_contig = W_dgrad.contiguous(); - - if(k.shared_weights) { - W_grad.zero_(); - TORCH_CHECK(W.dim() == 1); - } - - jit_kernel->double_backward( - num_batch, - data_ptr(L1_in_contig), data_ptr(L2_in_contig), - data_ptr(W_contig), data_ptr(L3_grad_contig), - data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), - data_ptr(W_dgrad_contig), - data_ptr(L1_grad), data_ptr(L2_grad), - data_ptr(W_grad), data_ptr(L3_dgrad), - stream - ); - - return tuple(L1_grad, L2_grad, W_grad, L3_dgrad); +void tensor_zero_(Tensor &tensor) { + tensor.zero_(); } - -// ========================= Convolution ================================== - -torch::Tensor jit_conv_forward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - int64_t L3_dim, - torch::Tensor rows, - torch::Tensor cols, - torch::Tensor workspace, - torch::Tensor transpose_perm) { - - auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t nnz = rows.size(0); - const int64_t node_count = L1_in.size(0); - - check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); - check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); - check_tensor(rows, {nnz}, k.idx_dtype, "rows"); - check_tensor(cols, {nnz}, k.idx_dtype, "cols"); - - if (k.deterministic){ - check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); - } else { - at::globalContext().alertNotDeterministic("OpenEquivariance_conv_atomic_forward"); - } - if (k.shared_weights) - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); - else - check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); - - torch::Tensor L3_out = torch::zeros({node_count, k.L3_dim}, L1_in.options()); - - torch::Tensor L1_contig = L1_in.contiguous(); - torch::Tensor L2_contig = L2_in.contiguous(); - torch::Tensor W_contig = W.contiguous(); - torch::Tensor rows_contig = rows.contiguous(); - torch::Tensor cols_contig = cols.contiguous(); - torch::Tensor workspace_contig = workspace.contiguous(); - - jit_kernel->exec_conv( - data_ptr(L1_contig), - data_ptr(L2_contig), - data_ptr(W_contig), - data_ptr(L3_out), - data_ptr(rows_contig), - data_ptr(cols_contig), - nnz, node_count, - data_ptr(workspace_contig), - stream); - - return L3_out; +void alert_not_deterministic(const char *name) { + at::globalContext().alertNotDeterministic(name); } -tuple jit_conv_backward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - torch::Tensor L3_grad, - torch::Tensor rows, - torch::Tensor cols, - torch::Tensor workspace, - torch::Tensor transpose_perm) { - - auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t nnz = rows.size(0); - const int64_t node_count = L1_in.size(0); - - check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); - check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); - check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); - check_tensor(rows, {nnz}, k.idx_dtype, "rows"); - check_tensor(cols, {nnz}, k.idx_dtype, "cols"); +const uint8_t *tensor_data_ptr_u8(const Tensor &tensor) { + return tensor.data_ptr(); +} - if (k.deterministic){ - check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); - } else { - at::globalContext().alertNotDeterministic("OpenEquivariance_conv_atomic_backward"); - } - - if (k.shared_weights) - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); +void *data_ptr(const Tensor &tensor) { + if (tensor.dtype() == torch::kFloat) + return reinterpret_cast(tensor.data_ptr()); + else if (tensor.dtype() == torch::kDouble) + return reinterpret_cast(tensor.data_ptr()); + else if (tensor.dtype() == torch::kLong) + return reinterpret_cast(tensor.data_ptr()); + else if (tensor.dtype() == torch::kByte) + return reinterpret_cast(tensor.data_ptr()); + else if (tensor.dtype() == torch::kInt) + return reinterpret_cast(tensor.data_ptr()); else - check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); - - torch::Tensor L1_grad = torch::zeros(L1_in.sizes(), L1_in.options()); - torch::Tensor L2_grad = torch::zeros(L2_in.sizes(), L2_in.options()); - torch::Tensor W_grad = torch::empty(W.sizes(), W.options()); - - torch::Tensor L1_in_contig = L1_in.contiguous(); - torch::Tensor L2_in_contig = L2_in.contiguous(); - torch::Tensor W_contig = W.contiguous(); - torch::Tensor L3_grad_contig = L3_grad.contiguous(); - - torch::Tensor rows_contig = rows.contiguous(); - torch::Tensor cols_contig = cols.contiguous(); - torch::Tensor workspace_contig = workspace.contiguous(); - torch::Tensor transpose_perm_contig = transpose_perm.contiguous(); - - if(k.shared_weights) - W_grad.zero_(); - - jit_kernel->backward( - data_ptr(L1_in_contig), data_ptr(L1_grad), - data_ptr(L2_in_contig), data_ptr(L2_grad), - data_ptr(W_contig), data_ptr(W_grad), - data_ptr(L3_grad_contig), - data_ptr(rows_contig), data_ptr(cols_contig), - nnz, node_count, - data_ptr(workspace_contig), - data_ptr(transpose_perm_contig), - stream); - - return tuple(L1_grad, L2_grad, W_grad); + throw std::logic_error("Unsupported tensor datatype!"); } -tuple jit_conv_double_backward( - torch::Tensor json_bytes, int64_t hash, - torch::Tensor L1_in, - torch::Tensor L2_in, - torch::Tensor W, - torch::Tensor L3_grad, - torch::Tensor L1_dgrad, - torch::Tensor L2_dgrad, - torch::Tensor W_dgrad, - torch::Tensor rows, - torch::Tensor cols, - torch::Tensor workspace, - torch::Tensor transpose_perm) { - - auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); - Stream stream = get_current_stream(); - - const int64_t nnz = rows.size(0); - const int64_t node_count = L1_in.size(0); - - check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); - check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); - check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); - check_tensor(L1_dgrad, {node_count, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); - check_tensor(L2_dgrad, {nnz, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); - check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); - check_tensor(rows, {nnz}, k.idx_dtype, "rows"); - check_tensor(cols, {nnz}, k.idx_dtype, "cols"); - - if (k.deterministic) { - check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); - } else { - at::globalContext().alertNotDeterministic("OpenEquivariance_conv_atomic_double_backward"); - } - - if (k.shared_weights) { - check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); - check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); - } - else { - check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); - check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad"); - } - - torch::Tensor L1_grad = torch::zeros(L1_in.sizes(), L1_in.options()); - torch::Tensor L2_grad = torch::zeros(L2_in.sizes(), L2_in.options()); - torch::Tensor W_grad = torch::empty(W.sizes(), W.options()); - torch::Tensor L3_dgrad = torch::zeros(L3_grad.sizes(), L3_grad.options()); - - torch::Tensor L1_in_contig = L1_in.contiguous(); - torch::Tensor L2_in_contig = L2_in.contiguous(); - torch::Tensor W_contig = W.contiguous(); - torch::Tensor L3_grad_contig = L3_grad.contiguous(); - torch::Tensor L1_dgrad_contig = L1_dgrad.contiguous(); - torch::Tensor L2_dgrad_contig = L2_dgrad.contiguous(); - torch::Tensor W_dgrad_contig = W_dgrad.contiguous(); - - torch::Tensor rows_contig = rows.contiguous(); - torch::Tensor cols_contig = cols.contiguous(); - torch::Tensor workspace_contig = workspace.contiguous(); - torch::Tensor transpose_perm_contig = transpose_perm.contiguous(); - - if(k.shared_weights) - W_grad.zero_(); - - jit_kernel->double_backward( - data_ptr(L1_in_contig), data_ptr(L2_in_contig), - data_ptr(W_contig), data_ptr(L3_grad_contig), - data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), - data_ptr(W_dgrad_contig), - data_ptr(L1_grad), data_ptr(L2_grad), - data_ptr(W_grad), data_ptr(L3_dgrad), - data_ptr(rows_contig), data_ptr(cols_contig), - nnz, node_count, - data_ptr(workspace_contig), data_ptr(transpose_perm_contig), - stream - ); - - return tuple(L1_grad, L2_grad, W_grad, L3_dgrad); +Stream get_current_stream() { +#ifdef CUDA_BACKEND + return c10::cuda::getCurrentCUDAStream(); +#endif +#ifdef HIP_BACKEND + return c10::hip::getCurrentHIPStream(); +#endif } -// =========================================================== - -TORCH_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { - m.impl("jit_tp_forward", &jit_tp_forward); - m.impl("jit_tp_backward", &jit_tp_backward); - m.impl("jit_tp_double_backward", &jit_tp_double_backward); - - m.impl("jit_conv_forward", &jit_conv_forward); - m.impl("jit_conv_backward", &jit_conv_backward); - m.impl("jit_conv_double_backward", &jit_conv_double_backward); -}; - -TORCH_LIBRARY(libtorch_tp_jit, m) { - m.def("jit_tp_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim) -> Tensor"); - m.def("jit_tp_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)"); - m.def("jit_tp_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)"); - - m.def("jit_conv_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); - m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); - m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); -}; - -PYBIND11_MODULE(libtorch_tp_jit, m) {} \ No newline at end of file +namespace py=pybind11; +PYBIND11_MODULE(libtorch_tp_jit, m) { + py::class_(m, "DeviceProp") + .def(py::init()) + .def_readonly("name", &DeviceProp::name) + .def_readonly("warpsize", &DeviceProp::warpsize) + .def_readonly("major", &DeviceProp::major) + .def_readonly("minor", &DeviceProp::minor) + .def_readonly("multiprocessorCount", &DeviceProp::multiprocessorCount) + .def_readonly("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); + + py::class_(m, "GPUTimer") + .def(py::init<>()) + .def("start", &GPUTimer::start) + .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) + .def("clear_L2_cache", &GPUTimer::clear_L2_cache); +} diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp new file mode 100644 index 00000000..5ddcb7f5 --- /dev/null +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -0,0 +1,109 @@ +#define USE_CUDA + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef HIP_BACKEND + #include +#endif + +using Tensor = torch::stable::Tensor; +using Dtype = torch::headeronly::ScalarType; + +constexpr Dtype kFloat = torch::headeronly::ScalarType::Float; +constexpr Dtype kDouble = torch::headeronly::ScalarType::Double; +constexpr Dtype kInt = torch::headeronly::ScalarType::Int; +constexpr Dtype kLong = torch::headeronly::ScalarType::Long; +constexpr Dtype kByte = torch::headeronly::ScalarType::Byte; + +#define TCHECK STD_TORCH_CHECK +#define BOX(x) TORCH_BOX(x) +#define REGISTER_LIBRARY_IMPL STABLE_TORCH_LIBRARY_IMPL +#define REGISTER_LIBRARY STABLE_TORCH_LIBRARY + +#include "torch_core.hpp" + +Tensor tensor_to_cpu_contiguous(const Tensor &tensor) { + torch::stable::Device device(torch::headeronly::DeviceType::CPU); + return torch::stable::contiguous(torch::stable::to(tensor, device)); +} + +Tensor tensor_contiguous(const Tensor &tensor) { + return torch::stable::contiguous(tensor); +} + +Tensor tensor_empty_like(const Tensor &ref, const std::vector &sizes) { + auto sizes_ref = torch::headeronly::IntHeaderOnlyArrayRef(sizes.data(), sizes.size()); + return torch::stable::new_empty(ref, sizes_ref); +} + +Tensor tensor_zeros_like(const Tensor &ref, const std::vector &sizes) { + auto sizes_ref = torch::headeronly::IntHeaderOnlyArrayRef(sizes.data(), sizes.size()); + Tensor out = torch::stable::new_empty(ref, sizes_ref); + torch::stable::zero_(out); + return out; +} + +void tensor_zero_(Tensor &tensor) { + torch::stable::zero_(tensor); +} + +void alert_not_deterministic(const char *name) { + (void)name; +} + +const uint8_t *tensor_data_ptr_u8(const Tensor &tensor) { + return static_cast(tensor.data_ptr()); +} + +void *data_ptr(const Tensor &tensor) { + return tensor.data_ptr(); +} + +Stream get_current_stream() { + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + + #ifdef CUDA_BACKEND + return static_cast(stream_ptr); + #elif defined(HIP_BACKEND) + return static_cast(stream_ptr); + #endif +} + +#ifdef CUDA_BACKEND + #define EXTENSION_NAME oeq_stable_cuda +#endif +#ifdef HIP_BACKEND + #define EXTENSION_NAME oeq_stable_hip +#endif + +#ifdef INCLUDE_NB_EXTENSION + #include "nanobind/nanobind.h" + namespace nb = nanobind; + NB_MODULE(EXTENSION_NAME, m) { + nb::class_(m, "DeviceProp") + .def(nb::init()) + .def_ro("name", &DeviceProp::name) + .def_ro("warpsize", &DeviceProp::warpsize) + .def_ro("major", &DeviceProp::major) + .def_ro("minor", &DeviceProp::minor) + .def_ro("multiprocessorCount", &DeviceProp::multiprocessorCount) + .def_ro("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); + + nb::class_(m, "GPUTimer") + .def(nb::init<>()) + .def("start", &GPUTimer::start) + .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) + .def("clear_L2_cache", &GPUTimer::clear_L2_cache); + } +#endif \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/stubs/stream.cpp b/openequivariance/openequivariance/extension/stubs/stream.cpp new file mode 100644 index 00000000..5e39739d --- /dev/null +++ b/openequivariance/openequivariance/extension/stubs/stream.cpp @@ -0,0 +1,7 @@ +#include + +extern "C" { + AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream) { + return 0; + } +} \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/test/CMakeLists.txt b/openequivariance/openequivariance/extension/test/CMakeLists.txt index 3209d412..d869beca 100644 --- a/openequivariance/openequivariance/extension/test/CMakeLists.txt +++ b/openequivariance/openequivariance/extension/test/CMakeLists.txt @@ -1,9 +1,9 @@ cmake_minimum_required(VERSION 3.5 FATAL_ERROR) -project(test_oeq_jitscript_load) +project(test_oeq_aoti_load) find_package(Torch REQUIRED) -add_executable(load_jitscript load_jitscript.cpp) -target_link_libraries(load_jitscript "${TORCH_LIBRARIES}") -target_link_libraries(load_jitscript -Wl,--no-as-needed "${OEQ_EXTLIB}") -set_property(TARGET load_jitscript PROPERTY CXX_STANDARD 17) \ No newline at end of file +add_executable(load_aoti load_aoti.cpp) +target_link_libraries(load_aoti "${TORCH_LIBRARIES}") +target_link_libraries(load_aoti -Wl,--no-as-needed "${OEQ_EXTLIB}") +set_property(TARGET load_aoti PROPERTY CXX_STANDARD 17) diff --git a/openequivariance/openequivariance/extension/test/load_aoti.cpp b/openequivariance/openequivariance/extension/test/load_aoti.cpp new file mode 100644 index 00000000..0ffc9679 --- /dev/null +++ b/openequivariance/openequivariance/extension/test/load_aoti.cpp @@ -0,0 +1,63 @@ +#include +#include +#include + +#include +#include + +/* +* This program takes in two JITScript modules that execute +* a tensor product in FP32 precision. +* The first module is compiled from e3nn, the second is +* OEQ's compiled module. The program checks that the +* two outputs are comparable. +*/ + +int main(int argc, const char* argv[]) { + if (argc != 7) { + std::cerr << "usage: load_aoti " + << " " + << " " + << " " + << " " + << " " + << " " + << std::endl; + + return 1; + } + + c10::InferenceMode mode; + + int64_t L1_dim = std::stoi(argv[3]); + int64_t L2_dim = std::stoi(argv[4]); + int64_t weight_numel = std::stoi(argv[5]); + int64_t batch_size = std::stoi(argv[6]); + + std::vector inputs; + inputs.push_back(torch::randn({batch_size, L1_dim}, at::kCUDA)); + inputs.push_back(torch::randn({batch_size, L2_dim}, at::kCUDA)); + inputs.push_back(torch::randn({batch_size, weight_numel}, at::kCUDA)); + + try { + torch::inductor::AOTIModelPackageLoader module_e3nn(argv[1]); + torch::inductor::AOTIModelPackageLoader module_oeq(argv[2]); + + std::vector output_e3nn = module_e3nn.run(inputs); + std::vector output_oeq = module_oeq.run(inputs); + + for (size_t i = 0; i < output_e3nn.size(); i++) { + if(at::allclose(output_e3nn[i], output_oeq[i], 1e-5, 1e-5)) { + return 0; + } + else { + std::cerr << "torch.allclose returned FALSE comparing model outputs." << std::endl; + return 1; + } + } + } + catch (const c10::Error& e) { + std::cerr << "error loading script module" << std::endl; + return 1; + } +} \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/test/load_jitscript.cpp b/openequivariance/openequivariance/extension/test/load_jitscript.cpp deleted file mode 100644 index 6ea7982b..00000000 --- a/openequivariance/openequivariance/extension/test/load_jitscript.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include - -#include -#include - -/* -* This program takes in two JITScript modules that execute -* a tensor product in FP32 precision. -* The first module is compiled from e3nn, the second is -* OEQ's compiled module. The program checks that the -* two outputs are comparable. -*/ - -int main(int argc, const char* argv[]) { - if (argc != 7) { - std::cerr << "usage: load_jitscript " - << " " - << " " - << " " - << " " - << " " - << " " - << std::endl; - - return 1; - } - - int64_t L1_dim = std::stoi(argv[3]); - int64_t L2_dim = std::stoi(argv[4]); - int64_t weight_numel = std::stoi(argv[5]); - int64_t batch_size = std::stoi(argv[6]); - - torch::Device device(torch::kCUDA); - std::vector inputs; - inputs.push_back(torch::randn({batch_size, L1_dim}, device)); - inputs.push_back(torch::randn({batch_size, L2_dim}, device)); - inputs.push_back(torch::randn({batch_size, weight_numel}, device)); - - torch::jit::script::Module module_e3nn, module_oeq; - try { - module_e3nn = torch::jit::load(argv[1]); - module_oeq = torch::jit::load(argv[2]); - } - catch (const c10::Error& e) { - std::cerr << "error loading script module" << std::endl; - return 1; - } - - module_e3nn.to(device); - module_oeq.to(device); - - at::Tensor output_e3nn = module_e3nn.forward(inputs).toTensor(); - at::Tensor output_oeq = module_oeq.forward(inputs).toTensor(); - - if(at::allclose(output_e3nn, output_oeq, 1e-5, 1e-5)) { - return 0; - } - else { - std::cerr << "torch.allclose returned FALSE comparing model outputs." << std::endl; - return 1; - } -} \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/torch_core.hpp b/openequivariance/openequivariance/extension/torch_core.hpp new file mode 100644 index 00000000..8e9e43e5 --- /dev/null +++ b/openequivariance/openequivariance/extension/torch_core.hpp @@ -0,0 +1,644 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "json11/json11.hpp" + +#ifdef CUDA_BACKEND + #include "backend_cuda.hpp" + #include "group_mm_cuda.hpp" + using JITKernel = CUJITKernel; + using GPU_Allocator = CUDA_Allocator; + + template + using GroupMM = GroupMMCUDA; +#endif + +#ifdef HIP_BACKEND + #include "backend_hip.hpp" + #include "group_mm_hip.hpp" + using JITKernel = HIPJITKernel; + using GPU_Allocator = HIP_Allocator; + + template + using GroupMM = GroupMMHIP; +#endif + +#include "tensorproducts.hpp" +#include "convolution.hpp" + +using namespace std; +using json = json11::Json; + +Dtype enum_to_torch_dtype(int64_t i); + +Tensor tensor_to_cpu_contiguous(const Tensor &tensor); +Tensor tensor_contiguous(const Tensor &tensor); +Tensor tensor_empty_like(const Tensor &ref, const std::vector &sizes); +Tensor tensor_zeros_like(const Tensor &ref, const std::vector &sizes); +void tensor_zero_(Tensor &tensor); + +void alert_not_deterministic(const char *name); +Stream get_current_stream(); + +const uint8_t *tensor_data_ptr_u8(const Tensor &tensor); +void *data_ptr(const Tensor &tensor); + +inline std::string shape_to_string(std::initializer_list shape) { + std::ostringstream oss; + oss << "["; + size_t i = 0; + for (int64_t dim : shape) { + if (i > 0) { + oss << ", "; + } + oss << dim; + ++i; + } + oss << "]"; + return oss.str(); +} + +inline std::string tensor_sizes_str(const Tensor &tensor) { + std::ostringstream oss; + oss << "["; + int64_t dims = tensor.dim(); + for (int64_t i = 0; i < dims; ++i) { + if (i > 0) { + oss << ", "; + } + oss << tensor.size(i); + } + oss << "]"; + return oss.str(); +} + +inline std::vector tensor_sizes_vec(const Tensor &tensor) { + int64_t dims = tensor.dim(); + std::vector sizes; + sizes.reserve(static_cast(dims)); + for (int64_t i = 0; i < dims; ++i) { + sizes.push_back(tensor.size(i)); + } + return sizes; +} + +inline std::vector make_sizes(std::initializer_list sizes) { + return std::vector(sizes); +} + +inline Dtype enum_to_torch_dtype(int64_t i) { + switch (i) { + case 1: return kFloat; + case 2: return kDouble; + case 3: return kInt; + case 4: return kLong; + case 5: return kByte; + } + throw logic_error("Unsupported tensor datatype!"); +} + +inline void check_tensor(const Tensor &tensor, + std::initializer_list expected_shape, + Dtype expected_dtype, + std::string tensor_name) { + bool shape_ok = (tensor.dim() == static_cast(expected_shape.size())); + if (shape_ok) { + int64_t i = 0; + for (int64_t dim : expected_shape) { + if (tensor.size(i) != dim) { + shape_ok = false; + break; + } + ++i; + } + } + + TCHECK(shape_ok, + "Shape mismatch for tensor '", tensor_name, + "'. Expected: ", shape_to_string(expected_shape), + ". Got: ", tensor_sizes_str(tensor)); + TCHECK(tensor.is_cuda(), "Tensor '", tensor_name, "' is not on the GPU."); + TCHECK(tensor.scalar_type() == expected_dtype, + "Dtype mismatch for tensor '", tensor_name, + "'. Expected: ", static_cast(expected_dtype), + ". Got: ", static_cast(tensor.scalar_type())); +} + +inline std::unordered_map parse_json_config(const json &j_obj) { + std::unordered_map result; + for (const auto &kv : j_obj.object_items()) { + result[kv.first] = static_cast(kv.second.number_value()); + } + return result; +} + +struct KernelProp { + int64_t L1_dim, L2_dim, L3_dim, weight_numel; + bool shared_weights; + Dtype irrep_dtype; + Dtype weight_dtype; + + int64_t workspace_size; // Convolution only + bool deterministic; + Dtype idx_dtype; + Dtype workspace_dtype; + + KernelProp() : + L1_dim(0), L2_dim(0), L3_dim(0), weight_numel(0), + shared_weights(false), + irrep_dtype(kFloat), weight_dtype(kFloat), + workspace_size(0), deterministic(false), + idx_dtype(kInt), workspace_dtype(kByte) {} + + KernelProp( + std::unordered_map &kernel_dims, bool is_convolution) : + L1_dim(kernel_dims.at("L1_dim")), + L2_dim(kernel_dims.at("L2_dim")), + L3_dim(kernel_dims.at("L3_dim")), + weight_numel(kernel_dims.at("weight_numel")), + shared_weights(kernel_dims.at("shared_weights")), + irrep_dtype(enum_to_torch_dtype(kernel_dims.at("irrep_dtype"))), + weight_dtype(enum_to_torch_dtype(kernel_dims.at("weight_dtype"))), + workspace_dtype(kByte) { + if (is_convolution) { + workspace_size = kernel_dims.at("workspace_size"); + deterministic = kernel_dims.at("deterministic"); + idx_dtype = enum_to_torch_dtype(kernel_dims.at("idx_dtype")); + } + } +}; + +inline std::unordered_map>, + KernelProp + >> tp_cache; + +inline std::unordered_map>, + KernelProp + >> conv_cache; + +inline std::mutex mut; + +inline std::pair*, KernelProp> + compile_tp_with_caching(const Tensor &json_bytes, + int64_t hash) { + { + const std::lock_guard lock(mut); + auto it = tp_cache.find(hash); + if (it == tp_cache.end()) { + Tensor cpu_tensor = tensor_to_cpu_contiguous(json_bytes); + std::string json_payload( + reinterpret_cast(tensor_data_ptr_u8(cpu_tensor)), + cpu_tensor.numel() + ); + + std::string err; + json root = json::parse(json_payload, err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + + auto jit_tp_impl = std::make_unique>( + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, + kernel_prop_map); + + tp_cache.insert({hash, + std::make_pair(std::move(jit_tp_impl), + KernelProp(kernel_prop_map, false))}); + it = tp_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; + } +} + +inline std::pair*, KernelProp> + compile_conv_with_caching(const Tensor &json_bytes, + int64_t hash) { + { + const std::lock_guard lock(mut); + auto it = conv_cache.find(hash); + if (it == conv_cache.end()) { + Tensor cpu_tensor = tensor_to_cpu_contiguous(json_bytes); + std::string json_payload( + reinterpret_cast(tensor_data_ptr_u8(cpu_tensor)), + cpu_tensor.numel() + ); + + std::string err; + json root = json::parse(json_payload, err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + + auto jit_conv_impl = std::make_unique>( + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, + kernel_prop_map); + + conv_cache.insert({hash, + std::make_pair(std::move(jit_conv_impl), + KernelProp(kernel_prop_map, true))}); + it = conv_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; + } +} + +// --------------------- Tensor Products -------------------------- + +inline Tensor jit_tp_forward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + int64_t L3_dim) { + + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t num_batch = L1_in.size(0); + + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + + Tensor L3_out = tensor_empty_like(L1_in, make_sizes({num_batch, k.L3_dim})); + + Tensor L1_contig = tensor_contiguous(L1_in); + Tensor L2_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + + jit_kernel->exec_tensor_product( + num_batch, + data_ptr(L1_contig), + data_ptr(L2_contig), + data_ptr(L3_out), + data_ptr(W_contig), + stream + ); + + return L3_out; +} + +inline tuple jit_tp_backward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + Tensor L3_grad) { + + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t num_batch = L1_in.size(0); + + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); + + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + + Tensor L1_grad = tensor_empty_like(L1_in, tensor_sizes_vec(L1_in)); + Tensor L2_grad = tensor_empty_like(L2_in, tensor_sizes_vec(L2_in)); + Tensor W_grad = tensor_empty_like(W, tensor_sizes_vec(W)); + + if (k.shared_weights) + tensor_zero_(W_grad); + + Tensor L1_in_contig = tensor_contiguous(L1_in); + Tensor L2_in_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + Tensor L3_grad_contig = tensor_contiguous(L3_grad); + + jit_kernel->backward( + num_batch, + data_ptr(L1_in_contig), data_ptr(L1_grad), + data_ptr(L2_in_contig), data_ptr(L2_grad), + data_ptr(W_contig), data_ptr(W_grad), + data_ptr(L3_grad_contig), + stream + ); + + return tuple(L1_grad, L2_grad, W_grad); +} + +inline tuple jit_tp_double_backward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + Tensor L3_grad, + Tensor L1_dgrad, + Tensor L2_dgrad, + Tensor W_dgrad) { + + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t num_batch = L1_in.size(0); + + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(L1_dgrad, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); + check_tensor(L2_dgrad, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); + + if (k.shared_weights) { + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); + } else { + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {num_batch, k.weight_numel}, k.weight_dtype, "W_dgrad"); + } + + Tensor L1_grad = tensor_empty_like(L1_in, tensor_sizes_vec(L1_in)); + Tensor L2_grad = tensor_empty_like(L2_in, tensor_sizes_vec(L2_in)); + Tensor W_grad = tensor_empty_like(W, tensor_sizes_vec(W)); + Tensor L3_dgrad = tensor_empty_like(L3_grad, tensor_sizes_vec(L3_grad)); + + Tensor L1_in_contig = tensor_contiguous(L1_in); + Tensor L2_in_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + Tensor L3_grad_contig = tensor_contiguous(L3_grad); + + Tensor L1_dgrad_contig = tensor_contiguous(L1_dgrad); + Tensor L2_dgrad_contig = tensor_contiguous(L2_dgrad); + Tensor W_dgrad_contig = tensor_contiguous(W_dgrad); + + if (k.shared_weights) { + tensor_zero_(W_grad); + TCHECK(W.dim() == 1); + } + + jit_kernel->double_backward( + num_batch, + data_ptr(L1_in_contig), data_ptr(L2_in_contig), + data_ptr(W_contig), data_ptr(L3_grad_contig), + data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), + data_ptr(W_dgrad_contig), + data_ptr(L1_grad), data_ptr(L2_grad), + data_ptr(W_grad), data_ptr(L3_dgrad), + stream + ); + + return tuple(L1_grad, L2_grad, W_grad, L3_dgrad); +} + + +// ========================= Convolution ================================== + +inline Tensor jit_conv_forward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + int64_t L3_dim, + Tensor rows, + Tensor cols, + Tensor workspace, + Tensor transpose_perm) { + + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t nnz = rows.size(0); + const int64_t node_count = L1_in.size(0); + + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic) { + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } else { + alert_not_deterministic("OpenEquivariance_conv_atomic_forward"); + } + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + + Tensor L3_out = tensor_zeros_like(L1_in, make_sizes({node_count, k.L3_dim})); + + Tensor L1_contig = tensor_contiguous(L1_in); + Tensor L2_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + Tensor rows_contig = tensor_contiguous(rows); + Tensor cols_contig = tensor_contiguous(cols); + Tensor workspace_contig = tensor_contiguous(workspace); + + jit_kernel->exec_conv( + data_ptr(L1_contig), + data_ptr(L2_contig), + data_ptr(W_contig), + data_ptr(L3_out), + data_ptr(rows_contig), + data_ptr(cols_contig), + nnz, node_count, + data_ptr(workspace_contig), + stream); + + return L3_out; +} + +inline tuple jit_conv_backward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + Tensor L3_grad, + Tensor rows, + Tensor cols, + Tensor workspace, + Tensor transpose_perm) { + + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t nnz = rows.size(0); + const int64_t node_count = L1_in.size(0); + + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic) { + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } else { + alert_not_deterministic("OpenEquivariance_conv_atomic_backward"); + } + + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + + Tensor L1_grad = tensor_zeros_like(L1_in, tensor_sizes_vec(L1_in)); + Tensor L2_grad = tensor_zeros_like(L2_in, tensor_sizes_vec(L2_in)); + Tensor W_grad = tensor_empty_like(W, tensor_sizes_vec(W)); + + Tensor L1_in_contig = tensor_contiguous(L1_in); + Tensor L2_in_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + Tensor L3_grad_contig = tensor_contiguous(L3_grad); + + Tensor rows_contig = tensor_contiguous(rows); + Tensor cols_contig = tensor_contiguous(cols); + Tensor workspace_contig = tensor_contiguous(workspace); + Tensor transpose_perm_contig = tensor_contiguous(transpose_perm); + + if (k.shared_weights) + tensor_zero_(W_grad); + + jit_kernel->backward( + data_ptr(L1_in_contig), data_ptr(L1_grad), + data_ptr(L2_in_contig), data_ptr(L2_grad), + data_ptr(W_contig), data_ptr(W_grad), + data_ptr(L3_grad_contig), + data_ptr(rows_contig), data_ptr(cols_contig), + nnz, node_count, + data_ptr(workspace_contig), + data_ptr(transpose_perm_contig), + stream); + + return tuple(L1_grad, L2_grad, W_grad); +} + +inline tuple jit_conv_double_backward( + Tensor json_bytes, int64_t hash, + Tensor L1_in, + Tensor L2_in, + Tensor W, + Tensor L3_grad, + Tensor L1_dgrad, + Tensor L2_dgrad, + Tensor W_dgrad, + Tensor rows, + Tensor cols, + Tensor workspace, + Tensor transpose_perm) { + + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); + Stream stream = get_current_stream(); + + const int64_t nnz = rows.size(0); + const int64_t node_count = L1_in.size(0); + + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(L1_dgrad, {node_count, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); + check_tensor(L2_dgrad, {nnz, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic) { + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } else { + alert_not_deterministic("OpenEquivariance_conv_atomic_double_backward"); + } + + if (k.shared_weights) { + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); + } else { + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad"); + } + + Tensor L1_grad = tensor_zeros_like(L1_in, tensor_sizes_vec(L1_in)); + Tensor L2_grad = tensor_zeros_like(L2_in, tensor_sizes_vec(L2_in)); + Tensor W_grad = tensor_empty_like(W, tensor_sizes_vec(W)); + Tensor L3_dgrad = tensor_zeros_like(L3_grad, tensor_sizes_vec(L3_grad)); + + Tensor L1_in_contig = tensor_contiguous(L1_in); + Tensor L2_in_contig = tensor_contiguous(L2_in); + Tensor W_contig = tensor_contiguous(W); + Tensor L3_grad_contig = tensor_contiguous(L3_grad); + Tensor L1_dgrad_contig = tensor_contiguous(L1_dgrad); + Tensor L2_dgrad_contig = tensor_contiguous(L2_dgrad); + Tensor W_dgrad_contig = tensor_contiguous(W_dgrad); + + Tensor rows_contig = tensor_contiguous(rows); + Tensor cols_contig = tensor_contiguous(cols); + Tensor workspace_contig = tensor_contiguous(workspace); + Tensor transpose_perm_contig = tensor_contiguous(transpose_perm); + + if (k.shared_weights) + tensor_zero_(W_grad); + + jit_kernel->double_backward( + data_ptr(L1_in_contig), data_ptr(L2_in_contig), + data_ptr(W_contig), data_ptr(L3_grad_contig), + data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), + data_ptr(W_dgrad_contig), + data_ptr(L1_grad), data_ptr(L2_grad), + data_ptr(W_grad), data_ptr(L3_dgrad), + data_ptr(rows_contig), data_ptr(cols_contig), + nnz, node_count, + data_ptr(workspace_contig), data_ptr(transpose_perm_contig), + stream + ); + + return tuple(L1_grad, L2_grad, W_grad, L3_dgrad); +} + +// =========================================================== + +REGISTER_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { + m.impl("jit_tp_forward", BOX(&jit_tp_forward)); + m.impl("jit_tp_backward", BOX(&jit_tp_backward)); + m.impl("jit_tp_double_backward", BOX(&jit_tp_double_backward)); + + m.impl("jit_conv_forward", BOX(&jit_conv_forward)); + m.impl("jit_conv_backward", BOX(&jit_conv_backward)); + m.impl("jit_conv_double_backward", BOX(&jit_conv_double_backward)); +}; + +REGISTER_LIBRARY(libtorch_tp_jit, m) { + m.def("jit_tp_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim) -> Tensor"); + m.def("jit_tp_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)"); + m.def("jit_tp_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)"); + + m.def("jit_conv_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); + m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); + m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); +}; diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index a0ddd618..a8fb6c2c 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools", "setuptools-scm"] -build-backend = "setuptools.build_meta" +requires = ["setuptools", "setuptools-scm", "scikit-build-core", "nanobind"] +build-backend = "scikit_build_core.build" [project] name = "openequivariance" @@ -60,21 +60,26 @@ dev = [ ] jax = [ - "nanobind", "scikit-build-core", - "setuptools-scm" + "setuptools-scm", + "nanobind" ] [tool.setuptools.packages.find] include = ["openequivariance*"] -[tool.setuptools_scm] -root = ".." - [tool.pytest.ini_options] addopts = [ "--import-mode=importlib", ] [tool.ruff] -lint.ignore = ["E741"] \ No newline at end of file +lint.ignore = ["E741"] + +[tool.scikit-build] +metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" +sdist.include = ["openequivariance/_version.py"] + +[tool.setuptools_scm] +write_to = "openequivariance/_version.py" +root = ".." \ No newline at end of file diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 90eafe6c..c815c39f 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -58,8 +58,8 @@ set(OEQ_JAX_SOURCES set(OEQ_JAX_HEADERS ${HEADER_DIR}/convolution.hpp ${HEADER_DIR}/tensorproducts.hpp - ${HEADER_DIR}/util/backend_cuda.hpp - ${HEADER_DIR}/util/backend_hip.hpp + ${HEADER_DIR}/backend/backend_cuda.hpp + ${HEADER_DIR}/backend/backend_hip.hpp ${HEADER_DIR}/json11/json11.hpp ) diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index 74f9627d..c0e3aec8 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -8,7 +8,8 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance_extjax" -version = "0.2.1" +dynamic = ["version"] + authors = [ { name="Austin Glover" }, { name="Vivek Bharadwaj" }, @@ -41,7 +42,12 @@ issues = "https://github.com/PASSIONLab/OpenEquivariance/issues" JAX_HIP = {env="JAX_HIP", default="0"} XLA_DIRECT_DOWNLOAD = {env="XLA_DIRECT_DOWNLOAD", default="0"} +[tool.scikit-build] +metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" +sdist.include = ["openequivariance_extjax/_version.py"] + [tool.setuptools_scm] +write_to = "openequivariance_extjax/_version.py" root = ".." [tool.pytest.ini_options] diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index bf3fd69e..87e5e785 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -18,7 +18,7 @@ using json = json11::Json; #include #include - #include "util/backend_cuda.hpp" + #include "backend/backend_cuda.hpp" #include "group_mm_cuda.hpp" using JITKernel = CUJITKernel; using GPU_Allocator = CUDA_Allocator; @@ -29,7 +29,7 @@ using json = json11::Json; #endif #ifdef HIP_BACKEND - #include "util/backend_hip.hpp" + #include "backend/backend_hip.hpp" #include "group_mm_hip.hpp" using JITKernel = HIPJITKernel; using GPU_Allocator = HIP_Allocator; diff --git a/tests/export_test.py b/tests/export_test.py index efdaf865..6aba9690 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1,11 +1,18 @@ import torch import pytest import tempfile +import subprocess +import shutil +import os +import sys +import importlib.resources import numpy as np import openequivariance as oeq from torch_geometric import EdgeIndex +from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct + @pytest.fixture(scope="session") def problem_and_irreps(): @@ -128,3 +135,91 @@ def test_aoti(tp_and_inputs): aoti_result = aoti_model(*inputs) assert torch.allclose(uncompiled_result, aoti_result, atol=1e-5) + + +def test_aoti_cpp_inference(problem_and_irreps): + assert oeq.LINKED_LIBPYTHON, oeq.LINKED_LIBPYTHON_ERROR + problem, X_ir, Y_ir, _ = problem_and_irreps + cmake_prefix_path = torch.utils.cmake_prefix_path + torch_ext_so_path = oeq.torch_ext_so_path() + + gen = torch.Generator(device="cuda") + gen.manual_seed(0) + batch_size = 1000 + + # Create models + oeq_tp = oeq.TensorProduct(problem).to("cuda") + e3nn_tp = E3NNTensorProduct(problem).e3nn_tp.to("cuda") + + # Prepare inputs for export + X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) + Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) + W = torch.rand(batch_size, problem.weight_numel, device="cuda", generator=gen) + inputs = (X, Y, W) + + with ( + tempfile.TemporaryDirectory() as tmpdir, + tempfile.NamedTemporaryFile(suffix=".pt2") as oeq_file, + tempfile.NamedTemporaryFile(suffix=".pt2") as e3nn_file, + ): + # Export and compile with AOTI + exported_oeq = torch.export.export(oeq_tp, args=inputs, strict=False) + torch._inductor.aoti_compile_and_package( + exported_oeq, package_path=oeq_file.name + ) + + exported_e3nn = torch.export.export(e3nn_tp, args=inputs, strict=False) + torch._inductor.aoti_compile_and_package( + exported_e3nn, package_path=e3nn_file.name + ) + + test_path = importlib.resources.files("openequivariance") / "extension" / "test" + build_dir = os.path.join(tmpdir, "build") + os.makedirs(build_dir, exist_ok=True) + + for item in test_path.iterdir(): + shutil.copy(item, tmpdir) + + try: + subprocess.run( + [ + "cmake", + "..", + "-DCMAKE_BUILD_TYPE=Release", + "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, + "-DOEQ_EXTLIB=" + torch_ext_so_path, + "-DCMAKE_CXX_COMPILER=g++", + ], + cwd=build_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + subprocess.run( + ["make"], + cwd=build_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + subprocess.run( + [ + "./load_aoti", + e3nn_file.name, + oeq_file.name, + str(X_ir.dim), + str(Y_ir.dim), + str(problem.weight_numel), + str(batch_size), + ], + cwd=build_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + except subprocess.CalledProcessError as e: + print(e.stdout.decode(), file=sys.stderr) + print(e.stderr.decode(), file=sys.stderr) + assert False diff --git a/tests/import_test.py b/tests/import_test.py index a4728412..bf26af31 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -14,10 +14,3 @@ def test_extension_built(): assert BUILT_EXTENSION_ERROR is None assert BUILT_EXTENSION - - -def test_torch_extension_built(): - from openequivariance import TORCH_COMPILE, TORCH_COMPILE_ERROR - - assert TORCH_COMPILE_ERROR is None - assert TORCH_COMPILE