From 9e8e63c688976a6fc0e94c8b693e0cf87095d856 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=94=A4=E6=B5=B7?= Date: Tue, 20 Jan 2026 13:48:07 +0800 Subject: [PATCH 1/2] Add loop corrections for BP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 周唤海 --- Project.toml | 6 +- src/TensorInference.jl | 2 + src/loop_series.jl | 419 +++++++++++++++++++++++++++++++++++++++++ test/loop_series.jl | 200 ++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 627 insertions(+), 1 deletion(-) create mode 100644 src/loop_series.jl create mode 100644 test/loop_series.jl diff --git a/Project.toml b/Project.toml index a578ad1..8af945d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,12 @@ name = "TensorInference" uuid = "c2297e78-99bd-40ad-871d-f50e56b81012" -authors = ["Jin-Guo Liu", "Martin Roa Villescas"] version = "0.6.3" +authors = ["Jin-Guo Liu", "Martin Roa Villescas"] [deps] +BitBasis = "50ba71b6-fa0f-514d-ae9a-0916efc90dcf" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -20,8 +22,10 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" TensorInferenceCUDAExt = "CUDA" [compat] +BitBasis = "0.9.10" CUDA = "4, 5" DocStringExtensions = "0.8.6, 0.9" +Graphs = "1.13.3" LinearAlgebra = "1" OMEinsum = "0.9.1" Pkg = "1" diff --git a/src/TensorInference.jl b/src/TensorInference.jl index 563bd33..ea8cebe 100644 --- a/src/TensorInference.jl +++ b/src/TensorInference.jl @@ -44,6 +44,7 @@ export update_temperature # belief propagation export BeliefPropgation, belief_propagate +export CycleBasis, LoopExcitation, minimal_loops, loop_basis, loop_weight, bp_vacuum_weight, loop_expansion # fileio export save_tensor_network, load_tensor_network @@ -60,6 +61,7 @@ include("mmap.jl") include("sampling.jl") include("cspmodels.jl") include("belief.jl") +include("loop_series.jl") include("fileio.jl") end # module diff --git a/src/loop_series.jl b/src/loop_series.jl new file mode 100644 index 0000000..9be24de --- /dev/null +++ b/src/loop_series.jl @@ -0,0 +1,419 @@ +""" +Loop series expansion utilities for belief propagation. +""" + +using Graphs +using BitBasis + +# --- Bitmask helpers --- + +bitmask_type(nbits::Int) = LongLongUInt{max(1, cld(nbits, 64))} +edge_key(u::Int, v::Int) = u < v ? (u, v) : (v, u) + +function mask_indices(mask::T, nbits::Int) where {T<:Integer} + ids = Int[] + for i in 1:nbits + readbit(mask, i) == 1 && push!(ids, i) + end + return ids +end + +function mask_highest(mask::T, nbits::Int) where {T<:Integer} + for i in nbits:-1:1 + readbit(mask, i) == 1 && return i + end + return 0 +end + +# --- Graph loop basis (decoupled from tensors) --- + +""" + CycleBasis + +A cycle basis represented by bitmasks over an explicit edge list. +""" +struct CycleBasis{INT<:Integer} + edges::Vector{Tuple{Int, Int}} + cycles::Vector{INT} +end + +cycle_rank(g::SimpleGraph) = ne(g) - nv(g) + length(connected_components(g)) + +function edge_list(g::SimpleGraph) + eds = Tuple{Int, Int}[] + for e in edges(g) + u, v = edge_key(src(e), dst(e)) + push!(eds, (u, v)) + end + sort!(eds) + return eds +end + +function _edge_adjacency(eds::Vector{Tuple{Int, Int}}, nverts::Int) + adj = [Vector{Tuple{Int, Int}}() for _ in 1:nverts] + for (idx, (u, v)) in enumerate(eds) + push!(adj[u], (v, idx)) + push!(adj[v], (u, idx)) + end + return adj +end + +function _shortest_path_masks(adj, start::Int, target::Int, skip_edge::Int, ::Type{INT}) where {INT<:Integer} + n = length(adj) + dist = fill(typemax(Int), n) + parents = [Vector{Tuple{Int, Int}}() for _ in 1:n] + queue = Vector{Int}(undef, n) + head = 1 + tail = 1 + dist[start] = 0 + queue[1] = start + + while head <= tail + u = queue[head] + head += 1 + for (v, eidx) in adj[u] + eidx == skip_edge && continue + nd = dist[u] + 1 + if nd < dist[v] + dist[v] = nd + empty!(parents[v]) + push!(parents[v], (u, eidx)) + tail += 1 + queue[tail] = v + elseif nd == dist[v] + push!(parents[v], (u, eidx)) + end + end + end + + dist[target] == typemax(Int) && return INT[] + masks = INT[] + function backtrack(node::Int, mask::INT) + if node == start + push!(masks, mask) + return + end + for (prev, eidx) in parents[node] + backtrack(prev, mask | bmask(INT, eidx)) + end + end + backtrack(target, zero(INT)) + return masks +end + +function _candidate_cycles(g::SimpleGraph, eds::Vector{Tuple{Int, Int}}) + nbits = length(eds) + INT = bitmask_type(nbits) + adj = _edge_adjacency(eds, nv(g)) + seen = Set{String}() + candidates = INT[] + weights = Int[] + + for (eidx, (u, v)) in enumerate(eds) + paths = _shortest_path_masks(adj, u, v, eidx, INT) + for path_mask in paths + cycle = path_mask | bmask(INT, eidx) + key = join(mask_indices(cycle, nbits), ",") + key in seen && continue + push!(seen, key) + push!(candidates, cycle) + push!(weights, count_ones(cycle)) + end + end + return candidates, weights +end + +function _reduce(vec::INT, basis::Vector{INT}, pivots::Vector{Int}) where {INT<:Integer} + for (b, p) in zip(basis, pivots) + readbit(vec, p) == 0 && continue + vec = vec ⊻ b + end + return vec +end + +function _minimum_cycle_basis(candidates::Vector{INT}, weights::Vector{Int}, rank::Int, nbits::Int) where {INT<:Integer} + rank == 0 && return INT[] + order = sortperm(weights) + basis = INT[] + pivots = Int[] + + for idx in order + vec = _reduce(candidates[idx], basis, pivots) + iszero(vec) && continue + pivot = mask_highest(vec, nbits) + insert_at = findfirst(x -> x < pivot, pivots) + if insert_at === nothing + push!(basis, vec) + push!(pivots, pivot) + else + insert!(basis, insert_at, vec) + insert!(pivots, insert_at, pivot) + end + length(basis) == rank && break + end + + length(basis) == rank || throw(ArgumentError("cycle basis incomplete: got $(length(basis)) of $rank")) + return basis +end + +""" + minimal_loops(g::SimpleGraph) -> CycleBasis + +Return a minimum cycle basis of `g`. Each cycle is a bitmask over the edge +list stored in `CycleBasis.edges`. +""" +function minimal_loops(g::SimpleGraph) + eds = edge_list(g) + rank = cycle_rank(g) + INT = bitmask_type(length(eds)) + rank == 0 && return CycleBasis{INT}(eds, INT[]) + candidates, weights = _candidate_cycles(g, eds) + cycles = _minimum_cycle_basis(candidates, weights, rank, length(eds)) + return CycleBasis{INT}(eds, cycles) +end + +# --- Loop series contractions --- + +""" + LoopExcitation + +A loop excitation represented by bitmasks for edges (variables) and tensors. +""" +struct LoopExcitation{E<:Integer, T<:Integer} + edges::E + tensors::T +end + +loop_degree(loop::LoopExcitation) = count_ones(loop.edges) + +struct LoopSeriesCache{T, VT <: AbstractVector{T}} + message_in_norm::Vector{Vector{VT}} + v2t_pos::Vector{Dict{Int, Int}} + complement_proj::Vector{Union{Nothing, Matrix{T}}} +end + +function LoopSeriesCache(bp::BeliefPropgation, state::BPState{T}) where {T} + nvars = num_variables(bp) + msg_norm = Vector{Vector{typeof(state.message_in[1][1])}}(undef, nvars) + v2t_pos = Vector{Dict{Int, Int}}(undef, nvars) + c_mats = Vector{Union{Nothing, Matrix{T}}}(undef, nvars) + + for v in 1:nvars + msgs = state.message_in[v] + msg_norm[v] = [copy(m) for m in msgs] + pos = Dict{Int, Int}() + for (idx, t) in enumerate(bp.v2t[v]) + pos[t] = idx + end + v2t_pos[v] = pos + + if length(msgs) == 2 + m1, m2 = msgs + s = dot(m2, m1) + iszero(s) && throw(ArgumentError("edge $v has zero message overlap")) + scale1 = inv(sqrt(s)) + scale2 = inv(sqrt(conj(s))) + msg_norm[v][1] = m1 .* scale1 + msg_norm[v][2] = m2 .* scale2 + d = length(m1) + P = msg_norm[v][2] * msg_norm[v][1]' + c_mats[v] = Matrix{T}(I, d, d) - P + else + c_mats[v] = nothing + end + end + + return LoopSeriesCache(msg_norm, v2t_pos, c_mats) +end + +function tensor_connectivity_graph(bp::BeliefPropgation) + g = SimpleGraph(num_tensors(bp)) + edge_to_var = Dict{Tuple{Int, Int}, Int}() + + for v in 1:num_variables(bp) + tids = bp.v2t[v] + length(tids) == 2 || continue + t1, t2 = tids + a, b = edge_key(t1, t2) + if haskey(edge_to_var, (a, b)) + throw(ArgumentError("multiple variables between tensors $a and $b")) + end + edge_to_var[(a, b)] = v + add_edge!(g, a, b) + end + + return g, edge_to_var +end + +function loop_basis(bp::BeliefPropgation) + g, edge_to_var = tensor_connectivity_graph(bp) + basis = minimal_loops(g) + return loops_from_basis(bp, basis, edge_to_var) +end + +function loops_from_basis(bp::BeliefPropgation, basis::CycleBasis{INT}, edge_to_var) where {INT<:Integer} + nvars = num_variables(bp) + nt = num_tensors(bp) + EdgeMask = bitmask_type(nvars) + TensorMask = bitmask_type(nt) + loops = LoopExcitation{EdgeMask, TensorMask}[] + + for cycle in basis.cycles + edge_mask = zero(EdgeMask) + tensor_mask = zero(TensorMask) + for edge_idx in mask_indices(cycle, length(basis.edges)) + u, v = basis.edges[edge_idx] + a, b = edge_key(u, v) + var = edge_to_var[(a, b)] + edge_mask = edge_mask | bmask(EdgeMask, var) + tensor_mask = tensor_mask | bmask(TensorMask, u, v) + end + push!(loops, LoopExcitation(edge_mask, tensor_mask)) + end + + return loops +end + +_scalar(x) = x isa AbstractArray && ndims(x) == 0 ? x[] : x + +function _message_to_tensor(bp::BeliefPropgation, state::BPState, cache::LoopSeriesCache, v::Int, t::Int) + tids = bp.v2t[v] + if length(tids) == 2 + idx = cache.v2t_pos[v][t] + idx_other = idx == 1 ? 2 : 1 + return cache.message_in_norm[v][idx_other] + elseif length(tids) == 1 + idx = cache.v2t_pos[v][t] + return state.message_out[v][idx] + else + throw(ArgumentError("loop corrections require variables of degree 1 or 2; variable $v has degree $(length(tids))")) + end +end + +function _reduced_tensor(bp::BeliefPropgation, state::BPState, cache::LoopSeriesCache, loop_mask::T, t::Int) where {T<:Integer} + vars = bp.t2v[t] + ixs = Vector{Vector{Int}}() + tensors = Any[] + push!(ixs, vars) + push!(tensors, bp.tensors[t]) + + keep_vars = Int[] + for v in vars + if readbit(loop_mask, v) == 1 + push!(keep_vars, v) + else + msg = _message_to_tensor(bp, state, cache, v, t) + push!(ixs, [v]) + push!(tensors, msg) + end + end + + code = EinCode(ixs, keep_vars) + return code(tensors...), keep_vars +end + +""" + loop_weight(bp::BeliefPropgation, state::BPState, loop::LoopExcitation; optimizer=nothing) + +Compute the weight of a single loop excitation by contracting the tensors +in the loop with BP messages absorbed on external edges and `I - P` projectors +inserted on loop edges. +""" +function loop_weight(bp::BeliefPropgation, state::BPState, loop::LoopExcitation; optimizer=nothing) + cache = LoopSeriesCache(bp, state) + return loop_weight(bp, state, loop, cache; optimizer) +end + +function loop_weight(bp::BeliefPropgation, state::BPState, loop::LoopExcitation, cache::LoopSeriesCache; optimizer=nothing) + nvars = num_variables(bp) + tensors = Any[] + labels = Vector{Vector{Int}}() + + for t in 1:num_tensors(bp) + readbit(loop.tensors, t) == 1 || continue + reduced, keep_vars = _reduced_tensor(bp, state, cache, loop.edges, t) + tensor_labels = Int[] + for v in keep_vars + t1, t2 = bp.v2t[v] + push!(tensor_labels, t == t1 ? v : v + nvars) + end + push!(tensors, reduced) + push!(labels, tensor_labels) + end + + for v in mask_indices(loop.edges, nvars) + C = cache.complement_proj[v] + C === nothing && throw(ArgumentError("loop edge $v is not degree-2")) + push!(tensors, C) + push!(labels, [v, v + nvars]) + end + + code = EinCode(labels, Int[]) + if optimizer !== nothing + size_dict = OMEinsum.get_size_dict(labels, tensors) + code = optimize_code(code, size_dict, optimizer) + end + return _scalar(code(tensors...)) +end + +""" + bp_vacuum_weight(bp::BeliefPropgation, state::BPState) + +Compute the BP vacuum contribution by contracting each tensor with incoming +messages on all edges and multiplying the resulting scalars. +""" +function bp_vacuum_weight(bp::BeliefPropgation, state::BPState) + cache = LoopSeriesCache(bp, state) + return bp_vacuum_weight(bp, state, cache) +end + +function bp_vacuum_weight(bp::BeliefPropgation, state::BPState, cache::LoopSeriesCache) + empty_loop = zero(bitmask_type(num_variables(bp))) + vals = map(1:num_tensors(bp)) do t + reduced, _ = _reduced_tensor(bp, state, cache, empty_loop, t) + _scalar(reduced) + end + return prod(vals) +end + +function _disjoint_loop_sum(loops::Vector{LoopExcitation{E, T}}, weights::Vector, K::Int, edge_zero::E, tensor_zero::T) where {E<:Integer, T<:Integer} + isempty(loops) && return zero(weights[1]) + K <= 0 && return zero(weights[1]) + total = zero(weights[1]) + + function backtrack(start::Int, depth::Int, used_edges::E, used_tensors::T, prod_weight) + depth > 0 && (total += prod_weight) + depth == K && return + for i in start:length(loops) + loop = loops[i] + iszero(loop.edges & used_edges) || continue + iszero(loop.tensors & used_tensors) || continue + backtrack(i + 1, depth + 1, used_edges | loop.edges, used_tensors | loop.tensors, prod_weight * weights[i]) + end + end + + backtrack(1, 0, edge_zero, tensor_zero, one(weights[1])) + return total +end + +""" + loop_expansion(bp::BeliefPropgation, state::BPState; loops=loop_basis(bp), K::Int=1, optimizer=nothing) + +Compute a loop series correction to the BP vacuum contribution. The correction +includes all disjoint combinations of up to `K` loops. Returns a named tuple +with the BP vacuum weight, correction, total value, and per-loop weights. +""" +function loop_expansion(bp::BeliefPropgation, state::BPState; loops = loop_basis(bp), K::Int = 1, optimizer = nothing) + cache = LoopSeriesCache(bp, state) + bp_weight = bp_vacuum_weight(bp, state, cache) + + if isempty(loops) || K <= 0 + return (bp_weight = bp_weight, correction = zero(bp_weight), value = bp_weight, loop_weights = typeof(bp_weight)[]) + end + + loop_weights = [loop_weight(bp, state, loop, cache; optimizer) for loop in loops] + edge_zero = zero(bitmask_type(num_variables(bp))) + tensor_zero = zero(bitmask_type(num_tensors(bp))) + correction = _disjoint_loop_sum(loops, loop_weights, K, edge_zero, tensor_zero) + return (bp_weight = bp_weight, correction = correction, value = bp_weight + correction, loop_weights = loop_weights) +end diff --git a/test/loop_series.jl b/test/loop_series.jl new file mode 100644 index 0000000..330bad5 --- /dev/null +++ b/test/loop_series.jl @@ -0,0 +1,200 @@ +using TensorInference, Test, LinearAlgebra, Graphs, BitBasis, Random + +function gf2_basis(masks::Vector{T}, nbits::Int) where {T<:Integer} + basis = T[] + pivots = Int[] + for mask in masks + vec = mask + for (b, p) in zip(basis, pivots) + readbit(vec, p) == 0 && continue + vec = vec ⊻ b + end + iszero(vec) && continue + pivot = 0 + for i in nbits:-1:1 + if readbit(vec, i) == 1 + pivot = i + break + end + end + insert_at = findfirst(x -> x < pivot, pivots) + if insert_at === nothing + push!(basis, vec) + push!(pivots, pivot) + else + insert!(basis, insert_at, vec) + insert!(pivots, insert_at, pivot) + end + end + return basis, pivots +end + +function gf2_rank(masks::Vector{T}, nbits::Int) where {T<:Integer} + basis, _ = gf2_basis(masks, nbits) + return length(basis) +end + +function reduce_with_basis(mask::T, basis::Vector{T}, pivots::Vector{Int}) where {T<:Integer} + vec = mask + for (b, p) in zip(basis, pivots) + readbit(vec, p) == 0 && continue + vec = vec ⊻ b + end + return vec +end + +function cycle_mask(cycle::Vector{Int}, edge_index, ::Type{T}) where {T<:Integer} + n = length(cycle) + n == 0 && return zero(T) + mask = zero(T) + for i in 1:n + u = cycle[i] + v = cycle[i == n ? 1 : i + 1] + idx = edge_index[(u, v)] + mask = mask | bmask(T, idx) + end + return mask +end + +@testset "cycle basis on Petersen graph" begin + g = Graphs.SimpleGraphs.smallgraph(:petersen) + basis = minimal_loops(g) + rank = ne(g) - nv(g) + length(connected_components(g)) + @test length(basis.cycles) == rank + lengths = count_ones.(basis.cycles) + @test minimum(lengths) == 5 + @test all(len -> len >= 5, lengths) + @test gf2_rank(basis.cycles, length(basis.edges)) == rank + edge_index = Dict{Tuple{Int, Int}, Int}() + for (i, (u, v)) in enumerate(basis.edges) + edge_index[(u, v)] = i + edge_index[(v, u)] = i + end + masks = Set{eltype(basis.cycles)}() + for cyc in simplecycles(DiGraph(g)) + length(cyc) < 3 && continue + push!(masks, cycle_mask(cyc, edge_index, eltype(basis.cycles))) + end + basis_vecs, pivots = gf2_basis(basis.cycles, length(basis.edges)) + for mask in masks + @test iszero(reduce_with_basis(mask, basis_vecs, pivots)) + end +end + +function cycle_uai(tensors::Vector{Matrix{T}}) where {T} + n = length(tensors) + d1, d2 = size(tensors[1]) + d1 == d2 || throw(ArgumentError("tensors must be square")) + cards = fill(d1, n) + factors = Vector{TensorInference.Factor{T, 2}}(undef, n) + for i in 1:n + j = i == n ? 1 : i + 1 + size(tensors[i], 1) == d1 || throw(ArgumentError("dimension mismatch")) + size(tensors[i], 2) == d1 || throw(ArgumentError("dimension mismatch")) + factors[i] = TensorInference.Factor((i, j), tensors[i]) + end + return TensorInference.UAIModel(n, cards, factors) +end + +_scalar(x) = x isa AbstractArray && ndims(x) == 0 ? x[] : x + +edge_key(u::Int, v::Int) = u < v ? (u, v) : (v, u) + +function edge_list(g::SimpleGraph) + eds = Tuple{Int, Int}[] + for e in edges(g) + u, v = edge_key(src(e), dst(e)) + push!(eds, (u, v)) + end + sort!(eds) + return eds +end + +function graph_uai(g::SimpleGraph, bond_dim::Int; rng::AbstractRNG = Random.default_rng()) + eds = edge_list(g) + edge_index = Dict{Tuple{Int, Int}, Int}() + for (i, (u, v)) in enumerate(eds) + edge_index[(u, v)] = i + edge_index[(v, u)] = i + end + factors = TensorInference.Factor{Float64}[] + for v in vertices(g) + neis = sort!(collect(neighbors(g, v))) + vars = [edge_index[(v, u)] for u in neis] + tensor = rand(rng, ntuple(_ -> bond_dim, length(vars))...) + push!(factors, TensorInference.Factor((vars...,), tensor)) + end + return TensorInference.UAIModel(length(eds), fill(bond_dim, length(eds)), factors) +end + +exact_weight(uai) = _scalar(probability(TensorNetworkModel(uai))) + +function run_loop_expansion(uai; max_iter::Int = 500, tol::Real = 1e-8, K::Int = 1) + bp = BeliefPropgation(uai) + state, info = belief_propagate(bp; max_iter, tol) + loops = loop_basis(bp) + bp_weight = bp_vacuum_weight(bp, state) + result = loop_expansion(bp, state; loops, K) + return bp, state, info, loops, bp_weight, result +end + +function random_cyclic_graph(n::Int, m::Int; rng::AbstractRNG = Random.default_rng(), max_tries::Int = 100) + max_edges = n * (n - 1) ÷ 2 + m > max_edges && throw(ArgumentError("m must be <= $max_edges")) + for _ in 1:max_tries + g = SimpleGraph(n) + pairs = [(u, v) for u in 1:n-1 for v in u+1:n] + shuffle!(rng, pairs) + for i in 1:m + u, v = pairs[i] + add_edge!(g, u, v) + end + if length(connected_components(g)) == 1 && ne(g) >= nv(g) + return g + end + end + error("failed to sample a connected cyclic graph after $max_tries attempts") +end + +@testset "loop expansion on single cycle" begin + A = [0.2 0.9; 0.9 0.2] + tensors = [A for _ in 1:5] + uai = cycle_uai(tensors) + _, _, info, loops, bp_weight, result = run_loop_expansion(uai; max_iter=500, tol=1e-10, K=1) + @test info.converged + + @test length(loops) == 1 + @test count_ones(loops[1].edges) == 5 + + exact = tr(A^5) + @test !isapprox(bp_weight, exact; atol=1e-6, rtol=1e-6) + @test result.value ≈ exact atol=1e-6 +end + +@testset "loop expansion on Petersen random tensors" begin + rng = MersenneTwister(17) + g = Graphs.SimpleGraphs.smallgraph(:petersen) + uai = graph_uai(g, 2; rng) + _, _, info, loops, bp_weight, result = run_loop_expansion(uai; max_iter=500, tol=1e-8, K=1) + @test info.converged + + @test !isempty(loops) + + exact = exact_weight(uai) + @test !isapprox(bp_weight, exact; atol=1e-6, rtol=1e-6) + @test isfinite(bp_weight) && isfinite(result.value) && isfinite(exact) +end + +@testset "loop expansion on random simple graphs" begin + for (n, m, seed) in ((6, 7, 23), (7, 9, 41)) + rng = MersenneTwister(seed) + g = random_cyclic_graph(n, m; rng) + uai = graph_uai(g, 2; rng) + _, _, info, loops, bp_weight, result = run_loop_expansion(uai; max_iter=600, tol=1e-8, K=1) + @test info.converged + @test !isempty(loops) + exact = exact_weight(uai) + @test !isapprox(bp_weight, exact; atol=1e-6, rtol=1e-6) + @test isfinite(bp_weight) && isfinite(result.value) && isfinite(exact) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 85acd40..4e3f27f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,6 +30,7 @@ end @testset "belief propagation" begin include("belief.jl") + include("loop_series.jl") end @testset "fileio" begin From d7d300ae5d09143baadf8367ca7ddbe9d3471a6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=94=A4=E6=B5=B7?= Date: Tue, 20 Jan 2026 23:21:20 +0800 Subject: [PATCH 2/2] Update loop series logic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 周唤海 --- src/TensorInference.jl | 5 +- src/loop_series.jl | 619 +++++++++++++++++++---------------------- test/loop_series.jl | 139 ++++++--- 3 files changed, 396 insertions(+), 367 deletions(-) diff --git a/src/TensorInference.jl b/src/TensorInference.jl index ea8cebe..72daba2 100644 --- a/src/TensorInference.jl +++ b/src/TensorInference.jl @@ -7,6 +7,8 @@ $(EXPORTS) """ module TensorInference +using BitBasis +using Graphs using OMEinsum, LinearAlgebra using OMEinsum: CacheTree, cached_einsum using OMEinsum.OMEinsumContractionOrders.JSON @@ -44,7 +46,8 @@ export update_temperature # belief propagation export BeliefPropgation, belief_propagate -export CycleBasis, LoopExcitation, minimal_loops, loop_basis, loop_weight, bp_vacuum_weight, loop_expansion +export LoopSeriesTruncation, XORLoopSum, UnionLoopSum, Degree, Cyclomatic +export loop_series, loop_basis, loop_weight, bp_vacuum_weight, loop_corrections # fileio export save_tensor_network, load_tensor_network diff --git a/src/loop_series.jl b/src/loop_series.jl index 9be24de..e455f85 100644 --- a/src/loop_series.jl +++ b/src/loop_series.jl @@ -1,419 +1,370 @@ -""" -Loop series expansion utilities for belief propagation. -""" - -using Graphs -using BitBasis - -# --- Bitmask helpers --- - bitmask_type(nbits::Int) = LongLongUInt{max(1, cld(nbits, 64))} -edge_key(u::Int, v::Int) = u < v ? (u, v) : (v, u) +@inline mask_indices(mask::T, nbits::Int) where {T<:Integer} = (i for i in 1:nbits if readbit(mask, i) == 1) -function mask_indices(mask::T, nbits::Int) where {T<:Integer} - ids = Int[] - for i in 1:nbits - readbit(mask, i) == 1 && push!(ids, i) - end - return ids -end +abstract type LoopSeriesTruncation end +struct XORLoopSum <: LoopSeriesTruncation; k::Int; end +struct UnionLoopSum <: LoopSeriesTruncation; k::Int; end +struct Degree <: LoopSeriesTruncation; max::Int; end +struct Cyclomatic <: LoopSeriesTruncation; max::Int; end -function mask_highest(mask::T, nbits::Int) where {T<:Integer} - for i in nbits:-1:1 - readbit(mask, i) == 1 && return i +function edge_list_index(g::SimpleGraph) + eds = [(min(src(e), dst(e)), max(src(e), dst(e))) for e in edges(g)] + sort!(eds) + idx = Dict{Tuple{Int, Int}, Int}() + for (i, (u, v)) in enumerate(eds) + idx[(u, v)] = i + idx[(v, u)] = i end - return 0 + return eds, idx end -# --- Graph loop basis (decoupled from tensors) --- - -""" - CycleBasis - -A cycle basis represented by bitmasks over an explicit edge list. -""" -struct CycleBasis{INT<:Integer} - edges::Vector{Tuple{Int, Int}} - cycles::Vector{INT} -end - -cycle_rank(g::SimpleGraph) = ne(g) - nv(g) + length(connected_components(g)) - -function edge_list(g::SimpleGraph) - eds = Tuple{Int, Int}[] - for e in edges(g) - u, v = edge_key(src(e), dst(e)) - push!(eds, (u, v)) +function cycle_mask(cycle, idx, ::Type{T}) where {T<:Integer} + n = length(cycle); n == 0 && return zero(T) + mask = zero(T) + @inbounds for i in 1:n + u = cycle[i]; v = cycle[i == n ? 1 : i + 1] + mask |= bmask(T, idx[(u, v)]) end - sort!(eds) - return eds + return mask end -function _edge_adjacency(eds::Vector{Tuple{Int, Int}}, nverts::Int) - adj = [Vector{Tuple{Int, Int}}() for _ in 1:nverts] - for (idx, (u, v)) in enumerate(eds) - push!(adj[u], (v, idx)) - push!(adj[v], (u, idx)) +function _generalized_loops(eds, nverts; max_edges, max_order) + m = length(eds); EdgeMask = bitmask_type(m); loops = EdgeMask[] + (m == 0 || max_edges <= 2) && return loops + max_edges = min(max_edges, m); max_order != typemax(Int) && (max_edges = min(max_edges, nverts - 1 + max_order)) + us = first.(eds); vs = last.(eds) + by_vertex = [Int[] for _ in 1:nverts] + for i in 1:m; push!(by_vertex[us[i]], i); push!(by_vertex[vs[i]], i); end + edge_adj = [Int[] for _ in 1:m] + for inc in by_vertex; for i in inc, j in inc; i == j || push!(edge_adj[i], j); end; end + for i in 1:m; sort!(edge_adj[i]); unique!(edge_adj[i]); end + deg = zeros(Int, nverts) + function add_edge!(idx, vcount, d1) + u = us[idx]; v = vs[idx] + deg[u] == 0 && (vcount += 1); deg[v] == 0 && (vcount += 1) + deg[u] == 1 && (d1 -= 1); deg[v] == 1 && (d1 -= 1) + deg[u] += 1; deg[v] += 1 + deg[u] == 1 && (d1 += 1); deg[v] == 1 && (d1 += 1) + return vcount, d1 end - return adj + remove_edge!(idx) = (deg[us[idx]] -= 1; deg[vs[idx]] -= 1) + function backtrack(root, edge_mask, edge_count, vcount, d1, cand, cand_mask, excluded) + cyclo = edge_count - vcount + 1; cyclo > max_order && return + edge_count > 0 && d1 == 0 && cyclo <= max_order && push!(loops, edge_mask) + edge_count == max_edges && return + prefix_mask = zero(EdgeMask) + for pos in 1:length(cand) + idx = cand[pos] + new_cand = cand[pos+1:end]; new_cand_mask = cand_mask & ~prefix_mask & ~bmask(EdgeMask, idx) + new_excluded = excluded | prefix_mask | bmask(EdgeMask, idx) + for nb in edge_adj[idx] + nb <= root && continue + readbit(edge_mask, nb) == 1 && continue + readbit(new_excluded, nb) == 1 && continue + readbit(new_cand_mask, nb) == 1 && continue + push!(new_cand, nb); new_cand_mask |= bmask(EdgeMask, nb) + end + vcount2, d12 = add_edge!(idx, vcount, d1) + backtrack(root, edge_mask | bmask(EdgeMask, idx), edge_count + 1, vcount2, d12, new_cand, new_cand_mask, new_excluded) + remove_edge!(idx); prefix_mask |= bmask(EdgeMask, idx) + end + end + for root in 1:m + edge_mask = bmask(EdgeMask, root) + vcount, d1 = add_edge!(root, 0, 0) + cand = [nb for nb in edge_adj[root] if nb > root] + cand_mask = isempty(cand) ? zero(EdgeMask) : bmask(EdgeMask, cand) + backtrack(root, edge_mask, 1, vcount, d1, cand, cand_mask, zero(EdgeMask)) + remove_edge!(root) + end + sort!(loops; by = count_ones); return loops end -function _shortest_path_masks(adj, start::Int, target::Int, skip_edge::Int, ::Type{INT}) where {INT<:Integer} - n = length(adj) - dist = fill(typemax(Int), n) - parents = [Vector{Tuple{Int, Int}}() for _ in 1:n] - queue = Vector{Int}(undef, n) - head = 1 - tail = 1 - dist[start] = 0 - queue[1] = start - - while head <= tail - u = queue[head] - head += 1 - for (v, eidx) in adj[u] - eidx == skip_edge && continue - nd = dist[u] + 1 - if nd < dist[v] - dist[v] = nd - empty!(parents[v]) - push!(parents[v], (u, eidx)) - tail += 1 - queue[tail] = v - elseif nd == dist[v] - push!(parents[v], (u, eidx)) +function _loop_series_cycles(g, k, op; connected::Bool = false, check_connected::Bool = false) + eds, idx = edge_list_index(g) + EdgeMask = bitmask_type(length(eds)) + cycles = [cycle_mask(c, idx, EdgeMask) for c in cycle_basis(g)] + k = min(k, length(cycles)) + k <= 0 && return (edges = eds, loops = eltype(cycles)[]) + if !connected + results = Set{eltype(cycles)}() + function combine_backtrack(start, depth, acc) + depth > 0 && push!(results, acc) + depth == k && return + for i in start:length(cycles) + combine_backtrack(i + 1, depth + 1, op(acc, cycles[i])) end end + combine_backtrack(1, 0, zero(eltype(cycles))) + delete!(results, zero(eltype(cycles))) + return (edges = eds, loops = collect(results)) end - - dist[target] == typemax(Int) && return INT[] - masks = INT[] - function backtrack(node::Int, mask::INT) - if node == start - push!(masks, mask) - return + ncycles = length(cycles); nverts = nv(g); nedges = length(eds) + VMask = bitmask_type(nverts); CycleMask = bitmask_type(ncycles) + vmasks = Vector{VMask}(undef, ncycles) + for i in 1:ncycles + vmask = zero(VMask) + for edge_idx in mask_indices(cycles[i], length(eds)); u, v = eds[edge_idx]; vmask |= bmask(VMask, u, v); end + vmasks[i] = vmask + end + adjmask = fill(zero(CycleMask), ncycles) + for i in 1:ncycles-1 + for j in i+1:ncycles + iszero(vmasks[i] & vmasks[j]) && continue + adjmask[i] |= bmask(CycleMask, j); adjmask[j] |= bmask(CycleMask, i) end - for (prev, eidx) in parents[node] - backtrack(prev, mask | bmask(INT, eidx)) + end + us = first.(eds); vs = last.(eds) + incident = [Int[] for _ in 1:nverts] + for i in 1:nedges; push!(incident[us[i]], i); push!(incident[vs[i]], i); end + used = zeros(Int, nverts); seen = zeros(Int, nverts) + used_mark = 0; seen_mark = 0; used_nodes = Int[]; stack = Int[] + results = Set{eltype(cycles)}() + function connected_mask(mask) + iszero(mask) && return false + used_mark += 1; start = 0; empty!(used_nodes) + for idx in mask_indices(mask, nedges) + u = us[idx]; v = vs[idx] + if used[u] != used_mark; used[u] = used_mark; push!(used_nodes, u); start == 0 && (start = u); end + if used[v] != used_mark; used[v] = used_mark; push!(used_nodes, v); start == 0 && (start = v); end + end + start == 0 && return false + seen_mark += 1; empty!(stack); push!(stack, start); seen[start] = seen_mark + while !isempty(stack) + u = pop!(stack) + for e in incident[u] + readbit(mask, e) == 1 || continue + v = us[e] == u ? vs[e] : us[e] + seen[v] == seen_mark && continue + seen[v] = seen_mark; push!(stack, v) + end end + for v in used_nodes; seen[v] != seen_mark && return false; end + return true end - backtrack(target, zero(INT)) - return masks -end - -function _candidate_cycles(g::SimpleGraph, eds::Vector{Tuple{Int, Int}}) - nbits = length(eds) - INT = bitmask_type(nbits) - adj = _edge_adjacency(eds, nv(g)) - seen = Set{String}() - candidates = INT[] - weights = Int[] - - for (eidx, (u, v)) in enumerate(eds) - paths = _shortest_path_masks(adj, u, v, eidx, INT) - for path_mask in paths - cycle = path_mask | bmask(INT, eidx) - key = join(mask_indices(cycle, nbits), ",") - key in seen && continue - push!(seen, key) - push!(candidates, cycle) - push!(weights, count_ones(cycle)) + function backtrack(root_lower, sub_cycles, acc_edges, cand_mask, depth) + depth > 0 && (!check_connected || depth == 1 || connected_mask(acc_edges)) && push!(results, acc_edges) + depth == k && return + prefix_mask = zero(CycleMask) + for v in mask_indices(cand_mask, ncycles) + vbit = bmask(CycleMask, v) + new_sub = sub_cycles | vbit; new_acc = op(acc_edges, cycles[v]) + new_cand = cand_mask & ~prefix_mask & ~vbit + new_cand |= adjmask[v] & ~root_lower & ~new_sub & ~new_cand & ~prefix_mask + backtrack(root_lower, new_sub, new_acc, new_cand, depth + 1) + prefix_mask |= vbit end end - return candidates, weights + lower_mask = zero(CycleMask) + for root in 1:ncycles + rootbit = bmask(CycleMask, root); lower_mask |= rootbit + backtrack(lower_mask, rootbit, cycles[root], adjmask[root] & ~lower_mask, 1) + end + return (edges = eds, loops = collect(results)) end -function _reduce(vec::INT, basis::Vector{INT}, pivots::Vector{Int}) where {INT<:Integer} - for (b, p) in zip(basis, pivots) - readbit(vec, p) == 0 && continue - vec = vec ⊻ b +loop_series(g::SimpleGraph, trunc::XORLoopSum) = _loop_series_cycles(g, trunc.k, ⊻; connected = true, check_connected = true) +loop_series(g::SimpleGraph, trunc::UnionLoopSum) = _loop_series_cycles(g, trunc.k, |; connected = true) +loop_series(g::SimpleGraph, trunc::Degree) = begin + corenum = core_number(g) + eds, _ = edge_list_index(g) + eds = [e for e in eds if corenum[e[1]] >= 2 && corenum[e[2]] >= 2] + return (edges = eds, loops = _generalized_loops(eds, nv(g); max_edges = trunc.max, max_order = typemax(Int))) +end +loop_series(g::SimpleGraph, trunc::Cyclomatic) = begin + if trunc.max == 1 + eds, idx = edge_list_index(g) + T = bitmask_type(length(eds)) + masks = unique!(collect(cycle_mask(c, idx, T) for c in simplecycles(DiGraph(g)) if length(c) >= 3)) + sort!(masks; by = count_ones) + return (edges = eds, loops = masks) end - return vec + corenum = core_number(g) + eds, _ = edge_list_index(g) + eds = [e for e in eds if corenum[e[1]] >= 2 && corenum[e[2]] >= 2] + return (edges = eds, loops = _generalized_loops(eds, nv(g); max_edges = typemax(Int), max_order = trunc.max)) end -function _minimum_cycle_basis(candidates::Vector{INT}, weights::Vector{Int}, rank::Int, nbits::Int) where {INT<:Integer} - rank == 0 && return INT[] - order = sortperm(weights) - basis = INT[] - pivots = Int[] +function tensor_connectivity_graph(bp::BeliefPropgation) + g = SimpleGraph(num_tensors(bp)) + edge_to_var = Dict{Tuple{Int, Int}, Int}() + for v in 1:num_variables(bp) + tids = bp.v2t[v]; length(tids) == 2 || continue + a = min(tids[1], tids[2]) + b = max(tids[1], tids[2]) + haskey(edge_to_var, (a, b)) && throw(ArgumentError("multiple variables between tensors $a and $b")) + edge_to_var[(a, b)] = v; add_edge!(g, a, b) + end + return g, edge_to_var +end - for idx in order - vec = _reduce(candidates[idx], basis, pivots) - iszero(vec) && continue - pivot = mask_highest(vec, nbits) - insert_at = findfirst(x -> x < pivot, pivots) - if insert_at === nothing - push!(basis, vec) - push!(pivots, pivot) - else - insert!(basis, insert_at, vec) - insert!(pivots, insert_at, pivot) +function _loops_from_edge_masks(bp, edges, masks, edge_to_var) + nvars = num_variables(bp); nt = num_tensors(bp) + EdgeMask = bitmask_type(nvars); TensorMask = bitmask_type(nt) + loops = NamedTuple{(:edges, :tensors), Tuple{EdgeMask, TensorMask}}[] + for mask in masks + iszero(mask) && continue + edge_mask = zero(EdgeMask); tensor_mask = zero(TensorMask) + for edge_idx in mask_indices(mask, length(edges)) + u, v = edges[edge_idx] + key = u < v ? (u, v) : (v, u) + var = edge_to_var[key] + edge_mask |= bmask(EdgeMask, var); tensor_mask |= bmask(TensorMask, u, v) end - length(basis) == rank && break + push!(loops, (edges = edge_mask, tensors = tensor_mask)) end - - length(basis) == rank || throw(ArgumentError("cycle basis incomplete: got $(length(basis)) of $rank")) - return basis + return loops end -""" - minimal_loops(g::SimpleGraph) -> CycleBasis - -Return a minimum cycle basis of `g`. Each cycle is a bitmask over the edge -list stored in `CycleBasis.edges`. -""" -function minimal_loops(g::SimpleGraph) - eds = edge_list(g) - rank = cycle_rank(g) - INT = bitmask_type(length(eds)) - rank == 0 && return CycleBasis{INT}(eds, INT[]) - candidates, weights = _candidate_cycles(g, eds) - cycles = _minimum_cycle_basis(candidates, weights, rank, length(eds)) - return CycleBasis{INT}(eds, cycles) +function loop_basis(bp::BeliefPropgation) + g, edge_to_var = tensor_connectivity_graph(bp) + eds, idx = edge_list_index(g) + EdgeMask = bitmask_type(length(eds)) + cycles = [cycle_mask(c, idx, EdgeMask) for c in cycle_basis(g)] + return _loops_from_edge_masks(bp, eds, cycles, edge_to_var) end -# --- Loop series contractions --- - -""" - LoopExcitation - -A loop excitation represented by bitmasks for edges (variables) and tensors. -""" -struct LoopExcitation{E<:Integer, T<:Integer} - edges::E - tensors::T +function loop_series(bp::BeliefPropgation, trunc::LoopSeriesTruncation) + g, edge_to_var = tensor_connectivity_graph(bp) + series = loop_series(g, trunc) + return _loops_from_edge_masks(bp, series.edges, series.loops, edge_to_var) end -loop_degree(loop::LoopExcitation) = count_ones(loop.edges) - -struct LoopSeriesCache{T, VT <: AbstractVector{T}} - message_in_norm::Vector{Vector{VT}} - v2t_pos::Vector{Dict{Int, Int}} - complement_proj::Vector{Union{Nothing, Matrix{T}}} -end +struct LoopSeriesCache{T, VT <: AbstractVector{T}}; message_in_norm::Vector{Vector{VT}}; complement_proj::Vector{Union{Nothing, Matrix{T}}}; reduced_cache::Vector{Dict{Any, Any}}; plan_cache::Dict{Any, Any}; end function LoopSeriesCache(bp::BeliefPropgation, state::BPState{T}) where {T} - nvars = num_variables(bp) + nvars = num_variables(bp); nt = num_tensors(bp) msg_norm = Vector{Vector{typeof(state.message_in[1][1])}}(undef, nvars) - v2t_pos = Vector{Dict{Int, Int}}(undef, nvars) c_mats = Vector{Union{Nothing, Matrix{T}}}(undef, nvars) - for v in 1:nvars - msgs = state.message_in[v] - msg_norm[v] = [copy(m) for m in msgs] - pos = Dict{Int, Int}() - for (idx, t) in enumerate(bp.v2t[v]) - pos[t] = idx - end - v2t_pos[v] = pos - + msgs = state.message_in[v]; msg_norm[v] = [copy(m) for m in msgs] if length(msgs) == 2 - m1, m2 = msgs - s = dot(m2, m1) + m1, m2 = msgs; s = dot(m2, m1) iszero(s) && throw(ArgumentError("edge $v has zero message overlap")) - scale1 = inv(sqrt(s)) - scale2 = inv(sqrt(conj(s))) - msg_norm[v][1] = m1 .* scale1 - msg_norm[v][2] = m2 .* scale2 - d = length(m1) - P = msg_norm[v][2] * msg_norm[v][1]' + scale1 = inv(sqrt(s)); scale2 = inv(sqrt(conj(s))) + msg_norm[v][1] = m1 .* scale1; msg_norm[v][2] = m2 .* scale2 + d = length(m1); P = msg_norm[v][2] * msg_norm[v][1]' c_mats[v] = Matrix{T}(I, d, d) - P else c_mats[v] = nothing end end - - return LoopSeriesCache(msg_norm, v2t_pos, c_mats) + return LoopSeriesCache(msg_norm, c_mats, [Dict{Any, Any}() for _ in 1:nt], Dict{Any, Any}()) end -function tensor_connectivity_graph(bp::BeliefPropgation) - g = SimpleGraph(num_tensors(bp)) - edge_to_var = Dict{Tuple{Int, Int}, Int}() - - for v in 1:num_variables(bp) - tids = bp.v2t[v] - length(tids) == 2 || continue - t1, t2 = tids - a, b = edge_key(t1, t2) - if haskey(edge_to_var, (a, b)) - throw(ArgumentError("multiple variables between tensors $a and $b")) - end - edge_to_var[(a, b)] = v - add_edge!(g, a, b) - end - - return g, edge_to_var -end - -function loop_basis(bp::BeliefPropgation) - g, edge_to_var = tensor_connectivity_graph(bp) - basis = minimal_loops(g) - return loops_from_basis(bp, basis, edge_to_var) -end - -function loops_from_basis(bp::BeliefPropgation, basis::CycleBasis{INT}, edge_to_var) where {INT<:Integer} - nvars = num_variables(bp) - nt = num_tensors(bp) - EdgeMask = bitmask_type(nvars) - TensorMask = bitmask_type(nt) - loops = LoopExcitation{EdgeMask, TensorMask}[] - - for cycle in basis.cycles - edge_mask = zero(EdgeMask) - tensor_mask = zero(TensorMask) - for edge_idx in mask_indices(cycle, length(basis.edges)) - u, v = basis.edges[edge_idx] - a, b = edge_key(u, v) - var = edge_to_var[(a, b)] - edge_mask = edge_mask | bmask(EdgeMask, var) - tensor_mask = tensor_mask | bmask(TensorMask, u, v) - end - push!(loops, LoopExcitation(edge_mask, tensor_mask)) - end - - return loops -end - -_scalar(x) = x isa AbstractArray && ndims(x) == 0 ? x[] : x - function _message_to_tensor(bp::BeliefPropgation, state::BPState, cache::LoopSeriesCache, v::Int, t::Int) tids = bp.v2t[v] if length(tids) == 2 - idx = cache.v2t_pos[v][t] - idx_other = idx == 1 ? 2 : 1 - return cache.message_in_norm[v][idx_other] + idx = findfirst(==(t), tids) + idx === nothing && throw(ArgumentError("tensor $t not attached to variable $v")) + return cache.message_in_norm[v][idx == 1 ? 2 : 1] elseif length(tids) == 1 - idx = cache.v2t_pos[v][t] - return state.message_out[v][idx] + return state.message_out[v][1] else throw(ArgumentError("loop corrections require variables of degree 1 or 2; variable $v has degree $(length(tids))")) end end -function _reduced_tensor(bp::BeliefPropgation, state::BPState, cache::LoopSeriesCache, loop_mask::T, t::Int) where {T<:Integer} +function _reduced_tensor(bp::BeliefPropgation, state::BPState, cache::LoopSeriesCache, loop_edges, t::Int) vars = bp.t2v[t] - ixs = Vector{Vector{Int}}() - tensors = Any[] - push!(ixs, vars) - push!(tensors, bp.tensors[t]) - - keep_vars = Int[] - for v in vars - if readbit(loop_mask, v) == 1 - push!(keep_vars, v) + local_mask_type = bitmask_type(length(vars)); local_mask = zero(local_mask_type) + ixs = [vars]; tensors = Any[bp.tensors[t]]; keep = Int[] + for (i, v) in enumerate(vars) + if readbit(loop_edges, v) == 1 + push!(keep, v); local_mask |= bmask(local_mask_type, i) else msg = _message_to_tensor(bp, state, cache, v, t) - push!(ixs, [v]) - push!(tensors, msg) + push!(ixs, [v]); push!(tensors, msg) end end - - code = EinCode(ixs, keep_vars) - return code(tensors...), keep_vars + cache_t = cache.reduced_cache[t] + haskey(cache_t, local_mask) && return cache_t[local_mask], keep + reduced = EinCode(ixs, keep)(tensors...) + cache_t[local_mask] = reduced + return reduced, keep end -""" - loop_weight(bp::BeliefPropgation, state::BPState, loop::LoopExcitation; optimizer=nothing) - -Compute the weight of a single loop excitation by contracting the tensors -in the loop with BP messages absorbed on external edges and `I - P` projectors -inserted on loop edges. -""" -function loop_weight(bp::BeliefPropgation, state::BPState, loop::LoopExcitation; optimizer=nothing) - cache = LoopSeriesCache(bp, state) - return loop_weight(bp, state, loop, cache; optimizer) -end - -function loop_weight(bp::BeliefPropgation, state::BPState, loop::LoopExcitation, cache::LoopSeriesCache; optimizer=nothing) - nvars = num_variables(bp) - tensors = Any[] - labels = Vector{Vector{Int}}() - - for t in 1:num_tensors(bp) - readbit(loop.tensors, t) == 1 || continue +function loop_weight(bp::BeliefPropgation, state::BPState, loop; optimizer = nothing, cache = nothing) + cache = cache === nothing ? LoopSeriesCache(bp, state) : cache + nvars = num_variables(bp); nt = num_tensors(bp) + tensors = Any[]; labels = Vector{Vector{Int}}() + sizehint!(tensors, count_ones(loop.tensors) + count_ones(loop.edges)); sizehint!(labels, count_ones(loop.tensors) + count_ones(loop.edges)) + for t in mask_indices(loop.tensors, nt) reduced, keep_vars = _reduced_tensor(bp, state, cache, loop.edges, t) - tensor_labels = Int[] - for v in keep_vars - t1, t2 = bp.v2t[v] - push!(tensor_labels, t == t1 ? v : v + nvars) - end - push!(tensors, reduced) - push!(labels, tensor_labels) + push!(tensors, reduced); push!(labels, [t == bp.v2t[v][1] ? v : v + nvars for v in keep_vars]) end - for v in mask_indices(loop.edges, nvars) - C = cache.complement_proj[v] - C === nothing && throw(ArgumentError("loop edge $v is not degree-2")) - push!(tensors, C) - push!(labels, [v, v + nvars]) + C = cache.complement_proj[v]; C === nothing && throw(ArgumentError("loop edge $v is not degree-2")) + push!(tensors, C); push!(labels, [v, v + nvars]) end - code = EinCode(labels, Int[]) if optimizer !== nothing - size_dict = OMEinsum.get_size_dict(labels, tensors) - code = optimize_code(code, size_dict, optimizer) + label_map = Dict{Int, Int}(); dims = Int[] + canon_labels = Vector{Vector{Int}}(undef, length(labels)) + for (i, labs) in enumerate(labels) + clabs = Vector{Int}(undef, length(labs)) + for (j, lab) in enumerate(labs) + cid = get!(label_map, lab) do + push!(dims, size(tensors[i], j)); length(dims) + end + clabs[j] = cid + end + canon_labels[i] = clabs + end + key = (Tuple(map(Tuple, canon_labels)), Tuple(dims)) + code = get!(cache.plan_cache, key) do + size_dict = OMEinsum.get_size_dict(canon_labels, tensors) + optimize_code(EinCode(canon_labels, Int[]), size_dict, optimizer) + end end - return _scalar(code(tensors...)) -end - -""" - bp_vacuum_weight(bp::BeliefPropgation, state::BPState) - -Compute the BP vacuum contribution by contracting each tensor with incoming -messages on all edges and multiplying the resulting scalars. -""" -function bp_vacuum_weight(bp::BeliefPropgation, state::BPState) - cache = LoopSeriesCache(bp, state) - return bp_vacuum_weight(bp, state, cache) + return code(tensors...)[] end -function bp_vacuum_weight(bp::BeliefPropgation, state::BPState, cache::LoopSeriesCache) +function _bp_vacuum_factors(bp::BeliefPropgation, state::BPState, cache::LoopSeriesCache) empty_loop = zero(bitmask_type(num_variables(bp))) - vals = map(1:num_tensors(bp)) do t + return map(1:num_tensors(bp)) do t reduced, _ = _reduced_tensor(bp, state, cache, empty_loop, t) - _scalar(reduced) + reduced[] end - return prod(vals) end -function _disjoint_loop_sum(loops::Vector{LoopExcitation{E, T}}, weights::Vector, K::Int, edge_zero::E, tensor_zero::T) where {E<:Integer, T<:Integer} - isempty(loops) && return zero(weights[1]) - K <= 0 && return zero(weights[1]) - total = zero(weights[1]) - - function backtrack(start::Int, depth::Int, used_edges::E, used_tensors::T, prod_weight) - depth > 0 && (total += prod_weight) - depth == K && return - for i in start:length(loops) - loop = loops[i] - iszero(loop.edges & used_edges) || continue - iszero(loop.tensors & used_tensors) || continue - backtrack(i + 1, depth + 1, used_edges | loop.edges, used_tensors | loop.tensors, prod_weight * weights[i]) - end - end - - backtrack(1, 0, edge_zero, tensor_zero, one(weights[1])) - return total +function bp_vacuum_weight(bp::BeliefPropgation, state::BPState; cache = nothing) + cache = cache === nothing ? LoopSeriesCache(bp, state) : cache + return prod(_bp_vacuum_factors(bp, state, cache)) end -""" - loop_expansion(bp::BeliefPropgation, state::BPState; loops=loop_basis(bp), K::Int=1, optimizer=nothing) - -Compute a loop series correction to the BP vacuum contribution. The correction -includes all disjoint combinations of up to `K` loops. Returns a named tuple -with the BP vacuum weight, correction, total value, and per-loop weights. -""" -function loop_expansion(bp::BeliefPropgation, state::BPState; loops = loop_basis(bp), K::Int = 1, optimizer = nothing) +function loop_corrections(bp::BeliefPropgation, state::BPState; loops, n_edges_trunc::Int = typemax(Int), n_loops_trunc::Int = 1, optimizer = nothing) cache = LoopSeriesCache(bp, state) - bp_weight = bp_vacuum_weight(bp, state, cache) - - if isempty(loops) || K <= 0 + vacuums = _bp_vacuum_factors(bp, state, cache) + bp_weight = prod(vacuums) + if isempty(loops) || n_edges_trunc <= 0 || n_loops_trunc <= 0 return (bp_weight = bp_weight, correction = zero(bp_weight), value = bp_weight, loop_weights = typeof(bp_weight)[]) end - - loop_weights = [loop_weight(bp, state, loop, cache; optimizer) for loop in loops] - edge_zero = zero(bitmask_type(num_variables(bp))) - tensor_zero = zero(bitmask_type(num_tensors(bp))) - correction = _disjoint_loop_sum(loops, loop_weights, K, edge_zero, tensor_zero) + edge_counts = count_ones.(getfield.(loops, :edges)) + keep = findall(edge_counts .<= n_edges_trunc) + isempty(keep) && return (bp_weight = bp_weight, correction = zero(bp_weight), value = bp_weight, loop_weights = typeof(bp_weight)[]) + loops = loops[keep]; edge_counts = edge_counts[keep] + loop_tensors = getfield.(loops, :tensors) + loop_weights = Vector{typeof(bp_weight)}(undef, length(loops)) + for i in eachindex(loops) + raw = loop_weight(bp, state, loops[i]; optimizer, cache) + vac = prod(vacuums[t] for t in mask_indices(loop_tensors[i], num_tensors(bp))) + iszero(vac) && throw(ArgumentError("loop vacuum factor is zero")) + loop_weights[i] = raw / vac + end + total = zero(bp_weight) + TensorMask = typeof(loop_tensors[1]) + function backtrack(start, depth, used_tensors, weight_prod, edge_total) + depth > 0 && (total += weight_prod) + depth == n_loops_trunc && return + for i in start:length(loop_weights) + iszero(loop_tensors[i] & used_tensors) || continue + new_edge_total = edge_total + edge_counts[i] + new_edge_total > n_edges_trunc && continue + backtrack(i + 1, depth + 1, used_tensors | loop_tensors[i], weight_prod * loop_weights[i], new_edge_total) + end + end + backtrack(1, 0, zero(TensorMask), one(bp_weight), 0) + correction = bp_weight * total return (bp_weight = bp_weight, correction = correction, value = bp_weight + correction, loop_weights = loop_weights) end diff --git a/test/loop_series.jl b/test/loop_series.jl index 330bad5..dfb5181 100644 --- a/test/loop_series.jl +++ b/test/loop_series.jl @@ -58,24 +58,26 @@ end @testset "cycle basis on Petersen graph" begin g = Graphs.SimpleGraphs.smallgraph(:petersen) - basis = minimal_loops(g) - rank = ne(g) - nv(g) + length(connected_components(g)) - @test length(basis.cycles) == rank - lengths = count_ones.(basis.cycles) - @test minimum(lengths) == 5 - @test all(len -> len >= 5, lengths) - @test gf2_rank(basis.cycles, length(basis.edges)) == rank + eds = [(min(src(e), dst(e)), max(src(e), dst(e))) for e in edges(g)] + sort!(eds) edge_index = Dict{Tuple{Int, Int}, Int}() - for (i, (u, v)) in enumerate(basis.edges) + for (i, (u, v)) in enumerate(eds) edge_index[(u, v)] = i edge_index[(v, u)] = i end - masks = Set{eltype(basis.cycles)}() + basis_cycles = [cycle_mask(c, edge_index, TensorInference.bitmask_type(length(eds))) for c in cycle_basis(g)] + rank = ne(g) - nv(g) + length(connected_components(g)) + @test length(basis_cycles) == rank + lengths = count_ones.(basis_cycles) + @test minimum(lengths) == 5 + @test all(len -> len >= 5, lengths) + @test gf2_rank(basis_cycles, length(eds)) == rank + masks = Set{eltype(basis_cycles)}() for cyc in simplecycles(DiGraph(g)) length(cyc) < 3 && continue - push!(masks, cycle_mask(cyc, edge_index, eltype(basis.cycles))) + push!(masks, cycle_mask(cyc, edge_index, eltype(basis_cycles))) end - basis_vecs, pivots = gf2_basis(basis.cycles, length(basis.edges)) + basis_vecs, pivots = gf2_basis(basis_cycles, length(eds)) for mask in masks @test iszero(reduce_with_basis(mask, basis_vecs, pivots)) end @@ -96,7 +98,28 @@ function cycle_uai(tensors::Vector{Matrix{T}}) where {T} return TensorInference.UAIModel(n, cards, factors) end -_scalar(x) = x isa AbstractArray && ndims(x) == 0 ? x[] : x +function disjoint_cycle_uai(tensors1::Vector{Matrix{T}}, tensors2::Vector{Matrix{T}}) where {T} + n1 = length(tensors1) + n2 = length(tensors2) + d1 = size(tensors1[1], 1) + d2 = size(tensors2[1], 1) + size(tensors1[1], 1) == size(tensors1[1], 2) || throw(ArgumentError("tensors1 must be square")) + size(tensors2[1], 1) == size(tensors2[1], 2) || throw(ArgumentError("tensors2 must be square")) + all(t -> size(t, 1) == d1 && size(t, 2) == d1, tensors1) || throw(ArgumentError("dimension mismatch in tensors1")) + all(t -> size(t, 1) == d2 && size(t, 2) == d2, tensors2) || throw(ArgumentError("dimension mismatch in tensors2")) + cards = vcat(fill(d1, n1), fill(d2, n2)) + factors = Vector{TensorInference.Factor{T, 2}}(undef, n1 + n2) + for i in 1:n1 + j = i == n1 ? 1 : i + 1 + factors[i] = TensorInference.Factor((i, j), tensors1[i]) + end + offset = n1 + for i in 1:n2 + j = i == n2 ? 1 : i + 1 + factors[offset + i] = TensorInference.Factor((offset + i, offset + j), tensors2[i]) + end + return TensorInference.UAIModel(n1 + n2, cards, factors) +end edge_key(u::Int, v::Int) = u < v ? (u, v) : (v, u) @@ -127,15 +150,12 @@ function graph_uai(g::SimpleGraph, bond_dim::Int; rng::AbstractRNG = Random.defa return TensorInference.UAIModel(length(eds), fill(bond_dim, length(eds)), factors) end -exact_weight(uai) = _scalar(probability(TensorNetworkModel(uai))) +exact_weight(uai) = probability(TensorNetworkModel(uai))[] -function run_loop_expansion(uai; max_iter::Int = 500, tol::Real = 1e-8, K::Int = 1) +function run_bp(uai; max_iter::Int = 500, tol::Real = 1e-8) bp = BeliefPropgation(uai) state, info = belief_propagate(bp; max_iter, tol) - loops = loop_basis(bp) - bp_weight = bp_vacuum_weight(bp, state) - result = loop_expansion(bp, state; loops, K) - return bp, state, info, loops, bp_weight, result + return bp, state, info, bp_vacuum_weight(bp, state) end function random_cyclic_graph(n::Int, m::Int; rng::AbstractRNG = Random.default_rng(), max_tries::Int = 100) @@ -160,29 +180,78 @@ end A = [0.2 0.9; 0.9 0.2] tensors = [A for _ in 1:5] uai = cycle_uai(tensors) - _, _, info, loops, bp_weight, result = run_loop_expansion(uai; max_iter=500, tol=1e-10, K=1) + bp, state, info, bp_weight = run_bp(uai; max_iter=500, tol=1e-10) @test info.converged - @test length(loops) == 1 - @test count_ones(loops[1].edges) == 5 - exact = tr(A^5) @test !isapprox(bp_weight, exact; atol=1e-6, rtol=1e-6) - @test result.value ≈ exact atol=1e-6 + + strategies = [ + ("basis", loop_basis(bp)), + ("xor", loop_series(bp, XORLoopSum(1))), + ("union", loop_series(bp, UnionLoopSum(1))), + ("degree", loop_series(bp, Degree(5))), + ("cyclomatic", loop_series(bp, Cyclomatic(1))), + ] + for (name, loops) in strategies + @testset "$name" begin + @test length(loops) == 1 + @test count_ones(loops[1].edges) == 5 + result = loop_corrections(bp, state; loops) + @info "exact: $exact, BP: $bp_weight, loop corrected: $(result.value)" + @test !isapprox(bp_weight, exact; atol=1e-6, rtol=1e-6) + @test result.value ≈ exact atol=1e-6 + end + end end -@testset "loop expansion on Petersen random tensors" begin - rng = MersenneTwister(17) +@testset "loop expansion on disjoint cycles" begin + A = [0.2 0.9; 0.9 0.2] + tensors1 = [A for _ in 1:5] + tensors2 = [A for _ in 1:5] + uai = disjoint_cycle_uai(tensors1, tensors2) + bp, state, info, bp_weight = run_bp(uai; max_iter=500, tol=1e-10) + @test info.converged + exact = tr(A^5)^2 + @test !isapprox(bp_weight, exact; atol=1e-6, rtol=1e-6) + + strategies = [ + ("basis", loop_basis(bp)), + ("degree", loop_series(bp, Degree(5))), + ("cyclomatic", loop_series(bp, Cyclomatic(1))), + ("xor", loop_series(bp, XORLoopSum(1))), + ("union", loop_series(bp, UnionLoopSum(2))), + ] + for (name, loops) in strategies + @testset "$name" begin + @test length(loops) == 2 + result_single = loop_corrections(bp, state; loops, n_edges_trunc = 5, n_loops_trunc = 1) + @test !isapprox(result_single.value, exact; atol=1e-6, rtol=1e-6) + result_multi = loop_corrections(bp, state; loops, n_edges_trunc = 10, n_loops_trunc = 2) + @test result_multi.value ≈ exact atol=1e-6 + end + end +end + +@testset "loop expansion on Petersen graph" begin + rng = MersenneTwister(42) g = Graphs.SimpleGraphs.smallgraph(:petersen) uai = graph_uai(g, 2; rng) - _, _, info, loops, bp_weight, result = run_loop_expansion(uai; max_iter=500, tol=1e-8, K=1) + bp, state, info, bp_weight = run_bp(uai; max_iter=500, tol=1e-8) @test info.converged - @test !isempty(loops) - exact = exact_weight(uai) @test !isapprox(bp_weight, exact; atol=1e-6, rtol=1e-6) - @test isfinite(bp_weight) && isfinite(result.value) && isfinite(exact) + + for trunc in [Degree(12), Cyclomatic(4)] + @testset "$(nameof(typeof(trunc)))" begin + @time loops = loop_series(bp, trunc) + @test !isempty(loops) + @time result = loop_corrections(bp, state; loops) + @info "exact: $exact, BP: $bp_weight, loop corrected: $(result.value)" + @test result.value ≈ exact atol=1e-6 + end + end end @testset "loop expansion on random simple graphs" begin @@ -190,11 +259,17 @@ end rng = MersenneTwister(seed) g = random_cyclic_graph(n, m; rng) uai = graph_uai(g, 2; rng) - _, _, info, loops, bp_weight, result = run_loop_expansion(uai; max_iter=600, tol=1e-8, K=1) + bp, state, info, bp_weight = run_bp(uai; max_iter=1600, tol=1e-12) @test info.converged - @test !isempty(loops) exact = exact_weight(uai) @test !isapprox(bp_weight, exact; atol=1e-6, rtol=1e-6) - @test isfinite(bp_weight) && isfinite(result.value) && isfinite(exact) + for trunc in (Degree(ne(g)), Cyclomatic(1)) + @testset "$(nameof(typeof(trunc)))" begin + loops = loop_series(bp, trunc) + @test !isempty(loops) + result = loop_corrections(bp, state; loops) + @test isfinite(bp_weight) && isfinite(result.value) && isfinite(exact) + end + end end end