From 04b1058c68de8a7bddd82697f83420679a2f64d8 Mon Sep 17 00:00:00 2001 From: Zach Anderson Date: Wed, 7 Jan 2026 01:56:43 -0600 Subject: [PATCH] Explictly call Ord::cmp to compare priorities Previously we would call `PartialOrd::partial_cmp` implicilty using the overloaded comparison operators `>` and `<` when comparing priorities, despite the fact that P is required to be `Ord`. A well behaved implementation of `partial_cmp` is supposed to return `Some(Ord::cmp(self, other))`. That relationship is only a convention, so a misbehaved implementation may return None, which may cause the order items are popped from the queue to behave seemingly randomly. We can be a bit more defensive here and instead always call `Ord::cmp` directly, ensuring that we never try to compare things that could possibly return `None`. In order to enforce this going forward I added a test that panics in the implementation for partial_cmp and exercised all of the code paths that might call it. This isn't perfect, since new callsites could be added, but I figure its probably good enough for now. Not sure exactly how to version this, it's not a breaking change if you implementation of `PartialOrd` follows the convention, but if you don't it would be a breaking change, since we'd be using a different function for comparisons. --- src/double_priority_queue/mod.rs | 45 ++++++++++++++++++++++++-------- src/priority_queue/mod.rs | 28 +++++++++++++++----- tests/double_priority_queue.rs | 35 +++++++++++++++++++++++++ tests/priority_queue.rs | 35 +++++++++++++++++++++++++ 4 files changed, 125 insertions(+), 18 deletions(-) diff --git a/src/double_priority_queue/mod.rs b/src/double_priority_queue/mod.rs index 103930a..1f2ca20 100644 --- a/src/double_priority_queue/mod.rs +++ b/src/double_priority_queue/mod.rs @@ -665,7 +665,10 @@ where /// /// Computes in **O(log(N))** time. pub fn push_increase(&mut self, item: I, priority: P) -> Option

