-
-
Notifications
You must be signed in to change notification settings - Fork 9
Add a context manager as an alternative to monkey patching #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1a5f808
d0e0fdf
3d338bb
39b4bf9
96fd801
21a607c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| """Stats, linear algebra and einops for xarray.""" | ||
|
|
||
| from __future__ import annotations | ||
| from contextlib import contextmanager | ||
| from collections.abc import Iterable | ||
|
|
||
| import numpy as np | ||
| import xarray as xr | ||
|
|
@@ -9,6 +11,7 @@ | |
| from .accessors import LinAlgAccessor, EinopsAccessor | ||
|
|
||
| __all__ = [ | ||
| "default_linalg_dims", | ||
| "einsum", | ||
| "einsum_path", | ||
| "matmul", | ||
|
|
@@ -188,3 +191,38 @@ def ones_ref(*args, dims, dtype=None): | |
| empty_ref, zeros_ref | ||
| """ | ||
| return _create_ref(*args, dims=dims, np_creator=np.ones, dtype=dtype) | ||
|
|
||
|
|
||
| @contextmanager | ||
| def default_linalg_dims(func_or_dims): | ||
| """Context manager to temporarily set the default dimensions for linalg functions. | ||
|
|
||
| Safer alternative to monkey patching the `get_default_dims` function in `linalg` module, | ||
| as it ensures that the original function is restored even if an error occurs within the context. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| func_or_dims : callable or iterable | ||
| If a callable is provided, it should take the same arguments as `get_default_dims` | ||
| and return the default dimensions based on those arguments. | ||
| If an iterable is provided, it will be used as the default dimensions | ||
| regardless of the input arguments. | ||
|
|
||
| Yields | ||
| ------ | ||
| None | ||
|
Comment on lines
+211
to
+213
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. numpydoc doesn't say anything on documenting context managers but I would remove this section and instead add an examples section or a seealso pointing to the docs on using it. In general I would also update all the docs to remove any reference to monkeypatching and rely on the context manager only (I am more than happy to do this myself if you prefer). |
||
| """ | ||
| from xarray_einstats import linalg | ||
|
|
||
| original_get_default_dims = linalg.get_default_dims | ||
|
|
||
| def func(*args): | ||
| if isinstance(func_or_dims, Iterable): | ||
| return func_or_dims | ||
| return func_or_dims(*args) | ||
|
|
||
| linalg.get_default_dims = func | ||
| try: | ||
| yield | ||
| finally: | ||
| linalg.get_default_dims = original_get_default_dims | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the body can be kept mostly as is, the only change is I would define it inside the
linalg.pyfile so it is also imported fromxarray_einstats.linalginstead of being available at the top level. We might then want to remove the linalg from the name but I am also fine keeping it.