Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(
self.nested_fitems = pbv.nested_funcs.keys()
self.fdefs_to_decorators = pbv.funcs_to_decorators
self.module_import_groups = pbv.module_import_groups
self.comprehension_to_fitem = pbv.comprehension_to_fitem

self.singledispatch_impls = singledispatch_impls

Expand Down Expand Up @@ -1263,6 +1264,37 @@ def leave(self) -> tuple[list[Register], list[RuntimeArg], list[BasicBlock], RTy
self.fn_info = self.fn_infos[-1]
return builder.args, runtime_args, builder.blocks, ret_type, fn_info

@contextmanager
def enter_scope(self, fn_info: FuncInfo) -> Iterator[None]:
"""Push a lightweight scope for comprehensions.

Unlike enter(), this reuses the same LowLevelIRBuilder (same basic
blocks and registers) but pushes new symtable and fn_info entries
so that the closure machinery sees a scope boundary.
"""
self.builders.append(self.builder)
# Copy the parent symtable so variables from the enclosing scope
# (e.g. function parameters used as the comprehension iterable)
# remain accessible. The comprehension is inlined (same basic blocks
# and registers), so the parent's register references are still valid.
self.symtables.append(dict(self.symtables[-1]))
self.runtime_args.append([])
self.fn_info = fn_info
self.fn_infos.append(self.fn_info)
self.ret_types.append(none_rprimitive)
self.nonlocal_control.append(BaseNonlocalControl())
try:
yield
finally:
self.builders.pop()
self.symtables.pop()
self.runtime_args.pop()
self.ret_types.pop()
self.fn_infos.pop()
self.nonlocal_control.pop()
self.builder = self.builders[-1]
self.fn_info = self.fn_infos[-1]

@contextmanager
def enter_method(
self,
Expand Down
6 changes: 4 additions & 2 deletions mypyc/irbuild/callable_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,13 @@ def instantiate_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> Value:
# - A generator function: the callable class is instantiated
# from the '__next__' method of the generator class, and hence the
# environment of the generator class is used.
# - Regular function: we use the environment of the original function.
# - Regular function or comprehension scope: we use the environment
# of the original function. Comprehension scopes are inlined (no
# callable class), so they fall into this case despite is_nested.
curr_env_reg = None
if builder.fn_info.is_generator:
curr_env_reg = builder.fn_info.generator_class.curr_env_reg
elif builder.fn_info.is_nested:
elif builder.fn_info.is_nested and not builder.fn_info.is_comprehension_scope:
curr_env_reg = builder.fn_info.callable_class.curr_env_reg
elif builder.fn_info.contains_nested:
curr_env_reg = builder.fn_info.curr_env_reg
Expand Down
6 changes: 6 additions & 0 deletions mypyc/irbuild/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
is_decorated: bool = False,
in_non_ext: bool = False,
add_nested_funcs_to_env: bool = False,
is_comprehension_scope: bool = False,
) -> None:
self.fitem = fitem
self.name = name
Expand All @@ -49,6 +50,11 @@ def __init__(
self.is_decorated = is_decorated
self.in_non_ext = in_non_ext
self.add_nested_funcs_to_env = add_nested_funcs_to_env
# Comprehension scopes are lightweight scope boundaries created when
# a comprehension body contains a lambda. The comprehension is still
# inlined (same basic blocks), but we push a new FuncInfo so the
# closure machinery can capture loop variables through env classes.
self.is_comprehension_scope = is_comprehension_scope

# TODO: add field for ret_type: RType = none_rprimitive

Expand Down
43 changes: 33 additions & 10 deletions mypyc/irbuild/env_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class is generated, the function environment has not yet been
)
env_class.reuse_freed_instance = True
env_class.attributes[SELF_NAME] = RInstance(env_class)
if builder.fn_info.is_nested:
if builder.fn_info.is_nested and builder.fn_infos[-2]._env_class is not None:
# If the function is nested, its environment class must contain an environment
# attribute pointing to its encapsulating functions' environment class.
env_class.attributes[ENV_ATTR_NAME] = RInstance(builder.fn_infos[-2].env_class)
Expand All @@ -73,11 +73,14 @@ def finalize_env_class(builder: IRBuilder, prefix: str = "") -> None:

# Iterate through the function arguments and replace local definitions (using registers)
# that were previously added to the environment with references to the function's
# environment class.
if builder.fn_info.is_nested:
add_args_to_env(builder, local=False, base=builder.fn_info.callable_class, prefix=prefix)
else:
add_args_to_env(builder, local=False, base=builder.fn_info, prefix=prefix)
# environment class. Comprehension scopes have no arguments to add.
if not builder.fn_info.is_comprehension_scope:
if builder.fn_info.is_nested:
add_args_to_env(
builder, local=False, base=builder.fn_info.callable_class, prefix=prefix
)
else:
add_args_to_env(builder, local=False, base=builder.fn_info, prefix=prefix)


def instantiate_env_class(builder: IRBuilder) -> Value:
Expand All @@ -86,7 +89,7 @@ def instantiate_env_class(builder: IRBuilder) -> Value:
Call(builder.fn_info.env_class.ctor, [], builder.fn_info.fitem.line)
)

if builder.fn_info.is_nested:
if builder.fn_info.is_nested and not builder.fn_info.is_comprehension_scope:
builder.fn_info.callable_class._curr_env_reg = curr_env_reg
builder.add(
SetAttr(
Expand All @@ -97,7 +100,22 @@ def instantiate_env_class(builder: IRBuilder) -> Value:
)
)
else:
# Top-level functions and comprehension scopes store env reg directly.
builder.fn_info._curr_env_reg = curr_env_reg
# Comprehension scopes link to parent env if it exists.
if (
builder.fn_info.is_nested
and builder.fn_infos[-2]._env_class is not None
and builder.fn_infos[-2]._curr_env_reg is not None
):
builder.add(
SetAttr(
curr_env_reg,
ENV_ATTR_NAME,
builder.fn_infos[-2].curr_env_reg,
builder.fn_info.fitem.line,
)
)

return curr_env_reg

Expand All @@ -114,7 +132,7 @@ def load_env_registers(builder: IRBuilder, prefix: str = "") -> None:

fn_info = builder.fn_info
fitem = fn_info.fitem
if fn_info.is_nested:
if fn_info.is_nested and builder.fn_infos[-2]._env_class is not None:
load_outer_envs(builder, fn_info.callable_class)
# If this is a FuncDef, then make sure to load the FuncDef into its own environment
# class so that the function can be called recursively.
Expand Down Expand Up @@ -155,7 +173,8 @@ def load_outer_envs(builder: IRBuilder, base: ImplicitClass) -> None:

# Load the first outer environment. This one is special because it gets saved in the
# FuncInfo instance's prev_env_reg field.
if index > 1:
has_outer = index > 1 or (index == 1 and builder.fn_infos[1].contains_nested)
if has_outer and builder.fn_infos[index]._env_class is not None:
# outer_env = builder.fn_infos[index].environment
outer_env = builder.symtables[index]
if isinstance(base, GeneratorClass):
Expand All @@ -167,6 +186,8 @@ def load_outer_envs(builder: IRBuilder, base: ImplicitClass) -> None:

# Load the remaining outer environments into registers.
while index > 1:
if builder.fn_infos[index]._env_class is None:
break
# outer_env = builder.fn_infos[index].environment
outer_env = builder.symtables[index]
env_reg = load_outer_env(builder, env_reg, outer_env)
Expand Down Expand Up @@ -224,7 +245,9 @@ def add_vars_to_env(builder: IRBuilder, prefix: str = "") -> None:
env_for_func: FuncInfo | ImplicitClass = builder.fn_info
if builder.fn_info.is_generator:
env_for_func = builder.fn_info.generator_class
elif builder.fn_info.is_nested or builder.fn_info.in_non_ext:
elif (
builder.fn_info.is_nested or builder.fn_info.in_non_ext
) and not builder.fn_info.is_comprehension_scope:
env_for_func = builder.fn_info.callable_class

if builder.fn_info.fitem in builder.free_variables:
Expand Down
64 changes: 62 additions & 2 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,20 +1164,45 @@ def _visit_display(


# Comprehensions
#
# mypyc always inlines comprehensions (the loop body is emitted directly into
# the enclosing function's IR, no implicit function call like CPython).
#
# However, when a comprehension body contains a lambda, we need a lightweight
# scope boundary so the closure/env-class machinery can see the comprehension
# as a separate scope. The comprehension is still inlined (same basic blocks
# and registers), but we push a new FuncInfo and set up an env class so the
# lambda can capture loop variables through the standard env-class chain.


def transform_list_comprehension(builder: IRBuilder, o: ListComprehension) -> Value:
return translate_list_comprehension(builder, o.generator)
gen = o.generator
if gen in builder.comprehension_to_fitem:
return _translate_comprehension_with_scope(
builder, gen, lambda: translate_list_comprehension(builder, gen)
)
return translate_list_comprehension(builder, gen)


def transform_set_comprehension(builder: IRBuilder, o: SetComprehension) -> Value:
return translate_set_comprehension(builder, o.generator)
gen = o.generator
if gen in builder.comprehension_to_fitem:
return _translate_comprehension_with_scope(
builder, gen, lambda: translate_set_comprehension(builder, gen)
)
return translate_set_comprehension(builder, gen)


def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehension) -> Value:
if raise_error_if_contains_unreachable_names(builder, o):
return builder.none()

if o in builder.comprehension_to_fitem:
return _translate_comprehension_with_scope(builder, o, lambda: _dict_comp_body(builder, o))
return _dict_comp_body(builder, o)


def _dict_comp_body(builder: IRBuilder, o: DictionaryComprehension) -> Value:
d = builder.maybe_spill(builder.call_c(dict_new_op, [], o.line))
loop_params = list(zip(o.indices, o.sequences, o.condlists, o.is_async))

Expand All @@ -1190,6 +1215,31 @@ def gen_inner_stmts() -> None:
return builder.read(d, o.line)


def _translate_comprehension_with_scope(
builder: IRBuilder,
node: GeneratorExpr | DictionaryComprehension,
gen_body: Callable[[], Value],
) -> Value:
"""Wrap a comprehension body with a lightweight scope for closure capture."""
from mypyc.irbuild.context import FuncInfo
from mypyc.irbuild.env_class import add_vars_to_env, finalize_env_class, setup_env_class

comprehension_fdef = builder.comprehension_to_fitem[node]
fn_info = FuncInfo(
fitem=comprehension_fdef,
name=comprehension_fdef.name,
is_nested=True,
contains_nested=True,
is_comprehension_scope=True,
)

with builder.enter_scope(fn_info):
setup_env_class(builder)
finalize_env_class(builder)
add_vars_to_env(builder)
return gen_body()


# Misc


Expand All @@ -1206,6 +1256,16 @@ def get_arg(arg: Expression | None) -> Value:

def transform_generator_expr(builder: IRBuilder, o: GeneratorExpr) -> Value:
builder.warning("Treating generator comprehension as list", o.line)
if o in builder.comprehension_to_fitem:
return builder.primitive_op(
iter_op,
[
_translate_comprehension_with_scope(
builder, o, lambda: translate_list_comprehension(builder, o)
)
],
o.line,
)
return builder.primitive_op(iter_op, [translate_list_comprehension(builder, o)], o.line)


Expand Down
Loading
Loading