diff --git a/crates/core/src/types.rs b/crates/core/src/types.rs index 8d971f7d..d332ef30 100644 --- a/crates/core/src/types.rs +++ b/crates/core/src/types.rs @@ -69,6 +69,25 @@ impl DutyType { pub fn is_valid(&self) -> bool { !matches!(self, DutyType::Unknown | DutyType::DutySentinel(_)) } + + /// Returns all valid duty types, matching Go's `AllDutyTypes()`. + pub fn all() -> &'static [DutyType] { + &[ + DutyType::Proposer, + DutyType::Attester, + DutyType::Signature, + DutyType::Exit, + DutyType::BuilderProposer, + DutyType::BuilderRegistration, + DutyType::Randao, + DutyType::PrepareAggregator, + DutyType::Aggregator, + DutyType::SyncMessage, + DutyType::PrepareSyncContribution, + DutyType::SyncContribution, + DutyType::InfoSync, + ] + } } /// Error type for duty type conversion. @@ -400,7 +419,7 @@ impl AsRef<[u8]> for PubKey { // todo: add toEth2Format for the pub key // https://github.com/ObolNetwork/charon/blob/b3008103c5429b031b63518195f4c49db4e9a68d/core/types.go#L311 -/// Duty definition type +/// Duty definition type. #[derive(Debug, Clone, PartialEq, Eq)] pub struct DutyDefinition(T); @@ -412,11 +431,16 @@ where pub fn new(duty_definition: T) -> Self { Self(duty_definition) } + + /// Inner value. + pub fn inner(&self) -> &T { + &self.0 + } } -/// Duty definition set +/// One duty definition per validator, matching Go's `core.DutyDefinitionSet`. #[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct DutyDefinitionSet(HashMap>) +pub struct DutyDefinitionSet(HashMap>) where T: Clone + Serialize + StdDebug; @@ -429,28 +453,33 @@ where Self(HashMap::default()) } - /// Get a duty definition by duty type. - pub fn get(&self, duty_type: &DutyType) -> Option<&DutyDefinition> { - self.0.get(duty_type) + /// Get a duty definition by public key. + pub fn get(&self, pubkey: &PubKey) -> Option<&DutyDefinition> { + self.0.get(pubkey) } /// Insert a duty definition. - pub fn insert(&mut self, duty_type: DutyType, duty_definition: DutyDefinition) { - self.0.insert(duty_type, duty_definition); + pub fn insert(&mut self, pubkey: PubKey, duty_definition: DutyDefinition) { + self.0.insert(pubkey, duty_definition); } - /// Remove a duty definition by duty type. - pub fn remove(&mut self, duty_type: &DutyType) -> Option> { - self.0.remove(duty_type) + /// Remove a duty definition by public key. + pub fn remove(&mut self, pubkey: &PubKey) -> Option> { + self.0.remove(pubkey) + } + + /// Iterate over all public keys in the set. + pub fn keys(&self) -> impl Iterator { + self.0.keys() } - /// Inner duty definition set. - pub fn inner(&self) -> &HashMap> { + /// Inner map. + pub fn inner(&self) -> &HashMap> { &self.0 } - /// Inner duty definition set. - pub fn inner_mut(&mut self) -> &mut HashMap> { + /// Inner map (mutable). + pub fn inner_mut(&mut self) -> &mut HashMap> { &mut self.0 } } @@ -996,14 +1025,24 @@ mod tests { assert_eq!(pk.abbreviated(), "2a2_a2a"); } + #[test] + fn duty_type_all() { + let all = DutyType::all(); + assert_eq!(all.len(), 13); + assert!(all.iter().all(DutyType::is_valid)); + assert!(!all.contains(&DutyType::Unknown)); + } + #[test] fn duty_definition_set() { - let mut duty_definition_set = DutyDefinitionSet::new(); - duty_definition_set.insert(DutyType::Proposer, DutyDefinition::new(DutyType::Proposer)); + let pubkey = PubKey::new([1u8; PK_LEN]); + let mut set = DutyDefinitionSet::new(); + set.insert(pubkey, DutyDefinition::new(DutyType::Proposer)); assert_eq!( - duty_definition_set.get(&DutyType::Proposer), + set.get(&pubkey), Some(&DutyDefinition::new(DutyType::Proposer)) ); + assert_eq!(set.keys().count(), 1); } #[test]