Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/xarray_einstats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Stats, linear algebra and einops for xarray."""

from __future__ import annotations
from contextlib import contextmanager
from collections.abc import Iterable

import numpy as np
import xarray as xr
Expand All @@ -9,6 +11,7 @@
from .accessors import LinAlgAccessor, EinopsAccessor

__all__ = [
"default_linalg_dims",
"einsum",
"einsum_path",
"matmul",
Expand Down Expand Up @@ -188,3 +191,38 @@ def ones_ref(*args, dims, dtype=None):
empty_ref, zeros_ref
"""
return _create_ref(*args, dims=dims, np_creator=np.ones, dtype=dtype)


@contextmanager
def default_linalg_dims(func_or_dims):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the body can be kept mostly as is, the only change is I would define it inside the linalg.py file so it is also imported from xarray_einstats.linalg instead of being available at the top level. We might then want to remove the linalg from the name but I am also fine keeping it.

"""Context manager to temporarily set the default dimensions for linalg functions.

Safer alternative to monkey patching the `get_default_dims` function in `linalg` module,
as it ensures that the original function is restored even if an error occurs within the context.

Parameters
----------
func_or_dims : callable or iterable
If a callable is provided, it should take the same arguments as `get_default_dims`
and return the default dimensions based on those arguments.
If an iterable is provided, it will be used as the default dimensions
regardless of the input arguments.

Yields
------
None
Comment on lines +211 to +213

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpydoc doesn't say anything on documenting context managers but I would remove this section and instead add an examples section or a seealso pointing to the docs on using it. In general I would also update all the docs to remove any reference to monkeypatching and rely on the context manager only (I am more than happy to do this myself if you prefer).

"""
from xarray_einstats import linalg

original_get_default_dims = linalg.get_default_dims

def func(*args):
if isinstance(func_or_dims, Iterable):
return func_or_dims
return func_or_dims(*args)

linalg.get_default_dims = func
try:
yield
finally:
linalg.get_default_dims = original_get_default_dims
7 changes: 7 additions & 0 deletions src/xarray_einstats/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from __future__ import annotations

from collections.abc import Hashable, Iterable, Sequence
from contextlib import contextmanager
from typing import Any, Callable, Generator

import numpy as np
import xarray
Expand All @@ -13,6 +15,7 @@ from .accessors import EinopsAccessor, LinAlgAccessor
from .linalg import einsum, einsum_path, matmul

__all__ = [
"default_linalg_dims",
"einsum",
"einsum_path",
"matmul",
Expand Down Expand Up @@ -52,3 +55,7 @@ def ones_ref(
dims: Sequence[Hashable],
dtype: np.typing.DTypeLike | None = ...,
) -> xarray.DataArray: ...
@contextmanager
def default_linalg_dims(
func_or_dims: Callable | Iterable,
) -> Generator[None, Any, None]: ...