Forward and reverse Enzyme tests and rules for linalg#449
Conversation
Codecov Report❌ Patch coverage is
... and 12 files with indirect coverage changes 🚀 New features to boost your workflow:
|
|
Your PR no longer requires formatting changes. Thank you for your contribution! |
|
The test on 1.12 is passing locally for me! I assume it's getting OOMed or something... |
|
OK, everything looks happy now except the GPU stuff which is unrelated. Are we good to go? |
| project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) | ||
|
|
||
| # in-place multiplication and accumulation which might project to (real) | ||
| # TODO: this could probably be done without allocating |
There was a problem hiding this comment.
Probably not, because BLAS doesn't support matrix multiplication with matrices that do not have stride one, which is what you will get if you try to isolate the real and imaginary part of a complex matrix, in order to do something like
C = Ar * Br - Ai * Bi.
We might be able to make the final real as a view instead of an extra allocation.
There was a problem hiding this comment.
That one is a comment of @lkdvos (I ported these over from the Mooncake extension) so I'll defer to him here 😉
There was a problem hiding this comment.
I thought there was a trick to make use of BLAS anyways, especially if one of the arrays is real, but I am now no longer really convinced... In any case, this is probably trying to over-optimize anyways, so I'm definitely not saying this is something to tackle here...
(the trick I was thinking of probably is this:)
If you have something like A' * B, you can turn this into a (m x 2n) * (2n x k) multiplication simply by reinterpreting as real arrays, in which case this will compute exactly transpose(Ar) * Br - transpose(Ai) * Bi, without allocating and with BLAS. The problem being that this is not super useful here, since we don't actually get to choose the layout of the matrices 😁
| trivtuple(N) = ntuple(identity, N) | ||
|
|
||
| Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) | ||
| length(p) >= N₁ || | ||
| throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) | ||
| return TupleTools.getindices(p, trivtuple(N₁)), | ||
| TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) | ||
| end | ||
| Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) | ||
| return _repartition(linearize(p), N₁) | ||
| end | ||
| function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} | ||
| return _repartition(p, N₁) | ||
| end | ||
| function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) | ||
| return _repartition(p, TensorKit.numout(t)) | ||
| end |
There was a problem hiding this comment.
It feels like much of this must already exist in the main TensorKit module, or in TensorOperations. If not, it is possibly useful to add it in the main module. (TensorOperations.jl/src/indices.jl has some stuff).
There was a problem hiding this comment.
Yeah, especially since this is now duplicated across the two AD extensions. I'll see if I can't move it into the main TK module.
There was a problem hiding this comment.
I actually think that these are now in the later versions of TensorOperations, so it's probably better to use that? E.g.: https://github.com/QuantumKitHub/TensorOperations.jl/blob/1cf29e29e43efac41f713292d99a9d58b7078c57/src/indices.jl#L35-L44
| !isa(α, Const) && project_mul!(C.dval, A.val, B.val, α.dval) | ||
| !isa(A, Const) && project_mul!(C.dval, A.dval, B.val, α.val) | ||
| !isa(B, Const) && project_mul!(C.dval, A.val, B.dval, α.val) |
There was a problem hiding this comment.
These are probably all regular mul! calls with beta = 1. Maybe it is a bit confusing that project_mul! does admit a beta and implicitly has beta=1 (I understand this is all you need for the tangents, but still), whereas mul! has default beta=0.
There was a problem hiding this comment.
I agree that it might be nicer to simply have the default beta = 0 and be explicit about beta = 1 here, also since it makes it more clear that this is accumulating when reading
Trying to make these a little more manageable and pick up the fwd rules where possible