Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 334 additions & 0 deletions .claude/skills/write-tests/SKILL.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions tests/composer/test_adapter_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from granite_switch.composer.arch import resolve_arch

pytestmark = pytest.mark.local_fast

# -- Fixtures ----------------------------------------------------------------

@pytest.fixture
Expand Down
2 changes: 2 additions & 0 deletions tests/composer/test_adapter_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from granite_switch.composer.adapter_discovery import discover_adapters, discover_adapters_from_yaml
from granite_switch.composer.arch import resolve_arch

pytestmark = pytest.mark.local_fast


@pytest.fixture
def simple_arch():
Expand Down
2 changes: 2 additions & 0 deletions tests/composer/test_arch_skinning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from granite_switch.composer.weight_transfer import _classify_base_weights
from granite_switch.composer.weight_remapper import AdapterRemapper

pytestmark = pytest.mark.local_fast


# ---------------------------------------------------------------------------
# Helpers
Expand Down
3 changes: 3 additions & 0 deletions tests/composer/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@
from types import SimpleNamespace
from unittest.mock import patch

import pytest
from jinja2 import Environment

from granite_switch.composer.tokenizer_setup import configure_chat_template

pytestmark = pytest.mark.local_fast

_PATCH_TARGET = "granite_switch.composer.tokenizer_setup._decode_alora_invocation_text"

_FIXTURES = os.path.join(os.path.dirname(__file__), "fixtures")
Expand Down
4 changes: 4 additions & 0 deletions tests/composer/test_debug_fields.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for --debug-fields flag and source propagation in compose reports."""

import pytest

from granite_switch.composer.adapter_discovery import discover_adapters

pytestmark = pytest.mark.local_fast


class TestSourcePropagation:
"""Tests for source propagation in discover_adapters()."""
Expand Down
2 changes: 2 additions & 0 deletions tests/composer/test_hf_snapshot_commit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
_extract_hf_snapshot_commit,
)

pytestmark = pytest.mark.local_fast


VALID_SHA = "6e4a75e35f1cb272e8d15b4615fb0a123398d1cf"
SHORT_SHA = VALID_SHA[:8]
Expand Down
2 changes: 2 additions & 0 deletions tests/composer/test_list_adapters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import pytest

pytestmark = pytest.mark.local_fast


FAKE_ADAPTERS = [
{"name": "rag", "technologies": ["alora", "lora"]},
Expand Down
3 changes: 3 additions & 0 deletions tests/composer/test_lora_substitute_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)


@pytest.mark.requires_model
class TestOnRealGraniteTokenizer:
"""Exercise the probe on actual Granite tokenizers. Network-dependent;
skips cleanly if the model can't be fetched."""
Expand All @@ -47,6 +48,7 @@ def test_granite_4_0_micro(self):
assert tok.convert_ids_to_tokens([sub_id])[0] == "<|start_of_role|>"


@pytest.mark.local_fast
class TestOnSyntheticTokenizer:
"""Verify the probe is generic — it returns whatever the template emits,
not a Granite-specific hardcoded token."""
Expand Down Expand Up @@ -75,6 +77,7 @@ def __call__(self, text, **kwargs):
assert _probe_lora_substitute_token_id(_FakeTokenizer()) == 42


@pytest.mark.local_fast
class TestErrorPaths:

def _minimal_tokenizer_without_template(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/composer/test_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
write_model_card,
)

pytestmark = pytest.mark.local_fast


def _fake_base_config(**overrides):
defaults = dict(
Expand Down
2 changes: 2 additions & 0 deletions tests/composer/test_save_load_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

import granite_switch.hf # noqa: F401 — registers AutoModel

pytestmark = [pytest.mark.slow, pytest.mark.requires_model]

SEED = 42


Expand Down
2 changes: 2 additions & 0 deletions tests/composer/test_tokenizer_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
configure_chat_template,
)

pytestmark = pytest.mark.local_fast

_PATCH_TARGET = "granite_switch.composer.tokenizer_setup._decode_alora_invocation_text"


Expand Down
2 changes: 2 additions & 0 deletions tests/composer/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from granite_switch.composer.validator import validate_all_parameters
from granite_switch.composer.arch import ModuleDescriptor, ArchDescriptor

pytestmark = pytest.mark.local_fast


@pytest.fixture
def simple_arch():
Expand Down
2 changes: 2 additions & 0 deletions tests/composer/test_weight_remapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from granite_switch.composer.weight_remapper import AdapterRemapper, RemapResult
from granite_switch.composer.arch import ModuleDescriptor

