Skip to content

Conversation

@oscarhiggott
Copy link
Owner

@oscarhiggott oscarhiggott commented Nov 27, 2025

Description:

This PR introduces functionality to efficiently reweight a subset of edges in the matching graph on a per-shot basis during decoding, without rebuilding the underlying graph structures. This addresses the need for dynamically adjusting edge weights while maintaining high performance (#172).

Key Changes:

  • Python API:

    • Updated Matching.decode and Matching.decode_batch to accept a new optional argument edge_reweights.
    • edge_reweights accepts a NumPy array (or list of arrays for batch decoding) of shape (N, 3), where each row [u, v, new_weight] specifies the new weight for the edge between nodes u and v.
  • Core Logic (C++):

    • Implemented a generic templated helper function apply_temp_reweights_generic to consolidate the reweighting logic for MatchingGraph and SearchGraph. This method temporarily updates edge weights in place and stores the original values in a previous_weights stack.
    • Added a sign-consistency check: reweighting is disallowed if it flips the sign of the edge weight (e.g., positive to negative). This prevents invalidating the pre-calculated negative weight syndrome/observables.
    • Correctly handles Negative-to-Negative reweighting by updating the global negative_weight_sum (and tracking the delta) in MatchingGraph to ensure the final solution weight remains accurate.
    • Optimized decode and decode_batch bindings to reuse a single reweight_buffer on the graph object, avoiding repeated heap allocations.
    • Updated Mwpm to include a search_flooder_available() check, ensuring reweights are correctly applied to the SearchGraph if it exists (e.g., due to a high observable count or the use of correlations).
  • Tests:

    • Added Python unit tests (tests/matching/test_reweight.py) covering:
      • Single-shot and batch reweighting.
      • Positive-to-Positive and Negative-to-Negative updates.
      • Error handling for forbidden sign flips.
      • Correct restoration of weights after decoding.
    • Added C++ unit tests in graph.test.cc and search_graph.test.cc to verify the low-level correctness of the reweighting logic and delta tracking.
    • Added performance benchmarks (Decode_surface_r11_d11_p100_reweight, Decode_surface_r21_d21_p1000_reweight, Decode_surface_r21_d21_p1000_reweight_with_correlations) in mwpm_decoding.perf.cc to measure the overhead of dynamic reweighting. These are placed alongside similar non-reweighted benchmarks for easy comparison.

Currently a WIP. TODO: check how weights are handled if they exceed the normalization constant (and handle this appropriately).

   1. C++ Core (`MatchingGraph` and `SearchGraph`):
       * Added apply_temp_reweights method to MatchingGraph in src/pymatching/sparse_blossom/flooder/graph.cc. This method accepts a list of edge reweights (u, v,
         new_weight), converts the weights to the internal integer representation (handling the normalising constant), updates the edge weights, and saves the previous
         weights for later restoration.
       * Added apply_temp_reweights method to SearchGraph in src/pymatching/sparse_blossom/search/search_graph.cc to support correlated decoding (which uses a separate
         search graph). This method requires the normalising_constant to be passed from the main graph.

   2. C++ Python Bindings (`user_graph.pybind.cc`):
       * Updated decode to accept an optional edge_reweights argument (numpy array of shape (N, 3)). If provided, it applies the temporary reweights to mwpm.flooder.graph
         (and mwpm.search_flooder.graph if correlations are enabled) before decoding, and reverts them using undo_reweights immediately after.
       * Updated decode_batch to accept an optional edge_reweights argument (list of numpy arrays). It iterates through the batch, applying and reverting reweights for each
         shot if a reweight rule is provided for that shot.

   3. Python API (`src/pymatching/matching.py`):
       * Updated Matching.decode signature to include edge_reweights: Optional[np.ndarray].
       * Updated Matching.decode_batch signature to include edge_reweights: Optional[List[Optional[np.ndarray]]].
       * Updated docstrings to explain the new functionality and format of edge_reweights.

   4. Verification:
       * Created tests/matching/test_reweight.py and verified:
           * Single shot decoding with reweighting (standard and boundary edges).
           * Batch decoding with per-shot reweighting.
           * Weights are correctly restored after decoding.
       * Ran existing tests (decode_test.py) to ensure no regressions.

  The feature allows users to dynamically adjust edge weights for individual shots or batches without the performance penalty of rebuilding the Mwpm object.

  Example Usage:

    1 import pymatching
    2 import numpy as np
    3
    4 m = pymatching.Matching()
    5 m.add_edge(0, 1, weight=2)
    6 m.add_edge(1, 2, weight=2)
    7
    8 # Normal decode
    9 # path 0-1-2 has weight 4
   10 m.decode(np.array([1, 0, 1]))
   11
   12 # Decode with temporary reweight of edge (0, 1) to weight 5
   13 # New path weight 5+2 = 7
   14 reweights = np.array([[0, 1, 5.0]])
   15 correction, weight = m.decode(np.array([1, 0, 1]), return_weight=True, edge_reweights=reweights)
   16 print(weight) # 7.0
   1. `Mwpm` Class Update:
       * Added search_flooder_available() method to the Mwpm class in src/pymatching/sparse_blossom/matcher/mwpm.h. This method checks if flooder.graph.nodes.size() ==
         search_flooder.graph.nodes.size() to determine if the search graph is present and synchronized.

   2. Reweighting Logic Update:
       * Updated src/pymatching/sparse_blossom/driver/user_graph.pybind.cc to use mwpm.search_flooder_available() instead of enable_correlations when deciding whether to
         apply or undo temporary reweights to the search_flooder.graph. This ensures that if the search graph exists (e.g., due to a large number of observables) but
         enable_correlations is False, the reweights are still correctly applied to it.

   3. Test Coverage:
       * Added a new test case test_decode_reweight_large_observables to tests/matching/test_reweight.py. This test constructs a matching graph with more than 64
         observables (triggering the creation of the search graph) but decodes with enable_correlations=False. It verifies that reweighting an edge correctly affects the
         decoding result, confirming that the search graph is being reweighted.

  Files Modified:
   * src/pymatching/sparse_blossom/matcher/mwpm.h
   * src/pymatching/sparse_blossom/driver/user_graph.pybind.cc
   * tests/matching/test_reweight.py
   1. Positive-to-Positive Reweighting: Standard update of edge weights.
   2. Negative-to-Negative Reweighting: Updates edge weights AND correctly updates the global negative_weight_sum to ensure the final result weight is accurate.
   3. Sign Flips: Strictly disallowed, raising a ValueError (via std::invalid_argument).

  This was achieved by:
   1. Adding neighbor_markers to DetectorNode to store the original sign of edge weights (using the WEIGHT_SIGN flag).
   2. Populating neighbor_markers in MatchingGraph::add_edge and MatchingGraph::add_boundary_edge.
   3. Updating MatchingGraph::apply_temp_reweights to:
       * Check for sign consistency using neighbor_markers.
       * Update negative_weight_sum (and track delta) if reweighting a negative edge.
   4. Updating MatchingGraph::undo_reweights to revert negative_weight_sum.
   5. Updating SearchGraph::apply_temp_reweights to also check for sign consistency.
   6. Crucially, updating mwpm_decoding.cc to use mwpm.flooder.graph.negative_weight_sum instead of the stale mwpm.flooder.negative_weight_sum, ensuring the dynamic
      updates are used in the final weight calculation.
   1. `MatchingGraph` Class Update:
       * Added std::vector<std::tuple<size_t, int64_t, double>> reweight_buffer member to MatchingGraph in src/pymatching/sparse_blossom/flooder/graph.h. This buffer is
         reused for temporary storage of reweighting instructions.

   2. `decode` and `decode_batch` Optimization:
       * Updated src/pymatching/sparse_blossom/driver/user_graph.pybind.cc to use mwpm.flooder.graph.reweight_buffer instead of a local std::vector. The buffer is
         cleared and reserved before use, preserving its capacity across calls.

   3. Test Verification:
       * Re-ran tests/matching/test_reweight.py and confirmed that all tests pass, ensuring the optimization did not introduce any regressions.

  This change minimizes heap allocations during repeated decoding with reweights, improving performance especially for batch processing.
   1. `apply_temp_reweights_generic` Implementation:
       * Implemented a templated helper function apply_temp_reweights_generic in src/pymatching/sparse_blossom/flooder/graph.h.
       * This function encapsulates the logic for iterating reweights, scaling weights, checking sign consistency (using neighbor_markers), updating neighbor_weights,
         and pushing to previous_weights.
       * It accepts a lambda on_negative_edge to handle graph-specific side effects (updating negative_weight_sum).

   2. `MatchingGraph` Refactor:
       * Updated MatchingGraph::apply_temp_reweights in src/pymatching/sparse_blossom/flooder/graph.cc to use the generic helper.
       * Passes a lambda that updates negative_weight_sum and negative_weight_sum_delta when a negative edge is reweighted.

   3. `SearchGraph` Refactor:
       * Updated SearchGraph::apply_temp_reweights in src/pymatching/sparse_blossom/search/search_graph.cc to use the generic helper.
       * Passes a no-op lambda since SearchGraph does not track negative_weight_sum internally.

   4. Verification:
       * Re-ran tests/matching/test_reweight.py and confirmed that all 7 tests passed, ensuring that the refactoring preserved all functionality and correctness
         guarantees.

  The codebase is now cleaner, more maintainable, and retains the robust sign-checking logic implemented earlier.
@qHaipengXie
Copy link

Great! Hopefully this feature will be available soon!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants