Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions crates/core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,25 @@ impl DutyType {
pub fn is_valid(&self) -> bool {
!matches!(self, DutyType::Unknown | DutyType::DutySentinel(_))
}

/// Returns all valid duty types, matching Go's `AllDutyTypes()`.
pub fn all() -> &'static [DutyType] {
&[
DutyType::Proposer,
DutyType::Attester,
DutyType::Signature,
DutyType::Exit,
DutyType::BuilderProposer,
DutyType::BuilderRegistration,
DutyType::Randao,
DutyType::PrepareAggregator,
DutyType::Aggregator,
DutyType::SyncMessage,
DutyType::PrepareSyncContribution,
DutyType::SyncContribution,
DutyType::InfoSync,
]
}
}

/// Error type for duty type conversion.
Expand Down Expand Up @@ -400,7 +419,7 @@ impl AsRef<[u8]> for PubKey {
// todo: add toEth2Format for the pub key
// https://github.com/ObolNetwork/charon/blob/b3008103c5429b031b63518195f4c49db4e9a68d/core/types.go#L311

/// Duty definition type
/// Duty definition type.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DutyDefinition<T: Clone + Serialize + StdDebug>(T);

Expand All @@ -412,11 +431,16 @@ where
pub fn new(duty_definition: T) -> Self {
Self(duty_definition)
}

/// Inner value.
pub fn inner(&self) -> &T {
&self.0
}
}

/// Duty definition set
/// One duty definition per validator, matching Go's `core.DutyDefinitionSet`.
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct DutyDefinitionSet<T>(HashMap<DutyType, DutyDefinition<T>>)
pub struct DutyDefinitionSet<T>(HashMap<PubKey, DutyDefinition<T>>)
where
T: Clone + Serialize + StdDebug;

Expand All @@ -429,28 +453,33 @@ where
Self(HashMap::default())
}

/// Get a duty definition by duty type.
pub fn get(&self, duty_type: &DutyType) -> Option<&DutyDefinition<T>> {
self.0.get(duty_type)
/// Get a duty definition by public key.
pub fn get(&self, pubkey: &PubKey) -> Option<&DutyDefinition<T>> {
self.0.get(pubkey)
}

/// Insert a duty definition.
pub fn insert(&mut self, duty_type: DutyType, duty_definition: DutyDefinition<T>) {
self.0.insert(duty_type, duty_definition);
pub fn insert(&mut self, pubkey: PubKey, duty_definition: DutyDefinition<T>) {
self.0.insert(pubkey, duty_definition);
}

/// Remove a duty definition by duty type.
pub fn remove(&mut self, duty_type: &DutyType) -> Option<DutyDefinition<T>> {
self.0.remove(duty_type)
/// Remove a duty definition by public key.
pub fn remove(&mut self, pubkey: &PubKey) -> Option<DutyDefinition<T>> {
self.0.remove(pubkey)
}

/// Iterate over all public keys in the set.
pub fn keys(&self) -> impl Iterator<Item = &PubKey> {
self.0.keys()
}

/// Inner duty definition set.
pub fn inner(&self) -> &HashMap<DutyType, DutyDefinition<T>> {
/// Inner map.
pub fn inner(&self) -> &HashMap<PubKey, DutyDefinition<T>> {
&self.0
}

/// Inner duty definition set.
pub fn inner_mut(&mut self) -> &mut HashMap<DutyType, DutyDefinition<T>> {
/// Inner map (mutable).
pub fn inner_mut(&mut self) -> &mut HashMap<PubKey, DutyDefinition<T>> {
&mut self.0
}
}
Expand Down Expand Up @@ -996,14 +1025,24 @@ mod tests {
assert_eq!(pk.abbreviated(), "2a2_a2a");
}

#[test]
fn duty_type_all() {
let all = DutyType::all();
assert_eq!(all.len(), 13);
assert!(all.iter().all(DutyType::is_valid));
assert!(!all.contains(&DutyType::Unknown));
}

#[test]
fn duty_definition_set() {
let mut duty_definition_set = DutyDefinitionSet::new();
duty_definition_set.insert(DutyType::Proposer, DutyDefinition::new(DutyType::Proposer));
let pubkey = PubKey::new([1u8; PK_LEN]);
let mut set = DutyDefinitionSet::new();
set.insert(pubkey, DutyDefinition::new(DutyType::Proposer));
assert_eq!(
duty_definition_set.get(&DutyType::Proposer),
set.get(&pubkey),
Some(&DutyDefinition::new(DutyType::Proposer))
);
assert_eq!(set.keys().count(), 1);
}

#[test]
Expand Down
Loading