Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[MASTER]
extension-pkg-allow-list=numpy,mkl_random.mklrand

[TYPECHECK]
generated-members=RandomState,min,max
9 changes: 9 additions & 0 deletions mkl_random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@

from mkl_random import interfaces

from ._patch import (
is_patched,
mkl_random,
monkey_patch,
patched_names,
restore,
use_in_numpy,
)

__all__ = [
"MKLRandomState",
"RandomState",
Expand Down
292 changes: 292 additions & 0 deletions mkl_random/src/_patch.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# Copyright (c) 2019, Intel Corporation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need Cython here

like with mkl_fft's patch, we can do it entirely on the Python level. We aren't using any of NumPy/Python C API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, overlooked that thanks for mentioning it

#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of Intel Corporation nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# distutils: language = c
# cython: language_level=3

"""
Patch NumPy's `numpy.random` symbols to use mkl_random implementations.

This is attribute-level monkey patching. It can replace legacy APIs like
`numpy.random.RandomState` and global distribution functions, but it does not
replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully
compatible replacements.
"""

from contextlib import ContextDecorator
from threading import Lock, local

import numpy as _np

from . import mklrand as _mr


cdef tuple _DEFAULT_NAMES = (
# Legacy seeding / state
"seed",
"get_state",
"set_state",
"RandomState",

# Common global sampling helpers
"random",
"random_sample",
"sample",
"rand",
"randn",
"bytes",
Comment on lines +53 to +59
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_DEFAULT_NAMES includes very common NumPy APIs like numpy.random.random and numpy.random.sample, but mkl_random.mklrand doesn't define random/sample (it has random_sample). With the current hasattr(_mr, name) check, those names will silently remain unpatched in non-strict mode, which undermines the goal of routing NumPy calls through mkl_random. Consider adding random/sample aliases on the mklrand side or explicitly mapping NumPy's random/sample to mklrand.random_sample during patching.

Copilot uses AI. Check for mistakes.

# Integers
"randint",

# Common distributions (only patched if present on both sides)
"standard_normal",
"normal",
"uniform",
"exponential",
"gamma",
"beta",
"chisquare",
"f",
"lognormal",
"laplace",
"logistic",
"multivariate_normal",
"poisson",
"power",
"rayleigh",
"triangular",
"vonmises",
"wald",
"weibull",
"zipf",

# Permutations / choices
"choice",
"permutation",
"shuffle",
)


class _GlobalPatch:
def __init__(self):
self._lock = Lock()
self._patch_count = 0
self._numpy_module = None
self._requested_names = None
self._originals = {}
self._patched = ()
self._tls = local()

def _normalize_names(self, names):
if names is None:
names = _DEFAULT_NAMES
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_normalize_names() uses tuple(names), so passing a string like names="normal" will be treated as an iterable of characters and attempt to patch n, o, etc. This leads to confusing behavior and hard-to-diagnose failures. Consider rejecting str/bytes explicitly (and/or accepting a single name by wrapping it in a 1-tuple) and validating that all entries are strings.

Suggested change
names = _DEFAULT_NAMES
names = _DEFAULT_NAMES
elif isinstance(names, str):
# Treat a single string as a single name, not an iterable of characters
names = (names,)
elif isinstance(names, bytes):
# Attribute names must be strings; reject bytes explicitly
raise TypeError("names must be a string or an iterable of strings, not bytes")
else:
names = tuple(names)
for name in names:
if not isinstance(name, str):
raise TypeError("All names must be strings; got {!r} of type {!r}".format(name, type(name)))

Copilot uses AI. Check for mistakes.
return tuple(names)

def _validate_module(self, numpy_module):
if not hasattr(numpy_module, "random"):
raise TypeError(
"Expected a numpy-like module with a `.random` attribute."
)

def _apply_patch(self, numpy_module, names, strict):
np_random = numpy_module.random
originals = {}
patched = []
missing = []
for name in names:
if not hasattr(np_random, name) or not hasattr(_mr, name):
missing.append(name)
continue
originals[name] = getattr(np_random, name)
setattr(np_random, name, getattr(_mr, name))
patched.append(name)

if strict and missing:
for name, value in originals.items():
setattr(np_random, name, value)
raise AttributeError(
"Could not patch these names (missing on numpy.random or "
"mkl_random.mklrand): "
+ ", ".join([str(x) for x in missing])
)

self._numpy_module = numpy_module
self._requested_names = names
self._originals = originals
self._patched = tuple(patched)

def do_patch(
self,
numpy_module=None,
names=None,
strict=False,
verbose=False,
):
if numpy_module is None:
numpy_module = _np
names = self._normalize_names(names)
self._validate_module(numpy_module)
strict = bool(strict)

with self._lock:
local_count = getattr(self._tls, "local_count", 0)
if self._patch_count == 0:
self._apply_patch(numpy_module, names, strict)
else:
if self._numpy_module is not numpy_module:
raise RuntimeError(
"Already patched a different numpy module; "
"call restore() first."
)
if names != self._requested_names:
raise RuntimeError(
"Already patched with a different names set; "
"call restore() first."
)
self._patch_count += 1
self._tls.local_count = local_count + 1

def do_restore(self, verbose=False):
with self._lock:
local_count = getattr(self._tls, "local_count", 0)
if local_count <= 0:
if verbose:
print(
"Warning: restore called more times than monkey_patch "
"in this thread."
)
Comment on lines +175 to +180
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restore(verbose=True) currently emits a warning via print(...). Elsewhere in this project warnings are surfaced via warnings.warn(...) (often with an appropriate category/stacklevel). Using warnings.warn here would make the warning easier to filter/test and avoids writing directly to stdout.

Copilot uses AI. Check for mistakes.
return

self._tls.local_count = local_count - 1
self._patch_count -= 1
if self._patch_count == 0:
np_random = self._numpy_module.random
for name, value in self._originals.items():
setattr(np_random, name, value)
self._numpy_module = None
self._requested_names = None
self._originals = {}
self._patched = ()

def is_patched(self):
with self._lock:
return self._patch_count > 0

def patched_names(self):
with self._lock:
return list(self._patched)


_patch = _GlobalPatch()


def monkey_patch(numpy_module=None, names=None, strict=False, verbose=False):
"""
Enables using mkl_random in the given NumPy module by patching
`numpy.random`.

Examples
--------
>>> import numpy as np
>>> import mkl_random
>>> mkl_random.is_patched()
False
>>> mkl_random.monkey_patch(np)
>>> mkl_random.is_patched()
True
>>> mkl_random.restore()
>>> mkl_random.is_patched()
False
"""
_patch.do_patch(
numpy_module=numpy_module,
names=names,
strict=bool(strict),
verbose=bool(verbose),
)


def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False):
"""
Backward-compatible alias for monkey_patch().
"""
monkey_patch(
numpy_module=numpy_module,
names=names,
strict=strict,
verbose=verbose,
)
Comment on lines +232 to +241
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_in_numpy() is documented as a backward-compatible alias for monkey_patch(), but it's currently a separate wrapper function. This breaks identity-based expectations (and the added test asserts use_in_numpy is monkey_patch). Either make use_in_numpy a true alias (assignment) or adjust the test and docs to reflect wrapper semantics.

