diff --git a/einx/_src/frontend/impl/jax.py b/einx/_src/frontend/impl/jax.py index 4dd4217..e704f53 100644 --- a/einx/_src/frontend/impl/jax.py +++ b/einx/_src/frontend/impl/jax.py @@ -1,3 +1,6 @@ +from typing import ParamSpec, TypeVar, Concatenate, cast +from collections.abc import Callable + import einx._src.tracer as tracer import einx._src.adapter as adapter from ..api import api @@ -48,7 +51,11 @@ def get_shape(tensor): } -def adapt_with_vmap(op, signature=None): +P = ParamSpec("P") +R = TypeVar("R") + + +def adapt_with_vmap(op: Callable[P, R], signature=None) -> Callable[Concatenate[str, P], R]: iskwarg = _make_iskwarg(op) jax = tracer.signature.jax() @@ -61,7 +68,7 @@ def adapt_with_vmap(op, signature=None): op = adapter.namedtensor_calltensorfactory.op(op, expected_type=jax.numpy.ndarray) op = adapter.einx_from_namedtensor.op(op, iskwarg=iskwarg, el_op=signature, implicit_output="bijective") - return api(op, backend=types.SimpleNamespace(**_get_backend_kwargs())) + return cast(Callable[Concatenate[str, P], R], api(op, backend=types.SimpleNamespace(**_get_backend_kwargs()))) adapt_with_vmap.__doc__ = _make_doc_adapt_with_vmap("jax", "``jax.vmap``") diff --git a/einx/_src/frontend/impl/mlx.py b/einx/_src/frontend/impl/mlx.py index 7a4e7ed..e9d2706 100644 --- a/einx/_src/frontend/impl/mlx.py +++ b/einx/_src/frontend/impl/mlx.py @@ -1,3 +1,6 @@ +from typing import ParamSpec, TypeVar, Concatenate, cast +from collections.abc import Callable + import einx._src.tracer as tracer import einx._src.adapter as adapter from ..types import Tensor @@ -39,7 +42,11 @@ def get_shape(tensor): } -def adapt_with_vmap(op, signature=None): +P = ParamSpec("P") +R = TypeVar("R") + + +def adapt_with_vmap(op: Callable[P, R], signature=None) -> Callable[Concatenate[str, P], R]: iskwarg = _make_iskwarg(op) mlx = tracer.signature.mlx() @@ -52,7 +59,7 @@ def adapt_with_vmap(op, signature=None): op = adapter.namedtensor_calltensorfactory.op(op, expected_type=mlx.core.array) op = adapter.einx_from_namedtensor.op(op, iskwarg=iskwarg, el_op=signature, implicit_output="bijective") - return api(op, backend=types.SimpleNamespace(**_get_backend_kwargs())) + return cast(Callable[Concatenate[str, P], R], api(op, backend=types.SimpleNamespace(**_get_backend_kwargs()))) adapt_with_vmap.__doc__ = _make_doc_adapt_with_vmap("mlx", "``mlx.core.vmap``") diff --git a/einx/_src/frontend/impl/torch.py b/einx/_src/frontend/impl/torch.py index 3c3fa52..1a6da73 100644 --- a/einx/_src/frontend/impl/torch.py +++ b/einx/_src/frontend/impl/torch.py @@ -1,3 +1,6 @@ +from typing import ParamSpec, TypeVar, Concatenate, cast +from collections.abc import Callable + import einx._src.tracer as tracer import einx._src.adapter as adapter from ..api import api @@ -75,7 +78,11 @@ def get_shape(tensor): } -def adapt_with_vmap(op, signature=None): +P = ParamSpec("P") +R = TypeVar("R") + + +def adapt_with_vmap(op: Callable[P, R], signature=None) -> Callable[Concatenate[str, P], R]: _raise_on_invalid_version() iskwarg = _make_iskwarg(op) @@ -100,7 +107,7 @@ def adapt_with_vmap(op, signature=None): torch.compiler.allow_in_graph(op) - return op + return cast(Callable[Concatenate[str, P], R], op) adapt_with_vmap.__doc__ = _make_doc_adapt_with_vmap("torch", "``torch.vmap``")