{ - if self.get_priority(&item).map_or(true, |p| priority > *p) { + if self + .get_priority(&item) + .map_or(true, |p| priority.cmp(p).is_gt()) + { self.push(item, priority) } else { Some(priority) @@ -705,7 +708,10 @@ where /// /// Computes in **O(log(N))** time. pub fn push_decrease(&mut self, item: I, priority: P) -> Option

{ - if self.get_priority(&item).map_or(true, |p| priority < *p) { + if self + .get_priority(&item) + .map_or(true, |p| priority.cmp(p).is_lt()) + { self.push(item, priority) } else { Some(priority) @@ -901,15 +907,20 @@ where .0; if unsafe { - self.store.get_priority_from_position(i) < self.store.get_priority_from_position(m) + self.store + .get_priority_from_position(i) + .cmp(self.store.get_priority_from_position(m)) + .is_lt() } { self.store.swap(i, m); if i > r { // i is a grandchild of m let p = parent(i); if unsafe { - self.store.get_priority_from_position(i) - > self.store.get_priority_from_position(p) + self.store + .get_priority_from_position(i) + .cmp(self.store.get_priority_from_position(p)) + .is_gt() } { self.store.swap(i, p); } @@ -943,15 +954,20 @@ where .0; if unsafe { - self.store.get_priority_from_position(i) > self.store.get_priority_from_position(m) + self.store + .get_priority_from_position(i) + .cmp(self.store.get_priority_from_position(m)) + .is_gt() } { self.store.swap(i, m); if i > r { // i is a grandchild of m let p = parent(i); if unsafe { - self.store.get_priority_from_position(i) - < self.store.get_priority_from_position(p) + self.store + .get_priority_from_position(i) + .cmp(self.store.get_priority_from_position(p)) + .is_lt() } { self.store.swap(i, p); } @@ -970,7 +986,10 @@ where let parent = parent(position); let parent_priority = unsafe { self.store.get_priority_from_position(parent) }; let parent_index = unsafe { *self.store.heap.get_unchecked(parent.0) }; - position = match (level(position) % 2 == 0, parent_priority < priority) { + position = match ( + level(position) % 2 == 0, + parent_priority.cmp(priority).is_lt(), + ) { // on a min level and greater then parent (true, true) => { unsafe { @@ -1008,7 +1027,9 @@ where let mut grand_parent = Position(0); while if position.0 > 0 && parent(position).0 > 0 { grand_parent = parent(parent(position)); - (unsafe { self.store.get_priority_from_position(grand_parent) }) > priority + (unsafe { self.store.get_priority_from_position(grand_parent) }) + .cmp(priority) + .is_gt() } else { false } { @@ -1027,7 +1048,9 @@ where let mut grand_parent = Position(0); while if position.0 > 0 && parent(position).0 > 0 { grand_parent = parent(parent(position)); - (unsafe { self.store.get_priority_from_position(grand_parent) }) < priority + (unsafe { self.store.get_priority_from_position(grand_parent) }) + .cmp(priority) + .is_lt() } else { false } { diff --git a/src/priority_queue/mod.rs b/src/priority_queue/mod.rs index 5eb66bc..a2c00bb 100644 --- a/src/priority_queue/mod.rs +++ b/src/priority_queue/mod.rs @@ -539,7 +539,10 @@ where /// /// Computes in **O(log(N))** time. pub fn push_increase(&mut self, item: I, priority: P) -> Option

{ - if self.get_priority(&item).map_or(true, |p| priority > *p) { + if self + .get_priority(&item) + .map_or(true, |p| priority.cmp(p).is_gt()) + { self.push(item, priority) } else { Some(priority) @@ -579,7 +582,10 @@ where /// /// Computes in **O(log(N))** time. pub fn push_decrease(&mut self, item: I, priority: P) -> Option

{ - if self.get_priority(&item).map_or(true, |p| priority < *p) { + if self + .get_priority(&item) + .map_or(true, |p| priority.cmp(p).is_lt()) + { self.push(item, priority) } else { Some(priority) @@ -762,12 +768,16 @@ where let mut largestp = unsafe { self.store.get_priority_from_position(i) }; if l.0 < self.len() { let childp = unsafe { self.store.get_priority_from_position(l) }; - if childp > largestp { + if childp.cmp(largestp).is_gt() { largest = l; largestp = childp; } - if r.0 < self.len() && unsafe { self.store.get_priority_from_position(r) } > largestp { + if r.0 < self.len() + && unsafe { self.store.get_priority_from_position(r) } + .cmp(largestp) + .is_gt() + { largest = r; } } @@ -780,14 +790,16 @@ where l = left(i); if l.0 < self.len() { let childp = unsafe { self.store.get_priority_from_position(l) }; - if childp > largestp { + if childp.cmp(largestp).is_gt() { largest = l; largestp = childp; } r = right(i); if r.0 < self.len() - && unsafe { self.store.get_priority_from_position(r) } > largestp + && unsafe { self.store.get_priority_from_position(r) } + .cmp(largestp) + .is_gt() { largest = r; } @@ -802,7 +814,9 @@ where let mut parent_position = Position(0); while if position.0 > 0 { parent_position = parent(position); - (unsafe { self.store.get_priority_from_position(parent_position) }) < priority + (unsafe { self.store.get_priority_from_position(parent_position) }) + .cmp(priority) + .is_lt() } else { false } { diff --git a/tests/double_priority_queue.rs b/tests/double_priority_queue.rs index c5bacf4..5330029 100644 --- a/tests/double_priority_queue.rs +++ b/tests/double_priority_queue.rs @@ -1275,6 +1275,41 @@ mod doublepq_tests { ); } } + + #[test] + fn partial_cmp_not_called() { + use std::cmp::{Ordering, PartialOrd}; + + #[derive(Debug, PartialEq, Eq, Hash, Ord)] + struct PanicPartial(i64); + + // This is an invalid implementation of PartialOrd according to + // the docs in `std::cmp::PartialOrd`, since Ord is also implemented, + // this should always return Some(Ord::cmp(self, other)). Instead this + // function panics as a way to ensure that we don't accidently + // use PartialOrd::partial_cmp when we should be using Ord::cmp + // instead. Enforcing the explicit use of Ord::cmp lets us rely on + // the compiler instead of the convention that PartialOrd::partial_cmp + // _should_ call Ord::cmp + impl PartialOrd for PanicPartial { + fn partial_cmp(&self, _other: &Self) -> Option { + panic!("partial_cmp should not be called"); + } + } + + let mut dpq = DoublePriorityQueue::new(); + dpq.push(0, PanicPartial(100)); + dpq.push(1, PanicPartial(200)); + dpq.push(2, PanicPartial(150)); + dpq.push_increase(2, PanicPartial(300)); + dpq.push_decrease(2, PanicPartial(0)); + + // These asserts are redundant since this behavior is tested elsewhere, we're + // mainly just interested in not panicking for this test. + assert_eq!(dpq.pop_min(), Some((2, PanicPartial(0)))); + assert_eq!(dpq.pop_max(), Some((1, PanicPartial(200)))); + assert_eq!(dpq.pop_min(), Some((0, PanicPartial(100)))); + } } #[cfg(all(feature = "serde", test))] diff --git a/tests/priority_queue.rs b/tests/priority_queue.rs index 93b3cab..25070c5 100644 --- a/tests/priority_queue.rs +++ b/tests/priority_queue.rs @@ -1143,6 +1143,41 @@ mod pqueue_tests { assert_eq!(removed_priority, 200); assert!(!pq.contains(&bob_view)); } + + #[test] + fn partial_cmp_not_called() { + use std::cmp::{Ordering, PartialOrd}; + + #[derive(Debug, PartialEq, Eq, Hash, Ord)] + struct PanicPartial(i64); + + // This is an invalid implementation of PartialOrd according to + // the docs in `std::cmp::PartialOrd`, since Ord is also implemented, + // this should always return Some(Ord::cmp(self, other)). Instead this + // function panics as a way to ensure that we don't accidently + // use PartialOrd::partial_cmp when we should be using Ord::cmp + // instead. Enforcing the explicit use of Ord::cmp lets us rely on + // the compiler instead of the convention that PartialOrd::partial_cmp + // _should_ call Ord::cmp + impl PartialOrd for PanicPartial { + fn partial_cmp(&self, _other: &Self) -> Option { + panic!("partial_cmp should not be called"); + } + } + + let mut pq = PriorityQueue::new(); + pq.push(0, PanicPartial(100)); + pq.push(1, PanicPartial(200)); + pq.push(2, PanicPartial(150)); + pq.push_increase(2, PanicPartial(300)); + pq.push_decrease(2, PanicPartial(0)); + + // These asserts are redundant since this behavior is tested elsewhere, we're + // mainly just interested in not panicking for this test. + assert_eq!(pq.pop(), Some((1, PanicPartial(200)))); + assert_eq!(pq.pop(), Some((0, PanicPartial(100)))); + assert_eq!(pq.pop(), Some((2, PanicPartial(0)))); + } } #[cfg(all(feature = "serde", test))]