diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 2ce8c53c844..2f181d7740a 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -40,7 +40,7 @@ from tqdm import tqdm from fastdeploy import envs -from fastdeploy.config import FDConfig, LoadConfig +from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.linear import KVBatchLinear from fastdeploy.model_executor.utils import multi_switch_config_context @@ -72,6 +72,11 @@ def layers_are_grouped(keys): return True +def values_are_naturally_ordered(values): + """Check if values are sorted in natural order.""" + return list(values) == sorted(values, key=natural_key) + + def pdparams_weight_iterator(paddle_file_list: list[str]): for pdparams_file in tqdm( paddle_file_list, @@ -117,10 +122,12 @@ def get_model_path(fd_config: FDConfig): return model_path -def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None): +def get_weight_iterator(model_path: str, fd_config: Optional[FDConfig] = None): files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path) if use_safetensors: + load_config = fd_config.load_config if fd_config else None extra_config = load_config.model_loader_extra_config if load_config else None + parallel_config = fd_config.parallel_config if fd_config else None if extra_config is not None and extra_config.get("enable_multithread_load", False): weights_iterator = multi_thread_safetensors_weights_iterator( files_list, @@ -128,7 +135,7 @@ def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = Non disable_mmap=extra_config.get("disable_mmap", False), ) else: - if is_layers_are_grouped: + if is_layers_are_grouped or (parallel_config is not None and parallel_config.tensor_parallel_size == 1): weights_iterator = safetensors_weights_iterator(files_list) else: weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map) @@ -532,7 +539,10 @@ def get_all_weights_file(model_path: str): with index_file.open("r") as f: weight_map = json.load(f)["weight_map"] keys = list(weight_map.keys()) - is_layers_are_grouped = layers_are_grouped(keys) + values = list(weight_map.values()) + is_keys_orders = layers_are_grouped(keys) + is_values_naturally_ordered = values_are_naturally_ordered(values) + is_layers_are_grouped = is_keys_orders and is_values_naturally_ordered ordered_weight_map = { key: str(model_path / weight_map[key]) for key in sorted(weight_map.keys(), key=natural_key) } diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 1217f9de28b..ab24323f305 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -57,7 +57,7 @@ def clean_memory_fragments(self) -> None: @measure_time() def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None: model_path = get_model_path(fd_config) - weights_iterator = get_weight_iterator(model_path, fd_config.load_config) + weights_iterator = get_weight_iterator(model_path, fd_config) if enable_cache: load_weights_from_cache(model, weights_iterator) else: