diff --git a/src/openfermion/measurements/fermion_partitioning.py b/src/openfermion/measurements/fermion_partitioning.py index e4d67b9c0..d0cf3529e 100644 --- a/src/openfermion/measurements/fermion_partitioning.py +++ b/src/openfermion/measurements/fermion_partitioning.py @@ -10,13 +10,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator, Iterable +from typing import Any, TypeVar + import numpy from openfermion.measurements import partition_iterator MAX_LOOPS = 1e6 +T = TypeVar('T') + +# A yielded pairing is a tuple whose elements are usually (T, T) pairs. +# Some code paths also append bare T values for unpaired leftovers. +Pairing = tuple[tuple[T, T] | T, ...] + -def pair_within(labels: list) -> list: +def pair_within(labels: list[T]) -> Generator[Pairing, None, None]: """ Generates pairings of labels that contain each pair at least once. @@ -51,24 +60,26 @@ def pair_within(labels: list) -> list: # Determine fragment size fragment_size = len(labels) // 2 - frag1 = labels[:fragment_size] - frag2 = labels[fragment_size:] + frag1_part: list[T] = labels[:fragment_size] + frag2: list[T] = labels[fragment_size:] - for pairing in pair_between(frag1, frag2, len(frag2) % 2): + for pairing in pair_between(frag1_part, frag2, len(frag2) % 2): yield pairing if len(labels) % 4 == 1: - frag1.append(None) + frag1_pairings = pair_within([*frag1_part, None]) + else: + frag1_pairings = pair_within(frag1_part) - for pairing1, pairing2 in zip(pair_within(frag1), pair_within(frag2)): + for pairing1, pairing2 in zip(frag1_pairings, pair_within(frag2)): if len(labels) % 4 == 1: if pairing1[-1] is None: yield pairing1[:-1] + pairing2 else: extra_pair = ((pairing1[-1], pairing2[-1]),) (zero_index,) = [pair[0] for pair in pairing1[:-1] if pair[1] is None] - pairing1 = tuple(pair for pair in pairing1[:-1] if pair[1] is not None) - yield pairing1 + pairing2[:-1] + extra_pair + (zero_index,) + pairing1_filtered = tuple(pair for pair in pairing1[:-1] if pair[1] is not None) + yield pairing1_filtered + pairing2[:-1] + extra_pair + (zero_index,) elif len(labels) % 4 == 2: extra_pair = ((pairing1[-1], pairing2[-1]),) @@ -81,7 +92,9 @@ def pair_within(labels: list) -> list: yield pairing1 + pairing2 -def pair_between(frag1: list, frag2: list, start_offset: int = 0) -> tuple: +def pair_between( + frag1: list[T], frag2: list[T], start_offset: int = 0 +) -> Generator[Pairing, None, None]: """Pairs between two fragments of a larger list A pairing of a list is a set of pairs of list elements. E.g. a pairing of @@ -115,6 +128,7 @@ def pair_between(frag1: list, frag2: list, start_offset: int = 0) -> tuple: num_pairs = min(len(frag1), len(frag2)) for index_offset in range(start_offset, num_iter): + pairing: Pairing if len(frag1) > len(frag2): pairing = tuple( (frag1[(index + index_offset) % len(frag1)], frag2[index]) @@ -197,7 +211,7 @@ def _gen_pairings_between_partitions(parta, partb): yield pair_a + pair_b + pair_ab -def pair_within_simultaneously(labels: list) -> tuple: +def pair_within_simultaneously(labels: list[T]) -> Generator[Pairing, None, None]: """Generates simultaneous pairings between four-element combinations A pairing of a list is a set of pairs of list elements. E.g. a pairing of @@ -230,7 +244,7 @@ def pair_within_simultaneously(labels: list) -> tuple: for partition in _gen_partitions(labels): generator_list = [_loop_iterator(pair_within, partition[j]) for j in range(len(partition))] for dummy1 in range(len(partition[-2]) - 1 + len(partition[-2]) % 2): - pairing = tuple() + pairing: Pairing = () for generator in generator_list[::2]: pairing = pairing + next(generator)[0] for dummy2 in range(len(partition[-1]) - 1 + len(partition[-1]) % 2): @@ -277,7 +291,9 @@ def _get_padding(num_bins, bin_size): trial_size += 1 -def _asynchronous_iter(iterators, flatten=False): +def _asynchronous_iter( + iterators: Iterable[Iterable[Any]], flatten: bool = False +) -> Generator[tuple[Any, ...], None, None]: """ Iterates over a set of K iterators with max L elements to generate all pairs between them in O(L^2 + 2L log(L) + log(L)^2), @@ -367,7 +383,9 @@ def _parallel_iter(iterators, flatten=False): yield tuple(next_result) -def pair_within_simultaneously_binned(binned_majoranas: list) -> tuple: +def pair_within_simultaneously_binned( + binned_majoranas: list[list[T]], +) -> Generator[Pairing, None, None]: """Generates symmetry-respecting pairings between four-elements in a list A pairing of a list is a set of pairs of list elements. E.g. a pairing of @@ -426,7 +444,9 @@ def pair_within_simultaneously_binned(binned_majoranas: list) -> tuple: yield pairing -def pair_within_simultaneously_symmetric(num_fermions: int, num_symmetries: int) -> tuple: +def pair_within_simultaneously_symmetric( + num_fermions: int, num_symmetries: int +) -> Generator[Pairing, None, None]: """Generates symmetry-respecting pairings between four-elements in a list A pairing of a list is a set of pairs of list elements. E.g. a pairing of