diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 9e7e6600..7dc19c8d 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -398,29 +398,20 @@ function svd_trunc_no_error!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{ (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool + # the output matrices here are the same size as for svd_full! do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr) return Utr, Str, Vᴴtr end function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) - U, S, Vᴴ = USVᴴ - check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) - _gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...) - - # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong - (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - + Utr, Str, Vᴴtr = svd_trunc_no_error!(A, USVᴴ, alg) # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum normS = norm(diagview(Str)) normA = norm(A) # equivalent to sqrt(normA^2 - normS^2) # but may be more accurate - ϵ = sqrt((normA + normS) * (normA - normS)) - - do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool - do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr) - + ϵ = sqrt((normA + normS) * abs(normA - normS)) return Utr, Str, Vᴴtr, ϵ end diff --git a/test/svd.jl b/test/svd.jl index affe2942..bc9e8c17 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -25,6 +25,9 @@ for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63) CUSOLVER_Jacobi(), ) TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS) + k = 5 + p = min(m, n) - k - 2 + min(m, n) > k + 2 && TestSuite.test_randomized_svd(CuMatrix{T}, (m, n), (MatrixAlgebraKit.TruncatedAlgorithm(CUSOLVER_Randomized(; k, p, niters = 20), truncrank(k)),)) if n == m TestSuite.test_svd(Diagonal{T, CuVector{T}}, m) TestSuite.test_svd_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) diff --git a/test/testsuite/svd.jl b/test/testsuite/svd.jl index 800018b2..6d5799ba 100644 --- a/test/testsuite/svd.jl +++ b/test/testsuite/svd.jl @@ -312,3 +312,16 @@ function test_svd_trunc_algs( end end end + +function test_randomized_svd(T::Type, sz, algs; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "randomized svd_trunc! algorithm $alg $summary_str" for alg in algs + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + m, n = size(A) + minmn = min(m, n) + S₀ = collect(svd_vals(A)) + U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc(A; alg) + @test collect(diagview(S1))[1:alg.alg.k] ≈ S₀[1:alg.alg.k] + end +end