task: add patch methods for mkl_random#90
task: add patch methods for mkl_random#90jharlow-intel wants to merge 6 commits intoIntelPython:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds monkey-patching functionality to mkl_random, allowing users to temporarily or permanently replace numpy.random functions with their mkl_random equivalents. The implementation provides both imperative (monkey_patch(), restore()) and context manager (mkl_random.mkl_random()) interfaces.
Changes:
- Adds new
_patch.pyxCython module implementing patching logic with thread-local state tracking - Extends
setup.pyto build the new_patchextension module - Exports patch functions (
monkey_patch,use_in_numpy,restore,is_patched,patched_names,mkl_random) in__init__.py - Adds comprehensive test suite in
test_patch.pycovering basic patching, restoration, and context manager functionality
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 13 comments.
| File | Description |
|---|---|
| setup.py | Adds Extension configuration for mkl_random._patch module |
| mkl_random/src/_patch.pyx | Implements core patching logic with thread-local storage, patch/unpatch methods, and context manager |
| mkl_random/init.py | Exports patch-related functions to public API |
| mkl_random/tests/test_patch.py | Adds tests for patching, restoration, context manager, and patched function behavior |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
setup.py
Outdated
| Extension( | ||
| "mkl_random._patch", | ||
| sources=[join("mkl_random", "src", "_patch.pyx")], | ||
| include_dirs=[np.get_include()], | ||
| define_macros=defs + [("NDEBUG", None)], | ||
| language="c", | ||
| ) |
There was a problem hiding this comment.
The _patch extension includes define_macros with PY_ARRAY_UNIQUE_SYMBOL, but _patch.pyx doesn't use the NumPy C API (no cimport numpy or usage of numpy C structures). While this doesn't cause harm, these macros are unnecessary for this extension and could be removed for clarity. The extension only needs include_dirs=[np.get_include()] to compile successfully with Cython's NumPy support.
There was a problem hiding this comment.
@antonwolfy @ndgrigorian I'm unsure about this one honestly, looks like removing PY_ARRAY_UNIQUE_SYMBOL from setup.py is ultimately unnecessary
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (1)
mkl_random/init.py:101
- The patching API is imported into the top-level package, but the new public symbols (
monkey_patch,use_in_numpy,restore,is_patched,patched_names,mkl_random) are not added to__all__. This makes the export surface inconsistent with the rest of this module (which explicitly enumerates public names) and can breakfrom mkl_random import *expectations.
from ._patch import monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random
from mkl_random import interfaces
__all__ = [
"MKLRandomState",
"RandomState",
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if local_count <= 0: | ||
| if verbose: | ||
| print( | ||
| "Warning: restore called more times than monkey_patch in this thread." | ||
| ) |
There was a problem hiding this comment.
restore(verbose=True) currently emits a warning via print(...). Elsewhere in this project warnings are surfaced via warnings.warn(...) (often with an appropriate category/stacklevel). Using warnings.warn here would make the warning easier to filter/test and avoids writing directly to stdout.
| def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False): | ||
| """ | ||
| Backward-compatible alias for monkey_patch(). | ||
| """ | ||
| monkey_patch( | ||
| numpy_module=numpy_module, | ||
| names=names, | ||
| strict=strict, | ||
| verbose=verbose, | ||
| ) |
There was a problem hiding this comment.
use_in_numpy() is documented as a backward-compatible alias for monkey_patch(), but it's currently a separate wrapper function. This breaks identity-based expectations (and the added test asserts use_in_numpy is monkey_patch). Either make use_in_numpy a true alias (assignment) or adjust the test and docs to reflect wrapper semantics.
| def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False): | |
| """ | |
| Backward-compatible alias for monkey_patch(). | |
| """ | |
| monkey_patch( | |
| numpy_module=numpy_module, | |
| names=names, | |
| strict=strict, | |
| verbose=verbose, | |
| ) | |
| # Backward-compatible alias for monkey_patch(). | |
| use_in_numpy = monkey_patch |
| mkl_random.monkey_patch(np) | ||
| mkl_random.monkey_patch(np) | ||
|
|
||
| assert mkl_random.is_patched() | ||
| assert np.random.normal is mkl_random.mklrand.normal | ||
|
|
||
| mkl_random.restore() | ||
| assert mkl_random.is_patched() | ||
| assert np.random.normal is mkl_random.mklrand.normal | ||
|
|
||
| mkl_random.restore() | ||
| assert not mkl_random.is_patched() | ||
| assert np.random.normal is orig_normal |
There was a problem hiding this comment.
test_patch_redundant_patching doesn't use a try/finally (or fixture finalizer) to guarantee mkl_random.restore() is called if an intermediate assertion fails. Because this test intentionally keeps NumPy patched after the first restore(), a failure before the final cleanup can leak global patched state into subsequent tests.
| mkl_random.monkey_patch(np) | |
| mkl_random.monkey_patch(np) | |
| assert mkl_random.is_patched() | |
| assert np.random.normal is mkl_random.mklrand.normal | |
| mkl_random.restore() | |
| assert mkl_random.is_patched() | |
| assert np.random.normal is mkl_random.mklrand.normal | |
| mkl_random.restore() | |
| assert not mkl_random.is_patched() | |
| assert np.random.normal is orig_normal | |
| try: | |
| mkl_random.monkey_patch(np) | |
| mkl_random.monkey_patch(np) | |
| assert mkl_random.is_patched() | |
| assert np.random.normal is mkl_random.mklrand.normal | |
| mkl_random.restore() | |
| assert mkl_random.is_patched() | |
| assert np.random.normal is mkl_random.mklrand.normal | |
| mkl_random.restore() | |
| assert not mkl_random.is_patched() | |
| assert np.random.normal is orig_normal | |
| finally: | |
| # Ensure that any remaining patches are removed even if the test fails | |
| while mkl_random.is_patched(): | |
| mkl_random.restore() |
| # Common global sampling helpers | ||
| "random", | ||
| "random_sample", | ||
| "sample", | ||
| "rand", | ||
| "randn", | ||
| "bytes", |
There was a problem hiding this comment.
_DEFAULT_NAMES includes very common NumPy APIs like numpy.random.random and numpy.random.sample, but mkl_random.mklrand doesn't define random/sample (it has random_sample). With the current hasattr(_mr, name) check, those names will silently remain unpatched in non-strict mode, which undermines the goal of routing NumPy calls through mkl_random. Consider adding random/sample aliases on the mklrand side or explicitly mapping NumPy's random/sample to mklrand.random_sample during patching.
|
|
||
| def _normalize_names(self, names): | ||
| if names is None: | ||
| names = _DEFAULT_NAMES |
There was a problem hiding this comment.
_normalize_names() uses tuple(names), so passing a string like names="normal" will be treated as an iterable of characters and attempt to patch n, o, etc. This leads to confusing behavior and hard-to-diagnose failures. Consider rejecting str/bytes explicitly (and/or accepting a single name by wrapping it in a 1-tuple) and validating that all entries are strings.
| names = _DEFAULT_NAMES | |
| names = _DEFAULT_NAMES | |
| elif isinstance(names, str): | |
| # Treat a single string as a single name, not an iterable of characters | |
| names = (names,) | |
| elif isinstance(names, bytes): | |
| # Attribute names must be strings; reject bytes explicitly | |
| raise TypeError("names must be a string or an iterable of strings, not bytes") | |
| else: | |
| names = tuple(names) | |
| for name in names: | |
| if not isinstance(name, str): | |
| raise TypeError("All names must be strings; got {!r} of type {!r}".format(name, type(name))) |
| @@ -0,0 +1,292 @@ | |||
| # Copyright (c) 2019, Intel Corporation | |||
There was a problem hiding this comment.
I don't think we need Cython here
like with mkl_fft's patch, we can do it entirely on the Python level. We aren't using any of NumPy/Python C API.
There was a problem hiding this comment.
Oops, overlooked that thanks for mentioning it
Gonna need some expert's eyes on this one, but it built fine and the tests look okay.
Definitely going to have to make sure it properly interfaced numpy.random though