From 11c49a6ef017f9b3e5cb7908c8edf66a6a0a7bb2 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 13 Jun 2026 14:32:44 -0700 Subject: [PATCH 1/8] Fixed to group_mm. --- .../_torch/extlib/__init__.py | 8 - ...contraction.py => SymmetricContraction.py} | 164 +++++------------- .../_torch/symmetric_contraction/__init__.py | 2 +- .../openequivariance/extension/group_mm.hpp | 136 +++++++++++++++ .../extension/group_mm_cuda.hpp | 143 --------------- .../extension/group_mm_hip.hpp | 129 -------------- .../extension/libtorch_tp_jit.cpp | 7 - .../extension/libtorch_tp_jit_stable.cpp | 7 - .../openequivariance/extension/torch_core.hpp | 50 +++++- 9 files changed, 227 insertions(+), 419 deletions(-) rename openequivariance/openequivariance/_torch/symmetric_contraction/{symmetric_contraction.py => SymmetricContraction.py} (72%) create mode 100644 openequivariance/openequivariance/extension/group_mm.hpp delete mode 100644 openequivariance/openequivariance/extension/group_mm_cuda.hpp delete mode 100644 openequivariance/openequivariance/extension/group_mm_hip.hpp diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index ad3878a2..64d3ce38 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -197,8 +197,6 @@ def torch_ext_so_path(): if BUILT_EXTENSION: from oeq_utilities import ( - GroupMM_F32, - GroupMM_F64, DeviceProp, GPUTimer, ) @@ -210,12 +208,6 @@ def _raise_import_error_helper(import_target: str): f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}" ) - def GroupMM_F32(*args, **kwargs): - _raise_import_error_helper("GroupMM_F32") - - def GroupMM_F64(*args, **kwargs): - _raise_import_error_helper("GroupMM_F64") - def DeviceProp(*args, **kwargs): _raise_import_error_helper("DeviceProp") diff --git a/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py b/openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py similarity index 72% rename from openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py rename to openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py index 655c13b0..4a6f903f 100644 --- a/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py +++ b/openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py @@ -1,115 +1,7 @@ # ruff: noqa : E402 import torch -from openequivariance._torch.extlib import GroupMM_F32, GroupMM_F64 - - -class GroupMM: - next_id = 0 - - def __init__(self, dtype, num_elements, batch_size): - self.id = GroupMM.next_id - self.num_elements = num_elements - GroupMM.next_id += 1 - - if dtype == torch.float32: - self.internal = GroupMM_F32(num_elements, batch_size) - else: - self.internal = GroupMM_F64(num_elements, batch_size) - - @torch.library.custom_op( - f"openequivariance::group_gemm{self.id}", - mutates_args=(), - device_types="cuda", - ) - def group_gemm( - A: torch.Tensor, - B: torch.Tensor, - ragged_counts: torch.Tensor, - M: int, - K: int, - ragged_inner: int, - ) -> torch.Tensor: - """ - If ragged_inner == 0: - A is 3D, num_weights x num_features x M x K - B is batch_size x num_features x K - C is batch_size x num_features x M - If ragged_inner == 1: (needed for the backward pass) - A is batch_size x num_features x M - B is batch_size x num_features K - C is 3D, num_weights x num_features M x K - """ - shape = None - if ragged_inner == 0: - shape = (B.shape[0], B.shape[1], M) - elif ragged_inner == 1: - shape = (num_elements, B.shape[1], M, K) - - C = torch.zeros(shape, device="cuda", dtype=A.dtype) - self.internal.group_gemm( - A.contiguous().data_ptr(), - B.contiguous().data_ptr(), - C.data_ptr(), - ragged_counts.data_ptr(), - M, - K, - ragged_inner, - ) - return C - - @group_gemm.register_fake - def _(A, B, ragged_counts, M, K, ragged_inner): - if ragged_inner == 0: - return A.new_empty(B.shape[0], B.shape[1], M) - elif ragged_inner == 1: - return A.new_empty(num_elements, batch_size, M, K) - - self.group_gemm = group_gemm - - def setup_context(ctx, inputs, output): - ctx.A, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, ctx.ragged_inner = inputs - - def backward(ctx, grad_output): - grad_A, grad_B = None, None - - if ctx.ragged_inner == 0: - grad_A = group_gemm( - grad_output, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, 1 - ) - grad_B = group_gemm( - ctx.A.transpose(2, 3), - grad_output, - ctx.ragged_counts, - ctx.K, - ctx.M, - 0, - ) - elif ctx.ragged_inner == 1: - grad_A = group_gemm( - grad_output, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, 0 - ) - grad_B = group_gemm( - grad_output.transpose(2, 3), - ctx.A, - ctx.ragged_counts, - ctx.K, - ctx.M, - 0, - ) - - return grad_A, grad_B, None, None, None, None - - self.group_gemm.register_autograd(backward, setup_context=setup_context) - - def forward(self, weights, vectors, bincounts): - return self.group_gemm( - weights, vectors, bincounts, weights.shape[2], weights.shape[3], 0 - ) - - -# -------------------------------------------------------------------------- -# The following segment of code was copied from MACE's repo at https://github.com/ACEsuit/mace/blob/b5faaa076c49778fc17493edfecebcabeb960155/mace/tools/cg.py#L106 +from openequivariance._torch import extlib import collections from typing import Dict, Optional, Union, List @@ -229,6 +121,19 @@ def U_matrix_real( return out +class GroupMM: + def __init__(self, dtype, num_elements, batch_size): + self.num_elements = num_elements + self.batch_size = batch_size + + def forward(self, weights, vectors, bincounts): + return torch.ops.libtorch_tp_jit.group_gemm( + weights, vectors, bincounts, + self.num_elements, self.batch_size, + weights.shape[2], weights.shape[3], 0, + ) + + @compile_mode("script") class Contraction(torch.nn.Module): def __init__( @@ -259,21 +164,17 @@ def __init__( )[-1] self.register_buffer(f"U_matrix_{nu}", U_matrix) - # Tensor contraction equations self.contractions_weighting = torch.nn.ModuleList() self.contractions_features = torch.nn.ModuleList() - # Create weight for product basis self.weights = torch.nn.ParameterList([]) self.groupMM = GroupMM(dtype, num_elements, self.num_features) self.num_equivariance = 2 * irrep_out.lmax + 1 for i in range(correlation, 0, -1): - # Shapes defining num_params = self.U_tensors(i).size()[-1] if i == correlation: - # Parameters for the product basis w = torch.nn.Parameter( torch.randn( (num_elements, num_params, self.num_features), dtype=dtype @@ -282,7 +183,6 @@ def __init__( ) self.weights_max = w else: - # Parameters for the product basis w = torch.nn.Parameter( torch.randn( (num_elements, num_params, self.num_features), dtype=dtype @@ -295,7 +195,6 @@ def __init__( self.weights = weights[:-1] self.weights_max = weights[-1] - # Permute the U matrices for i in range(correlation, 0, -1): U = self.U_tensors(i) num_params = U.shape[-1] @@ -334,7 +233,6 @@ def forward( c_tensor.view(s[0] * s[1], -1, s[-1]) * x.view(s[0] * s[1], 1, s[-1]), dim=2, ).view(s[:-1]) - # out = torch.bmm(c_tensor.view(s[0] * s[1], -1, s[-1]), x.view(s[0] * s[1], s[-1], 1)).view(s[:-1]) return out.view(out.shape[0], -1) @@ -420,3 +318,37 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): ] outs_cat = torch.cat(outs, dim=-1)[inverse_perm] return outs_cat + + +def register_torch_fakes(): + @torch.library.register_fake("libtorch_tp_jit::group_gemm") + def fake_group_gemm(A, B, ragged_counts, num_W, batch_size, m, k, ragged_inner): + if ragged_inner == 0: + return A.new_empty(B.shape[0], B.shape[1], m) + else: + return A.new_empty(num_W, batch_size, m, k) + + +def register_autograd(): + op = torch.ops.libtorch_tp_jit.group_gemm + + def setup_context(ctx, inputs, output): + ctx.A, ctx.B, ctx.ragged_counts, ctx.num_W, ctx.batch_size, ctx.m, ctx.k, ctx.ragged_inner = inputs + + def backward(ctx, grad_output): + if ctx.ragged_inner == 0: + grad_A = op(grad_output, ctx.B, ctx.ragged_counts, ctx.num_W, ctx.batch_size, ctx.m, ctx.k, 1) + grad_B = op(ctx.A.transpose(2, 3), grad_output, ctx.ragged_counts, ctx.num_W, ctx.batch_size, ctx.k, ctx.m, 0) + else: + grad_A = op(grad_output, ctx.B, ctx.ragged_counts, ctx.num_W, ctx.batch_size, ctx.m, ctx.k, 0) + grad_B = op(grad_output.transpose(2, 3), ctx.A, ctx.ragged_counts, ctx.num_W, ctx.batch_size, ctx.k, ctx.m, 0) + return grad_A, grad_B, None, None, None, None, None, None + + torch.library.register_autograd( + "libtorch_tp_jit::group_gemm", backward, setup_context=setup_context + ) + + +if extlib.BUILT_EXTENSION: + register_torch_fakes() + register_autograd() diff --git a/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py b/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py index 00edefcb..486334e9 100644 --- a/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py +++ b/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py @@ -1,4 +1,4 @@ -from openequivariance._torch.symmetric_contraction.symmetric_contraction import ( +from openequivariance._torch.symmetric_contraction.SymmetricContraction import ( SymmetricContraction, ) diff --git a/openequivariance/openequivariance/extension/group_mm.hpp b/openequivariance/openequivariance/extension/group_mm.hpp new file mode 100644 index 00000000..1e788bad --- /dev/null +++ b/openequivariance/openequivariance/extension/group_mm.hpp @@ -0,0 +1,136 @@ +#pragma once + +#include +#include + +#ifdef CUDA_BACKEND + #include "cublas_v2.h" + #include + + struct BlasHandle { + cublasHandle_t handle; + BlasHandle() { + if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) + throw std::logic_error("CUBLAS initialization failed"); + } + ~BlasHandle() { cublasDestroy(handle); } + }; +#elif defined(HIP_BACKEND) + #include "rocblas/rocblas.h" + #include + + struct BlasHandle { + rocblas_handle handle; + BlasHandle() { + if (rocblas_create_handle(&handle) != rocblas_status_success) + throw std::logic_error("rocBLAS initialization failed"); + } + ~BlasHandle() { rocblas_destroy_handle(handle); } + }; +#endif + +inline std::unique_ptr g_blas_handle = std::make_unique(); + +template +void group_gemm(void* A_raw, void* B_raw, void* C_raw, + int64_t* ragged_counts, int num_W, int batch_size, int m, int k, int ragged_inner) { + + T alpha = 1.0, beta = 0.0; + T* A_base = reinterpret_cast(A_raw); + T* B_base = reinterpret_cast(B_raw); + T* C_base = reinterpret_cast(C_raw); + + int64_t ragged_offset = 0; + for (int i = 0; i < num_W; i++) { + int M, K, N, lda, ldb, ldc, strideA, strideB, strideC; + T *A, *B, *C; +#ifdef CUDA_BACKEND + cublasOperation_t transa, transb; +#elif defined(HIP_BACKEND) + rocblas_operation transa, transb; +#endif + + if (ragged_inner == 0) { + M = m; K = k; N = static_cast(ragged_counts[i]); + A = A_base + (m * k * batch_size * i); + lda = k; strideA = M * K; + B = B_base + (k * batch_size * ragged_offset); + ldb = K * batch_size; strideB = K; + C = C_base + (m * batch_size * ragged_offset); + ldc = M * batch_size; strideC = M; +#ifdef CUDA_BACKEND + transa = CUBLAS_OP_T; transb = CUBLAS_OP_N; +#elif defined(HIP_BACKEND) + transa = rocblas_operation_transpose; transb = rocblas_operation_none; +#endif + } else { + M = k; K = static_cast(ragged_counts[i]); N = m; + A = B_base + (k * batch_size * ragged_offset); + lda = k * batch_size; strideA = M; + B = A_base + (m * batch_size * ragged_offset); + ldb = m * batch_size; strideB = N; + C = C_base + (m * k * batch_size * i); + ldc = k; strideC = M * N; +#ifdef CUDA_BACKEND + transa = CUBLAS_OP_N; transb = CUBLAS_OP_T; +#elif defined(HIP_BACKEND) + transa = rocblas_operation_none; transb = rocblas_operation_transpose; +#endif + } + ragged_offset += ragged_counts[i]; + + if (ragged_counts[i] > 0) { +#ifdef CUDA_BACKEND + cublasStatus_t stat; + if (std::is_same::value) { + stat = cublasSgemmStridedBatched(g_blas_handle->handle, + transa, transb, M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } else if (std::is_same::value) { + stat = cublasDgemmStridedBatched(g_blas_handle->handle, + transa, transb, M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } else { + throw std::logic_error("Unsupported datatype for grouped GEMM!"); + } + if (stat != CUBLAS_STATUS_SUCCESS) + throw std::logic_error("Grouped GEMM failed!"); +#elif defined(HIP_BACKEND) + rocblas_status stat; + if (std::is_same::value) { + stat = rocblas_sgemm_strided_batched(g_blas_handle->handle, + transa, transb, M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } else if (std::is_same::value) { + stat = rocblas_dgemm_strided_batched(g_blas_handle->handle, + transa, transb, M, N, K, + reinterpret_cast(&alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(&beta), + reinterpret_cast(C), ldc, strideC, + batch_size); + } else { + throw std::logic_error("Unsupported datatype for grouped GEMM!"); + } + if (stat != rocblas_status_success) + throw std::logic_error("Grouped GEMM failed!"); +#endif + } + } +} diff --git a/openequivariance/openequivariance/extension/group_mm_cuda.hpp b/openequivariance/openequivariance/extension/group_mm_cuda.hpp deleted file mode 100644 index 95f1412f..00000000 --- a/openequivariance/openequivariance/extension/group_mm_cuda.hpp +++ /dev/null @@ -1,143 +0,0 @@ -#pragma once - -#include "cublas_v2.h" -#include -#include -#include - -using namespace std; - -template -class GroupMMCUDA { - cublasStatus_t stat; - cublasHandle_t handle; - - int num_W; - int batch_size; - - T alpha; - T beta; - -public: - GroupMMCUDA(int num_W, int batch_size) : - num_W(num_W), - batch_size(batch_size), - alpha(1.0), - beta(0.0) { - stat = cublasCreate(&handle); - if (stat != CUBLAS_STATUS_SUCCESS) { - throw std::logic_error("CUBLAS initialization failed"); - } - } - - void group_gemm(void* A_raw, void* B_raw, void* C_raw, - int64_t* ragged_counts, int m, int k, int ragged_inner) { - /* - * Performs one of two grouped, batched GEMMs with a single ragged dimension: - * - * a) If ragged_inner = 0, multiplies each M x K row-major weight matrix A - * against B, where B is stored in column-major order with each matrix of - * dimensions K x [offset_diff]. Output has dimensions M x [offset_diff], - * stored in column-major order. - * b) If ragged_inner = 1, multiplies each M x [offset_diff] A matrix - * against each B K x [offset_diff] matrix transposed to produce a - * M x K matrix output. - */ - - T* A_base = reinterpret_cast(A_raw); - T* B_base = reinterpret_cast(B_raw); - T* C_base = reinterpret_cast(C_raw); - - int64_t ragged_offset = 0; - for(int i = 0; i < num_W; i++) { - int M, K, N, lda, ldb, ldc; - T *A, *B, *C; - - int strideA, strideB, strideC; - cublasOperation_t transa, transb; - - if(ragged_inner == 0) { - M = m; - K = k; - N = static_cast(ragged_counts[i]); - - A = A_base + (m * k * batch_size * i); - lda = k; strideA = M * K; - - B = B_base + (k * batch_size * ragged_offset); - ldb = K * batch_size; strideB = K; - - C = C_base + (m * batch_size * ragged_offset); - ldc = M * batch_size; strideC = M; - - transa = CUBLAS_OP_T; - transb = CUBLAS_OP_N; - } - else { - M = k; - K = static_cast(ragged_counts[i]); - N = m; - - A = B_base + (k * batch_size * ragged_offset); - lda = k * batch_size; strideA = M; - - B = A_base + (m * batch_size * ragged_offset); - ldb = m * batch_size; strideB = N; - - C = C_base + (m * k * batch_size * i); - ldc = k; strideC = M * N; - - transa = CUBLAS_OP_N; - transb = CUBLAS_OP_T; - } - ragged_offset += ragged_counts[i]; - - if(ragged_counts[i] > 0) { - if(std::is_same::value) { - stat = cublasSgemmStridedBatched(handle, - transa, transb, - M, N, K, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(&beta), - reinterpret_cast(C), ldc, strideC, - batch_size); - } - else if(std::is_same::value) { - stat = cublasDgemmStridedBatched(handle, - transa, transb, - M, N, K, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(&beta), - reinterpret_cast(C), ldc, strideC, - batch_size); - } - else { - throw std::logic_error("Unsupported datatype for grouped GEMM!"); - } - if (stat != CUBLAS_STATUS_SUCCESS) { - throw std::logic_error("Grouped GEMM failed!"); - } - } - } - } - - void group_gemm_intptr(uint64_t weights, - uint64_t vectors, uint64_t output, - uint64_t ragged_counts, int m, int k, int ragged_inner) { - - group_gemm( - reinterpret_cast(weights), - reinterpret_cast(vectors), - reinterpret_cast(output), - reinterpret_cast(ragged_counts), - m, k, ragged_inner); - } - - ~GroupMMCUDA() { - cublasDestroy(handle); - } -}; \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/group_mm_hip.hpp b/openequivariance/openequivariance/extension/group_mm_hip.hpp deleted file mode 100644 index 2e713cf4..00000000 --- a/openequivariance/openequivariance/extension/group_mm_hip.hpp +++ /dev/null @@ -1,129 +0,0 @@ -#pragma once - -#include "rocblas/rocblas.h" -#include -#include -#include - - -template -class GroupMMHIP { - rocblas_status stat; - rocblas_handle handle; - - int num_W; - int batch_size; - - T alpha; - T beta; - -public: - GroupMMHIP(int num_W, int batch_size) : - num_W(num_W), - batch_size(batch_size), - alpha(1.0), - beta(0.0) { - if(rocblas_create_handle(&handle) != rocblas_status_success) { - throw std::logic_error("rocBLAS initialization failed"); - } - } - - void group_gemm(void* A_raw, void* B_raw, void* C_raw, - int64_t* ragged_counts, int m, int k, int ragged_inner) { - - T* A_base = reinterpret_cast(A_raw); - T* B_base = reinterpret_cast(B_raw); - T* C_base = reinterpret_cast(C_raw); - - int64_t ragged_offset = 0; - for(int i = 0; i < num_W; i++) { - int M, K, N, lda, ldb, ldc; - T *A, *B, *C; - - int strideA, strideB, strideC; - rocblas_operation transa, transb; - - if(ragged_inner == 0) { - M = m; - K = k; - N = static_cast(ragged_counts[i]); - - A = A_base + (m * k * batch_size * i); - lda = k; strideA = M * K; - - B = B_base + (k * batch_size * ragged_offset); - ldb = K * batch_size; strideB = K; - - C = C_base + (m * batch_size * ragged_offset); - ldc = M * batch_size; strideC = M; - - transa = rocblas_operation_transpose; - transb = rocblas_operation_none; - } - else { - M = k; - K = static_cast(ragged_counts[i]); - N = m; - - A = B_base + (k * batch_size * ragged_offset); - lda = k * batch_size; strideA = M; - - B = A_base + (m * batch_size * ragged_offset); - ldb = m * batch_size; strideB = N; - - C = C_base + (m * k * batch_size * i); - ldc = k; strideC = M * N; - - transa = rocblas_operation_none; - transb = rocblas_operation_transpose; - } - ragged_offset += ragged_counts[i]; - - if(ragged_counts[i] > 0) { - if(std::is_same::value) { - stat = rocblas_sgemm_strided_batched(handle, - transa, transb, - M, N, K, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(&beta), - reinterpret_cast(C), ldc, strideC, - batch_size); - } - else if(std::is_same::value) { - stat = rocblas_dgemm_strided_batched(handle, - transa, transb, - M, N, K, - reinterpret_cast(&alpha), - reinterpret_cast(A), lda, strideA, - reinterpret_cast(B), ldb, strideB, - reinterpret_cast(&beta), - reinterpret_cast(C), ldc, strideC, - batch_size); - } - else { - throw std::logic_error("Unsupported datatype for grouped GEMM!"); - } - if (stat != rocblas_status_success) { - throw std::logic_error("Grouped GEMM failed!"); - } - } - } - } - - void group_gemm_intptr(uint64_t weights, - uint64_t vectors, uint64_t output, - uint64_t ragged_counts, int m, int k, int ragged_inner) { - group_gemm( - reinterpret_cast(weights), - reinterpret_cast(vectors), - reinterpret_cast(output), - reinterpret_cast(ragged_counts), - m, k, ragged_inner); - } - - ~GroupMMHIP() { - rocblas_destroy_handle(handle); - } -}; \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index 173c0a2e..698a142f 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -84,13 +84,6 @@ Stream get_current_stream() { namespace py=pybind11; PYBIND11_MODULE(libtorch_tp_jit, m) { - py::class_>(m, "GroupMM_F32") - .def(py::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - py::class_>(m, "GroupMM_F64") - .def(py::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - py::class_(m, "DeviceProp") .def(py::init()) .def_readonly("name", &DeviceProp::name) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp index 63fcc7f4..9ab6cd2f 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit_stable.cpp @@ -89,13 +89,6 @@ Stream get_current_stream() { #include "nanobind/stl/string.h" namespace nb = nanobind; NB_MODULE(EXTENSION_NAME, m) { - nb::class_>(m, "GroupMM_F32") - .def(nb::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - nb::class_>(m, "GroupMM_F64") - .def(nb::init()) - .def("group_gemm", &GroupMM::group_gemm_intptr); - nb::class_(m, "DeviceProp") .def(nb::init()) .def_ro("name", &DeviceProp::name) diff --git a/openequivariance/openequivariance/extension/torch_core.hpp b/openequivariance/openequivariance/extension/torch_core.hpp index 8e9e43e5..8138a863 100644 --- a/openequivariance/openequivariance/extension/torch_core.hpp +++ b/openequivariance/openequivariance/extension/torch_core.hpp @@ -14,24 +14,18 @@ #ifdef CUDA_BACKEND #include "backend_cuda.hpp" - #include "group_mm_cuda.hpp" using JITKernel = CUJITKernel; using GPU_Allocator = CUDA_Allocator; - - template - using GroupMM = GroupMMCUDA; #endif #ifdef HIP_BACKEND #include "backend_hip.hpp" - #include "group_mm_hip.hpp" using JITKernel = HIPJITKernel; using GPU_Allocator = HIP_Allocator; - - template - using GroupMM = GroupMMHIP; #endif +#include "group_mm.hpp" + #include "tensorproducts.hpp" #include "convolution.hpp" @@ -623,6 +617,42 @@ inline tuple jit_conv_double_backward( // =========================================================== +inline Tensor group_gemm( + Tensor A, Tensor B, Tensor ragged_counts, + int64_t num_W, int64_t batch_size, int64_t m, int64_t k, int64_t ragged_inner) { + TCHECK(A.scalar_type() == B.scalar_type(), "group_gemm: A and B must have the same dtype"); + TCHECK(ragged_counts.scalar_type() == kLong, "group_gemm: ragged_counts must be int64"); + + Tensor A_c = tensor_contiguous(A); + Tensor B_c = tensor_contiguous(B); + Tensor rc_c = tensor_contiguous(ragged_counts); + int64_t* rc_ptr = reinterpret_cast(data_ptr(rc_c)); + + Tensor C; + if (ragged_inner == 0) { + C = tensor_zeros_like(A, make_sizes({B.size(0), B.size(1), m})); + } + else { + C = tensor_zeros_like(A, make_sizes({num_W, batch_size, m, k})); + } + + if (A.scalar_type() == kFloat) { + group_gemm(data_ptr(A_c), data_ptr(B_c), data_ptr(C), rc_ptr, + (int)num_W, (int)batch_size, (int)m, (int)k, (int)ragged_inner); + } + else if (A.scalar_type() == kDouble) { + group_gemm(data_ptr(A_c), data_ptr(B_c), data_ptr(C), rc_ptr, + (int)num_W, (int)batch_size, (int)m, (int)k, (int)ragged_inner); + } + else { + throw std::logic_error("group_gemm: unsupported dtype, expected float32 or float64"); + } + + return C; +} + +// =========================================================== + REGISTER_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { m.impl("jit_tp_forward", BOX(&jit_tp_forward)); m.impl("jit_tp_backward", BOX(&jit_tp_backward)); @@ -631,6 +661,8 @@ REGISTER_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { m.impl("jit_conv_forward", BOX(&jit_conv_forward)); m.impl("jit_conv_backward", BOX(&jit_conv_backward)); m.impl("jit_conv_double_backward", BOX(&jit_conv_double_backward)); + + m.impl("group_gemm", BOX(&group_gemm)); }; REGISTER_LIBRARY(libtorch_tp_jit, m) { @@ -641,4 +673,6 @@ REGISTER_LIBRARY(libtorch_tp_jit, m) { m.def("jit_conv_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); + + m.def("group_gemm(Tensor A, Tensor B, Tensor ragged_counts, int num_W, int batch_size, int m, int k, int ragged_inner) -> Tensor"); }; From 895e7cb7cababc81b86c498d8aa835034e35c708 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 13 Jun 2026 14:57:18 -0700 Subject: [PATCH 2/8] More fixes. --- .../openequivariance/extension/group_mm.hpp | 2 +- .../openequivariance/extension/torch_core.hpp | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/openequivariance/openequivariance/extension/group_mm.hpp b/openequivariance/openequivariance/extension/group_mm.hpp index 1e788bad..e3110b36 100644 --- a/openequivariance/openequivariance/extension/group_mm.hpp +++ b/openequivariance/openequivariance/extension/group_mm.hpp @@ -32,7 +32,7 @@ inline std::unique_ptr g_blas_handle = std::make_unique(); template -void group_gemm(void* A_raw, void* B_raw, void* C_raw, +void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw, int64_t* ragged_counts, int num_W, int batch_size, int m, int k, int ragged_inner) { T alpha = 1.0, beta = 0.0; diff --git a/openequivariance/openequivariance/extension/torch_core.hpp b/openequivariance/openequivariance/extension/torch_core.hpp index 8138a863..869d72e2 100644 --- a/openequivariance/openequivariance/extension/torch_core.hpp +++ b/openequivariance/openequivariance/extension/torch_core.hpp @@ -637,14 +637,12 @@ inline Tensor group_gemm( } if (A.scalar_type() == kFloat) { - group_gemm(data_ptr(A_c), data_ptr(B_c), data_ptr(C), rc_ptr, + group_gemm_blas(data_ptr(A_c), data_ptr(B_c), data_ptr(C), rc_ptr, (int)num_W, (int)batch_size, (int)m, (int)k, (int)ragged_inner); - } - else if (A.scalar_type() == kDouble) { - group_gemm(data_ptr(A_c), data_ptr(B_c), data_ptr(C), rc_ptr, + } else if (A.scalar_type() == kDouble) { + group_gemm_blas(data_ptr(A_c), data_ptr(B_c), data_ptr(C), rc_ptr, (int)num_W, (int)batch_size, (int)m, (int)k, (int)ragged_inner); - } - else { + } else { throw std::logic_error("group_gemm: unsupported dtype, expected float32 or float64"); } From e3cac0b722d1ef175e2531e6771254e0a3d0cf5b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 13 Jun 2026 15:18:18 -0700 Subject: [PATCH 3/8] Compile and export tests passing. --- docs/supported_ops.rst | 14 +++--- .../SymmetricContraction.py | 2 +- .../openequivariance/extension/torch_core.hpp | 2 +- tests/export_test.py | 45 +++++++++++++++++++ 4 files changed, 52 insertions(+), 11 deletions(-) diff --git a/docs/supported_ops.rst b/docs/supported_ops.rst index bcc11955..b8d90203 100644 --- a/docs/supported_ops.rst +++ b/docs/supported_ops.rst @@ -105,15 +105,11 @@ See PyTorch usage details `here Date: Sat, 13 Jun 2026 15:40:53 -0700 Subject: [PATCH 4/8] More fixes. --- .../SymmetricContraction.py | 66 ++++++++++-- tests/export_test.py | 45 -------- tests/symmetric_contraction_test.py | 101 ++++++++++++------ 3 files changed, 128 insertions(+), 84 deletions(-) diff --git a/openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py b/openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py index 4c09c57e..1c91865d 100644 --- a/openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py +++ b/openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py @@ -128,9 +128,14 @@ def __init__(self, dtype, num_elements, batch_size): def forward(self, weights, vectors, bincounts): return torch.ops.libtorch_tp_jit.group_gemm( - weights, vectors, bincounts, - self.num_elements, self.batch_size, - weights.shape[2], weights.shape[3], 0, + weights, + vectors, + bincounts, + self.num_elements, + self.batch_size, + weights.shape[2], + weights.shape[3], + 0, ) @@ -333,15 +338,60 @@ def register_autograd(): op = torch.ops.libtorch_tp_jit.group_gemm def setup_context(ctx, inputs, output): - ctx.A, ctx.B, ctx.ragged_counts, ctx.num_W, ctx.batch_size, ctx.m, ctx.k, ctx.ragged_inner = inputs + ( + ctx.A, + ctx.B, + ctx.ragged_counts, + ctx.num_W, + ctx.batch_size, + ctx.m, + ctx.k, + ctx.ragged_inner, + ) = inputs def backward(ctx, grad_output): if ctx.ragged_inner == 0: - grad_A = op(grad_output, ctx.B, ctx.ragged_counts, ctx.num_W, ctx.batch_size, ctx.m, ctx.k, 1) - grad_B = op(ctx.A.transpose(2, 3), grad_output, ctx.ragged_counts, ctx.num_W, ctx.batch_size, ctx.k, ctx.m, 0) + grad_A = op( + grad_output, + ctx.B, + ctx.ragged_counts, + ctx.num_W, + ctx.batch_size, + ctx.m, + ctx.k, + 1, + ) + grad_B = op( + ctx.A.transpose(2, 3), + grad_output, + ctx.ragged_counts, + ctx.num_W, + ctx.batch_size, + ctx.k, + ctx.m, + 0, + ) else: - grad_A = op(grad_output, ctx.B, ctx.ragged_counts, ctx.num_W, ctx.batch_size, ctx.m, ctx.k, 0) - grad_B = op(grad_output.transpose(2, 3), ctx.A, ctx.ragged_counts, ctx.num_W, ctx.batch_size, ctx.k, ctx.m, 0) + grad_A = op( + grad_output, + ctx.B, + ctx.ragged_counts, + ctx.num_W, + ctx.batch_size, + ctx.m, + ctx.k, + 0, + ) + grad_B = op( + grad_output.transpose(2, 3), + ctx.A, + ctx.ragged_counts, + ctx.num_W, + ctx.batch_size, + ctx.k, + ctx.m, + 0, + ) return grad_A, grad_B, None, None, None, None, None, None torch.library.register_autograd( diff --git a/tests/export_test.py b/tests/export_test.py index 3074cc67..6aba9690 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -223,48 +223,3 @@ def test_aoti_cpp_inference(problem_and_irreps): print(e.stdout.decode(), file=sys.stderr) print(e.stderr.decode(), file=sys.stderr) assert False - - -# ------------------- Symmetric Contraction Tests -------------------- - -@pytest.fixture(scope="session") -def sc_and_inputs(): - from e3nn import o3 - from openequivariance._torch.symmetric_contraction import SymmetricContraction - - irreps_in = o3.Irreps("4x0e + 4x1o") - irreps_out = o3.Irreps("4x0e + 4x1o") - num_elements = 4 - batch_size = 256 - dtype = torch.float32 - - sc = SymmetricContraction( - irreps_in=irreps_in, - irreps_out=irreps_out, - correlation=2, - num_elements=num_elements, - dtype=dtype, - ).to("cuda") - - gen = torch.Generator(device="cuda") - gen.manual_seed(42) - x = torch.randn(batch_size, irreps_in.dim, dtype=dtype, device="cuda", generator=gen) - y = torch.zeros(batch_size, num_elements, dtype=dtype, device="cuda") - indices = torch.randint(num_elements, (batch_size,), generator=gen, device="cuda") - y[torch.arange(batch_size, device="cuda"), indices] = 1.0 - - return sc, (x, y) - - -def test_sc_compile(sc_and_inputs): - sc, inputs = sc_and_inputs - ref = sc(*inputs) - out = torch.compile(sc)(*inputs) - assert torch.allclose(ref, out, atol=1e-5) - - -def test_sc_export(sc_and_inputs): - sc, inputs = sc_and_inputs - ref = sc(*inputs) - exported = torch.export.export(sc, args=inputs, strict=False) - assert torch.allclose(ref, exported.module()(*inputs), atol=1e-5) diff --git a/tests/symmetric_contraction_test.py b/tests/symmetric_contraction_test.py index fdf0861d..69f324bb 100644 --- a/tests/symmetric_contraction_test.py +++ b/tests/symmetric_contraction_test.py @@ -1,3 +1,4 @@ +import collections from unittest.mock import patch import pytest @@ -12,13 +13,34 @@ MaceSymmetricContraction = mace_symmetric_contraction.SymmetricContraction -IRREPS_IN = o3.Irreps("2x0e + 2x1o") -IRREPS_OUT = o3.Irreps("2x0e + 2x1o") -CORRELATION = 2 -NUM_ELEMENTS = 4 -LABEL_VALUES = [0, 2, 3, 2, 0, 0, 2, 3, 2, 2] +SCConfig = collections.namedtuple( + "SCConfig", + [ + "irreps_in", + "irreps_out", + "correlation", + "num_elements", + "label_values", + "device", + ], +) + DEVICE = torch.device("cuda") +SC_CONFIGS = [ + SCConfig(o3.Irreps("2x0e + 2x1o"), o3.Irreps("2x0e + 2x1o"), 2, 4, [0, 2, 3, 2, 0, 0, 2, 3, 2, 2], DEVICE), + SCConfig(o3.Irreps("1x0e + 1x1o + 1x2e"), o3.Irreps("1x0e + 1x1o"), 3, 3, [0, 1, 2, 0, 1, 2, 0, 1], DEVICE), + SCConfig(o3.Irreps("4x0e + 4x1o"), o3.Irreps("4x0e"), 2, 5, [0, 1, 2, 3, 4, 0, 1, 2, 3, 4], DEVICE), +] + + +@pytest.fixture( + params=SC_CONFIGS, + ids=lambda cfg: f"{cfg.irreps_in}-corr{cfg.correlation}", +) +def symmetric_contraction_config(request): + return request.param + @pytest.fixture(params=[torch.float32, torch.float64], ids=["F32", "F64"]) def dtype(request): @@ -26,24 +48,28 @@ def dtype(request): @pytest.fixture -def labels(): - return torch.tensor(LABEL_VALUES, device=DEVICE, dtype=torch.long) +def labels(symmetric_contraction_config): + cfg = symmetric_contraction_config + return torch.tensor(cfg.label_values, device=cfg.device, dtype=torch.long) @pytest.fixture -def node_attrs(labels, dtype): - return F.one_hot(labels, num_classes=NUM_ELEMENTS).to(dtype=dtype) +def node_attrs(labels, dtype, symmetric_contraction_config): + return F.one_hot(labels, num_classes=symmetric_contraction_config.num_elements).to( + dtype=dtype + ) @pytest.fixture -def node_feats(dtype): - gen = torch.Generator(device=DEVICE) +def node_feats(dtype, symmetric_contraction_config): + cfg = symmetric_contraction_config + gen = torch.Generator(device=cfg.device) gen.manual_seed(2468) return torch.randn( - len(LABEL_VALUES), - IRREPS_IN.count((0, 1)), - IRREPS_IN.dim // IRREPS_IN.count((0, 1)), - device=DEVICE, + len(cfg.label_values), + cfg.irreps_in.count((0, 1)), + cfg.irreps_in.dim // cfg.irreps_in.count((0, 1)), + device=cfg.device, dtype=dtype, generator=gen, requires_grad=True, @@ -51,28 +77,27 @@ def node_feats(dtype): @pytest.fixture -def modules(dtype): +def modules(dtype, symmetric_contraction_config): + cfg = symmetric_contraction_config torch.manual_seed(12345) oeq_module = SymmetricContraction( - IRREPS_IN, - IRREPS_OUT, - correlation=CORRELATION, - num_elements=NUM_ELEMENTS, + cfg.irreps_in, + cfg.irreps_out, + correlation=cfg.correlation, + num_elements=cfg.num_elements, dtype=dtype, - ).to(DEVICE) + ).to(cfg.device) - # MACE's original e3nn implementation reads torch.get_default_dtype() - # during construction, so patch that lookup instead of mutating global state. with patch( "mace.modules.symmetric_contraction.torch.get_default_dtype", return_value=dtype, ): mace_module = MaceSymmetricContraction( - IRREPS_IN, - IRREPS_OUT, - correlation=CORRELATION, - num_elements=NUM_ELEMENTS, - ).to(device=DEVICE, dtype=dtype) + cfg.irreps_in, + cfg.irreps_out, + correlation=cfg.correlation, + num_elements=cfg.num_elements, + ).to(device=cfg.device, dtype=dtype) copy_matching_state(oeq_module, mace_module) return oeq_module, mace_module @@ -120,15 +145,16 @@ def random_like(tensor, seed): class TestSymmetricContraction: def test_matches_mace_forward_backward( - self, modules, node_feats, node_attrs, dtype + self, modules, node_feats, node_attrs, dtype, symmetric_contraction_config ): + cfg = symmetric_contraction_config oeq_module, mace_module = modules mace_node_feats = node_feats.detach().clone().requires_grad_() oeq_output = oeq_module(node_feats, node_attrs) mace_output = mace_module(mace_node_feats, node_attrs) - assert oeq_output.shape == (len(LABEL_VALUES), IRREPS_OUT.dim) + assert oeq_output.shape == (len(cfg.label_values), cfg.irreps_out.dim) torch.testing.assert_close(oeq_output, mace_output, **tolerance(dtype)) output_grad = random_like(oeq_output, seed=4321) @@ -144,7 +170,9 @@ def test_matches_mace_forward_backward( for oeq_grad, mace_grad in zip(oeq_grads, mace_grads): torch.testing.assert_close(oeq_grad, mace_grad, **tolerance(dtype)) - def test_matches_mace_double_backward(self, modules, node_feats, node_attrs, dtype): + def test_matches_mace_double_backward( + self, modules, node_feats, node_attrs, dtype, symmetric_contraction_config + ): oeq_module, mace_module = modules mace_node_feats = node_feats.detach().clone().requires_grad_() @@ -193,3 +221,14 @@ def test_matches_mace_double_backward(self, modules, node_feats, node_attrs, dty for oeq_grad, mace_grad in zip(oeq_second_grads, mace_second_grads): torch.testing.assert_close(oeq_grad, mace_grad, **tolerance(dtype)) + + def test_compile(self, modules, node_feats, node_attrs): + sc, _ = modules + ref = sc(node_feats, node_attrs) + assert torch.allclose(ref, torch.compile(sc)(node_feats, node_attrs), atol=1e-5) + + def test_export(self, modules, node_feats, node_attrs): + sc, _ = modules + ref = sc(node_feats, node_attrs) + exported = torch.export.export(sc, args=(node_feats, node_attrs), strict=False) + assert torch.allclose(ref, exported.module()(node_feats, node_attrs), atol=1e-5) From 265c9ff73d01f525e33b8b291aa50bc82e3c1b05 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 13 Jun 2026 15:47:43 -0700 Subject: [PATCH 5/8] Updated the changelog. --- CHANGELOG.md | 16 ++++++++++++++++ openequivariance/pyproject.toml | 2 +- openequivariance_extjax/pyproject.toml | 2 +- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 723857d4..a638438b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ ## Latest Changes +### v0.6.7 (2026-06-13) + +**Added**: +- Triple backward and higher-order derivative support for + tensor products in Pytorch. +- Reintroduced symmetric contraction implementation for PyTorch. +- `torch.compile`, `torch.export` support for symmetric + contraction. + +**Fixed**: +- Some compilation issues for RocM. + +### v0.6.6 (2026-06-13) +Bugfix: added alternate URL for libtorch aarch64 download in +stable extension. + ### v0.6.5 (2026-03-22) This release brings `ir_mul` layout support for OpenEquivariance. Pass the parameter diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 1028c8cb..4711e9cc 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance" -version = "0.6.6" +version = "0.6.7" authors = [ { name="Austin Glover" }, { name="Vivek Bharadwaj" }, diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index af35dde4..99731dab 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance_extjax" -version = "0.6.6" +version = "0.6.7" authors = [ { name="Austin Glover" }, From 25abd18f607c48bd604c508919ff4f04a9a2c2df Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 13 Jun 2026 18:52:45 -0700 Subject: [PATCH 6/8] Removed spurious diffs. --- .../SymmetricContraction.py | 7 +++++ tests/symmetric_contraction_test.py | 27 ++++++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py b/openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py index 1c91865d..83c56933 100644 --- a/openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py +++ b/openequivariance/openequivariance/_torch/symmetric_contraction/SymmetricContraction.py @@ -169,17 +169,21 @@ def __init__( )[-1] self.register_buffer(f"U_matrix_{nu}", U_matrix) + # Tensor contraction equations self.contractions_weighting = torch.nn.ModuleList() self.contractions_features = torch.nn.ModuleList() + # Create weight for product basis self.weights = torch.nn.ParameterList([]) self.groupMM = GroupMM(dtype, num_elements, self.num_features) self.num_equivariance = 2 * irrep_out.lmax + 1 for i in range(correlation, 0, -1): + # Shapes defining num_params = self.U_tensors(i).size()[-1] if i == correlation: + # Parameters for the product basis w = torch.nn.Parameter( torch.randn( (num_elements, num_params, self.num_features), dtype=dtype @@ -188,6 +192,7 @@ def __init__( ) self.weights_max = w else: + # Parameters for the product basis w = torch.nn.Parameter( torch.randn( (num_elements, num_params, self.num_features), dtype=dtype @@ -200,6 +205,7 @@ def __init__( self.weights = weights[:-1] self.weights_max = weights[-1] + # Permute the U matrices for i in range(correlation, 0, -1): U = self.U_tensors(i) num_params = U.shape[-1] @@ -238,6 +244,7 @@ def forward( c_tensor.view(s[0] * s[1], -1, s[-1]) * x.view(s[0] * s[1], 1, s[-1]), dim=2, ).view(s[:-1]) + # out = torch.bmm(c_tensor.view(s[0] * s[1], -1, s[-1]), x.view(s[0] * s[1], s[-1], 1)).view(s[:-1]) return out.view(out.shape[0], -1) diff --git a/tests/symmetric_contraction_test.py b/tests/symmetric_contraction_test.py index 69f324bb..bbd105ac 100644 --- a/tests/symmetric_contraction_test.py +++ b/tests/symmetric_contraction_test.py @@ -28,9 +28,30 @@ DEVICE = torch.device("cuda") SC_CONFIGS = [ - SCConfig(o3.Irreps("2x0e + 2x1o"), o3.Irreps("2x0e + 2x1o"), 2, 4, [0, 2, 3, 2, 0, 0, 2, 3, 2, 2], DEVICE), - SCConfig(o3.Irreps("1x0e + 1x1o + 1x2e"), o3.Irreps("1x0e + 1x1o"), 3, 3, [0, 1, 2, 0, 1, 2, 0, 1], DEVICE), - SCConfig(o3.Irreps("4x0e + 4x1o"), o3.Irreps("4x0e"), 2, 5, [0, 1, 2, 3, 4, 0, 1, 2, 3, 4], DEVICE), + SCConfig( + o3.Irreps("2x0e + 2x1o"), + o3.Irreps("2x0e + 2x1o"), + 2, + 4, + [0, 2, 3, 2, 0, 0, 2, 3, 2, 2], + DEVICE, + ), + SCConfig( + o3.Irreps("1x0e + 1x1o + 1x2e"), + o3.Irreps("1x0e + 1x1o"), + 3, + 3, + [0, 1, 2, 0, 1, 2, 0, 1], + DEVICE, + ), + SCConfig( + o3.Irreps("4x0e + 4x1o"), + o3.Irreps("4x0e"), + 2, + 5, + [0, 1, 2, 3, 4, 0, 1, 2, 3, 4], + DEVICE, + ), ] From c57013a23a8d401b25f56cc773adcb3fa339375c Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 13 Jun 2026 20:36:28 -0700 Subject: [PATCH 7/8] Attempted CI bugfix. --- .../openequivariance/extension/group_mm.hpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/openequivariance/openequivariance/extension/group_mm.hpp b/openequivariance/openequivariance/extension/group_mm.hpp index e3110b36..a1d7be5a 100644 --- a/openequivariance/openequivariance/extension/group_mm.hpp +++ b/openequivariance/openequivariance/extension/group_mm.hpp @@ -29,12 +29,16 @@ }; #endif -inline std::unique_ptr g_blas_handle = std::make_unique(); +inline BlasHandle& get_blas_handle() { + static BlasHandle handle; + return handle; +} template void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw, int64_t* ragged_counts, int num_W, int batch_size, int m, int k, int ragged_inner) { + auto& blas = get_blas_handle(); T alpha = 1.0, beta = 0.0; T* A_base = reinterpret_cast(A_raw); T* B_base = reinterpret_cast(B_raw); @@ -83,7 +87,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw, #ifdef CUDA_BACKEND cublasStatus_t stat; if (std::is_same::value) { - stat = cublasSgemmStridedBatched(g_blas_handle->handle, + stat = cublasSgemmStridedBatched(blas.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A), lda, strideA, @@ -92,7 +96,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw, reinterpret_cast(C), ldc, strideC, batch_size); } else if (std::is_same::value) { - stat = cublasDgemmStridedBatched(g_blas_handle->handle, + stat = cublasDgemmStridedBatched(blas.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A), lda, strideA, @@ -108,7 +112,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw, #elif defined(HIP_BACKEND) rocblas_status stat; if (std::is_same::value) { - stat = rocblas_sgemm_strided_batched(g_blas_handle->handle, + stat = rocblas_sgemm_strided_batched(blas.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A), lda, strideA, @@ -117,7 +121,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw, reinterpret_cast(C), ldc, strideC, batch_size); } else if (std::is_same::value) { - stat = rocblas_dgemm_strided_batched(g_blas_handle->handle, + stat = rocblas_dgemm_strided_batched(blas.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A), lda, strideA, From d97e790d4c7c249705e658b8c08d6aeef4c6c8d4 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 13 Jun 2026 20:46:18 -0700 Subject: [PATCH 8/8] Fixed errors. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 87e5e785..ec567a90 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -19,24 +19,16 @@ using json = json11::Json; #include #include "backend/backend_cuda.hpp" - #include "group_mm_cuda.hpp" using JITKernel = CUJITKernel; using GPU_Allocator = CUDA_Allocator; - - template - using GroupMM = GroupMMCUDA; using stream_t = cudaStream_t; #endif #ifdef HIP_BACKEND #include "backend/backend_hip.hpp" - #include "group_mm_hip.hpp" using JITKernel = HIPJITKernel; using GPU_Allocator = HIP_Allocator; - - template - using GroupMM = GroupMMHIP; - using stream_t = hipStream_t; + using stream_t = hipStream_t; #endif #include "tensorproducts.hpp"