Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions .github/workflows/prek.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,34 @@ name: Prek
on:
pull_request:
push:
branches: [ main ]
branches: [main]

jobs:
quality-checks:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH

- name: Install dependencies
run: |
uv sync --all-extras
- name: Run prek hooks (lint, format, typecheck, uv.lock, tests)
run: |
uv run prek run --all-files

- name: Run unit tests (via prek)
run: |
uv run prek run pytest
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH

- name: Install dependencies (with all optional extras for complete type checking)
run: |
uv sync --all-extras

- name: Run prek hooks (lint, format, typecheck, uv.lock, tests)
run: |
uv run prek run --all-files

- name: Run unit tests (via prek)
run: |
uv run prek run pytest
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ repos:

- repo: local
hooks:
- id: pyright
name: Pyright type checking
entry: uv run pyright src tests
- id: ty
name: ty type checking
entry: uv run ty check src tests
language: system
pass_filenames: false

Expand Down
49 changes: 48 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,53 @@ asyncio_mode = "auto"
[tool.uv]
required-version = ">=0.6.15"

[tool.ty.environment]
python-version = "3.11"

[tool.ty.rules]
# Ignore unused-ignore-comment warnings because they vary depending on whether
# optional deps are installed. The ty:ignore comments are needed in CI (with all deps)
# but become unused locally (without all deps).
unused-ignore-comment = "ignore"

[tool.ty.analysis]
# Allow unresolved imports for optional dependencies that may not be installed locally.
# In CI, we install all optional deps so these will be resolved and type-checked.
allowed-unresolved-imports = [
# backend deps
"accelerate.**",
"awscli.**",
"bitsandbytes.**",
"duckdb.**",
"fastapi.**",
"gql.**",
"hf_xet.**",
"nbclient.**",
"nbmake.**",
"peft.**",
"pyarrow.**",
"torch.**",
"torchao.**",
"transformers.**",
"trl.**",
"unsloth.**",
"unsloth_zoo.**",
"uvicorn.**",
"vllm.**",
"wandb.**",
# skypilot deps
"semver.**",
"sky.**",
"skypilot.**",
# langgraph deps
"langchain_core.**",
"langchain_openai.**",
"langgraph.**",
# plotting deps
"matplotlib.**",
"seaborn.**",
]