pytestmark = pytest.mark.local_fast


class TestRemapResult:
"""Tests for RemapResult dataclass."""
Expand Down
2 changes: 2 additions & 0 deletions tests/hf/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
make_switch_model,
)

pytestmark = pytest.mark.local_fast


# ── Helpers ───────────────────────────────────────────────────────

Expand Down
2 changes: 2 additions & 0 deletions tests/hf/test_granite4_fullsize.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def _run_equivalence(cfg_dict, *, seq_len=8):
return upstream_logits, switch_logits


pytestmark = pytest.mark.slow

_MODEL_NAMES = sorted(GRANITE4_FULLSIZE.keys())


Expand Down
3 changes: 3 additions & 0 deletions tests/hf/test_granite4_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
)


pytestmark = pytest.mark.local_fast


def _make_pair(cfg_dict):
"""Create upstream + switch model pair with transferred weights."""
torch.manual_seed(0)
Expand Down
2 changes: 2 additions & 0 deletions tests/hf/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
LoRAShapeCorrectnessCases,
)

pytestmark = pytest.mark.local_fast


# ════════════════════════════════════════════════════════════════════
# Section 1: SwitchedLoRALinear — shared mixin tests
Expand Down
2 changes: 2 additions & 0 deletions tests/hf/test_model_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from granite_switch.hf import GraniteSwitchForCausalLM
from granite_switch.hf.switch.single import SingleSwitch

pytestmark = pytest.mark.local_fast


# ── Helpers ────────────────────────────────────────────────────────

Expand Down
2 changes: 2 additions & 0 deletions tests/hf/test_qk_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from granite_switch.config import GraniteSwitchConfig
from granite_switch.hf.core.lora import GraniteLoRAEmbeddedAttention

pytestmark = pytest.mark.local_fast


# ── Helpers ────────────────────────────────────────────────────────

Expand Down
2 changes: 2 additions & 0 deletions tests/hf/test_single_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch

from granite_switch.hf.switch.single import SingleSwitch

