diff --git a/Project.toml b/Project.toml index 6ca22ef9..42124239 100644 --- a/Project.toml +++ b/Project.toml @@ -1,24 +1,24 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" -authors = ["Jutho and contributors"] version = "0.6.3" +authors = ["Jutho and contributors"] [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a" GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] -MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" MatrixAlgebraKitAMDGPUExt = "AMDGPU" MatrixAlgebraKitCUDAExt = "CUDA" +MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" MatrixAlgebraKitEnzymeExt = "Enzyme" MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra" MatrixAlgebraKitGenericSchurExt = "GenericSchur" @@ -27,13 +27,13 @@ MatrixAlgebraKitMooncakeExt = "Mooncake" [compat] AMDGPU = "2" Aqua = "0.6, 0.7, 0.8" +CUDA = "5" ChainRulesCore = "1" ChainRulesTestUtils = "1" -CUDA = "5" -GenericLinearAlgebra = "0.3.19" -GenericSchur = "0.5.6" Enzyme = "0.13.118" EnzymeTestUtils = "0.2.5" +GenericLinearAlgebra = "0.3.19" +GenericSchur = "0.5.6" JET = "0.9, 0.10" LinearAlgebra = "1" Mooncake = "0.5" @@ -49,8 +49,8 @@ julia = "1.10" [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -64,7 +64,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", - "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", - "GenericLinearAlgebra", "GenericSchur", "Mooncake", "Enzyme", - "EnzymeTestUtils", "ParallelTestRunner"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake", "Enzyme", "EnzymeTestUtils", "ParallelTestRunner"] diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 217a48c2..e4ec256f 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt using Mooncake using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive using MatrixAlgebraKit -using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output +using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero! using MatrixAlgebraKit: qr_pullback!, lq_pullback! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! @@ -57,8 +57,8 @@ for (f!, f, pb, adj) in ( $pb(dA, A, (arg1, arg2), (darg1, darg2)) copy!(arg1, arg1c) copy!(arg2, arg2c) - MatrixAlgebraKit.zero!(darg1) - MatrixAlgebraKit.zero!(darg2) + zero!(darg1) + zero!(darg2) return NoRData(), NoRData(), NoRData(), NoRData() end return args_dargs, $adj @@ -78,8 +78,8 @@ for (f!, f, pb, adj) in ( arg1, darg1 = arrayify(arg1, darg1_) arg2, darg2 = arrayify(arg2, darg2_) $pb(dA, A, (arg1, arg2), (darg1, darg2)) - MatrixAlgebraKit.zero!(darg1) - MatrixAlgebraKit.zero!(darg2) + zero!(darg1) + zero!(darg2) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -103,7 +103,7 @@ for (f!, f, pb, adj) in ( copy!(A, Ac) $pb(dA, A, arg, darg) copy!(arg, argc) - MatrixAlgebraKit.zero!(darg) + zero!(darg) return NoRData(), NoRData(), NoRData(), NoRData() end return arg_darg, $adj @@ -116,7 +116,7 @@ for (f!, f, pb, adj) in ( function $adj(::NoRData) arg, darg = arrayify(output_codual) $pb(dA, A, arg, darg) - MatrixAlgebraKit.zero!(darg) + zero!(darg) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -134,13 +134,15 @@ for (f!, f, f_full, pb, adj) in ( # compute primal A, dA = arrayify(A_dA) D, dD = arrayify(D_dD) + Dc = copy(D) # update primal DV = $f_full(A, Mooncake.primal(alg_dalg)) copy!(D, diagview(DV[1])) V = DV[2] function $adj(::NoRData) $pb(dA, A, DV, dD) - MatrixAlgebraKit.zero!(dD) + copy!(D, Dc) + zero!(dD) return NoRData(), NoRData(), NoRData(), NoRData() end return D_dD, $adj @@ -157,7 +159,7 @@ for (f!, f, f_full, pb, adj) in ( function $adj(::NoRData) D, dD = arrayify(output_codual) $pb(dA, A, DV, dD) - MatrixAlgebraKit.zero!(dD) + zero!(dD) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -165,12 +167,43 @@ for (f!, f, f_full, pb, adj) in ( end end -for (f, f_ne, pb, adj) in ( - (:eig_trunc, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint), - (:eigh_trunc, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint), +for (f!, f, f_ne!, f_ne, pb, adj) in ( + (:eig_trunc!, :eig_trunc, :eig_trunc_no_error!, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint), + (:eigh_trunc!, :eigh_trunc, :eigh_trunc_no_error!, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint), ) @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + DV = Mooncake.primal(DV_dDV) + dDV = Mooncake.tangent(DV_dDV) + Ac = copy(A) + DVc = copy.(DV) + alg = Mooncake.primal(alg_dalg) + output = $f!(A, DV, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} + copy!(A, Ac) + Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) + dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) + abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" + D′, dD′ = arrayify(Dtrunc, dDtrunc_) + V′, dV′ = arrayify(Vtrunc, dVtrunc_) + $pb(dA, A, (D′, V′), (dD′, dV′)) + copy!(DV[1], DVc[1]) + copy!(DV[2], DVc[2]) + zero!(dD′) + zero!(dV′) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -188,13 +221,43 @@ for (f, f_ne, pb, adj) in ( D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) $pb(dA, A, (D, V), (dD, dV)) - MatrixAlgebraKit.zero!(dD) - MatrixAlgebraKit.zero!(dV) + zero!(dD) + zero!(dV) return NoRData(), NoRData(), NoRData() end return output_codual, $adj end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f_ne!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + DV = Mooncake.primal(DV_dDV) + dDV = Mooncake.tangent(DV_dDV) + Ac = copy(A) + DVc = copy.(DV) + output = $f_ne!(A, DV, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function $adj(::NoRData) + copy!(A, Ac) + Dtrunc, Vtrunc = Mooncake.primal(output_codual) + dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual) + D′, dD′ = arrayify(Dtrunc, dDtrunc_) + V′, dV′ = arrayify(Vtrunc, dVtrunc_) + $pb(dA, A, (D′, V′), (dD′, dV′)) + copy!(DV[1], DVc[1]) + copy!(DV[2], DVc[2]) + zero!(dD′) + zero!(dV′) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -211,8 +274,8 @@ for (f, f_ne, pb, adj) in ( D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) $pb(dA, A, (D, V), (dD, dV)) - MatrixAlgebraKit.zero!(dD) - MatrixAlgebraKit.zero!(dV) + zero!(dD) + zero!(dV) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -234,6 +297,7 @@ for (f!, f) in ( U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + USVᴴc = copy.(USVᴴ) output = $f!(A, Mooncake.primal(alg_dalg)) function svd_adjoint(::NoRData) copy!(A, Ac) @@ -249,9 +313,12 @@ for (f!, f) in ( vdVᴴ = view(dVᴴ, 1:minmn, :) svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) end - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + copy!(U, USVᴴc[1]) + copy!(S, USVᴴc[2]) + copy!(Vᴴ, USVᴴc[3]) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData(), NoRData() end return CoDual(output, dUSVᴴ), svd_adjoint @@ -283,9 +350,9 @@ for (f!, f) in ( vdVᴴ = view(dVᴴ, 1:minmn, :) svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) end - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData() end return USVᴴ_codual, svd_adjoint @@ -298,11 +365,13 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua # compute primal A, dA = arrayify(A_dA) S, dS = arrayify(S_dS) + Sc = copy(S) USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) copy!(S, diagview(USVᴴ[2])) function svd_vals_adjoint(::NoRData) svd_vals_pullback!(dA, A, USVᴴ, dS) - MatrixAlgebraKit.zero!(dS) + zero!(dS) + copy!(S, Sc) return NoRData(), NoRData(), NoRData(), NoRData() end return S_dS, svd_vals_adjoint @@ -322,18 +391,57 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co function svd_vals_adjoint(::NoRData) S, dS = arrayify(S_codual) svd_vals_pullback!(dA, A, USVᴴ, dS) - MatrixAlgebraKit.zero!(dS) + zero!(dS) return NoRData(), NoRData(), NoRData() end return S_codual, svd_vals_adjoint end +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + Ac = copy(A) + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + USVᴴc = copy.(USVᴴ) + output = svd_trunc!(A, USVᴴ, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = Mooncake.zero_fcodual(output) + function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} + copy!(A, Ac) + Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) + dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) + abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error" + U′, dU′ = arrayify(Utrunc, dUtrunc_) + S′, dS′ = arrayify(Strunc, dStrunc_) + Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_) + svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′)) + copy!(U, USVᴴc[1]) + copy!(S, USVᴴc[2]) + copy!(Vᴴ, USVᴴc[3]) + zero!(dU) + zero!(dS) + zero!(dVᴴ) + zero!(dU′) + zero!(dS′) + zero!(dVᴴ′) + return NoRData(), NoRData(), NoRData() + end + return output_codual, svd_trunc_adjoint +end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal - A_ = Mooncake.primal(A_dA) - dA_ = Mooncake.tangent(A_dA) - A, dA = arrayify(A_, dA_) + A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) output = svd_trunc(A, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal @@ -349,9 +457,49 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C S, dS = arrayify(Strunc, dStrunc_) Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + zero!(dU) + zero!(dS) + zero!(dVᴴ) + return NoRData(), NoRData(), NoRData() + end + return output_codual, svd_trunc_adjoint +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + Ac = copy(A) + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + USVᴴc = copy.(USVᴴ) + output = svd_trunc_no_error!(A, USVᴴ, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function svd_trunc_adjoint(::NoRData) + copy!(A, Ac) + Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual) + dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual) + U′, dU′ = arrayify(Utrunc, dUtrunc_) + S′, dS′ = arrayify(Strunc, dStrunc_) + Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_) + svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′)) + copy!(U, USVᴴc[1]) + copy!(S, USVᴴc[2]) + copy!(Vᴴ, USVᴴc[3]) + zero!(dU) + zero!(dS) + zero!(dVᴴ) + zero!(dU′) + zero!(dS′) + zero!(dVᴴ′) return NoRData(), NoRData(), NoRData() end return output_codual, svd_trunc_adjoint @@ -360,9 +508,7 @@ end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal - A_ = Mooncake.primal(A_dA) - dA_ = Mooncake.tangent(A_dA) - A, dA = arrayify(A_, dA_) + A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) output = svd_trunc_no_error(A, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal @@ -377,9 +523,9 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al S, dS = arrayify(Strunc, dStrunc_) Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData() end return output_codual, svd_trunc_adjoint diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index dc15dde9..29d65e31 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -52,8 +52,8 @@ MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit. MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) make_mooncake_tangent(ΔAelem::T) where {T <: Number} = ΔAelem -make_mooncake_tangent(ΔA::Matrix) = ΔA -make_mooncake_tangent(ΔA::Vector) = ΔA +make_mooncake_tangent(ΔA::AbstractMatrix) = ΔA +make_mooncake_tangent(ΔA::AbstractVector) = ΔA make_mooncake_tangent(ΔD::Diagonal) = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), make_mooncake_tangent.(T)...) @@ -62,62 +62,108 @@ make_mooncake_fdata(x) = make_mooncake_tangent(x) make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) make_mooncake_fdata(x::Tuple) = map(make_mooncake_fdata, x) +# copies a preset tangent into a Mooncake CoDual +# for use in the pullback. +function copy_tangent(var::Mooncake.CoDual, Δargs) + dargs = make_mooncake_fdata(deepcopy(Δargs)) + copyto!(Mooncake.tangent(var), dargs) + return +end + +function copy_tangent(var::Mooncake.CoDual, Δargs::Tuple) + dargs = make_mooncake_fdata.(deepcopy(Δargs)) + for (var_tangent, darg) in zip(Mooncake.tangent(var), dargs) + if var_tangent isa Mooncake.FData + for (var_f, darg_f) in zip(Mooncake._fields(var_tangent), Mooncake._fields(darg)) + copyto!(var_f, darg_f) + end + else + copyto!(var_tangent, darg) + end + end + return +end + # no `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) +function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) dA_copy = make_mooncake_fdata(copy(ΔA)) A_copy = copy(A) - dargs_copy = make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) + A_dA = Mooncake.CoDual(A_copy, dA_copy) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), A_dA) + # copy Δargs into tangent of the output variable for the pullback check + copy_tangent(copy_out, Δargs) copy_pb!!(rdata) - return dA_copy + @test Mooncake.primal(A_dA) == A + return dA_copy, Mooncake.tangent(copy_out) end # `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) +function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, alg, rdata) dA_copy = make_mooncake_fdata(copy(ΔA)) A_copy = copy(A) - dargs_copy = make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) + A_dA = Mooncake.CoDual(A_copy, dA_copy) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), A_dA, Mooncake.CoDual(alg, Mooncake.NoFData())) + # copy Δargs into tangent of the output variable for the pullback check + copy_tangent(copy_out, Δargs) copy_pb!!(rdata) - return dA_copy + @test Mooncake.primal(A_dA) == A + return dA_copy, Mooncake.tangent(copy_out) end -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata; ȳ = Δargs) dA_inplace = make_mooncake_fdata(copy(ΔA)) A_inplace = copy(A) + args_copy = deepcopy(args) dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) # not every f! has a handwritten rrule!! inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + A_dA = Mooncake.CoDual(A_inplace, dA_inplace) + args_dargs = Mooncake.CoDual(args_copy, dargs_inplace) if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs) else inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs) end + # copy reference derivative of output ȳ into inplace_out + # needed for inplace methods like svd_trunc! that generate + # new output variables + copy_tangent(inplace_out, ȳ) inplace_pb!!(rdata) - return dA_inplace + @test Mooncake.primal(A_dA) == A + @test Mooncake.primal(args_dargs) == args_copy + return dA_inplace, Mooncake.tangent(inplace_out) end -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata; ȳ = Δargs) dA_inplace = make_mooncake_fdata(copy(ΔA)) A_inplace = copy(A) + args_copy = deepcopy(args) dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) # not every f! has a handwritten rrule!! inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + A_dA = Mooncake.CoDual(A_inplace, dA_inplace) + args_dargs = Mooncake.CoDual(args_copy, dargs_inplace) if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs, Mooncake.CoDual(alg, Mooncake.NoFData())) else inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs, Mooncake.CoDual(alg, Mooncake.NoFData())) end + # copy reference derivative of output ȳ into inplace_out + # needed for inplace methods like svd_trunc! that generate + # new output variables + copy_tangent(inplace_out, ȳ) inplace_pb!!(rdata) - return dA_inplace + @test Mooncake.primal(A_dA) == A + @test Mooncake.primal(args_dargs) == args_copy + return dA_inplace, Mooncake.tangent(inplace_out) end """ @@ -137,19 +183,34 @@ The arguments to this function are: - `alg` optional algorithm keyword argument - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) """ -function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) - sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} +function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData(), ȳ = deepcopy(Δargs)) + sig = isnothing(alg) ? Tuple{typeof(f), typeof(A)} : Tuple{typeof(f), typeof(A), typeof(alg)} rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) rrule = Mooncake.build_rrule(rvs_interp, sig) - ΔA = isa(A, Diagonal) ? Diagonal(randn!(similar(A.diag))) : randn!(similar(A)) + ΔA = randn(rng, eltype(A), size(A)) - dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + copy_args = isa(args, Tuple) ? copy.(args) : copy(args) + inplace_args = isa(args, Tuple) ? copy.(args) : copy(args) + dA_copy, dargs_copy = _get_copying_derivative(f, rrule, A, ΔA, copy_args, ȳ, alg, rdata) + dA_inplace, dargs_inplace = _get_inplace_derivative(f!, A, ΔA, inplace_args, Δargs, alg, rdata; ȳ) dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] @test dA_inplace_ ≈ dA_copy_ + @test copy_args == inplace_args + if dargs_copy isa Tuple + for (darg_copy_, darg_inplace_) in zip(dargs_copy, dargs_inplace) + if darg_copy_ isa Mooncake.FData + for (c_f, i_f) in zip(Mooncake._fields(darg_copy_), Mooncake._fields(darg_inplace_)) + @test c_f == i_f + end + else + @test darg_copy_ == darg_inplace_ + end + end + else + @test dargs_copy == dargs_inplace + end return end @@ -182,19 +243,19 @@ function test_mooncake_qr( @testset "qr_compact" begin QR, ΔQR = ad_qr_compact_setup(A) dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) test_pullbacks_match(qr_compact!, qr_compact, A, QR, ΔQR) end @testset "qr_null" begin N, ΔN = ad_qr_null_setup(A) dN = make_mooncake_tangent(copy(ΔN)) - Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol, rtol) test_pullbacks_match(qr_null!, qr_null, A, N, ΔN) end @testset "qr_full" begin QR, ΔQR = ad_qr_full_setup(A) dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) test_pullbacks_match(qr_full!, qr_full, A, QR, ΔQR) end @testset "qr_compact - rank-deficient A" begin @@ -203,7 +264,7 @@ function test_mooncake_qr( Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) test_pullbacks_match(qr_compact!, qr_compact, Ard, QR, ΔQR) end end @@ -219,19 +280,19 @@ function test_mooncake_lq( A = instantiate_matrix(T, sz) @testset "lq_compact" begin LQ, ΔLQ = ad_lq_compact_setup(A) - Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(lq_compact!, lq_compact, A, LQ, ΔLQ) end @testset "lq_null" begin Nᴴ, ΔNᴴ = ad_lq_null_setup(A) dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol, rtol) test_pullbacks_match(lq_null!, lq_null, A, Nᴴ, ΔNᴴ) end @testset "lq_full" begin LQ, ΔLQ = ad_lq_full_setup(A) dLQ = make_mooncake_tangent(ΔLQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) test_pullbacks_match(lq_full!, lq_full, A, LQ, ΔLQ) end @testset "lq_compact - rank-deficient A" begin @@ -240,7 +301,7 @@ function test_mooncake_lq( Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) dLQ = make_mooncake_tangent(ΔLQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) test_pullbacks_match(lq_compact!, lq_compact, Ard, LQ, ΔLQ) end end @@ -258,13 +319,13 @@ function test_mooncake_eig( @testset "eig_full" begin DV, ΔDV, ΔD2V = ad_eig_full_setup(A) dDV = make_mooncake_tangent(ΔD2V) - Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol, rtol) test_pullbacks_match(eig_full!, eig_full, A, DV, ΔD2V) end @testset "eig_vals" begin D, ΔD = ad_eig_vals_setup(A) dD = make_mooncake_tangent(ΔD) - Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dD, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dD, atol, rtol) test_pullbacks_match(eig_vals!, eig_vals, A, D, ΔD) end @testset "eig_trunc" begin @@ -272,22 +333,22 @@ function test_mooncake_eig( truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) - dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) - test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + dDVerr = make_mooncake_tangent((copy.(ΔDVtrunc)..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol) - test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) + test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) end truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) - test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol) - test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) + test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) end end end @@ -304,13 +365,13 @@ function test_mooncake_eigh( @testset "eigh_full" begin DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) dDV = make_mooncake_tangent(ΔD2V) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol, rtol) test_pullbacks_match(mc_copy_eigh_full!, mc_copy_eigh_full, A, DV, ΔD2V) end @testset "eigh_vals" begin D, ΔD = ad_eigh_vals_setup(A) dD = make_mooncake_tangent(ΔD) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol, rtol) test_pullbacks_match(mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD) end @testset "eigh_trunc" begin @@ -319,22 +380,22 @@ function test_mooncake_eigh( DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) end D = eigh_vals(A / 2) truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) end end end @@ -351,43 +412,44 @@ function test_mooncake_svd( @testset "svd_compact" begin USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) test_pullbacks_match(svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) end @testset "svd_full" begin USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) test_pullbacks_match(svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) end @testset "svd_vals" begin S, ΔS = ad_svd_vals_setup(A) - Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(svd_vals!, svd_vals, A, S, ΔS) end @testset "svd_trunc" begin - S, ΔS = ad_svd_vals_setup(A) @testset for r in 1:4:minmn truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔUSVᴴtrunc) dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg; ȳ = ΔUSVᴴtrunc) end @testset "trunctol" begin + A = instantiate_matrix(T, sz) + S, ΔS = ad_svd_vals_setup(A) truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔUSVᴴtrunc) dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg; ȳ = ΔUSVᴴtrunc) end end end @@ -405,14 +467,14 @@ function test_mooncake_polar( @testset "left_polar" begin if m >= n WP, ΔWP = ad_left_polar_setup(A) - Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(left_polar!, left_polar, A, WP, ΔWP) end end @testset "right_polar" begin if m <= n PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) - Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) end end @@ -444,34 +506,34 @@ function test_mooncake_orthnull( m, n = size(A) VC, ΔVC = ad_left_orth_setup(A) CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(left_orth!, left_orth, A, VC, ΔVC) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) - Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) if m >= n - Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) end N, ΔN = ad_left_null_setup(A) dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) + Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dN) test_pullbacks_match(((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) - Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) if m <= n - Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) end Nᴴ, ΔNᴴ = ad_right_null_setup(A) dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) + Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dNᴴ) test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) end end