diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 434f33ed4..8880dfcf1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -30,7 +30,8 @@ jobs: - symmetries - tensors - other - - autodiff + - mooncake + - chainrules os: - ubuntu-latest - macOS-latest @@ -55,7 +56,8 @@ jobs: - symmetries - tensors - other - - autodiff + - mooncake + - chainrules os: - ubuntu-latest - macOS-latest diff --git a/Project.toml b/Project.toml index ddcbe4141..a27a95f0d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorKit" uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" -authors = ["Jutho Haegeman, Lukas Devos"] version = "0.16.0" +authors = ["Jutho Haegeman, Lukas Devos"] [deps] LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" @@ -22,8 +22,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [extensions] TensorKitAdaptExt = "Adapt" @@ -34,6 +34,7 @@ TensorKitMooncakeExt = "Mooncake" [compat] Adapt = "4" +AllocCheck = "0.2.3" Aqua = "0.6, 0.7, 0.8" ArgParse = "1.2.0" CUDA = "5.9" @@ -42,10 +43,11 @@ ChainRulesTestUtils = "1" Combinatorics = "1" FiniteDifferences = "0.12" GPUArrays = "11.3.1" +JET = "0.9, 0.10, 0.11" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.6.3" -Mooncake = "0.4.183" +MatrixAlgebraKit = "0.6.4" +Mooncake = "0.5" OhMyThreads = "0.8.0" Printf = "1" Random = "1" @@ -56,7 +58,7 @@ TensorKitSectors = "0.3.3" TensorOperations = "5.1" Test = "1" TestExtras = "0.2,0.3" -TupleTools = "1.1" +TupleTools = "1.5" VectorInterface = "0.4.8, 0.5" Zygote = "0.7" cuTENSOR = "2" @@ -64,6 +66,7 @@ julia = "1.10" [extras] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -72,6 +75,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" @@ -82,4 +86,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"] +test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"] diff --git a/docs/src/lib/tensors.md b/docs/src/lib/tensors.md index 3e9520cbc..6bb19dee7 100644 --- a/docs/src/lib/tensors.md +++ b/docs/src/lib/tensors.md @@ -184,7 +184,7 @@ type `add_transform!`, for additional expert-mode options that allows for additi scaling, as well as the selection of a custom backend. ```@docs -permute(::AbstractTensorMap, ::Index2Tuple{N₁,N₂}) where {N₁,N₂} +permute(::AbstractTensorMap, ::Index2Tuple) braid(::AbstractTensorMap, ::Index2Tuple, ::IndexTuple) transpose(::AbstractTensorMap, ::Index2Tuple) repartition(::AbstractTensorMap, ::Int, ::Int) diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index b35c73f4c..7067bb280 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -1,17 +1,23 @@ module TensorKitMooncakeExt using Mooncake -using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal +using Mooncake: @zero_derivative, @is_primitive, + DefaultCtx, MinimalCtx, ReverseMode, NoFData, NoRData, CoDual, Dual, + arrayify, primal, tangent using TensorKit +import TensorKit as TK +using VectorInterface using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize import TensorOperations as TO -using VectorInterface: One, Zero using TupleTools - +using Random: AbstractRNG include("utility.jl") include("tangent.jl") include("linalg.jl") +include("indexmanipulations.jl") +include("vectorinterface.jl") include("tensoroperations.jl") +include("planaroperations.jl") end diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl new file mode 100644 index 000000000..76f2c126b --- /dev/null +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -0,0 +1,388 @@ +for transform in (:permute, :transpose) + add_transform! = Symbol(:add_, transform, :!) + add_transform_pullback = Symbol(add_transform!, :_pullback) + @eval @is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TK.$add_transform!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, + Number, Number, Vararg{Any}, + } + ) + + @eval function Mooncake.rrule!!( + ::CoDual{typeof(TK.$add_transform!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + ba_Δba::CoDual... + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + α, β = primal.((α_Δα, β_Δβ)) + ba = primal.(ba_Δba) + + C_cache = copy(C) + + # if we need to compute Δa, it is faster to allocate an intermediate permuted A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Ap = if _needs_tangent(α) + Ap = $transform(A, p) + add!(C, Ap, α, β) + Ap + else + TK.$add_transform!(C, A, p, α, β, ba...) + nothing + end + + function $add_transform_pullback(::NoRData) + copy!(C, C_cache) + + # ΔA + ip = invperm(linearize(p)) + pΔA = _repartition(ip, A) + TK.$add_transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...) + ΔAr = NoRData() + + # Δα + Δαr = if isnothing(Ap) + NoRData() + else + inner(Ap, ΔC) + end + + Δβr = pullback_dβ(ΔC, C, β) + ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() + + return NoRData(), ΔCr, ΔAr, NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... + end + + return C_ΔC, $add_transform_pullback + end +end + +@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TK.add_braid!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, IndexTuple, + Number, Number, Vararg{Any}, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TK.add_braid!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, levels_Δlevels::CoDual{<:IndexTuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + ba_Δba::CoDual... + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + levels = primal(levels_Δlevels) + α, β = primal.((α_Δα, β_Δβ)) + ba = primal.(ba_Δba) + + C_cache = copy(C) + + # if we need to compute Δa, it is faster to allocate an intermediate braided A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Ap = if _needs_tangent(α) + Ap = braid(A, p, levels) + add!(C, Ap, α, β) + Ap + else + TK.add_braid!(C, A, p, levels, α, β, ba...) + nothing + end + + function add_braid!_pullback(::NoRData) + copy!(C, C_cache) + + # ΔA + ip = invperm(linearize(p)) + pΔA = _repartition(ip, A) + ilevels = TupleTools.permute(levels, linearize(p)) + TK.add_braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...) + ΔAr = NoRData() + + # Δα + Δαr = if isnothing(Ap) + NoRData() + else + inner(Ap, ΔC) + end + + Δβr = pullback_dβ(C, ΔC, β) + ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() + + return NoRData(), ΔCr, ΔAr, NoRData(), NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... + end + + return C_ΔC, add_braid!_pullback +end + +# both are needed for correctly capturing every dispatch +@is_primitive DefaultCtx ReverseMode Tuple{typeof(twist!), AbstractTensorMap, Any} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), @NamedTuple{inv::Bool}, typeof(twist!), AbstractTensorMap, Any} + +function Mooncake.rrule!!(::CoDual{typeof(twist!)}, t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = false + inds = primal(inds_Δinds) + + # primal call + t_cache = copy(t) + twist!(t, inds; inv) + + function twist_pullback(::NoRData) + copy!(t, t_cache) + twist!(Δt, inds; inv = !inv) + return ntuple(Returns(NoRData()), 3) + end + + return t_Δt, twist_pullback + +end +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{@NamedTuple{inv::Bool}}, ::CoDual{typeof(twist!)}, + t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual + ) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = primal(kwargs_Δkwargs).inv + inds = primal(inds_Δinds) + + # primal call + t_cache = copy(t) + twist!(t, inds; inv) + + function twist_pullback(::NoRData) + copy!(t, t_cache) + twist!(Δt, inds; inv = !inv) + return ntuple(Returns(NoRData()), 5) + end + + return t_Δt, twist_pullback +end + +# both are needed for correctly capturing every dispatch +@is_primitive DefaultCtx ReverseMode Tuple{typeof(flip), AbstractTensorMap, Any} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), @NamedTuple{inv::Bool}, typeof(flip), AbstractTensorMap, Any} + +function Mooncake.rrule!!(::CoDual{typeof(flip)}, t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = false + inds = primal(inds_Δinds) + + # primal call + t_flipped = flip(t, inds; inv) + t_flipped_Δt_flipped = Mooncake.zero_fcodual(t_flipped) + _, Δt_flipped = arrayify(t_flipped_Δt_flipped) + + function flip_pullback(::NoRData) + Δt_flipflipped = flip(Δt_flipped, inds; inv = !inv) + add!(Δt, scalartype(Δt) <: Real ? real(Δt_flipflipped) : Δt_flipflipped) + return ntuple(Returns(NoRData()), 3) + end + + return t_flipped_Δt_flipped, flip_pullback +end +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{@NamedTuple{inv::Bool}}, ::CoDual{typeof(flip)}, + t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual + ) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = primal(kwargs_Δkwargs).inv + inds = primal(inds_Δinds) + + # primal call + t_flipped = flip(t, inds; inv) + t_flipped_Δt_flipped = Mooncake.zero_fcodual(t_flipped) + _, Δt_flipped = arrayify(t_flipped_Δt_flipped) + + function flip_pullback(::NoRData) + Δt_flipflipped = flip(Δt_flipped, inds; inv = !inv) + add!(Δt, scalartype(Δt) <: Real ? real(Δt_flipflipped) : Δt_flipflipped) + return ntuple(Returns(NoRData()), 5) + end + + return t_flipped_Δt_flipped, flip_pullback +end + +for insertunit in (:insertleftunit, :insertrightunit) + insertunit_pullback = Symbol(insertunit, :_pullback) + @eval begin + # both are needed for correctly capturing every dispatch + @is_primitive DefaultCtx ReverseMode Tuple{typeof($insertunit), AbstractTensorMap, Val} + @is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), NamedTuple, typeof($insertunit), AbstractTensorMap, Val} + + function Mooncake.rrule!!(::CoDual{typeof($insertunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{<:Val}) + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + + # tdst shares data with tsrc if <:TensorMap, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap + tsrc_cache = copy(tsrc) + tdst_Δtdst = CoDual( + $insertunit(tsrc, ival), + $insertunit(Mooncake.tangent(tsrc_Δtsrc), ival) + ) + else + tsrc_cache = nothing + tdst = $insertunit(tsrc, ival) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function $insertunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + add!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 3) + end + + return tdst_Δtdst, $insertunit_pullback + end + function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{<:NamedTuple}, + ::CoDual{typeof($insertunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{<:Val} + ) + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + kwargs = primal(kwargs_Δkwargs) + + # tdst shares data with tsrc if <:TensorMap & copy=false, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap && !get(kwargs, :copy, false) + tsrc_cache = copy(tsrc) + tdst = $insertunit(tsrc, ival; kwargs...) + tdst_Δtdst = CoDual( + $insertunit(tsrc, ival; kwargs...), + $insertunit(Mooncake.tangent(tsrc_Δtsrc), ival; kwargs...) + ) + else + tsrc_cache = nothing + tdst = $insertunit(tsrc, ival; kwargs...) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function $insertunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + add!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 5) + end + + return tdst_Δtdst, $insertunit_pullback + end + end +end + + +@is_primitive DefaultCtx ReverseMode Tuple{typeof(removeunit), AbstractTensorMap, Val} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), NamedTuple, typeof(removeunit), AbstractTensorMap, Val} + +function Mooncake.rrule!!(::CoDual{typeof(removeunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{Val{i}}) where {i} + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + + # tdst shares data with tsrc if <:TensorMap, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap + tsrc_cache = copy(tsrc) + tdst_Δtdst = CoDual( + removeunit(tsrc, ival), + removeunit(Mooncake.tangent(tsrc_Δtsrc), ival) + ) + else + tsrc_cache = nothing + tdst = removeunit(tsrc, ival) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function removeunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + add!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 3) + end + + return tdst_Δtdst, removeunit_pullback +end +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{<:NamedTuple}, + ::CoDual{typeof(removeunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{<:Val} + ) + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + kwargs = primal(kwargs_Δkwargs) + + # tdst shares data with tsrc if <:TensorMap & copy=false, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap && !get(kwargs, :copy, false) + tsrc_cache = copy(tsrc) + tdst_Δtdst = CoDual( + removeunit(tsrc, ival; kwargs...), + removeunit(Mooncake.tangent(tsrc_Δtsrc), ival) + ) + else + tsrc_cache = nothing + tdst = removeunit(tsrc, ival; kwargs...) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function removeunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + add!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 5) + end + + return tdst_Δtdst, removeunit_pullback +end diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 56533d227..8f5306ac4 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -1,4 +1,47 @@ -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real} +# Shared +# ------ +pullback_dC!(ΔC, β) = (scale!(ΔC, conj(β)); return NoRData()) +pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData() + +@is_primitive DefaultCtx ReverseMode Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number} + +function Mooncake.rrule!!( + ::CoDual{typeof(mul!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number} + ) + (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) + α, β = primal.((α_Δα, β_Δβ)) + + # primal call + C_cache = copy(C) + AB = if _needs_tangent(α) + AB = A * B + add!(C, AB, α, β) + AB + else + mul!(C, A, B, α, β) + nothing + end + + function mul_pullback(::NoRData) + copy!(C, C_cache) + + mul!(ΔA, ΔC, B', conj(α), One()) + mul!(ΔB, A', ΔC, conj(α), One()) + ΔAr = NoRData() + ΔBr = NoRData() + Δαr = isnothing(AB) ? NoRData() : inner(AB, ΔC) + Δβr = pullback_dβ(ΔC, C, β) + ΔCr = pullback_dC!(ΔC, β) + + return NoRData(), ΔCr, ΔAr, ΔBr, Δαr, Δβr + end + + return C_ΔC, mul_pullback +end + +@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real} function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real}) t, Δt = arrayify(tΔt) @@ -12,3 +55,34 @@ function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorM end return CoDual(n, Mooncake.NoFData()), norm_pullback end + +@is_primitive DefaultCtx ReverseMode Tuple{typeof(tr), AbstractTensorMap} + +function Mooncake.rrule!!(::CoDual{typeof(tr)}, A_ΔA::CoDual{<:AbstractTensorMap}) + A, ΔA = arrayify(A_ΔA) + trace = tr(A) + + function tr_pullback(Δtrace) + for (_, b) in blocks(ΔA) + TensorKit.diagview(b) .+= Δtrace + end + return NoRData(), NoRData() + end + + return CoDual(trace, Mooncake.NoFData()), tr_pullback +end + +@is_primitive DefaultCtx ReverseMode Tuple{typeof(inv), AbstractTensorMap} + +function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorMap}) + A, ΔA = arrayify(A_ΔA) + Ainv_ΔAinv = Mooncake.zero_fcodual(inv(A)) + Ainv, ΔAinv = arrayify(Ainv_ΔAinv) + + function inv_pullback(::NoRData) + mul!(ΔA, Ainv' * ΔAinv, Ainv', -1, One()) + return NoRData(), NoRData() + end + + return Ainv_ΔAinv, inv_pullback +end diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl new file mode 100644 index 000000000..9633dfad6 --- /dev/null +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -0,0 +1,101 @@ +# planartrace! +# ------------ +@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TensorKit.planartrace!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, Index2Tuple, + Number, Number, + Any, Any, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TensorKit.planartrace!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, q_Δq::CoDual{<:Index2Tuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + backend_Δbackend::CoDual, allocator_Δallocator::CoDual + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + q = primal(q_Δq) + α, β = primal.((α_Δα, β_Δβ)) + backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) + + # primal call + C_cache = copy(C) + TensorKit.planartrace!(C, A, p, q, α, β, backend, allocator) + + function planartrace_pullback(::NoRData) + copy!(C, C_cache) + + ΔAr = planartrace_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend, allocator) # this typically returns NoRData() + Δαr = planartrace_pullback_Δα(ΔC, A, p, q, α, backend, allocator) + Δβr = planartrace_pullback_Δβ(ΔC, C, β) + ΔCr = planartrace_pullback_ΔC!(ΔC, β) # this typically returns NoRData() + + return NoRData(), + ΔCr, ΔAr, NoRData(), NoRData(), + Δαr, Δβr, NoRData(), NoRData() + end + + return C_ΔC, planartrace_pullback +end + +planartrace_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) + +# TODO: Fix planartrace pullback +# This implementation is slightly more involved than its non-planar counterpart +# this is because we lack a general `pAB` argument in `planarcontract`, and need +# to keep things planar along the way. +# In particular, we can't simply tensor product with multiple identities in one go +# if they aren't "contiguous", e.g. p = ((1, 4, 5), ()), q = ((2, 6), (3, 7)) +function planartrace_pullback_ΔA!( + ΔA, ΔC, A, p, q, α, backend, allocator + ) + if length(q[1]) == 0 + ip = invperm(linearize(p)) + pΔA = _repartition(ip, A) + TK.add_transpose!(ΔA, ΔC, pΔA, conj(α), One(), backend, allocator) + return NoRData() + end + # if length(q[1]) == 1 + # ip = invperm((p[1]..., q[2]..., p[2]..., q[1]...)) + # pdA = _repartition(ip, A) + # E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + # twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + # # pE = ((), trivtuple(TO.numind(q))) + # # pΔC = (trivtuple(TO.numind(p)), ()) + # TensorKit.planaradd!(ΔA, ΔC ⊗ E, pdA, conj(α), One(), backend, allocator) + # return NoRData() + # end + error("The reverse rule for `planartrace` is not yet implemented") +end + +function planartrace_pullback_Δα( + ΔC, A, p, q, α, backend, allocator + ) + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Tdα === NoRData && return NoRData() + + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = TO.tensoralloc_add(scalartype(A), A, p, false, Val(true), allocator) + TensorKit.planartrace!(At, A, p, q, One(), Zero(), backend, allocator) + Δα = inner(At, ΔC) + TO.tensorfree!(At, allocator) + return Δα +end + +function planartrace_pullback_Δβ(ΔC, C, β) + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Tdβ === NoRData && return NoRData() + + Δβ = inner(C, ΔC) + return Δβ +end diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 761e626f0..9154ba6f7 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -1,7 +1,219 @@ -function Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) - A = Mooncake.primal(A_dA) - dA_fw = Mooncake.tangent(A_dA) - data = dA_fw.data.data - dA = typeof(A)(data, A.space) - return A, dA +Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) = primal(A_dA), tangent(A_dA) + +function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TK.AdjointTensorMap}) + Aᴴ = Mooncake.primal(Aᴴ_ΔAᴴ) + ΔAᴴ = Mooncake.tangent(Aᴴ_ΔAᴴ) + A_ΔA = CoDual(Aᴴ', ΔAᴴ.data.parent) + A, ΔA = arrayify(A_ΔA) + return A', ΔA' +end + +# Define the tangent type of a TensorMap to be TensorMap itself. +# This has a number of benefits, but also correctly alters the +# inner product when dealing with non-abelian symmetries. +# +# Note: this implementation is technically a little lazy, since we are +# assuming that the tangent type of the underlying storage is also given +# by that same type. This should in principle work out fine, and will only +# fail for types that would be self-referential, which we choose to not support +# for now. + +Mooncake.@foldable Mooncake.tangent_type(::Type{T}, ::Type{NoRData}) where {T <: TensorMap} = T +Mooncake.@foldable Mooncake.tangent_type(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A} = + TK.tensormaptype(S, N₁, N₂, Mooncake.tangent_type(A)) + +Mooncake.@foldable Mooncake.fdata_type(::Type{T}) where {T <: TensorMap} = Mooncake.tangent_type(T) +Mooncake.@foldable Mooncake.rdata_type(::Type{T}) where {T <: TensorMap} = NoRData + +Mooncake.tangent(t::TensorMap, ::NoRData) = t +Mooncake.zero_tangent_internal(t::TensorMap, c::Mooncake.MaybeCache) = + TensorMap(Mooncake.zero_tangent_internal(t.data, c), space(t)) + +Mooncake.randn_tangent_internal(rng::AbstractRNG, p::TensorMap, c::Mooncake.MaybeCache) = + TensorMap(Mooncake.randn_tangent_internal(rng, p.data, c), space(p)) + +Mooncake.set_to_zero_internal!!(::Mooncake.SetToZeroCache, t::TensorMap) = zerovector!(t) +function Mooncake.increment!!(x::TensorMap, y::TensorMap) + data = Mooncake.increment!!(x.data, y.data) + return x.data === data ? x : TensorMap(data, space(x)) +end +function Mooncake.increment_internal!!(c::Mooncake.IncCache, x::TensorMap, y::TensorMap) + data = Mooncake.increment_internal!!(c, x.data, y.data) + return x.data === data ? x : TensorMap(data, space(x)) +end + +Mooncake._add_to_primal_internal(c::Mooncake.MaybeCache, p::TensorMap, t::TensorMap, unsafe::Bool) = + TensorMap(Mooncake._add_to_primal_internal(c, p.data, t.data, unsafe), space(p)) +function Mooncake.tangent_to_primal_internal!!(p::TensorMap, t::TensorMap, c::Mooncake.MaybeCache) + data = Mooncake.tangent_to_primal_internal!!(p.data, t.data, c) + data === p.data || copy!(p.data, data) + return p +end +Mooncake.primal_to_tangent_internal!!(t::T, p::T, ::Mooncake.MaybeCache) where {T <: TensorMap} = copy!(t, p) + +Mooncake._dot_internal(::Mooncake.MaybeCache, t::TensorMap, s::TensorMap) = Float64(real(inner(t, s))) +Mooncake._scale_internal(::Mooncake.MaybeCache, a::Float64, t::TensorMap) = scale(t, a) + +Mooncake.TestUtils.populate_address_map_internal(m::Mooncake.TestUtils.AddressMap, primal::TensorMap, tangent::TensorMap) = + Mooncake.populate_address_map_internal(m, primal.data, tangent.data) +@inline Mooncake.TestUtils.__get_data_field(t::TensorMap, n) = getfield(t, n) + +function Mooncake.__verify_fdata_value(c::IdDict{Any, Nothing}, p::TensorMap, t::TensorMap) + space(p) == space(t) || + throw(Mooncake.InvalidFDataException(lazy"p has space $(space(p)) but t has size $(space(t))")) + return Mooncake.__verify_fdata_value(c, p.data, t.data) +end + +Mooncake.to_cr_tangent(x::TensorMap) = x + +@is_primitive MinimalCtx Tuple{typeof(Mooncake.lgetfield), <:TensorMap, Val} + +# TODO: double-check if this has to include quantum dimensinos for non-abelian? +function Mooncake.frule!!( + ::Dual{typeof(Mooncake.lgetfield)}, t::Dual{<:TensorMap}, ::Dual{Val{FieldName}} + ) where {FieldName} + y = getfield(primal(t), FieldName) + + return if FieldName === 1 || FieldName === :data + dval = tangent(t).data + Dual(val, dval) + elseif FieldName === 2 || FieldName === :space + Dual(val, NoFData()), getfield_pullback + else + throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) + end +end + +function Mooncake.rrule!!( + ::CoDual{typeof(Mooncake.lgetfield)}, t::CoDual{<:TensorMap}, ::CoDual{Val{FieldName}} + ) where {FieldName} + val = getfield(primal(t), FieldName) + getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3)) + + return if FieldName === 1 || FieldName === :data + dval = Mooncake.tangent(t).data + CoDual(val, dval), getfield_pullback + elseif FieldName === 2 || FieldName === :space + Mooncake.zero_fcodual(val), getfield_pullback + else + throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) + end +end + +@is_primitive MinimalCtx Tuple{typeof(getfield), <:TensorMap, Any, Vararg{Symbol}} + +Base.@constprop :aggressive function Mooncake.frule!!( + ::Dual{typeof(getfield)}, t::Dual{<:TensorMap}, name::Dual + ) + y = getfield(primal(t), primal(name)) + + return if primal(name) === 1 || primal(name) === :data + dval = tangent(t).data + Dual(val, dval) + elseif primal(name) === 2 || primal(name) === :space + Dual(val, NoFData()) + else + throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) + end +end + +Base.@constprop :aggressive function Mooncake.rrule!!( + ::CoDual{typeof(getfield)}, t::CoDual{<:TensorMap}, name::CoDual + ) + val = getfield(primal(t), primal(name)) + getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3)) + + return if primal(name) === 1 || primal(name) === :data + dval = Mooncake.tangent(t).data + CoDual(val, dval), getfield_pullback + elseif primal(name) === 2 || primal(name) === :space + Mooncake.zero_fcodual(val), getfield_pullback + else + throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) + end +end + +Base.@constprop :aggressive function Mooncake.frule!!( + ::Dual{typeof(getfield)}, t::Dual{<:TensorMap}, name::Dual, order::Dual + ) + y = getfield(primal(t), primal(name), primal(order)) + + return if primal(name) === 1 || primal(name) === :data + dval = tangent(t).data + Dual(val, dval) + elseif primal(name) === 2 || primal(name) === :space + Dual(val, NoFData()) + else + throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) + end +end + +Base.@constprop :aggressive function Mooncake.rrule!!( + ::CoDual{typeof(getfield)}, t::CoDual{<:TensorMap}, name::CoDual, order::CoDual + ) + val = getfield(primal(t), primal(name), primal(order)) + getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 4)) + + return if primal(name) === 1 || primal(name) === :data + dval = Mooncake.tangent(t).data + CoDual(val, dval), getfield_pullback + elseif primal(name) === 2 || primal(name) === :space + Mooncake.zero_fcodual(val), getfield_pullback + else + throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) + end +end + + +@is_primitive MinimalCtx Tuple{typeof(Mooncake.lgetfield), <:TensorMap, Val, Val} + +# TODO: double-check if this has to include quantum dimensinos for non-abelian? +function Mooncake.frule!!( + ::Dual{typeof(Mooncake.lgetfield)}, t::Dual{<:TensorMap}, ::Dual{Val{FieldName}}, ::Dual{Val{Order}} + ) where {FieldName, Order} + y = getfield(primal(t), FieldName, Order) + + return if FieldName === 1 || FieldName === :data + dval = tangent(t).data + Dual(val, dval) + elseif FieldName === 2 || FieldName === :space + Dual(val, NoFData()) + else + throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) + end +end + +function Mooncake.rrule!!( + ::CoDual{typeof(Mooncake.lgetfield)}, t::CoDual{<:TensorMap}, ::CoDual{Val{FieldName}}, ::CoDual{Val{Order}} + ) where {FieldName, Order} + val = getfield(primal(t), FieldName, Order) + getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 4)) + + return if FieldName === 1 || FieldName === :data + dval = Mooncake.tangent(t).data + CoDual(val, dval), getfield_pullback + elseif FieldName === 2 || FieldName === :space + Mooncake.zero_fcodual(val), getfield_pullback + else + throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) + end +end + + +Mooncake.@zero_derivative Mooncake.MinimalCtx Tuple{typeof(Mooncake._new_), Type{TensorMap{T, S, N₁, N₂, A}}, UndefInitializer, TensorMapSpace{S, N₁, N₂}} where {T, S, N₁, N₂, A} +@is_primitive Mooncake.MinimalCtx Tuple{typeof(Mooncake._new_), Type{TensorMap{T, S, N₁, N₂, A}}, A, TensorMapSpace{S, N₁, N₂}} where {T, S, N₁, N₂, A} + +function Mooncake.frule!!( + ::Dual{typeof(Mooncake._new_)}, ::Dual{Type{TensorMap{T, S, N₁, N₂, A}}}, data::Dual{A}, space::Dual{TensorMapSpace{S, N₁, N₂}} + ) where {T, S, N₁, N₂, A} + t = Mooncake._new_(TensorMap{T, S, N₁, N₂, A}, primal(data), primal(space)) + dt = Mooncake._new_(TensorMap{T, S, N₁, N₂, A}, tangent(data), primal(space)) + return Dual(t, dt) +end + +function Mooncake.rrule!!( + ::CoDual{typeof(Mooncake._new_)}, ::CoDual{Type{TensorMap{T, S, N₁, N₂, A}}}, data::CoDual{A}, space::CoDual{TensorMapSpace{S, N₁, N₂}} + ) where {T, S, N₁, N₂, A} + return Mooncake.zero_fcodual(Mooncake._new_(TensorMap{T, S, N₁, N₂, A}, primal(data), primal(space))), + Returns(ntuple(Returns(NoRData()), 4)) end diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index d663a3281..c4468ef65 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -1,73 +1,70 @@ -Mooncake.@is_primitive( +# tensorcontract! +# --------------- +@is_primitive( DefaultCtx, ReverseMode, Tuple{ - typeof(TO.tensorcontract!), + typeof(TensorKit.blas_contract!), AbstractTensorMap, - AbstractTensorMap, Index2Tuple, Bool, - AbstractTensorMap, Index2Tuple, Bool, + AbstractTensorMap, Index2Tuple, + AbstractTensorMap, Index2Tuple, Index2Tuple, Number, Number, - Vararg{Any}, + Any, Any, } ) function Mooncake.rrule!!( - ::CoDual{typeof(TO.tensorcontract!)}, + ::CoDual{typeof(TensorKit.blas_contract!)}, C_ΔC::CoDual{<:AbstractTensorMap}, - A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, - B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, conjB_ΔconjB::CoDual{Bool}, + A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, + B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, pAB_ΔpAB::CoDual{<:Index2Tuple}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, - ba_Δba::CoDual..., + backend_Δbackend::CoDual, allocator_Δallocator::CoDual ) # prepare arguments (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB)) - conjA, conjB = primal.((conjA_ΔconjA, conjB_ΔconjB)) α, β = primal.((α_Δα, β_Δβ)) - ba = primal.(ba_Δba) + backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) # primal call C_cache = copy(C) - TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) + TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) - function tensorcontract_pullback(::NoRData) + function blas_contract_pullback(::NoRData) copy!(C, C_cache) - ΔCr = tensorcontract_pullback_ΔC!(ΔC, β) - ΔAr = tensorcontract_pullback_ΔA!( - ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ΔAr = blas_contract_pullback_ΔA!( + ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator + ) # this typically returns NoRData() + ΔBr = blas_contract_pullback_ΔB!( + ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator + ) # this typically returns NoRData() + Δαr = blas_contract_pullback_Δα( + ΔC, A, pA, B, pB, pAB, α, backend, allocator ) - ΔBr = tensorcontract_pullback_ΔB!( - ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... - ) - Δαr = tensorcontract_pullback_Δα( - ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... - ) - Δβr = tensorcontract_pullback_Δβ(ΔC, C, β) + Δβr = pullback_dβ(ΔC, C, β) + ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() return NoRData(), ΔCr, - ΔAr, NoRData(), NoRData(), - ΔBr, NoRData(), NoRData(), + ΔAr, NoRData(), + ΔBr, NoRData(), NoRData(), Δαr, Δβr, - map(ba_ -> NoRData(), ba)... + NoRData(), NoRData() end - return C_ΔC, tensorcontract_pullback + return C_ΔC, blas_contract_pullback end -tensorcontract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) - -function tensorcontract_pullback_ΔA!( - ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... +function blas_contract_pullback_ΔA!( + ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) ipAB = invperm(linearize(pAB)) pΔC = _repartition(ipAB, TO.numout(pA)) ipA = _repartition(invperm(linearize(pA)), A) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB tB = twist( B, @@ -79,24 +76,22 @@ function tensorcontract_pullback_ΔA!( TO.tensorcontract!( ΔA, - ΔC, pΔC, conjΔC, - tB, reverse(pB), conjB′, + ΔC, pΔC, false, + tB, reverse(pB), true, ipA, - conjA ? α : conj(α), Zero(), - ba... + conj(α), One(), + backend, allocator ) return NoRData() end -function tensorcontract_pullback_ΔB!( - ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... +function blas_contract_pullback_ΔB!( + ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) ipAB = invperm(linearize(pAB)) pΔC = _repartition(ipAB, TO.numout(pA)) ipB = _repartition(invperm(linearize(pB)), B) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA tA = twist( A, @@ -108,30 +103,97 @@ function tensorcontract_pullback_ΔB!( TO.tensorcontract!( ΔB, - tA, reverse(pA), conjA′, - ΔC, pΔC, conjΔC, + tA, reverse(pA), true, + ΔC, pΔC, false, ipB, - conjB ? α : conj(α), Zero(), ba... + conj(α), One(), backend, allocator ) return NoRData() end -function tensorcontract_pullback_Δα( - ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... +function blas_contract_pullback_Δα( + ΔC, A, pA, B, pB, pAB, α, backend, allocator ) - Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) - Tdα === NoRData && return NoRData() + _needs_tangent(α) || return NoRData() - AB = TO.tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator) Δα = inner(AB, ΔC) - return Mooncake._rdata(Δα) + return project_scalar(α, Δα) +end + +# tensortrace! +# ------------ +@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TensorKit.trace_permute!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, Index2Tuple, + Number, Number, + Any, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TensorKit.trace_permute!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, q_Δq::CoDual{<:Index2Tuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + backend_Δbackend::CoDual + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + q = primal(q_Δq) + α, β = primal.((α_Δα, β_Δβ)) + backend = primal(backend_Δbackend) + + # primal call + C_cache = copy(C) + TensorKit.trace_permute!(C, A, p, q, α, β, backend) + + function trace_permute_pullback(::NoRData) + copy!(C, C_cache) + + ΔAr = trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) # this typically returns NoRData() + Δαr = trace_permute_pullback_Δα(ΔC, A, p, q, α, backend) + Δβr = pullback_dβ(ΔC, C, β) + ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() + + return NoRData(), + ΔCr, ΔAr, NoRData(), NoRData(), + Δαr, Δβr, NoRData() + end + + return C_ΔC, trace_permute_pullback +end + +function trace_permute_pullback_ΔA!( + ΔA, ΔC, A, p, q, α, backend + ) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = _repartition(ip, A) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TO.tensorproduct!( + ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend + ) + return NoRData() end -function tensorcontract_pullback_Δβ(ΔC, C, β) - Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) - Tdβ === NoRData && return NoRData() +function trace_permute_pullback_Δα( + ΔC, A, p, q, α, backend + ) + _needs_tangent(α) || return NoRData() - Δβ = inner(C, ΔC) - return Mooncake._rdata(Δβ) + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = TO.tensortrace(A, p, q, false, One(), backend) + Δα = inner(At, ΔC) + return Δα end diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index ca2c79b54..3f50bffa0 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -1,7 +1,15 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) -_needs_tangent(::Type{<:Number}) = true -_needs_tangent(::Type{<:Integer}) = false -_needs_tangent(::Type{<:Union{One, Zero}}) = false +_needs_tangent(::Type{T}) where {T <: Number} = + Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData + +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) # IndexTuple utility # ------------------ @@ -25,4 +33,16 @@ end # Ignore derivatives # ------------------ + +# A VectorSpace has no meaningful notion of a vector space (tangent space) +Mooncake.tangent_type(::Type{<:VectorSpace}) = Mooncake.NoTangent +Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent + @zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} + +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.braid), HomSpace, Index2Tuple, IndexTuple} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.compose), HomSpace, HomSpace} +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract), HomSpace, Index2Tuple, Bool, HomSpace, Index2Tuple, Bool, Index2Tuple} diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl new file mode 100644 index 000000000..625aadd61 --- /dev/null +++ b/ext/TensorKitMooncakeExt/vectorinterface.jl @@ -0,0 +1,93 @@ +@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, Number} + +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + α = primal(α_Δα) + + # primal call + C_cache = copy(C) + scale!(C, α) + + function scale_pullback(::NoRData) + copy!(C, C_cache) + scale!(ΔC, conj(α)) + TΔα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Δαr = TΔα === NoRData ? NoRData() : inner(C, ΔC) + return NoRData(), NoRData(), Δαr + end + + return C_ΔC, scale_pullback +end + +@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} + +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α = primal(α_Δα) + + # primal call + C_cache = copy(C) + scale!(C, A, α) + + function scale_pullback(::NoRData) + copy!(C, C_cache) + zerovector!(ΔC) + scale!(ΔA, conj(α)) + TΔα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Δαr = TΔα === NoRData ? NoRData() : inner(C, ΔC) + return NoRData(), NoRData(), NoRData(), Δαr + end + + return C_ΔC, scale_pullback +end + +@is_primitive DefaultCtx ReverseMode Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} + +function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α = primal(α_Δα) + β = primal(β_Δβ) + + # primal call + C_cache = copy(C) + add!(C, A, α, β) + + function add_pullback(::NoRData) + copy!(C, C_cache) + scale!(ΔC, conj(β)) + scale!(ΔA, conj(α)) + + TΔα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Δαr = TΔα === NoRData ? NoRData() : inner(A, ΔC) + TΔβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Δβr = TΔβ === NoRData ? NoRData() : inner(C, ΔC) + + return NoRData(), NoRData(), NoRData(), Δαr, Δβr + end + + return C_ΔC, add_pullback +end + +@is_primitive DefaultCtx ReverseMode Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} + +function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}) + # prepare arguments + A, ΔA = arrayify(A_ΔA) + B, ΔB = arrayify(B_ΔB) + + # primal call + s = inner(A, B) + + function inner_pullback(Δs) + scale!(ΔA, B, conj(Δs)) + scale!(ΔB, A, Δs) + return NoRData(), NoRData(), NoRData() + end + + return CoDual(s, NoFData()), inner_pullback +end diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index 1564b1b67..3cc6a16b6 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -692,7 +692,7 @@ function planar_trace( k += 1 end end - k > N₃ && throw(ArgumentError("Not a planar trace")) + k > N₃ && throw(ArgumentError(lazy"not a planar trace: ($q1, $q2)")) q1′ = let i = i, j = j map(l -> (l - (l > i) - (l > j)), TupleTools.deleteat(q1, k)) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index a06f4431d..982a848f8 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -69,8 +69,7 @@ function planartrace!( α::Number, β::Number, backend, allocator ) - (S = spacetype(C)) == spacetype(A) || - throw(SpaceMismatch("incompatible spacetypes")) + S = check_spacetype(C, A) if BraidingStyle(sectortype(S)) == Bosonic() return trace_permute!(C, A, (p₁, p₂), (q₁, q₂), α, β, backend) end diff --git a/src/spaces/vectorspaces.jl b/src/spaces/vectorspaces.jl index 5ef171c31..867d18959 100644 --- a/src/spaces/vectorspaces.jl +++ b/src/spaces/vectorspaces.jl @@ -376,19 +376,31 @@ abstract type CompositeSpace{S <: ElementarySpace} <: VectorSpace end InnerProductStyle(::Type{<:CompositeSpace{S}}) where {S} = InnerProductStyle(S) """ - spacetype(a) -> Type{S<:IndexSpace} - spacetype(::Type) -> Type{S<:IndexSpace} + spacetype(a) -> Type{S <: IndexSpace} + spacetype(::Type) -> Type{S <: IndexSpace} -Return the type of the elementary space `S` of object `a` (e.g. a tensor). Also works in -type domain. +Return the type of the elementary space `S` of object `a` (e.g. a tensor). +Also works in type domain. """ spacetype(x) = spacetype(typeof(x)) -function spacetype(::Type{T}) where {T} - throw(MethodError(spacetype, (T,))) -end +spacetype(::Type{T}) where {T} = throw(MethodError(spacetype, (T,))) spacetype(::Type{E}) where {E <: ElementarySpace} = E spacetype(::Type{S}) where {E, S <: CompositeSpace{E}} = E +""" + check_spacetype(Bool, x, y, z...) -> Bool + check_spacetype(x, y, z...) -> Type{<:IndexSpace} + +Check whether the given inputs have matching spacetypes. +The first signature returns a `Bool` to indicate whether all spacetypes are equal, +while the second will return the spacetype if all types are equal, and throw a [`SpaceMismatch`](@ref) if not. +""" +check_spacetype(::Type{Bool}, x, y, z...) = _allequal(spacetype, (x, y, z...)) +@noinline function check_spacetype(x, y, z...) + check_spacetype(Bool, x, y, z...) || throw(SpaceMismatch("incompatible space types")) + return spacetype(x) +end + # make ElementarySpace instances behave similar to ProductSpace instances blocksectors(V::ElementarySpace) = collect(sectors(V)) blockdim(V::ElementarySpace, c::Sector) = dim(V, c) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 4b3149a4c..0f580c17c 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -103,6 +103,11 @@ similarstoragetype(::Type{D}, ::Type{T}) where {D <: AbstractDict{<:Sector, <:Ab similarstoragetype(::Type{T}) where {T <: Number} = Vector{T} +# helper function to determine the scalartype taking into account that recouplings might happen +recoupled_scalartype(::Type{T}, ::Type{I}) where {T <: Number, I <: Sector} = isreal(I) ? T : complex(T) +recoupled_scalartype(t::AbstractTensorMap) = recoupled_scalartype(typeof(t)) +recoupled_scalartype(::Type{T}) where {T <: AbstractTensorMap} = recoupled_scalartype(scalartype(T), sectortype(T)) + # tensor characteristics: space and index information #----------------------------------------------------- """ diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index c191bb6b5..f25a81961 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -260,11 +260,10 @@ function TO.tensorcontract_type( B::DiagonalTensorMap, ::Index2Tuple{1, 1}, ::Bool, ::Index2Tuple{1, 1} ) - M = similarstoragetype(A, TC) - M == similarstoragetype(B, TC) || - throw(ArgumentError("incompatible storage types:\n$(M) ≠ $(similarstoragetype(B, TC))")) - spacetype(A) == spacetype(B) || throw(SpaceMismatch("incompatible space types")) - return DiagonalTensorMap{TC, spacetype(A), M} + S = check_spacetype(A, B) + TC′ = recoupled_scalartype(TC, sectortype(S)) + M = promote_storagetype(similarstoragetype(A, TC′), similarstoragetype(B, TC′)) + return DiagonalTensorMap{TC, S, M} end function TO.tensoralloc( @@ -291,6 +290,15 @@ function Base.zero(d::DiagonalTensorMap) return DiagonalTensorMap(zero.(d.data), d.domain) end +function compose_dest(A::DiagonalTensorMap, B::DiagonalTensorMap) + S = check_spacetype(A, B) + TC = TO.promote_contract(scalartype(A), scalartype(B), One) + M = promote_storagetype(similarstoragetype(A, TC), similarstoragetype(B, TC)) + TTC = DiagonalTensorMap{TC, S, M} + structure = codomain(A) ← domain(B) + return TO.tensoralloc(TTC, structure, Val(false)) +end + function LinearAlgebra.mul!( dC::DiagonalTensorMap, dA::DiagonalTensorMap, dB::DiagonalTensorMap, α::Number, β::Number ) diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index 0c918e85e..548374916 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -15,7 +15,7 @@ Return a new tensor that is isomorphic to `t` but where the arrows on the indice """ function flip(t::AbstractTensorMap, I; inv::Bool = false) P = flip(space(t), I) - t′ = similar(t, P) + t′ = similar(t, recoupled_scalartype(t), P) for (f₁, f₂) in fusiontrees(t) (f₁′, f₂′), factor = only(flip(f₁, f₂, I; inv)) scale!(t′[f₁′, f₂′], t[f₁, f₂], factor) @@ -39,53 +39,35 @@ See [`permute`](@ref) for creating a new tensor and [`add_permute!`](@ref) for a end """ - permute(tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple; - copy::Bool=false) - -> tdst::TensorMap + permute(tsrc::AbstractTensorMap, (p₁, p₂)::Index2Tuple; copy::Bool = false) -> tdst::TensorMap Return tensor `tdst` obtained by permuting the indices of `tsrc`. The codomain and domain of `tdst` correspond to the indices in `p₁` and `p₂` of `tsrc` respectively. -If `copy=false`, `tdst` might share data with `tsrc` whenever possible. Otherwise, a copy is always made. +If `copy = false`, `tdst` might share data with `tsrc` whenever possible. +Otherwise, a copy is always made. To permute into an existing destination, see [permute!](@ref) and [`add_permute!`](@ref) """ -function permute( - t::AbstractTensorMap, (p₁, p₂)::Index2Tuple{N₁, N₂}; copy::Bool = false - ) where {N₁, N₂} - space′ = permute(space(t), (p₁, p₂)) - # share data if possible - if !copy && p₁ === codomainind(t) && p₂ === domainind(t) - return t - end - # general case - @inbounds begin - return permute!(similar(t, space′), t, (p₁, p₂)) - end -end -function permute(t::TensorMap, (p₁, p₂)::Index2Tuple{N₁, N₂}; copy::Bool = false) where {N₁, N₂} - space′ = permute(space(t), (p₁, p₂)) +function permute(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; copy::Bool = false) # share data if possible if !copy if p₁ === codomainind(t) && p₂ === domainind(t) return t - elseif has_shared_permute(t, (p₁, p₂)) - return TensorMap(t.data, space′) + elseif t isa TensorMap && has_shared_permute(t, (p₁, p₂)) + return TensorMap(t.data, permute(space(t), (p₁, p₂))) end end + tdst = TO.tensoralloc_add(scalartype(t), t, (p₁, p₂), false, Val(false)) # general case - @inbounds begin - return permute!(similar(t, space′), t, (p₁, p₂)) - end + return @inbounds permute!(tdst, t, (p₁, p₂)) end function permute(t::AdjointTensorMap, (p₁, p₂)::Index2Tuple; copy::Bool = false) p₁′ = adjointtensorindices(t, p₂) p₂′ = adjointtensorindices(t, p₁) return adjoint(permute(adjoint(t), (p₁′, p₂′); copy)) end -function permute(t::AbstractTensorMap, p::IndexTuple; copy::Bool = false) - return permute(t, (p, ()); copy) -end +permute(t::AbstractTensorMap, p::IndexTuple; copy::Bool = false) = permute(t, (p, ()); copy) function has_shared_permute(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple) return (p₁ === codomainind(t) && p₂ === domainind(t)) @@ -145,18 +127,14 @@ To braid into an existing destination, see [braid!](@ref) and [`add_braid!`](@re function braid( t::AbstractTensorMap, (p₁, p₂)::Index2Tuple, levels::IndexTuple; copy::Bool = false ) - @assert length(levels) == numind(t) - if BraidingStyle(sectortype(t)) isa SymmetricBraiding - return permute(t, (p₁, p₂); copy = copy) - end - if !copy && p₁ == codomainind(t) && p₂ == domainind(t) - return t - end + length(levels) == numind(t) || throw(ArgumentError("invalid levels")) + + BraidingStyle(sectortype(t)) isa SymmetricBraiding && return permute(t, (p₁, p₂); copy = copy) + (!copy && p₁ == codomainind(t) && p₂ == domainind(t)) && return t + # general case - space′ = permute(space(t), (p₁, p₂)) - @inbounds begin - return braid!(similar(t, space′), t, (p₁, p₂), levels) - end + tdst = TO.tensoralloc_add(scalartype(t), t, (p₁, p₂), false, Val(false)) + return @inbounds braid!(tdst, t, (p₁, p₂), levels) end # TODO: braid for `AdjointTensorMap`; think about how to map the `levels` argument. @@ -199,17 +177,12 @@ function LinearAlgebra.transpose( t::AbstractTensorMap, (p₁, p₂)::Index2Tuple = _transpose_indices(t); copy::Bool = false ) - if sectortype(t) === Trivial - return permute(t, (p₁, p₂); copy = copy) - end - if !copy && p₁ == codomainind(t) && p₂ == domainind(t) - return t - end + sectortype(t) === Trivial && return permute(t, (p₁, p₂); copy) + (!copy && p₁ == codomainind(t) && p₂ == domainind(t)) && return t + # general case - space′ = permute(space(t), (p₁, p₂)) - @inbounds begin - return transpose!(similar(t, space′), t, (p₁, p₂)) - end + tdst = TO.tensoralloc_add(scalartype(t), t, (p₁, p₂), false, Val(false)) + return @inbounds transpose!(tdst, t, (p₁, p₂)) end function LinearAlgebra.transpose( @@ -295,7 +268,13 @@ function twist!(t::AbstractTensorMap, inds; inv::Bool = false) msg = "Can't twist indices $inds of a tensor with only $(numind(t)) indices." throw(ArgumentError(msg)) end + (scalartype(t) <: Real && !(sectorscalartype(sectortype(t)) <: Real)) && + throw(ArgumentError("Can't in-place twist a real tensor with complex sector type")) has_shared_twist(t, inds) && return t + + (scalartype(t) <: Real && !(sectorscalartype(sectortype(t)) <: Real)) && + throw(ArgumentError("No in-place `twist!` for a real tensor with complex sector type")) + N₁ = numout(t) for (f₁, f₂) in fusiontrees(t) θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), inds) @@ -317,7 +296,9 @@ See [`twist!`](@ref) for storing the result in place. """ function twist(t::AbstractTensorMap, inds; inv::Bool = false, copy::Bool = false) !copy && has_shared_twist(t, inds) && return t - return twist!(Base.copy(t), inds; inv) + tdst = TO.tensoralloc_add(scalartype(t), t, (codomainind(t), domainind(t)), false, Val(false)) + copy!(tdst, t) + return twist!(tdst, inds; inv) end # Methods which change the number of indices, implement using `Val(i)` for type inference @@ -413,7 +394,7 @@ end spacecheck_transform(f, tdst::AbstractTensorMap, tsrc::AbstractTensorMap, args...) = spacecheck_transform(f, space(tdst), space(tsrc), args...) @noinline function spacecheck_transform(f, Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple) - spacetype(Vdst) == spacetype(Vsrc) || throw(SectorMismatch("incompatible sector types")) + check_spacetype(Vdst, Vsrc) f(Vsrc, p) == Vdst || throw( SpaceMismatch( @@ -427,7 +408,7 @@ spacecheck_transform(f, tdst::AbstractTensorMap, tsrc::AbstractTensorMap, args.. return nothing end @noinline function spacecheck_transform(::typeof(braid), Vdst::TensorMapSpace, Vsrc::TensorMapSpace, p::Index2Tuple, levels::IndexTuple) - spacetype(Vdst) == spacetype(Vsrc) || throw(SectorMismatch("incompatible sector types")) + check_spacetype(Vdst, Vsrc) braid(Vsrc, p, levels) == Vdst || throw( SpaceMismatch( diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index d174bef71..e28ec8153 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -19,17 +19,15 @@ LinearAlgebra.normalize!(t::AbstractTensorMap, p::Real = 2) = scale!(t, inv(norm LinearAlgebra.normalize(t::AbstractTensorMap, p::Real = 2) = scale(t, inv(norm(t, p))) # destination allocation for matrix multiplication +# note that we don't fall back to `tensoralloc_contract` since that needs to account for +# permutations, which might require complex scalartypes even if the inputs are real. function compose_dest(A::AbstractTensorMap, B::AbstractTensorMap) + S = check_spacetype(A, B) TC = TO.promote_contract(scalartype(A), scalartype(B), One) - pA = (codomainind(A), domainind(A)) - pB = (codomainind(B), domainind(B)) - pAB = (codomainind(A), ntuple(i -> i + numout(A), numin(B))) - return TO.tensoralloc_contract( - TC, - A, pA, false, - B, pB, false, - pAB, Val(false) - ) + M = promote_storagetype(similarstoragetype(A, TC), similarstoragetype(B, TC)) + TTC = tensormaptype(S, numout(A), numin(B), M) + structure = codomain(A) ← domain(B) + return TO.tensoralloc(TTC, structure, Val(false)) end """ @@ -538,8 +536,7 @@ absorb(tdst::AbstractTensorMap, tsrc::AbstractTensorMap) = absorb!(copy(tdst), t function absorb!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap) numin(tdst) == numin(tsrc) && numout(tdst) == numout(tsrc) || throw(DimensionError("Incompatible number of indices for source and destination")) - S = spacetype(tdst) - S == spacetype(tsrc) || throw(SpaceMismatch("incompatible spacetypes")) + S = check_spacetype(tdst, tsrc) dom = mapreduce(infimum, ⊗, domain(tdst), domain(tsrc); init = one(S)) cod = mapreduce(infimum, ⊗, codomain(tdst), codomain(tsrc); init = one(S)) for (f1, f2) in fusiontrees(cod ← dom) @@ -561,7 +558,7 @@ new `TensorMap` instance whose codomain is `codomain(t1) ⊗ codomain(t2)` and w is `domain(t1) ⊗ domain(t2)`. """ function ⊗(A::AbstractTensorMap, B::AbstractTensorMap) - (S = spacetype(A)) === spacetype(B) || throw(SpaceMismatch("incompatible space types")) + check_spacetype(A, B) # allocate destination with correct scalartype pA = ((codomainind(A)..., domainind(A)...), ()) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 526f4e489..6e6cce626 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -105,6 +105,11 @@ TensorMapWithStorage{T, A}(::UndefInitializer, codomain::TensorSpace, domain::Te TensorMapWithStorage{T, A}(undef, codomain ← domain) TensorWithStorage{T, A}(::UndefInitializer, V::TensorSpace) where {T, A} = TensorMapWithStorage{T, A}(undef, V ← one(V)) +# Utility constructors +# -------------------- +TensorMap(t::TensorMap) = copy(t) + + # raw data constructors # --------------------- # - dispatch starts with TensorMap{T}(::DenseVector{T}, ...) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 6b01d1b2c..bc5d4beeb 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -54,8 +54,7 @@ end function TO.tensoradd_type( TC, A::AbstractTensorMap, ::Index2Tuple{N₁, N₂}, ::Bool ) where {N₁, N₂} - I = sectortype(A) - M = similarstoragetype(A, sectorscalartype(I) <: Real ? TC : complex(TC)) + M = similarstoragetype(A, recoupled_scalartype(TC, sectortype(A))) return tensormaptype(spacetype(A), N₁, N₂, M) end @@ -103,7 +102,7 @@ end VB::TensorMapSpace, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple ) - spacetype(VC) == spacetype(VA) == spacetype(VB) || throw(SectorMismatch("incompatible sector types")) + check_spacetype(VC, VA, VB) TO.tensorcontract(VA, pA, conjA, VB, pB, conjB, pAB) == VC || throw( SpaceMismatch( @@ -153,11 +152,10 @@ function TO.tensorcontract_type( B::AbstractTensorMap, ::Index2Tuple, ::Bool, ::Index2Tuple{N₁, N₂} ) where {N₁, N₂} - spacetype(A) == spacetype(B) || throw(SpaceMismatch("incompatible space types")) - I = sectortype(A) - TC′ = isreal(I) ? TC : complex(TC) + S = check_spacetype(A, B) + TC′ = recoupled_scalartype(TC, sectortype(S)) M = promote_storagetype(similarstoragetype(A, TC′), similarstoragetype(B, TC′)) - return tensormaptype(spacetype(A), N₁, N₂, M) + return tensormaptype(S, N₁, N₂, M) end # TODO: handle actual promotion rule system @@ -213,8 +211,7 @@ function trace_permute!( backend = TO.DefaultBackend() ) # some input checks - (S = spacetype(tdst)) == spacetype(tsrc) || - throw(SpaceMismatch("incompatible spacetypes")) + S = check_spacetype(tdst, tsrc) if !(BraidingStyle(sectortype(S)) isa SymmetricBraiding) throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead")) end diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl deleted file mode 100644 index 1cd74fa27..000000000 --- a/test/autodiff/mooncake.jl +++ /dev/null @@ -1,117 +0,0 @@ -using Test, TestExtras -using TensorKit -using TensorOperations -using Mooncake -using Random - -mode = Mooncake.ReverseMode -rng = Random.default_rng() -is_primitive = false - -function randindextuple(N::Int, k::Int = rand(0:N)) - @assert 0 ≤ k ≤ N - _p = randperm(N) - return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...)) -end - -const _repartition = @static if isdefined(Base, :get_extension) - Base.get_extension(TensorKit, :TensorKitMooncakeExt)._repartition -else - TensorKit.TensorKitMooncakeExt._repartition -end - -spacelist = ( - (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), - ( - Vect[Z2Irrep](0 => 1, 1 => 1), - Vect[Z2Irrep](0 => 1, 1 => 2)', - Vect[Z2Irrep](0 => 2, 1 => 2)', - Vect[Z2Irrep](0 => 2, 1 => 3), - Vect[Z2Irrep](0 => 2, 1 => 2), - ), - ( - Vect[FermionParity](0 => 1, 1 => 1), - Vect[FermionParity](0 => 1, 1 => 2)', - Vect[FermionParity](0 => 2, 1 => 1)', - Vect[FermionParity](0 => 2, 1 => 3), - Vect[FermionParity](0 => 2, 1 => 2), - ), - ( - Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), - Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), - Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', - Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), - Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', - ), - ( - Vect[SU2Irrep](0 => 2, 1 // 2 => 1), - Vect[SU2Irrep](0 => 1, 1 => 1), - Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', - Vect[SU2Irrep](1 // 2 => 2), - Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', - ), - ( - Vect[FibonacciAnyon](:I => 2, :τ => 1), - Vect[FibonacciAnyon](:I => 1, :τ => 2)', - Vect[FibonacciAnyon](:I => 2, :τ => 2)', - Vect[FibonacciAnyon](:I => 2, :τ => 3), - Vect[FibonacciAnyon](:I => 2, :τ => 2), - ), -) - -for V in spacelist - I = sectortype(eltype(V)) - Istr = TensorKit.type_repr(I) - - symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding - println("---------------------------------------") - println("Mooncake with symmetry: $Istr") - println("---------------------------------------") - eltypes = (Float64,) # no complex support yet - symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes - atol = precision(T) - rtol = precision(T) - - @timedtestset "tensorcontract!" begin - for _ in 1:5 - d = 0 - local V1, V2, V3 - # retry a couple times to make sure there are at least some nonzero elements - for _ in 1:10 - k1 = rand(0:3) - k2 = rand(0:2) - k3 = rand(0:2) - V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) - V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) - V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) - d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) - d > 0 && break - end - ipA = randindextuple(length(V1) + length(V2)) - pA = _repartition(invperm(linearize(ipA)), length(V1)) - ipB = randindextuple(length(V2) + length(V3)) - pB = _repartition(invperm(linearize(ipB)), length(V2)) - pAB = randindextuple(length(V1) + length(V3)) - - α = randn(T) - β = randn(T) - V2_conj = prod(conj, V2; init = one(V[1])) - - for conjA in (false, true), conjB in (false, true) - A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA)) - B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB)) - C = randn!( - TensorOperations.tensoralloc_contract( - T, A, pA, conjA, B, pB, conjB, pAB, Val(false) - ) - ) - Mooncake.TestUtils.test_rule( - rng, tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β; - atol, rtol, mode, is_primitive - ) - - end - end - end - end -end diff --git a/test/autodiff/chainrules.jl b/test/chainrules/chainrules.jl similarity index 100% rename from test/autodiff/chainrules.jl rename to test/chainrules/chainrules.jl diff --git a/test/mooncake/indexmanipulations.jl b/test/mooncake/indexmanipulations.jl new file mode 100644 index 000000000..614439b23 --- /dev/null +++ b/test/mooncake/indexmanipulations.jl @@ -0,0 +1,137 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Mooncake - Index Manipulations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + + symmetricbraiding && @timedtestset "add_permute!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randindextuple(numind(A)) + C = randn!(permute(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_permute!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "add_transpose!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randcircshift(numout(A), numin(A)) + C = randn!(transpose(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "add_braid!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randcircshift(numout(A), numin(A)) + levels = tuple(randperm(numind(A))) + C = randn!(transpose(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "flip_n_twist!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + + if !(T <: Real && !(sectorscalartype(sectortype(A)) <: Real)) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), twist!, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), twist!, A, [1, 3]; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, twist!, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, twist!, A, [1, 3]; atol, rtol, mode) + end + + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), flip, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), flip, A, [1, 3]; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, flip, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, flip, A, [1, 3]; atol, rtol, mode) + end + + @timedtestset "insert and remove units" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + + for insertunit in (insertleftunit, insertrightunit) + Mooncake.TestUtils.test_rule(rng, insertunit, A, Val(1); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, insertunit, A, Val(4); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, insertunit, A', Val(2); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false), insertunit, A, Val(1); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = true), insertunit, A, Val(2); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false, dual = true, conj = true), insertunit, A, Val(3); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false, dual = true, conj = true), insertunit, A', Val(3); atol, rtol, mode) + end + + for i in 1:4 + B = insertleftunit(A, i; dual = rand(Bool)) + Mooncake.TestUtils.test_rule(rng, removeunit, B, Val(i); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false), removeunit, B, Val(i); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = true), removeunit, B, Val(i); atol, rtol, mode) + end + end +end diff --git a/test/mooncake/linalg.jl b/test/mooncake/linalg.jl new file mode 100644 index 000000000..ead21f7a1 --- /dev/null +++ b/test/mooncake/linalg.jl @@ -0,0 +1,79 @@ +using Test, TestExtras +using TensorKit +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Mooncake - LinearAlgebra: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + C = randn(T, V[1] ⊗ V[2] ← V[5]) + A = randn(T, codomain(C) ← V[3] ⊗ V[4]) + B = randn(T, domain(A) ← domain(C)) + α = randn(T) + β = randn(T) + + Mooncake.TestUtils.test_rule(rng, mul!, C, A, B, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, mul!, C, A, B; atol, rtol, mode, is_primitive = false) + + Mooncake.TestUtils.test_rule(rng, norm, C, 2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, norm, C', 2; atol, rtol, mode) + + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + + Mooncake.TestUtils.test_rule(rng, tr, D1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, tr, D2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, tr, D3; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, inv, D1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, inv, D2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, inv, D3; atol, rtol, mode) +end diff --git a/test/mooncake/planaroperations.jl b/test/mooncake/planaroperations.jl new file mode 100644 index 000000000..98a9afe22 --- /dev/null +++ b/test/mooncake/planaroperations.jl @@ -0,0 +1,132 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup +using .TestSetup: _repartition + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Mooncake - PlanarOperations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + @timedtestset "planarcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3, k1, k2, k3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 1 && break + end + k′ = rand(0:(k1 + k2)) + pA = randcircshift(k′, k1 + k2 - k′, k1) + ipA = _repartition(invperm(linearize(pA)), k′) + k′ = rand(0:(k2 + k3)) + pB = randcircshift(k′, k2 + k3 - k′, k2) + ipB = _repartition(invperm(linearize(pB)), k′) + # TODO: primal value already is broken for this? + # pAB = randcircshift(k1, k3) + pAB = _repartition(tuple((1:(k1 + k3))...), k1) + + α = randn(T) + β = randn(T) + + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.planarcontract!, C, A, pA, B, pB, pAB, One(), Zero(); + atol, rtol, mode, is_primitive = false + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.planarcontract!, C, A, pA, B, pB, pAB, α, β; + atol, rtol, mode, is_primitive = false + ) + end + end + + # TODO: currently broken + # @timedtestset "planartrace!" begin + # for _ in 1:5 + # k1 = rand(0:2) + # k2 = rand(0:1) + # V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + # V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + # V3 = prod(x -> x ⊗ x', V2[1:k2]; init = one(V[1])) + # V4 = prod(x -> x ⊗ x', V2[(k2 + 1):end]; init = one(V[1])) + # + # k′ = rand(0:(k1 + 2k2)) + # (_p, _q) = randcircshift(k′, k1 + 2k2 - k′, k1) + # p = _repartition(_p, rand(0:k1)) + # q = (tuple(_q[1:2:end]...), tuple(_q[2:2:end]...)) + # ip = _repartition(invperm(linearize((_p, _q))), k′) + # A = randn(T, permute(prod(V1) ⊗ V3 ← V4, ip)) + # + # α = randn(T) + # β = randn(T) + # C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + # Mooncake.TestUtils.test_rule( + # rng, TensorKit.planartrace!, + # C, A, p, q, α, β, + # TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + # atol, rtol, mode + # ) + # end + # end +end diff --git a/test/mooncake/tangent.jl b/test/mooncake/tangent.jl new file mode 100644 index 000000000..5b001fc51 --- /dev/null +++ b/test/mooncake/tangent.jl @@ -0,0 +1,58 @@ +using Test, TestExtras +using TensorKit +using Mooncake +using Random +using JET, AllocCheck + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup +using .TestSetup: _repartition + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +# only run on Linux since allocation tests are broken on other versions +Sys.islinux() && @timedtestset "Mooncake - Tangent type: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + Mooncake.TestUtils.test_data(rng, A) +end diff --git a/test/mooncake/tensoroperations.jl b/test/mooncake/tensoroperations.jl new file mode 100644 index 000000000..6802e4a73 --- /dev/null +++ b/test/mooncake/tensoroperations.jl @@ -0,0 +1,134 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: One, Zero +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Mooncake - TensorOperations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + + symmetricbraiding && @timedtestset "tensorcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init = one(V[1])) + + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, A, pA, B, pB, pAB, One(), Zero(), + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, A, pA, B, pB, pAB, α, β, + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) + T <: Complex && Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, A, pA, B, pB, pAB, real(α), real(β), + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) + end + end + + symmetricbraiding && @timedtestset "trace_permute!" begin + for _ in 1:5 + k1 = rand(0:2) + k2 = rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + Mooncake.TestUtils.test_rule( + rng, TensorKit.trace_permute!, C, A, p, q, α, β, TensorOperations.DefaultBackend(); + atol, rtol, mode + ) + end + end +end diff --git a/test/mooncake/vectorinterface.jl b/test/mooncake/vectorinterface.jl new file mode 100644 index 000000000..d43f7014d --- /dev/null +++ b/test/mooncake/vectorinterface.jl @@ -0,0 +1,75 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +@timedtestset "Mooncake - VectorInterface: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, mode, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, mode, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode) +end diff --git a/test/runtests.jl b/test/runtests.jl index 3b0bfe8b0..8f58d7dc8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,7 +57,7 @@ istestfile(fn) = endswith(fn, ".jl") && !contains(fn, "setup") # somehow AD tests are unreasonably slow on Apple CI # and ChainRulesTestUtils doesn't like prereleases - if group == "autodiff" + if group == "chainrules" Sys.isapple() && get(ENV, "CI", "false") == "true" && continue isempty(VERSION.prerelease) || continue end diff --git a/test/setup.jl b/test/setup.jl index 5c8516eb9..c99fbdad1 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -1,5 +1,7 @@ module TestSetup +export randindextuple, randcircshift, _repartition, trivtuple +export default_tol export smallset, randsector, hasfusiontensor, force_planar export random_fusion export sectorlist @@ -10,10 +12,52 @@ using Random using Test: @test using TensorKit using TensorKit: ℙ, PlanarTrivial +using TensorOperations: IndexTuple, Index2Tuple using Base.Iterators: take, product +using TupleTools Random.seed!(123456) +# IndexTuple utility +# ------------------ +function randindextuple(N::Int, k::Int = rand(0:N)) + @assert 0 ≤ k ≤ N + _p = randperm(N) + return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...)) +end +function randcircshift(N₁::Int, N₂::Int, k::Int = rand(0:(N₁ + N₂))) + N = N₁ + N₂ + @assert 0 ≤ k ≤ N + p = TupleTools.vcat(ntuple(identity, N₁), reverse(ntuple(identity, N₂) .+ N₁)) + n = rand(0:N) + _p = TupleTools.circshift(p, n) + return (tuple(_p[1:k]...), reverse(tuple(_p[(k + 1):end]...))) +end + +trivtuple(N) = ntuple(identity, N) + +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +# Float32 and finite differences don't mix well +default_tol(::Type{<:Union{Float32, Complex{Float32}}}) = 1.0e-2 +default_tol(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-5 + +# Sector utility +# -------------- smallset(::Type{I}) where {I <: Sector} = take(values(I), 5) function smallset(::Type{ProductSector{Tuple{I1, I2}}}) where {I1, I2} iter = product(smallset(I1), smallset(I2))