Skip to content

Commit 10a438b

Browse files
mnadzamLalatenduMohanty
authored andcommitted
test(hooks): add coverage for hook plugin system
Closes #1033 Signed-off-by: Marcel Nadzam <mnadzam@redhat.com> Co-Authored-By: Cursor
1 parent 767422d commit 10a438b

1 file changed

Lines changed: 249 additions & 0 deletions

File tree

tests/test_hooks.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
from __future__ import annotations
2+
3+
import pathlib
4+
import typing
5+
from importlib.metadata import EntryPoint
6+
from unittest.mock import MagicMock, Mock, patch
7+
8+
import pytest
9+
from packaging.requirements import Requirement
10+
11+
from fromager import hooks
12+
13+
14+
@pytest.fixture(autouse=True)
15+
def _clear_hook_cache() -> typing.Generator[None, None, None]:
16+
hooks._mgrs.clear()
17+
yield
18+
hooks._mgrs.clear()
19+
20+
21+
def _make_fake_ext(plugin: typing.Callable[..., typing.Any]) -> Mock:
22+
ext = Mock()
23+
ext.plugin = plugin
24+
return ext
25+
26+
27+
def _make_fake_mgr(plugins: list[typing.Callable[..., typing.Any]]) -> MagicMock:
28+
"""Return a mock HookManager that iterates over the given plugins."""
29+
fake_mgr = MagicMock()
30+
fake_mgr.names.return_value = [p.__name__ for p in plugins]
31+
fake_mgr.__iter__ = lambda self: iter([_make_fake_ext(p) for p in plugins])
32+
return fake_mgr
33+
34+
35+
def test_die_on_plugin_load_failure_raises() -> None:
36+
ep = EntryPoint(name="bad_plugin", value="some.module:func", group="fromager.hooks")
37+
original_err = ImportError("no such module")
38+
39+
with pytest.raises(RuntimeError, match="bad_plugin") as exc_info:
40+
hooks._die_on_plugin_load_failure(
41+
mgr=Mock(),
42+
ep=ep,
43+
err=original_err,
44+
)
45+
46+
assert exc_info.value.__cause__ is original_err
47+
48+
49+
@patch("fromager.hooks.hook.HookManager")
50+
def test_get_hooks_creates_manager(mock_hm_cls: Mock) -> None:
51+
fake_mgr = MagicMock()
52+
fake_mgr.names.return_value = ["my_hook"]
53+
mock_hm_cls.return_value = fake_mgr
54+
55+
result = hooks._get_hooks("post_build")
56+
57+
mock_hm_cls.assert_called_once_with(
58+
namespace="fromager.hooks",
59+
name="post_build",
60+
invoke_on_load=False,
61+
on_load_failure_callback=hooks._die_on_plugin_load_failure,
62+
)
63+
assert result is fake_mgr
64+
65+
66+
@patch("fromager.hooks.hook.HookManager")
67+
def test_get_hooks_returns_cached(mock_hm_cls: Mock) -> None:
68+
fake_mgr = MagicMock()
69+
fake_mgr.names.return_value = ["my_hook"]
70+
mock_hm_cls.return_value = fake_mgr
71+
72+
first = hooks._get_hooks("post_build")
73+
second = hooks._get_hooks("post_build")
74+
75+
mock_hm_cls.assert_called_once()
76+
assert first is second
77+
78+
79+
@patch("fromager.hooks._get_hooks")
80+
def test_run_post_build_hooks_exception_propagates(mock_get: Mock) -> None:
81+
def bad_plugin(**kwargs: typing.Any) -> None:
82+
raise ValueError("hook failed")
83+
84+
mock_get.return_value = _make_fake_mgr([bad_plugin])
85+
86+
with pytest.raises(ValueError, match="hook failed"):
87+
hooks.run_post_build_hooks(
88+
ctx=Mock(),
89+
req=Requirement("pkg"),
90+
dist_name="pkg",
91+
dist_version="1.0",
92+
sdist_filename=pathlib.Path("/tmp/a.tar.gz"),
93+
wheel_filename=pathlib.Path("/tmp/a.whl"),
94+
)
95+
96+
97+
@patch("fromager.hooks._get_hooks")
98+
def test_run_post_build_hooks_calls_plugin(mock_get: Mock) -> None:
99+
called_with: dict[str, typing.Any] = {}
100+
101+
def fake_plugin(**kwargs: typing.Any) -> None:
102+
called_with.update(kwargs)
103+
104+
mock_get.return_value = _make_fake_mgr([fake_plugin])
105+
106+
ctx = Mock()
107+
req = Requirement("numpy>=1.0")
108+
sdist = pathlib.Path("/tmp/numpy-1.0.tar.gz")
109+
wheel = pathlib.Path("/tmp/numpy-1.0-cp312-linux_x86_64.whl")
110+
111+
hooks.run_post_build_hooks(
112+
ctx=ctx,
113+
req=req,
114+
dist_name="numpy",
115+
dist_version="1.0",
116+
sdist_filename=sdist,
117+
wheel_filename=wheel,
118+
)
119+
120+
mock_get.assert_called_once_with("post_build")
121+
assert called_with["ctx"] is ctx
122+
assert called_with["req"] is req
123+
assert called_with["dist_name"] == "numpy"
124+
assert called_with["dist_version"] == "1.0"
125+
assert called_with["sdist_filename"] is sdist
126+
assert called_with["wheel_filename"] is wheel
127+
128+
129+
@patch("fromager.hooks._get_hooks")
130+
def test_run_post_bootstrap_hooks_exception_propagates(mock_get: Mock) -> None:
131+
def bad_plugin(**kwargs: typing.Any) -> None:
132+
raise ValueError("hook failed")
133+
134+
mock_get.return_value = _make_fake_mgr([bad_plugin])
135+
136+
with pytest.raises(ValueError, match="hook failed"):
137+
hooks.run_post_bootstrap_hooks(
138+
ctx=Mock(),
139+
req=Requirement("pkg"),
140+
dist_name="pkg",
141+
dist_version="1.0",
142+
sdist_filename=None,
143+
wheel_filename=None,
144+
)
145+
146+
147+
@patch("fromager.hooks._get_hooks")
148+
def test_run_post_bootstrap_hooks_calls_plugin(mock_get: Mock) -> None:
149+
called_with: dict[str, typing.Any] = {}
150+
151+
def fake_plugin(**kwargs: typing.Any) -> None:
152+
called_with.update(kwargs)
153+
154+
mock_get.return_value = _make_fake_mgr([fake_plugin])
155+
156+
ctx = Mock()
157+
req = Requirement("flask>=2.0")
158+
159+
hooks.run_post_bootstrap_hooks(
160+
ctx=ctx,
161+
req=req,
162+
dist_name="flask",
163+
dist_version="2.0",
164+
sdist_filename=None,
165+
wheel_filename=None,
166+
)
167+
168+
mock_get.assert_called_once_with("post_bootstrap")
169+
assert called_with["ctx"] is ctx
170+
assert called_with["req"] is req
171+
assert called_with["dist_name"] == "flask"
172+
assert called_with["dist_version"] == "2.0"
173+
assert called_with["sdist_filename"] is None
174+
assert called_with["wheel_filename"] is None
175+
176+
177+
@patch("fromager.hooks._get_hooks")
178+
def test_run_prebuilt_wheel_hooks_exception_propagates(mock_get: Mock) -> None:
179+
def bad_plugin(**kwargs: typing.Any) -> None:
180+
raise ValueError("hook failed")
181+
182+
mock_get.return_value = _make_fake_mgr([bad_plugin])
183+
184+
with pytest.raises(ValueError, match="hook failed"):
185+
hooks.run_prebuilt_wheel_hooks(
186+
ctx=Mock(),
187+
req=Requirement("pkg"),
188+
dist_name="pkg",
189+
dist_version="1.0",
190+
wheel_filename=pathlib.Path("/tmp/a.whl"),
191+
)
192+
193+
194+
@patch("fromager.hooks._get_hooks")
195+
def test_run_prebuilt_wheel_hooks_calls_plugin(mock_get: Mock) -> None:
196+
called_with: dict[str, typing.Any] = {}
197+
198+
def fake_plugin(**kwargs: typing.Any) -> None:
199+
called_with.update(kwargs)
200+
201+
mock_get.return_value = _make_fake_mgr([fake_plugin])
202+
203+
ctx = Mock()
204+
req = Requirement("torch>=2.0")
205+
wheel = pathlib.Path("/tmp/torch-2.0-cp312-linux_x86_64.whl")
206+
207+
hooks.run_prebuilt_wheel_hooks(
208+
ctx=ctx,
209+
req=req,
210+
dist_name="torch",
211+
dist_version="2.0",
212+
wheel_filename=wheel,
213+
)
214+
215+
mock_get.assert_called_once_with("prebuilt_wheel")
216+
assert called_with["ctx"] is ctx
217+
assert called_with["req"] is req
218+
assert called_with["dist_name"] == "torch"
219+
assert called_with["dist_version"] == "2.0"
220+
assert called_with["wheel_filename"] is wheel
221+
assert "sdist_filename" not in called_with
222+
223+
224+
@patch("fromager.hooks.overrides._get_dist_info", return_value=("mypkg", "1.0.0"))
225+
@patch("fromager.hooks.extension.ExtensionManager")
226+
def test_log_hooks_logs_each_extension(
227+
mock_em_cls: Mock,
228+
mock_dist_info: Mock,
229+
) -> None:
230+
ext_a = Mock()
231+
ext_a.name = "post_build"
232+
ext_a.module_name = "my_plugins.hooks"
233+
234+
ext_b = Mock()
235+
ext_b.name = "post_bootstrap"
236+
ext_b.module_name = "other_plugins.hooks"
237+
238+
mock_em_cls.return_value = [ext_a, ext_b]
239+
240+
hooks.log_hooks()
241+
242+
mock_em_cls.assert_called_once_with(
243+
namespace="fromager.hooks",
244+
invoke_on_load=False,
245+
on_load_failure_callback=hooks._die_on_plugin_load_failure,
246+
)
247+
assert mock_dist_info.call_count == 2
248+
mock_dist_info.assert_any_call("my_plugins.hooks")
249+
mock_dist_info.assert_any_call("other_plugins.hooks")

0 commit comments

Comments
 (0)