-
Notifications
You must be signed in to change notification settings - Fork 53
Add Fast Per-Shot Edge Reweighting #178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
oscarhiggott
wants to merge
11
commits into
master
Choose a base branch
from
u/oscarhiggott/fast-edge-reweighting-issue172
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add Fast Per-Shot Edge Reweighting #178
oscarhiggott
wants to merge
11
commits into
master
from
u/oscarhiggott/fast-edge-reweighting-issue172
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1. C++ Core (`MatchingGraph` and `SearchGraph`):
* Added apply_temp_reweights method to MatchingGraph in src/pymatching/sparse_blossom/flooder/graph.cc. This method accepts a list of edge reweights (u, v,
new_weight), converts the weights to the internal integer representation (handling the normalising constant), updates the edge weights, and saves the previous
weights for later restoration.
* Added apply_temp_reweights method to SearchGraph in src/pymatching/sparse_blossom/search/search_graph.cc to support correlated decoding (which uses a separate
search graph). This method requires the normalising_constant to be passed from the main graph.
2. C++ Python Bindings (`user_graph.pybind.cc`):
* Updated decode to accept an optional edge_reweights argument (numpy array of shape (N, 3)). If provided, it applies the temporary reweights to mwpm.flooder.graph
(and mwpm.search_flooder.graph if correlations are enabled) before decoding, and reverts them using undo_reweights immediately after.
* Updated decode_batch to accept an optional edge_reweights argument (list of numpy arrays). It iterates through the batch, applying and reverting reweights for each
shot if a reweight rule is provided for that shot.
3. Python API (`src/pymatching/matching.py`):
* Updated Matching.decode signature to include edge_reweights: Optional[np.ndarray].
* Updated Matching.decode_batch signature to include edge_reweights: Optional[List[Optional[np.ndarray]]].
* Updated docstrings to explain the new functionality and format of edge_reweights.
4. Verification:
* Created tests/matching/test_reweight.py and verified:
* Single shot decoding with reweighting (standard and boundary edges).
* Batch decoding with per-shot reweighting.
* Weights are correctly restored after decoding.
* Ran existing tests (decode_test.py) to ensure no regressions.
The feature allows users to dynamically adjust edge weights for individual shots or batches without the performance penalty of rebuilding the Mwpm object.
Example Usage:
1 import pymatching
2 import numpy as np
3
4 m = pymatching.Matching()
5 m.add_edge(0, 1, weight=2)
6 m.add_edge(1, 2, weight=2)
7
8 # Normal decode
9 # path 0-1-2 has weight 4
10 m.decode(np.array([1, 0, 1]))
11
12 # Decode with temporary reweight of edge (0, 1) to weight 5
13 # New path weight 5+2 = 7
14 reweights = np.array([[0, 1, 5.0]])
15 correction, weight = m.decode(np.array([1, 0, 1]), return_weight=True, edge_reweights=reweights)
16 print(weight) # 7.0
1. `Mwpm` Class Update:
* Added search_flooder_available() method to the Mwpm class in src/pymatching/sparse_blossom/matcher/mwpm.h. This method checks if flooder.graph.nodes.size() ==
search_flooder.graph.nodes.size() to determine if the search graph is present and synchronized.
2. Reweighting Logic Update:
* Updated src/pymatching/sparse_blossom/driver/user_graph.pybind.cc to use mwpm.search_flooder_available() instead of enable_correlations when deciding whether to
apply or undo temporary reweights to the search_flooder.graph. This ensures that if the search graph exists (e.g., due to a large number of observables) but
enable_correlations is False, the reweights are still correctly applied to it.
3. Test Coverage:
* Added a new test case test_decode_reweight_large_observables to tests/matching/test_reweight.py. This test constructs a matching graph with more than 64
observables (triggering the creation of the search graph) but decodes with enable_correlations=False. It verifies that reweighting an edge correctly affects the
decoding result, confirming that the search graph is being reweighted.
Files Modified:
* src/pymatching/sparse_blossom/matcher/mwpm.h
* src/pymatching/sparse_blossom/driver/user_graph.pybind.cc
* tests/matching/test_reweight.py
1. Positive-to-Positive Reweighting: Standard update of edge weights.
2. Negative-to-Negative Reweighting: Updates edge weights AND correctly updates the global negative_weight_sum to ensure the final result weight is accurate.
3. Sign Flips: Strictly disallowed, raising a ValueError (via std::invalid_argument).
This was achieved by:
1. Adding neighbor_markers to DetectorNode to store the original sign of edge weights (using the WEIGHT_SIGN flag).
2. Populating neighbor_markers in MatchingGraph::add_edge and MatchingGraph::add_boundary_edge.
3. Updating MatchingGraph::apply_temp_reweights to:
* Check for sign consistency using neighbor_markers.
* Update negative_weight_sum (and track delta) if reweighting a negative edge.
4. Updating MatchingGraph::undo_reweights to revert negative_weight_sum.
5. Updating SearchGraph::apply_temp_reweights to also check for sign consistency.
6. Crucially, updating mwpm_decoding.cc to use mwpm.flooder.graph.negative_weight_sum instead of the stale mwpm.flooder.negative_weight_sum, ensuring the dynamic
updates are used in the final weight calculation.
1. `MatchingGraph` Class Update:
* Added std::vector<std::tuple<size_t, int64_t, double>> reweight_buffer member to MatchingGraph in src/pymatching/sparse_blossom/flooder/graph.h. This buffer is
reused for temporary storage of reweighting instructions.
2. `decode` and `decode_batch` Optimization:
* Updated src/pymatching/sparse_blossom/driver/user_graph.pybind.cc to use mwpm.flooder.graph.reweight_buffer instead of a local std::vector. The buffer is
cleared and reserved before use, preserving its capacity across calls.
3. Test Verification:
* Re-ran tests/matching/test_reweight.py and confirmed that all tests pass, ensuring the optimization did not introduce any regressions.
This change minimizes heap allocations during repeated decoding with reweights, improving performance especially for batch processing.
1. `apply_temp_reweights_generic` Implementation:
* Implemented a templated helper function apply_temp_reweights_generic in src/pymatching/sparse_blossom/flooder/graph.h.
* This function encapsulates the logic for iterating reweights, scaling weights, checking sign consistency (using neighbor_markers), updating neighbor_weights,
and pushing to previous_weights.
* It accepts a lambda on_negative_edge to handle graph-specific side effects (updating negative_weight_sum).
2. `MatchingGraph` Refactor:
* Updated MatchingGraph::apply_temp_reweights in src/pymatching/sparse_blossom/flooder/graph.cc to use the generic helper.
* Passes a lambda that updates negative_weight_sum and negative_weight_sum_delta when a negative edge is reweighted.
3. `SearchGraph` Refactor:
* Updated SearchGraph::apply_temp_reweights in src/pymatching/sparse_blossom/search/search_graph.cc to use the generic helper.
* Passes a no-op lambda since SearchGraph does not track negative_weight_sum internally.
4. Verification:
* Re-ran tests/matching/test_reweight.py and confirmed that all 7 tests passed, ensuring that the refactoring preserved all functionality and correctness
guarantees.
The codebase is now cleaner, more maintainable, and retains the robust sign-checking logic implemented earlier.
|
Great! Hopefully this feature will be available soon! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description:
This PR introduces functionality to efficiently reweight a subset of edges in the matching graph on a per-shot basis during decoding, without rebuilding the underlying graph structures. This addresses the need for dynamically adjusting edge weights while maintaining high performance (#172).
Key Changes:
Python API:
Matching.decodeandMatching.decode_batchto accept a new optional argumentedge_reweights.edge_reweightsaccepts a NumPy array (or list of arrays for batch decoding) of shape(N, 3), where each row[u, v, new_weight]specifies the new weight for the edge between nodesuandv.Core Logic (C++):
apply_temp_reweights_genericto consolidate the reweighting logic forMatchingGraphandSearchGraph. This method temporarily updates edge weights in place and stores the original values in aprevious_weightsstack.negative_weight_sum(and tracking the delta) inMatchingGraphto ensure the final solution weight remains accurate.decodeanddecode_batchbindings to reuse a singlereweight_bufferon the graph object, avoiding repeated heap allocations.Mwpmto include asearch_flooder_available()check, ensuring reweights are correctly applied to theSearchGraphif it exists (e.g., due to a high observable count or the use of correlations).Tests:
tests/matching/test_reweight.py) covering:graph.test.ccandsearch_graph.test.ccto verify the low-level correctness of the reweighting logic and delta tracking.Decode_surface_r11_d11_p100_reweight,Decode_surface_r21_d21_p1000_reweight,Decode_surface_r21_d21_p1000_reweight_with_correlations) inmwpm_decoding.perf.ccto measure the overhead of dynamic reweighting. These are placed alongside similar non-reweighted benchmarks for easy comparison.Currently a WIP. TODO: check how weights are handled if they exceed the normalization constant (and handle this appropriately).