diff --git a/CHANGELOG.md b/CHANGELOG.md index 723857d4..a638438b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ ## Latest Changes +### v0.6.7 (2026-06-13) + +**Added**: +- Triple backward and higher-order derivative support for + tensor products in Pytorch. +- Reintroduced symmetric contraction implementation for PyTorch. +- `torch.compile`, `torch.export` support for symmetric + contraction. + +**Fixed**: +- Some compilation issues for RocM. + +### v0.6.6 (2026-06-13) +Bugfix: added alternate URL for libtorch aarch64 download in +stable extension. + ### v0.6.5 (2026-03-22) This release brings `ir_mul` layout support for OpenEquivariance. Pass the parameter diff --git a/docs/supported_ops.rst b/docs/supported_ops.rst index bcc11955..b8d90203 100644 --- a/docs/supported_ops.rst +++ b/docs/supported_ops.rst @@ -105,15 +105,11 @@ See PyTorch usage details `here torch.Tensor: - """ - If ragged_inner == 0: - A is 3D, num_weights x num_features x M x K - B is batch_size x num_features x K - C is batch_size x num_features x M - If ragged_inner == 1: (needed for the backward pass) - A is batch_size x num_features x M - B is batch_size x num_features K - C is 3D, num_weights x num_features M x K - """ - shape = None - if ragged_inner == 0: - shape = (B.shape[0], B.shape[1], M) - elif ragged_inner == 1: - shape = (num_elements, B.shape[1], M, K) - - C = torch.zeros(shape, device="cuda", dtype=A.dtype) - self.internal.group_gemm( - A.contiguous().data_ptr(), - B.contiguous().data_ptr(), - C.data_ptr(), - ragged_counts.data_ptr(), - M, - K, - ragged_inner, - ) - return C - - @group_gemm.register_fake - def _(A, B, ragged_counts, M, K, ragged_inner): - if ragged_inner == 0: - return A.new_empty(B.shape[0], B.shape[1], M) - elif ragged_inner == 1: - return A.new_empty(num_elements, batch_size, M, K) - - self.group_gemm = group_gemm - - def setup_context(ctx, inputs, output): - ctx.A, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, ctx.ragged_inner = inputs - - def backward(ctx, grad_output): - grad_A, grad_B = None, None - - if ctx.ragged_inner == 0: - grad_A = group_gemm( - grad_output, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, 1 - ) - grad_B = group_gemm( - ctx.A.transpose(2, 3), - grad_output, - ctx.ragged_counts, - ctx.K, - ctx.M, - 0, - ) - elif ctx.ragged_inner == 1: - grad_A = group_gemm( - grad_output, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, 0 - ) - grad_B = group_gemm( - grad_output.transpose(2, 3), - ctx.A, - ctx.ragged_counts, - ctx.K, - ctx.M, - 0, - ) - - return grad_A, grad_B, None, None, None, None - - self.group_gemm.register_autograd(backward, setup_context=setup_context) - - def forward(self, weights, vectors, bincounts): - return self.group_gemm( - weights, vectors, bincounts, weights.shape[2], weights.shape[3], 0 - ) - - -# -------------------------------------------------------------------------- -# The following segment of code was copied from MACE's repo at https://github.com/ACEsuit/mace/blob/b5faaa076c49778fc17493edfecebcabeb960155/mace/tools/cg.py#L106 +from openequivariance._torch import extlib import collections from typing import Dict, Optional, Union, List @@ -229,6 +121,24 @@ def U_matrix_real( return out +class GroupMM: + def __init__(self, dtype, num_elements, batch_size): + self.num_elements = num_elements + self.batch_size = batch_size + + def forward(self, weights, vectors, bincounts): + return torch.ops.libtorch_tp_jit.group_gemm( + weights, + vectors, + bincounts, + self.num_elements, + self.batch_size, + weights.shape[2], + weights.shape[3], + 0, + ) + + @compile_mode("script") class Contraction(torch.nn.Module): def __init__( @@ -420,3 +330,82 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): ] outs_cat = torch.cat(outs, dim=-1)[inverse_perm] return outs_cat + + +def register_torch_fakes(): + @torch.library.register_fake("libtorch_tp_jit::group_gemm") + def fake_group_gemm(A, B, ragged_counts, num_W, batch_size, m, k, ragged_inner): + if ragged_inner == 0: + return A.new_empty(B.shape[0], batch_size, m) + else: + return A.new_empty(num_W, batch_size, m, k) + + +def register_autograd(): + op = torch.ops.libtorch_tp_jit.group_gemm + + def setup_context(ctx, inputs, output): + ( + ctx.A, + ctx.B, + ctx.ragged_counts, + ctx.num_W, + ctx.batch_size, + ctx.m, + ctx.k, + ctx.ragged_inner, + ) = inputs + + def backward(ctx, grad_output): + if ctx.ragged_inner == 0: + grad_A = op( + grad_output, + ctx.B, + ctx.ragged_counts, + ctx.num_W, + ctx.batch_size, + ctx.m, + ctx.k, + 1, + ) + grad_B = op( + ctx.A.transpose(2, 3), + grad_output, + ctx.ragged_counts, + ctx.num_W, + ctx.batch_size, + ctx.k, + ctx.m, + 0, + ) + else: + grad_A = op( + grad_output, + ctx.B, + ctx.ragged_counts, + ctx.num_W, + ctx.batch_size, + ctx.m, + ctx.k, + 0, + ) + grad_B = op( + grad_output.transpose(2, 3), + ctx.A, + ctx.ragged_counts, + ctx.num_W, + ctx.batch_size, + ctx.k, + ctx.m, + 0, + ) + return grad_A, grad_B, None, None, None, None, None, None + + torch.library.register_autograd( + "libtorch_tp_jit::group_gemm", backward, setup_context=setup_context + ) + + +if extlib.BUILT_EXTENSION: + register_torch_fakes() + register_autograd() diff --git a/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py b/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py index 00edefcb..486334e9 100644 --- a/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py +++ b/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py @@ -1,4 +1,4 @@ -from openequivariance._torch.symmetric_contraction.symmetric_contraction import ( +from openequivariance._torch.symmetric_contraction.SymmetricContraction import ( SymmetricContraction, ) diff --git a/openequivariance/openequivariance/extension/group_mm.hpp b/openequivariance/openequivariance/extension/group_mm.hpp new file mode 100644 index 00000000..a1d7be5a --- /dev/null +++ b/openequivariance/openequivariance/extension/group_mm.hpp @@ -0,0 +1,140 @@ +#pragma once + +#include +#include + +#ifdef CUDA_BACKEND + #include "cublas_v2.h" + #include + + struct BlasHandle { + cublasHandle_t handle; + BlasHandle() { + if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) + throw std::logic_error("CUBLAS initialization failed"); + } + ~BlasHandle() { cublasDestroy(handle); } + }; +#elif defined(HIP_BACKEND) + #include "rocblas/rocblas.h" + #include + + struct BlasHandle { + rocblas_handle handle; + BlasHandle() { + if (rocblas_create_handle(&handle) != rocblas_status_success) + throw std::logic_error("rocBLAS initialization failed"); + } + ~BlasHandle() { rocblas_destroy_handle(handle); } + }; +#endif + +inline BlasHandle& get_blas_handle() { + static BlasHandle handle; + return handle; +} + +template +void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw, + int64_t* ragged_counts, int num_W, int batch_size, int m, int k, int ragged_inner) { + + auto& blas = get_blas_handle(); + T alpha = 1.0, beta = 0.0; + T* A_base = reinterpret_cast(A_raw); + T* B_base = reinterpret_cast(B_raw); + T* C_base = reinterpret_cast(C_raw); + + int64_t ragged_offset = 0; + for (int i = 0; i < num_W; i++) { + int M, K, N, lda, ldb, ldc, strideA, strideB, strideC; + T *A, *B, *C; +#ifdef CUDA_BACKEND + cublasOperation_t transa, transb; +#elif defined(HIP_BACKEND) + rocblas_operation transa, transb; +#endif + + if (ragged_inner == 0) { + M = m; K = k; N = static_cast(ragged_counts[i]); + A = A_base + (m * k * batch_size * i); + lda = k; strideA = M * K; + B = B_base + (k * batch_size * ragged_offset); + ldb = K * batch_size; strideB = K; + C = C_base + (m * batch_size * ragged_offset); + ldc = M * batch_size; strideC = M; +#ifdef CUDA_BACKEND + transa = CUBLAS_OP_T; transb = CUBLAS_OP_N; +#elif defined(HIP_BACKEND) + transa = rocblas_operation_transpose; transb = rocblas_operation_none; +#endif + } else { + M = k; K = static_cast(ragged_counts[i]); N = m; + A = B_base + (k * batch_size * ragged_offset); + lda = k * batch_size; strideA = M; + B = A_base + (m * batch_size * ragged_offset); + ldb = m * batch_size; strideB = N; + C = C_base + (m * k * batch_size * i); + ldc = k; strideC = M * N; +#ifdef CUDA_BACKEND + transa = CUBLAS_OP_N; transb = CUBLAS_OP_T; +#elif defined(HIP_BACKEND) + transa = rocblas_operation_none; transb = rocblas_operation_transpose; +#endif + } + ragged_offset += ragged_counts[i]; + + if (ragged_counts[i] > 0) { +#ifdef CUDA_BACKEND + cublasStatus_t stat; + if (std::is_same::value) { + stat = cublasSgemmStridedBatched(blas.handle, + transa, transb, M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } else if (std::is_same::value) { + stat = cublasDgemmStridedBatched(blas.handle, + transa, transb, M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } else { + throw std::logic_error("Unsupported datatype for grouped GEMM!"); + } + if (stat != CUBLAS_STATUS_SUCCESS) + throw std::logic_error("Grouped GEMM failed!"); +#elif defined(HIP_BACKEND) + rocblas_status stat; + if (std::is_same::value) { + stat = rocblas_sgemm_strided_batched(blas.handle, + transa, transb, M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } else if (std::is_same::value) { + stat = rocblas_dgemm_strided_batched(blas.handle, + transa, transb, M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } else { + throw std::logic_error("Unsupported datatype for grouped GEMM!"); + } + if (stat != rocblas_status_success) + throw std::logic_error("Grouped GEMM failed!"); +#endif + } + } +} diff --git a/openequivariance/openequivariance/extension/group_mm_cuda.hpp b/openequivariance/openequivariance/extension/group_mm_cuda.hpp deleted file mode 100644 index 95f1412f..00000000 --- a/openequivariance/openequivariance/extension/group_mm_cuda.hpp +++ /dev/null @@ -1,143 +0,0 @@ -#pragma once - -#include "cublas_v2.h" -#include -#include -#include - -using namespace std; - -template -class GroupMMCUDA { - cublasStatus_t stat; - cublasHandle_t handle; - - int num_W; - int batch_size; - - T alpha; - T beta; - -public: - GroupMMCUDA(int num_W, int batch_size) : - num_W(num_W), - batch_size(batch_size), - alpha(1.0), - beta(0.0) { - stat = cublasCreate(&handle); - if (stat != CUBLAS_STATUS_SUCCESS) { - throw std::logic_error("CUBLAS initialization failed"); - } - } - - void group_gemm(void* A_raw, void* B_raw, void* C_raw, - int64_t* ragged_counts, int m, int k, int ragged_inner) { - /* - * Performs one of two grouped, batched GEMMs with a single ragged dimension: - * - * a) If ragged_inner = 0, multiplies each M x K row-major weight matrix A - * against B, where B is stored in column-major order with each matrix of - * dimensions K x [offset_diff]. Output has dimensions M x [offset_diff], - * stored in column-major order. - * b) If ragged_inner = 1, multiplies each M x [offset_diff] A matrix - * against each B K x [offset_diff] matrix transposed to produce a - * M x K matrix output. - */ - - T* A_base = reinterpret_cast(A_raw); - T* B_base = reinterpret_cast(B_raw); - T* C_base = reinterpret_cast(C_raw); - - int64_t ragged_offset = 0; - for(int i = 0; i < num_W; i++) { - int M, K, N, lda, ldb, ldc; - T *A, *B, *C; - - int strideA, strideB, strideC; - cublasOperation_t transa, transb; - - if(ragged_inner == 0) { - M = m; - K = k; - N = static_cast(ragged_counts[i]); - - A = A_base + (m * k * batch_size * i); - lda = k; strideA = M * K; - - B = B_base + (k * batch_size * ragged_offset); - ldb = K * batch_size; strideB = K; - - C = C_base + (m * batch_size * ragged_offset); - ldc = M * batch_size; strideC = M; - - transa = CUBLAS_OP_T; - transb = CUBLAS_OP_N; - } - else { - M = k; - K = static_cast(ragged_counts[i]); - N = m; - - A = B_base + (k * batch_size * ragged_offset); - lda = k * batch_size; strideA = M; - - B = A_base + (m * batch_size * ragged_offset); - ldb = m * batch_size; strideB = N; - - C = C_base + (m * k * batch_size * i); - ldc = k; strideC = M * N; - - transa = CUBLAS_OP_N; - transb = CUBLAS_OP_T; - } - ragged_offset += ragged_counts[i]; - - if(ragged_counts[i] > 0) { - if(std::is_same::value) { - stat = cublasSgemmStridedBatched(handle, - transa, transb, - M, N, K, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(&beta), - reinterpret_cast(C), ldc, strideC, - batch_size); - } - else if(std::is_same::value) { - stat = cublasDgemmStridedBatched(handle, - transa, transb, - M, N, K, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(&beta), - reinterpret_cast(C), ldc, strideC, - batch_size); - } - else { - throw std::logic_error("Unsupported datatype for grouped GEMM!"); - } - if (stat != CUBLAS_STATUS_SUCCESS) { - throw std::logic_error("Grouped GEMM failed!"); - } - } - } - } - - void group_gemm_intptr(uint64_t weights, - uint64_t vectors, uint64_t output, - uint64_t ragged_counts, int m, int k, int ragged_inner) { - - group_gemm( - reinterpret_cast(weights), - reinterpret_cast(vectors), - reinterpret_cast(output), - reinterpret_cast(ragged_counts), - m, k, ragged_inner); - } - - ~GroupMMCUDA() { - cublasDestroy(handle); - } -}; \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/group_mm_hip.hpp b/openequivariance/openequivariance/extension/group_mm_hip.hpp deleted file mode 100644 index 2e713cf4..00000000 --- a/openequivariance/openequivariance/extension/group_mm_hip.hpp +++ /dev/null @@ -1,129 +0,0 @@ -#pragma once - -#include "rocblas/rocblas.h" -#include -#include -#include - - -template -class GroupMMHIP { - rocblas_status stat; - rocblas_handle handle; - - int num_W; - int batch_size; - - T alpha; - T beta; - -public: - GroupMMHIP(int num_W, int batch_size) : - num_W(num_W), - batch_size(batch_size), - alpha(1.0), - beta(0.0) { - if(rocblas_create_handle(&handle) != rocblas_status_success) { - throw std::logic_error("rocBLAS initialization failed"); - } - } - - void group_gemm(void* A_raw, void* B_raw, void* C_raw, - int64_t* ragged_counts, int m, int k, int ragged_inner) { - - T* A_base = reinterpret_cast(A_raw); - T* B_base = reinterpret_cast(B_raw); - T* C_base = reinterpret_cast(C_raw); - - int64_t ragged_offset = 0; - for(int i = 0; i < num_W; i++) { - int M, K, N, lda, ldb, ldc; - T *A, *B, *C; - - int strideA, strideB, strideC; - rocblas_operation transa, transb; - - if(ragged_inner == 0) { - M = m; - K = k; - N = static_cast(ragged_counts[i]); - - A = A_base + (m * k * batch_size * i); - lda = k; strideA = M * K; - - B = B_base + (k * batch_size * ragged_offset); - ldb = K * batch_size; strideB = K; - - C = C_base + (m * batch_size * ragged_offset); - ldc = M * batch_size; strideC = M; - - transa = rocblas_operation_transpose; - transb = rocblas_operation_none; - } - else { - M = k; - K = static_cast(ragged_counts[i]); - N = m; - - A = B_base + (k * batch_size * ragged_offset); - lda = k * batch_size; strideA = M; - - B = A_base + (m * batch_size * ragged_offset); - ldb = m * batch_size; strideB = N; - - C = C_base + (m * k * batch_size * i); - ldc = k; strideC = M * N; - - transa = rocblas_operation_none; - transb = rocblas_operation_transpose; - } - ragged_offset += ragged_counts[i]; - - if(ragged_counts[i] > 0) { - if(std::is_same::value) { - stat = rocblas_sgemm_strided_batched(handle, - transa, transb, - M, N, K, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(&beta), - reinterpret_cast(C), ldc, strideC, - batch_size); - } - else if(std::is_same::value) { - stat = rocblas_dgemm_strided_batched(handle, - transa, transb, - M, N, K, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(&beta), - reinterpret_cast(C), ldc, strideC, - batch_size); - } - else { - throw std::logic_error("Unsupported datatype for grouped GEMM!"); - } - if (stat != rocblas_status_success) { - throw std::logic_error("Grouped GEMM failed!"); - } - } - } - } - - void group_gemm_intptr(uint64_t weights, - uint64_t vectors, uint64_t output, - uint64_t ragged_counts, int m, int k, int ragged_inner) { - group_gemm( - reinterpret_cast(weights), - reinterpret_cast(vectors), - reinterpret_cast(output), - reinterpret_cast(ragged_counts), - m, k, ragged_inner); - } - - ~GroupMMHIP() { - rocblas_destroy_handle(handle); - } -}; \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index 173c0a2e..698a142f 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -84,13 +84,6 @@ Stream get_current_stream() { namespace py=pybind11; PYBIND11_MODULE(libtorch_tp_jit, m) { - py::class_>(m, "GroupMM_F32") - .def(py::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - py::class_>(m, "GroupMM_F64") - .def(py::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - py::class_(m, "DeviceProp") .def(py::init()) .def_readonly("name", &DeviceProp::name) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 63fcc7f4..9ab6cd2f 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -89,13 +89,6 @@ Stream get_current_stream() { #include "nanobind/stl/string.h" namespace nb = nanobind; NB_MODULE(EXTENSION_NAME, m) { - nb::class_>(m, "GroupMM_F32") - .def(nb::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - nb::class_>(m, "GroupMM_F64") - .def(nb::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - nb::class_(m, "DeviceProp") .def(nb::init()) .def_ro("name", &DeviceProp::name) diff --git a/openequivariance/openequivariance/extension/torch_core.hpp b/openequivariance/openequivariance/extension/torch_core.hpp index 8e9e43e5..ec7698de 100644 --- a/openequivariance/openequivariance/extension/torch_core.hpp +++ b/openequivariance/openequivariance/extension/torch_core.hpp @@ -14,24 +14,18 @@ #ifdef CUDA_BACKEND #include "backend_cuda.hpp" - #include "group_mm_cuda.hpp" using JITKernel = CUJITKernel; using GPU_Allocator = CUDA_Allocator; - - template - using GroupMM = GroupMMCUDA; #endif #ifdef HIP_BACKEND #include "backend_hip.hpp" - #include "group_mm_hip.hpp" using JITKernel = HIPJITKernel; using GPU_Allocator = HIP_Allocator; - - template - using GroupMM = GroupMMHIP; #endif +#include "group_mm.hpp" + #include "tensorproducts.hpp" #include "convolution.hpp" @@ -623,6 +617,40 @@ inline tuple jit_conv_double_backward( // =========================================================== +inline Tensor group_gemm( + Tensor A, Tensor B, Tensor ragged_counts, + int64_t num_W, int64_t batch_size, int64_t m, int64_t k, int64_t ragged_inner) { + TCHECK(A.scalar_type() == B.scalar_type(), "group_gemm: A and B must have the same dtype"); + TCHECK(ragged_counts.scalar_type() == kLong, "group_gemm: ragged_counts must be int64"); + + Tensor A_c = tensor_contiguous(A); + Tensor B_c = tensor_contiguous(B); + Tensor rc_c = tensor_contiguous(ragged_counts); + int64_t* rc_ptr = reinterpret_cast(data_ptr(rc_c)); + + Tensor C; + if (ragged_inner == 0) { + C = tensor_zeros_like(A, make_sizes({B.size(0), batch_size, m})); + } + else { + C = tensor_zeros_like(A, make_sizes({num_W, batch_size, m, k})); + } + + if (A.scalar_type() == kFloat) { + group_gemm_blas(data_ptr(A_c), data_ptr(B_c), data_ptr(C), rc_ptr, + (int)num_W, (int)batch_size, (int)m, (int)k, (int)ragged_inner); + } else if (A.scalar_type() == kDouble) { + group_gemm_blas(data_ptr(A_c), data_ptr(B_c), data_ptr(C), rc_ptr, + (int)num_W, (int)batch_size, (int)m, (int)k, (int)ragged_inner); + } else { + throw std::logic_error("group_gemm: unsupported dtype, expected float32 or float64"); + } + + return C; +} + +// =========================================================== + REGISTER_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { m.impl("jit_tp_forward", BOX(&jit_tp_forward)); m.impl("jit_tp_backward", BOX(&jit_tp_backward)); @@ -631,6 +659,8 @@ REGISTER_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { m.impl("jit_conv_forward", BOX(&jit_conv_forward)); m.impl("jit_conv_backward", BOX(&jit_conv_backward)); m.impl("jit_conv_double_backward", BOX(&jit_conv_double_backward)); + + m.impl("group_gemm", BOX(&group_gemm)); }; REGISTER_LIBRARY(libtorch_tp_jit, m) { @@ -641,4 +671,6 @@ REGISTER_LIBRARY(libtorch_tp_jit, m) { m.def("jit_conv_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); + + m.def("group_gemm(Tensor A, Tensor B, Tensor ragged_counts, int num_W, int batch_size, int m, int k, int ragged_inner) -> Tensor"); }; diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 1028c8cb..4711e9cc 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance" -version = "0.6.6" +version = "0.6.7" authors = [ { name="Austin Glover" }, { name="Vivek Bharadwaj" }, diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index af35dde4..99731dab 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance_extjax" -version = "0.6.6" +version = "0.6.7" authors = [ { name="Austin Glover" }, diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 87e5e785..ec567a90 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -19,24 +19,16 @@ using json = json11::Json; #include #include "backend/backend_cuda.hpp" - #include "group_mm_cuda.hpp" using JITKernel = CUJITKernel; using GPU_Allocator = CUDA_Allocator; - - template - using GroupMM = GroupMMCUDA; using stream_t = cudaStream_t; #endif #ifdef HIP_BACKEND #include "backend/backend_hip.hpp" - #include "group_mm_hip.hpp" using JITKernel = HIPJITKernel; using GPU_Allocator = HIP_Allocator; - - template - using GroupMM = GroupMMHIP; - using stream_t = hipStream_t; + using stream_t = hipStream_t; #endif #include "tensorproducts.hpp" diff --git a/tests/symmetric_contraction_test.py b/tests/symmetric_contraction_test.py index fdf0861d..bbd105ac 100644 --- a/tests/symmetric_contraction_test.py +++ b/tests/symmetric_contraction_test.py @@ -1,3 +1,4 @@ +import collections from unittest.mock import patch import pytest @@ -12,13 +13,55 @@ MaceSymmetricContraction = mace_symmetric_contraction.SymmetricContraction -IRREPS_IN = o3.Irreps("2x0e + 2x1o") -IRREPS_OUT = o3.Irreps("2x0e + 2x1o") -CORRELATION = 2 -NUM_ELEMENTS = 4 -LABEL_VALUES = [0, 2, 3, 2, 0, 0, 2, 3, 2, 2] +SCConfig = collections.namedtuple( + "SCConfig", + [ + "irreps_in", + "irreps_out", + "correlation", + "num_elements", + "label_values", + "device", + ], +) + DEVICE = torch.device("cuda") +SC_CONFIGS = [ + SCConfig( + o3.Irreps("2x0e + 2x1o"), + o3.Irreps("2x0e + 2x1o"), + 2, + 4, + [0, 2, 3, 2, 0, 0, 2, 3, 2, 2], + DEVICE, + ), + SCConfig( + o3.Irreps("1x0e + 1x1o + 1x2e"), + o3.Irreps("1x0e + 1x1o"), + 3, + 3, + [0, 1, 2, 0, 1, 2, 0, 1], + DEVICE, + ), + SCConfig( + o3.Irreps("4x0e + 4x1o"), + o3.Irreps("4x0e"), + 2, + 5, + [0, 1, 2, 3, 4, 0, 1, 2, 3, 4], + DEVICE, + ), +] + + +@pytest.fixture( + params=SC_CONFIGS, + ids=lambda cfg: f"{cfg.irreps_in}-corr{cfg.correlation}", +) +def symmetric_contraction_config(request): + return request.param + @pytest.fixture(params=[torch.float32, torch.float64], ids=["F32", "F64"]) def dtype(request): @@ -26,24 +69,28 @@ def dtype(request): @pytest.fixture -def labels(): - return torch.tensor(LABEL_VALUES, device=DEVICE, dtype=torch.long) +def labels(symmetric_contraction_config): + cfg = symmetric_contraction_config + return torch.tensor(cfg.label_values, device=cfg.device, dtype=torch.long) @pytest.fixture -def node_attrs(labels, dtype): - return F.one_hot(labels, num_classes=NUM_ELEMENTS).to(dtype=dtype) +def node_attrs(labels, dtype, symmetric_contraction_config): + return F.one_hot(labels, num_classes=symmetric_contraction_config.num_elements).to( + dtype=dtype + ) @pytest.fixture -def node_feats(dtype): - gen = torch.Generator(device=DEVICE) +def node_feats(dtype, symmetric_contraction_config): + cfg = symmetric_contraction_config + gen = torch.Generator(device=cfg.device) gen.manual_seed(2468) return torch.randn( - len(LABEL_VALUES), - IRREPS_IN.count((0, 1)), - IRREPS_IN.dim // IRREPS_IN.count((0, 1)), - device=DEVICE, + len(cfg.label_values), + cfg.irreps_in.count((0, 1)), + cfg.irreps_in.dim // cfg.irreps_in.count((0, 1)), + device=cfg.device, dtype=dtype, generator=gen, requires_grad=True, @@ -51,28 +98,27 @@ def node_feats(dtype): @pytest.fixture -def modules(dtype): +def modules(dtype, symmetric_contraction_config): + cfg = symmetric_contraction_config torch.manual_seed(12345) oeq_module = SymmetricContraction( - IRREPS_IN, - IRREPS_OUT, - correlation=CORRELATION, - num_elements=NUM_ELEMENTS, + cfg.irreps_in, + cfg.irreps_out, + correlation=cfg.correlation, + num_elements=cfg.num_elements, dtype=dtype, - ).to(DEVICE) + ).to(cfg.device) - # MACE's original e3nn implementation reads torch.get_default_dtype() - # during construction, so patch that lookup instead of mutating global state. with patch( "mace.modules.symmetric_contraction.torch.get_default_dtype", return_value=dtype, ): mace_module = MaceSymmetricContraction( - IRREPS_IN, - IRREPS_OUT, - correlation=CORRELATION, - num_elements=NUM_ELEMENTS, - ).to(device=DEVICE, dtype=dtype) + cfg.irreps_in, + cfg.irreps_out, + correlation=cfg.correlation, + num_elements=cfg.num_elements, + ).to(device=cfg.device, dtype=dtype) copy_matching_state(oeq_module, mace_module) return oeq_module, mace_module @@ -120,15 +166,16 @@ def random_like(tensor, seed): class TestSymmetricContraction: def test_matches_mace_forward_backward( - self, modules, node_feats, node_attrs, dtype + self, modules, node_feats, node_attrs, dtype, symmetric_contraction_config ): + cfg = symmetric_contraction_config oeq_module, mace_module = modules mace_node_feats = node_feats.detach().clone().requires_grad_() oeq_output = oeq_module(node_feats, node_attrs) mace_output = mace_module(mace_node_feats, node_attrs) - assert oeq_output.shape == (len(LABEL_VALUES), IRREPS_OUT.dim) + assert oeq_output.shape == (len(cfg.label_values), cfg.irreps_out.dim) torch.testing.assert_close(oeq_output, mace_output, **tolerance(dtype)) output_grad = random_like(oeq_output, seed=4321) @@ -144,7 +191,9 @@ def test_matches_mace_forward_backward( for oeq_grad, mace_grad in zip(oeq_grads, mace_grads): torch.testing.assert_close(oeq_grad, mace_grad, **tolerance(dtype)) - def test_matches_mace_double_backward(self, modules, node_feats, node_attrs, dtype): + def test_matches_mace_double_backward( + self, modules, node_feats, node_attrs, dtype, symmetric_contraction_config + ): oeq_module, mace_module = modules mace_node_feats = node_feats.detach().clone().requires_grad_() @@ -193,3 +242,14 @@ def test_matches_mace_double_backward(self, modules, node_feats, node_attrs, dty for oeq_grad, mace_grad in zip(oeq_second_grads, mace_second_grads): torch.testing.assert_close(oeq_grad, mace_grad, **tolerance(dtype)) + + def test_compile(self, modules, node_feats, node_attrs): + sc, _ = modules + ref = sc(node_feats, node_attrs) + assert torch.allclose(ref, torch.compile(sc)(node_feats, node_attrs), atol=1e-5) + + def test_export(self, modules, node_feats, node_attrs): + sc, _ = modules + ref = sc(node_feats, node_attrs) + exported = torch.export.export(sc, args=(node_feats, node_attrs), strict=False) + assert torch.allclose(ref, exported.module()(node_feats, node_attrs), atol=1e-5)