Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/runpod_flash/core/resources/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,17 @@ def to_gpu_ids_str(cls, gpu_types: List[GpuType | GpuGroup]) -> str:
out.append(pool_id.value)
else:
out.append(str(pool_id))
return ",".join(out)
return cls.normalize_gpu_ids_str(",".join(out))

@classmethod
def normalize_gpu_ids_str(cls, gpu_ids_str: str) -> str:
"""Return canonical gpuIds string for stable comparisons and hashing."""
tokens = {token.strip() for token in gpu_ids_str.split(",") if token.strip()}
ordered_tokens = sorted(
tokens,
key=lambda token: (token.startswith("-"), token.lstrip("-").lower()),
)
return ",".join(ordered_tokens)

@classmethod
def from_gpu_ids_str(cls, gpu_ids_str: str) -> List[GpuGroup | GpuType]:
Expand Down
3 changes: 3 additions & 0 deletions src/runpod_flash/core/resources/serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,9 @@ async def _sync_graphql_object_with_inputs(
return returned_endpoint

def _sync_input_fields_gpu(self):
if self.gpuIds:
self.gpuIds = GpuGroup.normalize_gpu_ids_str(self.gpuIds)

# GPU-specific fields (idempotent - only set if not already set)
if self.gpus and not self.gpuIds:
# Convert gpus list to gpuIds string
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/resources/test_gpu_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ def test_b200_type_maps_to_blackwell_180(self):
parsed = GpuGroup.from_gpu_ids_str(gpu_ids)
assert parsed == [GpuGroup.BLACKWELL_180]

def test_to_gpu_ids_str_is_order_stable(self):
first = GpuGroup.to_gpu_ids_str([GpuGroup.AMPERE_24, GpuType.NVIDIA_RTX_A5000])
second = GpuGroup.to_gpu_ids_str([GpuType.NVIDIA_RTX_A5000, GpuGroup.AMPERE_24])

assert first == second

def test_normalize_gpu_ids_str_sorts_and_deduplicates_tokens(self):
normalized = GpuGroup.normalize_gpu_ids_str(
"-NVIDIA GeForce RTX 3090,AMPERE_24,AMPERE_24,NVIDIA L4"
)

assert normalized == "AMPERE_24,NVIDIA L4,-NVIDIA GeForce RTX 3090"

def test_rtx_pro_6000_type_maps_to_blackwell_96(self):
gpu_ids = GpuGroup.to_gpu_ids_str(
[GpuType.NVIDIA_RTX_PRO_6000_BLACKWELL_SERVER_EDITION]
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/resources/test_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,14 @@ def test_reverse_sync_gpuids_to_gpus(self):
assert GpuGroup.AMPERE_48 in serverless.gpus
assert GpuGroup.AMPERE_24 in serverless.gpus

def test_sync_input_fields_normalizes_gpuids_order(self):
serverless = ServerlessResource(
name="test",
gpuIds="-NVIDIA GeForce RTX 3090,AMPERE_24,NVIDIA L4",
)

assert serverless.gpuIds == "AMPERE_24,NVIDIA L4,-NVIDIA GeForce RTX 3090"

def test_reverse_sync_cuda_versions(self):
"""Test reverse sync from allowedCudaVersions string to cudaVersions list."""
serverless = ServerlessResource(
Expand Down
Loading