Skip to content

Commit c2cdc3b

Browse files
committed
test(hooks): add coverage for hook plugin system
Closes #1033 Signed-off-by: Marcel Nadzam <mnadzam@redhat.com> Co-Authored-By: Cursor
1 parent d86f938 commit c2cdc3b

File tree

1 file changed

+215
-0
lines changed

1 file changed

+215
-0
lines changed

tests/test_hooks.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from __future__ import annotations
2+
3+
import pathlib
4+
import typing
5+
from importlib.metadata import EntryPoint
6+
from unittest.mock import MagicMock, Mock, patch
7+
8+
import pytest
9+
from packaging.requirements import Requirement
10+
11+
from fromager import hooks
12+
13+
14+
@pytest.fixture(autouse=True)
15+
def _clear_hook_cache() -> typing.Generator[None, None, None]:
16+
hooks._mgrs.clear()
17+
yield
18+
hooks._mgrs.clear()
19+
20+
21+
def _make_fake_ext(plugin: typing.Callable[..., typing.Any]) -> Mock:
22+
ext = Mock()
23+
ext.plugin = plugin
24+
return ext
25+
26+
27+
def test_die_on_plugin_load_failure_raises() -> None:
28+
ep = EntryPoint(name="bad_plugin", value="some.module:func", group="fromager.hooks")
29+
original_err = ImportError("no such module")
30+
31+
with pytest.raises(RuntimeError, match="bad_plugin") as exc_info:
32+
hooks._die_on_plugin_load_failure(
33+
mgr=Mock(),
34+
ep=ep,
35+
err=original_err,
36+
)
37+
38+
assert exc_info.value.__cause__ is original_err
39+
40+
41+
@patch("fromager.hooks.hook.HookManager")
42+
def test_get_hooks_creates_manager(mock_hm_cls: Mock) -> None:
43+
fake_mgr = MagicMock()
44+
fake_mgr.names.return_value = ["my_hook"]
45+
mock_hm_cls.return_value = fake_mgr
46+
47+
result = hooks._get_hooks("post_build")
48+
49+
mock_hm_cls.assert_called_once_with(
50+
namespace="fromager.hooks",
51+
name="post_build",
52+
invoke_on_load=False,
53+
on_load_failure_callback=hooks._die_on_plugin_load_failure,
54+
)
55+
assert result is fake_mgr
56+
57+
58+
@patch("fromager.hooks.hook.HookManager")
59+
def test_get_hooks_returns_cached(mock_hm_cls: Mock) -> None:
60+
fake_mgr = MagicMock()
61+
fake_mgr.names.return_value = ["my_hook"]
62+
mock_hm_cls.return_value = fake_mgr
63+
64+
first = hooks._get_hooks("post_build")
65+
second = hooks._get_hooks("post_build")
66+
67+
mock_hm_cls.assert_called_once()
68+
assert first is second
69+
70+
71+
@patch("fromager.hooks._get_hooks")
72+
def test_run_post_build_hooks_exception_propagates(mock_get: Mock) -> None:
73+
def bad_plugin(**kwargs: typing.Any) -> None:
74+
raise ValueError("hook failed")
75+
76+
fake_mgr = MagicMock()
77+
fake_mgr.names.return_value = ["bad"]
78+
fake_mgr.__iter__ = Mock(return_value=iter([_make_fake_ext(bad_plugin)]))
79+
mock_get.return_value = fake_mgr
80+
81+
with pytest.raises(ValueError, match="hook failed"):
82+
hooks.run_post_build_hooks(
83+
ctx=Mock(),
84+
req=Requirement("pkg"),
85+
dist_name="pkg",
86+
dist_version="1.0",
87+
sdist_filename=pathlib.Path("/tmp/a.tar.gz"),
88+
wheel_filename=pathlib.Path("/tmp/a.whl"),
89+
)
90+
91+
92+
@patch("fromager.hooks._get_hooks")
93+
def test_run_post_build_hooks_calls_plugin(mock_get: Mock) -> None:
94+
called_with: dict[str, typing.Any] = {}
95+
96+
def fake_plugin(**kwargs: typing.Any) -> None:
97+
called_with.update(kwargs)
98+
99+
fake_mgr = MagicMock()
100+
fake_mgr.names.return_value = ["my_hook"]
101+
fake_mgr.__iter__ = Mock(return_value=iter([_make_fake_ext(fake_plugin)]))
102+
mock_get.return_value = fake_mgr
103+
104+
ctx = Mock()
105+
req = Requirement("numpy>=1.0")
106+
sdist = pathlib.Path("/tmp/numpy-1.0.tar.gz")
107+
wheel = pathlib.Path("/tmp/numpy-1.0-cp312-linux_x86_64.whl")
108+
109+
hooks.run_post_build_hooks(
110+
ctx=ctx,
111+
req=req,
112+
dist_name="numpy",
113+
dist_version="1.0",
114+
sdist_filename=sdist,
115+
wheel_filename=wheel,
116+
)
117+
118+
assert called_with["ctx"] is ctx
119+
assert called_with["req"] is req
120+
assert called_with["dist_name"] == "numpy"
121+
assert called_with["dist_version"] == "1.0"
122+
assert called_with["sdist_filename"] is sdist
123+
assert called_with["wheel_filename"] is wheel
124+
125+
126+
@patch("fromager.hooks._get_hooks")
127+
def test_run_post_bootstrap_hooks_calls_plugin(mock_get: Mock) -> None:
128+
called_with: dict[str, typing.Any] = {}
129+
130+
def fake_plugin(**kwargs: typing.Any) -> None:
131+
called_with.update(kwargs)
132+
133+
fake_mgr = MagicMock()
134+
fake_mgr.names.return_value = ["my_hook"]
135+
fake_mgr.__iter__ = Mock(return_value=iter([_make_fake_ext(fake_plugin)]))
136+
mock_get.return_value = fake_mgr
137+
138+
ctx = Mock()
139+
req = Requirement("flask>=2.0")
140+
141+
hooks.run_post_bootstrap_hooks(
142+
ctx=ctx,
143+
req=req,
144+
dist_name="flask",
145+
dist_version="2.0",
146+
sdist_filename=None,
147+
wheel_filename=None,
148+
)
149+
150+
assert called_with["ctx"] is ctx
151+
assert called_with["req"] is req
152+
assert called_with["dist_name"] == "flask"
153+
assert called_with["dist_version"] == "2.0"
154+
assert called_with["sdist_filename"] is None
155+
assert called_with["wheel_filename"] is None
156+
157+
158+
@patch("fromager.hooks._get_hooks")
159+
def test_run_prebuilt_wheel_hooks_calls_plugin(mock_get: Mock) -> None:
160+
called_with: dict[str, typing.Any] = {}
161+
162+
def fake_plugin(**kwargs: typing.Any) -> None:
163+
called_with.update(kwargs)
164+
165+
fake_mgr = MagicMock()
166+
fake_mgr.names.return_value = ["my_hook"]
167+
fake_mgr.__iter__ = Mock(return_value=iter([_make_fake_ext(fake_plugin)]))
168+
mock_get.return_value = fake_mgr
169+
170+
ctx = Mock()
171+
req = Requirement("torch>=2.0")
172+
wheel = pathlib.Path("/tmp/torch-2.0-cp312-linux_x86_64.whl")
173+
174+
hooks.run_prebuilt_wheel_hooks(
175+
ctx=ctx,
176+
req=req,
177+
dist_name="torch",
178+
dist_version="2.0",
179+
wheel_filename=wheel,
180+
)
181+
182+
assert called_with["ctx"] is ctx
183+
assert called_with["req"] is req
184+
assert called_with["dist_name"] == "torch"
185+
assert called_with["dist_version"] == "2.0"
186+
assert called_with["wheel_filename"] is wheel
187+
assert "sdist_filename" not in called_with
188+
189+
190+
@patch("fromager.hooks.overrides._get_dist_info", return_value=("mypkg", "1.0.0"))
191+
@patch("fromager.hooks.extension.ExtensionManager")
192+
def test_log_hooks_logs_each_extension(
193+
mock_em_cls: Mock,
194+
mock_dist_info: Mock,
195+
) -> None:
196+
ext_a = Mock()
197+
ext_a.name = "post_build"
198+
ext_a.module_name = "my_plugins.hooks"
199+
200+
ext_b = Mock()
201+
ext_b.name = "post_bootstrap"
202+
ext_b.module_name = "other_plugins.hooks"
203+
204+
mock_em_cls.return_value = [ext_a, ext_b]
205+
206+
hooks.log_hooks()
207+
208+
mock_em_cls.assert_called_once_with(
209+
namespace="fromager.hooks",
210+
invoke_on_load=False,
211+
on_load_failure_callback=hooks._die_on_plugin_load_failure,
212+
)
213+
assert mock_dist_info.call_count == 2
214+
mock_dist_info.assert_any_call("my_plugins.hooks")
215+
mock_dist_info.assert_any_call("other_plugins.hooks")

0 commit comments

Comments
 (0)