Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
## Latest Changes

### v0.6.7 (2026-06-13)

**Added**:
- Triple backward and higher-order derivative support for
tensor products in Pytorch.
- Reintroduced symmetric contraction implementation for PyTorch.
- `torch.compile`, `torch.export` support for symmetric
contraction.

**Fixed**:
- Some compilation issues for RocM.

### v0.6.6 (2026-06-13)
Bugfix: added alternate URL for libtorch aarch64 download in
stable extension.

### v0.6.5 (2026-03-22)
This release brings `ir_mul` layout support for
OpenEquivariance. Pass the parameter
Expand Down
14 changes: 5 additions & 9 deletions docs/supported_ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,11 @@ See PyTorch usage details `here <https://docs.pytorch.org/docs/stable/notes/cuda
Symmetric Contraction (Beta)
----------------------------

We have recently added beta support for symmetric
contraction acceleration. This primitive:

- Is specific to MACE
- Requires e3nn as a dependency.
- Currently has no support for compile / export

As a result, we do not expose it in the package
toplevel. You can use our implementation by running
We have beta support for symmetric
contraction acceleration, which is used by MACE. This primitive
requires e3nn installed as a dependency. As a result, we do not
expose it in the package toplevel. You can use our implementation
by running

.. code-block::

Expand Down
8 changes: 0 additions & 8 deletions openequivariance/openequivariance/_torch/extlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,6 @@ def torch_ext_so_path():

if BUILT_EXTENSION:
from oeq_utilities import (
GroupMM_F32,
GroupMM_F64,
DeviceProp,
GPUTimer,
)
Expand All @@ -210,12 +208,6 @@ def _raise_import_error_helper(import_target: str):
f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}"
)

def GroupMM_F32(*args, **kwargs):
_raise_import_error_helper("GroupMM_F32")

def GroupMM_F64(*args, **kwargs):
_raise_import_error_helper("GroupMM_F64")

def DeviceProp(*args, **kwargs):
_raise_import_error_helper("DeviceProp")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,115 +1,7 @@
# ruff: noqa : E402
import torch

from openequivariance._torch.extlib import GroupMM_F32, GroupMM_F64


class GroupMM:
next_id = 0

def __init__(self, dtype, num_elements, batch_size):
self.id = GroupMM.next_id
self.num_elements = num_elements
GroupMM.next_id += 1

if dtype == torch.float32:
self.internal = GroupMM_F32(num_elements, batch_size)
else:
self.internal = GroupMM_F64(num_elements, batch_size)

@torch.library.custom_op(
f"openequivariance::group_gemm{self.id}",
mutates_args=(),
device_types="cuda",
)
def group_gemm(
A: torch.Tensor,
B: torch.Tensor,
ragged_counts: torch.Tensor,
M: int,
K: int,
ragged_inner: int,
) -> torch.Tensor:
"""
If ragged_inner == 0:
A is 3D, num_weights x num_features x M x K
B is batch_size x num_features x K
C is batch_size x num_features x M
If ragged_inner == 1: (needed for the backward pass)
A is batch_size x num_features x M
B is batch_size x num_features K
C is 3D, num_weights x num_features M x K
"""
shape = None
if ragged_inner == 0:
shape = (B.shape[0], B.shape[1], M)
elif ragged_inner == 1:
shape = (num_elements, B.shape[1], M, K)

C = torch.zeros(shape, device="cuda", dtype=A.dtype)
self.internal.group_gemm(
A.contiguous().data_ptr(),
B.contiguous().data_ptr(),
C.data_ptr(),
ragged_counts.data_ptr(),
M,
K,
ragged_inner,
)
return C

@group_gemm.register_fake
def _(A, B, ragged_counts, M, K, ragged_inner):
if ragged_inner == 0:
return A.new_empty(B.shape[0], B.shape[1], M)
elif ragged_inner == 1:
return A.new_empty(num_elements, batch_size, M, K)

self.group_gemm = group_gemm

def setup_context(ctx, inputs, output):
ctx.A, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, ctx.ragged_inner = inputs

def backward(ctx, grad_output):
grad_A, grad_B = None, None

if ctx.ragged_inner == 0:
grad_A = group_gemm(
grad_output, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, 1
)
grad_B = group_gemm(
ctx.A.transpose(2, 3),
grad_output,
ctx.ragged_counts,
ctx.K,
ctx.M,
0,
)
elif ctx.ragged_inner == 1:
grad_A = group_gemm(
grad_output, ctx.B, ctx.ragged_counts, ctx.M, ctx.K, 0
)
grad_B = group_gemm(
grad_output.transpose(2, 3),
ctx.A,
ctx.ragged_counts,
ctx.K,
ctx.M,
0,
)

return grad_A, grad_B, None, None, None, None

self.group_gemm.register_autograd(backward, setup_context=setup_context)

def forward(self, weights, vectors, bincounts):
return self.group_gemm(
weights, vectors, bincounts, weights.shape[2], weights.shape[3], 0
)


# --------------------------------------------------------------------------
# The following segment of code was copied from MACE's repo at https://github.com/ACEsuit/mace/blob/b5faaa076c49778fc17493edfecebcabeb960155/mace/tools/cg.py#L106
from openequivariance._torch import extlib

import collections
from typing import Dict, Optional, Union, List
Expand Down Expand Up @@ -229,6 +121,24 @@ def U_matrix_real(
return out


class GroupMM:
def __init__(self, dtype, num_elements, batch_size):
self.num_elements = num_elements
self.batch_size = batch_size

def forward(self, weights, vectors, bincounts):
return torch.ops.libtorch_tp_jit.group_gemm(
weights,
vectors,
bincounts,
self.num_elements,
self.batch_size,
weights.shape[2],
weights.shape[3],
0,
)


@compile_mode("script")
class Contraction(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -420,3 +330,82 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
]
outs_cat = torch.cat(outs, dim=-1)[inverse_perm]
return outs_cat


def register_torch_fakes():
@torch.library.register_fake("libtorch_tp_jit::group_gemm")
def fake_group_gemm(A, B, ragged_counts, num_W, batch_size, m, k, ragged_inner):
if ragged_inner == 0:
return A.new_empty(B.shape[0], batch_size, m)
else:
return A.new_empty(num_W, batch_size, m, k)


def register_autograd():
op = torch.ops.libtorch_tp_jit.group_gemm

def setup_context(ctx, inputs, output):
(
ctx.A,
ctx.B,
ctx.ragged_counts,
ctx.num_W,
ctx.batch_size,
ctx.m,
ctx.k,
ctx.ragged_inner,
) = inputs

def backward(ctx, grad_output):
if ctx.ragged_inner == 0:
grad_A = op(
grad_output,
ctx.B,
ctx.ragged_counts,
ctx.num_W,
ctx.batch_size,
ctx.m,
ctx.k,
1,
)
grad_B = op(
ctx.A.transpose(2, 3),
grad_output,
ctx.ragged_counts,
ctx.num_W,
ctx.batch_size,
ctx.k,
ctx.m,
0,
)
else:
grad_A = op(
grad_output,
ctx.B,
ctx.ragged_counts,
ctx.num_W,
ctx.batch_size,
ctx.m,
ctx.k,
0,
)
grad_B = op(
grad_output.transpose(2, 3),
ctx.A,
ctx.ragged_counts,
ctx.num_W,
ctx.batch_size,
ctx.k,
ctx.m,
0,
)
return grad_A, grad_B, None, None, None, None, None, None

torch.library.register_autograd(
"libtorch_tp_jit::group_gemm", backward, setup_context=setup_context
)


if extlib.BUILT_EXTENSION:
register_torch_fakes()
register_autograd()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from openequivariance._torch.symmetric_contraction.symmetric_contraction import (
from openequivariance._torch.symmetric_contraction.SymmetricContraction import (
SymmetricContraction,
)

Expand Down
Loading
Loading