Skip to content

Commit 0b016cd

Browse files
committed
test(hooks): add coverage for hook plugin system
Closes #1033 Signed-off-by: Marcel Nadzam <mnadzam@redhat.com> Co-Authored-By: Cursor
1 parent d86f938 commit 0b016cd

1 file changed

Lines changed: 259 additions & 0 deletions

File tree

tests/test_hooks.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
from __future__ import annotations
2+
3+
import pathlib
4+
import typing
5+
from importlib.metadata import EntryPoint
6+
from unittest.mock import MagicMock, Mock, patch
7+
8+
import pytest
9+
from packaging.requirements import Requirement
10+
11+
from fromager import hooks
12+
13+
14+
@pytest.fixture(autouse=True)
15+
def _clear_hook_cache() -> typing.Generator[None, None, None]:
16+
hooks._mgrs.clear()
17+
yield
18+
hooks._mgrs.clear()
19+
20+
21+
def _make_fake_ext(plugin: typing.Callable[..., typing.Any]) -> Mock:
22+
ext = Mock()
23+
ext.plugin = plugin
24+
return ext
25+
26+
27+
def test_die_on_plugin_load_failure_raises() -> None:
28+
ep = EntryPoint(name="bad_plugin", value="some.module:func", group="fromager.hooks")
29+
original_err = ImportError("no such module")
30+
31+
with pytest.raises(RuntimeError, match="bad_plugin") as exc_info:
32+
hooks._die_on_plugin_load_failure(
33+
mgr=Mock(),
34+
ep=ep,
35+
err=original_err,
36+
)
37+
38+
assert exc_info.value.__cause__ is original_err
39+
40+
41+
@patch("fromager.hooks.hook.HookManager")
42+
def test_get_hooks_creates_manager(mock_hm_cls: Mock) -> None:
43+
fake_mgr = MagicMock()
44+
fake_mgr.names.return_value = ["my_hook"]
45+
mock_hm_cls.return_value = fake_mgr
46+
47+
result = hooks._get_hooks("post_build")
48+
49+
mock_hm_cls.assert_called_once_with(
50+
namespace="fromager.hooks",
51+
name="post_build",
52+
invoke_on_load=False,
53+
on_load_failure_callback=hooks._die_on_plugin_load_failure,
54+
)
55+
assert result is fake_mgr
56+
57+
58+
@patch("fromager.hooks.hook.HookManager")
59+
def test_get_hooks_returns_cached(mock_hm_cls: Mock) -> None:
60+
fake_mgr = MagicMock()
61+
fake_mgr.names.return_value = ["my_hook"]
62+
mock_hm_cls.return_value = fake_mgr
63+
64+
first = hooks._get_hooks("post_build")
65+
second = hooks._get_hooks("post_build")
66+
67+
mock_hm_cls.assert_called_once()
68+
assert first is second
69+
70+
71+
@patch("fromager.hooks._get_hooks")
72+
def test_run_post_build_hooks_exception_propagates(mock_get: Mock) -> None:
73+
def bad_plugin(**kwargs: typing.Any) -> None:
74+
raise ValueError("hook failed")
75+
76+
fake_mgr = MagicMock()
77+
fake_mgr.names.return_value = ["bad"]
78+
fake_mgr.__iter__ = lambda self: iter([_make_fake_ext(bad_plugin)])
79+
mock_get.return_value = fake_mgr
80+
81+
with pytest.raises(ValueError, match="hook failed"):
82+
hooks.run_post_build_hooks(
83+
ctx=Mock(),
84+
req=Requirement("pkg"),
85+
dist_name="pkg",
86+
dist_version="1.0",
87+
sdist_filename=pathlib.Path("/tmp/a.tar.gz"),
88+
wheel_filename=pathlib.Path("/tmp/a.whl"),
89+
)
90+
91+
92+
@patch("fromager.hooks._get_hooks")
93+
def test_run_post_build_hooks_calls_plugin(mock_get: Mock) -> None:
94+
called_with: dict[str, typing.Any] = {}
95+
96+
def fake_plugin(**kwargs: typing.Any) -> None:
97+
called_with.update(kwargs)
98+
99+
fake_mgr = MagicMock()
100+
fake_mgr.names.return_value = ["my_hook"]
101+
fake_mgr.__iter__ = lambda self: iter([_make_fake_ext(fake_plugin)])
102+
mock_get.return_value = fake_mgr
103+
104+
ctx = Mock()
105+
req = Requirement("numpy>=1.0")
106+
sdist = pathlib.Path("/tmp/numpy-1.0.tar.gz")
107+
wheel = pathlib.Path("/tmp/numpy-1.0-cp312-linux_x86_64.whl")
108+
109+
hooks.run_post_build_hooks(
110+
ctx=ctx,
111+
req=req,
112+
dist_name="numpy",
113+
dist_version="1.0",
114+
sdist_filename=sdist,
115+
wheel_filename=wheel,
116+
)
117+
118+
mock_get.assert_called_once_with("post_build")
119+
assert called_with["ctx"] is ctx
120+
assert called_with["req"] is req
121+
assert called_with["dist_name"] == "numpy"
122+
assert called_with["dist_version"] == "1.0"
123+
assert called_with["sdist_filename"] is sdist
124+
assert called_with["wheel_filename"] is wheel
125+
126+
127+
@patch("fromager.hooks._get_hooks")
128+
def test_run_post_bootstrap_hooks_exception_propagates(mock_get: Mock) -> None:
129+
def bad_plugin(**kwargs: typing.Any) -> None:
130+
raise ValueError("hook failed")
131+
132+
fake_mgr = MagicMock()
133+
fake_mgr.names.return_value = ["bad"]
134+
fake_mgr.__iter__ = lambda self: iter([_make_fake_ext(bad_plugin)])
135+
mock_get.return_value = fake_mgr
136+
137+
with pytest.raises(ValueError, match="hook failed"):
138+
hooks.run_post_bootstrap_hooks(
139+
ctx=Mock(),
140+
req=Requirement("pkg"),
141+
dist_name="pkg",
142+
dist_version="1.0",
143+
sdist_filename=None,
144+
wheel_filename=None,
145+
)
146+
147+
148+
@patch("fromager.hooks._get_hooks")
149+
def test_run_post_bootstrap_hooks_calls_plugin(mock_get: Mock) -> None:
150+
called_with: dict[str, typing.Any] = {}
151+
152+
def fake_plugin(**kwargs: typing.Any) -> None:
153+
called_with.update(kwargs)
154+
155+
fake_mgr = MagicMock()
156+
fake_mgr.names.return_value = ["my_hook"]
157+
fake_mgr.__iter__ = lambda self: iter([_make_fake_ext(fake_plugin)])
158+
mock_get.return_value = fake_mgr
159+
160+
ctx = Mock()
161+
req = Requirement("flask>=2.0")
162+
163+
hooks.run_post_bootstrap_hooks(
164+
ctx=ctx,
165+
req=req,
166+
dist_name="flask",
167+
dist_version="2.0",
168+
sdist_filename=None,
169+
wheel_filename=None,
170+
)
171+
172+
mock_get.assert_called_once_with("post_bootstrap")
173+
assert called_with["ctx"] is ctx
174+
assert called_with["req"] is req
175+
assert called_with["dist_name"] == "flask"
176+
assert called_with["dist_version"] == "2.0"
177+
assert called_with["sdist_filename"] is None
178+
assert called_with["wheel_filename"] is None
179+
180+
181+
@patch("fromager.hooks._get_hooks")
182+
def test_run_prebuilt_wheel_hooks_exception_propagates(mock_get: Mock) -> None:
183+
def bad_plugin(**kwargs: typing.Any) -> None:
184+
raise ValueError("hook failed")
185+
186+
fake_mgr = MagicMock()
187+
fake_mgr.names.return_value = ["bad"]
188+
fake_mgr.__iter__ = lambda self: iter([_make_fake_ext(bad_plugin)])
189+
mock_get.return_value = fake_mgr
190+
191+
with pytest.raises(ValueError, match="hook failed"):
192+
hooks.run_prebuilt_wheel_hooks(
193+
ctx=Mock(),
194+
req=Requirement("pkg"),
195+
dist_name="pkg",
196+
dist_version="1.0",
197+
wheel_filename=pathlib.Path("/tmp/a.whl"),
198+
)
199+
200+
201+
@patch("fromager.hooks._get_hooks")
202+
def test_run_prebuilt_wheel_hooks_calls_plugin(mock_get: Mock) -> None:
203+
called_with: dict[str, typing.Any] = {}
204+
205+
def fake_plugin(**kwargs: typing.Any) -> None:
206+
called_with.update(kwargs)
207+
208+
fake_mgr = MagicMock()
209+
fake_mgr.names.return_value = ["my_hook"]
210+
fake_mgr.__iter__ = lambda self: iter([_make_fake_ext(fake_plugin)])
211+
mock_get.return_value = fake_mgr
212+
213+
ctx = Mock()
214+
req = Requirement("torch>=2.0")
215+
wheel = pathlib.Path("/tmp/torch-2.0-cp312-linux_x86_64.whl")
216+
217+
hooks.run_prebuilt_wheel_hooks(
218+
ctx=ctx,
219+
req=req,
220+
dist_name="torch",
221+
dist_version="2.0",
222+
wheel_filename=wheel,
223+
)
224+
225+
mock_get.assert_called_once_with("prebuilt_wheel")
226+
assert called_with["ctx"] is ctx
227+
assert called_with["req"] is req
228+
assert called_with["dist_name"] == "torch"
229+
assert called_with["dist_version"] == "2.0"
230+
assert called_with["wheel_filename"] is wheel
231+
assert "sdist_filename" not in called_with
232+
233+
234+
@patch("fromager.hooks.overrides._get_dist_info", return_value=("mypkg", "1.0.0"))
235+
@patch("fromager.hooks.extension.ExtensionManager")
236+
def test_log_hooks_logs_each_extension(
237+
mock_em_cls: Mock,
238+
mock_dist_info: Mock,
239+
) -> None:
240+
ext_a = Mock()
241+
ext_a.name = "post_build"
242+
ext_a.module_name = "my_plugins.hooks"
243+
244+
ext_b = Mock()
245+
ext_b.name = "post_bootstrap"
246+
ext_b.module_name = "other_plugins.hooks"
247+
248+
mock_em_cls.return_value = [ext_a, ext_b]
249+
250+
hooks.log_hooks()
251+
252+
mock_em_cls.assert_called_once_with(
253+
namespace="fromager.hooks",
254+
invoke_on_load=False,
255+
on_load_failure_callback=hooks._die_on_plugin_load_failure,
256+
)
257+
assert mock_dist_info.call_count == 2
258+
mock_dist_info.assert_any_call("my_plugins.hooks")
259+
mock_dist_info.assert_any_call("other_plugins.hooks")

0 commit comments

Comments
 (0)