Description
I found a backward-hook cleanup issue in TransformerBridge.
After a plain with model.hooks(bwd_hooks=[...]) context exits, TransformerBridge still has non-permanent backward hooks attached. The equivalent HookedTransformer path cleans them up correctly.
Expected behavior
After a with model.hooks(...): context exits, non-permanent hooks added inside the context should be removed.
This is what happens with HookedTransformer.
Actual behavior
TransformerBridge retains non-permanent backward hooks after the context exits.
On my local repro:
HookedTransformer leaked hooks: {}
TransformerBridge leaked hooks: {'hook_embed': 1, 'embed.hook_out': 1}
Minimal repro
import copy
from pathlib import Path
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import transformer_lens
from transformer_lens.config import HookedTransformerConfig
from transformer_lens.model_bridge import TransformerBridge
from transformer_lens.pretrained.weight_conversions.llama import convert_llama_weights
LLAMA_ROPE_BASE = 500000.0
PROMPT = "The cat sat on the mat"
HOOK_NAME = "hook_embed"
def llama_snapshot() -> str:
return str(
Path.home()
/ ".cache"
/ "huggingface"
/ "hub"
/ "models--meta-llama--Llama-3.2-3B"
/ "snapshots"
/ "13afe5124825b4f3751f836b40dafda64c1ed062"
)
def make_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(llama_snapshot(), local_files_only=True)
tokenizer.padding_side = "right"
return tokenizer
def make_hf_model(tokenizer):
config = LlamaConfig(
vocab_size=len(tokenizer),
hidden_size=128,
intermediate_size=256,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=4,
max_position_embeddings=128,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
rms_norm_eps=1e-5,
hidden_act="silu",
rope_theta=LLAMA_ROPE_BASE,
attention_bias=False,
mlp_bias=False,
tie_word_embeddings=False,
)
model = LlamaForCausalLM(config)
model.eval()
return model
def make_base_state(tokenizer):
torch.manual_seed(0)
model = make_hf_model(tokenizer)
return {name: tensor.detach().clone() for name, tensor in model.state_dict().items()}
def load_hf_model_from_state(tokenizer, state_dict):
model = make_hf_model(tokenizer)
model.load_state_dict(state_dict)
model.eval()
return model
def load_hooked_transformer(tokenizer, state_dict):
hf_model = load_hf_model_from_state(tokenizer, state_dict)
cfg = HookedTransformerConfig.from_dict(
{
"d_model": hf_model.config.hidden_size,
"d_head": hf_model.config.hidden_size // hf_model.config.num_attention_heads,
"n_heads": hf_model.config.num_attention_heads,
"d_mlp": hf_model.config.intermediate_size,
"n_layers": hf_model.config.num_hidden_layers,
"n_ctx": hf_model.config.max_position_embeddings,
"eps": hf_model.config.rms_norm_eps,
"d_vocab": hf_model.config.vocab_size,
"act_fn": hf_model.config.hidden_act,
"n_key_value_heads": None,
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_adjacent_pairs": False,
"rotary_dim": hf_model.config.hidden_size // hf_model.config.num_attention_heads,
"rotary_base": LLAMA_ROPE_BASE,
"final_rms": True,
"gated_mlp": True,
"dtype": torch.float32,
"device": "cpu",
"n_devices": 1,
"model_name": "tiny-llama-local",
"original_architecture": "LlamaForCausalLM",
"tokenizer_name": llama_snapshot(),
"default_prepend_bos": True,
"init_weights": False,
}
)
processed_state = convert_llama_weights(hf_model, cfg)
model = transformer_lens.HookedTransformer(
cfg,
tokenizer=copy.deepcopy(tokenizer),
move_to_device=False,
)
model.load_and_process_state_dict(processed_state)
model.eval()
return model
def load_transformer_bridge(tokenizer, state_dict):
hf_model = load_hf_model_from_state(tokenizer, state_dict)
bridge = TransformerBridge.boot_transformers(
"meta-llama/Llama-3.2-3B",
hf_model=hf_model,
tokenizer=copy.deepcopy(tokenizer),
device="cpu",
)
bridge.enable_compatibility_mode()
bridge.set_use_attn_result(True)
bridge.set_use_split_qkv_input(True)
bridge.set_use_hook_mlp_in(True)
bridge.eval()
return bridge
def build_inputs(tokenizer):
encoded = tokenizer([PROMPT], return_tensors="pt", add_special_tokens=False)
bos = torch.full(
(encoded.input_ids.size(0), 1),
tokenizer.bos_token_id,
dtype=encoded.input_ids.dtype,
)
tokens = torch.cat([bos, encoded.input_ids], dim=1)
attention_mask = torch.ones_like(tokens)
input_lengths = attention_mask.sum(1)
return tokens, attention_mask, input_lengths
def run_plain_backward_context(model, tokens, attention_mask, input_lengths):
device = next(model.parameters()).device
tokens = tokens.to(device)
attention_mask = attention_mask.to(device)
input_lengths = input_lengths.to(device)
def bwd_hook(grad, hook):
return None
model.zero_grad()
with model.hooks(bwd_hooks=[(HOOK_NAME, bwd_hook)]):
logits = model(tokens, attention_mask=attention_mask)
batch = torch.arange(logits.size(0), device=logits.device)
final_pos = input_lengths - 1
logits[batch, final_pos, 0].sum().backward()
def non_permanent_bwd_hooks(model):
return {
name: len(handles)
for name, handles in model.list_hooks(
dir="bwd",
including_permanent=False,
).items()
if handles
}
tokenizer = make_tokenizer()
state_dict = make_base_state(tokenizer)
tokens, attention_mask, input_lengths = build_inputs(tokenizer)
hooked = load_hooked_transformer(tokenizer, state_dict)
bridge = load_transformer_bridge(tokenizer, state_dict)
run_plain_backward_context(hooked, tokens, attention_mask, input_lengths)
run_plain_backward_context(bridge, tokens, attention_mask, input_lengths)
print("HookedTransformer leaked hooks:", non_permanent_bwd_hooks(hooked))
print("TransformerBridge leaked hooks:", non_permanent_bwd_hooks(bridge))
Output
HookedTransformer leaked hooks: {}
TransformerBridge leaked hooks: {'hook_embed': 1, 'embed.hook_out': 1}
Notes
A slightly larger EAP-IG parity investigation led me to this, but this repro is independent of EAP-IG and seems to isolate a TransformerBridge-side hook cleanup bug.
I also observed the same pattern on other hook names, not just hook_embed.
Checklist
Description
I found a backward-hook cleanup issue in
TransformerBridge.After a plain
with model.hooks(bwd_hooks=[...])context exits,TransformerBridgestill has non-permanent backward hooks attached. The equivalentHookedTransformerpath cleans them up correctly.Expected behavior
After a
with model.hooks(...):context exits, non-permanent hooks added inside the context should be removed.This is what happens with
HookedTransformer.Actual behavior
TransformerBridgeretains non-permanent backward hooks after the context exits.On my local repro:
HookedTransformer leaked hooks: {}TransformerBridge leaked hooks: {'hook_embed': 1, 'embed.hook_out': 1}Minimal repro
Output
HookedTransformer leaked hooks: {}
TransformerBridge leaked hooks: {'hook_embed': 1, 'embed.hook_out': 1}
Notes
A slightly larger EAP-IG parity investigation led me to this, but this repro is independent of EAP-IG and seems to isolate a TransformerBridge-side hook cleanup bug.
I also observed the same pattern on other hook names, not just hook_embed.
Checklist