From 8df520ec05c2a7609dd4e942bb962cc0884efc6b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 8 Jan 2026 15:37:26 +0100 Subject: [PATCH 01/30] try to make truncation GPU-friendly --- src/factorizations/truncation.jl | 49 +++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index b9d060fec..3bc1ebf48 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -191,6 +191,23 @@ function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false) return values_sorted, perms end +function _findtruncvalue_order(values::SectorVector, n::Int; by = identity, rev::Bool = false) + I = sectortype(values) + p = sortperm(parent(values); by, rev) + + if FusionStyle(I) isa UniqueFusion # dimensions are all 1 + return n <= 0 ? nothing : p[min(n, length(p))] + else + dims = similar(values, Base.promote_op(dim, I)) + for (c, v) in pairs(dims) + fill!(v, dim(c)) + end + cumulative_dim = cumsum(Base.permute!(parent(dims), p)) + k = findlast(<=(n), cumulative_dim) + return isnothing(k) ? k : p[k] + end +end + # findtruncated # ------------- # Generic fallback @@ -202,25 +219,25 @@ function MAK.findtruncated(values::SectorVector, ::NoTruncation) return SectorDict(c => Colon() for c in keys(values)) end +# TruncationByOrder strategy: +# - find the howmany'th value of the input sorted according to the strategy +# - discard everything that is ordered after that value + function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) - values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev) - inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany)) - return SectorDict(c => perms[c][I] for (c, I) in inds) -end -function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) - I = keytype(values) - truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values)) - totaldim = sum(dim(c) * d for (c, d) in truncdim; init = 0) - while totaldim > strategy.howmany - next = _findnexttruncvalue(values, truncdim; strategy.by, strategy.rev) - isnothing(next) && break - _, cmin = next - truncdim[cmin] -= 1 - totaldim -= dim(cmin) - truncdim[cmin] == 0 && delete!(truncdim, cmin) + k = _findtruncvalue_order(values, strategy.howmany; strategy.by, strategy.rev) + + if isnothing(k) + # discard everything + return SectorDict{sectortype(values), UnitRange{Int}}() + else + val = strategy.by(values[k]) + strategy = trunctol(; atol = val, strategy.by, keep_below = !strategy.rev) + return MAK.findtruncated_svd(values, strategy) end - return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim) end +# disambiguate +MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) = + MAK.findtruncated(values, strategy) function MAK.findtruncated(values::SectorVector, strategy::TruncationByFilter) return SectorDict(c => findall(strategy.filter, d) for (c, d) in pairs(values)) From 1481184ad925591006dec9221bc2fcd5e7646a46 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 11:21:17 +0100 Subject: [PATCH 02/30] Temporarily fix StridedViews version --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 934bdb6ed..3e0c7a02d 100644 --- a/Project.toml +++ b/Project.toml @@ -52,6 +52,7 @@ Random = "1" SafeTestsets = "0.1" ScopedValues = "1.3.0" Strided = "2" +StridedViews = "=0.4.1" TensorKitSectors = "0.3.3" TensorOperations = "5.1" Test = "1" @@ -75,6 +76,7 @@ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" From 1e46b0d809a76457e7f726bd0b8b95fe841db7d2 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 16:16:49 +0100 Subject: [PATCH 03/30] Revert "Temporarily fix StridedViews version" This reverts commit 77f0ffa0879bd84e3cd304ec6c9e0b29a20b12de. --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 3e0c7a02d..934bdb6ed 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,6 @@ Random = "1" SafeTestsets = "0.1" ScopedValues = "1.3.0" Strided = "2" -StridedViews = "=0.4.1" TensorKitSectors = "0.3.3" TensorOperations = "5.1" Test = "1" @@ -76,7 +75,6 @@ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" From 8ef042513acb5109ffcb60f562f0077612c5d31a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 22 Jan 2026 13:54:03 +0100 Subject: [PATCH 04/30] Small update for diagonal pullbacks --- Project.toml | 3 +++ src/factorizations/diagonal.jl | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 934bdb6ed..ae640ba98 100644 --- a/Project.toml +++ b/Project.toml @@ -83,3 +83,6 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"] + +[sources] +MatrixAlgebraKit = {url = "https://github.com/quantumkithub/matrixalgebrakit.jl", rev = "main"} diff --git a/src/factorizations/diagonal.jl b/src/factorizations/diagonal.jl index bdcfebd74..92e4348cc 100644 --- a/src/factorizations/diagonal.jl +++ b/src/factorizations/diagonal.jl @@ -17,7 +17,7 @@ for f! in (:eig_full!, :eig_trunc!) @eval function MAK.initialize_output( ::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm ) - return d, similar(d) + return similar(d, complex(scalartype(d))), similar(d, complex(scalartype(d))) end end @@ -93,7 +93,7 @@ end # For diagonal inputs we don't have to promote the scalartype since we know they are symmetric function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::DiagonalAlgorithm) V_D = fuse(domain(t)) - Tc = scalartype(t) + Tc = complex(scalartype(t)) A = similarstoragetype(t, Tc) return SectorVector{Tc, sectortype(t), A}(undef, V_D) end From 8423ce82da83b6d467ab4fb57afb6da371e8bf34 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 22 Jan 2026 14:04:12 +0100 Subject: [PATCH 05/30] Fix last error --- src/factorizations/diagonal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/factorizations/diagonal.jl b/src/factorizations/diagonal.jl index 92e4348cc..04606249e 100644 --- a/src/factorizations/diagonal.jl +++ b/src/factorizations/diagonal.jl @@ -17,7 +17,7 @@ for f! in (:eig_full!, :eig_trunc!) @eval function MAK.initialize_output( ::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm ) - return similar(d, complex(scalartype(d))), similar(d, complex(scalartype(d))) + return similar(d, complex(scalartype(d))), similar(d, complex(scalartype(d)), space(d)) end end From 848f0cc1c508c390e2e3b9720276e38c1e4a9b2c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 22 Jan 2026 14:10:52 +0100 Subject: [PATCH 06/30] Reenable truncated CUDA tests --- test/cuda/factorizations.jl | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/test/cuda/factorizations.jl b/test/cuda/factorizations.jl index f7b6ad6d6..af5c1478a 100644 --- a/test/cuda/factorizations.jl +++ b/test/cuda/factorizations.jl @@ -229,17 +229,17 @@ for V in spacelist @test isisometric(N) @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) - #N = @constinferred left_null(t; trunc = (; atol = 100 * eps(norm(t)))) - #@test isisometric(N) - #@test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) + N = @constinferred left_null(t; trunc = (; atol = 100 * eps(norm(t)))) + @test isisometric(N) + @test norm(N' * t) ≈ 0 atol = 100 * eps(norm(t)) Nᴴ = @constinferred right_null(t; alg = :svd) @test isisometric(Nᴴ; side = :right) @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) - #Nᴴ = @constinferred right_null(t; trunc = (; atol = 100 * eps(norm(t)))) - #@test isisometric(Nᴴ; side = :right) - #@test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) + Nᴴ = @constinferred right_null(t; trunc = (; atol = 100 * eps(norm(t)))) + @test isisometric(Nᴴ; side = :right) + @test norm(t * Nᴴ') ≈ 0 atol = 100 * eps(norm(t)) end # empty tensor @@ -258,15 +258,15 @@ for V in spacelist end end - #=@testset "truncated SVD" begin + @testset "truncated SVD" begin for T in eltypes, t in ( CUDA.randn(T, W, W), - #CUDA.randn(T, W, W)', + CUDA.randn(T, W, W)', CUDA.randn(T, W, V4), CUDA.randn(T, V4, W), - #CUDA.randn(T, W, V4)', - #CUDA.randn(T, V4, W)', + CUDA.randn(T, W, V4)', + CUDA.randn(T, V4, W)', DiagonalTensorMap(CUDA.randn(T, reduceddim(V1)), V1), ) @@ -327,7 +327,7 @@ for V in spacelist @test minimum(diagview(S5)) >= λ @test dim(domain(S5)) ≤ nvals end - end=# # TODO + end @testset "Eigenvalue decomposition" begin for T in eltypes, @@ -349,10 +349,10 @@ for V in spacelist @test @constinferred isposdef(vdv) t isa DiagonalTensorMap || @test !isposdef(t) # unlikely for non-hermitian map - #=nvals = round(Int, dim(domain(t)) / 2) + nvals = round(Int, dim(domain(t)) / 2) d, v = @constinferred eig_trunc(t; trunc = truncrank(nvals)) @test t * v ≈ v * d - @test dim(domain(d)) ≤ nvals=# + @test dim(domain(d)) ≤ nvals t2 = @constinferred project_hermitian(t) D, V = eigen(t2) @@ -380,10 +380,9 @@ for V in spacelist @test isposdef(t2 - λ * one(t) + 0.1 * one(t2)) @test !isposdef(t2 - λ * one(t) - 0.1 * one(t2)) - # TODO - #=d, v = @constinferred eigh_trunc(t2; trunc = truncrank(nvals)) + d, v = @constinferred eigh_trunc(t2; trunc = truncrank(nvals)) @test t2 * v ≈ v * d - @test dim(domain(d)) ≤ nvals=# + @test dim(domain(d)) ≤ nvals end end From 7ae9b05cca06fd97976f8e8375bd6815917280d3 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 10:02:58 -0500 Subject: [PATCH 07/30] make truncation run on GPU --- ext/TensorKitCUDAExt/TensorKitCUDAExt.jl | 1 + ext/TensorKitCUDAExt/truncation.jl | 20 ++++++++++++++++++++ src/factorizations/truncation.jl | 16 ++++++++-------- 3 files changed, 29 insertions(+), 8 deletions(-) create mode 100644 ext/TensorKitCUDAExt/truncation.jl diff --git a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl index 417970a02..f5efb98bb 100644 --- a/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl +++ b/ext/TensorKitCUDAExt/TensorKitCUDAExt.jl @@ -17,5 +17,6 @@ using TensorKit: MatrixAlgebraKit using Random include("cutensormap.jl") +include("truncation.jl") end diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl new file mode 100644 index 000000000..ca66c394b --- /dev/null +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -0,0 +1,20 @@ +const CuSectorVector{T, I} = TensorKit.SectorVector{T, I, <: CuVector{T}} + +function Factorizations._findtruncvalue_order( + values::CuSectorVector, n::Int; by = identity, rev::Bool = false +) + I = sectortype(values) + p = sortperm(parent(values); by, rev) + + if FusionStyle(I) isa UniqueFusion # dimensions are all 1 + return n <= 0 ? nothing : @allowscalar(by(values[p[min(n, length(p))]])) + else + dims = similar(values, Base.promote_op(dim, I)) + for (c, v) in pairs(dims) + fill!(v, dim(c)) + end + cumulative_dim = cumsum(Base.permute!(parent(dims), p)) + k = findlast(<=(n), cumulative_dim) + return isnothing(k) ? k : @allowscalar(by(values[p[k]])) + end +end diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 3bc1ebf48..2ae1ad5d5 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -193,19 +193,20 @@ end function _findtruncvalue_order(values::SectorVector, n::Int; by = identity, rev::Bool = false) I = sectortype(values) - p = sortperm(parent(values); by, rev) if FusionStyle(I) isa UniqueFusion # dimensions are all 1 - return n <= 0 ? nothing : p[min(n, length(p))] - else + p = sortperm(parent(values), n; by, rev) + return n <= 0 ? nothing : by(values[p[min(n, length(p))]]) + end + + p = sortperm(parent(values); by, rev) dims = similar(values, Base.promote_op(dim, I)) for (c, v) in pairs(dims) fill!(v, dim(c)) end cumulative_dim = cumsum(Base.permute!(parent(dims), p)) k = findlast(<=(n), cumulative_dim) - return isnothing(k) ? k : p[k] - end + return isnothing(k) ? k : by(values[p[k]]) end # findtruncated @@ -224,13 +225,12 @@ end # - discard everything that is ordered after that value function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) - k = _findtruncvalue_order(values, strategy.howmany; strategy.by, strategy.rev) + val = _findtruncvalue_order(values, strategy.howmany; strategy.by, strategy.rev) - if isnothing(k) + if isnothing(val) # discard everything return SectorDict{sectortype(values), UnitRange{Int}}() else - val = strategy.by(values[k]) strategy = trunctol(; atol = val, strategy.by, keep_below = !strategy.rev) return MAK.findtruncated_svd(values, strategy) end From b1fe3bdaf2b7206591aecb5e7d687aeed0097a05 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 10:03:16 -0500 Subject: [PATCH 08/30] bypass scalar indexing by specializing --- src/factorizations/truncation.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 2ae1ad5d5..2beb940a4 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -29,6 +29,8 @@ end # --------- _blocklength(d::Integer, ind) = _blocklength(Base.OneTo(d), ind) _blocklength(ax, ind) = length(ax[ind]) +_blocklength(ax::Base.OneTo, ind::AbstractVector{<:Integer}) = length(ind) + function truncate_space(V::ElementarySpace, inds) return spacetype(V)(c => _blocklength(dim(V, c), ind) for (c, ind) in inds) end From 94ecfcaf9352230e6551d5149d002e9ea8deb1b3 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 13:52:20 -0500 Subject: [PATCH 09/30] convenience overloads --- src/tensors/sectorvector.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/tensors/sectorvector.jl b/src/tensors/sectorvector.jl index 4c914d6de..67cab2fb9 100644 --- a/src/tensors/sectorvector.jl +++ b/src/tensors/sectorvector.jl @@ -108,3 +108,11 @@ LinearAlgebra.dot(v1::SectorVector, v2::SectorVector) = inner(v1, v2) function LinearAlgebra.norm(v::SectorVector, p::Real = 2) return _norm(blocks(v), p, float(zero(real(scalartype(v))))) end + +# Common functionality +# -------------------- +# specific overloads for performance and/or GPU +Base.minimum(x::SectorVector) = minimum(parent(x)) +Base.minimum(f, x::SectorVector) = minimum(f, parent(x)) +Base.maximum(x::SectorVector) = maximum(parent(x)) +Base.maximum(f, x::SectorVector) = maximum(f, parent(x)) From 7395b8a804b7d11baf42dd205ac9bf9b5a959a07 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 16:50:28 -0500 Subject: [PATCH 10/30] gpu-friendly copies --- src/factorizations/truncation.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 2beb940a4..136321896 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -39,7 +39,8 @@ function truncate_domain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, inds for (c, b) in blocks(tdst) I = get(inds, c, nothing) @assert !isnothing(I) - copy!(b, view(block(tsrc, c), :, I)) + b′ = block(tsrc, c) + b .= b′[:, I] end return tdst end @@ -47,7 +48,8 @@ function truncate_codomain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, in for (c, b) in blocks(tdst) I = get(inds, c, nothing) @assert !isnothing(I) - copy!(b, view(block(tsrc, c), I, :)) + b′ = block(tsrc, c) + b .= b′[I, :] end return tdst end @@ -55,7 +57,7 @@ function truncate_diagonal!(Ddst::DiagonalTensorMap, Dsrc::DiagonalTensorMap, in for (c, b) in blocks(Ddst) I = get(inds, c, nothing) @assert !isnothing(I) - copy!(diagview(b), view(diagview(block(Dsrc, c)), I)) + diagview(b) .= @view diagview(block(Dsrc, c))[I] end return Ddst end From 9af19b7648d0739b823e5ca38b5d46848af805f2 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 16:50:45 -0500 Subject: [PATCH 11/30] retain storagetype in extended_S --- src/factorizations/truncation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 136321896..fe6fa5a31 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -82,7 +82,7 @@ end function MAK.truncate( ::typeof(left_null!), (U, S)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy ) - extended_S = zerovector!(SectorVector{eltype(S)}(undef, fuse(codomain(U)))) + extended_S = zerovector!(SectorVector{eltype(S), sectortype(S), storagetype(S)}(undef, fuse(codomain(U)))) for (c, b) in blocks(S) copyto!(extended_S[c], diagview(b)) # copyto! since `b` might be shorter end @@ -95,7 +95,7 @@ end function MAK.truncate( ::typeof(right_null!), (S, Vᴴ)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy ) - extended_S = zerovector!(SectorVector{eltype(S)}(undef, fuse(domain(Vᴴ)))) + extended_S = zerovector!(SectorVector{eltype(S), sectortype(S), storagetype(S)}(undef, fuse(domain(Vᴴ)))) for (c, b) in blocks(S) copyto!(extended_S[c], diagview(b)) # copyto! since `b` might be shorter end From 180afe6395054aa5b0512949392273cf7c48b7e3 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 16:53:07 -0500 Subject: [PATCH 12/30] avoid GPU issues with truncated adjoint tensormaps --- src/factorizations/adjoint.jl | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/factorizations/adjoint.jl b/src/factorizations/adjoint.jl index 20d6d5986..eae8989ce 100644 --- a/src/factorizations/adjoint.jl +++ b/src/factorizations/adjoint.jl @@ -7,6 +7,7 @@ _adjoint(alg::MAK.LAPACK_HouseholderLQ) = MAK.LAPACK_HouseholderQR(; alg.kwargs. _adjoint(alg::MAK.LAPACK_HouseholderQL) = MAK.LAPACK_HouseholderRQ(; alg.kwargs...) _adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs...) _adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svd_alg)) +_adjoint(alg::TruncatedAlgorithm) = TruncatedAlgorithm(_adjoint(alg.alg), alg.trunc) _adjoint(alg::AbstractAlgorithm) = alg _adjoint(alg::MAK.CUSOLVER_HouseholderQR) = MAK.LQViaTransposedQR(alg) @@ -81,7 +82,7 @@ for (left_f, right_f) in zip( end # 3-arg functions -for f in (:svd_full, :svd_compact) +for f in (:svd_full, :svd_compact, :svd_trunc) f! = Symbol(f, :!) @eval function MAK.copy_input(::typeof($f), t::AdjointTensorMap) return adjoint(MAK.copy_input($f, adjoint(t))) @@ -93,9 +94,16 @@ for f in (:svd_full, :svd_compact) return reverse(adjoint.(MAK.initialize_output($f!, adjoint(t), _adjoint(alg)))) end - @eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) - F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) - return reverse(adjoint.(F′)) + if f === :svd_trunc + function MAK.svd_trunc!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) + U, S, Vᴴ, ϵ = svd_trunc!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return Vᴴ', S, U', ϵ + end + else + @eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm) + F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return reverse(adjoint.(F′)) + end end # disambiguate by prohibition @@ -111,6 +119,15 @@ function MAK.svd_compact!(t::AdjointTensorMap, F, alg::DiagonalAlgorithm) F′ = svd_compact!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) return reverse(adjoint.(F′)) end +function MAK.initialize_output( + ::typeof(svd_trunc!), t::AdjointTensorMap, alg::TruncatedAlgorithm + ) + return reverse(adjoint.(MAK.initialize_output(svd_trunc!, adjoint(t), _adjoint(alg)))) +end +function MAK.svd_trunc!(t::AdjointTensorMap, F, alg::TruncatedAlgorithm) + U, S, Vᴴ, ϵ = svd_trunc!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) + return Vᴴ', S, U', ϵ +end function LinearAlgebra.isposdef(t::AdjointTensorMap) return isposdef(adjoint(t)) From efbe0886717fd5e4599790ddef4a439b72b9b2ca Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 16:54:04 -0500 Subject: [PATCH 13/30] various utility improvements --- src/factorizations/truncation.jl | 5 +++-- src/tensors/sectorvector.jl | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index fe6fa5a31..7bcb4ff5d 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -30,9 +30,10 @@ end _blocklength(d::Integer, ind) = _blocklength(Base.OneTo(d), ind) _blocklength(ax, ind) = length(ax[ind]) _blocklength(ax::Base.OneTo, ind::AbstractVector{<:Integer}) = length(ind) +_blocklength(ax::Base.OneTo, ind::AbstractVector{Bool}) = count(ind) function truncate_space(V::ElementarySpace, inds) - return spacetype(V)(c => _blocklength(dim(V, c), ind) for (c, ind) in inds) + return spacetype(V)(c => _blocklength(dim(V, c), ind) for (c, ind) in pairs(inds)) end function truncate_domain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, inds) @@ -314,7 +315,7 @@ end MAK.truncation_error(values::SectorVector, ind) = MAK.truncation_error!(copy(values), ind) function MAK.truncation_error!(values::SectorVector, ind) - for (c, ind_c) in ind + for (c, ind_c) in pairs(ind) v = values[c] v[ind_c] .= zero(eltype(v)) end diff --git a/src/tensors/sectorvector.jl b/src/tensors/sectorvector.jl index 67cab2fb9..a005c84cd 100644 --- a/src/tensors/sectorvector.jl +++ b/src/tensors/sectorvector.jl @@ -36,6 +36,7 @@ Base.size(v::SectorVector, args...) = size(parent(v), args...) Base.similar(v::SectorVector) = SectorVector(similar(v.data), v.structure) Base.similar(v::SectorVector, ::Type{T}) where {T} = SectorVector(similar(v.data, T), v.structure) +Base.similar(v::SectorVector, V::ElementarySpace) where {T} = typeof(v)(undef, V) Base.copy(v::SectorVector) = SectorVector(copy(v.data), v.structure) @@ -53,6 +54,9 @@ Base.keys(v::SectorVector) = keys(v.structure) Base.values(v::SectorVector) = (v[c] for c in keys(v)) Base.pairs(v::SectorVector) = SectorDict(c => v[c] for c in keys(v)) +Base.get(v::SectorVector{<:Any, I}, key::I, default) where {I} = haskey(v, key) ? v[key] : default +Base.haskey(v::SectorVector{<:Any, I}, key::I) where {I} = key in keys(v) + # TensorKit interface # ------------------- sectortype(::Type{T}) where {T <: SectorVector} = keytype(T) From eafd7a80769cb705cd0a6667853949e94de842e1 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 16:54:43 -0500 Subject: [PATCH 14/30] complete rewrite of implementation --- src/factorizations/truncation.jl | 131 ++++++++----------------------- 1 file changed, 34 insertions(+), 97 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 7bcb4ff5d..6313c7f50 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -147,75 +147,11 @@ for f! in (:eig_trunc!, :eigh_trunc!) end end -# Find truncation -# --------------- +# findtruncated +# ------------- # auxiliary functions rtol_to_atol(S, p, atol, rtol) = rtol == 0 ? atol : max(atol, norm(S, p) * rtol) -function _compute_truncerr(Σdata, truncdim, p = 2) - I = keytype(Σdata) - S = scalartype(valtype(Σdata)) - return TensorKit._norm( - (c => @view(v[(get(truncdim, c, 0) + 1):end]) for (c, v) in Σdata), - p, zero(S) - ) -end - -function _findnexttruncvalue( - S, truncdim::SectorDict{I, Int}; by = identity, rev::Bool = true - ) where {I <: Sector} - # early return - (isempty(S) || all(iszero, values(truncdim))) && return nothing - if rev - σmin, imin = findmin(keys(truncdim)) do c - d = truncdim[c] - return by(S[c][d]) - end - return σmin, keys(truncdim)[imin] - else - σmax, imax = findmax(keys(truncdim)) do c - d = truncdim[c] - return by(S[c][d]) - end - return σmax, keys(truncdim)[imax] - end -end - -function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false) - values_sorted = similar(values) - perms = SectorDict( - ( - begin - p = sortperm(v; by, rev) - vs = values_sorted[c] - vs .= view(v, p) - c => p - end - ) for (c, v) in pairs(values) - ) - return values_sorted, perms -end - -function _findtruncvalue_order(values::SectorVector, n::Int; by = identity, rev::Bool = false) - I = sectortype(values) - - if FusionStyle(I) isa UniqueFusion # dimensions are all 1 - p = sortperm(parent(values), n; by, rev) - return n <= 0 ? nothing : by(values[p[min(n, length(p))]]) - end - - p = sortperm(parent(values); by, rev) - dims = similar(values, Base.promote_op(dim, I)) - for (c, v) in pairs(dims) - fill!(v, dim(c)) - end - cumulative_dim = cumsum(Base.permute!(parent(dims), p)) - k = findlast(<=(n), cumulative_dim) - return isnothing(k) ? k : by(values[p[k]]) -end - -# findtruncated -# ------------- # Generic fallback function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationStrategy) return MAK.findtruncated(values, strategy) @@ -225,20 +161,20 @@ function MAK.findtruncated(values::SectorVector, ::NoTruncation) return SectorDict(c => Colon() for c in keys(values)) end -# TruncationByOrder strategy: -# - find the howmany'th value of the input sorted according to the strategy -# - discard everything that is ordered after that value - function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) - val = _findtruncvalue_order(values, strategy.howmany; strategy.by, strategy.rev) + I = sectortype(values) - if isnothing(val) - # discard everything - return SectorDict{sectortype(values), UnitRange{Int}}() - else - strategy = trunctol(; atol = val, strategy.by, keep_below = !strategy.rev) - return MAK.findtruncated_svd(values, strategy) + dims = similar(values, Base.promote_op(dim, I)) + for (c, v) in pairs(dims) + fill!(v, dim(c)) end + + perm = sortperm(parent(values); strategy.by, strategy.rev) + cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) + + result = similar(values, Bool) + parent(result)[perm] .= cumulative_dim .<= strategy.howmany + return result end # disambiguate MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) = @@ -259,28 +195,29 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByValue return SectorDict(c => MAK.findtruncated_svd(d, strategy′) for (c, d) in pairs(values)) end -function MAK.findtruncated(values::SectorVector, strategy::TruncationByError) - values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev) - inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany)) - return SectorDict(c => perms[c][I] for (c, I) in inds) -end -function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByError) - I = keytype(values) - truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values)) - by(c, v) = abs(v)^strategy.p * dim(c) - Nᵖ = sum(((c, v),) -> sum(Base.Fix1(by, c), v), pairs(values)) - ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ) - truncerrᵖ = zero(real(scalartype(valtype(values)))) - next = _findnexttruncvalue(values, truncdim) - while !isnothing(next) - σmin, cmin = next - truncerrᵖ += by(cmin, σmin) - truncerrᵖ >= ϵᵖ && break - (truncdim[cmin] -= 1) == 0 && delete!(truncdim, cmin) - next = _findnexttruncvalue(values, truncdim) +function MAK.findtruncated(values::SectorVector, strategy::MAK.TruncationByError) + ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p)) + ϵᵖ = similar(values, typeof(ϵᵖmax)) + + if FusionStyle(sectortype(values)) isa UniqueFusion # dimensions are all 1 + parent(ϵᵖ) .= abs.(parent(values)) .^ strategy.p + else + for (c, v) in pairs(values) + v′ = ϵᵖ[c] + v′ .= abs.(v) .^ strategy.p .* dim(c) + end end - return SectorDict{I, Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim) + + perm = sortperm(parent(values); by = abs, rev = false) + cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm)) + + result = similar(values, Bool) + parent(result)[perm] .= cumulative_err .> ϵᵖmax + return result end +# disambiguate +MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByError) = + MAK.findtruncated(values, strategy) function MAK.findtruncated(values::SectorVector, strategy::TruncationSpace) blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev) From f4892cf979d740ebf1e9312093d8be5a30c2c387 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 16:55:02 -0500 Subject: [PATCH 15/30] GPU doesn't like `trues` --- src/factorizations/truncation.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 6313c7f50..93dde1b86 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -232,8 +232,7 @@ function MAK.findtruncated(values::SectorVector, strategy::TruncationIntersectio inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components) return SectorDict( c => mapreduce( - Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds; - init = trues(length(values[c])) + Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds ) for c in intersect(map(keys, inds)...) ) end @@ -241,8 +240,7 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationInterse inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components) return SectorDict( c => mapreduce( - Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds; - init = trues(length(values[c])) + Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds ) for c in intersect(map(keys, inds)...) ) end From 3f273a1e20cba238015d572d3b0097c388dbda68 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 16:55:31 -0500 Subject: [PATCH 16/30] remove CUDA specializations and temporarily add missing MatrixAlgebraKit thingies --- Project.toml | 3 --- ext/TensorKitCUDAExt/truncation.jl | 25 ++++++------------------- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index ae640ba98..934bdb6ed 100644 --- a/Project.toml +++ b/Project.toml @@ -83,6 +83,3 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"] - -[sources] -MatrixAlgebraKit = {url = "https://github.com/quantumkithub/matrixalgebrakit.jl", rev = "main"} diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index ca66c394b..4412cb221 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -1,20 +1,7 @@ -const CuSectorVector{T, I} = TensorKit.SectorVector{T, I, <: CuVector{T}} - -function Factorizations._findtruncvalue_order( - values::CuSectorVector, n::Int; by = identity, rev::Bool = false -) - I = sectortype(values) - p = sortperm(parent(values); by, rev) - - if FusionStyle(I) isa UniqueFusion # dimensions are all 1 - return n <= 0 ? nothing : @allowscalar(by(values[p[min(n, length(p))]])) - else - dims = similar(values, Base.promote_op(dim, I)) - for (c, v) in pairs(dims) - fill!(v, dim(c)) - end - cumulative_dim = cumsum(Base.permute!(parent(dims), p)) - k = findlast(<=(n), cumulative_dim) - return isnothing(k) ? k : @allowscalar(by(values[p[k]])) - end +function MatrixAlgebraKit._ind_intersect(A::AbstractVector{Bool}, B::CuVector{Int}) + return MatrixAlgebraKit._ind_intersect(findall(A), B) end + +# TODO: intersect doesn't work on GPU +MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = + MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) From 6842a702cc00d641853bd5f6db39fa35a6576f27 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 17:03:25 -0500 Subject: [PATCH 17/30] better dimension testing --- test/cuda/factorizations.jl | 10 +++++----- test/setup.jl | 12 ++++++++++++ test/tensors/factorizations.jl | 8 ++++---- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/test/cuda/factorizations.jl b/test/cuda/factorizations.jl index af5c1478a..f3f15fe4b 100644 --- a/test/cuda/factorizations.jl +++ b/test/cuda/factorizations.jl @@ -286,7 +286,7 @@ for V in spacelist @test isisometric(U1) @test isisometric(Vᴴ1; side = :right) @test norm(t - U1 * S1 * Vᴴ1) ≈ ϵ1 atol = eps(real(T))^(4 / 5) - @test dim(domain(S1)) <= nvals + test_dim_isapprox(domain(S1), nvals) λ = minimum(diagview(S1)) trunc = trunctol(; atol = λ - 10eps(λ)) @@ -325,7 +325,7 @@ for V in spacelist @test isisometric(Vᴴ5; side = :right) @test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5) @test minimum(diagview(S5)) >= λ - @test dim(domain(S5)) ≤ nvals + test_dim_isapprox(domain(S5), nvals) end end @@ -335,7 +335,7 @@ for V in spacelist CUDA.rand(T, V1, V1), CUDA.rand(T, W, W), CUDA.rand(T, W, W)', - DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1), + # DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1), ) d, v = @constinferred eig_full(t) @@ -352,7 +352,7 @@ for V in spacelist nvals = round(Int, dim(domain(t)) / 2) d, v = @constinferred eig_trunc(t; trunc = truncrank(nvals)) @test t * v ≈ v * d - @test dim(domain(d)) ≤ nvals + test_dim_isapprox(domain(d), nvals) t2 = @constinferred project_hermitian(t) D, V = eigen(t2) @@ -382,7 +382,7 @@ for V in spacelist d, v = @constinferred eigh_trunc(t2; trunc = truncrank(nvals)) @test t2 * v ≈ v * d - @test dim(domain(d)) ≤ nvals + test_dim_isapprox(domain(d), nvals) end end diff --git a/test/setup.jl b/test/setup.jl index 6cde01d28..c0aeed17b 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -3,6 +3,7 @@ module TestSetup export smallset, randsector, hasfusiontensor, force_planar export random_fusion export sectorlist +export test_dim_isapprox export Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, VSU₂U₁, Vfib, VIB_diag, VIB_M using Random @@ -88,6 +89,17 @@ function random_fusion(I::Type{<:Sector}, ::Val{N}) where {N} # for fusion tree return (s, tail...) end +# helper function to check that d - dim(c) < dim(V) <= d where c is the largest sector +# to allow for truncations to have some margin with larger sectors +function test_dim_isapprox(V::ElementarySpace, d::Int) + dim_c_max = maximum(dim, sectors(V); init = 1) + return @test max(0, d - dim_c_max) ≤ dim(V) ≤ d + dim_c_max +end +function test_dim_isapprox(V::ProductSpace, d::Int) + dim_c_max = maximum(dim, blocksectors(V); init = 1) + return @test max(0, d - dim_c_max) ≤ dim(V) ≤ d + dim_c_max +end + sectorlist = ( Z2Irrep, Z3Irrep, Z4Irrep, Z3Irrep ⊠ Z4Irrep, U1Irrep, CU1Irrep, SU2Irrep, diff --git a/test/tensors/factorizations.jl b/test/tensors/factorizations.jl index 176d62657..41f30567b 100644 --- a/test/tensors/factorizations.jl +++ b/test/tensors/factorizations.jl @@ -259,7 +259,7 @@ for V in spacelist @test isisometric(U1) @test isisometric(Vᴴ1; side = :right) @test norm(t - U1 * S1 * Vᴴ1) ≈ ϵ1 atol = eps(real(T))^(4 / 5) - @test dim(domain(S1)) <= nvals + test_dim_isapprox(domain(S1), nvals) λ = minimum(diagview(S1)) trunc = trunctol(; atol = λ - 10eps(λ)) @@ -298,7 +298,7 @@ for V in spacelist @test isisometric(Vᴴ5; side = :right) @test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5) @test minimum(diagview(S5)) >= λ - @test dim(domain(S5)) ≤ nvals + test_dim_isapprox(domain(S5), nvals) end end @@ -323,7 +323,7 @@ for V in spacelist nvals = round(Int, dim(domain(t)) / 2) d, v = @constinferred eig_trunc(t; trunc = truncrank(nvals)) @test t * v ≈ v * d - @test dim(domain(d)) ≤ nvals + test_dim_isapprox(domain(d), nvals) t2 = @constinferred project_hermitian(t) D, V = eigen(t2) @@ -353,7 +353,7 @@ for V in spacelist d, v = @constinferred eigh_trunc(t2; trunc = truncrank(nvals)) @test t2 * v ≈ v * d - @test dim(domain(d)) ≤ nvals + test_dim_isapprox(domain(d), nvals) end end From 5bb2a237f108af8b5737299231e19e7c31d39b34 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 17:07:58 -0500 Subject: [PATCH 18/30] fix unbound type parameter --- src/tensors/sectorvector.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/sectorvector.jl b/src/tensors/sectorvector.jl index a005c84cd..5d88300c2 100644 --- a/src/tensors/sectorvector.jl +++ b/src/tensors/sectorvector.jl @@ -36,7 +36,7 @@ Base.size(v::SectorVector, args...) = size(parent(v), args...) Base.similar(v::SectorVector) = SectorVector(similar(v.data), v.structure) Base.similar(v::SectorVector, ::Type{T}) where {T} = SectorVector(similar(v.data, T), v.structure) -Base.similar(v::SectorVector, V::ElementarySpace) where {T} = typeof(v)(undef, V) +Base.similar(v::SectorVector, V::ElementarySpace) = typeof(v)(undef, V) Base.copy(v::SectorVector) = SectorVector(copy(v.data), v.structure) From ddd0ed6671ae91bb74319eb8fcce8d309f367c52 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 17:08:35 -0500 Subject: [PATCH 19/30] add missing import --- test/setup.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/setup.jl b/test/setup.jl index c0aeed17b..5c8516eb9 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -7,6 +7,7 @@ export test_dim_isapprox export Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, VSU₂U₁, Vfib, VIB_diag, VIB_M using Random +using Test: @test using TensorKit using TensorKit: ℙ, PlanarTrivial using Base.Iterators: take, product From f3b45ef66d5f306dd6bb2bc6d6d0f4125ec0e551 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 22 Jan 2026 17:37:08 -0500 Subject: [PATCH 20/30] be careful about double method definitions --- src/tensors/sectorvector.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/tensors/sectorvector.jl b/src/tensors/sectorvector.jl index 5d88300c2..8f133b20e 100644 --- a/src/tensors/sectorvector.jl +++ b/src/tensors/sectorvector.jl @@ -36,7 +36,7 @@ Base.size(v::SectorVector, args...) = size(parent(v), args...) Base.similar(v::SectorVector) = SectorVector(similar(v.data), v.structure) Base.similar(v::SectorVector, ::Type{T}) where {T} = SectorVector(similar(v.data, T), v.structure) -Base.similar(v::SectorVector, V::ElementarySpace) = typeof(v)(undef, V) +Base.similar(v::SectorVector, V::ElementarySpace) = SectorVector{eltype(v), sectortype(V), storagetype(v)}(undef, V) Base.copy(v::SectorVector) = SectorVector(copy(v.data), v.structure) @@ -60,8 +60,7 @@ Base.haskey(v::SectorVector{<:Any, I}, key::I) where {I} = key in keys(v) # TensorKit interface # ------------------- sectortype(::Type{T}) where {T <: SectorVector} = keytype(T) - -Base.similar(v::SectorVector, V::ElementarySpace) = SectorVector(undef, V) +storagetype(::Type{SectorVector{T, I, A}}) where {T, I, A} = A blocksectors(v::SectorVector) = keys(v) blocks(v::SectorVector) = pairs(v) From f26cffe42f41785b9818404326b52b4473bb4542 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 23 Jan 2026 07:56:00 -0500 Subject: [PATCH 21/30] disable diagonal test --- test/tensors/factorizations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tensors/factorizations.jl b/test/tensors/factorizations.jl index 41f30567b..49bb2eff7 100644 --- a/test/tensors/factorizations.jl +++ b/test/tensors/factorizations.jl @@ -306,7 +306,7 @@ for V in spacelist for T in eltypes, t in ( rand(T, V1, V1), rand(T, W, W), rand(T, W, W)', - DiagonalTensorMap(rand(T, reduceddim(V1)), V1), + # DiagonalTensorMap(rand(T, reduceddim(V1)), V1), ) d, v = @constinferred eig_full(t) From 6ff9ac8cb067623edbe4fc52e9d632f04522459a Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 23 Jan 2026 13:47:42 -0500 Subject: [PATCH 22/30] bump MatrixAlgebraKit dependency --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 934bdb6ed..ddcbe4141 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ FiniteDifferences = "0.12" GPUArrays = "11.3.1" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.6.2" +MatrixAlgebraKit = "0.6.3" Mooncake = "0.4.183" OhMyThreads = "0.8.0" Printf = "1" From 66664593919dc4acad0973bc4805a46db2b3fb80 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 23 Jan 2026 13:47:52 -0500 Subject: [PATCH 23/30] Revert "disable diagonal test" This reverts commit f26cffe42f41785b9818404326b52b4473bb4542. --- test/tensors/factorizations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tensors/factorizations.jl b/test/tensors/factorizations.jl index 49bb2eff7..41f30567b 100644 --- a/test/tensors/factorizations.jl +++ b/test/tensors/factorizations.jl @@ -306,7 +306,7 @@ for V in spacelist for T in eltypes, t in ( rand(T, V1, V1), rand(T, W, W), rand(T, W, W)', - # DiagonalTensorMap(rand(T, reduceddim(V1)), V1), + DiagonalTensorMap(rand(T, reduceddim(V1)), V1), ) d, v = @constinferred eig_full(t) From ebbdb84c733e06f30f2b44ae8e7be99273db382f Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 23 Jan 2026 14:07:35 -0500 Subject: [PATCH 24/30] remove unnecessary specializations --- src/factorizations/diagonal.jl | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/factorizations/diagonal.jl b/src/factorizations/diagonal.jl index 04606249e..dae550ea1 100644 --- a/src/factorizations/diagonal.jl +++ b/src/factorizations/diagonal.jl @@ -13,26 +13,6 @@ for f in ( @eval MAK.copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d) end -for f! in (:eig_full!, :eig_trunc!) - @eval function MAK.initialize_output( - ::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm - ) - return similar(d, complex(scalartype(d))), similar(d, complex(scalartype(d)), space(d)) - end -end - -for f! in (:eigh_full!, :eigh_trunc!) - @eval function MAK.initialize_output( - ::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm - ) - if scalartype(d) <: Real - return d, similar(d, space(d)) - else - return similar(d, real(scalartype(d))), similar(d, space(d)) - end - end -end - for f! in (:qr_full!, :qr_compact!) @eval function MAK.initialize_output( ::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm From 2d7338af1c9f2d4719a29ad5cd272119cea27a8e Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 23 Jan 2026 14:37:36 -0500 Subject: [PATCH 25/30] specialize CPU implementations --- ext/TensorKitCUDAExt/truncation.jl | 29 ++++++++++++++++++++++++----- src/factorizations/truncation.jl | 21 ++++++++++++++++++--- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index 4412cb221..87485d804 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -1,7 +1,26 @@ -function MatrixAlgebraKit._ind_intersect(A::AbstractVector{Bool}, B::CuVector{Int}) - return MatrixAlgebraKit._ind_intersect(findall(A), B) +const CuSectorVector{T, I} = TensorKit.SectorVector{T, I, <:CuVector{T}} + +function MatrixAlgebraKit.findtruncated( + values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByOrder + ) + I = sectortype(values) + + dims = similar(values, Base.promote_op(dim, I)) + for (c, v) in pairs(dims) + fill!(v, dim(c)) + end + + perm = sortperm(parent(values); strategy.by, strategy.rev) + cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) + + result = similar(values, Bool) + parent(result)[perm] .= cumulative_dim .<= strategy.howmany + return result end -# TODO: intersect doesn't work on GPU -MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = - MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) +# Needed until MatrixAlgebraKit patch hits... +function MatrixAlgebraKit._ind_intersect(A::CuVector{Bool}, B::CuVector{Int}) + result = fill!(similar(A), false) + result[B] .= @view A[B] + return result +end diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 93dde1b86..7572714a3 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -164,16 +164,31 @@ end function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) I = sectortype(values) + + if FusionStyle(I) isa UniqueFusion # dimensions are all 1 + perm = partialsortperm(parent(values), 1:strategy.howmany; strategy.by, strategy.rev) + result = similar(values, Bool) + fill!(parent(result), false) + parent(result)[perm] .= true + return result + end + dims = similar(values, Base.promote_op(dim, I)) for (c, v) in pairs(dims) fill!(v, dim(c)) end perm = sortperm(parent(values); strategy.by, strategy.rev) - cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) - result = similar(values, Bool) - parent(result)[perm] .= cumulative_dim .<= strategy.howmany + fill!(parent(result), false) + + totaldim = 0 + for i in perm + totaldim += dims[i] + totaldim > strategy.howmany && break + result[i] = true + end + return result end # disambiguate From bde9c506c5b5fe2b241389ab6484e742d23dbe13 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 26 Jan 2026 10:09:41 -0500 Subject: [PATCH 26/30] add explanation TruncationByOrder --- src/factorizations/truncation.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 7572714a3..f516c1aaa 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -161,11 +161,16 @@ function MAK.findtruncated(values::SectorVector, ::NoTruncation) return SectorDict(c => Colon() for c in keys(values)) end +# Need to select the first k values here after sorting across blocks, weighted by quantum dimension +# The strategy is therefore to sort all values, and then use a logical array to indicate +# which ones to keep. +# For GenericFusion, we additionally keep a vector of the quantum dimensions to provide the +# correct weight function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) I = sectortype(values) - - if FusionStyle(I) isa UniqueFusion # dimensions are all 1 + # dimensions are all 1 so no need to account for weight + if FusionStyle(I) isa UniqueFusion perm = partialsortperm(parent(values), 1:strategy.howmany; strategy.by, strategy.rev) result = similar(values, Bool) fill!(parent(result), false) @@ -173,17 +178,19 @@ function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) return result end + # allocate vector of weights for each value dims = similar(values, Base.promote_op(dim, I)) for (c, v) in pairs(dims) fill!(v, dim(c)) end - perm = sortperm(parent(values); strategy.by, strategy.rev) + # allocate logical array for the output result = similar(values, Bool) fill!(parent(result), false) + # loop over sorted values and mark first `howmany` as to keep totaldim = 0 - for i in perm + for i in sortperm(parent(values); strategy.by, strategy.rev) totaldim += dims[i] totaldim > strategy.howmany && break result[i] = true From d603b9e0a23c59a01592358cdfd447a5e1128050 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 26 Jan 2026 10:18:05 -0500 Subject: [PATCH 27/30] add explanation and specialization TruncationByError --- ext/TensorKitCUDAExt/truncation.jl | 24 ++++++++++++++++++++++++ src/factorizations/truncation.jl | 24 ++++++++++++++++++------ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index 87485d804..49b2aedec 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -18,6 +18,30 @@ function MatrixAlgebraKit.findtruncated( return result end +function MatrixAlgebraKit.findtruncated( + values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByError + ) + ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p)) + ϵᵖ = similar(values, typeof(ϵᵖmax)) + + # dimensions are all 1 so no need to account for weight + if FusionStyle(sectortype(values)) isa UniqueFusion + parent(ϵᵖ) .= abs.(parent(values)) .^ strategy.p + else + for (c, v) in pairs(values) + v′ = ϵᵖ[c] + v′ .= abs.(v) .^ strategy.p .* dim(c) + end + end + + perm = sortperm(parent(values); by = abs, rev = false) + cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm)) + + result = similar(values, Bool) + parent(result)[perm] .= cumulative_err .> ϵᵖmax + return result +end + # Needed until MatrixAlgebraKit patch hits... function MatrixAlgebraKit._ind_intersect(A::CuVector{Bool}, B::CuVector{Int}) result = fill!(similar(A), false) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index f516c1aaa..8ae57c696 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -188,7 +188,7 @@ function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) result = similar(values, Bool) fill!(parent(result), false) - # loop over sorted values and mark first `howmany` as to keep + # loop over sorted values and mark as to keep until dimension is reached totaldim = 0 for i in sortperm(parent(values); strategy.by, strategy.rev) totaldim += dims[i] @@ -217,11 +217,16 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByValue return SectorDict(c => MAK.findtruncated_svd(d, strategy′) for (c, d) in pairs(values)) end +# Need to select the first k values here after sorting by error across blocks, +# where k is determined by the cumulative truncation error of these values. +# The strategy is therefore to sort all values, and then use a logical array to indicate +# which ones to keep. function MAK.findtruncated(values::SectorVector, strategy::MAK.TruncationByError) ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p)) ϵᵖ = similar(values, typeof(ϵᵖmax)) - if FusionStyle(sectortype(values)) isa UniqueFusion # dimensions are all 1 + # dimensions are all 1 so no need to account for weight + if FusionStyle(sectortype(values)) isa UniqueFusion parent(ϵᵖ) .= abs.(parent(values)) .^ strategy.p else for (c, v) in pairs(values) @@ -230,11 +235,18 @@ function MAK.findtruncated(values::SectorVector, strategy::MAK.TruncationByError end end - perm = sortperm(parent(values); by = abs, rev = false) - cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm)) - + # allocate logical array for the output result = similar(values, Bool) - parent(result)[perm] .= cumulative_err .> ϵᵖmax + fill!(parent(result), false) + + # loop over sorted values and mark as to keep until maximal error is reached + totalerr = zero(eltype(ϵᵖ)) + for i in sortperm(parent(values); by = abs, rev = false) + totalerr += ϵᵖ[i] + totalerr > ϵᵖmax && break + result[i] = true + end + return result end # disambiguate From 5bc15066b69c87a944d1ffbb9f797b0c1451def4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 26 Jan 2026 14:50:05 -0500 Subject: [PATCH 28/30] fix stupidity --- src/factorizations/truncation.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 8ae57c696..bc87f8d17 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -237,14 +237,14 @@ function MAK.findtruncated(values::SectorVector, strategy::MAK.TruncationByError # allocate logical array for the output result = similar(values, Bool) - fill!(parent(result), false) + fill!(parent(result), true) - # loop over sorted values and mark as to keep until maximal error is reached + # loop over sorted values and mark as to discard until maximal error is reached totalerr = zero(eltype(ϵᵖ)) for i in sortperm(parent(values); by = abs, rev = false) totalerr += ϵᵖ[i] totalerr > ϵᵖmax && break - result[i] = true + result[i] = false end return result From b44878fdbfa11de9ae777a84b9a8168c92c4e395 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 Jan 2026 14:07:56 -0500 Subject: [PATCH 29/30] fix views --- src/factorizations/truncation.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index bc87f8d17..0024077b0 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -41,7 +41,7 @@ function truncate_domain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, inds I = get(inds, c, nothing) @assert !isnothing(I) b′ = block(tsrc, c) - b .= b′[:, I] + b .= view(b′, :, I) end return tdst end @@ -50,7 +50,7 @@ function truncate_codomain!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap, in I = get(inds, c, nothing) @assert !isnothing(I) b′ = block(tsrc, c) - b .= b′[I, :] + b .= view(b′, I, :) end return tdst end @@ -58,7 +58,7 @@ function truncate_diagonal!(Ddst::DiagonalTensorMap, Dsrc::DiagonalTensorMap, in for (c, b) in blocks(Ddst) I = get(inds, c, nothing) @assert !isnothing(I) - diagview(b) .= @view diagview(block(Dsrc, c))[I] + diagview(b) .= view(diagview(block(Dsrc, c)), I) end return Ddst end From 4a722ef11811e11125f117661cb452cfc411aed6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 Jan 2026 14:21:28 -0500 Subject: [PATCH 30/30] enfore positive and finite p-norms --- ext/TensorKitCUDAExt/truncation.jl | 2 ++ src/factorizations/truncation.jl | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index 49b2aedec..019ded97b 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -21,6 +21,8 @@ end function MatrixAlgebraKit.findtruncated( values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByError ) + (isfinite(strategy.p) && strategy.p > 0) || + throw(ArgumentError(lazy"p-norm with p = $(strategy.p) is currently not supported.")) ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p)) ϵᵖ = similar(values, typeof(ϵᵖmax)) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 0024077b0..e8e113ec1 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -221,7 +221,9 @@ end # where k is determined by the cumulative truncation error of these values. # The strategy is therefore to sort all values, and then use a logical array to indicate # which ones to keep. -function MAK.findtruncated(values::SectorVector, strategy::MAK.TruncationByError) +function MAK.findtruncated(values::SectorVector, strategy::TruncationByError) + (isfinite(strategy.p) && strategy.p > 0) || + throw(ArgumentError(lazy"p-norm with p = $(strategy.p) is currently not supported.")) ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p)) ϵᵖ = similar(values, typeof(ϵᵖmax))