Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ jobs:
# To speed-up process until ast_serialize is on PyPI.
- name: Install pinned ast-serialize
if: ${{ matrix.dev_ast_serialize }}
run: pip install ast-serialize@git+https://github.com/mypyc/ast_serialize.git@d277690a078c7784667a640ed1045e725bc42c00
run: pip install ast-serialize@git+https://github.com/mypyc/ast_serialize.git@052c5bfa3b2a5bf07c0b163ccbe2c5ccbfae9ac5

- name: Setup tox environment
run: |
Expand Down
141 changes: 104 additions & 37 deletions mypy/build.py

Large diffs are not rendered by default.

97 changes: 53 additions & 44 deletions mypy/build_worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,24 @@

from mypy import util
from mypy.build import (
SCC,
AckMessage,
BuildManager,
Graph,
GraphMessage,
SccRequestMessage,
SccResponseMessage,
SccsDataMessage,
SourcesDataMessage,
load_graph,
load_plugins,
process_stale_scc,
)
from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT
from mypy.errors import CompileError, Errors, report_internal_error
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
from mypy.fscache import FileSystemCache
from mypy.ipc import IPCException, IPCServer, receive, send
from mypy.modulefinder import BuildSource, BuildSourceSet, compute_search_paths
from mypy.nodes import FileRawData
from mypy.options import Options
from mypy.util import read_py_file
from mypy.version import __version__
Expand Down Expand Up @@ -123,42 +125,24 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
if manager is None:
return

# Mirror the GC freeze hack in the coordinator.
if platform.python_implementation() == "CPython":
gc.disable()
try:
graph = load_graph(sources, manager)
except CompileError:
# CompileError during loading will be reported by the coordinator.
return
if platform.python_implementation() == "CPython":
gc.freeze()
gc.unfreeze()
gc.enable()
for id in graph:
manager.import_map[id] = graph[id].dependencies_set
# Ignore errors during local graph loading to check that receiving
# early errors from coordinator works correctly.
manager.errors.reset()

# Notify worker we are done loading graph.
# Notify coordinator we are done with setup.
send(server, AckMessage())

# Compare worker graph and coordinator, with parallel parser we will only use the latter.
graph_data = GraphMessage.read(receive(server), manager)
assert set(manager.missing_modules) == graph_data.missing_modules
coordinator_graph = graph_data.graph
assert coordinator_graph.keys() == graph.keys()
# Update some manager data in-place as it has been passed to semantic analyzer.
manager.missing_modules |= graph_data.missing_modules
graph = graph_data.graph
for id in graph:
assert graph[id].dependencies_set == coordinator_graph[id].dependencies_set
assert graph[id].suppressed_set == coordinator_graph[id].suppressed_set
send(server, AckMessage())
manager.import_map[id] = graph[id].dependencies_set
# Link modules dicts, so that plugins will get access to ASTs as we parse them.
manager.plugin.set_modules(manager.modules)

# Notify coordinator we are ready to receive computed graph SCC structure.
send(server, AckMessage())
sccs = SccsDataMessage.read(receive(server)).sccs
manager.scc_by_id = {scc.id: scc for scc in sccs}
manager.top_order = [scc.id for scc in sccs]

# Notify coordinator we are ready to process SCCs.
# Notify coordinator we are ready to start processing SCCs.
send(server, AckMessage())
while True:
scc_message = SccRequestMessage.read(receive(server))
Expand All @@ -169,20 +153,17 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
scc = manager.scc_by_id[scc_id]
t0 = time.time()
try:
for id in scc.mod_ids:
state = graph[id]
# Extra if below is needed only because we are using local graph.
# TODO: clone options when switching to coordinator graph.
if state.tree is None:
# Parse early to get errors related data, such as ignored
# and skipped lines before replaying the errors.
state.parse_file()
else:
state.setup_errors()
if id in scc_message.import_errors:
manager.errors.set_file(state.xpath, id, state.options)
for err_info in scc_message.import_errors[id]:
manager.errors.add_error_info(err_info)
if platform.python_implementation() == "CPython":
# Since we are splitting the GC freeze hack into multiple smaller freezes,
# we should collect young generations to not accumulate accidental garbage.
gc.collect(generation=1)
gc.collect(generation=0)
gc.disable()
load_states(scc, graph, manager, scc_message.import_errors, scc_message.mod_data)
if platform.python_implementation() == "CPython":
gc.freeze()
gc.unfreeze()
gc.enable()
result = process_stale_scc(graph, scc, manager, from_cache=graph_data.from_cache)
# We must commit after each SCC, otherwise we break --sqlite-cache.
manager.metastore.commit()
Expand All @@ -193,6 +174,34 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
manager.add_stats(total_process_stale_time=time.time() - t0, stale_sccs_processed=1)