Suggested change
def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False):
"""
Backward-compatible alias for monkey_patch().
"""
monkey_patch(
numpy_module=numpy_module,
names=names,
strict=strict,
verbose=verbose,
)
# Backward-compatible alias for monkey_patch().
use_in_numpy = monkey_patch

Copilot uses AI. Check for mistakes.


def restore(verbose=False):
"""
Disables using mkl_random in NumPy by restoring the original
`numpy.random` symbols.
"""
_patch.do_restore(verbose=bool(verbose))


def is_patched():
"""
Returns whether NumPy has been patched with mkl_random.
"""
return _patch.is_patched()


def patched_names():
"""
Returns the names actually patched in `numpy.random`.
"""
return _patch.patched_names()


class mkl_random(ContextDecorator):
"""
Context manager and decorator to temporarily patch NumPy's `numpy.random`.

Examples
--------
>>> import numpy as np
>>> import mkl_random
>>> with mkl_random.mkl_random(np):
... x = np.random.normal(size=10)
"""
def __init__(self, numpy_module=None, names=None, strict=False):
self._numpy_module = numpy_module
self._names = names
self._strict = strict

def __enter__(self):
monkey_patch(
numpy_module=self._numpy_module,
names=self._names,
strict=self._strict,
)
return self

def __exit__(self, *exc):
restore()
return False
Loading
Loading