diff --git a/.github/workflows/prek.yml b/.github/workflows/prek.yml index 1f3d52b4a..bb4e129ba 100644 --- a/.github/workflows/prek.yml +++ b/.github/workflows/prek.yml @@ -3,34 +3,34 @@ name: Prek on: pull_request: push: - branches: [ main ] + branches: [main] jobs: quality-checks: runs-on: ubuntu-latest - + steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - - name: Install uv - run: | - curl -LsSf https://astral.sh/uv/install.sh | sh - echo "$HOME/.cargo/bin" >> $GITHUB_PATH - - - name: Install dependencies - run: | - uv sync --all-extras - - - name: Run prek hooks (lint, format, typecheck, uv.lock, tests) - run: | - uv run prek run --all-files - - - name: Run unit tests (via prek) - run: | - uv run prek run pytest \ No newline at end of file + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + echo "$HOME/.cargo/bin" >> $GITHUB_PATH + + - name: Install dependencies (with all optional extras for complete type checking) + run: | + uv sync --all-extras + + - name: Run prek hooks (lint, format, typecheck, uv.lock, tests) + run: | + uv run prek run --all-files + + - name: Run unit tests (via prek) + run: | + uv run prek run pytest \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b5be17686..622254ac5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,9 +7,9 @@ repos: - repo: local hooks: - - id: pyright - name: Pyright type checking - entry: uv run pyright src tests + - id: ty + name: ty type checking + entry: uv run ty check src tests language: system pass_filenames: false diff --git a/pyproject.toml b/pyproject.toml index 4d0f4bf23..1a50198af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,53 @@ asyncio_mode = "auto" [tool.uv] required-version = ">=0.6.15" +[tool.ty.environment] +python-version = "3.11" + +[tool.ty.rules] +# Ignore unused-ignore-comment warnings because they vary depending on whether +# optional deps are installed. The ty:ignore comments are needed in CI (with all deps) +# but become unused locally (without all deps). +unused-ignore-comment = "ignore" + +[tool.ty.analysis] +# Allow unresolved imports for optional dependencies that may not be installed locally. +# In CI, we install all optional deps so these will be resolved and type-checked. +allowed-unresolved-imports = [ + # backend deps + "accelerate.**", + "awscli.**", + "bitsandbytes.**", + "duckdb.**", + "fastapi.**", + "gql.**", + "hf_xet.**", + "nbclient.**", + "nbmake.**", + "peft.**", + "pyarrow.**", + "torch.**", + "torchao.**", + "transformers.**", + "trl.**", + "unsloth.**", + "unsloth_zoo.**", + "uvicorn.**", + "vllm.**", + "wandb.**", + # skypilot deps + "semver.**", + "sky.**", + "skypilot.**", + # langgraph deps + "langchain_core.**", + "langchain_openai.**", + "langgraph.**", + # plotting deps + "matplotlib.**", + "seaborn.**", +] + [dependency-groups] dev = [ "black>=25.1.0", @@ -112,7 +159,7 @@ dev = [ "pytest>=8.4.1", "nbval>=0.11.0", "pytest-xdist>=3.8.0", - "pyright[nodejs]>=1.1.403", + "ty>=0.0.14", "pytest-asyncio>=1.1.0", "duckdb>=1.0.0", "pyarrow>=15.0.0", diff --git a/src/art/__init__.py b/src/art/__init__.py index 75957e176..b6948f514 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -34,10 +34,10 @@ def __init__(self, **kwargs): # Import unsloth before transformers, peft, and trl to maximize Unsloth optimizations if os.environ.get("IMPORT_UNSLOTH", "0") == "1": - import unsloth # type: ignore # noqa: F401 + import unsloth # noqa: F401 try: - import transformers # type: ignore + import transformers try: from .transformers.patches import patch_preprocess_mask_arguments diff --git a/src/art/auto_trajectory.py b/src/art/auto_trajectory.py index 4979602f7..a22d0be7c 100644 --- a/src/art/auto_trajectory.py +++ b/src/art/auto_trajectory.py @@ -177,10 +177,10 @@ async def patched_aclose(self: httpx._models.Response) -> None: if context := auto_trajectory_context_var.get(None): context.handle_httpx_response(self) - httpx._models.Response.iter_bytes = patched_iter_bytes - httpx._models.Response.aiter_bytes = patched_aiter_bytes - httpx._models.Response.close = patched_close - httpx._models.Response.aclose = patched_aclose + httpx._models.Response.iter_bytes = patched_iter_bytes # ty:ignore[invalid-assignment] + httpx._models.Response.aiter_bytes = patched_aiter_bytes # ty:ignore[invalid-assignment] + httpx._models.Response.close = patched_close # ty:ignore[invalid-assignment] + httpx._models.Response.aclose = patched_aclose # ty:ignore[invalid-assignment] patch_httpx() diff --git a/src/art/cli.py b/src/art/cli.py index dd64fc634..09b32e9de 100644 --- a/src/art/cli.py +++ b/src/art/cli.py @@ -159,7 +159,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: return pydantic.BaseModel.__init__(self, *args, **kwargs) TrajectoryGroup.__new__ = __new__ # type: ignore - TrajectoryGroup.__init__ = __init__ + TrajectoryGroup.__init__ = __init__ # ty:ignore[invalid-assignment] backend = LocalBackend() app = FastAPI() diff --git a/src/art/gather.py b/src/art/gather.py index 3461107c2..3a72df637 100644 --- a/src/art/gather.py +++ b/src/art/gather.py @@ -134,7 +134,7 @@ async def gather_trajectories( ) if context.pbar is not None: context.pbar.close() - return results # type: ignore + return results async def wrap_group_awaitable( @@ -193,7 +193,7 @@ def record_metrics(context: "GatherContext", trajectory: Trajectory) -> None: len(l.content or l.refusal or []) for l in logprobs # noqa: E741 ) / len(logprobs) - context.metric_sums["reward"] += trajectory.reward # type: ignore + context.metric_sums["reward"] += trajectory.reward context.metric_divisors["reward"] += 1 context.metric_sums.update(trajectory.metrics) context.metric_divisors.update(trajectory.metrics.keys()) @@ -229,7 +229,7 @@ def too_many_exceptions(self) -> bool: if ( 0 < self.max_exceptions < 1 and self.pbar is not None - and self.metric_sums["exceptions"] / self.pbar.total <= self.max_exceptions + and self.metric_sums["exceptions"] / self.pbar.total <= self.max_exceptions # ty:ignore[unsupported-operator] ) or self.metric_sums["exceptions"] <= self.max_exceptions: return False return True diff --git a/src/art/guided_completion.py b/src/art/guided_completion.py index 6599f486f..c1feff6bc 100644 --- a/src/art/guided_completion.py +++ b/src/art/guided_completion.py @@ -19,11 +19,11 @@ def freeze_tool_schema(tool: dict, fixed_args: dict) -> ChatCompletionToolParam: Each field is cast to typing.Literal[value] so Pydantic emits an enum-of-one in the JSON schema, which vLLM's `guided_json` accepts. """ - fields = {k: (Literal[v], ...) for k, v in fixed_args.items()} + fields = {k: (Literal[v], ...) for k, v in fixed_args.items()} # ty:ignore[invalid-type-form] FrozenModel = create_model( f"{tool['function']['name'].title()}FrozenArgs", - **fields, # type: ignore - ) + **fields, + ) # ty:ignore[no-matching-overload] locked = deepcopy(tool) locked["function"]["parameters"] = FrozenModel.model_json_schema() @@ -71,7 +71,7 @@ def get_guided_completion_params( } chosen_tool = next(t for t in base_tools if t["function"]["name"] == tool_name) tool_params = [ - freeze_tool_schema(chosen_tool, json.loads(tool_call.function.arguments)) # type: ignore + freeze_tool_schema(chosen_tool, json.loads(tool_call.function.arguments)) ] else: content = completion.choices[0].message.content diff --git a/src/art/langgraph/llm_wrapper.py b/src/art/langgraph/llm_wrapper.py index 11f2bbee6..36b5314b3 100644 --- a/src/art/langgraph/llm_wrapper.py +++ b/src/art/langgraph/llm_wrapper.py @@ -118,9 +118,9 @@ def init_chat_model( config = CURRENT_CONFIG.get() return LoggingLLM( ChatOpenAI( - base_url=config["base_url"], - api_key=config["api_key"], - model=config["model"], + base_url=config["base_url"], # ty:ignore[unknown-argument] + api_key=config["api_key"], # ty:ignore[unknown-argument] + model=config["model"], # ty:ignore[unknown-argument] temperature=1.0, ), config["logger"], @@ -222,17 +222,17 @@ def with_config( self.llm, "bound", ChatOpenAI( - base_url=art_config["base_url"], - api_key=art_config["api_key"], - model=art_config["model"], + base_url=art_config["base_url"], # ty:ignore[unknown-argument] + api_key=art_config["api_key"], # ty:ignore[unknown-argument] + model=art_config["model"], # ty:ignore[unknown-argument] temperature=1.0, ), ) else: self.llm = ChatOpenAI( - base_url=art_config["base_url"], - api_key=art_config["api_key"], - model=art_config["model"], + base_url=art_config["base_url"], # ty:ignore[unknown-argument] + api_key=art_config["api_key"], # ty:ignore[unknown-argument] + model=art_config["model"], # ty:ignore[unknown-argument] temperature=1.0, ) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index c017ad5f9..19f26afbe 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -215,7 +215,7 @@ def _get_packed_tensors( packed_tensors = packed_tensors_from_tokenized_results( tokenized_results, sequence_length, - pad_token_id=tokenizer.eos_token_id, # type: ignore + pad_token_id=tokenizer.eos_token_id, advantage_balance=advantage_balance, ) if ( @@ -360,7 +360,7 @@ def _trajectory_log(self, trajectory: Trajectory) -> str: if isinstance(message_or_choice, dict): message = message_or_choice else: - message = cast(Message, message_or_choice.message.model_dump()) + message = cast(Message, message_or_choice.message.model_dump()) # ty:ignore[possibly-missing-attribute] formatted_messages.append(format_message(message)) return header + "\n".join(formatted_messages) diff --git a/src/art/loss.py b/src/art/loss.py index dda539585..79154fde9 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -106,7 +106,7 @@ def loss_fn( ) if upper_bound := experimental_config.get("truncated_importance_sampling", None): if "original_logprobs" in inputs: - original_logprobs = shift_tensor(inputs["original_logprobs"], 0.0) + original_logprobs = shift_tensor(inputs["original_logprobs"], 0.0) # ty:ignore[invalid-key] original_logprobs = torch.where( torch.isnan(original_logprobs), new_logprobs.detach(), diff --git a/src/art/mcp/generate_scenarios.py b/src/art/mcp/generate_scenarios.py index df92ea3c4..6932d9c2e 100644 --- a/src/art/mcp/generate_scenarios.py +++ b/src/art/mcp/generate_scenarios.py @@ -75,13 +75,13 @@ async def generate_scenarios( # Assume it's already a list of dictionaries tools_info = [ { - "name": tool.get("name", "") + "name": tool.get("name", "") # ty:ignore[no-matching-overload] if isinstance(tool, dict) else getattr(tool, "name", ""), - "description": tool.get("description", "") + "description": tool.get("description", "") # ty:ignore[no-matching-overload] if isinstance(tool, dict) else getattr(tool, "description", ""), - "parameters": tool.get("parameters", {}) + "parameters": tool.get("parameters", {}) # ty:ignore[no-matching-overload] if isinstance(tool, dict) else getattr(tool, "parameters", {}), } diff --git a/src/art/model.py b/src/art/model.py index 1afd407e2..2fc38640b 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -153,7 +153,7 @@ def __new__( # pyright: ignore[reportInconsistentOverload] *args, **kwargs, ) -> "Model[ModelConfig, StateType]": - return super().__new__(cls) # type: ignore[return-value] + return super().__new__(cls) def safe_model_dump(self, *args, **kwargs) -> dict: """ @@ -174,7 +174,7 @@ def backend(self) -> "Backend": async def register(self, backend: "Backend") -> None: if self.config is not None: try: - self.config.model_dump_json() + self.config.model_dump_json() # ty:ignore[invalid-argument-type, possibly-missing-attribute] except Exception as e: raise ValueError( "The model config cannot be serialized to JSON. Please ensure that all fields are JSON serializable and try again." @@ -500,7 +500,7 @@ def __init__( entity=entity, id=id, config=config, - base_model=base_model, # type: ignore + base_model=base_model, base_path=base_path, report_metrics=report_metrics, **kwargs, @@ -544,7 +544,7 @@ def __new__( # pyright: ignore[reportInconsistentOverload] *args, **kwargs, ) -> "TrainableModel[ModelConfig, StateType]": - return super().__new__(cls) # type: ignore + return super().__new__(cls) def model_dump(self, *args, **kwargs) -> dict: data = super().model_dump(*args, **kwargs) @@ -649,7 +649,7 @@ async def train( stacklevel=2, ) groups_list = list(trajectory_groups) - _config = _config or {} + _config = _config or {} # ty:ignore[invalid-assignment] # 1. Log trajectories first (frontend handles this now) await self.log(groups_list, split="train") @@ -657,7 +657,11 @@ async def train( # 2. Train (backend no longer logs internally) training_metrics: list[dict[str, float]] = [] async for metrics in self.backend()._train_model( - self, groups_list, config, _config, verbose + self, + groups_list, + config, + _config, # ty:ignore[invalid-argument-type] + verbose, ): training_metrics.append(metrics) diff --git a/src/art/preprocessing/inputs.py b/src/art/preprocessing/inputs.py index 78d9710bb..996c15f20 100644 --- a/src/art/preprocessing/inputs.py +++ b/src/art/preprocessing/inputs.py @@ -24,7 +24,7 @@ def create_train_inputs( warmup: bool, ) -> TrainInputs: """Create TrainInputs for a single batch offset.""" - return TrainInputs( + return TrainInputs( # ty:ignore[missing-typed-dict-key] **{ k: ( v[offset : offset + 1, :1024] diff --git a/src/art/preprocessing/pack.py b/src/art/preprocessing/pack.py index 0e99176af..e943a4306 100644 --- a/src/art/preprocessing/pack.py +++ b/src/art/preprocessing/pack.py @@ -194,8 +194,8 @@ def packed_tensors_from_dir(**kwargs: Unpack[DiskPackedTensors]) -> PackedTensor "weights": torch.float32, }.items() } - _add_tensor_list(packed_tensors, kwargs, "pixel_values", torch.float32) - _add_tensor_list(packed_tensors, kwargs, "image_grid_thw", torch.long) + _add_tensor_list(packed_tensors, kwargs, "pixel_values", torch.float32) # ty:ignore[invalid-argument-type] + _add_tensor_list(packed_tensors, kwargs, "image_grid_thw", torch.long) # ty:ignore[invalid-argument-type] return cast(PackedTensors, packed_tensors) @@ -237,7 +237,7 @@ def packed_tensors_to_dir(tensors: PackedTensors, dir: str) -> DiskPackedTensors if isinstance(tensor, list): for i, t in enumerate(tensor): if t is not None: - t.copy_(tensors[key][i]) + t.copy_(tensors[key][i]) # ty:ignore[invalid-key, unresolved-attribute] else: tensor.copy_(tensors[key]) # type: ignore return disk_packed_tensors @@ -288,7 +288,7 @@ def plot_packed_tensors( tensor.numpy(), cmap="viridis", cbar_kws={"label": label}, - xticklabels=False, # type: ignore + xticklabels=False, ) plt.title(title) plt.xlabel("Sequence Position") diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 1b4230960..f146e5ca0 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -165,7 +165,7 @@ def tokenize_trajectory( ): last_assistant_index = i elif not isinstance(message, dict) and ( - message.logprobs or allow_training_without_logprobs + message.logprobs or allow_training_without_logprobs # ty:ignore[possibly-missing-attribute] ): last_assistant_index = i # If there are no trainable assistant messages, return None @@ -182,7 +182,7 @@ def tokenize_trajectory( str, tokenizer.apply_chat_template( cast(list[dict], messages), - tools=tools, # type: ignore + tools=tools, continue_final_message=True, tokenize=False, ), @@ -191,7 +191,7 @@ def tokenize_trajectory( list[int], tokenizer.apply_chat_template( cast(list[dict], messages), - tools=tools, # type: ignore + tools=tools, continue_final_message=True, ), ) @@ -214,8 +214,8 @@ def tokenize_trajectory( "role": "assistant", "content": sentinal_token, **( - {"tool_calls": message.get("tool_calls")} # type: ignore[call-overload] - if message.get("tool_calls") # type: ignore[call-overload] + {"tool_calls": message.get("tool_calls")} + if message.get("tool_calls") else {} ), } @@ -226,7 +226,7 @@ def tokenize_trajectory( list[int], tokenizer.apply_chat_template( cast(list[dict], token_template_messages), - tools=tools, # type: ignore + tools=tools, continue_final_message=True, ), ) @@ -238,7 +238,7 @@ def tokenize_trajectory( continue if not allow_training_without_logprobs: continue - elif message.logprobs is None and not allow_training_without_logprobs: + elif message.logprobs is None and not allow_training_without_logprobs: # ty:ignore[possibly-missing-attribute] continue start = token_ids.index(sentinal_token_id) end = start + 1 @@ -263,12 +263,12 @@ def tokenize_trajectory( assistant_mask[start:end] = [1] * len(content_token_ids) else: choice = message - assert choice.logprobs or allow_training_without_logprobs, ( + assert choice.logprobs or allow_training_without_logprobs, ( # ty:ignore[possibly-missing-attribute] "Chat completion choices must have logprobs" ) - if not choice.logprobs: + if not choice.logprobs: # ty:ignore[possibly-missing-attribute] continue - token_logprobs = choice.logprobs.content or choice.logprobs.refusal or [] + token_logprobs = choice.logprobs.content or choice.logprobs.refusal or [] # ty:ignore[possibly-missing-attribute] if ( bytes(token_logprobs[0].bytes or []).decode("utf-8") == "" @@ -281,14 +281,14 @@ def tokenize_trajectory( for token_logprob in token_logprobs ) except (IndexError, ValueError): - token_ids[start:end] = [ # type: ignore + token_ids[start:end] = [ token_id if token_id is not None else tokenizer.eos_token_id for token_id in tokenizer.convert_tokens_to_ids( [ token_logprob.token or tokenizer.eos_token for token_logprob in token_logprobs ] - ) # type: ignore + ) ] logprobs[start:end] = ( token_logprob.logprob for token_logprob in token_logprobs @@ -313,7 +313,7 @@ def tokenize_trajectory( image_token_id = cast( int, getattr(image_processor, "image_token_id", None) - or tokenizer.convert_tokens_to_ids( # type: ignore + or tokenizer.convert_tokens_to_ids( getattr(image_processor, "image_token", "<|image_pad|>") ), ) diff --git a/src/art/rewards/ruler.py b/src/art/rewards/ruler.py index 2ea333124..84d41f4de 100644 --- a/src/art/rewards/ruler.py +++ b/src/art/rewards/ruler.py @@ -186,7 +186,7 @@ async def ruler( first_choice = response.choices[0] if debug: - raw_content = first_choice.message.content or "{}" # type: ignore[attr-defined] + raw_content = first_choice.message.content or "{}" try: print("\n[RULER] Pretty-printed LLM choice JSON:") print(json.loads(raw_content)) @@ -194,7 +194,7 @@ async def ruler( print(f"[RULER] Could not parse choice content as JSON: {e}") print(f"[RULER] Raw choice content: {raw_content}") - content = first_choice.message.content or "{}" # type: ignore[attr-defined] + content = first_choice.message.content or "{}" parsed = Response.model_validate_json(content) # If all trajectories were identical, we only sent one to the judge diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 53e427aef..f431d4b17 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -25,7 +25,7 @@ def __init__( self._client = client async def close(self) -> None: - await self._client.close() + await self._client.close() # ty:ignore[possibly-missing-attribute] async def register( self, @@ -44,7 +44,7 @@ async def register( "Registering a non-trainable model with the Serverless backend is not supported." ) return - client_model = await self._client.models.create( + client_model = await self._client.models.create( # ty:ignore[possibly-missing-attribute] entity=model.entity, project=model.project, name=model.name, @@ -72,7 +72,7 @@ async def delete( ) return assert model.id is not None, "Model ID is required" - await self._client.models.delete(model_id=model.id) + await self._client.models.delete(model_id=model.id) # ty:ignore[possibly-missing-attribute] def _model_inference_name(self, model: "Model", step: int | None = None) -> str: """Return the inference name for a model checkpoint. @@ -92,7 +92,7 @@ def _model_inference_name(self, model: "Model", step: int | None = None) -> str: async def _get_step(self, model: "Model") -> int: if model.trainable: assert model.id is not None, "Model ID is required" - async for checkpoint in self._client.models.checkpoints.list( + async for checkpoint in self._client.models.checkpoints.list( # ty:ignore[possibly-missing-attribute] limit=1, order="desc", model_id=model.id ): return checkpoint.step @@ -108,11 +108,11 @@ async def _delete_checkpoint_files( assert model.id is not None, "Model ID is required" # Get all checkpoint steps all_steps: list[int] = [] - async for checkpoint in self._client.models.checkpoints.list(model_id=model.id): + async for checkpoint in self._client.models.checkpoints.list(model_id=model.id): # ty:ignore[possibly-missing-attribute] all_steps.append(checkpoint.step) # Delete all steps not in steps_to_keep if steps_to_delete := [step for step in all_steps if step not in steps_to_keep]: - await self._client.models.checkpoints.delete( + await self._client.models.checkpoints.delete( # ty:ignore[possibly-missing-attribute] model_id=model.id, steps=steps_to_delete, ) @@ -122,7 +122,7 @@ async def _prepare_backend_for_training( model: "TrainableModel", config: dev.OpenAIServerConfig | None, ) -> tuple[str, str]: - return str(self._base_url), self._client.api_key + return str(self._base_url), self._client.api_key # ty:ignore[possibly-missing-attribute] # Note: _log() method has been moved to the Model class (frontend) # Trajectories are now saved locally by the Model.log() method @@ -256,7 +256,7 @@ async def _train_model( verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: assert model.id is not None, "Model ID is required" - training_job = await self._client.training_jobs.create( + training_job = await self._client.training_jobs.create( # ty:ignore[possibly-missing-attribute] model_id=model.id, trajectory_groups=trajectory_groups, experimental_config=ExperimentalTrainingConfig( @@ -280,7 +280,7 @@ async def _train_model( pbar: tqdm.tqdm | None = None while True: await asyncio.sleep(1) - async for event in self._client.training_jobs.events.list( + async for event in self._client.training_jobs.events.list( # ty:ignore[possibly-missing-attribute] training_job_id=training_job.id, after=after or NOT_GIVEN ): if event.type == "gradient_step": @@ -336,7 +336,7 @@ async def _experimental_pull_model_checkpoint( assert model.id is not None, "Model ID is required" # If entity is not set, use the user's default entity from W&B - api = wandb.Api(api_key=self._client.api_key) + api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute] if model.entity is None: model.entity = api.default_entity if verbose: @@ -346,7 +346,7 @@ async def _experimental_pull_model_checkpoint( resolved_step: int if step is None or step == "latest": # Get latest checkpoint from API - async for checkpoint in self._client.models.checkpoints.list( + async for checkpoint in self._client.models.checkpoints.list( # ty:ignore[possibly-missing-attribute] limit=1, order="desc", model_id=model.id ): resolved_step = checkpoint.step diff --git a/src/art/serverless/client.py b/src/art/serverless/client.py index b28e9df9e..076418f39 100644 --- a/src/art/serverless/client.py +++ b/src/art/serverless/client.py @@ -132,7 +132,7 @@ async def delete(self, *, model_id: str) -> None: @cached_property def checkpoints(self) -> "Checkpoints": - return Checkpoints(cast(AsyncOpenAI, self._client)) + return Checkpoints(cast(AsyncOpenAI, self._client)) # ty:ignore[redundant-cast] class Checkpoints(AsyncAPIResource): @@ -193,7 +193,7 @@ async def create( @cached_property def events(self) -> "TrainingJobEvents": - return TrainingJobEvents(cast(AsyncOpenAI, self._client)) + return TrainingJobEvents(cast(AsyncOpenAI, self._client)) # ty:ignore[redundant-cast] class TrainingJobEvents(AsyncAPIResource): @@ -241,7 +241,7 @@ def __init__( ) @override - async def request( # type: ignore[reportIncompatibleMethodOverride] + async def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, diff --git a/src/art/tinker/server.py b/src/art/tinker/server.py index 9781880f5..6be096ad4 100644 --- a/src/art/tinker/server.py +++ b/src/art/tinker/server.py @@ -31,7 +31,7 @@ def _patched_parse_tool_call( return _parse_tool_call(self, tool_call_str.replace('"arguments": ', '"args": ')) -renderers.Qwen3InstructRenderer._parse_tool_call = _patched_parse_tool_call +renderers.Qwen3InstructRenderer._parse_tool_call = _patched_parse_tool_call # ty:ignore[invalid-assignment] @dataclass @@ -102,7 +102,7 @@ async def chat_completions( list(body["messages"]), # type: ignore tools=body.get("tools"), # type: ignore add_generation_prompt=True, - ) + ) # ty:ignore[invalid-argument-type] ) try: sample_response = await sampler_client.sample_async( @@ -122,7 +122,7 @@ async def chat_completions( except tinker.APIStatusError as e: error_body = e.body if isinstance(error_body, dict) and "detail" in error_body: - detail = error_body["detail"] + detail = error_body["detail"] # ty:ignore[invalid-argument-type] elif error_body is not None: detail = error_body else: diff --git a/src/art/trajectories.py b/src/art/trajectories.py index 69ab17db9..f74fb3039 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -87,9 +87,9 @@ def for_logging(self) -> dict[str, Any]: for message_or_choice in self.messages_and_choices: trainable = isinstance(message_or_choice, Choice) message = ( - message_or_choice.message.to_dict() if trainable else message_or_choice + message_or_choice.message.to_dict() if trainable else message_or_choice # ty:ignore[possibly-missing-attribute] ) - loggable_dict["messages"].append({**message, "trainable": trainable}) + loggable_dict["messages"].append({**message, "trainable": trainable}) # ty:ignore[invalid-argument-type, possibly-missing-attribute] return loggable_dict @@ -112,7 +112,7 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages: } if tool_calls else {} - ), # type: ignore + ), } ) else: diff --git a/src/art/transformers/patches.py b/src/art/transformers/patches.py index 6b16424fe..97e09f6c8 100644 --- a/src/art/transformers/patches.py +++ b/src/art/transformers/patches.py @@ -34,4 +34,4 @@ def _patched_preprocess_mask_arguments( def patch_preprocess_mask_arguments() -> None: - masking_utils._preprocess_mask_arguments = _patched_preprocess_mask_arguments + masking_utils._preprocess_mask_arguments = _patched_preprocess_mask_arguments # ty:ignore[invalid-assignment] diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index 18c1b9027..d42941357 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -57,7 +57,7 @@ def precalculate_new_logprobs( [ trainer.compute_loss( peft_model, - TrainInputs( + TrainInputs( # ty:ignore[missing-typed-dict-key] **{ k: v[_offset : _offset + 1] for k, v in packed_tensors.items() @@ -70,7 +70,7 @@ def precalculate_new_logprobs( config=config, _config=_config, return_new_logprobs=True, - ), # type: ignore + ), ) for _offset in range(0, packed_tensors["tokens"].shape[0]) ] @@ -471,7 +471,7 @@ def _state(self) -> UnslothState: trainer = GRPOTrainer( model=peft_model, # type: ignore reward_funcs=[], - args=GRPOConfig(**self.config.get("trainer_args", {})), # type: ignore + args=GRPOConfig(**self.config.get("trainer_args", {})), train_dataset=Dataset.from_list([data for _ in range(10_000_000)]), processing_class=tokenizer, ) @@ -513,7 +513,7 @@ def llm(self) -> asyncio.Task[AsyncLLM]: # Remove boolean flags that vLLM's argparse doesn't accept as =False for key in ["enable_log_requests", "disable_log_requests"]: engine_args.pop(key, None) - return asyncio.create_task(get_llm(AsyncEngineArgs(**engine_args))) + return asyncio.create_task(get_llm(AsyncEngineArgs(**engine_args))) # ty:ignore[invalid-argument-type] # ============================================================================ @@ -576,7 +576,7 @@ def do_sleep(*, level: int) -> None: pin_memory=is_pin_memory_available(), ) cpu_ptr = cpu_backup_tensor.data_ptr() - libcudart.cudaMemcpy( + libcudart.cudaMemcpy( # ty:ignore[possibly-missing-attribute] ctypes.c_void_p(cpu_ptr), ctypes.c_void_p(ptr), size_in_bytes ) data.cpu_backup_tensor = cpu_backup_tensor @@ -612,7 +612,7 @@ def do_wake_up() -> None: cpu_backup_tensor = data.cpu_backup_tensor size_in_bytes = cpu_backup_tensor.numel() * cpu_backup_tensor.element_size() cpu_ptr = cpu_backup_tensor.data_ptr() - libcudart.cudaMemcpy( + libcudart.cudaMemcpy( # ty:ignore[possibly-missing-attribute] ctypes.c_void_p(ptr), ctypes.c_void_p(cpu_ptr), size_in_bytes, diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index fba02d22e..e5d229537 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -27,7 +27,7 @@ async def train( _compute_loss = trainer.compute_loss _log = trainer.log trainer.compute_loss = get_compute_loss_fn(trainer) - trainer.log = get_log_fn(trainer, results_queue) + trainer.log = get_log_fn(trainer, results_queue) # ty:ignore[invalid-assignment] # Ensure we have a metrics container in the expected format try: is_dict = isinstance(getattr(trainer, "_metrics", None), dict) @@ -40,7 +40,7 @@ async def train( trainer.train() finally: trainer.compute_loss = _compute_loss - trainer.log = _log + trainer.log = _log # ty:ignore[invalid-assignment] def get_compute_loss_fn(trainer: "GRPOTrainer") -> Callable[..., torch.Tensor]: @@ -82,7 +82,7 @@ def compute_loss( inputs = { key: tensor.to(trainer.accelerator.device) # type: ignore for key, tensor in inputs.items() - } + } # ty:ignore[invalid-assignment] accelerate_mixed_precision = os.environ.get("ACCELERATE_MIXED_PRECISION") force_float32 = os.environ.get("UNSLOTH_FORCE_FLOAT32") @@ -120,9 +120,9 @@ def compute_loss( os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" forward_kwargs = {} if "pixel_values" in inputs: - forward_kwargs["pixel_values"] = inputs["pixel_values"] # type: ignore + forward_kwargs["pixel_values"] = inputs["pixel_values"] if "image_grid_thw" in inputs: - forward_kwargs["image_grid_thw"] = inputs["image_grid_thw"] # type: ignore + forward_kwargs["image_grid_thw"] = inputs["image_grid_thw"] new_logprobs, entropies = calculate_logprobs( dtype_for_autocasting, trainer, @@ -167,10 +167,10 @@ def compute_loss( trainer._metrics["train"]["learning_rate"].append(config.learning_rate) trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item()) if loss.mean_entropy is not None: - trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) # type: ignore + trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) if config.beta > 0.0: trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item()) - return loss.mean_policy_loss + config.beta * loss.mean_kl # type: ignore + return loss.mean_policy_loss + config.beta * loss.mean_kl return compute_loss diff --git a/src/art/utils/benchmark_rollout.py b/src/art/utils/benchmark_rollout.py index c8ef57d70..4ff1cfc0e 100644 --- a/src/art/utils/benchmark_rollout.py +++ b/src/art/utils/benchmark_rollout.py @@ -11,7 +11,7 @@ async def benchmark_rollout( rollout: Callable[[str, int, bool], Coroutine[Any, Any, Trajectory]], ) -> float: trajectory_groups = await art.gather_trajectory_groups( - [TrajectoryGroup(rollout(model, i, False) for i in range(num_rollouts))], + [TrajectoryGroup(rollout(model, i, False) for i in range(num_rollouts))], # ty:ignore[invalid-argument-type] pbar_desc="Benchmarking rollout", ) diff --git a/src/art/utils/benchmarking/charts/training_progress_chart.py b/src/art/utils/benchmarking/charts/training_progress_chart.py index 5a41bce1e..1f03507c6 100644 --- a/src/art/utils/benchmarking/charts/training_progress_chart.py +++ b/src/art/utils/benchmarking/charts/training_progress_chart.py @@ -187,7 +187,7 @@ def training_progress_chart( ordered_for_palette = trained_first + comparison_last palette = sns.color_palette(n_colors=len(ordered_for_palette)) - model_colors = {m: c for m, c in zip(ordered_for_palette, palette)} # type: ignore + model_colors = {m: c for m, c in zip(ordered_for_palette, palette)} # Track scores of comparison models to adjust linestyle for overlaps plotted_comparison_scores: set[float] = set() diff --git a/src/art/utils/benchmarking/load_trajectories.py b/src/art/utils/benchmarking/load_trajectories.py index 961e2c12d..2b494ffe4 100644 --- a/src/art/utils/benchmarking/load_trajectories.py +++ b/src/art/utils/benchmarking/load_trajectories.py @@ -239,7 +239,7 @@ async def load_trajectories( } if msg_dict.get("tool_calls"): try: - processed_msg["tool_calls"] = json.loads(msg_dict["tool_calls"]) + processed_msg["tool_calls"] = json.loads(msg_dict["tool_calls"]) # ty:ignore[invalid-argument-type] except (json.JSONDecodeError, TypeError): pass diff --git a/src/art/utils/format_message.py b/src/art/utils/format_message.py index 874eb6ec9..fd3fe20dc 100644 --- a/src/art/utils/format_message.py +++ b/src/art/utils/format_message.py @@ -14,7 +14,7 @@ def format_message(message: Message) -> str: # Format any tool calls tool_calls_text = "\n" if content else "" tool_calls_text += "\n".join( - f"{tool_call['function']['name']}({tool_call['function']['arguments']})" + f"{tool_call['function']['name']}({tool_call['function']['arguments']})" # ty:ignore[invalid-key] for tool_call in message.get("tool_calls") or [] if "function" in tool_call ) diff --git a/src/art/utils/litellm.py b/src/art/utils/litellm.py index 481be5fbb..d65178d24 100644 --- a/src/art/utils/litellm.py +++ b/src/art/utils/litellm.py @@ -40,7 +40,7 @@ def convert_litellm_choice_to_openai( openai_tool_calls.append( ChatCompletionMessageToolCall( id=tool_call.id, - type=tool_call.type, + type=tool_call.type, # ty:ignore[invalid-argument-type] function=Function( name=tool_call.function.name, arguments=tool_call.function.arguments, diff --git a/src/art/utils/old_benchmarking/generate_line_graphs.py b/src/art/utils/old_benchmarking/generate_line_graphs.py index a0612c19d..182c354b1 100644 --- a/src/art/utils/old_benchmarking/generate_line_graphs.py +++ b/src/art/utils/old_benchmarking/generate_line_graphs.py @@ -56,7 +56,7 @@ def has_all_recorded(model): last_x_global: float | None = None for model in line_graph_models: if x_axis_metric == "time": - from matplotlib import dates as mdates # type: ignore + from matplotlib import dates as mdates x_values_float = [ float(mdates.date2num(step.recorded_at or datetime.min)) diff --git a/src/art/utils/retry.py b/src/art/utils/retry.py index 51c22c548..975e53e07 100644 --- a/src/art/utils/retry.py +++ b/src/art/utils/retry.py @@ -61,7 +61,7 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Any: on_retry(e, attempt) else: logging.warning( - f"Retry {attempt}/{max_attempts} for {func.__name__} " + f"Retry {attempt}/{max_attempts} for {func.__name__} " # ty:ignore[unresolved-attribute] f"after error: {str(e)}" ) @@ -92,7 +92,7 @@ def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: on_retry(e, attempt) else: logging.warning( - f"Retry {attempt}/{max_attempts} for {func.__name__} " + f"Retry {attempt}/{max_attempts} for {func.__name__} " # ty:ignore[unresolved-attribute] f"after error: {str(e)}" ) diff --git a/src/art/utils/trajectory_migration.py b/src/art/utils/trajectory_migration.py index 3b4a14a8a..2cdf200cf 100644 --- a/src/art/utils/trajectory_migration.py +++ b/src/art/utils/trajectory_migration.py @@ -81,17 +81,17 @@ def message_or_choice_to_dict(message_or_choice: MessageOrChoice) -> dict[str, A item_dict = ( message_or_choice if isinstance(message_or_choice, dict) - else message_or_choice.to_dict() + else message_or_choice.to_dict() # ty:ignore[possibly-missing-attribute] ) if "logprobs" in item_dict: # item is a choice with logprobs, remove the logprobs - item_dict.pop("logprobs") + item_dict.pop("logprobs") # ty:ignore[invalid-argument-type] if "content" in item_dict and isinstance(item_dict["content"], Iterator): item_dict["content"] = list(item_dict["content"]) # type: ignore - return dict(item_dict) + return dict(item_dict) # ty:ignore[no-matching-overload] def deserialize_trajectory_groups(serialized: str) -> list[TrajectoryGroup]: diff --git a/src/art/vllm/engine.py b/src/art/vllm/engine.py index 46843d003..c8da5c55b 100644 --- a/src/art/vllm/engine.py +++ b/src/art/vllm/engine.py @@ -14,7 +14,7 @@ from vllm.v1.worker.gpu_worker import Worker -async def get_llm(args: vllm.AsyncEngineArgs) -> AsyncLLM: +async def get_llm(args: vllm.AsyncEngineArgs) -> AsyncLLM: # ty:ignore[unresolved-attribute] """ Create an AsyncLLM engine with model download and patches applied. diff --git a/src/art/vllm/patches.py b/src/art/vllm/patches.py index 54d2eba89..6bd28bf9f 100644 --- a/src/art/vllm/patches.py +++ b/src/art/vllm/patches.py @@ -11,12 +11,12 @@ def subclass_chat_completion_request() -> None: class ChatCompletionRequest(vllm.entrypoints.openai.protocol.ChatCompletionRequest): def __init__(self, *args: object, **kwargs: object) -> None: - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # ty:ignore[invalid-argument-type] self.logprobs = True if self.top_logprobs is None: self.top_logprobs = 0 - vllm.entrypoints.openai.protocol.ChatCompletionRequest = ChatCompletionRequest + vllm.entrypoints.openai.protocol.ChatCompletionRequest = ChatCompletionRequest # ty:ignore[invalid-assignment] def patch_listen_for_disconnect() -> None: @@ -32,7 +32,7 @@ async def patched_listen_for_disconnect(request): # Replace the original function import vllm.entrypoints.utils - vllm.entrypoints.utils.listen_for_disconnect = patched_listen_for_disconnect + vllm.entrypoints.utils.listen_for_disconnect = patched_listen_for_disconnect # ty:ignore[invalid-assignment] def patch_tool_parser_manager() -> None: @@ -54,7 +54,7 @@ def patch( ) -> Any: return original(*args, **kwargs) or DeltaMessage() - tool_parser_class.extract_tool_calls_streaming = patch + tool_parser_class.extract_tool_calls_streaming = patch # ty:ignore[invalid-assignment] return tool_parser_class - ToolParserManager.get_tool_parser = patched_get_tool_parser + ToolParserManager.get_tool_parser = patched_get_tool_parser # ty:ignore[invalid-assignment] diff --git a/src/art/vllm/server.py b/src/art/vllm/server.py index 20f62db4d..0131eaae3 100644 --- a/src/art/vllm/server.py +++ b/src/art/vllm/server.py @@ -58,7 +58,7 @@ def _init(self, *args: Any, **kwargs: Any) -> None: global _openai_serving_models _openai_serving_models = self - serving_models.OpenAIServingModels.__init__ = _init + serving_models.OpenAIServingModels.__init__ = _init # ty:ignore[invalid-assignment] patch_listen_for_disconnect() patch_tool_parser_manager() @@ -86,7 +86,7 @@ async def _add_lora(lora_request) -> bool: _openai_serving_models.lora_requests[lora_request.lora_name] = lora_request return added - engine.add_lora = _add_lora + engine.add_lora = _add_lora # ty:ignore[invalid-assignment] @asynccontextmanager async def build_async_engine_client( diff --git a/tests/integration.py b/tests/integration.py index 43a60c09a..3ed8070c5 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -331,7 +331,7 @@ def main() -> int: processed_config = config.copy() # Resolve path relative to this file - p = (here / processed_config["path"]).resolve() + p = (here / processed_config["path"]).resolve() # ty:ignore[unsupported-operator] if not p.exists(): print(f"Warning: notebook not found: {p}") processed_config["path"] = str(p) diff --git a/tests/integration/test_multi_checkpoint_training.py b/tests/integration/test_multi_checkpoint_training.py index 38c3c3c9f..9252f59a4 100644 --- a/tests/integration/test_multi_checkpoint_training.py +++ b/tests/integration/test_multi_checkpoint_training.py @@ -83,7 +83,7 @@ async def run_training_loop( ] ) for prompt in prompts - ] + ] # ty:ignore[invalid-argument-type] ) result = await backend.train(model, train_groups, learning_rate=1e-5) await model.log( diff --git a/tests/test_backend_train_api.py b/tests/test_backend_train_api.py index bc9551175..b24200ae1 100644 --- a/tests/test_backend_train_api.py +++ b/tests/test_backend_train_api.py @@ -78,7 +78,7 @@ async def main(): ] ) for prompt in prompts - ] + ] # ty:ignore[invalid-argument-type] ) print(f" ✓ Gathered {len(train_groups)} trajectory groups") diff --git a/tests/unit/test_auto_trajectory.py b/tests/unit/test_auto_trajectory.py index d70243bd7..fadd6ab98 100644 --- a/tests/unit/test_auto_trajectory.py +++ b/tests/unit/test_auto_trajectory.py @@ -181,7 +181,7 @@ } ], }, - } + } # ty:ignore[invalid-argument-type] ) @@ -288,9 +288,9 @@ async def say_hi() -> str | None: trajectory = await art.capture_auto_trajectory(say_hi()) assert trajectory.messages_and_choices == [ message, - Choice(**mock_response["choices"][0]), + Choice(**mock_response["choices"][0]), # ty:ignore[invalid-argument-type, not-subscriptable] message, - Choice(**mock_response["choices"][0]), + Choice(**mock_response["choices"][0]), # ty:ignore[invalid-argument-type, not-subscriptable] ] assert trajectory.tools == tools assert trajectory.additional_histories[0].messages_and_choices == [ @@ -305,12 +305,12 @@ async def say_hi() -> str | None: "role": "assistant", }, message, - Choice(**mock_response["choices"][0]), + Choice(**mock_response["choices"][0]), # ty:ignore[invalid-argument-type, not-subscriptable] ] assert trajectory.additional_histories[0].tools is None assert trajectory.additional_histories[1].messages_and_choices == [ message, - Choice(**mock_response["choices"][0]), + Choice(**mock_response["choices"][0]), # ty:ignore[invalid-argument-type, not-subscriptable] ] assert trajectory.additional_histories[1].tools == tools assert trajectory.additional_histories[2].messages_and_choices == [ @@ -387,9 +387,9 @@ async def say_hi() -> str | None: trajectory = await art.capture_auto_trajectory(say_hi()) assert trajectory.messages_and_choices == [ message, - Choice(**mock_response["choices"][0]), + Choice(**mock_response["choices"][0]), # ty:ignore[invalid-argument-type, not-subscriptable] message, - Choice(**mock_response["choices"][0]), + Choice(**mock_response["choices"][0]), # ty:ignore[invalid-argument-type, not-subscriptable] ] assert trajectory.additional_histories[0].messages_and_choices == [ message, diff --git a/tests/unit/test_tokenize_trajectory_groups.ipynb b/tests/unit/test_tokenize_trajectory_groups.ipynb index 0628de3ee..ed488855d 100644 --- a/tests/unit/test_tokenize_trajectory_groups.ipynb +++ b/tests/unit/test_tokenize_trajectory_groups.ipynb @@ -126,7 +126,7 @@ " result.weight = round(result.weight, 2)\n", " # set prompt_id to 0 to eliminate stochasticity\n", " result.prompt_id = 0\n", - " display(result)" + " display(result) # ty:ignore[unresolved-reference]" ] } ], diff --git a/tests/unit/test_trajectory_parquet.py b/tests/unit/test_trajectory_parquet.py index dbcef3bbb..597d93e71 100644 --- a/tests/unit/test_trajectory_parquet.py +++ b/tests/unit/test_trajectory_parquet.py @@ -58,7 +58,7 @@ def _ensure_message(item: MessageOrChoice) -> ChatCompletionMessageParam: """Narrow a trajectory entry to a concrete message (not a Choice).""" assert not isinstance(item, Choice) - return cast(ChatCompletionMessageParam, item) + return cast(ChatCompletionMessageParam, item) # ty:ignore[redundant-cast] def _ensure_assistant_message( @@ -66,19 +66,19 @@ def _ensure_assistant_message( ) -> ChatCompletionAssistantMessageParam: msg = _ensure_message(item) assert msg["role"] == "assistant" - return cast(ChatCompletionAssistantMessageParam, msg) + return cast(ChatCompletionAssistantMessageParam, msg) # ty:ignore[redundant-cast] def _ensure_tool_message(item: MessageOrChoice) -> ChatCompletionToolMessageParam: msg = _ensure_message(item) assert msg["role"] == "tool" - return cast(ChatCompletionToolMessageParam, msg) + return cast(ChatCompletionToolMessageParam, msg) # ty:ignore[redundant-cast] def _ensure_user_message(item: MessageOrChoice) -> ChatCompletionUserMessageParam: msg = _ensure_message(item) assert msg["role"] == "user" - return cast(ChatCompletionUserMessageParam, msg) + return cast(ChatCompletionUserMessageParam, msg) # ty:ignore[redundant-cast] class TestParquetRoundTrip: @@ -173,7 +173,7 @@ def test_tool_calls(self, tmp_path: Path): assert tool_calls, "Assistant message should include tool calls" first_call = tool_calls[0] assert first_call["type"] == "function" - function_call = cast(ChatCompletionMessageFunctionToolCallParam, first_call) + function_call = cast(ChatCompletionMessageFunctionToolCallParam, first_call) # ty:ignore[redundant-cast] assert function_call["function"]["name"] == "search" # Check tool result message diff --git a/tests/unit/test_yield_trajectory.py b/tests/unit/test_yield_trajectory.py index c7114051b..dbc6a3d94 100644 --- a/tests/unit/test_yield_trajectory.py +++ b/tests/unit/test_yield_trajectory.py @@ -150,5 +150,5 @@ async def say_hi() -> str | None: trajectory = await art.capture_yielded_trajectory(say_hi()) assert trajectory.messages_and_choices == [ {"role": "user", "content": "Hi!"}, - Choice(**mock_response["choices"][0]), + Choice(**mock_response["choices"][0]), # ty:ignore[invalid-argument-type, not-subscriptable] ] diff --git a/uv.lock b/uv.lock index 647914fd8..1e73e5e41 100644 --- a/uv.lock +++ b/uv.lock @@ -4703,31 +4703,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/60/90/81ac364ef94209c100e12579629dc92bf7a709a84af32f8c551b02c07e94/nltk-3.9.2-py3-none-any.whl", hash = "sha256:1e209d2b3009110635ed9709a67a1a3e33a10f799490fa71cf4bec218c11c88a", size = 1513404, upload-time = "2025-10-01T07:19:21.648Z" }, ] -[[package]] -name = "nodeenv" -version = "1.10.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, -] - -[[package]] -name = "nodejs-wheel-binaries" -version = "24.12.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b9/35/d806c2ca66072e36dc340ccdbeb2af7e4f1b5bcc33f1481f00ceed476708/nodejs_wheel_binaries-24.12.0.tar.gz", hash = "sha256:f1b50aa25375e264697dec04b232474906b997c2630c8f499f4caf3692938435", size = 8058, upload-time = "2025-12-11T21:12:26.856Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/3b/9d6f044319cd5b1e98f07c41e2465b58cadc1c9c04a74c891578f3be6cb5/nodejs_wheel_binaries-24.12.0-py2.py3-none-macosx_13_0_arm64.whl", hash = "sha256:7564ddea0a87eff34e9b3ef71764cc2a476a8f09a5cccfddc4691148b0a47338", size = 55125859, upload-time = "2025-12-11T21:11:58.132Z" }, - { url = "https://files.pythonhosted.org/packages/48/a5/f5722bf15c014e2f476d7c76bce3d55c341d19122d8a5d86454db32a61a4/nodejs_wheel_binaries-24.12.0-py2.py3-none-macosx_13_0_x86_64.whl", hash = "sha256:8ff929c4669e64613ceb07f5bbd758d528c3563820c75d5de3249eb452c0c0ab", size = 55309035, upload-time = "2025-12-11T21:12:01.754Z" }, - { url = "https://files.pythonhosted.org/packages/a9/61/68d39a6f1b5df67805969fd2829ba7e80696c9af19537856ec912050a2be/nodejs_wheel_binaries-24.12.0-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:6ebacefa8891bc456ad3655e6bce0af7e20ba08662f79d9109986faeb703fd6f", size = 59661017, upload-time = "2025-12-11T21:12:05.268Z" }, - { url = "https://files.pythonhosted.org/packages/16/a1/31aad16f55a5e44ca7ea62d1367fc69f4b6e1dba67f58a0a41d0ed854540/nodejs_wheel_binaries-24.12.0-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:3292649a03682ccbfa47f7b04d3e4240e8c46ef04dc941b708f20e4e6a764f75", size = 60159770, upload-time = "2025-12-11T21:12:08.696Z" }, - { url = "https://files.pythonhosted.org/packages/c4/5e/b7c569aa1862690ca4d4daf3a64cafa1ea6ce667a9e3ae3918c56e127d9b/nodejs_wheel_binaries-24.12.0-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7fb83df312955ea355ba7f8cbd7055c477249a131d3cb43b60e4aeb8f8c730b1", size = 61653561, upload-time = "2025-12-11T21:12:12.575Z" }, - { url = "https://files.pythonhosted.org/packages/71/87/567f58d7ba69ff0208be849b37be0f2c2e99c69e49334edd45ff44f00043/nodejs_wheel_binaries-24.12.0-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2473c819448fedd7b036dde236b09f3c8bbf39fbbd0c1068790a0498800f498b", size = 62238331, upload-time = "2025-12-11T21:12:16.143Z" }, - { url = "https://files.pythonhosted.org/packages/6a/9d/c6492188ce8de90093c6755a4a63bb6b2b4efb17094cb4f9a9a49c73ed3b/nodejs_wheel_binaries-24.12.0-py2.py3-none-win_amd64.whl", hash = "sha256:2090d59f75a68079fabc9b86b14df8238b9aecb9577966dc142ce2a23a32e9bb", size = 41342076, upload-time = "2025-12-11T21:12:20.618Z" }, - { url = "https://files.pythonhosted.org/packages/df/af/cd3290a647df567645353feed451ef4feaf5844496ced69c4dcb84295ff4/nodejs_wheel_binaries-24.12.0-py2.py3-none-win_arm64.whl", hash = "sha256:d0c2273b667dd7e3f55e369c0085957b702144b1b04bfceb7ce2411e58333757", size = 39048104, upload-time = "2025-12-11T21:12:23.495Z" }, -] - [[package]] name = "numba" version = "0.61.2" @@ -5112,11 +5087,11 @@ dev = [ { name = "nbval" }, { name = "prek" }, { name = "pyarrow" }, - { name = "pyright", extra = ["nodejs"] }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-xdist" }, { name = "ruff" }, + { name = "ty" }, ] [package.metadata] @@ -5171,11 +5146,11 @@ dev = [ { name = "nbval", specifier = ">=0.11.0" }, { name = "prek", specifier = ">=0.2.29" }, { name = "pyarrow", specifier = ">=15.0.0" }, - { name = "pyright", extras = ["nodejs"], specifier = ">=1.1.403" }, { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-asyncio", specifier = ">=1.1.0" }, { name = "pytest-xdist", specifier = ">=3.8.0" }, { name = "ruff", specifier = ">=0.12.1" }, + { name = "ty", specifier = ">=0.0.14" }, ] [[package]] @@ -6658,24 +6633,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178, upload-time = "2024-09-19T02:40:08.598Z" }, ] -[[package]] -name = "pyright" -version = "1.1.407" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nodeenv" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a6/1b/0aa08ee42948b61745ac5b5b5ccaec4669e8884b53d31c8ec20b2fcd6b6f/pyright-1.1.407.tar.gz", hash = "sha256:099674dba5c10489832d4a4b2d302636152a9a42d317986c38474c76fe562262", size = 4122872, upload-time = "2025-10-24T23:17:15.145Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/93/b69052907d032b00c40cb656d21438ec00b3a471733de137a3f65a49a0a0/pyright-1.1.407-py3-none-any.whl", hash = "sha256:6dd419f54fcc13f03b52285796d65e639786373f433e243f8b94cf93a7444d21", size = 5997008, upload-time = "2025-10-24T23:17:13.159Z" }, -] - -[package.optional-dependencies] -nodejs = [ - { name = "nodejs-wheel-binaries" }, -] - [[package]] name = "pytest" version = "9.0.2" @@ -8648,6 +8605,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/7e/bc19996fa86cad8801e8ffe6f1bba5836ca0160df76d0410d27432193712/trove_classifiers-2025.12.1.14-py3-none-any.whl", hash = "sha256:a8206978ede95937b9959c3aff3eb258bbf7b07dff391ddd4ea7e61f316635ab", size = 14184, upload-time = "2025-12-01T14:47:10.113Z" }, ] +[[package]] +name = "ty" +version = "0.0.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/57/22c3d6bf95c2229120c49ffc2f0da8d9e8823755a1c3194da56e51f1cc31/ty-0.0.14.tar.gz", hash = "sha256:a691010565f59dd7f15cf324cdcd1d9065e010c77a04f887e1ea070ba34a7de2", size = 5036573, upload-time = "2026-01-27T00:57:31.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/cb/cc6d1d8de59beb17a41f9a614585f884ec2d95450306c173b3b7cc090d2e/ty-0.0.14-py3-none-linux_armv6l.whl", hash = "sha256:32cf2a7596e693094621d3ae568d7ee16707dce28c34d1762947874060fdddaa", size = 10034228, upload-time = "2026-01-27T00:57:53.133Z" }, + { url = "https://files.pythonhosted.org/packages/f3/96/dd42816a2075a8f31542296ae687483a8d047f86a6538dfba573223eaf9a/ty-0.0.14-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f971bf9805f49ce8c0968ad53e29624d80b970b9eb597b7cbaba25d8a18ce9a2", size = 9939162, upload-time = "2026-01-27T00:57:43.857Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b4/73c4859004e0f0a9eead9ecb67021438b2e8e5fdd8d03e7f5aca77623992/ty-0.0.14-py3-none-macosx_11_0_arm64.whl", hash = "sha256:45448b9e4806423523268bc15e9208c4f3f2ead7c344f615549d2e2354d6e924", size = 9418661, upload-time = "2026-01-27T00:58:03.411Z" }, + { url = "https://files.pythonhosted.org/packages/58/35/839c4551b94613db4afa20ee555dd4f33bfa7352d5da74c5fa416ffa0fd2/ty-0.0.14-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee94a9b747ff40114085206bdb3205a631ef19a4d3fb89e302a88754cbbae54c", size = 9837872, upload-time = "2026-01-27T00:57:23.718Z" }, + { url = "https://files.pythonhosted.org/packages/41/2b/bbecf7e2faa20c04bebd35fc478668953ca50ee5847ce23e08acf20ea119/ty-0.0.14-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6756715a3c33182e9ab8ffca2bb314d3c99b9c410b171736e145773ee0ae41c3", size = 9848819, upload-time = "2026-01-27T00:57:58.501Z" }, + { url = "https://files.pythonhosted.org/packages/be/60/3c0ba0f19c0f647ad9d2b5b5ac68c0f0b4dc899001bd53b3a7537fb247a2/ty-0.0.14-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89d0038a2f698ba8b6fec5cf216a4e44e2f95e4a5095a8c0f57fe549f87087c2", size = 10324371, upload-time = "2026-01-27T00:57:29.291Z" }, + { url = "https://files.pythonhosted.org/packages/24/32/99d0a0b37d0397b0a989ffc2682493286aa3bc252b24004a6714368c2c3d/ty-0.0.14-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c64a83a2d669b77f50a4957039ca1450626fb474619f18f6f8a3eb885bf7544", size = 10865898, upload-time = "2026-01-27T00:57:33.542Z" }, + { url = "https://files.pythonhosted.org/packages/1a/88/30b583a9e0311bb474269cfa91db53350557ebec09002bfc3fb3fc364e8c/ty-0.0.14-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:242488bfb547ef080199f6fd81369ab9cb638a778bb161511d091ffd49c12129", size = 10555777, upload-time = "2026-01-27T00:58:05.853Z" }, + { url = "https://files.pythonhosted.org/packages/cd/a2/cb53fb6325dcf3d40f2b1d0457a25d55bfbae633c8e337bde8ec01a190eb/ty-0.0.14-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4790c3866f6c83a4f424fc7d09ebdb225c1f1131647ba8bdc6fcdc28f09ed0ff", size = 10412913, upload-time = "2026-01-27T00:57:38.834Z" }, + { url = "https://files.pythonhosted.org/packages/42/8f/f2f5202d725ed1e6a4e5ffaa32b190a1fe70c0b1a2503d38515da4130b4c/ty-0.0.14-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:950f320437f96d4ea9a2332bbfb5b68f1c1acd269ebfa4c09b6970cc1565bd9d", size = 9837608, upload-time = "2026-01-27T00:57:55.898Z" }, + { url = "https://files.pythonhosted.org/packages/f7/ba/59a2a0521640c489dafa2c546ae1f8465f92956fede18660653cce73b4c5/ty-0.0.14-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4a0ec3ee70d83887f86925bbc1c56f4628bd58a0f47f6f32ddfe04e1f05466df", size = 9884324, upload-time = "2026-01-27T00:57:46.786Z" }, + { url = "https://files.pythonhosted.org/packages/03/95/8d2a49880f47b638743212f011088552ecc454dd7a665ddcbdabea25772a/ty-0.0.14-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a1a4e6b6da0c58b34415955279eff754d6206b35af56a18bb70eb519d8d139ef", size = 10033537, upload-time = "2026-01-27T00:58:01.149Z" }, + { url = "https://files.pythonhosted.org/packages/e9/40/4523b36f2ce69f92ccf783855a9e0ebbbd0f0bb5cdce6211ee1737159ed3/ty-0.0.14-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:dc04384e874c5de4c5d743369c277c8aa73d1edea3c7fc646b2064b637db4db3", size = 10495910, upload-time = "2026-01-27T00:57:26.691Z" }, + { url = "https://files.pythonhosted.org/packages/08/d5/655beb51224d1bfd4f9ddc0bb209659bfe71ff141bcf05c418ab670698f0/ty-0.0.14-py3-none-win32.whl", hash = "sha256:b20e22cf54c66b3e37e87377635da412d9a552c9bf4ad9fc449fed8b2e19dad2", size = 9507626, upload-time = "2026-01-27T00:57:41.43Z" }, + { url = "https://files.pythonhosted.org/packages/b6/d9/c569c9961760e20e0a4bc008eeb1415754564304fd53997a371b7cf3f864/ty-0.0.14-py3-none-win_amd64.whl", hash = "sha256:e312ff9475522d1a33186657fe74d1ec98e4a13e016d66f5758a452c90ff6409", size = 10437980, upload-time = "2026-01-27T00:57:36.422Z" }, + { url = "https://files.pythonhosted.org/packages/ad/0c/186829654f5bfd9a028f6648e9caeb11271960a61de97484627d24443f91/ty-0.0.14-py3-none-win_arm64.whl", hash = "sha256:b6facdbe9b740cb2c15293a1d178e22ffc600653646452632541d01c36d5e378", size = 9885831, upload-time = "2026-01-27T00:57:49.747Z" }, +] + [[package]] name = "typeguard" version = "4.4.4"