Skip to content

Forward and reverse Enzyme tests and rules for linalg#449

Open
kshyatt wants to merge 6 commits into
mainfrom
ksh/enz_linalg
Open

Forward and reverse Enzyme tests and rules for linalg#449
kshyatt wants to merge 6 commits into
mainfrom
ksh/enz_linalg

Conversation

@kshyatt

@kshyatt kshyatt commented Jun 10, 2026

Copy link
Copy Markdown
Member

Trying to make these a little more manageable and pick up the fwd rules where possible

@kshyatt kshyatt requested review from Jutho and lkdvos June 10, 2026 13:30
@codecov

codecov Bot commented Jun 10, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0.47393% with 210 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/TensorKitEnzymeExt/linalg.jl 0.00% 127 Missing ⚠️
ext/TensorKitEnzymeExt/utility.jl 0.00% 44 Missing ⚠️
ext/TensorKitEnzymeTestUtilsExt.jl 0.00% 39 Missing ⚠️
Files with missing lines Coverage Δ
ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl 100.00% <100.00%> (ø)
ext/TensorKitEnzymeTestUtilsExt.jl 0.00% <0.00%> (ø)
ext/TensorKitEnzymeExt/utility.jl 0.00% <0.00%> (ø)
ext/TensorKitEnzymeExt/linalg.jl 0.00% <0.00%> (ø)

... and 12 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread ext/TensorKitEnzymeExt/linalg.jl
@github-actions

github-actions Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Your PR no longer requires formatting changes. Thank you for your contribution!

@kshyatt kshyatt marked this pull request as draft June 11, 2026 07:18
@kshyatt kshyatt marked this pull request as ready for review June 11, 2026 09:26
@kshyatt

kshyatt commented Jun 12, 2026

Copy link
Copy Markdown
Member Author

The test on 1.12 is passing locally for me! I assume it's getting OOMed or something...

@kshyatt

kshyatt commented Jun 12, 2026

Copy link
Copy Markdown
Member Author

OK, everything looks happy now except the GPU stuff which is unrelated. Are we good to go?

project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx))

# in-place multiplication and accumulation which might project to (real)
# TODO: this could probably be done without allocating

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not, because BLAS doesn't support matrix multiplication with matrices that do not have stride one, which is what you will get if you try to isolate the real and imaginary part of a complex matrix, in order to do something like
C = Ar * Br - Ai * Bi.

We might be able to make the final real as a view instead of an extra allocation.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That one is a comment of @lkdvos (I ported these over from the Mooncake extension) so I'll defer to him here 😉

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought there was a trick to make use of BLAS anyways, especially if one of the arrays is real, but I am now no longer really convinced... In any case, this is probably trying to over-optimize anyways, so I'm definitely not saying this is something to tackle here...

(the trick I was thinking of probably is this:)

If you have something like A' * B, you can turn this into a (m x 2n) * (2n x k) multiplication simply by reinterpreting as real arrays, in which case this will compute exactly transpose(Ar) * Br - transpose(Ai) * Bi, without allocating and with BLAS. The problem being that this is not super useful here, since we don't actually get to choose the layout of the matrices 😁

Comment on lines +48 to +64
trivtuple(N) = ntuple(identity, N)

Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
return TupleTools.getindices(p, trivtuple(N₁)),
TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁)
end
Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int)
return _repartition(linearize(p), N₁)
end
function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
return _repartition(p, N₁)
end
function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap)
return _repartition(p, TensorKit.numout(t))
end

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like much of this must already exist in the main TensorKit module, or in TensorOperations. If not, it is possibly useful to add it in the main module. (TensorOperations.jl/src/indices.jl has some stuff).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, especially since this is now duplicated across the two AD extensions. I'll see if I can't move it into the main TK module.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think that these are now in the later versions of TensorOperations, so it's probably better to use that? E.g.: https://github.com/QuantumKitHub/TensorOperations.jl/blob/1cf29e29e43efac41f713292d99a9d58b7078c57/src/indices.jl#L35-L44

Comment on lines +75 to +77
!isa(α, Const) && project_mul!(C.dval, A.val, B.val, α.dval)
!isa(A, Const) && project_mul!(C.dval, A.dval, B.val, α.val)
!isa(B, Const) && project_mul!(C.dval, A.val, B.dval, α.val)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are probably all regular mul! calls with beta = 1. Maybe it is a bit confusing that project_mul! does admit a beta and implicitly has beta=1 (I understand this is all you need for the tangents, but still), whereas mul! has default beta=0.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it might be nicer to simply have the default beta = 0 and be explicit about beta = 1 here, also since it makes it more clear that this is accumulating when reading

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants