diff --git a/flake8_async/visitors/__init__.py b/flake8_async/visitors/__init__.py index a97345b..0c5a006 100644 --- a/flake8_async/visitors/__init__.py +++ b/flake8_async/visitors/__init__.py @@ -38,6 +38,7 @@ visitor111, visitor118, visitor123, + visitor430, visitor_utility, visitors, ) diff --git a/flake8_async/visitors/visitor430.py b/flake8_async/visitors/visitor430.py new file mode 100644 index 0000000..ce5b53c --- /dev/null +++ b/flake8_async/visitors/visitor430.py @@ -0,0 +1,84 @@ +"""Visitor to check for pytest.raises(ExceptionGroup) usage. + +ASYNC430: Suggests using pytest.RaisesGroup instead of pytest.raises(ExceptionGroup). +""" + +from __future__ import annotations + +import ast +from typing import TYPE_CHECKING, Any + +from .flake8asyncvisitor import Flake8AsyncVisitor +from .helpers import error_class + +if TYPE_CHECKING: + from collections.abc import Mapping + + +@error_class +class Visitor430(Flake8AsyncVisitor): + error_codes: Mapping[str, str] = { + "ASYNC430": ( + "Using `pytest.raises(ExceptionGroup)` is discouraged, consider using " + "`pytest.RaisesGroup` instead." + ) + } + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.imports_pytest: bool = False + self.imports_exceptiongroup: bool = False + self.async_function = False + + def visit_AsyncFunctionDef( + self, node: ast.AsyncFunctionDef | ast.FunctionDef | ast.Lambda + ): + self.save_state(node, "async_function") + self.async_function = isinstance(node, ast.AsyncFunctionDef) + + visit_FunctionDef = visit_AsyncFunctionDef + visit_Lambda = visit_AsyncFunctionDef + + def visit_Import(self, node: ast.Import) -> None: + for alias in node.names: + if alias.name == "pytest": + self.imports_pytest = True + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if node.module == "pytest": + self.imports_pytest = True + elif node.module == "builtins" or node.module is None: + # Check for `from builtins import ExceptionGroup` + for alias in node.names: + if alias.name in ("ExceptionGroup", "BaseExceptionGroup"): + self.imports_exceptiongroup = True + + def visit_Call(self, node: ast.Call) -> None: + if not self.async_function: + return + + func_name = ast.unparse(node.func) + + # Check for pytest.raises(ExceptionGroup) or pytest.raises(BaseExceptionGroup) + if not ( + func_name == "pytest.raises" + or (self.imports_pytest and func_name == "raises") + ): + return + + # Check first argument (exception type) + if not node.args: + return + + first_arg = node.args[0] + if isinstance(first_arg, ast.Name) and first_arg.id in ( + "ExceptionGroup", + "BaseExceptionGroup", + ): + self.error(node) + elif isinstance(first_arg, ast.Attribute) and first_arg.attr in ( + "ExceptionGroup", + "BaseExceptionGroup", + ): + # Handle pytest.raises(pytest.ExceptionGroup) or similar + self.error(node) diff --git a/tests/eval_files/async430.py b/tests/eval_files/async430.py new file mode 100644 index 0000000..6c80c72 --- /dev/null +++ b/tests/eval_files/async430.py @@ -0,0 +1,24 @@ +# type: ignore +import pytest + + +async def test_pytest_raises_exceptiongroup(): + with pytest.raises(ExceptionGroup): # ASYNC430: 9 + pass + + +async def test_pytest_raises_baseexceptiongroup(): + with pytest.raises(BaseExceptionGroup): # ASYNC430: 9 + pass + + +async def test_pytest_raises_other(): + # Should not error + with pytest.raises(ValueError): + pass + + +async def test_pytest_raises_group(): + # Should not error - this is what we want users to use + with pytest.RaisesGroup(ExceptionGroup): + pass