Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions ionq_core/results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Pure-Python results post-processing helpers for IonQ's probability mappings.
"""

import math
from collections.abc import Mapping, Sequence

__all__ = ["expectation_z", "marginal", "probabilities_to_counts", "relabel_to_bitstrings"]


def _validate_probabilities(probabilities: Mapping[str, float]) -> None:
"""Validate that all probabilities are finite and non-negative."""
for state, prob in probabilities.items():
if not math.isfinite(prob) or prob < 0.0:
raise ValueError(f"Probability for state '{state}' must be finite and non-negative, got {prob}.")


def probabilities_to_counts(probabilities: Mapping[str, float], shots: int) -> dict[str, int]:
"""
Convert a probability mapping to exact integer counts summing to `shots`.

Uses the largest-remainder method (Hare quota) to handle floating-point
rounding errors and guarantee the final counts sum perfectly to `shots`.
"""
if shots < 1:
raise ValueError(f"Shots must be at least 1, got {shots}.")

_validate_probabilities(probabilities)

base_counts = {}
remainders = {}

for state, prob in probabilities.items():
exact = prob * shots
base = math.floor(exact)
base_counts[state] = base
remainders[state] = exact - base

shortfall = shots - sum(base_counts.values())

# Sort by remainder descending.
# Tie-breaker: sort by integer state ascending to make it deterministic.
sorted_states = sorted(remainders.keys(), key=lambda s: (-remainders[s], int(s)))

counts = base_counts.copy()
for i in range(shortfall):
counts[sorted_states[i]] += 1

# Only return states that actually have at least 1 count to keep the dict clean
return {k: v for k, v in counts.items() if v > 0}


def relabel_to_bitstrings(probabilities: Mapping[str, float], num_qubits: int) -> dict[str, float]:
"""Convert integer state keys to zero-padded big-endian bitstrings."""
if num_qubits < 1:
raise ValueError(f"num_qubits must be at least 1, got {num_qubits}.")

_validate_probabilities(probabilities)
max_state = (1 << num_qubits) - 1

result = {}
for state, prob in probabilities.items():
state_int = int(state)
if state_int < 0 or state_int > max_state:
raise ValueError(f"State integer {state_int} is out of bounds for {num_qubits} qubits.")

bitstring = f"{state_int:0{num_qubits}b}"
result[bitstring] = prob

return result


def marginal(probabilities: Mapping[str, float], qubits: Sequence[int], num_qubits: int) -> dict[str, float]:
"""
Compute the marginal distribution over a specified subset of qubits.
Maintains the requested order of the subset qubits in the new state keys.
"""
if not qubits:
raise ValueError("Must specify at least one qubit index to marginalize over.")
if len(set(qubits)) != len(qubits):
raise ValueError("Qubit indices must be unique.")
for q in qubits:
if q < 0 or q >= num_qubits:
raise ValueError(f"Qubit index {q} is out of bounds for {num_qubits} qubits.")

_validate_probabilities(probabilities)

result: dict[str, float] = {}
for state, prob in probabilities.items():
state_int = int(state)
new_state_int = 0

# Extract bits big-endian style: qubit 0 is the most significant bit
for i, q in enumerate(qubits):
bit = (state_int >> (num_qubits - 1 - q)) & 1
new_state_int |= bit << (len(qubits) - 1 - i)

new_state_str = str(new_state_int)
result[new_state_str] = result.get(new_state_str, 0.0) + prob

return result


def expectation_z(probabilities: Mapping[str, float], num_qubits: int) -> float:
"""
Calculate the Z-parity expectation value: Σ p(x)·(-1)^popcount(x).
"""
_validate_probabilities(probabilities)
max_state = (1 << num_qubits) - 1
expected_value = 0.0

for state, prob in probabilities.items():
state_int = int(state)
if state_int < 0 or state_int > max_state:
raise ValueError(f"State integer {state_int} is out of bounds for {num_qubits} qubits.")

# Z-parity is 1 if popcount is even, -1 if popcount is odd
parity = 1 if state_int.bit_count() % 2 == 0 else -1
expected_value += prob * parity

return expected_value
140 changes: 140 additions & 0 deletions tests/test_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
Tests for pure-Python results post-processing helpers.
"""

import math

import pytest

