From 2bde844e21c5223f846594a3ae842597231fa909 Mon Sep 17 00:00:00 2001 From: "P. Clawmogorov" <262173731+Alm0stSurely@users.noreply.github.com> Date: Wed, 18 Feb 2026 13:57:28 +0000 Subject: [PATCH] Add ASYNC430: lint rule for pytest.raises(ExceptionGroup) Adds a new rule to detect usage of pytest.raises(ExceptionGroup) or pytest.raises(BaseExceptionGroup) in async functions, suggesting the use of pytest.RaisesGroup instead. This is recommended because RaisesGroup provides better support for exception group testing in async contexts, matching the structure of exception groups more accurately. Closes #430 --- flake8_async/visitors/__init__.py | 1 + flake8_async/visitors/visitor430.py | 84 +++++++++++++++++++++++++++++ tests/eval_files/async430.py | 24 +++++++++ 3 files changed, 109 insertions(+) create mode 100644 flake8_async/visitors/visitor430.py create mode 100644 tests/eval_files/async430.py diff --git a/flake8_async/visitors/__init__.py b/flake8_async/visitors/__init__.py index a97345b..0c5a006 100644 --- a/flake8_async/visitors/__init__.py +++ b/flake8_async/visitors/__init__.py @@ -38,6 +38,7 @@ visitor111, visitor118, visitor123, + visitor430, visitor_utility, visitors, ) diff --git a/flake8_async/visitors/visitor430.py b/flake8_async/visitors/visitor430.py new file mode 100644 index 0000000..ce5b53c --- /dev/null +++ b/flake8_async/visitors/visitor430.py @@ -0,0 +1,84 @@ +"""Visitor to check for pytest.raises(ExceptionGroup) usage. + +ASYNC430: Suggests using pytest.RaisesGroup instead of pytest.raises(ExceptionGroup). +""" + +from __future__ import annotations + +import ast +from typing import TYPE_CHECKING, Any + +from .flake8asyncvisitor import Flake8AsyncVisitor +from .helpers import error_class + +if TYPE_CHECKING: + from collections.abc import Mapping + + +@error_class +class Visitor430(Flake8AsyncVisitor): + error_codes: Mapping[str, str] = { + "ASYNC430": ( + "Using `pytest.raises(ExceptionGroup)` is discouraged, consider using " + "`pytest.RaisesGroup` instead." + ) + } + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.imports_pytest: bool = False + self.imports_exceptiongroup: bool = False + self.async_function = False + + def visit_AsyncFunctionDef( + self, node: ast.AsyncFunctionDef | ast.FunctionDef | ast.Lambda + ): + self.save_state(node, "async_function") + self.async_function = isinstance(node, ast.AsyncFunctionDef) + + visit_FunctionDef = visit_AsyncFunctionDef + visit_Lambda = visit_AsyncFunctionDef + + def visit_Import(self, node: ast.Import) -> None: + for alias in node.names: + if alias.name == "pytest": + self.imports_pytest = True + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if node.module == "pytest": + self.imports_pytest = True + elif node.module == "builtins" or node.module is None: + # Check for `from builtins import ExceptionGroup` + for alias in node.names: + if alias.name in ("ExceptionGroup", "BaseExceptionGroup"): + self.imports_exceptiongroup = True + + def visit_Call(self, node: ast.Call) -> None: + if not self.async_function: + return + + func_name = ast.unparse(node.func) + + # Check for pytest.raises(ExceptionGroup) or pytest.raises(BaseExceptionGroup) + if not ( + func_name == "pytest.raises" + or (self.imports_pytest and func_name == "raises") + ): + return + + # Check first argument (exception type) + if not node.args: + return + + first_arg = node.args[0] + if isinstance(first_arg, ast.Name) and first_arg.id in ( + "ExceptionGroup", + "BaseExceptionGroup", + ): + self.error(node) + elif isinstance(first_arg, ast.Attribute) and first_arg.attr in ( + "ExceptionGroup", + "BaseExceptionGroup", + ): + # Handle pytest.raises(pytest.ExceptionGroup) or similar + self.error(node) diff --git a/tests/eval_files/async430.py b/tests/eval_files/async430.py new file mode 100644 index 0000000..6c80c72 --- /dev/null +++ b/tests/eval_files/async430.py @@ -0,0 +1,24 @@ +# type: ignore +import pytest + + +async def test_pytest_raises_exceptiongroup(): + with pytest.raises(ExceptionGroup): # ASYNC430: 9 + pass + + +async def test_pytest_raises_baseexceptiongroup(): + with pytest.raises(BaseExceptionGroup): # ASYNC430: 9 + pass + + +async def test_pytest_raises_other(): + # Should not error + with pytest.raises(ValueError): + pass + + +async def test_pytest_raises_group(): + # Should not error - this is what we want users to use + with pytest.RaisesGroup(ExceptionGroup): + pass