Skip to content

Commit d0a249c

Browse files
committed
Tighten memory detection and target validation
1 parent 2b2daf7 commit d0a249c

File tree

3 files changed

+15
-16
lines changed

3 files changed

+15
-16
lines changed

src/arraybridge/converters.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from arraybridge.converters_registry import get_converter
88
from arraybridge.framework_config import _FRAMEWORK_CONFIG
9-
from arraybridge.types import MemoryType
9+
from arraybridge.types import MemoryType, VALID_MEMORY_TYPES
1010

1111

1212
def convert_memory(data: Any, source_type: str, target_type: str, gpu_id: int) -> Any:
@@ -26,6 +26,13 @@ def convert_memory(data: Any, source_type: str, target_type: str, gpu_id: int) -
2626
ValueError: If source_type or target_type is invalid
2727
MemoryConversionError: If conversion fails
2828
"""
29+
if isinstance(target_type, MemoryType):
30+
target_type = target_type.value
31+
if target_type not in VALID_MEMORY_TYPES:
32+
raise ValueError(
33+
f"Invalid target_type '{target_type}'. Available types: {sorted(VALID_MEMORY_TYPES)}"
34+
)
35+
2936
converter = get_converter(source_type) # Will raise ValueError if invalid
3037
method = getattr(converter, f"to_{target_type}")
3138
return method(data, gpu_id)
@@ -51,10 +58,14 @@ def detect_memory_type(data: Any) -> str:
5158
# Check all frameworks using their module names from config
5259
module_name = type(data).__module__
5360

61+
top_level = module_name.split(".")[0]
62+
5463
for mem_type, config in _FRAMEWORK_CONFIG.items():
5564
import_name = config["import_name"]
56-
# Check if module name starts with or contains the import name
57-
if module_name.startswith(import_name) or import_name in module_name:
65+
aliases = {import_name}
66+
if import_name == "jax":
67+
aliases.add("jaxlib")
68+
if top_level in aliases:
5869
return mem_type.value
5970

6071
raise ValueError(f"Unknown memory type for {type(data)} (module: {module_name})")

src/arraybridge/framework_config.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,6 @@ def _tensorflow_validate_dlpack(obj: Any, mod: Any) -> bool:
120120
f"Clause 88 violation: Cannot infer DLPack capability."
121121
)
122122

123-
# Check GPU
124-
"""Validate TensorFlow DLPack support."""
125-
# Check version
126-
major, minor = map(int, mod.__version__.split(".")[:2])
127-
if major < 2 or (major == 2 and minor < 12):
128-
raise RuntimeError(
129-
f"TensorFlow {mod.__version__} does not support stable DLPack. "
130-
f"Version 2.12.0+ required. "
131-
f"Clause 88 violation: Cannot infer DLPack capability."
132-
)
133-
134123
# Check GPU
135124
device_str = obj.device.lower()
136125
if "gpu" not in device_str:

tests/test_converters.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,7 @@ def test_convert_invalid_source_type(self):
8585
def test_convert_invalid_target_type(self):
8686
"""Test that invalid target type raises error."""
8787
arr = np.array([1, 2, 3])
88-
# This might raise ValueError or AttributeError depending on implementation
89-
with pytest.raises((ValueError, AttributeError)):
88+
with pytest.raises(ValueError):
9089
convert_memory(arr, source_type="numpy", target_type="invalid_type", gpu_id=0)
9190

9291
@pytest.mark.torch

0 commit comments

Comments
 (0)