Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
:nosignatures:
:toctree: generated

angle
apply_where
argpartition
at
Expand Down
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from ._lib._at import at
from ._lib._funcs import (
angle,
apply_where,
broadcast_shapes,
default_dtype,
Expand All @@ -32,6 +33,7 @@
# pylint: disable=duplicate-code
__all__ = [
"__version__",
"angle",
"apply_where",
"argpartition",
"at",
Expand Down
49 changes: 49 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._utils._typing import Array, Device, DType

__all__ = [
"angle",
"apply_where",
"atleast_nd",
"broadcast_shapes",
Expand Down Expand Up @@ -818,3 +819,51 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
b = xp.reshape(b, (-1,))
# XXX: `sparse` returns NumPy arrays from `unique_values`
return xp.asarray(xp.unique_values(xp.concat([a, b])))


def angle(z: Array, /, *, deg: bool = False, xp: ModuleType | None = None) -> Array:
"""
Return the angle of the complex argument.

Parameters
----------
z : Array
Input array.
deg : bool, optional
Return angle in degrees if True, radians if False (default).
xp : array_namespace, optional
The standard-compatible namespace for `z`. Default: infer.

Returns
-------
array
The counterclockwise angle from the positive real axis on the complex
plane in the range ``(-pi, pi]``.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a note to the documentation stating what happens if the input x is real? i.e. interpreted as x + 0j

Notes
-----
A real input x is interpreted as x + 0j

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> xpx.angle(xp.asarray([1.0, 1.0j, 1 + 1j]), xp=xp)
Array([0. , 1.57079633, 0.78539816], dtype=array_api_strict.float64)
Comment thread
prady0t marked this conversation as resolved.
>>> xpx.angle(xp.asarray([1.0, 1.0j, 1 + 1j]), deg=True, xp=xp)
Array([ 0., 90., 45.], dtype=array_api_strict.float64)
"""
Comment thread
lucascolley marked this conversation as resolved.
if xp is None:
xp = array_namespace(z)
if xp.isdtype(z.dtype, "complex floating"):
zimag = xp.imag(z)
zreal = xp.real(z)
else:
if not xp.isdtype(z.dtype, "real floating"):
z = xp.astype(z, default_dtype(xp, device=_compat.device(z)))
zimag = xp.zeros_like(z)
zreal = z
a = xp.atan2(zimag, zreal)
if deg:
a = a * 180 / xp.pi
return a
76 changes: 76 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing_extensions import override

from array_api_extra import (
angle,
apply_where,
argpartition,
at,
Expand Down Expand Up @@ -1881,3 +1882,78 @@ def test_device(self, xp: ModuleType, device: Device):
a = xp.asarray([-1, 1, 0], device=device)
b = xp.asarray([2, -2, 0], device=device)
assert get_device(union1d(a, b)) == device


class TestAngle:
def test_simple(self, xp: ModuleType):
Comment thread
lucascolley marked this conversation as resolved.
a = xp.asarray([1, 0])
res = angle(a)
expected = xp.asarray([0.0, 0.0], dtype=res.dtype)
xp_assert_equal(res, expected)

def test_basic(self, xp: ModuleType):
Comment thread
lucascolley marked this conversation as resolved.
x = xp.asarray(
[
1 + 3j,
np.sqrt(2) / 2.0 + 1j * np.sqrt(2) / 2,
1,
1j,
-1,
-1j,
1 - 3j,
-1 + 3j,
],
dtype=xp.complex128,
)
expected = xp.asarray(
[
np.arctan(3.0 / 1.0),
np.arctan(1.0),
0,
np.pi / 2,
np.pi,
-np.pi / 2.0,
-np.arctan(3.0 / 1.0),
np.pi - np.arctan(3.0 / 1.0),
],
dtype=xp.float64,
)
xp_assert_close(angle(x, xp=xp), expected, rtol=0, atol=1e-11)
xp_assert_close(
angle(x, deg=True, xp=xp),
expected * 180 / xp.pi,
rtol=0,
atol=1e-11,
)

def test_real(self, xp: ModuleType):
x = xp.asarray([0.0, -0.0, 1.0, -1.0])
expected = xp.asarray([0.0, xp.pi, 0.0, xp.pi], dtype=x.dtype)
xp_assert_close(angle(x, xp=xp), expected)

def test_complex(self, xp: ModuleType):
Comment thread
lucascolley marked this conversation as resolved.
a = xp.asarray([1 + 1j, 1 - 1j, -1 + 1j, -1 - 1j])
expected = xp.asarray([xp.pi / 4, -xp.pi / 4, 3 * xp.pi / 4, -3 * xp.pi / 4])
res = angle(a, xp=xp)
xp_assert_equal(res, expected)

def test_integral(self, xp: ModuleType):
x = xp.asarray([0, -1, 1], dtype=xp.int32)
actual = angle(x, xp=xp)
expected = xp.asarray(
[0.0, xp.pi, 0.0], dtype=default_dtype(xp, device=get_device(x))
)
xp_assert_close(actual, expected)

def test_2d(self, xp: ModuleType):
a = xp.asarray([[1 + 1j, 1 - 1j], [-1 + 1j, -1 - 1j]])
expected = xp.asarray(
[[xp.pi / 4, -xp.pi / 4], [3 * xp.pi / 4, -3 * xp.pi / 4]]
)
res = angle(a, xp=xp)
xp_assert_equal(res, expected)

Comment thread
lucascolley marked this conversation as resolved.
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
Comment thread
lucascolley marked this conversation as resolved.
def test_device(self, xp: ModuleType, device: Device):
a = xp.asarray([1 + 1j], device=device)
assert get_device(angle(a)) == device