[dependency-groups]
dev = [
"black>=25.1.0",
Expand All @@ -112,7 +159,7 @@ dev = [
"pytest>=8.4.1",
"nbval>=0.11.0",
"pytest-xdist>=3.8.0",
"pyright[nodejs]>=1.1.403",
"ty>=0.0.14",
"pytest-asyncio>=1.1.0",
"duckdb>=1.0.0",
"pyarrow>=15.0.0",
Expand Down
4 changes: 2 additions & 2 deletions src/art/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def __init__(self, **kwargs):

# Import unsloth before transformers, peft, and trl to maximize Unsloth optimizations
if os.environ.get("IMPORT_UNSLOTH", "0") == "1":
import unsloth # type: ignore # noqa: F401
import unsloth # noqa: F401

try:
import transformers # type: ignore
import transformers

try:
from .transformers.patches import patch_preprocess_mask_arguments
Expand Down
8 changes: 4 additions & 4 deletions src/art/auto_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,10 @@ async def patched_aclose(self: httpx._models.Response) -> None:
if context := auto_trajectory_context_var.get(None):
context.handle_httpx_response(self)

httpx._models.Response.iter_bytes = patched_iter_bytes
httpx._models.Response.aiter_bytes = patched_aiter_bytes
httpx._models.Response.close = patched_close
httpx._models.Response.aclose = patched_aclose
httpx._models.Response.iter_bytes = patched_iter_bytes # ty:ignore[invalid-assignment]
httpx._models.Response.aiter_bytes = patched_aiter_bytes # ty:ignore[invalid-assignment]
httpx._models.Response.close = patched_close # ty:ignore[invalid-assignment]
httpx._models.Response.aclose = patched_aclose # ty:ignore[invalid-assignment]


patch_httpx()
2 changes: 1 addition & 1 deletion src/art/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
return pydantic.BaseModel.__init__(self, *args, **kwargs)

TrajectoryGroup.__new__ = __new__ # type: ignore
TrajectoryGroup.__init__ = __init__
TrajectoryGroup.__init__ = __init__ # ty:ignore[invalid-assignment]

backend = LocalBackend()
app = FastAPI()
Expand Down
6 changes: 3 additions & 3 deletions src/art/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def gather_trajectories(
)
if context.pbar is not None:
context.pbar.close()
return results # type: ignore
return results


async def wrap_group_awaitable(
Expand Down Expand Up @@ -193,7 +193,7 @@ def record_metrics(context: "GatherContext", trajectory: Trajectory) -> None:
len(l.content or l.refusal or [])
for l in logprobs # noqa: E741
) / len(logprobs)
context.metric_sums["reward"] += trajectory.reward # type: ignore
context.metric_sums["reward"] += trajectory.reward
context.metric_divisors["reward"] += 1
context.metric_sums.update(trajectory.metrics)
context.metric_divisors.update(trajectory.metrics.keys())
Expand Down Expand Up @@ -229,7 +229,7 @@ def too_many_exceptions(self) -> bool:
if (
0 < self.max_exceptions < 1
and self.pbar is not None
and self.metric_sums["exceptions"] / self.pbar.total <= self.max_exceptions
and self.metric_sums["exceptions"] / self.pbar.total <= self.max_exceptions # ty:ignore[unsupported-operator]
) or self.metric_sums["exceptions"] <= self.max_exceptions:
return False
return True
Expand Down
8 changes: 4 additions & 4 deletions src/art/guided_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def freeze_tool_schema(tool: dict, fixed_args: dict) -> ChatCompletionToolParam:
Each field is cast to typing.Literal[value] so Pydantic emits an
enum-of-one in the JSON schema, which vLLM's `guided_json` accepts.
"""
fields = {k: (Literal[v], ...) for k, v in fixed_args.items()}
fields = {k: (Literal[v], ...) for k, v in fixed_args.items()} # ty:ignore[invalid-type-form]
FrozenModel = create_model(
f"{tool['function']['name'].title()}FrozenArgs",
**fields, # type: ignore
)
**fields,
) # ty:ignore[no-matching-overload]

locked = deepcopy(tool)
locked["function"]["parameters"] = FrozenModel.model_json_schema()
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_guided_completion_params(
}
chosen_tool = next(t for t in base_tools if t["function"]["name"] == tool_name)
tool_params = [
freeze_tool_schema(chosen_tool, json.loads(tool_call.function.arguments)) # type: ignore
freeze_tool_schema(chosen_tool, json.loads(tool_call.function.arguments))
]
else:
content = completion.choices[0].message.content
Expand Down
18 changes: 9 additions & 9 deletions src/art/langgraph/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def init_chat_model(
config = CURRENT_CONFIG.get()
return LoggingLLM(
ChatOpenAI(
base_url=config["base_url"],
api_key=config["api_key"],
model=config["model"],
base_url=config["base_url"], # ty:ignore[unknown-argument]
api_key=config["api_key"], # ty:ignore[unknown-argument]
model=config["model"], # ty:ignore[unknown-argument]
temperature=1.0,
),
config["logger"],
Expand Down Expand Up @@ -222,17 +222,17 @@ def with_config(
self.llm,
"bound",
ChatOpenAI(
base_url=art_config["base_url"],
api_key=art_config["api_key"],
model=art_config["model"],
base_url=art_config["base_url"], # ty:ignore[unknown-argument]
api_key=art_config["api_key"], # ty:ignore[unknown-argument]
model=art_config["model"], # ty:ignore[unknown-argument]
temperature=1.0,
),
)
else:
self.llm = ChatOpenAI(
base_url=art_config["base_url"],
api_key=art_config["api_key"],
model=art_config["model"],
base_url=art_config["base_url"], # ty:ignore[unknown-argument]
api_key=art_config["api_key"], # ty:ignore[unknown-argument]
model=art_config["model"], # ty:ignore[unknown-argument]
temperature=1.0,
)

Expand Down
4 changes: 2 additions & 2 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _get_packed_tensors(
packed_tensors = packed_tensors_from_tokenized_results(
tokenized_results,
sequence_length,
pad_token_id=tokenizer.eos_token_id, # type: ignore
pad_token_id=tokenizer.eos_token_id,
advantage_balance=advantage_balance,
)
if (
Expand Down Expand Up @@ -360,7 +360,7 @@ def _trajectory_log(self, trajectory: Trajectory) -> str:
if isinstance(message_or_choice, dict):
message = message_or_choice
else:
message = cast(Message, message_or_choice.message.model_dump())
message = cast(Message, message_or_choice.message.model_dump()) # ty:ignore[possibly-missing-attribute]
formatted_messages.append(format_message(message))
return header + "\n".join(formatted_messages)

Expand Down
2 changes: 1 addition & 1 deletion src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def loss_fn(
)
if upper_bound := experimental_config.get("truncated_importance_sampling", None):
if "original_logprobs" in inputs:
original_logprobs = shift_tensor(inputs["original_logprobs"], 0.0)
original_logprobs = shift_tensor(inputs["original_logprobs"], 0.0) # ty:ignore[invalid-key]
original_logprobs = torch.where(
torch.isnan(original_logprobs),
new_logprobs.detach(),
Expand Down
6 changes: 3 additions & 3 deletions src/art/mcp/generate_scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ async def generate_scenarios(
# Assume it's already a list of dictionaries
tools_info = [
{
"name": tool.get("name", "")
"name": tool.get("name", "") # ty:ignore[no-matching-overload]
if isinstance(tool, dict)
else getattr(tool, "name", ""),
"description": tool.get("description", "")
"description": tool.get("description", "") # ty:ignore[no-matching-overload]
if isinstance(tool, dict)
else getattr(tool, "description", ""),
"parameters": tool.get("parameters", {})
"parameters": tool.get("parameters", {}) # ty:ignore[no-matching-overload]
if isinstance(tool, dict)
else getattr(tool, "parameters", {}),
}
Expand Down
16 changes: 10 additions & 6 deletions src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __new__( # pyright: ignore[reportInconsistentOverload]
*args,
**kwargs,
) -> "Model[ModelConfig, StateType]":
return super().__new__(cls) # type: ignore[return-value]
return super().__new__(cls)

def safe_model_dump(self, *args, **kwargs) -> dict:
"""
Expand All @@ -174,7 +174,7 @@ def backend(self) -> "Backend":
async def register(self, backend: "Backend") -> None:
if self.config is not None:
try:
self.config.model_dump_json()
self.config.model_dump_json() # ty:ignore[invalid-argument-type, possibly-missing-attribute]
except Exception as e:
raise ValueError(
"The model config cannot be serialized to JSON. Please ensure that all fields are JSON serializable and try again."
Expand Down Expand Up @@ -500,7 +500,7 @@ def __init__(
entity=entity,
id=id,
config=config,
base_model=base_model, # type: ignore
base_model=base_model,
base_path=base_path,
report_metrics=report_metrics,
**kwargs,
Expand Down Expand Up @@ -544,7 +544,7 @@ def __new__( # pyright: ignore[reportInconsistentOverload]
*args,
**kwargs,
) -> "TrainableModel[ModelConfig, StateType]":
return super().__new__(cls) # type: ignore
return super().__new__(cls)

def model_dump(self, *args, **kwargs) -> dict:
data = super().model_dump(*args, **kwargs)
Expand Down Expand Up @@ -649,15 +649,19 @@ async def train(
stacklevel=2,
)
groups_list = list(trajectory_groups)
_config = _config or {}
_config = _config or {} # ty:ignore[invalid-assignment]

# 1. Log trajectories first (frontend handles this now)
await self.log(groups_list, split="train")

# 2. Train (backend no longer logs internally)
training_metrics: list[dict[str, float]] = []
async for metrics in self.backend()._train_model(
self, groups_list, config, _config, verbose
self,
groups_list,
config,
_config, # ty:ignore[invalid-argument-type]
verbose,
):
training_metrics.append(metrics)

Expand Down
2 changes: 1 addition & 1 deletion src/art/preprocessing/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def create_train_inputs(
warmup: bool,
) -> TrainInputs:
"""Create TrainInputs for a single batch offset."""
return TrainInputs(
return TrainInputs( # ty:ignore[missing-typed-dict-key]
**{
k: (
v[offset : offset + 1, :1024]
Expand Down
Loading