pytestmark = pytest.mark.local_fast
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from tests.shared.single_switch_cases import (
Expand Down
2 changes: 2 additions & 0 deletions tests/hf/test_single_switch_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import pytest
import torch

pytestmark = pytest.mark.local_fast

from tests.shared.generation_models import DENSE_CFG, make_switch_model
from tests.shared.granite4_constants import (
MAX_POSITION_EMBEDDINGS,
Expand Down
2 changes: 2 additions & 0 deletions tests/hf/test_token_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from granite_switch.config import GraniteSwitchConfig
from granite_switch.hf import GraniteSwitchForCausalLM

pytestmark = pytest.mark.local_fast


def _build(num_adapters=2, substitute_ids=(1, 7)):
return GraniteSwitchConfig(
Expand Down
13 changes: 9 additions & 4 deletions tests/integration/test_hf_to_vllm_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,15 @@ def _try_import_vllm():

_VLLM_AVAILABLE = _try_import_vllm() if _CUDA_AVAILABLE else False

pytestmark = pytest.mark.skipif(
not _CUDA_AVAILABLE or not _VLLM_AVAILABLE,
reason="requires CUDA GPU and vLLM installed",
)
pytestmark = [
pytest.mark.skipif(
not _CUDA_AVAILABLE or not _VLLM_AVAILABLE,
reason="requires CUDA GPU and vLLM installed",
),
pytest.mark.vllm,
pytest.mark.gpu,
pytest.mark.slow,
]

if _VLLM_AVAILABLE:
from safetensors.torch import load_file
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from granite_switch.config import GraniteSwitchConfig

pytestmark = pytest.mark.local_fast


# ── Helper ────────────────────────────────────────────────────────────

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_config_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from granite_switch.config import GraniteSwitchConfig

pytestmark = pytest.mark.local_fast


def _valid_kwargs(num_adapters=2, **overrides):
"""Return kwargs for a valid token-exchange config."""
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_sharpness_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
MAX_POSITION_EMBEDDINGS,
)

pytestmark = pytest.mark.local_fast

# Stress adapter IDs: 1 (smallest), 16 (middle), 32 (largest supported)
ADAPTER_IDS = [1, 16, 32]

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_token_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from granite_switch.config import GraniteSwitchConfig

pytestmark = pytest.mark.local_fast


def _base(num_adapters=2, **overrides):
names = [f"a{i}" for i in range(num_adapters)]
Expand Down
1 change: 1 addition & 0 deletions tests/vllm/test_generation_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pytest

pytestmark = [pytest.mark.vllm, pytest.mark.gpu, pytest.mark.slow]

WORKER = Path(__file__).parent / "_generation_equivalence_worker.py"
TIMEOUT = 1200 # 20 min per model (download + build + 2× vLLM load + generate)
Expand Down
4 changes: 4 additions & 0 deletions tests/vllm/test_granite4_fullsize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
# ── Weight transfer tests (HF-level, no vLLM) ────────────────────


@pytest.mark.slow
class TestGranite4FullSizeWeightTransfer:
"""HF-level weight transfer at full model dimensions.

Expand Down Expand Up @@ -85,6 +86,9 @@ def _run_inner_class(class_name):


@pytest.mark.skipif(not _VLLM_AVAILABLE, reason="requires vLLM installed")
@pytest.mark.vllm
@pytest.mark.gpu
@pytest.mark.slow
class TestGranite4FullSizeEquivalence:
def test_suite(self):
_run_inner_class("TestGranite4FullSizeEquivalence")
9 changes: 9 additions & 0 deletions tests/vllm/test_granite4_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
# ── Weight transfer tests (HF-level, no vLLM) ────────────────────


@pytest.mark.local_fast
class TestGranite4FamilyWeightTransfer:
"""HF-level weight transfer: all switch params populated from upstream.

Expand All @@ -70,6 +71,7 @@ def test_weight_transfer(self, model_name):
gc.collect()


@pytest.mark.local_fast
class TestZeroAdapterWeightTransfer:
"""HF-level weight transfer with adapter infrastructure.

Expand Down Expand Up @@ -122,20 +124,27 @@ def _run_inner_class(class_name):


@pytest.mark.skipif(not _VLLM_AVAILABLE, reason="requires vLLM installed")
@pytest.mark.vllm
@pytest.mark.gpu
@pytest.mark.slow
class TestGranite4FamilyEquivalence:
@pytest.mark.parametrize("model_name", _MODEL_NAMES)
def test_suite(self, model_name):
_run_inner_class(f"TestGranite4FamilyEquivalence and {model_name}")


@pytest.mark.skipif(not _VLLM_AVAILABLE, reason="requires vLLM installed")
@pytest.mark.vllm
@pytest.mark.gpu
class TestZeroAdapterNoHiding:
@pytest.mark.parametrize("model_name", _MODEL_NAMES)
def test_suite(self, model_name):
_run_inner_class(f"TestZeroAdapterNoHiding and {model_name}")


@pytest.mark.skipif(not _VLLM_AVAILABLE, reason="requires vLLM installed")
@pytest.mark.vllm
@pytest.mark.gpu
class TestZeroAdapterEquivalence:
@pytest.mark.parametrize("model_name", _MODEL_NAMES)
def test_suite(self, model_name):
Expand Down
9 changes: 5 additions & 4 deletions tests/vllm/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

_VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None

pytestmark = pytest.mark.skipif(
not _VLLM_AVAILABLE,
reason="requires vLLM installed (GPU checked by inner tests)",
)
pytestmark = [
pytest.mark.skipif(not _VLLM_AVAILABLE, reason="requires vLLM installed (GPU checked by inner tests)"),
pytest.mark.vllm,
pytest.mark.gpu,
]

_INNER = Path(__file__).parent / "_lora_tests.py"
_TIMEOUT = 300
Expand Down
9 changes: 5 additions & 4 deletions tests/vllm/test_model_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

_VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None

pytestmark = pytest.mark.skipif(
not _VLLM_AVAILABLE,
reason="requires vLLM installed (GPU checked by inner tests)",
)
pytestmark = [
pytest.mark.skipif(not _VLLM_AVAILABLE, reason="requires vLLM installed (GPU checked by inner tests)"),
pytest.mark.vllm,
pytest.mark.gpu,
]

_INNER = Path(__file__).parent / "_model_forward_tests.py"
_TIMEOUT = 600
Expand Down
9 changes: 5 additions & 4 deletions tests/vllm/test_noneager_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

_VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None

pytestmark = pytest.mark.skipif(
not _VLLM_AVAILABLE,
reason="requires vLLM installed (GPU checked by inner tests)",
)
pytestmark = [
pytest.mark.skipif(not _VLLM_AVAILABLE, reason="requires vLLM installed (GPU checked by inner tests)"),
pytest.mark.vllm,
pytest.mark.gpu,
]

_INNER = Path(__file__).parent / "_noneager_generation_tests.py"
_TIMEOUT = 600
Expand Down
Loading