Skip to content

[Bug Report] TransformerBridge leaks non-permanent backward hooks after hooks(...) context exit #1323

@SamuelePunzo

Description

@SamuelePunzo

Description

I found a backward-hook cleanup issue in TransformerBridge.

After a plain with model.hooks(bwd_hooks=[...]) context exits, TransformerBridge still has non-permanent backward hooks attached. The equivalent HookedTransformer path cleans them up correctly.

Expected behavior

After a with model.hooks(...): context exits, non-permanent hooks added inside the context should be removed.
This is what happens with HookedTransformer.

Actual behavior

TransformerBridge retains non-permanent backward hooks after the context exits.

On my local repro:

  • HookedTransformer leaked hooks: {}
  • TransformerBridge leaked hooks: {'hook_embed': 1, 'embed.hook_out': 1}

Minimal repro

import copy
from pathlib import Path

import torch

from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import transformer_lens
from transformer_lens.config import HookedTransformerConfig
from transformer_lens.model_bridge import TransformerBridge
from transformer_lens.pretrained.weight_conversions.llama import convert_llama_weights


LLAMA_ROPE_BASE = 500000.0
PROMPT = "The cat sat on the mat"
HOOK_NAME = "hook_embed"


def llama_snapshot() -> str:
    return str(
        Path.home()
        / ".cache"
        / "huggingface"
        / "hub"
        / "models--meta-llama--Llama-3.2-3B"
        / "snapshots"
        / "13afe5124825b4f3751f836b40dafda64c1ed062"
    )


def make_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(llama_snapshot(), local_files_only=True)
    tokenizer.padding_side = "right"
    return tokenizer


def make_hf_model(tokenizer):
    config = LlamaConfig(
        vocab_size=len(tokenizer),
        hidden_size=128,
        intermediate_size=256,
        num_hidden_layers=2,
        num_attention_heads=4,
        num_key_value_heads=4,
        max_position_embeddings=128,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        rms_norm_eps=1e-5,
        hidden_act="silu",
        rope_theta=LLAMA_ROPE_BASE,
        attention_bias=False,
        mlp_bias=False,
        tie_word_embeddings=False,
    )
    model = LlamaForCausalLM(config)
    model.eval()
    return model


def make_base_state(tokenizer):
    torch.manual_seed(0)
    model = make_hf_model(tokenizer)
    return {name: tensor.detach().clone() for name, tensor in model.state_dict().items()}


def load_hf_model_from_state(tokenizer, state_dict):
    model = make_hf_model(tokenizer)
    model.load_state_dict(state_dict)
    model.eval()
    return model


def load_hooked_transformer(tokenizer, state_dict):
    hf_model = load_hf_model_from_state(tokenizer, state_dict)
    cfg = HookedTransformerConfig.from_dict(
        {
            "d_model": hf_model.config.hidden_size,
            "d_head": hf_model.config.hidden_size // hf_model.config.num_attention_heads,
            "n_heads": hf_model.config.num_attention_heads,
            "d_mlp": hf_model.config.intermediate_size,
            "n_layers": hf_model.config.num_hidden_layers,
            "n_ctx": hf_model.config.max_position_embeddings,
            "eps": hf_model.config.rms_norm_eps,
            "d_vocab": hf_model.config.vocab_size,
            "act_fn": hf_model.config.hidden_act,
            "n_key_value_heads": None,
            "normalization_type": "RMS",
            "positional_embedding_type": "rotary",
            "rotary_adjacent_pairs": False,
            "rotary_dim": hf_model.config.hidden_size // hf_model.config.num_attention_heads,
            "rotary_base": LLAMA_ROPE_BASE,
            "final_rms": True,
            "gated_mlp": True,
            "dtype": torch.float32,
            "device": "cpu",
            "n_devices": 1,
            "model_name": "tiny-llama-local",
            "original_architecture": "LlamaForCausalLM",
            "tokenizer_name": llama_snapshot(),
            "default_prepend_bos": True,
            "init_weights": False,
        }
    )
    processed_state = convert_llama_weights(hf_model, cfg)
    model = transformer_lens.HookedTransformer(
        cfg,
        tokenizer=copy.deepcopy(tokenizer),
        move_to_device=False,
    )
    model.load_and_process_state_dict(processed_state)
    model.eval()
    return model


def load_transformer_bridge(tokenizer, state_dict):
    hf_model = load_hf_model_from_state(tokenizer, state_dict)
    bridge = TransformerBridge.boot_transformers(
        "meta-llama/Llama-3.2-3B",
        hf_model=hf_model,
        tokenizer=copy.deepcopy(tokenizer),
        device="cpu",
    )
    bridge.enable_compatibility_mode()
    bridge.set_use_attn_result(True)
    bridge.set_use_split_qkv_input(True)
    bridge.set_use_hook_mlp_in(True)
    bridge.eval()
    return bridge


def build_inputs(tokenizer):
    encoded = tokenizer([PROMPT], return_tensors="pt", add_special_tokens=False)
    bos = torch.full(
        (encoded.input_ids.size(0), 1),
        tokenizer.bos_token_id,
        dtype=encoded.input_ids.dtype,
    )
    tokens = torch.cat([bos, encoded.input_ids], dim=1)
    attention_mask = torch.ones_like(tokens)
    input_lengths = attention_mask.sum(1)
    return tokens, attention_mask, input_lengths


def run_plain_backward_context(model, tokens, attention_mask, input_lengths):
    device = next(model.parameters()).device
    tokens = tokens.to(device)
    attention_mask = attention_mask.to(device)
    input_lengths = input_lengths.to(device)

    def bwd_hook(grad, hook):
        return None

    model.zero_grad()
    with model.hooks(bwd_hooks=[(HOOK_NAME, bwd_hook)]):
        logits = model(tokens, attention_mask=attention_mask)
        batch = torch.arange(logits.size(0), device=logits.device)
        final_pos = input_lengths - 1
        logits[batch, final_pos, 0].sum().backward()


def non_permanent_bwd_hooks(model):
    return {
        name: len(handles)
        for name, handles in model.list_hooks(
            dir="bwd",
            including_permanent=False,
        ).items()
        if handles
    }


tokenizer = make_tokenizer()
state_dict = make_base_state(tokenizer)
tokens, attention_mask, input_lengths = build_inputs(tokenizer)

hooked = load_hooked_transformer(tokenizer, state_dict)
bridge = load_transformer_bridge(tokenizer, state_dict)

run_plain_backward_context(hooked, tokens, attention_mask, input_lengths)
run_plain_backward_context(bridge, tokens, attention_mask, input_lengths)

print("HookedTransformer leaked hooks:", non_permanent_bwd_hooks(hooked))
print("TransformerBridge leaked hooks:", non_permanent_bwd_hooks(bridge))

Output

HookedTransformer leaked hooks: {}
TransformerBridge leaked hooks: {'hook_embed': 1, 'embed.hook_out': 1}

Notes

A slightly larger EAP-IG parity investigation led me to this, but this repro is independent of EAP-IG and seems to isolate a TransformerBridge-side hook cleanup bug.

I also observed the same pattern on other hook names, not just hook_embed.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions