diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 343bd968..68e2916b 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -8,6 +8,7 @@ using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: eig_pushforward!, eigh_pushforward!, eig_vals_pushforward!, eigh_vals_pushforward! using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback! +using MatrixAlgebraKit: svd_pushforward!, svd_vals_pushforward! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! using Enzyme @@ -264,6 +265,34 @@ for f in (:svd_compact!, :svd_full!) !isa(USVᴴ, Const) && make_zero!(USVᴴ.dval) return (nothing, nothing, nothing) end + function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation{TA}, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA} + $f(A.val, USVᴴ.val, alg.val) + if !isa(A, Const) + if $(f == svd_compact!) + make_zero!(USVᴴ.dval[2].diag) + else + make_zero!(USVᴴ.dval[2]) + end + !isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval) + make_zero!(A.dval) + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return USVᴴ + elseif EnzymeRules.needs_primal(config) + return USVᴴ.val + elseif EnzymeRules.needs_shadow(config) + return USVᴴ.dval + else + return nothing + end + end end end @@ -502,5 +531,32 @@ function EnzymeRules.reverse( !isa(S, Const) && !A_is_arg && make_zero!(S.dval) return (nothing, nothing, nothing) end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(svd_vals!)}, + ::Type{RT}, + A::Annotation{TA}, + S::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA} + A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval + U, S_, Vᴴ = svd_compact!(A.val, alg.val) + if !isa(A, Const) && !isa(S, Const) + ΔS = A_is_arg ? make_zero(S.dval) : S.dval + svd_vals_pushforward!(A.dval, A.val, (U, Diagonal(diagview(S_)), Vᴴ), ΔS) + A_is_arg && (S.dval .= ΔS) + end + !A_is_arg && make_zero!(A.dval) + copyto!(S.val, diagview(S_)) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return S + elseif EnzymeRules.needs_primal(config) + return S.val + elseif EnzymeRules.needs_shadow(config) + return S.dval + else + return nothing + end +end end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 16241385..ebe080eb 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -13,6 +13,7 @@ using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pul using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback! +using MatrixAlgebraKit: svd_pushforward!, svd_trunc_pushforward!, svd_vals_pushforward! using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra @@ -538,7 +539,7 @@ for (f!, f) in ( (:svd_compact!, :svd_compact), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) @@ -562,7 +563,18 @@ for (f!, f) in ( end return USVᴴ_dUSVᴴ, svd_adjoint end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual) + A, dA = arrayify(A_dA) + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + $f!(A, USVᴴ, Mooncake.primal(alg_dalg)) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + return USVᴴ_dUSVᴴ + end + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) USVᴴ = $f(A, Mooncake.primal(alg_dalg)) @@ -585,10 +597,23 @@ for (f!, f) in ( end return USVᴴ_codual, svd_adjoint end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) + A, dA = arrayify(A_dA) + USVᴴ = $f(A, Mooncake.primal(alg_dalg)) + dUSVᴴ = Mooncake.zero_tangent(USVᴴ) + USVᴴ_dual = Dual(USVᴴ, dUSVᴴ) + U, S, Vᴴ = Mooncake.primal(USVᴴ_dual) + dU_, dS_, dVᴴ_ = Mooncake.tangent(USVᴴ_dual) + U, dU = arrayify(U, dU_) + S, dS = arrayify(S, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + return USVᴴ_dual + end end end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -604,8 +629,17 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua end return S_dS, svd_vals_adjoint end +function Mooncake.frule!!(::Dual{typeof(svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + S, dS = arrayify(S_dS) + USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + copy!(S, diagview(USVᴴ[2])) + svd_vals_pushforward!(dA, A, USVᴴ, dS) + return S_dS +end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -624,6 +658,16 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co end return S_codual, svd_vals_adjoint end +function Mooncake.frule!!(::Dual{typeof(svd_vals)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + S = diagview(USVᴴ[2]) + S_dual = Dual(S, Mooncake.zero_tangent(S)) + S_, dS = arrayify(S_dual) + svd_vals_pushforward!(dA, A, USVᴴ, dS) + return S_dual +end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 65de152c..115b8301 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -132,6 +132,7 @@ include("pullbacks/polar.jl") include("pushforwards/polar.jl") include("pushforwards/eig.jl") include("pushforwards/eigh.jl") +include("pushforwards/svd.jl") include("precompile.jl") diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl new file mode 100644 index 00000000..9acca17f --- /dev/null +++ b/src/pushforwards/svd.jl @@ -0,0 +1,91 @@ +function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol = default_pullback_rank_atol(A), kwargs...) + U, Smat, Vᴴ = USVᴴ + m, n = size(U, 1), size(Vᴴ, 2) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) + minmn = min(m, n) + S = diagview(Smat) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + r = svd_rank(S; rank_atol) + + vΔS = view(diagview(ΔS), 1:r) + + vU = view(U, :, 1:r) + vS = view(S, 1:r) + vSmat = view(Smat, 1:r, 1:r) + vVᴴ = view(Vᴴ, 1:r, :) + + # compact region + vV = adjoint(vVᴴ) + UΔAV = vU' * ΔA * vV + copyto!(vΔS, real.(diagview(UΔAV))) + F = inv_safe.(transpose(vS) .- vS) + G = inv_safe.(transpose(vS) .+ vS) + hUΔAV = F .* (UΔAV + UΔAV') ./ 2 + aUΔAV = G .* (UΔAV - UΔAV') ./ 2 + K̇ = hUΔAV + aUΔAV + Ṁ = hUΔAV - aUΔAV + + # check gauge condition + @assert isantihermitian(K̇) + @assert isantihermitian(Ṁ) + K̇diag = diagview(K̇) + + ∂U = vU * K̇ + ∂V = vV * Ṁ + # full component + if size(U, 2) > minmn && size(Vᴴ, 1) > minmn + Uperp = view(U, :, (minmn + 1):m) + Vᴴperp = view(Vᴴ, (minmn + 1):n, :) + + aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp) + + UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2))) + fill!(UÃÃV, 0) + view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV + view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV' + rhs = vcat(adjoint(Uperp * ΔA * Vᴴ), Vᴴperp * ΔA' * U) + superKM = -_sylvester(UÃÃV, Smat, rhs) + K̇perp = view(superKM, 1:size(aUAV, 2)) + Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2))) + ∂U .+= Uperp * K̇perp + ∂V .+= Vᴴperp * Ṁperp + else + ImUU = (LinearAlgebra.diagm(one!(similar(U, m))) - vU * vU') + ImVV = (LinearAlgebra.diagm(one!(similar(Vᴴ, n))) - vV * vVᴴ) + upper = ImUU * ΔA * vV + lower = ImVV * ΔA' * vU + rhs = vcat(upper, lower) + + Ã = ImUU * A * ImVV + ÃÃ = similar(A, (m + n, m + n)) + fill!(ÃÃ, 0) + view(ÃÃ, (1:m), m .+ (1:n)) .= Ã + view(ÃÃ, m .+ (1:n), 1:m) .= Ã' + + superLN = -_sylvester(ÃÃ, vSmat, rhs) + ∂U += view(superLN, 1:size(upper, 1), :) + ∂V += view(superLN, (size(upper, 1) + 1):(size(upper, 1) + size(lower, 1)), :) + end + if !iszerotangent(ΔU) + vΔU = view(ΔU, :, 1:r) + copyto!(vΔU, ∂U) + end + if !iszerotangent(ΔVᴴ) + vΔVᴴ = view(ΔVᴴ, 1:r, :) + adjoint!(vΔVᴴ, ∂V) + end + return (ΔU, ΔS, ΔVᴴ) +end + +function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...) + # TODO +end + +function svd_vals_pushforward!( + ΔA, A, USVᴴ, ΔS, ind = Colon(); + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]) + ) + ΔUSVᴴ = (nothing, diagonal(ΔS), nothing) + return svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol) +end diff --git a/test/enzyme/svd.jl b/test/enzyme/svd.jl index e4aaa7aa..bef41e5c 100644 --- a/test/enzyme/svd.jl +++ b/test/enzyme/svd.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) if !is_buildkite TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) AT = Diagonal{T, Vector{T}} - m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + m == n && TestSuite.test_enzyme_svd(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/testsuite/enzyme/svd.jl b/test/testsuite/enzyme/svd.jl index 2131aa8d..1861b83b 100644 --- a/test/testsuite/enzyme/svd.jl +++ b/test/testsuite/enzyme/svd.jl @@ -8,48 +8,80 @@ function test_enzyme_svd(T::Type, sz; kwargs...) end end +""" + test_enzyme_svd_compact(T, sz; rng, atol, rtol) + +Test the Enzyme forward- and reverse-mode AD rule for `svd_compact` and its in-place variant. +""" function test_enzyme_svd_compact( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "svd_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "svd_compact: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(svd_compact, A) USVᴴ, ΔUSVᴴ = ad_svd_compact_setup(A) test_reverse(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) - test_reverse(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_reverse(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + if eltype(T) <: Real + test_forward(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm) + end end end +""" + test_enzyme_svd_full(T, sz; rng, atol, rtol) + +Test the Enzyme forward- and reverse-mode AD rule for `svd_full` and its in-place variant. The +gauge-dependent extra columns of `U` and rows of `Vᴴ` are zeroed out in the cotangent. +""" function test_enzyme_svd_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "svd_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "svd_full: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(svd_full, A) USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) - test_reverse(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_reverse(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + if eltype(T) <: Real && size(A, 1) == size(A, 2) # finite differences check for free component is very finicky + test_forward(svd_full, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm) + end end end +""" + test_enzyme_svd_vals(T, sz; rng, atol, rtol) + +Test the Enzyme forward- and reverse-mode AD rule for `svd_vals` and its in-place variant. +""" function test_enzyme_svd_vals( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "svd_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "svd_vals: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(svd_vals, A) S, ΔS = ad_svd_vals_setup(A) test_reverse(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) - test_reverse(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) + test_reverse(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) + test_forward(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm) end end +""" + test_enzyme_svd_trunc(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rules for `svd_trunc`, `svd_trunc_no_error`, and their +in-place variants, over a range of truncation ranks and a tolerance-based truncation. +""" function test_enzyme_svd_trunc( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), @@ -64,7 +96,7 @@ function test_enzyme_svd_trunc( trunc = truncrank(r) truncalg = TruncatedAlgorithm(alg, trunc) USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) end @testset "trunctol" begin @@ -72,7 +104,7 @@ function test_enzyme_svd_trunc( trunc = trunctol(atol = maximum(S) / 2) truncalg = TruncatedAlgorithm(alg, trunc) USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) end end diff --git a/test/testsuite/mooncake/svd.jl b/test/testsuite/mooncake/svd.jl index 5ac79744..d58a9c0c 100644 --- a/test/testsuite/mooncake/svd.jl +++ b/test/testsuite/mooncake/svd.jl @@ -16,7 +16,7 @@ end """ test_mooncake_svd_compact(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `svd_compact` and its in-place variant. +Test the Mooncake forward- and reverse-mode AD rule for `svd_compact` and its in-place variant. """ function test_mooncake_svd_compact( T, sz; @@ -33,16 +33,26 @@ function test_mooncake_svd_compact( mode = Mooncake.ReverseMode, output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, call_and_zero!, svd_compact!, A, alg; + rng, call_and_zero!, svd_compact!, copy(A), alg; mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false ) + if eltype(T) <: Real # gauge freedom in complex outputs + Mooncake.TestUtils.test_rule( + rng, svd_compact, A, alg; + mode = Mooncake.ForwardMode, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, call_and_zero!, svd_compact!, copy(A), alg; + mode = Mooncake.ForwardMode, atol, rtol, is_primitive = false + ) + end end end """ test_mooncake_svd_full(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `svd_full` and its in-place variant. The +Test the Mooncake forward- and reverse-mode AD rule for `svd_full` and its in-place variant. The gauge-dependent extra columns of `U` and rows of `Vᴴ` are zeroed out in the cotangent. """ function test_mooncake_svd_full( @@ -60,16 +70,26 @@ function test_mooncake_svd_full( mode = Mooncake.ReverseMode, output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, call_and_zero!, svd_full!, A, alg; + rng, call_and_zero!, svd_full!, copy(A), alg; mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false ) + if eltype(T) <: Real # gauge freedom in complex outputs + Mooncake.TestUtils.test_rule( + rng, svd_full, A, alg; + mode = Mooncake.ForwardMode, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, call_and_zero!, svd_full!, copy(A), alg; + mode = Mooncake.ForwardMode, atol, rtol, is_primitive = false + ) + end end end """ test_mooncake_svd_vals(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `svd_vals` and its in-place variant. +Test the Mooncake forward- and reverse-mode AD rule for `svd_vals` and its in-place variant. """ function test_mooncake_svd_vals( T, sz; @@ -83,11 +103,11 @@ function test_mooncake_svd_vals( Mooncake.TestUtils.test_rule( rng, svd_vals, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, svd_vals!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false + output_tangent, atol, rtol, is_primitive = false ) end end