def load_states(
scc: SCC,
graph: Graph,
manager: BuildManager,
import_errors: dict[str, list[ErrorInfo]],
mod_data: dict[str, tuple[bytes, FileRawData | None]],
) -> None:
"""Re-create full state of an SCC as it would have been in coordinator."""
for id in scc.mod_ids:
state = graph[id]
# Re-clone options since we don't send them, it is usually faster than deserializing.
state.options = state.options.clone_for_module(state.id)
suppressed_deps_opts, raw_data = mod_data[id]
state.parse_file(raw_data=raw_data)
# Set data that is needed to be written to cache meta.
state.known_suppressed_deps_opts = suppressed_deps_opts
assert state.tree is not None
import_lines = {imp.line for imp in state.tree.imports}
state.imports_ignored = {
line: codes for line, codes in state.tree.ignored_lines.items() if line in import_lines
}
# Replay original errors encountered during graph loading in coordinator.
if id in import_errors:
manager.errors.set_file(state.xpath, id, state.options)
for err_info in import_errors[id]:
manager.errors.add_error_info(err_info)


def setup_worker_manager(sources: list[BuildSource], ctx: ServerContext) -> BuildManager | None:
data_dir = os.path.dirname(os.path.dirname(__file__))
# This is used for testing only now.
Expand Down
1 change: 1 addition & 0 deletions mypy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def read(cls, data: ReadBuffer, data_file: str) -> CacheMeta | None:
LIST_BYTES: Final[Tag] = 23
TUPLE_GEN: Final[Tag] = 24
DICT_STR_GEN: Final[Tag] = 30
DICT_INT_GEN: Final[Tag] = 31

# Misc classes.
EXTRA_ATTRS: Final[Tag] = 150
Expand Down
4 changes: 4 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def main(
stdout, stderr, options.hide_error_codes, hide_success=bool(options.output)
)

if options.num_workers:
# Supporting both parsers would be really tricky, so just support the new one.
options.native_parser = True

