Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tests/acceptance/model_bridge/compatibility/test_backward_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,42 @@ def sum_bridge_grads(grad, hook=None):
f"Gradient sums should be identical but differ by "
f"{abs(hooked_grad_sum - bridge_grad_sum).item():.6f}"
)


def test_transformer_bridge_hooks_context_cleans_up_backward_hooks(
gpt2_hooked_unprocessed, gpt2_bridge_compat_no_processing
):
"""Regression test for backward-hook cleanup on context exit."""
hooked_model = gpt2_hooked_unprocessed
bridge_model = gpt2_bridge_compat_no_processing
hooked_hook = hooked_model.blocks[0].hook_resid_post
bridge_hook = bridge_model.blocks[0].hook_resid_post
test_input = torch.tensor([[1, 2, 3]])

def noop_backward_hook(grad, hook=None):
return None

hooked_model.zero_grad()
with hooked_model.hooks(bwd_hooks=[("blocks.0.hook_resid_post", noop_backward_hook)]):
hooked_model(test_input).sum().backward()

bridge_model.zero_grad()
with bridge_model.hooks(bwd_hooks=[("blocks.0.hook_resid_post", noop_backward_hook)]):
bridge_model(test_input).sum().backward()

assert not hooked_hook.has_hooks(dir="bwd", including_permanent=False)
assert not bridge_hook.has_hooks(dir="bwd", including_permanent=False)


def test_transformer_bridge_reset_hooks_removes_backward_hooks(gpt2_bridge_compat_no_processing):
"""Regression test for bridge reset_hooks removing backward hooks."""
bridge_model = gpt2_bridge_compat_no_processing
backward_hook = bridge_model.blocks[0].hook_resid_post

backward_hook.add_hook(lambda grad, hook=None: None, dir="bwd")

assert backward_hook.has_hooks(dir="bwd", including_permanent=False)

bridge_model.reset_hooks()

assert not backward_hook.has_hooks(dir="bwd", including_permanent=False)
18 changes: 18 additions & 0 deletions tests/acceptance/model_bridge/compatibility/test_run_with_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,21 @@ def test_run_with_cache_accepts_1d_tensor(self, gpt2_bridge_compat_no_processing
assert torch.allclose(
cache_1d["blocks.0.hook_mlp_out"], cache_2d["blocks.0.hook_mlp_out"], atol=1e-5
)


def test_transformer_bridge_run_with_cache_preserves_existing_backward_hooks(
gpt2_bridge_compat_no_processing,
):
"""run_with_cache should not remove unrelated backward hooks on the same HookPoint."""
bridge_model = gpt2_bridge_compat_no_processing
target_hook = bridge_model.blocks[0].hook_resid_post

target_hook.add_hook(lambda grad, hook=None: None, dir="bwd")

assert target_hook.has_hooks(dir="bwd", including_permanent=False)

bridge_model.run_with_cache(torch.tensor([[1, 2, 3]]), names_filter="blocks.0.hook_resid_post")

assert target_hook.has_hooks(dir="bwd", including_permanent=False)

bridge_model.reset_hooks()
18 changes: 9 additions & 9 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,7 +2081,7 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
raise e
finally:
for hp, _ in hooks:
hp.remove_hooks()
hp.remove_hooks(dir="fwd")
if self.compatibility_mode == True:
reverse_aliases = {}
for old_name, new_name in aliases.items():
Expand Down Expand Up @@ -2148,7 +2148,7 @@ def run_with_hooks(
Returns:
Model output
"""
added_hooks: List[Tuple[HookPoint, str]] = []
added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = []
effective_stop_layer = None
if stop_at_layer is not None and hasattr(self, "blocks"):
if stop_at_layer < 0:
Expand All @@ -2174,7 +2174,7 @@ def add_hook_to_point(
hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
else:
hook_point.add_hook(hook_fn, dir=dir)
added_hooks.append((hook_point, name))
added_hooks.append((hook_point, dir))

if stop_at_layer is not None and hasattr(self, "blocks"):
if stop_at_layer < 0:
Expand Down Expand Up @@ -2243,8 +2243,8 @@ def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn):
return output
finally:
if reset_hooks_end:
for hook_point, name in added_hooks:
hook_point.remove_hooks()
for hook_point, direction in added_hooks:
hook_point.remove_hooks(dir=direction)

def _generate_tokens(
self,
Expand Down Expand Up @@ -3306,7 +3306,7 @@ def hooks(self, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts

@contextmanager
def _hooks_context():
added_hooks: List[Tuple[HookPoint, str]] = []
added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = []

def add_hook_to_point(
hook_point: HookPoint,
Expand All @@ -3322,7 +3322,7 @@ def add_hook_to_point(
hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
else:
hook_point.add_hook(hook_fn, dir=dir)
added_hooks.append((hook_point, name))
added_hooks.append((hook_point, dir))

def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool):
direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd"
Expand Down Expand Up @@ -3355,8 +3355,8 @@ def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool
yield self
finally:
if reset_hooks_end:
for hook_point, name in added_hooks:
hook_point.remove_hooks()
for hook_point, direction in added_hooks:
hook_point.remove_hooks(dir=direction)

return _hooks_context()

Expand Down
8 changes: 4 additions & 4 deletions transformer_lens/model_bridge/generalized_components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,12 @@ def remove_hooks(self, hook_name: str | None = None) -> None:
hook_name: Name of the hook point to remove. If None, removes all hooks.
"""
if hook_name is None:
self.hook_in.remove_hooks()
self.hook_out.remove_hooks()
self.hook_in.remove_hooks(dir="both")
self.hook_out.remove_hooks(dir="both")
elif hook_name == "output":
self.hook_out.remove_hooks()
self.hook_out.remove_hooks(dir="both")
elif hook_name == "input":
self.hook_in.remove_hooks()
self.hook_in.remove_hooks(dir="both")
else:
raise ValueError(
f"Hook name '{hook_name}' not supported. Supported names are 'output' and 'input'."
Expand Down
Loading