diff --git a/pyrefly/lib/alt/attr.rs b/pyrefly/lib/alt/attr.rs index 96b6a475c0..a0b3a7d48d 100644 --- a/pyrefly/lib/alt/attr.rs +++ b/pyrefly/lib/alt/attr.rs @@ -17,6 +17,7 @@ use pyrefly_types::literal::LitEnum; use pyrefly_types::special_form::SpecialForm; use pyrefly_types::tensor::TensorShape; use pyrefly_types::tensor::TensorType; +use pyrefly_types::tuple::Tuple; use pyrefly_types::typed_dict::TypedDictInner; use pyrefly_types::types::Forall; use pyrefly_types::types::Forallable; @@ -59,7 +60,6 @@ use crate::types::module::ModuleType; use crate::types::quantified::Quantified; use crate::types::quantified::QuantifiedKind; use crate::types::read_only::ReadOnlyReason; -use crate::types::tuple::Tuple; use crate::types::type_var::Restriction; use crate::types::typed_dict::TypedDict; use crate::types::types::AnyStyle; @@ -935,9 +935,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { .add_to(errors, range, attr_name, todo_ctx); return None; }; - let (lookup_found, lookup_not_found, lookup_error) = self + let (mut lookup_found, mut lookup_not_found, lookup_error) = self .lookup_attr_from_base(attr_base.clone(), attr_name) .decompose(); + let slot_violations = self.apply_slots_restriction_for_write(attr_name, &mut lookup_found); + if !slot_violations.is_empty() { + should_narrow = false; + lookup_not_found.extend(slot_violations); + } for e in lookup_error { e.add_to(errors, range, attr_name, todo_ctx); should_narrow = false; @@ -1065,9 +1070,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { .add_to(errors, range, attr_name, todo_ctx); return; }; - let (lookup_found, lookup_not_found, lookup_error) = self + let (mut lookup_found, mut lookup_not_found, lookup_error) = self .lookup_attr_from_base(attr_base.clone(), attr_name) .decompose(); + let slot_violations = self.apply_slots_restriction_for_write(attr_name, &mut lookup_found); + lookup_not_found.extend(slot_violations); for not_found in lookup_not_found { self.check_delattr( attr_base.clone(), @@ -1106,6 +1113,106 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } + fn apply_slots_restriction_for_write( + &self, + attr_name: &Name, + lookup_found: &mut Vec<(Attribute, AttributeBase1)>, + ) -> Vec { + if lookup_found.is_empty() { + return Vec::new(); + } + let mut kept = Vec::with_capacity(lookup_found.len()); + let mut violations = Vec::new(); + for (attr, base) in lookup_found.drain(..) { + match &attr { + Attribute::ClassAttribute(ClassAttribute::Property(..)) + | Attribute::ClassAttribute(ClassAttribute::Descriptor(..)) => { + kept.push((attr, base)); + continue; + } + _ => {} + } + let Some(class) = self.class_for_slots_restriction(&base) else { + kept.push((attr, base)); + continue; + }; + let Some(slots) = self.slots_for_class(&class) else { + kept.push((attr, base)); + continue; + }; + if slots.contains(attr_name) { + kept.push((attr, base)); + } else { + violations.push(NotFoundOn::ClassInstance( + class.class_object().dupe(), + base.clone(), + )); + } + } + *lookup_found = kept; + violations + } + + fn class_for_slots_restriction(&self, base: &AttributeBase1) -> Option { + match base { + AttributeBase1::ClassInstance(cls) + | AttributeBase1::SelfType(cls) + | AttributeBase1::Quantified(_, cls) + | AttributeBase1::SuperInstance(cls, _) => Some(cls.clone()), + AttributeBase1::EnumLiteral(lit) => Some(lit.class.clone()), + _ => None, + } + } + + fn slots_for_class(&self, cls: &ClassType) -> Option> { + let mro = self.get_mro_for_class(cls.class_object()); + let mut slots = SmallSet::new(); + let dict_name = Name::new_static("__dict__"); + let classes = std::iter::once(cls.class_object().dupe()).chain( + mro.ancestors_no_object() + .iter() + .map(|c| c.class_object().dupe()), + ); + for class in classes { + let Some(field) = self.get_field_from_current_class_only(&class, &dunder::SLOTS) else { + return None; + }; + let Some(names) = self.extract_slot_names_from_type(&field.ty()) else { + return None; + }; + if names.contains(&dict_name) { + return None; + } + slots.extend(names); + } + Some(slots) + } + + fn extract_slot_names_from_type(&self, ty: &Type) -> Option> { + let mut slots = SmallSet::new(); + match ty { + Type::Tuple(Tuple::Concrete(elts)) => { + for elt in elts { + let Type::Literal(lit) = elt else { + return None; + }; + let Lit::Str(name) = &lit.value else { + return None; + }; + slots.insert(Name::new(name.as_str())); + } + } + Type::Literal(lit) => { + let Lit::Str(name) = &lit.value else { + return None; + }; + slots.insert(Name::new(name.as_str())); + } + _ => return None, + } + Some(slots) + } + /// Predicate for whether a specific attribute name matches a protocol during structural /// subtyping checks. /// diff --git a/pyrefly/lib/test/dataclasses.rs b/pyrefly/lib/test/dataclasses.rs index 26b1369eef..5281783d86 100644 --- a/pyrefly/lib/test/dataclasses.rs +++ b/pyrefly/lib/test/dataclasses.rs @@ -1729,7 +1729,6 @@ assert_type(dc2.y, str) # E: assert_type(Desc2[str], str) failed ); testcase!( - bug = "conformance: Dataclass with slots=True should error when setting undeclared attributes", test_dataclass_slots_undeclared_attr_conformance, r#" from dataclasses import dataclass @@ -1741,7 +1740,7 @@ class DC2: def __init__(self): self.x = 3 # should error: y is not in slots - self.y = 3 + self.y = 3 # E: Object of class `DC2` has no attribute `y` @dataclass(slots=False) class DC3: @@ -1751,7 +1750,7 @@ class DC3: def __init__(self): self.x = 3 # should error: y is not in slots - self.y = 3 + self.y = 3 # E: Object of class `DC3` has no attribute `y` "#, );