diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index e4ec256f..f44fd123 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -3,13 +3,15 @@ module MatrixAlgebraKitMooncakeExt using Mooncake using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive using MatrixAlgebraKit -using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero! -using MatrixAlgebraKit: qr_pullback!, lq_pullback! -using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! -using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! -using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback! -using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! -using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback! +using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero!, truncate, truncation_error! +using MatrixAlgebraKit: qr_pullback!, qr_pushforward!, lq_pullback!, lq_pushforward! +using MatrixAlgebraKit: qr_null_pullback!, qr_null_pushforward!, lq_null_pullback!, lq_null_pushforward! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback! +using MatrixAlgebraKit: eig_vals_pullback!, eigh_vals_pullback!, eig_vals_pushforward!, eigh_vals_pushforward! +using MatrixAlgebraKit: eig_pushforward!, eigh_pushforward!, eig_trunc_pushforward!, eigh_trunc_pushforward! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pushforward!, right_polar_pushforward! +using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_pushforward!, svd_trunc_pushforward! +using MatrixAlgebraKit: svd_vals_pullback!, svd_vals_pushforward! using LinearAlgebra @@ -28,20 +30,21 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu end Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any} -# two-argument in-place factorizations like LQ, QR, EIG -for (f!, f, pb, adj) in ( - (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), - (:lq_full!, :lq_full, :lq_pullback!, :lq_adjoint), - (:qr_compact!, :qr_compact, :qr_pullback!, :qr_adjoint), - (:lq_compact!, :lq_compact, :lq_pullback!, :lq_adjoint), - (:eig_full!, :eig_full, :eig_pullback!, :eig_adjoint), - (:eigh_full!, :eigh_full, :eigh_pullback!, :eigh_adjoint), - (:left_polar!, :left_polar, :left_polar_pullback!, :left_polar_adjoint), - (:right_polar!, :right_polar, :right_polar_pullback!, :right_polar_adjoint), +# two-argument factorizations like LQ, QR, EIG +for (f!, f, pb, pf, adj) in ( + (:qr_full!, :qr_full, :qr_pullback!, :qr_pushforward!, :dqr_adjoint), + (:qr_compact!, :qr_compact, :qr_pullback!, :qr_pushforward!, :dqr_adjoint), + (:lq_full!, :lq_full, :lq_pullback!, :lq_pushforward!, :dlq_adjoint), + (:lq_compact!, :lq_compact, :lq_pullback!, :lq_pushforward!, :dlq_adjoint), + (:eig_full!, :eig_full, :eig_pullback!, :eig_pushforward!, :deig_adjoint), + (:eigh_full!, :eigh_full, :eigh_pullback!, :eigh_pushforward!, :deigh_adjoint), + (:left_polar!, :left_polar, :left_polar_pullback!, :left_polar_pushforward!, :dleft_polar_adjoint), + (:right_polar!, :right_polar, :right_polar_pullback!, :right_polar_pushforward!, :dright_polar_adjoint), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) args = Mooncake.primal(args_dargs) @@ -63,7 +66,6 @@ for (f!, f, pb, adj) in ( end return args_dargs, $adj end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) output = $f(A, Mooncake.primal(alg_dalg)) @@ -84,15 +86,39 @@ for (f!, f, pb, adj) in ( end return output_codual, $adj end + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + args = Mooncake.primal(args_dargs) + args = $f!(A, args, Mooncake.primal(alg_dalg)) + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(args[1], dargs[1]) + arg2, darg2 = arrayify(args[2], dargs[2]) + darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2)) + zero!(dA) + return args_dargs + end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + args = $f(A, Mooncake.primal(alg_dalg)) + args_dargs = Mooncake.zero_dual(args) + arg1, arg2 = args + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(arg1, dargs[1]) + arg2, darg2 = arrayify(arg2, dargs[2]) + $pf(dA, A, (arg1, arg2), (darg1, darg2)) + return args_dargs + end end end -for (f!, f, pb, adj) in ( - (:qr_null!, :qr_null, :qr_null_pullback!, :qr_null_adjoint), - (:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_adjoint), +for (f!, f, pb, pf, adj) in ( + (:qr_null!, :qr_null, :qr_null_pullback!, :qr_null_pushforward!, :dqr_null_adjoint), + (:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_pushforward!, :dlq_null_adjoint), ) + #forward mode not implemented yet @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) Ac = copy(A) @@ -108,7 +134,6 @@ for (f!, f, pb, adj) in ( end return arg_darg, $adj end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) output = $f(A, Mooncake.primal(alg_dalg)) @@ -121,15 +146,32 @@ for (f!, f, pb, adj) in ( end return output_codual, $adj end + function Mooncake.frule!!(f_df::Dual{typeof($f!)}, A_dA::Dual, arg_darg::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + Ac = MatrixAlgebraKit.copy_input($f, A) + arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg)) + arg = $f!(A, arg, Mooncake.primal(alg_dalg)) + $pf(dA, Ac, arg, darg) + zero!(dA) + return arg_darg + end + function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + arg = $f(A, Mooncake.primal(alg_dalg)) + darg = Mooncake.zero_tangent(arg) + $pf(dA, A, arg, darg) + return Dual(arg, darg) + end end end -for (f!, f, f_full, pb, adj) in ( - (:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_adjoint), - (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint), +for (f!, f, f_full, pb, pf, adj) in ( + (:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_pushforward!, :eig_vals_adjoint), + (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_pushforward!, :eigh_vals_adjoint), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -147,7 +189,16 @@ for (f!, f, f_full, pb, adj) in ( end return D_dD, $adj end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + D, dD = arrayify(D_dD) + nD, V = $f_full(A, Mooncake.primal(alg_dalg)) + copy!(D, diagview(nD)) + $pf(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing)) + zero!(dA) + return D_dD + end function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -164,16 +215,25 @@ for (f!, f, f_full, pb, adj) in ( end return output_codual, $adj end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + fullD, V = $f_full(A, Mooncake.primal(alg_dalg)) + D_dD = Mooncake.zero_dual(diagview(fullD)) + D, dD = arrayify(D_dD) + $pf(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing)) + return D_dD + end end end -for (f!, f, f_ne!, f_ne, pb, adj) in ( - (:eig_trunc!, :eig_trunc, :eig_trunc_no_error!, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint), - (:eigh_trunc!, :eigh_trunc, :eigh_trunc_no_error!, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint), +for (f!, f, f_ne!, f_ne, pb, pf, adj) in ( + (:eig_trunc!, :eig_trunc, :eig_trunc_no_error!, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_pushforward!, :eig_trunc_adjoint), + (:eigh_trunc!, :eigh_trunc, :eigh_trunc_no_error!, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_pushforward!, :eigh_trunc_adjoint), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -227,8 +287,8 @@ for (f!, f, f_ne!, f_ne, pb, adj) in ( end return output_codual, $adj end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f_ne!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f_ne!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -280,6 +340,32 @@ for (f!, f, f_ne!, f_ne, pb, adj) in ( end return output_codual, $adj end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + output = $f(A, alg) + output_dual = Mooncake.zero_dual(output) + dD_ = Mooncake.tangent(output_dual)[1] + dV_ = Mooncake.tangent(output_dual)[2] + D, dD = arrayify(output[1], dD_) + V, dV = arrayify(output[2], dV_) + $pf(dA, A, (D, V), (dD, dV)) + return output_dual + end + function Mooncake.frule!!(::Dual{typeof($f_ne)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + output = $f_ne(A, alg) + output_dual = Mooncake.zero_dual(output) + dD_ = Mooncake.tangent(output_dual)[1] + dV_ = Mooncake.tangent(output_dual)[2] + D, dD = arrayify(output[1], dD_) + V, dV = arrayify(output[2], dV_) + $pf(dA, A, (D, V), (dD, dV)) + return output_dual + end end end @@ -288,7 +374,8 @@ for (f!, f) in ( (:svd_compact!, :svd_compact), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) Ac = copy(A) @@ -323,7 +410,6 @@ for (f!, f) in ( end return CoDual(output, dUSVᴴ), svd_adjoint end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) USVᴴ = $f(A, Mooncake.primal(alg_dalg)) @@ -357,10 +443,75 @@ for (f!, f) in ( end return USVᴴ_codual, svd_adjoint end + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual) + # compute primal + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + A, dA = arrayify(A_dA) + $f!(A, USVᴴ, Mooncake.primal(alg_dalg)) + # update tangents + U_, S_, Vᴴ_ = USVᴴ + dU_, dS_, dVᴴ_ = dUSVᴴ + U, dU = arrayify(U_, dU_) + S, dS = arrayify(S_, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_) + minmn = min(size(A)...) + if $(f == svd_compact!) # compact + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + else # full + vU = view(U, :, 1:minmn) + vS = view(S, 1:minmn, 1:minmn) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(dU, :, 1:minmn) + vdS = view(dS, 1:minmn, 1:minmn) + vdVᴴ = view(dVᴴ, 1:minmn, :) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + end + zero!(dA) + return USVᴴ_dUSVᴴ + end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + USVᴴ = $f(A, Mooncake.primal(alg_dalg)) + # update tangents + U, S, Vᴴ = USVᴴ + dU_ = Mooncake.zero_tangent(U) + dS_ = Mooncake.zero_tangent(S) + dVᴴ_ = Mooncake.zero_tangent(Vᴴ) + U, dU = arrayify(U, dU_) + S, dS = arrayify(S, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_) + if $(f == svd_compact!) # compact + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + else # full + minmn = min(size(A)...) + vU = view(U, :, 1:minmn) + vS = view(S, 1:minmn, 1:minmn) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(dU, :, 1:minmn) + vdS = view(dS, 1:minmn, 1:minmn) + vdVᴴ = view(dVᴴ, 1:minmn, :) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + end + return Dual(USVᴴ, (dU_, dS_, dVᴴ_)) + end end end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(MatrixAlgebraKit.svd_vals!), Any, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{typeof(svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual) + # compute primal + S, dS = Mooncake.arrayify(S_dS) + A, dA = Mooncake.arrayify(A_dA) + U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + # update tangent + copyto!(dS, diag(real.(Vᴴ * dA' * U))) + copyto!(S, diagview(nS)) + zero!(dA) + return S_dS +end + function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -377,7 +528,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua return S_dS, svd_vals_adjoint end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -396,8 +547,17 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co end return S_codual, svd_vals_adjoint end +function Mooncake.frule!!(::Dual{typeof(svd_vals)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + U, S, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + S_dS = Mooncake.zero_dual(diagview(S)) + S_, dS = arrayify(S_dS) + copyto!(dS, diag(real.(Vᴴ * dA' * U))) + return S_dS +end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -437,10 +597,8 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS end return output_codual, svd_trunc_adjoint end - -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) - # compute primal A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) output = svd_trunc(A, alg) @@ -465,7 +623,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C return output_codual, svd_trunc_adjoint end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -505,7 +663,34 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U return output_codual, svd_trunc_adjoint end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{typeof(svd_trunc)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = Mooncake.arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + USVᴴ = svd_compact(A, alg.alg) + U, S, Vᴴ = USVᴴ + dUfull = zeros(eltype(U), size(U)) + dSfull = Diagonal(zeros(eltype(S), length(diagview(S)))) + dVᴴfull = zeros(eltype(Vᴴ), size(Vᴴ)) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dUfull, dSfull, dVᴴfull)) + + USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = truncation_error!(diagview(S), ind) + output = (USVᴴtrunc..., ϵ) + output_dual = Mooncake.zero_dual(output) + Utrunc, Strunc, Vᴴtrunc, ϵ = output + dU_, dS_, dVᴴ_, dϵ = Mooncake.tangent(output_dual) + Utrunc, dU = arrayify(Utrunc, dU_) + Strunc, dS = arrayify(Strunc, dS_) + Vᴴtrunc, dVᴴ = arrayify(Vᴴtrunc, dVᴴ_) + dU .= view(dUfull, :, ind) + diagview(dS) .= view(diagview(dSfull), ind) + dVᴴ .= view(dVᴴfull, ind, :) + return output_dual +end + + +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -531,4 +716,29 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al return output_codual, svd_trunc_adjoint end +function Mooncake.frule!!(::Dual{typeof(svd_trunc_no_error)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + USVᴴ = svd_compact(A, alg.alg) + U, S, Vᴴ = USVᴴ + dUfull = zeros(eltype(U), size(U)) + dSfull = Diagonal(zeros(eltype(S), length(diagview(S)))) + dVᴴfull = zeros(eltype(Vᴴ), size(Vᴴ)) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dUfull, dSfull, dVᴴfull)) + + USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc) + output = USVᴴtrunc + output_dual = Mooncake.zero_dual(output) + Utrunc, Strunc, Vᴴtrunc = output + dU_, dS_, dVᴴ_ = Mooncake.tangent(output_dual) + Utrunc, dU = arrayify(Utrunc, dU_) + Strunc, dS = arrayify(Strunc, dS_) + Vᴴtrunc, dVᴴ = arrayify(Vᴴtrunc, dVᴴ_) + dU .= view(dUfull, :, ind) + diagview(dS) .= view(diagview(dSfull), ind) + dVᴴ .= view(dVᴴfull, ind, :) + return output_dual +end + end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 03fb05bd..6651de67 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -117,4 +117,11 @@ include("pullbacks/eigh.jl") include("pullbacks/svd.jl") include("pullbacks/polar.jl") +include("pushforwards/qr.jl") +include("pushforwards/lq.jl") +include("pushforwards/eig.jl") +include("pushforwards/eigh.jl") +include("pushforwards/polar.jl") +include("pushforwards/svd.jl") + end diff --git a/src/pushforwards/eig.jl b/src/pushforwards/eig.jl new file mode 100644 index 00000000..a48df3aa --- /dev/null +++ b/src/pushforwards/eig.jl @@ -0,0 +1,19 @@ +function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...) + D, V = DV + ΔD, ΔV = ΔDV + iVΔAV = inv(V) * ΔA * V + diagview(ΔD) .= diagview(iVΔAV) + if !iszerotangent(ΔV) + F = 1 ./ (transpose(diagview(D)) .- diagview(D)) + fill!(diagview(F), zero(eltype(F))) + K̇ = F .* iVΔAV + mul!(ΔV, V, K̇, 1, 0) + end + return ΔDV +end + +function eig_trunc_pushforward!(ΔA, A, DV, ΔDV; kwargs...) end + +function eig_vals_pushforward!(ΔA, A, DV, ΔD; kwargs...) + return eig_pushforward!(ΔA, A, DV, ΔD; kwargs...) +end diff --git a/src/pushforwards/eigh.jl b/src/pushforwards/eigh.jl new file mode 100644 index 00000000..ebb470b9 --- /dev/null +++ b/src/pushforwards/eigh.jl @@ -0,0 +1,21 @@ +function eigh_pushforward!(dA, A, DV, dDV; kwargs...) + D, V = DV + dD, dV = dDV + tmpV = V \ dA + ∂K = tmpV * V + ∂Kdiag = diag(∂K) + diagview(dD) .= real.(∂Kdiag) + if !iszerotangent(dV) + dDD = transpose(diagview(D)) .- diagview(D) + F = one(eltype(dDD)) ./ dDD + diagview(F) .= zero(eltype(F)) + ∂K .*= F + ∂V = mul!(tmpV, V, ∂K) + copyto!(dV, ∂V) + end + return (dD, dV) +end + +function eigh_trunc_pushforward!(dA, A, DV, dDV; kwargs...) end + +function eigh_vals_pushforward!(dA, A, DV, dDV, ind = Colon(); kwargs...) end diff --git a/src/pushforwards/lq.jl b/src/pushforwards/lq.jl new file mode 100644 index 00000000..6490e1ef --- /dev/null +++ b/src/pushforwards/lq.jl @@ -0,0 +1,7 @@ +function lq_pushforward!(dA, A, LQ, dLQ; tol::Real = default_pullback_gauge_atol(LQ[1]), rank_atol::Real = tol, gauge_atol::Real = tol) + return qr_pushforward!(adjoint(dA), adjoint(A), adjoint.(reverse(LQ)), adjoint.(reverse(dLQ)); tol, rank_atol, gauge_atol) +end + +function lq_null_pushforward!(dA, A, Nᴴ, dNᴴ; tol::Real = default_pullback_gauge_atol(Nᴴ), rank_atol::Real = tol, gauge_atol::Real = tol) + return iszero(min(size(Nᴴ)...)) && return # nothing to do +end diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl new file mode 100644 index 00000000..1e0da1b2 --- /dev/null +++ b/src/pushforwards/polar.jl @@ -0,0 +1,21 @@ +function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) + W, P = WP + ΔW, ΔP = ΔWP + aWdA = adjoint(W) * ΔA + K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA))) + L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W * adjoint(W)) * ΔA * inv(P) + ΔW .= W * K̇ + L̇ + ΔP .= aWdA - K̇ * P + return (ΔW, ΔP) +end + +function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) + P, Wᴴ = PWᴴ + ΔP, ΔWᴴ = ΔPWᴴ + dAW = ΔA * adjoint(Wᴴ) + K̇ = sylvester(P, P, -(dAW - adjoint(dAW))) + L̇ = inv(P) * ΔA * (Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ) + ΔWᴴ .= K̇ * Wᴴ + L̇ + ΔP .= dAW - P * K̇ + return (ΔWᴴ, ΔP) +end diff --git a/src/pushforwards/qr.jl b/src/pushforwards/qr.jl new file mode 100644 index 00000000..37781193 --- /dev/null +++ b/src/pushforwards/qr.jl @@ -0,0 +1,61 @@ +function qr_pushforward!(dA, A, QR, dQR; tol::Real = default_pullback_gauge_atol(QR[2]), rank_atol::Real = tol, gauge_atol::Real = tol) + Q, R = QR + m = size(A, 1) + n = size(A, 2) + minmn = min(m, n) + Rd = diagview(R) + p = findlast(>=(rank_atol) ∘ abs, Rd) + + m1 = p + m2 = minmn - p + m3 = m - minmn + n1 = p + n2 = n - p + + Q1 = view(Q, 1:m, 1:m1) # full rank portion + Q2 = view(Q, 1:m, (m1 + 1):(m2 + m1)) + R11 = view(R, 1:m1, 1:n1) + R12 = view(R, 1:m1, (n1 + 1):n) + + dA1 = view(dA, 1:m, 1:n1) + dA2 = view(dA, 1:m, (n1 + 1):n) + + dQ, dR = dQR + dQ1 = view(dQ, 1:m, 1:m1) + dQ2 = view(dQ, 1:m, (m1 + 1):(m2 + m1)) + dQ3 = minmn + 1 < size(dQ, 2) ? view(dQ, :, (minmn + 1):size(dQ, 2)) : similar(dQ, eltype(dQ), (0, 0)) + dR11 = view(dR, 1:m1, 1:n1) + dR12 = view(dR, 1:m1, (n1 + 1):n) + dR22 = view(dR, (m1 + 1):(m1 + m2), (n1 + 1):n) + + # fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need + invR11 = inv(R11) + tmp = Q1' * dA1 * invR11 + Rtmp = tmp + tmp' + diagview(Rtmp) ./= 2 + ltRtmp = view(Rtmp, lowertriangularind(Rtmp)) + ltRtmp .= zero(eltype(Rtmp)) + dR11 .= Rtmp * R11 + dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11 + dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12) + if size(Q2, 2) > 0 + dQ2 .= -Q1 * (Q1' * Q2) + dQ2 .+= Q2 * (Q2' * dQ2) + end + if m3 > 0 && size(Q, 2) > minmn + # only present for qr_full or rank-deficient qr_compact + Q′ = view(Q, :, 1:minmn) + Q3 = view(Q, :, (minmn + 1):m) + #dQ3 .= Q′ * (Q′' * Q3) + dQ3 .= Q3 + end + if !isempty(dR22) + _, r22 = qr_compact(dA2 - dQ1 * R12 - Q1 * dR12; positive = true) + dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2)) + end + return (dQ, dR) +end + +function qr_null_pushforward!(dA, A, N, dN; tol::Real = default_pullback_gauge_atol(N), rank_atol::Real = tol, gauge_atol::Real = tol) + return iszero(min(size(N)...)) && return # nothing to do +end diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl new file mode 100644 index 00000000..3d2d2733 --- /dev/null +++ b/src/pushforwards/svd.jl @@ -0,0 +1,82 @@ +function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol = default_pullback_rank_atol(A), kwargs...) + U, Smat, Vᴴ = USVᴴ + m, n = size(U, 1), size(Vᴴ, 2) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) + minmn = min(m, n) + S = diagview(Smat) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + r = searchsortedlast(S, rank_atol; rev = true) # rank + + vΔU = view(ΔU, :, 1:r) + vΔS = view(ΔS, 1:r, 1:r) + vΔVᴴ = view(ΔVᴴ, 1:r, :) + + vU = view(U, :, 1:r) + vS = view(S, 1:r) + vSmat = view(Smat, 1:r, 1:r) + vVᴴ = view(Vᴴ, 1:r, :) + + # compact region + vV = adjoint(vVᴴ) + UΔAV = vU' * ΔA * vV + copyto!(diagview(vΔS), diag(real.(UΔAV))) + F = one(eltype(S)) ./ (transpose(vS) .- vS) + G = one(eltype(S)) ./ (transpose(vS) .+ vS) + diagview(F) .= zero(eltype(F)) + hUΔAV = F .* (UΔAV + UΔAV') ./ 2 + aUΔAV = G .* (UΔAV - UΔAV') ./ 2 + K̇ = hUΔAV + aUΔAV + Ṁ = hUΔAV - aUΔAV + + # check gauge condition + @assert isantihermitian(K̇) + @assert isantihermitian(Ṁ) + K̇diag = diagview(K̇) + for i in 1:length(K̇diag) + @assert K̇diag[i] ≈ (im / 2) * imag(diagview(UΔAV)[i]) / S[i] + end + + ∂U = vU * K̇ + ∂V = vV * Ṁ + # full component + if size(U, 2) > minmn && size(Vᴴ, 1) > minmn + Uperp = view(U, :, (minmn + 1):m) + Vᴴperp = view(Vᴴ, (minmn + 1):n, :) + + aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp) + + UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2))) + fill!(UÃÃV, 0) + view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV + view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV' + rhs = vcat(adjoint(Uperp * ΔA * Vᴴ), Vᴴperp * ΔA' * U) + superKM = -sylvester(UÃÃV, Smat, rhs) + K̇perp = view(superKM, 1:size(aUAV, 2)) + Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2))) + ∂U .+= Uperp * K̇perp + ∂V .+= Vᴴperp * Ṁperp + else + ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU * vU') + ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV * vVᴴ) + upper = ImUU * ΔA * vV + lower = ImVV * ΔA' * vU + rhs = vcat(upper, lower) + + Ã = ImUU * A * ImVV + ÃÃ = similar(A, (m + n, m + n)) + fill!(ÃÃ, 0) + view(ÃÃ, (1:m), m .+ (1:n)) .= Ã + view(ÃÃ, m .+ (1:n), 1:m) .= Ã' + + superLN = -sylvester(ÃÃ, vSmat, rhs) + ∂U += view(superLN, 1:size(upper, 1), :) + ∂V += view(superLN, (size(upper, 1) + 1):(size(upper, 1) + size(lower, 1)), :) + end + copyto!(vΔU, ∂U) + adjoint!(vΔVᴴ, ∂V) + return (ΔU, ΔS, ΔVᴴ) +end + +function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...) + +end diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index 29d65e31..967348b4 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -243,19 +243,19 @@ function test_mooncake_qr( @testset "qr_compact" begin QR, ΔQR = ad_qr_compact_setup(A) dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) + Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, output_tangent = dQR, atol, rtol) test_pullbacks_match(qr_compact!, qr_compact, A, QR, ΔQR) end @testset "qr_null" begin N, ΔN = ad_qr_null_setup(A) dN = make_mooncake_tangent(copy(ΔN)) - Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol, rtol) + Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, output_tangent = dN, atol, rtol) test_pullbacks_match(qr_null!, qr_null, A, N, ΔN) end @testset "qr_full" begin QR, ΔQR = ad_qr_full_setup(A) dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) + Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, output_tangent = dQR, atol, rtol) test_pullbacks_match(qr_full!, qr_full, A, QR, ΔQR) end @testset "qr_compact - rank-deficient A" begin @@ -264,7 +264,7 @@ function test_mooncake_qr( Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) + Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, output_tangent = dQR, atol, rtol) test_pullbacks_match(qr_compact!, qr_compact, Ard, QR, ΔQR) end end @@ -280,19 +280,19 @@ function test_mooncake_lq( A = instantiate_matrix(T, sz) @testset "lq_compact" begin LQ, ΔLQ = ad_lq_compact_setup(A) - Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) + Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, atol, rtol) test_pullbacks_match(lq_compact!, lq_compact, A, LQ, ΔLQ) end @testset "lq_null" begin Nᴴ, ΔNᴴ = ad_lq_null_setup(A) dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, output_tangent = dNᴴ, atol, rtol) test_pullbacks_match(lq_null!, lq_null, A, Nᴴ, ΔNᴴ) end @testset "lq_full" begin LQ, ΔLQ = ad_lq_full_setup(A) dLQ = make_mooncake_tangent(ΔLQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, output_tangent = dLQ, atol, rtol) test_pullbacks_match(lq_full!, lq_full, A, LQ, ΔLQ) end @testset "lq_compact - rank-deficient A" begin @@ -301,7 +301,7 @@ function test_mooncake_lq( Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) dLQ = make_mooncake_tangent(ΔLQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, output_tangent = dLQ, atol, rtol) test_pullbacks_match(lq_compact!, lq_compact, Ard, LQ, ΔLQ) end end @@ -319,13 +319,13 @@ function test_mooncake_eig( @testset "eig_full" begin DV, ΔDV, ΔD2V = ad_eig_full_setup(A) dDV = make_mooncake_tangent(ΔD2V) - Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, output_tangent = dDV, atol, rtol) test_pullbacks_match(eig_full!, eig_full, A, DV, ΔD2V) end @testset "eig_vals" begin D, ΔD = ad_eig_vals_setup(A) dD = make_mooncake_tangent(ΔD) - Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dD, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, output_tangent = dD, atol, rtol) test_pullbacks_match(eig_vals!, eig_vals, A, D, ΔD) end @testset "eig_trunc" begin @@ -334,20 +334,20 @@ function test_mooncake_eig( DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dDVerr = make_mooncake_tangent((copy.(ΔDVtrunc)..., ϵ)) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; output_tangent = dDVerr, atol, rtol) test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; output_tangent = dDVtrunc, atol, rtol) test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) end truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; output_tangent = dDVerr, atol, rtol) test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; output_tangent = dDVtrunc, atol, rtol) test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) end end @@ -365,13 +365,13 @@ function test_mooncake_eigh( @testset "eigh_full" begin DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) dDV = make_mooncake_tangent(ΔD2V) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol, rtol) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; output_tangent = dDV, is_primitive = false, atol, rtol) test_pullbacks_match(mc_copy_eigh_full!, mc_copy_eigh_full, A, DV, ΔD2V) end @testset "eigh_vals" begin D, ΔD = ad_eigh_vals_setup(A) dD = make_mooncake_tangent(ΔD) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol, rtol) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; output_tangent = dD, is_primitive = false, atol, rtol) test_pullbacks_match(mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD) end @testset "eigh_trunc" begin @@ -380,10 +380,10 @@ function test_mooncake_eigh( DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; output_tangent = dDVerr, atol, rtol, is_primitive = false) test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; output_tangent = dDVtrunc, atol, rtol, is_primitive = false) test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) end D = eigh_vals(A / 2) @@ -391,10 +391,10 @@ function test_mooncake_eigh( DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; output_tangent = dDVerr, atol, rtol, is_primitive = false) test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; output_tangent = dDVtrunc, atol, rtol, is_primitive = false) test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) end end @@ -412,18 +412,18 @@ function test_mooncake_svd( @testset "svd_compact" begin USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, output_tangent = dUSVᴴ, atol, rtol) test_pullbacks_match(svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) end @testset "svd_full" begin USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, output_tangent = dUSVᴴ, atol, rtol) test_pullbacks_match(svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) end @testset "svd_vals" begin S, ΔS = ad_svd_vals_setup(A) - Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, atol, rtol) test_pullbacks_match(svd_vals!, svd_vals, A, S, ΔS) end @testset "svd_trunc" begin @@ -432,10 +432,10 @@ function test_mooncake_svd( USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; output_tangent = dUSVᴴerr, atol, rtol) test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔUSVᴴtrunc) dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; output_tangent = dUSVᴴ, atol, rtol) test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg; ȳ = ΔUSVᴴtrunc) end @testset "trunctol" begin @@ -445,10 +445,10 @@ function test_mooncake_svd( USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) ϵ = zero(real(eltype(T))) dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; output_tangent = dUSVᴴerr, atol, rtol) test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔUSVᴴtrunc) dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; output_tangent = dUSVᴴ, atol, rtol) test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg; ȳ = ΔUSVᴴtrunc) end end @@ -467,14 +467,14 @@ function test_mooncake_polar( @testset "left_polar" begin if m >= n WP, ΔWP = ad_left_polar_setup(A) - Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) + Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, atol, rtol) test_pullbacks_match(left_polar!, left_polar, A, WP, ΔWP) end end @testset "right_polar" begin if m <= n PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) - Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) + Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, atol, rtol) test_pullbacks_match(right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) end end @@ -506,34 +506,34 @@ function test_mooncake_orthnull( m, n = size(A) VC, ΔVC = ad_left_orth_setup(A) CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth, A; atol, rtol, is_primitive = false) test_pullbacks_match(left_orth!, left_orth, A, VC, ΔVC) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth, A; atol, rtol, is_primitive = false) test_pullbacks_match(right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) - Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; atol, rtol, is_primitive = false) test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) if m >= n - Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; atol, rtol, is_primitive = false) test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) end N, ΔN = ad_left_null_setup(A) dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dN) + Mooncake.TestUtils.test_rule(rng, left_null_qr, A; atol, rtol, is_primitive = false, output_tangent = dN) test_pullbacks_match(((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) - Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; atol, rtol, is_primitive = false) test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) if m <= n - Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; atol, rtol, is_primitive = false) test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) end Nᴴ, ΔNᴴ = ad_right_null_setup(A) dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dNᴴ) + Mooncake.TestUtils.test_rule(rng, right_null_lq, A; atol, rtol, is_primitive = false, output_tangent = dNᴴ) test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) end end