From b57ade91c6575746a0166871f94afc956c1ed5b4 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 9 Sep 2025 19:56:22 +0200 Subject: [PATCH 1/9] Reverse rules for Enzyme --- Project.toml | 9 +- .../MatrixAlgebraKitEnzymeExt.jl | 438 ++++++++++++++++ src/common/initialization.jl | 4 + src/common/safemethods.jl | 3 +- test/enzyme.jl | 487 ++++++++++++++++++ test/runtests.jl | 4 + 6 files changed, 943 insertions(+), 2 deletions(-) create mode 100644 ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl create mode 100644 test/enzyme.jl diff --git a/Project.toml b/Project.toml index 72bc22be..c45ec893 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a" GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" @@ -18,6 +19,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" MatrixAlgebraKitAMDGPUExt = "AMDGPU" MatrixAlgebraKitCUDAExt = "CUDA" +MatrixAlgebraKitEnzymeExt = "Enzyme" MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra" MatrixAlgebraKitGenericSchurExt = "GenericSchur" MatrixAlgebraKitMooncakeExt = "Mooncake" @@ -30,6 +32,8 @@ ChainRulesTestUtils = "1" CUDA = "5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" +Enzyme = "0.13.116" +EnzymeTestUtils = "0.2.5" JET = "0.9, 0.10" LinearAlgebra = "1" Mooncake = "0.4.183" @@ -47,6 +51,8 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc" @@ -60,4 +66,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", - "GenericLinearAlgebra", "GenericSchur", "Mooncake", "ParallelTestRunner"] + "GenericLinearAlgebra", "GenericSchur", "Mooncake", "Enzyme", + "EnzymeTestUtils", "ParallelTestRunner"] diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl new file mode 100644 index 00000000..ae3f9c0e --- /dev/null +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -0,0 +1,438 @@ +module MatrixAlgebraKitEnzymeExt + +using MatrixAlgebraKit +using MatrixAlgebraKit: copy_input, initialize_output, zero! +using MatrixAlgebraKit: diagview, inv_safe, truncate +using MatrixAlgebraKit: qr_pullback!, lq_pullback! +using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback! +using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! +using Enzyme +using Enzyme.EnzymeCore +using Enzyme.EnzymeCore: EnzymeRules +using LinearAlgebra + +@inline EnzymeRules.inactive_type(::Type{Alg}) where {Alg <: MatrixAlgebraKit.AbstractAlgorithm} = true +@inline EnzymeRules.inactive_type(::Type{TS}) where {TS <: MatrixAlgebraKit.TruncationStrategy} = true +@inline EnzymeRules.inactive(f::typeof(MatrixAlgebraKit.select_algorithm), func::F, A::AbstractMatrix, alg::Alg) where {F, Alg} = true +@inline EnzymeRules.inactive(f::typeof(MatrixAlgebraKit.default_algorithm), func::F, A::AbstractMatrix) where {F} = true +@inline EnzymeRules.inactive(f::typeof(MatrixAlgebraKit.check_input), func::F, A::AbstractMatrix, alg::Alg) where {F, Alg} = true +@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.defaulttol), ::Any) = true +@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_pullback_gauge_atol), ::Any) = true +@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_pullback_gauge_atol), ::Any, ::Any...) = true +@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_pullback_degeneracy_atol), ::Any) = true +@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_pullback_rank_atol), ::Any) = true +@inline EnzymeRules.inactive(::typeof(MatrixAlgebraKit.default_hermitian_tol), ::Any) = true + +#----------- NOTE about derivatives --------- +# Each Enzyme augmented_return + reverse pair +# has a "tape" or "cache" -- we can place +# variables on this tape that can be accessed +# in the return pass *after* they have been +# "filled in" with accumulated derivatives. +# For many of the rules here, we may create a +# placeholder (usually called `dret`) for +# variables which may be instantiated, then, +# earlier in the reverse pass, this `dret` is +# filled in with accumulated derivatives for +# the created variable. It can then be used +# to update the derivative of `A` or any +# other provided input variable. +#-------------------------------------------- + +# this rule is necessary for now as without it, +# a segfault occurs both on 1.10 and 1.12 -- likely +# a deeper internal bug +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(copy_input)}, + ::Type{RT}, + f::Annotation, + A::Annotation + ) where {RT} + ret = func.val(f.val, A.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(A.dval) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, shadow) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(copy_input)}, + dret::Type{RT}, + cache, + f::Annotation, + A::Annotation + ) where {RT} + copy_shadow = cache + if !isa(A, Const) && !isnothing(copy_shadow) + A.dval .+= copy_shadow + end + return (nothing, nothing) +end + +# needed as Enzyme can't diff through this on 1.12 +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(initialize_output)}, + ::Type{RT}, + f::Const, + A::Annotation, + alg::Const + ) where {RT} + ret = func.val(f.val, A.val, alg.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(initialize_output)}, + dret::Type{RT}, + cache, + f::Const, + A::Annotation, + alg::Const + ) where {RT} + return (nothing, nothing, nothing) +end + +# two-argument factorizations like LQ, QR, EIG +for (f, pb) in ( + (qr_full!, qr_pullback!), + (lq_full!, lq_pullback!), + (qr_compact!, qr_pullback!), + (lq_compact!, lq_pullback!), + (eig_full!, eig_pullback!), + (eigh_full!, eigh_pullback!), + (left_polar!, left_polar_pullback!), + (right_polar!, right_polar_pullback!), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + arg::Annotation{Tuple{TA, TB}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA, TB} + # form cache if needed + cache_A = !isa(A, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + cache_arg = !isa(arg, Const) && EnzymeRules.overwritten(config)[3] ? copy(arg.val) : nothing + ret = func.val(A.val, arg.val, alg.val) + dret = (TA == Nothing && TB == Nothing) ? make_zero!.(similar.(ret)) : arg.dval + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? dret : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg, ret, dret)) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation, + arg::Annotation{Tuple{TA, TB}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA, TB} + cache_A, cache_arg, argval, darg = cache + argval = arg.val + Aval = something(cache_A, A.val) + argval = something(cache_arg, argval) + ∂arg = isa(arg, Const) ? (nothing, nothing) : darg + if !isa(A, Const) + $pb(A.dval, Aval, argval, ∂arg) + end + !isa(arg, Const) && make_zero!(arg.dval) + return (nothing, nothing, nothing) + end + end +end + +for (f, pb) in ( + (qr_null!, qr_null_pullback!), + (lq_null!, lq_null_pullback!), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + arg::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A = copy(A.val) + ret = func.val(A.val, arg.val, alg.val) + dret = isa(arg, Const) ? nothing : arg.dval + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? dret : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, dret)) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation, + arg::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A, dret = cache + Aval = something(cache_A, A.val) + if !isa(A, Const) + $pb(A.dval, Aval, arg.val, dret) + end + !isa(arg, Const) && make_zero!(arg.dval) + return (nothing, nothing, nothing) + end + end +end + +for f in (:svd_compact!, :svd_full!) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + # form cache if needed + cache_A = !isa(A, Const) ? copy(A.val) : nothing + ret = func.val(A.val, USVᴴ.val, alg.val) + cache_USVᴴ = EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing + # the USVᴴ may be nothing for eltypes handled by GenericLinearAlgebra + dret = all(isnothing, USVᴴ.val) ? zero.(ret) : USVᴴ.dval + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? dret : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, ret, dret)) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A, cache_USVᴴ, USVᴴval, dUSVᴴ = cache + Aval = something(cache_A, A.val) + USVᴴval = something(cache_USVᴴ, USVᴴval) + U, S, Vᴴ = USVᴴval + ∂USVᴴ = dUSVᴴ + if !isa(A, Const) + minmn = min(size(A.val)...) + if $(f == svd_compact!) # compact + svd_pullback!(A.dval, Aval, USVᴴval, ∂USVᴴ) + else # full + vU = view(U, :, 1:minmn) + vS = Diagonal(diagview(S)[1:minmn]) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(∂USVᴴ[1], :, 1:minmn) + vdS = Diagonal(diagview(∂USVᴴ[2])[1:minmn]) + vdVᴴ = view(∂USVᴴ[3], 1:minmn, :) + svd_pullback!(A.dval, Aval, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) + end + end + !isa(USVᴴ, Const) && make_zero!(USVᴴ.dval) + return (nothing, nothing, nothing) + end + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc_no_error!)}, + ::Type{RT}, + A::Annotation, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + # form cache if needed + cache_A = !isa(A, Const) ? copy(A.val) : nothing + ret = svd_compact!(A.val, USVᴴ.val, alg.val.alg) + cache_USVᴴ = copy.(ret) + USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, ret, alg.val.trunc) + primal = EnzymeRules.needs_primal(config) ? USVᴴ′ : nothing + shadow_USVᴴ = if !isa(A, Const) + # the USVᴴ may be nothing for eltypes handled by GenericLinearAlgebra + dU, dS, dVᴴ = all(isnothing, USVᴴ.val) ? zero.(ret) : USVᴴ.dval + # This creates new output shadow matrices, we do this slicing + # to ensure they have the correct eltype and dimensions. + # These new shadow matrices are "filled in" with the accumulated + # results from earlier in reverse-mode AD after this function exits + # and before `reverse` is called. + dStrunc = Diagonal(diagview(dS)[ind]) + dUtrunc = dU[:, ind] + dVᴴtrunc = dVᴴ[ind, :] + (dUtrunc, dStrunc, dVᴴtrunc) + else + (nothing, nothing, nothing) + end + shadow = EnzymeRules.needs_shadow(config) ? shadow_USVᴴ : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, shadow_USVᴴ, ind)) +end + + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc_no_error!)}, + dret::Type{RT}, + cache, + A::Annotation, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache + Aval = something(cache_A, A.val) + if !isa(A, Const) + svd_pullback!(A.dval, Aval, cache_USVᴴ, shadow_USVᴴ, ind) + end + !isa(USVᴴ, Const) && make_zero!(USVᴴ.dval) + return (nothing, nothing, nothing) +end + +for (f, trunc_f, full_f, pb) in ( + (:eigh_trunc_no_error!, :eigh_trunc!, :eigh_full!, :eigh_pullback!), + (:eig_trunc_no_error!, :eig_trunc!, :eig_full!, :eig_pullback!), + ) + @eval function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + DV::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + # form cache if needed + cache_A = !isa(A, Const) ? copy(A.val) : nothing + ret = $full_f(A.val, DV.val, alg.val.alg) + cache_DV = copy.(ret) + DV′, ind = truncate($trunc_f, ret, alg.val.trunc) + primal = EnzymeRules.needs_primal(config) ? DV′ : nothing + shadow_DV = if !isa(A, Const) && !isa(DV, Const) + dD, dV = all(isnothing, DV.val) ? zero!.(similar.(ret)) : DV.dval + dDtrunc = Diagonal(diagview(dD)[ind]) + dVtrunc = dV[:, ind] + (dDtrunc, dVtrunc) + else + (nothing, nothing) + end + shadow = EnzymeRules.needs_shadow(config) ? shadow_DV : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind)) + end + @eval function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation, + DV::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A, cache_DV, cache_dDVtrunc, ind = cache + Aval = something(cache_A, A.val) + D, V = something(cache_DV, DV) + dD, dV = cache_dDVtrunc + if !isa(A, Const) + $pb(A.dval, Aval, (D, V), (dD, dV), ind) + end + !isa(DV, Const) && make_zero!(DV.dval) + return (nothing, nothing, nothing) + end +end + +for (f!, f_full!, pb!) in ( + (eig_vals!, eig_full!, eig_vals_pullback!), + (eigh_vals!, eigh_full!, eigh_vals_pullback!), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation, + D::Annotation, + alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A = !isa(A, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + nD, V = initialize_output($f_full!, A.val, alg.val) + nD, V = $f_full!(A.val, (nD, V), alg.val) + ret = something(D.val, similar(A.val, eltype(nD), length(diagview(nD)))) + dret = something(D.dval, zero!(similar(A.val, eltype(nD), length(diagview(nD))))) + copy!(ret, diagview(nD)) + cache_D = EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? dret : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D, ret, dret, V)) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + cache, + A::Annotation, + D::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A, cache_D, Dval, dD, V = cache + Dval = something(cache_D, Dval) + Aval = something(cache_A, A.val) + ∂D = isa(D, Const) ? nothing : dD + if !isa(A, Const) + $pb!(A.dval, Aval, (Diagonal(Dval), V), ∂D) + end + !isa(D, Const) && make_zero!(D.dval) + return (nothing, nothing, nothing) + end + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_vals!)}, + ::Type{RT}, + A::Annotation, + S::Annotation, + alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A = !isa(A, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + U, nS, Vᴴ = svd_compact!(A.val, alg.val) + ret = something(S.val, similar(A.val, real(eltype(A.val)), length(diagview(nS)))) + dret = if isa(S, Const) + zero!(similar(ret)) + else + something(S.dval, zero!(similar(A.val, real(eltype(A.val)), length(diagview(nS))))) + end + copy!(ret, diagview(nS)) + cache_S = EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? dret : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_S, ret, dret, U, Vᴴ)) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_vals!)}, + ::Type{RT}, + cache, + A::Annotation, + S::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A, cache_S, Sval, dS, U, Vᴴ = cache + Sval = something(cache_S, Sval) + Aval = something(cache_A, A.val) + ∂S = dS + if !isa(A, Const) + svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), ∂S) + end + !isa(S, Const) && make_zero!(S.dval) + return (nothing, nothing, nothing) +end + +end diff --git a/src/common/initialization.jl b/src/common/initialization.jl index cdabcf43..b94d7a84 100644 --- a/src/common/initialization.jl +++ b/src/common/initialization.jl @@ -3,6 +3,10 @@ function zero!(A::AbstractArray) A .= zero(eltype(A)) return A end +function zero!(A::Diagonal) + diagview(A) .= zero(eltype(A)) + return A +end function one!(A::AbstractMatrix) length(A) > 0 || return A diff --git a/src/common/safemethods.jl b/src/common/safemethods.jl index 62f23a4e..43b06513 100644 --- a/src/common/safemethods.jl +++ b/src/common/safemethods.jl @@ -13,8 +13,9 @@ sign_safe(s::Complex) = ifelse(iszero(s), one(s), s / abs(s)) # Inverse """ - function inv_safe(a::Number, tol = defaulttol(a)) + inv_safe(a::Number, tol = defaulttol(a)) Compute the inverse of a number `a`, but return zero if `a` is smaller than `tol`. """ inv_safe(a::Number, tol = defaulttol(a)) = abs(a) < tol ? zero(a) : inv(a) +@noinline inv_safe(a::ComplexF32, tol = defaulttol(a)) = abs(a) < tol ? zero(a) : inv(a) diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 00000000..19588951 --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,487 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using ChainRulesCore +using Enzyme, EnzymeTestUtils +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!, BlasFloat +using GenericLinearAlgebra + +# https://github.com/EnzymeAD/Enzyme.jl/issues/2888, +# test_reverse doesn't work with BigFloat + +ETs = @static if VERSION < v"1.12.0" + (ComplexF64, BigFloat) +else + (ComplexF64,) +end +include("ad_utils.jl") +function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated) + ΔA = randn(rng, eltype(A), size(A)...) + A_ΔA() = Duplicated(copy(A), copy(ΔA)) + function args_Δargs() + if isnothing(args) + return Const(args) + elseif args isa Tuple && all(isnothing, args) + return Const(args) + else + return Duplicated(copy.(args), copy.(Δargs)) + end + end + copy_activities = isnothing(alg) ? (Const(f), A_ΔA()) : (Const(f), A_ΔA(), Const(alg)) + inplace_activities = isnothing(alg) ? (Const(f!), A_ΔA(), args_Δargs()) : (Const(f!), A_ΔA(), args_Δargs(), Const(alg)) + + mode = EnzymeTestUtils.set_runtime_activity(ReverseSplitWithPrimal, false) + c_act = Const(EnzymeTestUtils.call_with_kwargs) + forward_copy, reverse_copy = autodiff_thunk( + mode, typeof(c_act), return_act, typeof(Const(())), map(typeof, copy_activities)... + ) + forward_inplace, reverse_inplace = autodiff_thunk( + mode, typeof(c_act), return_act, typeof(Const(())), map(typeof, inplace_activities)... + ) + copy_tape, copy_y_ad, copy_shadow_result = forward_copy(c_act, Const(()), copy_activities...) + inplace_tape, inplace_y_ad, inplace_shadow_result = forward_inplace(c_act, Const(()), inplace_activities...) + if !(copy_shadow_result === nothing) + flush(stdout) + EnzymeTestUtils.map_fields_recursive(copyto!, copy_shadow_result, copy.(ȳ)) + end + if !(inplace_shadow_result === nothing) + EnzymeTestUtils.map_fields_recursive(copyto!, inplace_shadow_result, copy.(ȳ)) + end + dx_copy_ad = only(reverse_copy(c_act, Const(()), copy_activities..., copy_tape)) + dx_inplace_ad = only(reverse_inplace(c_act, Const(()), inplace_activities..., inplace_tape)) + # check all returned derivatives between copy & inplace + for (i, (copy_act_i, inplace_act_i)) in enumerate(zip(copy_activities[2:end], inplace_activities[2:end])) + if copy_act_i isa Duplicated && inplace_act_i isa Duplicated + msg_deriv = "shadow derivative for argument $(i - 1) should match between copy and inplace" + EnzymeTestUtils.test_approx(copy_act_i.dval, inplace_act_i.dval, msg_deriv) + end + end + return +end + +@timedtestset "QR AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + A = randn(rng, T, m, n) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + alg = MatrixAlgebraKit.default_qr_algorithm(A) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "qr_compact" begin + ΔQR = (randn(rng, T, m, minmn), randn(rng, T, minmn, n)) + Q, R = qr_compact(A, alg) + QR = MatrixAlgebraKit.initialize_output(qr_compact!, A, alg) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + T <: BlasFloat && test_reverse(qr_compact, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = ΔQR, fdm = fdm) + test_pullbacks_match(rng, qr_compact!, qr_compact, A, QR, ΔQR, alg) + end + @testset "qr_null" begin + Q, R = qr_compact(A, alg) + N = zeros(T, m, max(0, m - minmn)) + ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + T <: BlasFloat && test_reverse(qr_null, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = ΔN) + test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) + end + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn(rng, T, m, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + T <: BlasFloat && test_reverse(qr_full, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (ΔQ, ΔR), fdm = fdm) + test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) + end + @testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Q, R = qr_compact(Ard, alg) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn(rng, T, minmn, n) + view(ΔR, (r + 1):minmn, :) .= 0 + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + T <: BlasFloat && test_reverse(qr_compact, RT, (Ard, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (ΔQ, ΔR), fdm = fdm) + test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) + end + end + end +end + +@timedtestset "LQ AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + A = randn(rng, T, m, n) + alg = MatrixAlgebraKit.default_lq_algorithm(A) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "lq_compact" begin + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + L, Q = lq_compact(A, alg) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + T <: BlasFloat && test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (ΔL, ΔQ), fdm = fdm) + test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (ΔL, ΔQ), alg) + end + @testset "lq_null" begin + L, Q = lq_compact(A, alg) + ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + Nᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + T <: BlasFloat && test_reverse(lq_null, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = ΔNᴴ) + # runtime activity problems here with BigFloat + T <: BlasFloat && test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) + end + @testset "lq_full" begin + L, Q = lq_full(A, alg) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn(rng, T, n, n) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = randn(rng, T, m, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + T <: BlasFloat && test_reverse(lq_full, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (ΔL, ΔQ), fdm = fdm) + test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) + end + @testset "lq_compact -- rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + L, Q = lq_compact(Ard, alg) + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + T <: BlasFloat && test_reverse(lq_compact, RT, (Ard, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (ΔL, ΔQ), fdm = fdm) + test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) + end + end + end +end + +@timedtestset "EIG AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = make_eig_matrix(rng, T, m) + D, V = eig_full(A) + Ddiag = diagview(D) + ΔV = randn(rng, complex(T), m, m) + ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) + ΔD = randn(rng, complex(T), m, m) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + alg = MatrixAlgebraKit.default_eig_algorithm(A) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) + test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) + test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm) + test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) + end + @testset "eig_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(diagview(D), truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + T <: BlasFloat && test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(Ddiag[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + T <: BlasFloat && test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ = (ΔDtrunc, ΔVtrunc)) + end +end + + +function copy_eigh_full(A, alg) + A = (A + A') / 2 + return eigh_full(A, alg) +end + +function copy_eigh_full!(A, DV::Tuple, alg::MatrixAlgebraKit.AbstractAlgorithm) + A = (A + A') / 2 + return eigh_full!(A, DV, alg) +end + +function copy_eigh_vals(A; kwargs...) + A = (A + A') / 2 + return eigh_vals(A; kwargs...) +end + +function copy_eigh_vals!(A, D; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D; kwargs...) +end + +function copy_eigh_vals(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_vals(A, alg; kwargs...) +end + +function copy_eigh_vals!(A, D, alg; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D, alg; kwargs...) +end + +function copy_eigh_trunc_no_error(A, alg) + A = (A + A') / 2 + return eigh_trunc_no_error(A, alg) +end + +function copy_eigh_trunc_no_error!(A, DV, alg) + A = (A + A') / 2 + return eigh_trunc_no_error!(A, DV, alg) +end + +# https://github.com/EnzymeAD/Enzyme.jl/issues/2889 +# the addition methods cannot be compiled +@timedtestset "EIGH AD Rules with eltype $T" for T in filter(T -> <:(T, BlasFloat), ETs) + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = make_eigh_matrix(rng, T, m) + #A = (A + A') / 2 + D, V = eigh_full(A) + D2 = Diagonal(D) + ΔV = randn(rng, T, m, m) + ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) + ΔD = randn(rng, real(T), m, m) + ΔD2 = Diagonal(randn(rng, real(T), m)) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + alg = MatrixAlgebraKit.default_eigh_algorithm(A) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + T <: BlasFloat && test_reverse(copy_eigh_full, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) + T <: BlasFloat && test_reverse(copy_eigh_full!, RT, (copy(A), TA), ((D, V), TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) + test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) + T <: BlasFloat && test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm) + test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) + end + @testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:m + Ddiag = diagview(D) + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + T <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm) + test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) + end + Ddiag = diagview(D) + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + T <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm) + test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) + end +end + +@timedtestset "SVD AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + alg = MatrixAlgebraKit.default_svd_algorithm(A) + minmn = min(m, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "svd_compact" begin + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + T <: BlasFloat && test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (ΔU, ΔS, ΔVᴴ), fdm = fdm) + if T <: BlasFloat + test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), alg) + else + USVᴴ = MatrixAlgebraKit.initialize_output(svd_compact!, A, alg) + test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = (ΔU, ΔS, ΔVᴴ)) + end + end + @testset "svd_full" begin + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + ΔUfull = zeros(T, m, m) + ΔSfull = zeros(real(T), m, n) + ΔVᴴfull = zeros(T, n, n) + U, S, Vᴴ = svd_full(A) + view(ΔUfull, :, 1:minmn) .= ΔU + view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ + diagview(ΔSfull)[1:minmn] .= diagview(ΔS) + T <: BlasFloat && test_reverse(svd_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (ΔUfull, ΔSfull, ΔVᴴfull), fdm = fdm) + if T <: BlasFloat + test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) + else + USVᴴ = MatrixAlgebraKit.initialize_output(svd_full!, A, alg) + test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = (ΔUfull, ΔSfull, ΔVᴴfull)) + end + end + @testset "svd_vals" begin + S = svd_vals(A) + ΔS = randn(rng, real(T), minmn) + T <: BlasFloat && test_reverse(svd_vals, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = ΔS, fdm = fdm) + if T <: BlasFloat + test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg) + else + S = MatrixAlgebraKit.initialize_output(svd_vals!, A, alg) + test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, nothing, alg; ȳ = ΔS) + end + end + end + # SVD trunc tests segfault with Enzyme on 1.12, possibly will be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/2902 + # but this also needs a new 1.12 release to come out with the backport of https://github.com/JuliaLang/julia/pull/60695 + @static if VERSION < v"1.12.0" + @testset "svd_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:minmn + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + USVᴴ = T <: BlasFloat ? (U, S, Vᴴ) : (nothing, nothing, nothing) + T <: BlasFloat && test_reverse(svd_trunc_no_error!, RT, (A, TA), (USVᴴ, Duplicated), (truncalg, Const); atol, rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm) + ΔUSVᴴ = T <: BlasFloat ? (ΔU, ΔS2, ΔVᴴ) : (nothing, nothing, nothing) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + end + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + T <: BlasFloat && test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm) + USVᴴ = T <: BlasFloat ? (U, S, Vᴴ) : (nothing, nothing, nothing) + ΔUSVᴴ = T <: BlasFloat ? (ΔU, ΔS2, ΔVᴴ) : (nothing, nothing, nothing) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + end + end + end +end + +# GLA works with polar, but these tests +# segfault because of Sylvester + BigFloat +@timedtestset "Polar AD Rules with eltype $T" for T in filter(T -> <:(T, BlasFloat), ETs) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + alg = MatrixAlgebraKit.default_polar_algorithm(A) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + if m >= n + WP = left_polar(A; alg = alg) + W, P = WP + ΔWP = randn(rng, T, size(W)...), randn(rng, T, size(P)...) + T <: BlasFloat && test_reverse(left_polar, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol) + test_pullbacks_match(rng, left_polar!, left_polar, A, WP, ΔWP, alg) + elseif m <= n + PWᴴ = right_polar(A; alg = alg) + P, Wᴴ = PWᴴ + ΔPWᴴ = randn(rng, T, size(P)...), randn(rng, T, size(Wᴴ)...) + T <: BlasFloat && test_reverse(right_polar, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol) + test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, ΔPWᴴ, alg) + end + end + end +end + +# GLA not working with orthnull yet +@timedtestset "Orth and null with eltype $T" for T in filter(T -> <:(T, BlasFloat), ETs) + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "left_orth" begin + @testset for alg in (:polar, :qr) + n > m && alg == :polar && continue + VC = left_orth(A; alg = alg) + V, C = VC + ΔV = randn(rng, T, size(V)...) + ΔC = randn(rng, T, size(C)...) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + T <: BlasFloat && test_reverse(left_orth, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), fdm = fdm) + left_orth_alg!(A, VC) = left_orth!(A, VC; alg = alg) + left_orth_alg(A) = left_orth(A; alg = alg) + test_pullbacks_match(rng, left_orth_alg!, left_orth_alg, A, (V, C), (ΔV, ΔC)) + end + end + @testset "right_orth" begin + @testset for alg in (:polar, :lq) + n < m && alg == :polar && continue + CVᴴ = right_orth(A; alg = alg) + C, Vᴴ = CVᴴ + ΔC = randn(rng, T, size(C)...) + ΔVᴴ = randn(rng, T, size(Vᴴ)...) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + T <: BlasFloat && test_reverse(right_orth, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), fdm = fdm) + right_orth_alg!(A, CVᴴ) = right_orth!(A, CVᴴ; alg = alg) + right_orth_alg(A) = right_orth(A; alg = alg) + test_pullbacks_match(rng, right_orth_alg!, right_orth_alg, A, (C, Vᴴ), (ΔC, ΔVᴴ)) + end + end + @testset "left_null" begin + ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + N = similar(ΔN) + left_null_qr!(A, N) = left_null!(A, N; alg = :qr) + left_null_qr(A) = left_null(A; alg = :qr) + T <: BlasFloat && test_reverse(left_null_qr, RT, (A, TA); output_tangent = ΔN, atol = atol, rtol = rtol) + test_pullbacks_match(rng, left_null_qr!, left_null_qr, A, N, ΔN) + end + @testset "right_null" begin + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] + Nᴴ = similar(ΔNᴴ) + right_null_lq!(A, Nᴴ) = right_null!(A, Nᴴ; alg = :lq) + right_null_lq(A) = right_null(A; alg = :lq) + T <: BlasFloat && test_reverse(right_null_lq, RT, (A, TA); output_tangent = ΔNᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, right_null_lq!, right_null_lq, A, Nᴴ, ΔNᴴ) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2fd630d2..28f220f6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,8 +23,12 @@ if filter_tests!(testsuite, args) delete!(testsuite, "truncate") delete!(testsuite, "gen_eig") delete!(testsuite, "mooncake") + delete!(testsuite, "enzyme") delete!(testsuite, "chainrules") delete!(testsuite, "codequality") + else + is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true" + (Sys.iswindows() || is_apple_ci) && delete!(testsuite, "enzyme") end end From c1e7ff358a2ddbe440dfc27af65008b422b8c8e3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 19 Jan 2026 15:22:34 +0100 Subject: [PATCH 2/9] Working on 1.10, 1.11, 1.12 --- Project.toml | 2 +- .../MatrixAlgebraKitEnzymeExt.jl | 99 ++++++------------- test/enzyme.jl | 57 +++++++---- 3 files changed, 69 insertions(+), 89 deletions(-) diff --git a/Project.toml b/Project.toml index c45ec893..f1da10b4 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ ChainRulesTestUtils = "1" CUDA = "5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" -Enzyme = "0.13.116" +Enzyme = "0.13.118" EnzymeTestUtils = "0.2.5" JET = "0.9, 0.10" LinearAlgebra = "1" diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index ae3f9c0e..a3b51db8 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -72,33 +72,6 @@ function EnzymeRules.reverse( return (nothing, nothing) end -# needed as Enzyme can't diff through this on 1.12 -function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(initialize_output)}, - ::Type{RT}, - f::Const, - A::Annotation, - alg::Const - ) where {RT} - ret = func.val(f.val, A.val, alg.val) - primal = EnzymeRules.needs_primal(config) ? ret : nothing - shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing - return EnzymeRules.AugmentedReturn(primal, shadow, nothing) -end - -function EnzymeRules.reverse( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(initialize_output)}, - dret::Type{RT}, - cache, - f::Const, - A::Annotation, - alg::Const - ) where {RT} - return (nothing, nothing, nothing) -end - # two-argument factorizations like LQ, QR, EIG for (f, pb) in ( (qr_full!, qr_pullback!), @@ -123,7 +96,7 @@ for (f, pb) in ( cache_A = !isa(A, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing cache_arg = !isa(arg, Const) && EnzymeRules.overwritten(config)[3] ? copy(arg.val) : nothing ret = func.val(A.val, arg.val, alg.val) - dret = (TA == Nothing && TB == Nothing) ? make_zero!.(similar.(ret)) : arg.dval + dret = (TA == Nothing && TB == Nothing) ? zero.(ret) : arg.dval primal = EnzymeRules.needs_primal(config) ? ret : nothing shadow = EnzymeRules.needs_shadow(config) ? dret : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg, ret, dret)) @@ -138,10 +111,9 @@ for (f, pb) in ( alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TA, TB} cache_A, cache_arg, argval, darg = cache - argval = arg.val Aval = something(cache_A, A.val) argval = something(cache_arg, argval) - ∂arg = isa(arg, Const) ? (nothing, nothing) : darg + ∂arg = darg if !isa(A, Const) $pb(A.dval, Aval, argval, ∂arg) end @@ -180,10 +152,10 @@ for (f, pb) in ( arg::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} - cache_A, dret = cache + cache_A, darg = cache Aval = something(cache_A, A.val) if !isa(A, Const) - $pb(A.dval, Aval, arg.val, dret) + $pb(A.dval, Aval, arg.val, darg) end !isa(arg, Const) && make_zero!(arg.dval) return (nothing, nothing, nothing) @@ -198,15 +170,22 @@ for f in (:svd_compact!, :svd_full!) func::Const{typeof($f)}, ::Type{RT}, A::Annotation, - USVᴴ::Annotation, + USVᴴ::Annotation{Tuple{TU, TS, TVᴴ}}, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} + ) where {RT, TU, TS, TVᴴ} # form cache if needed cache_A = !isa(A, Const) ? copy(A.val) : nothing ret = func.val(A.val, USVᴴ.val, alg.val) cache_USVᴴ = EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing # the USVᴴ may be nothing for eltypes handled by GenericLinearAlgebra - dret = all(isnothing, USVᴴ.val) ? zero.(ret) : USVᴴ.dval + dret = if (TU == TS == TVᴴ == Nothing) + dU = zero(ret[1]) + dS = $(f == svd_compact!) ? Diagonal(zero(ret[2].diag)) : zero(ret[2]) + dVᴴ = zero(ret[3]) + (dU, dS, dVᴴ) + else + USVᴴ.dval + end primal = EnzymeRules.needs_primal(config) ? ret : nothing shadow = EnzymeRules.needs_shadow(config) ? dret : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, ret, dret)) @@ -220,11 +199,10 @@ for f in (:svd_compact!, :svd_full!) USVᴴ::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} - cache_A, cache_USVᴴ, USVᴴval, dUSVᴴ = cache + cache_A, cache_USVᴴ, USVᴴval, ∂USVᴴ = cache Aval = something(cache_A, A.val) USVᴴval = something(cache_USVᴴ, USVᴴval) U, S, Vᴴ = USVᴴval - ∂USVᴴ = dUSVᴴ if !isa(A, Const) minmn = min(size(A.val)...) if $(f == svd_compact!) # compact @@ -250,30 +228,21 @@ function EnzymeRules.augmented_primal( func::Const{typeof(svd_trunc_no_error!)}, ::Type{RT}, A::Annotation, - USVᴴ::Annotation, + USVᴴ::Annotation{Tuple{TU, TS, TVᴴ}}, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} + ) where {RT, TU, TS, TVᴴ} # form cache if needed cache_A = !isa(A, Const) ? copy(A.val) : nothing ret = svd_compact!(A.val, USVᴴ.val, alg.val.alg) cache_USVᴴ = copy.(ret) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, ret, alg.val.trunc) primal = EnzymeRules.needs_primal(config) ? USVᴴ′ : nothing - shadow_USVᴴ = if !isa(A, Const) - # the USVᴴ may be nothing for eltypes handled by GenericLinearAlgebra - dU, dS, dVᴴ = all(isnothing, USVᴴ.val) ? zero.(ret) : USVᴴ.dval - # This creates new output shadow matrices, we do this slicing - # to ensure they have the correct eltype and dimensions. - # These new shadow matrices are "filled in" with the accumulated - # results from earlier in reverse-mode AD after this function exits - # and before `reverse` is called. - dStrunc = Diagonal(diagview(dS)[ind]) - dUtrunc = dU[:, ind] - dVᴴtrunc = dVᴴ[ind, :] - (dUtrunc, dStrunc, dVᴴtrunc) - else - (nothing, nothing, nothing) - end + # This creates new output shadow matrices, we do this slicing + # to ensure they have the correct eltype and dimensions. + # These new shadow matrices are "filled in" with the accumulated + # results from earlier in reverse-mode AD after this function exits + # and before `reverse` is called. + shadow_USVᴴ = (zero(USVᴴ′[1]), Diagonal(zero(USVᴴ′[2].diag)), zero(USVᴴ′[3])) shadow = EnzymeRules.needs_shadow(config) ? shadow_USVᴴ : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, shadow_USVᴴ, ind)) end @@ -306,17 +275,17 @@ for (f, trunc_f, full_f, pb) in ( func::Const{typeof($f)}, ::Type{RT}, A::Annotation, - DV::Annotation, + DV::Annotation{Tuple{TA, TB}}, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} + ) where {RT, TA, TB} # form cache if needed cache_A = !isa(A, Const) ? copy(A.val) : nothing ret = $full_f(A.val, DV.val, alg.val.alg) cache_DV = copy.(ret) DV′, ind = truncate($trunc_f, ret, alg.val.trunc) primal = EnzymeRules.needs_primal(config) ? DV′ : nothing - shadow_DV = if !isa(A, Const) && !isa(DV, Const) - dD, dV = all(isnothing, DV.val) ? zero!.(similar.(ret)) : DV.dval + shadow_DV = if !isa(A, Const) + dD, dV = (TA == Nothing && TB == Nothing) ? zero.(ret) : DV.dval dDtrunc = Diagonal(diagview(dD)[ind]) dVtrunc = dV[:, ind] (dDtrunc, dVtrunc) @@ -364,7 +333,7 @@ for (f!, f_full!, pb!) in ( nD, V = initialize_output($f_full!, A.val, alg.val) nD, V = $f_full!(A.val, (nD, V), alg.val) ret = something(D.val, similar(A.val, eltype(nD), length(diagview(nD)))) - dret = something(D.dval, zero!(similar(A.val, eltype(nD), length(diagview(nD))))) + dret = isa(D, Const) ? zero(ret) : D.dval copy!(ret, diagview(nD)) cache_D = EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing primal = EnzymeRules.needs_primal(config) ? ret : nothing @@ -383,9 +352,8 @@ for (f!, f_full!, pb!) in ( cache_A, cache_D, Dval, dD, V = cache Dval = something(cache_D, Dval) Aval = something(cache_A, A.val) - ∂D = isa(D, Const) ? nothing : dD if !isa(A, Const) - $pb!(A.dval, Aval, (Diagonal(Dval), V), ∂D) + $pb!(A.dval, Aval, (Diagonal(Dval), V), dD) end !isa(D, Const) && make_zero!(D.dval) return (nothing, nothing, nothing) @@ -404,11 +372,7 @@ function EnzymeRules.augmented_primal( cache_A = !isa(A, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing U, nS, Vᴴ = svd_compact!(A.val, alg.val) ret = something(S.val, similar(A.val, real(eltype(A.val)), length(diagview(nS)))) - dret = if isa(S, Const) - zero!(similar(ret)) - else - something(S.dval, zero!(similar(A.val, real(eltype(A.val)), length(diagview(nS))))) - end + dret = isa(S, Const) ? zero(ret) : S.dval copy!(ret, diagview(nS)) cache_S = EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing primal = EnzymeRules.needs_primal(config) ? ret : nothing @@ -427,9 +391,8 @@ function EnzymeRules.reverse( cache_A, cache_S, Sval, dS, U, Vᴴ = cache Sval = something(cache_S, Sval) Aval = something(cache_A, A.val) - ∂S = dS if !isa(A, Const) - svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), ∂S) + svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS) end !isa(S, Const) && make_zero!(S.dval) return (nothing, nothing, nothing) diff --git a/test/enzyme.jl b/test/enzyme.jl index 19588951..32053783 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -6,7 +6,7 @@ using ChainRulesCore using Enzyme, EnzymeTestUtils using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!, BlasFloat -using GenericLinearAlgebra +using GenericLinearAlgebra, GenericSchur # https://github.com/EnzymeAD/Enzyme.jl/issues/2888, # test_reverse doesn't work with BigFloat @@ -184,10 +184,15 @@ end fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) alg = MatrixAlgebraKit.default_eig_algorithm(A) @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) - test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) - test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm) - test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) + if T <: BlasFloat + test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) + test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) + test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm) + test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) + else + test_pullbacks_match(rng, eig_full!, eig_full, A, (nothing, nothing), (nothing, nothing), alg; ȳ = (ΔD2, ΔV)) + test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD2.diag) + end end @testset "eig_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) for r in 1:4:m @@ -197,8 +202,12 @@ end Vtrunc = V[:, ind] ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) ΔVtrunc = ΔV[:, ind] - T <: BlasFloat && test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + if T <: BlasFloat + test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + else + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + end end truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) @@ -206,8 +215,12 @@ end Vtrunc = V[:, ind] ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) ΔVtrunc = ΔV[:, ind] - T <: BlasFloat && test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ = (ΔDtrunc, ΔVtrunc)) + if T <: BlasFloat + test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ = (ΔDtrunc, ΔVtrunc)) + else + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg; ȳ = (ΔDtrunc, ΔVtrunc)) + end end end @@ -315,8 +328,8 @@ end ΔS = Diagonal(randn(rng, real(T), minmn)) ΔVᴴ = randn(rng, T, minmn, n) ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - T <: BlasFloat && test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (ΔU, ΔS, ΔVᴴ), fdm = fdm) if T <: BlasFloat + test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (ΔU, ΔS, ΔVᴴ), fdm = fdm) test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), alg) else USVᴴ = MatrixAlgebraKit.initialize_output(svd_compact!, A, alg) @@ -336,8 +349,8 @@ end view(ΔUfull, :, 1:minmn) .= ΔU view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ diagview(ΔSfull)[1:minmn] .= diagview(ΔS) - T <: BlasFloat && test_reverse(svd_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (ΔUfull, ΔSfull, ΔVᴴfull), fdm = fdm) if T <: BlasFloat + test_reverse(svd_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (ΔUfull, ΔSfull, ΔVᴴfull), fdm = fdm) test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) else USVᴴ = MatrixAlgebraKit.initialize_output(svd_full!, A, alg) @@ -347,8 +360,8 @@ end @testset "svd_vals" begin S = svd_vals(A) ΔS = randn(rng, real(T), minmn) - T <: BlasFloat && test_reverse(svd_vals, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = ΔS, fdm = fdm) if T <: BlasFloat + test_reverse(svd_vals, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = ΔS, fdm = fdm) test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg) else S = MatrixAlgebraKit.initialize_output(svd_vals!, A, alg) @@ -375,10 +388,12 @@ end ΔStrunc = Diagonal(diagview(ΔS2)[ind]) ΔUtrunc = ΔU[:, ind] ΔVᴴtrunc = ΔVᴴ[ind, :] - USVᴴ = T <: BlasFloat ? (U, S, Vᴴ) : (nothing, nothing, nothing) - T <: BlasFloat && test_reverse(svd_trunc_no_error!, RT, (A, TA), (USVᴴ, Duplicated), (truncalg, Const); atol, rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm) - ΔUSVᴴ = T <: BlasFloat ? (ΔU, ΔS2, ΔVᴴ) : (nothing, nothing, nothing) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + if T <: BlasFloat + test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + else + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + end end U, S, Vᴴ = svd_compact(A) ΔU = randn(rng, T, m, minmn) @@ -394,10 +409,12 @@ end ΔStrunc = Diagonal(diagview(ΔS2)[ind]) ΔUtrunc = ΔU[:, ind] ΔVᴴtrunc = ΔVᴴ[ind, :] - T <: BlasFloat && test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm) - USVᴴ = T <: BlasFloat ? (U, S, Vᴴ) : (nothing, nothing, nothing) - ΔUSVᴴ = T <: BlasFloat ? (ΔU, ΔS2, ΔVᴴ) : (nothing, nothing, nothing) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + if T <: BlasFloat + test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + else + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) + end end end end From a1b5d97fd6d3cce2e9abcdfa9739a9a771da9295 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 14:36:44 +0100 Subject: [PATCH 3/9] Comments --- .../MatrixAlgebraKitEnzymeExt.jl | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index a3b51db8..a5b2db3f 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -60,7 +60,7 @@ end function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(copy_input)}, - dret::Type{RT}, + ::Type{RT}, cache, f::Annotation, A::Annotation @@ -104,7 +104,7 @@ for (f, pb) in ( function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof($f)}, - dret::Type{RT}, + ::Type{RT}, cache, A::Annotation, arg::Annotation{Tuple{TA, TB}}, @@ -113,9 +113,8 @@ for (f, pb) in ( cache_A, cache_arg, argval, darg = cache Aval = something(cache_A, A.val) argval = something(cache_arg, argval) - ∂arg = darg if !isa(A, Const) - $pb(A.dval, Aval, argval, ∂arg) + $pb(A.dval, Aval, argval, darg) end !isa(arg, Const) && make_zero!(arg.dval) return (nothing, nothing, nothing) @@ -133,12 +132,12 @@ for (f, pb) in ( func::Const{typeof($f)}, ::Type{RT}, A::Annotation, - arg::Annotation, + arg::Annotation{TA}, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} + ) where {RT, TA} cache_A = copy(A.val) ret = func.val(A.val, arg.val, alg.val) - dret = isa(arg, Const) ? nothing : arg.dval + dret = TA == Nothing ? zero(ret) : arg.dval primal = EnzymeRules.needs_primal(config) ? ret : nothing shadow = EnzymeRules.needs_shadow(config) ? dret : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, dret)) @@ -146,7 +145,7 @@ for (f, pb) in ( function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof($f)}, - dret::Type{RT}, + ::Type{RT}, cache, A::Annotation, arg::Annotation, @@ -180,6 +179,7 @@ for f in (:svd_compact!, :svd_full!) # the USVᴴ may be nothing for eltypes handled by GenericLinearAlgebra dret = if (TU == TS == TVᴴ == Nothing) dU = zero(ret[1]) + # seems to be necessary due to Enzyme's type analysis dS = $(f == svd_compact!) ? Diagonal(zero(ret[2].diag)) : zero(ret[2]) dVᴴ = zero(ret[3]) (dU, dS, dVᴴ) @@ -193,27 +193,27 @@ for f in (:svd_compact!, :svd_full!) function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof($f)}, - dret::Type{RT}, + ::Type{RT}, cache, A::Annotation, USVᴴ::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} - cache_A, cache_USVᴴ, USVᴴval, ∂USVᴴ = cache + cache_A, cache_USVᴴ, USVᴴval, dUSVᴴ = cache Aval = something(cache_A, A.val) USVᴴval = something(cache_USVᴴ, USVᴴval) U, S, Vᴴ = USVᴴval if !isa(A, Const) minmn = min(size(A.val)...) if $(f == svd_compact!) # compact - svd_pullback!(A.dval, Aval, USVᴴval, ∂USVᴴ) + svd_pullback!(A.dval, Aval, USVᴴval, dUSVᴴ) else # full vU = view(U, :, 1:minmn) vS = Diagonal(diagview(S)[1:minmn]) vVᴴ = view(Vᴴ, 1:minmn, :) - vdU = view(∂USVᴴ[1], :, 1:minmn) - vdS = Diagonal(diagview(∂USVᴴ[2])[1:minmn]) - vdVᴴ = view(∂USVᴴ[3], 1:minmn, :) + vdU = view(dUSVᴴ[1], :, 1:minmn) + vdS = Diagonal(diagview(dUSVᴴ[2])[1:minmn]) + vdVᴴ = view(dUSVᴴ[3], 1:minmn, :) svd_pullback!(A.dval, Aval, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) end end @@ -251,16 +251,16 @@ end function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(svd_trunc_no_error!)}, - dret::Type{RT}, + ::Type{RT}, cache, A::Annotation, USVᴴ::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} - cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache + cache_A, cache_USVᴴ, dUSVᴴ, ind = cache Aval = something(cache_A, A.val) if !isa(A, Const) - svd_pullback!(A.dval, Aval, cache_USVᴴ, shadow_USVᴴ, ind) + svd_pullback!(A.dval, Aval, cache_USVᴴ, dUSVᴴ, ind) end !isa(USVᴴ, Const) && make_zero!(USVᴴ.dval) return (nothing, nothing, nothing) @@ -304,12 +304,11 @@ for (f, trunc_f, full_f, pb) in ( DV::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} - cache_A, cache_DV, cache_dDVtrunc, ind = cache + cache_A, cache_DV, dDVtrunc, ind = cache Aval = something(cache_A, A.val) D, V = something(cache_DV, DV) - dD, dV = cache_dDVtrunc if !isa(A, Const) - $pb(A.dval, Aval, (D, V), (dD, dV), ind) + $pb(A.dval, Aval, (D, V), dDVtrunc, ind) end !isa(DV, Const) && make_zero!(DV.dval) return (nothing, nothing, nothing) From 52969b2065e6251d9cc1273653c419232bb173da Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 15:45:10 +0100 Subject: [PATCH 4/9] Allow eigh tests to run with BigFloat --- test/enzyme.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index 32053783..71aa3373 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -265,14 +265,11 @@ function copy_eigh_trunc_no_error!(A, DV, alg) return eigh_trunc_no_error!(A, DV, alg) end -# https://github.com/EnzymeAD/Enzyme.jl/issues/2889 -# the addition methods cannot be compiled -@timedtestset "EIGH AD Rules with eltype $T" for T in filter(T -> <:(T, BlasFloat), ETs) +@timedtestset "EIGH AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 atol = rtol = m * m * precision(T) A = make_eigh_matrix(rng, T, m) - #A = (A + A') / 2 D, V = eigh_full(A) D2 = Diagonal(D) ΔV = randn(rng, T, m, m) From 0ed19838e7e1ed244d458076ed6e4a58fc9cfcd4 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:36:18 +0100 Subject: [PATCH 5/9] Update ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl Co-authored-by: Jutho --- ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index a5b2db3f..ce1a7464 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -209,7 +209,7 @@ for f in (:svd_compact!, :svd_full!) svd_pullback!(A.dval, Aval, USVᴴval, dUSVᴴ) else # full vU = view(U, :, 1:minmn) - vS = Diagonal(diagview(S)[1:minmn]) + vS = Diagonal(view(diagview(S), 1:minmn)) vVᴴ = view(Vᴴ, 1:minmn, :) vdU = view(dUSVᴴ[1], :, 1:minmn) vdS = Diagonal(diagview(dUSVᴴ[2])[1:minmn]) From f99c77e30a52c05c9bc530aacf866d7c7a02cb37 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:38:49 +0100 Subject: [PATCH 6/9] Update ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl Co-authored-by: Jutho --- ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index ce1a7464..0b19e864 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -94,7 +94,7 @@ for (f, pb) in ( ) where {RT, TA, TB} # form cache if needed cache_A = !isa(A, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing - cache_arg = !isa(arg, Const) && EnzymeRules.overwritten(config)[3] ? copy(arg.val) : nothing + cache_arg = !isa(arg, Const) && EnzymeRules.overwritten(config)[3] ? copy.(arg.val) : nothing ret = func.val(A.val, arg.val, alg.val) dret = (TA == Nothing && TB == Nothing) ? zero.(ret) : arg.dval primal = EnzymeRules.needs_primal(config) ? ret : nothing From 8d98d6b72047b7432e22817b60659cd28d2f84b0 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:41:54 +0100 Subject: [PATCH 7/9] Update ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl Co-authored-by: Jutho --- ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 0b19e864..aae6f152 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -212,7 +212,7 @@ for f in (:svd_compact!, :svd_full!) vS = Diagonal(view(diagview(S), 1:minmn)) vVᴴ = view(Vᴴ, 1:minmn, :) vdU = view(dUSVᴴ[1], :, 1:minmn) - vdS = Diagonal(diagview(dUSVᴴ[2])[1:minmn]) + vdS = Diagonal(view(diagview(dUSVᴴ[2]), 1:minmn)) vdVᴴ = view(dUSVᴴ[3], 1:minmn, :) svd_pullback!(A.dval, Aval, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) end From afb32788f47358e40506ca1918a6b14bcd1f07c6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:53:22 +0100 Subject: [PATCH 8/9] Comments --- .../MatrixAlgebraKitEnzymeExt.jl | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index aae6f152..2bdcd542 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -96,10 +96,11 @@ for (f, pb) in ( cache_A = !isa(A, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing cache_arg = !isa(arg, Const) && EnzymeRules.overwritten(config)[3] ? copy.(arg.val) : nothing ret = func.val(A.val, arg.val, alg.val) + cache_arg = EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing dret = (TA == Nothing && TB == Nothing) ? zero.(ret) : arg.dval primal = EnzymeRules.needs_primal(config) ? ret : nothing shadow = EnzymeRules.needs_shadow(config) ? dret : nothing - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg, ret, dret)) + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg, dret)) end function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, @@ -110,9 +111,9 @@ for (f, pb) in ( arg::Annotation{Tuple{TA, TB}}, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TA, TB} - cache_A, cache_arg, argval, darg = cache + cache_A, cache_arg, darg = cache Aval = something(cache_A, A.val) - argval = something(cache_arg, argval) + argval = something(cache_arg, arg.val) if !isa(A, Const) $pb(A.dval, Aval, argval, darg) end @@ -135,12 +136,13 @@ for (f, pb) in ( arg::Annotation{TA}, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TA} - cache_A = copy(A.val) + cache_A = !isa(A, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing ret = func.val(A.val, arg.val, alg.val) + cache_arg = EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing dret = TA == Nothing ? zero(ret) : arg.dval primal = EnzymeRules.needs_primal(config) ? ret : nothing shadow = EnzymeRules.needs_shadow(config) ? dret : nothing - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, dret)) + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg, dret)) end function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, @@ -151,10 +153,11 @@ for (f, pb) in ( arg::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} - cache_A, darg = cache + cache_A, cache_arg, darg = cache Aval = something(cache_A, A.val) + argval = something(cache_arg, arg.val) if !isa(A, Const) - $pb(A.dval, Aval, arg.val, darg) + $pb(A.dval, Aval, argval, darg) end !isa(arg, Const) && make_zero!(arg.dval) return (nothing, nothing, nothing) @@ -175,7 +178,7 @@ for f in (:svd_compact!, :svd_full!) # form cache if needed cache_A = !isa(A, Const) ? copy(A.val) : nothing ret = func.val(A.val, USVᴴ.val, alg.val) - cache_USVᴴ = EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing + cache_USVᴴ = copy.(ret) # the USVᴴ may be nothing for eltypes handled by GenericLinearAlgebra dret = if (TU == TS == TVᴴ == Nothing) dU = zero(ret[1]) @@ -188,7 +191,7 @@ for f in (:svd_compact!, :svd_full!) end primal = EnzymeRules.needs_primal(config) ? ret : nothing shadow = EnzymeRules.needs_shadow(config) ? dret : nothing - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, ret, dret)) + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, dret)) end function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, @@ -199,15 +202,14 @@ for f in (:svd_compact!, :svd_full!) USVᴴ::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} - cache_A, cache_USVᴴ, USVᴴval, dUSVᴴ = cache + cache_A, USVᴴval, dUSVᴴ = cache Aval = something(cache_A, A.val) - USVᴴval = something(cache_USVᴴ, USVᴴval) - U, S, Vᴴ = USVᴴval if !isa(A, Const) minmn = min(size(A.val)...) if $(f == svd_compact!) # compact svd_pullback!(A.dval, Aval, USVᴴval, dUSVᴴ) else # full + U, S, Vᴴ = USVᴴval vU = view(U, :, 1:minmn) vS = Diagonal(view(diagview(S), 1:minmn)) vVᴴ = view(Vᴴ, 1:minmn, :) From 0b49a70fc0addb02634fd44c29b9dc53d8b04e6e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 23:35:34 +0100 Subject: [PATCH 9/9] Consistent arg caching --- .../MatrixAlgebraKitEnzymeExt.jl | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 2bdcd542..153648d5 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -94,9 +94,11 @@ for (f, pb) in ( ) where {RT, TA, TB} # form cache if needed cache_A = !isa(A, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing - cache_arg = !isa(arg, Const) && EnzymeRules.overwritten(config)[3] ? copy.(arg.val) : nothing ret = func.val(A.val, arg.val, alg.val) - cache_arg = EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing + # if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed + # if arg isa Const, ret may still be modified further down the call graph so we should + # copy it to protect ourselves + cache_arg = (arg.val !== ret) || (!isa(arg, Const) && EnzymeRules.overwritten(config)[3]) ? copy.(ret) : nothing dret = (TA == Nothing && TB == Nothing) ? zero.(ret) : arg.dval primal = EnzymeRules.needs_primal(config) ? ret : nothing shadow = EnzymeRules.needs_shadow(config) ? dret : nothing @@ -138,7 +140,10 @@ for (f, pb) in ( ) where {RT, TA} cache_A = !isa(A, Const) && EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing ret = func.val(A.val, arg.val, alg.val) - cache_arg = EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing + # if arg.val == ret, the annotation must be Duplicated or DuplicatedNoNeed + # if arg isa Const, ret may still be modified further down the call graph so we should + # copy it to protect ourselves + cache_arg = (arg.val !== ret) || (!isa(arg, Const) && EnzymeRules.overwritten(config)[3]) ? copy(ret) : nothing dret = TA == Nothing ? zero(ret) : arg.dval primal = EnzymeRules.needs_primal(config) ? ret : nothing shadow = EnzymeRules.needs_shadow(config) ? dret : nothing @@ -178,7 +183,10 @@ for f in (:svd_compact!, :svd_full!) # form cache if needed cache_A = !isa(A, Const) ? copy(A.val) : nothing ret = func.val(A.val, USVᴴ.val, alg.val) - cache_USVᴴ = copy.(ret) + # if USVᴴ.val == ret, the annotation must be Duplicated or DuplicatedNoNeed + # if USVᴴ isa Const, ret may still be modified further down the call graph so we should + # copy it to protect ourselves + cache_USVᴴ = (USVᴴ.val !== ret) || (!isa(USVᴴ, Const) && EnzymeRules.overwritten(config)[3]) ? copy.(ret) : nothing # the USVᴴ may be nothing for eltypes handled by GenericLinearAlgebra dret = if (TU == TS == TVᴴ == Nothing) dU = zero(ret[1]) @@ -236,7 +244,7 @@ function EnzymeRules.augmented_primal( # form cache if needed cache_A = !isa(A, Const) ? copy(A.val) : nothing ret = svd_compact!(A.val, USVᴴ.val, alg.val.alg) - cache_USVᴴ = copy.(ret) + cache_USVᴴ = (USVᴴ.val !== ret) || (!isa(USVᴴ, Const) && EnzymeRules.overwritten(config)[3]) ? copy.(ret) : nothing USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, ret, alg.val.trunc) primal = EnzymeRules.needs_primal(config) ? USVᴴ′ : nothing # This creates new output shadow matrices, we do this slicing @@ -283,7 +291,7 @@ for (f, trunc_f, full_f, pb) in ( # form cache if needed cache_A = !isa(A, Const) ? copy(A.val) : nothing ret = $full_f(A.val, DV.val, alg.val.alg) - cache_DV = copy.(ret) + cache_DV = (DV.val !== ret) || (!isa(DV, Const) && EnzymeRules.overwritten(config)[3]) ? copy.(ret) : nothing DV′, ind = truncate($trunc_f, ret, alg.val.trunc) primal = EnzymeRules.needs_primal(config) ? DV′ : nothing shadow_DV = if !isa(A, Const) @@ -336,10 +344,10 @@ for (f!, f_full!, pb!) in ( ret = something(D.val, similar(A.val, eltype(nD), length(diagview(nD)))) dret = isa(D, Const) ? zero(ret) : D.dval copy!(ret, diagview(nD)) - cache_D = EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing + cache_D = (D.val !== ret) || (!isa(D, Const) && EnzymeRules.overwritten(config)[3]) ? copy(ret) : nothing primal = EnzymeRules.needs_primal(config) ? ret : nothing shadow = EnzymeRules.needs_shadow(config) ? dret : nothing - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D, ret, dret, V)) + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D, dret, V)) end function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, @@ -350,8 +358,8 @@ for (f!, f_full!, pb!) in ( D::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} - cache_A, cache_D, Dval, dD, V = cache - Dval = something(cache_D, Dval) + cache_A, cache_D, dD, V = cache + Dval = something(cache_D, D.val) Aval = something(cache_A, A.val) if !isa(A, Const) $pb!(A.dval, Aval, (Diagonal(Dval), V), dD) @@ -375,10 +383,10 @@ function EnzymeRules.augmented_primal( ret = something(S.val, similar(A.val, real(eltype(A.val)), length(diagview(nS)))) dret = isa(S, Const) ? zero(ret) : S.dval copy!(ret, diagview(nS)) - cache_S = EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing + cache_S = (S.val !== ret) || (!isa(S, Const) && EnzymeRules.overwritten(config)[3]) ? copy(ret) : nothing primal = EnzymeRules.needs_primal(config) ? ret : nothing shadow = EnzymeRules.needs_shadow(config) ? dret : nothing - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_S, ret, dret, U, Vᴴ)) + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_S, dret, U, Vᴴ)) end function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, @@ -389,8 +397,8 @@ function EnzymeRules.reverse( S::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} - cache_A, cache_S, Sval, dS, U, Vᴴ = cache - Sval = something(cache_S, Sval) + cache_A, cache_S, dS, U, Vᴴ = cache + Sval = something(cache_S, S.val) Aval = something(cache_A, A.val) if !isa(A, Const) svd_vals_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), dS)