diff --git a/mypy/checker.py b/mypy/checker.py index 20a825e9cc5e..78e99f99afee 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -85,11 +85,13 @@ CallExpr, ClassDef, ComparisonExpr, + ConditionalExpr, Context, ContinueStmt, Decorator, DelStmt, DictExpr, + DictionaryComprehension, EllipsisExpr, Expression, ExpressionStmt, @@ -98,6 +100,7 @@ FuncBase, FuncDef, FuncItem, + GeneratorExpr, GlobalDecl, IfStmt, Import, @@ -107,6 +110,7 @@ IndexExpr, IntExpr, LambdaExpr, + ListComprehension, ListExpr, Lvalue, MatchStmt, @@ -124,7 +128,9 @@ RaiseStmt, RefExpr, ReturnStmt, + SetComprehension, SetExpr, + SliceExpr, StarExpr, Statement, StrExpr, @@ -4928,6 +4934,8 @@ def infer_context_dependent( return typ # If there are errors with the original type context, try re-inferring in empty context. + # However, skip this fallback if the expression contains assignment expressions (walrus + # operator), as they can cause incorrect type inference when the context is removed. original_messages = msg.filtered_errors() original_type_map = type_map with self.msg.filter_errors( @@ -4937,7 +4945,12 @@ def infer_context_dependent( alt_typ = get_proper_type( self.expr_checker.accept(expr, None, allow_none_return=allow_none_func_call) ) - if not msg.has_new_errors() and is_subtype(alt_typ, type_ctx): + + if ( + not msg.has_new_errors() + and is_subtype(alt_typ, type_ctx) + and not self.contains_assignment_expr(expr) + ): self.store_types(type_map) return alt_typ @@ -4979,7 +4992,10 @@ def check_return_stmt(self, s: ReturnStmt) -> None: # Return with a value. if ( - isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr)) + isinstance( + s.expr, + (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr, AssignmentExpr), + ) or isinstance(s.expr, AwaitExpr) and isinstance(s.expr.expr, CallExpr) ): @@ -5057,6 +5073,125 @@ def check_return_stmt(self, s: ReturnStmt) -> None: if self.in_checked_function(): self.fail(message_registry.RETURN_VALUE_EXPECTED, s) + def contains_assignment_expr(self, expr: Expression) -> bool: + """Check if expression contains any AssignmentExpr (walrus operator).""" + # Base case: found an assignment expression + if isinstance(expr, AssignmentExpr): + return True + + # Recursively check nested expressions in various expression types + + # Container expressions + if isinstance(expr, (TupleExpr, ListExpr, SetExpr)): + return any(self.contains_assignment_expr(item) for item in expr.items) + + if isinstance(expr, DictExpr): + # Check both keys and values + # DictExpr.items is list[tuple[Expression | None, Expression]] + for key_expr, value_expr in expr.items: + if key_expr is not None and self.contains_assignment_expr(key_expr): + return True + if self.contains_assignment_expr(value_expr): + return True + return False + + # Binary operations (left and right operands) + if isinstance(expr, OpExpr): + return self.contains_assignment_expr(expr.left) or self.contains_assignment_expr( + expr.right + ) + + # Unary operations + if isinstance(expr, UnaryExpr): + return self.contains_assignment_expr(expr.expr) + + # Comparison expressions (multiple operands) + if isinstance(expr, ComparisonExpr): + return any(self.contains_assignment_expr(operand) for operand in expr.operands) + + # Function calls (check arguments) + if isinstance(expr, CallExpr): + # Check callee and all arguments + if self.contains_assignment_expr(expr.callee): + return True + return any(self.contains_assignment_expr(arg) for arg in expr.args) + + # Index expressions (subscripts) + if isinstance(expr, IndexExpr): + if self.contains_assignment_expr(expr.base): + return True + return self.contains_assignment_expr(expr.index) + + # Member access + if isinstance(expr, MemberExpr): + return self.contains_assignment_expr(expr.expr) + + # Starred expressions (unpacking) + if isinstance(expr, StarExpr): + return self.contains_assignment_expr(expr.expr) + + # Await expressions + if isinstance(expr, AwaitExpr): + return self.contains_assignment_expr(expr.expr) + + # Yield expressions + if isinstance(expr, YieldExpr): + if expr.expr is not None: + return self.contains_assignment_expr(expr.expr) + return False + + # Conditional expressions (ternary operator: x if cond else y) + if isinstance(expr, ConditionalExpr): + return ( + self.contains_assignment_expr(expr.cond) + or self.contains_assignment_expr(expr.if_expr) + or self.contains_assignment_expr(expr.else_expr) + ) + + # Slice expressions (x:y:z) + if isinstance(expr, SliceExpr): + return ( + (expr.begin_index is not None and self.contains_assignment_expr(expr.begin_index)) + or (expr.end_index is not None and self.contains_assignment_expr(expr.end_index)) + or (expr.stride is not None and self.contains_assignment_expr(expr.stride)) + ) + + # Generator expressions and comprehensions + if isinstance(expr, GeneratorExpr): + if self.contains_assignment_expr(expr.left_expr): + return True + for seq in expr.sequences: + if self.contains_assignment_expr(seq): + return True + for condlist in expr.condlists: + for cond in condlist: + if self.contains_assignment_expr(cond): + return True + return False + + if isinstance(expr, ListComprehension): + return self.contains_assignment_expr(expr.generator) + + if isinstance(expr, SetComprehension): + return self.contains_assignment_expr(expr.generator) + + if isinstance(expr, DictionaryComprehension): + if self.contains_assignment_expr(expr.key) or self.contains_assignment_expr( + expr.value + ): + return True + for seq in expr.sequences: + if self.contains_assignment_expr(seq): + return True + for condlist in expr.condlists: + for cond in condlist: + if self.contains_assignment_expr(cond): + return True + return False + + # All other expression types (NameExpr, IntExpr, StrExpr, etc.) don't contain nested expressions + return False + def visit_if_stmt(self, s: IfStmt) -> None: """Type check an if statement.""" # This frame records the knowledge from previous if/elif clauses not being taken. @@ -6114,8 +6249,35 @@ def conditional_callable_type_map( if not current_type: return {}, {} - if isinstance(get_proper_type(current_type), AnyType): - return {}, {} + proper_type = get_proper_type(current_type) + if isinstance(proper_type, AnyType): + # Narrow Any to a generic callable type to satisfy no-any-return in strict mode. + # We use a synthesized fallback "" to preserve attribute + # access (fixing regressions in sympy, pandas, etc.) without triggering + # metaclass-related internal errors or breaking invariant subtyping. + obj_fallback = self.named_type("builtins.object") + if obj_fallback.type.fullname == "builtins.object": + cdef = nodes.ClassDef("", nodes.Block([])) + cdef._fullname = "" + info = TypeInfo(nodes.SymbolTable(), cdef, "") + info.mro = obj_fallback.type.mro + info.bases = obj_fallback.type.bases + info.fallback_to_any = True + fallback_instance = Instance(info, []) + + return { + expr: CallableType( + [ + AnyType(TypeOfAny.from_another_any, source_any=proper_type), + AnyType(TypeOfAny.from_another_any, source_any=proper_type), + ], + [nodes.ARG_STAR, nodes.ARG_STAR2], + [None, None], + ret_type=AnyType(TypeOfAny.from_another_any, source_any=proper_type), + fallback=fallback_instance, + is_ellipsis_args=True, + ) + }, {} callables, uncallables = self.partition_by_callable(current_type, unsound_partition=False) diff --git a/mypy/nodes.py b/mypy/nodes.py index 4168b2e00f15..f377c27c822f 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -3829,7 +3829,9 @@ def is_metaclass(self, *, precise: bool = False) -> bool: return ( self.has_base("builtins.type") or self.fullname == "abc.ABCMeta" - or (self.fallback_to_any and not precise) + or ( + self.fallback_to_any and not precise and self.fullname != "" + ) ) def has_base(self, fullname: str) -> bool: diff --git a/test-data/unit/check-callable.test b/test-data/unit/check-callable.test index f3ec6ec5f939..6e5c75db3537 100644 --- a/test-data/unit/check-callable.test +++ b/test-data/unit/check-callable.test @@ -188,7 +188,7 @@ from typing import Any x = 5 # type: Any if callable(x): - reveal_type(x) # N: Revealed type is "Any" + reveal_type(x) # N: Revealed type is "def (*Any, **Any) -> Any" else: reveal_type(x) # N: Revealed type is "Any" [builtins fixtures/callable.pyi]