3333# 3rd party
3434import flake8_helper
3535
36- __all__ = ("Plugin" , "Visitor" , "get_decorator_names" , "check_params" )
36+ __all__ = ("Plugin" , "Visitor" , "get_decorator_names" , "check_params" , "get_docstring_args" , "get_signature_args" )
3737
3838__author__ = "Dominic Davis-Foster"
3939__copyright__ = "2025 Dominic Davis-Foster"
4242__email__ = "dominic@davis-foster.co.uk"
4343
4444PRM001 = "PRM001 Docstring parameters in wrong order."
45- PRM002 = "PRM002 Missing parameters in docstring. "
46- PRM003 = "PRM003 Extra parameters in docstring. "
45+ PRM002 = "PRM002 Missing parameters in docstring"
46+ PRM003 = "PRM003 Extra parameters in docstring"
4747# TODO: class-specific codes?
4848
4949deco_allowed_attr_names = {
@@ -124,14 +124,58 @@ def check_params(
124124 return PRM001
125125 elif signature_set - docstring_set :
126126 # Extras in signature
127- return PRM002
127+ return PRM002 + ": " + ' ' . join ( sorted ( signature_set - docstring_set ))
128128 elif docstring_set - signature_set :
129129 # Extras in docstrings
130- return PRM003
130+ return PRM003 + ": " + ' ' . join ( sorted ( docstring_set - signature_set ))
131131
132132 return None # pragma: no cover
133133
134134
135+ def get_signature_args (function : Union [ast .FunctionDef , ast .AsyncFunctionDef ]) -> Iterator [str ]:
136+ """
137+ Extract arguments from the function signature.
138+
139+ :param function:
140+
141+ :rtype:
142+
143+ ..versionadded:: 0.2.0
144+ """
145+
146+ for arg in function .args .posonlyargs :
147+ yield arg .arg
148+
149+ for arg in function .args .args :
150+ yield arg .arg
151+
152+ if function .args .vararg :
153+ yield '*' + function .args .vararg .arg
154+
155+ for arg in function .args .kwonlyargs :
156+ yield arg .arg
157+
158+ if function .args .kwarg :
159+ yield "**" + function .args .kwarg .arg
160+
161+
162+ def get_docstring_args (docstring : str ) -> Iterator [str ]:
163+ """
164+ Extract arguments from the docstring.
165+
166+ :param docstring:
167+
168+ :rtype:
169+
170+ ..versionadded:: 0.2.0
171+ """
172+
173+ for line in docstring .split ('\n ' ):
174+ line = line .strip ()
175+ if line .startswith (":param" ):
176+ yield line [6 :].split (':' , 1 )[0 ].strip ().replace (r"\*" , '*' )
177+
178+
135179class Visitor (flake8_helper .Visitor ):
136180 """
137181 AST node visitor for identifying mismatches between function signatures and docstring params.
@@ -154,14 +198,8 @@ def _visit_function(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) ->
154198 self .generic_visit (node )
155199 return
156200
157- docstring_args = []
158- for line in docstring .split ('\n ' ):
159- line = line .strip ()
160- if line .startswith (":param" ):
161- docstring_args .append (line [6 :].split (':' , 1 )[0 ].strip ())
162-
163- signature_args = [a .arg for a in node .args .args ]
164-
201+ docstring_args = list (get_docstring_args (docstring ))
202+ signature_args = list (get_signature_args (node ))
165203 decorators = list (get_decorator_names (node ))
166204
167205 error = check_params (signature_args , docstring_args , decorators )
@@ -184,20 +222,15 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: # noqa: D102
184222 self .generic_visit (node )
185223 return
186224
187- docstring_args = []
188- for line in docstring .split ('\n ' ):
189- line = line .strip ()
190- if line .startswith (":param" ):
191- docstring_args .append (line [6 :].split (':' , 1 )[0 ].strip ())
192-
225+ docstring_args = list (get_docstring_args (docstring ))
193226 decorators = list (get_decorator_names (node ))
194227
195228 signature_args = []
196229 functions_in_body : List [ast .FunctionDef ] = [n for n in node .body if isinstance (n , ast .FunctionDef )]
197230
198231 for function in functions_in_body :
199232 if function .name == "__init__" :
200- signature_args = [ a . arg for a in function . args . args ]
233+ signature_args = list ( get_signature_args ( function ))
201234 break
202235 else :
203236 # No __init__; maybe it comes from a base class.
0 commit comments