66
77from arraybridge .converters_registry import get_converter
88from arraybridge .framework_config import _FRAMEWORK_CONFIG
9- from arraybridge .types import MemoryType
9+ from arraybridge .types import MemoryType , VALID_MEMORY_TYPES
1010
1111
1212def convert_memory (data : Any , source_type : str , target_type : str , gpu_id : int ) -> Any :
@@ -26,6 +26,13 @@ def convert_memory(data: Any, source_type: str, target_type: str, gpu_id: int) -
2626 ValueError: If source_type or target_type is invalid
2727 MemoryConversionError: If conversion fails
2828 """
29+ if isinstance (target_type , MemoryType ):
30+ target_type = target_type .value
31+ if target_type not in VALID_MEMORY_TYPES :
32+ raise ValueError (
33+ f"Invalid target_type '{ target_type } '. Available types: { sorted (VALID_MEMORY_TYPES )} "
34+ )
35+
2936 converter = get_converter (source_type ) # Will raise ValueError if invalid
3037 method = getattr (converter , f"to_{ target_type } " )
3138 return method (data , gpu_id )
@@ -51,10 +58,14 @@ def detect_memory_type(data: Any) -> str:
5158 # Check all frameworks using their module names from config
5259 module_name = type (data ).__module__
5360
61+ top_level = module_name .split ("." )[0 ]
62+
5463 for mem_type , config in _FRAMEWORK_CONFIG .items ():
5564 import_name = config ["import_name" ]
56- # Check if module name starts with or contains the import name
57- if module_name .startswith (import_name ) or import_name in module_name :
65+ aliases = {import_name }
66+ if import_name == "jax" :
67+ aliases .add ("jaxlib" )
68+ if top_level in aliases :
5869 return mem_type .value
5970
6071 raise ValueError (f"Unknown memory type for { type (data )} (module: { module_name } )" )
0 commit comments