from ionq_core.results import (
expectation_z,
marginal,
probabilities_to_counts,
relabel_to_bitstrings,
)

# --- Fixtures ---


@pytest.fixture
def bell_state() -> dict[str, float]:
"""A standard two-qubit Bell state response."""
return {"0": 0.5, "3": 0.5}


@pytest.fixture
def ghz_state() -> dict[str, float]:
"""A three-qubit GHZ state."""
return {"0": 0.5, "7": 0.5}


# --- Validation Tests ---


def test_invalid_probabilities():
"""Ensure non-finite and negative probabilities are rejected."""
with pytest.raises(ValueError, match="finite and non-negative"):
probabilities_to_counts({"0": -0.5}, 100)

with pytest.raises(ValueError, match="finite and non-negative"):
probabilities_to_counts({"0": math.inf}, 100)

with pytest.raises(ValueError, match="finite and non-negative"):
probabilities_to_counts({"0": math.nan}, 100)


# --- probabilities_to_counts Tests ---


def test_probabilities_to_counts_bell(bell_state):
"""Test standard perfect distribution."""
counts = probabilities_to_counts(bell_state, 100)
assert counts == {"0": 50, "3": 50}


def test_probabilities_to_counts_rounding():
"""Test largest-remainder method with tricky fractions and deterministic tie-breaking."""
probs = {"0": 1 / 3, "1": 1 / 3, "2": 1 / 3}
counts = probabilities_to_counts(probs, 10)
# Exact is 3.333 each. Base is 3, 3, 3. Shortfall is 1.
# Tie-breaker should pick the lowest integer state ("0") to get the +1.
assert counts == {"0": 4, "1": 3, "2": 3}
assert sum(counts.values()) == 10


def test_probabilities_to_counts_invalid_shots(bell_state):
with pytest.raises(ValueError, match="at least 1"):
probabilities_to_counts(bell_state, 0)


# --- relabel_to_bitstrings Tests ---


def test_relabel_to_bitstrings_bell(bell_state):
result = relabel_to_bitstrings(bell_state, 2)
assert result == {"00": 0.5, "11": 0.5}


def test_relabel_to_bitstrings_invalid_qubits(bell_state):
with pytest.raises(ValueError, match="at least 1"):
relabel_to_bitstrings(bell_state, 0)


def test_relabel_to_bitstrings_out_of_bounds():
with pytest.raises(ValueError, match="out of bounds"):
relabel_to_bitstrings({"4": 1.0}, 2)


# --- marginal Tests ---


def test_marginal_bell(bell_state):
"""Marginalizing a Bell state on either qubit gives a 50/50 mix."""
res_q0 = marginal(bell_state, [0], 2)
assert res_q0 == {"0": 0.5, "1": 0.5}

res_q1 = marginal(bell_state, [1], 2)
assert res_q1 == {"0": 0.5, "1": 0.5}


def test_marginal_ghz_subset(ghz_state):
"""Extracting qubits 0 and 2 from a 3-qubit GHZ state."""
# Qubit 0 and 2 from |000> is |00> (state 0). From |111> is |11> (state 3).
res = marginal(ghz_state, [0, 2], 3)
assert res == {"0": 0.5, "3": 0.5}


def test_marginal_invalid_inputs(bell_state):
with pytest.raises(ValueError, match="at least one qubit"):
marginal(bell_state, [], 2)

with pytest.raises(ValueError, match="unique"):
marginal(bell_state, [0, 0], 2)

with pytest.raises(ValueError, match="out of bounds"):
marginal(bell_state, [2], 2)

with pytest.raises(ValueError, match="out of bounds"):
marginal(bell_state, [-1], 2)


# --- expectation_z Tests ---


def test_expectation_z_bell(bell_state):
"""
Z-parity of |00> (popcount 0) is 1.
Z-parity of |11> (popcount 2) is 1.
Total expectation: 0.5*1 + 0.5*1 = 1.0
"""
assert expectation_z(bell_state, 2) == 1.0


def test_expectation_z_odd_parity():
"""State '1' is '01', popcount 1 -> parity -1."""
assert expectation_z({"1": 1.0}, 2) == -1.0


def test_expectation_z_out_of_bounds():
with pytest.raises(ValueError, match="out of bounds"):
expectation_z({"4": 1.0}, 2)