From 3c2fd067aec0f1025990ff8baeaca14e5d77ac01 Mon Sep 17 00:00:00 2001 From: Matthias Klumpp Date: Tue, 5 May 2026 23:41:02 +0200 Subject: [PATCH] fix: Invert input/output context if a Python function is called from C++ When the Callable itself is an input (parameter) to a C++ function, its arguments are outputs (C++ passes them to the Python callback), and vice versa. Therefore, we must invert them if C++ calls a Python function, but keep them the same in the other direction. --- include/pybind11/detail/descr.h | 5 +++++ include/pybind11/functional.h | 5 ++--- include/pybind11/pybind11.h | 8 +++++++- include/pybind11/typing.h | 8 +++----- tests/test_callbacks.py | 2 +- tests/test_pytypes.py | 16 ++++++++-------- 6 files changed, 26 insertions(+), 18 deletions(-) diff --git a/include/pybind11/detail/descr.h b/include/pybind11/detail/descr.h index 701662c4cf..4f5ae34fc8 100644 --- a/include/pybind11/detail/descr.h +++ b/include/pybind11/detail/descr.h @@ -222,5 +222,10 @@ constexpr descr return_descr(const descr &descr) { return const_name("@$") + descr + const_name("@!"); } +template +constexpr descr inv_descr(const descr &descr) { + return const_name("@~") + descr + const_name("@!"); +} + PYBIND11_NAMESPACE_END(detail) PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 8f59f5fe5e..c7122795ad 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -138,9 +138,8 @@ struct type_caster> { PYBIND11_TYPE_CASTER( type, const_name("collections.abc.Callable[[") - + ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster::name)...) - + const_name("], ") + ::pybind11::detail::return_descr(make_caster::name) - + const_name("]")); + + ::pybind11::detail::concat(::pybind11::detail::inv_descr(make_caster::name)...) + + const_name("], ") + make_caster::name + const_name("]")); }; PYBIND11_NAMESPACE_END(detail) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 12c0cd65a2..9cc45bdbdc 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -185,7 +185,8 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel signature += *++pc; } else if (c == '@') { // `@^ ... @!` and `@$ ... @!` are used to force arg/return value type (see - // typing::Callable/detail::arg_descr/detail::return_descr) + // typing::Callable/detail::arg_descr/detail::return_descr). + // `@~ ... @!` inverts the current context (see detail::inv_descr). if (*(pc + 1) == '^') { is_return_value.emplace(false); ++pc; @@ -196,6 +197,11 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel ++pc; continue; } + if (*(pc + 1) == '~') { + is_return_value.emplace(!is_return_value.top()); + ++pc; + continue; + } if (*(pc + 1) == '!') { is_return_value.pop(); ++pc; diff --git a/include/pybind11/typing.h b/include/pybind11/typing.h index 43e2187b9e..4b027dbfea 100644 --- a/include/pybind11/typing.h +++ b/include/pybind11/typing.h @@ -194,9 +194,8 @@ struct handle_type_name> { using retval_type = conditional_t::value, void_type, Return>; static constexpr auto name = const_name("collections.abc.Callable[[") - + ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster::name)...) - + const_name("], ") + ::pybind11::detail::return_descr(make_caster::name) - + const_name("]"); + + ::pybind11::detail::concat(::pybind11::detail::inv_descr(make_caster::name)...) + + const_name("], ") + make_caster::name + const_name("]"); }; template @@ -204,8 +203,7 @@ struct handle_type_name> { // PEP 484 specifies this syntax for defining only return types of callables using retval_type = conditional_t::value, void_type, Return>; static constexpr auto name = const_name("collections.abc.Callable[..., ") - + ::pybind11::detail::return_descr(make_caster::name) - + const_name("]"); + + make_caster::name + const_name("]"); }; template diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 327e41eb33..1c3117041b 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -140,7 +140,7 @@ def test_cpp_function_roundtrip(): def test_function_signatures(doc): assert ( doc(m.test_callback3) - == "test_callback3(arg0: collections.abc.Callable[[typing.SupportsInt | typing.SupportsIndex], int]) -> str" + == "test_callback3(arg0: collections.abc.Callable[[int], typing.SupportsInt | typing.SupportsIndex]) -> str" ) assert ( doc(m.test_callback4) diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 580371f02d..9a80f1ea41 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -976,14 +976,14 @@ def test_iterator_annotations(doc): def test_fn_annotations(doc): assert ( doc(m.annotate_fn) - == "annotate_fn(arg0: collections.abc.Callable[[list[str], str], int]) -> None" + == "annotate_fn(arg0: collections.abc.Callable[[list[str], str], typing.SupportsInt | typing.SupportsIndex]) -> None" ) def test_fn_return_only(doc): assert ( doc(m.annotate_fn_only_return) - == "annotate_fn_only_return(arg0: collections.abc.Callable[..., int]) -> None" + == "annotate_fn_only_return(arg0: collections.abc.Callable[..., typing.SupportsInt | typing.SupportsIndex]) -> None" ) @@ -1085,7 +1085,7 @@ def test_literal(doc): ) assert ( doc(m.identity_literal_arrow_with_callable) - == 'identity_literal_arrow_with_callable(arg0: collections.abc.Callable[[typing.Literal["->"], float | int], float]) -> collections.abc.Callable[[typing.Literal["->"], float | int], float]' + == 'identity_literal_arrow_with_callable(arg0: collections.abc.Callable[[typing.Literal["->"], float], float | int]) -> collections.abc.Callable[[typing.Literal["->"], float | int], float]' ) assert ( doc(m.identity_literal_all_special_chars) @@ -1325,27 +1325,27 @@ def test_arg_return_type_hints(doc, backport_typehints): # Callable identity assert ( doc(m.identity_callable) - == "identity_callable(arg0: collections.abc.Callable[[float | int], float]) -> collections.abc.Callable[[float | int], float]" + == "identity_callable(arg0: collections.abc.Callable[[float], float | int]) -> collections.abc.Callable[[float | int], float]" ) # Callable identity assert ( doc(m.identity_callable_ellipsis) - == "identity_callable_ellipsis(arg0: collections.abc.Callable[..., float]) -> collections.abc.Callable[..., float]" + == "identity_callable_ellipsis(arg0: collections.abc.Callable[..., float | int]) -> collections.abc.Callable[..., float]" ) # Nested Callable identity assert ( doc(m.identity_nested_callable) - == "identity_nested_callable(arg0: collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float | int], float]]) -> collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float | int], float]]" + == "identity_nested_callable(arg0: collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float], float | int]]) -> collections.abc.Callable[[collections.abc.Callable[[float], float | int]], collections.abc.Callable[[float | int], float]]" ) # Callable assert ( doc(m.apply_callable) - == "apply_callable(arg0: float | int, arg1: collections.abc.Callable[[float | int], float]) -> float" + == "apply_callable(arg0: float | int, arg1: collections.abc.Callable[[float], float | int]) -> float" ) # Callable assert ( doc(m.apply_callable_ellipsis) - == "apply_callable_ellipsis(arg0: float | int, arg1: collections.abc.Callable[..., float]) -> float" + == "apply_callable_ellipsis(arg0: float | int, arg1: collections.abc.Callable[..., float | int]) -> float" ) # Union assert (