diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 30d117e42c71..6018634b4c48 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -227,6 +227,7 @@ def __init__( self.nested_fitems = pbv.nested_funcs.keys() self.fdefs_to_decorators = pbv.funcs_to_decorators self.module_import_groups = pbv.module_import_groups + self.comprehension_to_fitem = pbv.comprehension_to_fitem self.singledispatch_impls = singledispatch_impls @@ -1263,6 +1264,37 @@ def leave(self) -> tuple[list[Register], list[RuntimeArg], list[BasicBlock], RTy self.fn_info = self.fn_infos[-1] return builder.args, runtime_args, builder.blocks, ret_type, fn_info + @contextmanager + def enter_scope(self, fn_info: FuncInfo) -> Iterator[None]: + """Push a lightweight scope for comprehensions. + + Unlike enter(), this reuses the same LowLevelIRBuilder (same basic + blocks and registers) but pushes new symtable and fn_info entries + so that the closure machinery sees a scope boundary. + """ + self.builders.append(self.builder) + # Copy the parent symtable so variables from the enclosing scope + # (e.g. function parameters used as the comprehension iterable) + # remain accessible. The comprehension is inlined (same basic blocks + # and registers), so the parent's register references are still valid. + self.symtables.append(dict(self.symtables[-1])) + self.runtime_args.append([]) + self.fn_info = fn_info + self.fn_infos.append(self.fn_info) + self.ret_types.append(none_rprimitive) + self.nonlocal_control.append(BaseNonlocalControl()) + try: + yield + finally: + self.builders.pop() + self.symtables.pop() + self.runtime_args.pop() + self.ret_types.pop() + self.fn_infos.pop() + self.nonlocal_control.pop() + self.builder = self.builders[-1] + self.fn_info = self.fn_infos[-1] + @contextmanager def enter_method( self, diff --git a/mypyc/irbuild/callable_class.py b/mypyc/irbuild/callable_class.py index f1dc32c6b5c8..9ca1318a843f 100644 --- a/mypyc/irbuild/callable_class.py +++ b/mypyc/irbuild/callable_class.py @@ -224,11 +224,13 @@ def instantiate_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> Value: # - A generator function: the callable class is instantiated # from the '__next__' method of the generator class, and hence the # environment of the generator class is used. - # - Regular function: we use the environment of the original function. + # - Regular function or comprehension scope: we use the environment + # of the original function. Comprehension scopes are inlined (no + # callable class), so they fall into this case despite is_nested. curr_env_reg = None if builder.fn_info.is_generator: curr_env_reg = builder.fn_info.generator_class.curr_env_reg - elif builder.fn_info.is_nested: + elif builder.fn_info.is_nested and not builder.fn_info.is_comprehension_scope: curr_env_reg = builder.fn_info.callable_class.curr_env_reg elif builder.fn_info.contains_nested: curr_env_reg = builder.fn_info.curr_env_reg diff --git a/mypyc/irbuild/context.py b/mypyc/irbuild/context.py index d5a48bf838c8..c1997a921333 100644 --- a/mypyc/irbuild/context.py +++ b/mypyc/irbuild/context.py @@ -23,6 +23,7 @@ def __init__( is_decorated: bool = False, in_non_ext: bool = False, add_nested_funcs_to_env: bool = False, + is_comprehension_scope: bool = False, ) -> None: self.fitem = fitem self.name = name @@ -49,6 +50,11 @@ def __init__( self.is_decorated = is_decorated self.in_non_ext = in_non_ext self.add_nested_funcs_to_env = add_nested_funcs_to_env + # Comprehension scopes are lightweight scope boundaries created when + # a comprehension body contains a lambda. The comprehension is still + # inlined (same basic blocks), but we push a new FuncInfo so the + # closure machinery can capture loop variables through env classes. + self.is_comprehension_scope = is_comprehension_scope # TODO: add field for ret_type: RType = none_rprimitive diff --git a/mypyc/irbuild/env_class.py b/mypyc/irbuild/env_class.py index a69340517863..3128543d4cd2 100644 --- a/mypyc/irbuild/env_class.py +++ b/mypyc/irbuild/env_class.py @@ -56,7 +56,7 @@ class is generated, the function environment has not yet been ) env_class.reuse_freed_instance = True env_class.attributes[SELF_NAME] = RInstance(env_class) - if builder.fn_info.is_nested: + if builder.fn_info.is_nested and builder.fn_infos[-2]._env_class is not None: # If the function is nested, its environment class must contain an environment # attribute pointing to its encapsulating functions' environment class. env_class.attributes[ENV_ATTR_NAME] = RInstance(builder.fn_infos[-2].env_class) @@ -73,11 +73,14 @@ def finalize_env_class(builder: IRBuilder, prefix: str = "") -> None: # Iterate through the function arguments and replace local definitions (using registers) # that were previously added to the environment with references to the function's - # environment class. - if builder.fn_info.is_nested: - add_args_to_env(builder, local=False, base=builder.fn_info.callable_class, prefix=prefix) - else: - add_args_to_env(builder, local=False, base=builder.fn_info, prefix=prefix) + # environment class. Comprehension scopes have no arguments to add. + if not builder.fn_info.is_comprehension_scope: + if builder.fn_info.is_nested: + add_args_to_env( + builder, local=False, base=builder.fn_info.callable_class, prefix=prefix + ) + else: + add_args_to_env(builder, local=False, base=builder.fn_info, prefix=prefix) def instantiate_env_class(builder: IRBuilder) -> Value: @@ -86,7 +89,7 @@ def instantiate_env_class(builder: IRBuilder) -> Value: Call(builder.fn_info.env_class.ctor, [], builder.fn_info.fitem.line) ) - if builder.fn_info.is_nested: + if builder.fn_info.is_nested and not builder.fn_info.is_comprehension_scope: builder.fn_info.callable_class._curr_env_reg = curr_env_reg builder.add( SetAttr( @@ -97,7 +100,22 @@ def instantiate_env_class(builder: IRBuilder) -> Value: ) ) else: + # Top-level functions and comprehension scopes store env reg directly. builder.fn_info._curr_env_reg = curr_env_reg + # Comprehension scopes link to parent env if it exists. + if ( + builder.fn_info.is_nested + and builder.fn_infos[-2]._env_class is not None + and builder.fn_infos[-2]._curr_env_reg is not None + ): + builder.add( + SetAttr( + curr_env_reg, + ENV_ATTR_NAME, + builder.fn_infos[-2].curr_env_reg, + builder.fn_info.fitem.line, + ) + ) return curr_env_reg @@ -114,7 +132,7 @@ def load_env_registers(builder: IRBuilder, prefix: str = "") -> None: fn_info = builder.fn_info fitem = fn_info.fitem - if fn_info.is_nested: + if fn_info.is_nested and builder.fn_infos[-2]._env_class is not None: load_outer_envs(builder, fn_info.callable_class) # If this is a FuncDef, then make sure to load the FuncDef into its own environment # class so that the function can be called recursively. @@ -155,7 +173,8 @@ def load_outer_envs(builder: IRBuilder, base: ImplicitClass) -> None: # Load the first outer environment. This one is special because it gets saved in the # FuncInfo instance's prev_env_reg field. - if index > 1: + has_outer = index > 1 or (index == 1 and builder.fn_infos[1].contains_nested) + if has_outer and builder.fn_infos[index]._env_class is not None: # outer_env = builder.fn_infos[index].environment outer_env = builder.symtables[index] if isinstance(base, GeneratorClass): @@ -167,6 +186,8 @@ def load_outer_envs(builder: IRBuilder, base: ImplicitClass) -> None: # Load the remaining outer environments into registers. while index > 1: + if builder.fn_infos[index]._env_class is None: + break # outer_env = builder.fn_infos[index].environment outer_env = builder.symtables[index] env_reg = load_outer_env(builder, env_reg, outer_env) @@ -224,7 +245,9 @@ def add_vars_to_env(builder: IRBuilder, prefix: str = "") -> None: env_for_func: FuncInfo | ImplicitClass = builder.fn_info if builder.fn_info.is_generator: env_for_func = builder.fn_info.generator_class - elif builder.fn_info.is_nested or builder.fn_info.in_non_ext: + elif ( + builder.fn_info.is_nested or builder.fn_info.in_non_ext + ) and not builder.fn_info.is_comprehension_scope: env_for_func = builder.fn_info.callable_class if builder.fn_info.fitem in builder.free_variables: diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 1e7ece6eeacf..f45319b7ef2b 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -1164,20 +1164,45 @@ def _visit_display( # Comprehensions +# +# mypyc always inlines comprehensions (the loop body is emitted directly into +# the enclosing function's IR, no implicit function call like CPython). +# +# However, when a comprehension body contains a lambda, we need a lightweight +# scope boundary so the closure/env-class machinery can see the comprehension +# as a separate scope. The comprehension is still inlined (same basic blocks +# and registers), but we push a new FuncInfo and set up an env class so the +# lambda can capture loop variables through the standard env-class chain. def transform_list_comprehension(builder: IRBuilder, o: ListComprehension) -> Value: - return translate_list_comprehension(builder, o.generator) + gen = o.generator + if gen in builder.comprehension_to_fitem: + return _translate_comprehension_with_scope( + builder, gen, lambda: translate_list_comprehension(builder, gen) + ) + return translate_list_comprehension(builder, gen) def transform_set_comprehension(builder: IRBuilder, o: SetComprehension) -> Value: - return translate_set_comprehension(builder, o.generator) + gen = o.generator + if gen in builder.comprehension_to_fitem: + return _translate_comprehension_with_scope( + builder, gen, lambda: translate_set_comprehension(builder, gen) + ) + return translate_set_comprehension(builder, gen) def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehension) -> Value: if raise_error_if_contains_unreachable_names(builder, o): return builder.none() + if o in builder.comprehension_to_fitem: + return _translate_comprehension_with_scope(builder, o, lambda: _dict_comp_body(builder, o)) + return _dict_comp_body(builder, o) + + +def _dict_comp_body(builder: IRBuilder, o: DictionaryComprehension) -> Value: d = builder.maybe_spill(builder.call_c(dict_new_op, [], o.line)) loop_params = list(zip(o.indices, o.sequences, o.condlists, o.is_async)) @@ -1190,6 +1215,31 @@ def gen_inner_stmts() -> None: return builder.read(d, o.line) +def _translate_comprehension_with_scope( + builder: IRBuilder, + node: GeneratorExpr | DictionaryComprehension, + gen_body: Callable[[], Value], +) -> Value: + """Wrap a comprehension body with a lightweight scope for closure capture.""" + from mypyc.irbuild.context import FuncInfo + from mypyc.irbuild.env_class import add_vars_to_env, finalize_env_class, setup_env_class + + comprehension_fdef = builder.comprehension_to_fitem[node] + fn_info = FuncInfo( + fitem=comprehension_fdef, + name=comprehension_fdef.name, + is_nested=True, + contains_nested=True, + is_comprehension_scope=True, + ) + + with builder.enter_scope(fn_info): + setup_env_class(builder) + finalize_env_class(builder) + add_vars_to_env(builder) + return gen_body() + + # Misc @@ -1206,6 +1256,16 @@ def get_arg(arg: Expression | None) -> Value: def transform_generator_expr(builder: IRBuilder, o: GeneratorExpr) -> Value: builder.warning("Treating generator comprehension as list", o.line) + if o in builder.comprehension_to_fitem: + return builder.primitive_op( + iter_op, + [ + _translate_comprehension_with_scope( + builder, o, lambda: translate_list_comprehension(builder, o) + ) + ], + o.line, + ) return builder.primitive_op(iter_op, [translate_list_comprehension(builder, o)], o.line) diff --git a/mypyc/irbuild/prebuildvisitor.py b/mypyc/irbuild/prebuildvisitor.py index e630fed0d85a..b99a588c34d2 100644 --- a/mypyc/irbuild/prebuildvisitor.py +++ b/mypyc/irbuild/prebuildvisitor.py @@ -4,9 +4,11 @@ AssignmentStmt, Block, Decorator, + DictionaryComprehension, Expression, FuncDef, FuncItem, + GeneratorExpr, Import, LambdaExpr, MemberExpr, @@ -16,12 +18,39 @@ SymbolNode, Var, ) -from mypy.traverser import ExtendedTraverserVisitor +from mypy.traverser import ExtendedTraverserVisitor, TraverserVisitor from mypy.types import Type from mypyc.errors import Errors from mypyc.irbuild.missingtypevisitor import MissingTypesVisitor +class _LambdaChecker(TraverserVisitor): + """Check whether an AST subtree contains a lambda expression.""" + + found = False + + def visit_lambda_expr(self, _o: LambdaExpr) -> None: + self.found = True + + +def _comprehension_has_lambda(node: GeneratorExpr | DictionaryComprehension) -> bool: + """Return True if a comprehension body contains a lambda. + + Only checks body expressions (left_expr/key/value and conditions), + not the sequences, since sequences are evaluated in the enclosing scope. + """ + checker = _LambdaChecker() + if isinstance(node, GeneratorExpr): + node.left_expr.accept(checker) + else: + node.key.accept(checker) + node.value.accept(checker) + for conds in node.condlists: + for cond in conds: + cond.accept(checker) + return checker.found + + class PreBuildVisitor(ExtendedTraverserVisitor): """Mypy file AST visitor run before building the IR. @@ -88,6 +117,17 @@ def __init__( self.missing_types_visitor = MissingTypesVisitor(types) + # Synthetic FuncDef representing the module scope, created on demand + # when a comprehension at module/class level contains a lambda. + self._module_fitem: FuncDef | None = None + + # Counter for generating unique synthetic comprehension scope names. + self._comprehension_counter = 0 + + # Map comprehension AST nodes to synthetic FuncDefs representing + # their scope (only for comprehensions that contain lambdas). + self.comprehension_to_fitem: dict[GeneratorExpr | DictionaryComprehension, FuncDef] = {} + def visit(self, o: Node) -> bool: if not isinstance(o, Import): self._current_import_group = None @@ -157,6 +197,55 @@ def visit_func(self, func: FuncItem) -> None: super().visit_func(func) self.funcs.pop() + def _visit_comprehension_with_scope(self, o: GeneratorExpr | DictionaryComprehension) -> None: + """Visit a comprehension that contains lambdas. + + Creates a synthetic FuncDef to represent the comprehension's scope, + registers it in the function nesting hierarchy, and traverses the + comprehension body with it on the stack. + """ + pushed_module = False + if not self.funcs: + # At module level: push synthetic module FuncDef. + if self._module_fitem is None: + self._module_fitem = FuncDef("__mypyc_module__") + self._module_fitem.line = 1 + self.funcs.append(self._module_fitem) + pushed_module = True + + # Create synthetic FuncDef for the comprehension scope. + comprehension_fdef = FuncDef(f"__comprehension_{self._comprehension_counter}__") + self._comprehension_counter += 1 + comprehension_fdef.line = o.line + self.comprehension_to_fitem[o] = comprehension_fdef + + # Register as nested within enclosing function. + self.encapsulating_funcs.setdefault(self.funcs[-1], []).append(comprehension_fdef) + self.nested_funcs[comprehension_fdef] = self.funcs[-1] + + # Push and traverse. + self.funcs.append(comprehension_fdef) + if isinstance(o, GeneratorExpr): + super().visit_generator_expr(o) + else: + super().visit_dictionary_comprehension(o) + self.funcs.pop() + + if pushed_module: + self.funcs.pop() + + def visit_generator_expr(self, o: GeneratorExpr) -> None: + if _comprehension_has_lambda(o): + self._visit_comprehension_with_scope(o) + else: + super().visit_generator_expr(o) + + def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None: + if _comprehension_has_lambda(o): + self._visit_comprehension_with_scope(o) + else: + super().visit_dictionary_comprehension(o) + def visit_import(self, imp: Import) -> None: if self._current_import_group is not None: self.module_import_groups[self._current_import_group].append(imp) diff --git a/mypyc/test-data/run-functions.test b/mypyc/test-data/run-functions.test index 6ceaeb594362..a59d7f729e9f 100644 --- a/mypyc/test-data/run-functions.test +++ b/mypyc/test-data/run-functions.test @@ -1407,3 +1407,120 @@ def test_star2_fastpath() -> None: assert star2(x="a", y=3) == "aaa" def test_star2_fastpath_generic() -> None: assert star2_generic({"x": "a", "y": 3}) == "aaa" + +[case testLambdaInListComprehension] +# Lambda inside list comprehension at module level +funcs_list = [lambda x, i=i: x + i for i in range(3)] + +def test_module_level() -> None: + assert [f(10) for f in funcs_list] == [10, 11, 12] + +def test_in_function() -> None: + funcs = [lambda x, i=i: x + i for i in range(3)] + assert [f(10) for f in funcs] == [10, 11, 12] + # Different parameter name than loop variable + funcs2 = [lambda x, n=i: x + n for i in range(3)] + assert [f(10) for f in funcs2] == [10, 11, 12] + +[case testLambdaInDictComprehension] +# Lambda inside dict comprehension at module level +funcs_dict = {k: lambda x, k=k: x + k for k in range(3)} + +def test_module_level() -> None: + assert funcs_dict[0](10) == 10 + assert funcs_dict[1](10) == 11 + assert funcs_dict[2](10) == 12 + +def test_in_function() -> None: + d = {k: (lambda k=k: k * 2) for k in range(3)} + assert d[0]() == 0 + assert d[1]() == 2 + assert d[2]() == 4 + # Different parameter name than loop variable + d2 = {k: (lambda n=k: n * 2) for k in range(3)} + assert d2[0]() == 0 + assert d2[1]() == 2 + assert d2[2]() == 4 + +[case testLambdaInSetComprehension] +def test_set_comp() -> None: + funcs = {(lambda i=i: i) for i in range(3)} + results = {f() for f in funcs} + assert results == {0, 1, 2} + # Different parameter name than loop variable + funcs2 = {(lambda n=i: n) for i in range(3)} + results2 = {f() for f in funcs2} + assert results2 == {0, 1, 2} + +[case testLambdaInComprehensionCaptureOuter] +# Lambda capturing both loop var and outer function var +def test_capture_outer_and_loop() -> None: + base = 100 + funcs = [lambda i=i: i + base for i in range(3)] + assert funcs[0]() == 100 + assert funcs[1]() == 101 + assert funcs[2]() == 102 + # Different parameter name than loop variable + funcs2 = [lambda n=i: n + base for i in range(3)] + assert funcs2[0]() == 100 + assert funcs2[1]() == 101 + assert funcs2[2]() == 102 + +[case testLambdaInDictComprehensionLateBind] +# Dict comprehension with late-binding lambda (no default arg) +def test_dict_comp_late_bind() -> None: + d = {name: (lambda: name) for name in ("a", "b")} + # Late binding: all lambdas see the final value of 'name' + assert d["a"]() == "b" + assert d["b"]() == "b" + +[case testLambdaInComprehensionClassLevel] +# Lambda inside comprehension at class level +from typing import ClassVar +class Foo: + A: ClassVar[object] = {name: (lambda: name) for name in ("a", "b")} + +[file driver.py] +from native import Foo +assert Foo.A["a"]() == "b" +assert Foo.A["b"]() == "b" + +[case testLambdaInGeneratorExpression] +def test_generator_with_lambda_default() -> None: + result = list((lambda i=i: i * 2) for i in range(4)) + assert [f() for f in result] == [0, 2, 4, 6] + # Different parameter name than loop variable + result2 = list((lambda n=i: n * 2) for i in range(4)) + assert [f() for f in result2] == [0, 2, 4, 6] + +[case testLambdaInComprehensionWithParamIterable] +# Lambda inside comprehension where the iterable is a function parameter. +# The comprehension scope must be able to read variables from the enclosing +# function scope (not just the comprehension's own env class). +from typing import List, Callable + +def transform(items: List[int]) -> List[Callable[[], int]]: + return [(lambda i=i: i * 2) for i in items] # type: ignore[misc] + +def transform_different_name(items: List[int]) -> List[Callable[[], int]]: + return [(lambda n=i: n * 2) for i in items] # type: ignore[misc] + +def uses_multiple_params(items: List[str], sep: str) -> List[Callable[[], str]]: + return [(lambda s=s: s + sep) for s in items] # type: ignore[misc] + +def uses_multiple_params_different_name(items: List[str], sep: str) -> List[Callable[[], str]]: + return [(lambda t=s: t + sep) for s in items] # type: ignore[misc] + +def test_param_iterable() -> None: + funcs = transform([1, 2, 3]) + assert [f() for f in funcs] == [2, 4, 6] + # Different parameter name than loop variable + funcs2 = transform_different_name([1, 2, 3]) + assert [f() for f in funcs2] == [2, 4, 6] + +def test_multiple_params() -> None: + funcs = uses_multiple_params(["a", "b"], "!") + assert [f() for f in funcs] == ["a!", "b!"] + # Different parameter name than loop variable + funcs2 = uses_multiple_params_different_name(["a", "b"], "!") + assert [f() for f in funcs2] == ["a!", "b!"]