if options.allow_redefinition_new and not options.local_partial_types:
fail(
"error: --local-partial-types must be enabled if using --allow-redefinition-new",
Expand Down
25 changes: 17 additions & 8 deletions mypy/nativeparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
EllipsisExpr,
Expression,
ExpressionStmt,
FileRawData,
FloatExpr,
ForStmt,
FuncDef,
Expand Down Expand Up @@ -169,7 +170,6 @@ def __init__(self, options: Options) -> None:
self.options = options
self.errors: list[dict[str, Any]] = []
self.num_funcs = 0
self.uses_template_strings = False

def add_error(
self,
Expand All @@ -195,7 +195,7 @@ def add_error(


def native_parse(
filename: str, options: Options, skip_function_bodies: bool = False
filename: str, options: Options, skip_function_bodies: bool = False, imports_only: bool = False
) -> tuple[MypyFile, list[dict[str, Any]], TypeIgnores]:
"""Parse a Python file using the native Rust-based parser.
Expand All @@ -208,6 +208,8 @@ def native_parse(
skip_function_bodies: If True, many function and method bodies are omitted from
the AST, useful for parsing stubs or extracting signatures without full
implementation details
imports_only: If True create an empty MypyFile with actual serialized defs
stored in binary_data.
Returns:
A tuple containing:
Expand All @@ -222,20 +224,27 @@ def native_parse(
node.path = filename
return node, [], []

b, errors, ignores, import_bytes, is_partial_package = parse_to_binary_ast(
filename, options, skip_function_bodies
b, errors, ignores, import_bytes, is_partial_package, uses_template_strings = (
parse_to_binary_ast(filename, options, skip_function_bodies)
)
data = ReadBuffer(b)
n = read_int(data)
state = State(options)
defs = read_statements(state, data, n)
if imports_only:
defs = []
else:
defs = read_statements(state, data, n)

imports = deserialize_imports(import_bytes)

node = MypyFile(defs, imports)
node.path = filename
node.is_partial_stub_package = is_partial_package
node.uses_template_strings = state.uses_template_strings
if imports_only:
node.raw_data = FileRawData(
b, import_bytes, errors, dict(ignores), is_partial_package, uses_template_strings
)
node.uses_template_strings = uses_template_strings
# Merge deserialization errors with parsing errors
all_errors = errors + state.errors
return node, all_errors, ignores
Expand Down Expand Up @@ -263,7 +272,7 @@ def read_statements(state: State, data: ReadBuffer, n: int) -> list[Statement]:

def parse_to_binary_ast(
filename: str, options: Options, skip_function_bodies: bool = False
) -> tuple[bytes, list[dict[str, Any]], TypeIgnores, bytes, bool]:
) -> tuple[bytes, list[dict[str, Any]], TypeIgnores, bytes, bool, bool]:
ast_bytes, errors, ignores, import_bytes, ast_data = ast_serialize.parse(
filename,
skip_function_bodies=skip_function_bodies,
Expand All @@ -278,6 +287,7 @@ def parse_to_binary_ast(
ignores,
import_bytes,
ast_data["is_partial_package"],
ast_data["uses_template_strings"],
)


Expand Down Expand Up @@ -1528,7 +1538,6 @@ def read_expression(state: State, data: ReadBuffer) -> Expression:
expect_end_tag(data)
return expr
elif tag == nodes.TSTRING_EXPR:
state.uses_template_strings = True
nparts = read_int(data)
titems: list[Expression | tuple[Expression, str, str | None, Expression | None]] = []
for _ in range(nparts):
Expand Down
70 changes: 70 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import mypy.strconv
from mypy.cache import (
DICT_INT_GEN,
DICT_STR_GEN,
DT_SPEC,
END_TAG,
Expand All @@ -41,6 +42,7 @@
Tag,
WriteBuffer,
read_bool,
read_bytes,
read_int,
read_int_list,
read_int_opt,
Expand All @@ -52,6 +54,7 @@
read_str_opt_list,
read_tag,
write_bool,
write_bytes,
write_int,
write_int_list,
write_int_opt,
Expand Down Expand Up @@ -307,6 +310,69 @@ def read(cls, data: ReadBuffer) -> SymbolNode:
Definition: _TypeAlias = tuple[str, "SymbolTableNode", Optional["TypeInfo"]]


class FileRawData:
"""Raw (binary) data representing parsed, but not deserialized file."""

__slots__ = (
"defs",
"imports",
"raw_errors",
"ignored_lines",
"is_partial_stub_package",
"uses_template_strings",
)

defs: bytes
imports: bytes
raw_errors: list[dict[str, Any]] # TODO: switch to more precise type here.
ignored_lines: dict[int, list[str]]
is_partial_stub_package: bool
uses_template_strings: bool

def __init__(
self,
defs: bytes,
imports: bytes,
raw_errors: list[dict[str, Any]],
ignored_lines: dict[int, list[str]],
is_partial_stub_package: bool,
uses_template_strings: bool,
) -> None:
self.defs = defs
self.imports = imports
self.raw_errors = raw_errors
self.ignored_lines = ignored_lines
self.is_partial_stub_package = is_partial_stub_package
self.uses_template_strings = uses_template_strings

def write(self, data: WriteBuffer) -> None:
write_bytes(data, self.defs)
write_bytes(data, self.imports)
write_tag(data, LIST_GEN)
write_int_bare(data, len(self.raw_errors))
for err in self.raw_errors:
write_json(data, err)
write_tag(data, DICT_INT_GEN)
write_int_bare(data, len(self.ignored_lines))
for line, codes in self.ignored_lines.items():
write_int(data, line)
write_str_list(data, codes)
write_bool(data, self.is_partial_stub_package)
write_bool(data, self.uses_template_strings)

@classmethod
def read(cls, data: ReadBuffer) -> FileRawData:
defs = read_bytes(data)
imports = read_bytes(data)
assert read_tag(data) == LIST_GEN
raw_errors = [read_json(data) for _ in range(read_int_bare(data))]
assert read_tag(data) == DICT_INT_GEN
ignored_lines = {read_int(data): read_str_list(data) for _ in range(read_int_bare(data))}
return FileRawData(
defs, imports, raw_errors, ignored_lines, read_bool(data), read_bool(data)
)


class MypyFile(SymbolNode):
"""The abstract syntax tree of a single source file."""

Expand All @@ -328,6 +394,7 @@ class MypyFile(SymbolNode):
"plugin_deps",
"future_import_flags",
"_is_typeshed_file",
"raw_data",
)

__match_args__ = ("name", "path", "defs")
Expand Down Expand Up @@ -370,6 +437,8 @@ class MypyFile(SymbolNode):
# Future imports defined in this file. Populated during semantic analysis.
future_import_flags: set[str]
_is_typeshed_file: bool | None
# For native parser store actual serialized data here.
raw_data: FileRawData | None

def __init__(
self,
Expand Down Expand Up @@ -400,6 +469,7 @@ def __init__(
self.uses_template_strings = False
self.future_import_flags = set()
self._is_typeshed_file = None
self.raw_data = None

def local_definitions(self) -> Iterator[Definition]:
"""Return all definitions within the module (including nested).
Expand Down
Loading
Loading