diff --git a/tests/acceptance/model_bridge/compatibility/test_backward_hooks.py b/tests/acceptance/model_bridge/compatibility/test_backward_hooks.py index 682f519c3..03f868124 100644 --- a/tests/acceptance/model_bridge/compatibility/test_backward_hooks.py +++ b/tests/acceptance/model_bridge/compatibility/test_backward_hooks.py @@ -51,3 +51,42 @@ def sum_bridge_grads(grad, hook=None): f"Gradient sums should be identical but differ by " f"{abs(hooked_grad_sum - bridge_grad_sum).item():.6f}" ) + + +def test_transformer_bridge_hooks_context_cleans_up_backward_hooks( + gpt2_hooked_unprocessed, gpt2_bridge_compat_no_processing +): + """Regression test for backward-hook cleanup on context exit.""" + hooked_model = gpt2_hooked_unprocessed + bridge_model = gpt2_bridge_compat_no_processing + hooked_hook = hooked_model.blocks[0].hook_resid_post + bridge_hook = bridge_model.blocks[0].hook_resid_post + test_input = torch.tensor([[1, 2, 3]]) + + def noop_backward_hook(grad, hook=None): + return None + + hooked_model.zero_grad() + with hooked_model.hooks(bwd_hooks=[("blocks.0.hook_resid_post", noop_backward_hook)]): + hooked_model(test_input).sum().backward() + + bridge_model.zero_grad() + with bridge_model.hooks(bwd_hooks=[("blocks.0.hook_resid_post", noop_backward_hook)]): + bridge_model(test_input).sum().backward() + + assert not hooked_hook.has_hooks(dir="bwd", including_permanent=False) + assert not bridge_hook.has_hooks(dir="bwd", including_permanent=False) + + +def test_transformer_bridge_reset_hooks_removes_backward_hooks(gpt2_bridge_compat_no_processing): + """Regression test for bridge reset_hooks removing backward hooks.""" + bridge_model = gpt2_bridge_compat_no_processing + backward_hook = bridge_model.blocks[0].hook_resid_post + + backward_hook.add_hook(lambda grad, hook=None: None, dir="bwd") + + assert backward_hook.has_hooks(dir="bwd", including_permanent=False) + + bridge_model.reset_hooks() + + assert not backward_hook.has_hooks(dir="bwd", including_permanent=False) diff --git a/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py b/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py index f653de956..d91db5e67 100644 --- a/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py +++ b/tests/acceptance/model_bridge/compatibility/test_run_with_cache.py @@ -77,3 +77,21 @@ def test_run_with_cache_accepts_1d_tensor(self, gpt2_bridge_compat_no_processing assert torch.allclose( cache_1d["blocks.0.hook_mlp_out"], cache_2d["blocks.0.hook_mlp_out"], atol=1e-5 ) + + +def test_transformer_bridge_run_with_cache_preserves_existing_backward_hooks( + gpt2_bridge_compat_no_processing, +): + """run_with_cache should not remove unrelated backward hooks on the same HookPoint.""" + bridge_model = gpt2_bridge_compat_no_processing + target_hook = bridge_model.blocks[0].hook_resid_post + + target_hook.add_hook(lambda grad, hook=None: None, dir="bwd") + + assert target_hook.has_hooks(dir="bwd", including_permanent=False) + + bridge_model.run_with_cache(torch.tensor([[1, 2, 3]]), names_filter="blocks.0.hook_resid_post") + + assert target_hook.has_hooks(dir="bwd", including_permanent=False) + + bridge_model.reset_hooks() diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 30d58363c..440ad04b8 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -2081,7 +2081,7 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: raise e finally: for hp, _ in hooks: - hp.remove_hooks() + hp.remove_hooks(dir="fwd") if self.compatibility_mode == True: reverse_aliases = {} for old_name, new_name in aliases.items(): @@ -2148,7 +2148,7 @@ def run_with_hooks( Returns: Model output """ - added_hooks: List[Tuple[HookPoint, str]] = [] + added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = [] effective_stop_layer = None if stop_at_layer is not None and hasattr(self, "blocks"): if stop_at_layer < 0: @@ -2174,7 +2174,7 @@ def add_hook_to_point( hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list) else: hook_point.add_hook(hook_fn, dir=dir) - added_hooks.append((hook_point, name)) + added_hooks.append((hook_point, dir)) if stop_at_layer is not None and hasattr(self, "blocks"): if stop_at_layer < 0: @@ -2243,8 +2243,8 @@ def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn): return output finally: if reset_hooks_end: - for hook_point, name in added_hooks: - hook_point.remove_hooks() + for hook_point, direction in added_hooks: + hook_point.remove_hooks(dir=direction) def _generate_tokens( self, @@ -3306,7 +3306,7 @@ def hooks(self, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts @contextmanager def _hooks_context(): - added_hooks: List[Tuple[HookPoint, str]] = [] + added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = [] def add_hook_to_point( hook_point: HookPoint, @@ -3322,7 +3322,7 @@ def add_hook_to_point( hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list) else: hook_point.add_hook(hook_fn, dir=dir) - added_hooks.append((hook_point, name)) + added_hooks.append((hook_point, dir)) def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool): direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd" @@ -3355,8 +3355,8 @@ def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool yield self finally: if reset_hooks_end: - for hook_point, name in added_hooks: - hook_point.remove_hooks() + for hook_point, direction in added_hooks: + hook_point.remove_hooks(dir=direction) return _hooks_context() diff --git a/transformer_lens/model_bridge/generalized_components/base.py b/transformer_lens/model_bridge/generalized_components/base.py index ae6787b67..f891b81fb 100644 --- a/transformer_lens/model_bridge/generalized_components/base.py +++ b/transformer_lens/model_bridge/generalized_components/base.py @@ -188,12 +188,12 @@ def remove_hooks(self, hook_name: str | None = None) -> None: hook_name: Name of the hook point to remove. If None, removes all hooks. """ if hook_name is None: - self.hook_in.remove_hooks() - self.hook_out.remove_hooks() + self.hook_in.remove_hooks(dir="both") + self.hook_out.remove_hooks(dir="both") elif hook_name == "output": - self.hook_out.remove_hooks() + self.hook_out.remove_hooks(dir="both") elif hook_name == "input": - self.hook_in.remove_hooks() + self.hook_in.remove_hooks(dir="both") else: raise ValueError( f"Hook name '{hook_name}' not supported. Supported names are 'output' and 'input'."