Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions fastdeploy/model_executor/load_weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from tqdm import tqdm

from fastdeploy import envs
from fastdeploy.config import FDConfig, LoadConfig
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.linear import KVBatchLinear
from fastdeploy.model_executor.utils import multi_switch_config_context

Expand Down Expand Up @@ -72,6 +72,11 @@ def layers_are_grouped(keys):
return True


def values_are_naturally_ordered(values):
"""Check if values are sorted in natural order."""
return list(values) == sorted(values, key=natural_key)


def pdparams_weight_iterator(paddle_file_list: list[str]):
for pdparams_file in tqdm(
paddle_file_list,
Expand Down Expand Up @@ -117,18 +122,20 @@ def get_model_path(fd_config: FDConfig):
return model_path


def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None):
def get_weight_iterator(model_path: str, fd_config: Optional[FDConfig] = None):
files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path)
if use_safetensors:
load_config = fd_config.load_config if fd_config else None
extra_config = load_config.model_loader_extra_config if load_config else None
parallel_config = fd_config.parallel_config if fd_config else None
if extra_config is not None and extra_config.get("enable_multithread_load", False):
weights_iterator = multi_thread_safetensors_weights_iterator(
files_list,
max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS),
disable_mmap=extra_config.get("disable_mmap", False),
)
else:
if is_layers_are_grouped:
if is_layers_are_grouped or (parallel_config is not None and parallel_config.tensor_parallel_size == 1):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 tensor_parallel_size == 1 时强制使用无序迭代器缺少注释说明

is_layers_are_grouped=False(keys 未按层聚合 或 values 未按自然序排列)但 tp_size==1 时,仍然选择 safetensors_weights_iterator(按文件顺序遍历所有 tensor)。这一短路逻辑的正确性依赖于「TP=1 时单 rank 加载完整权重,加载顺序不影响正确性」这一前提,建议补充注释说明,例如:

# When tensor_parallel_size == 1, the rank loads full model weights,
# so tensor loading order does not need to match the weight_map order.
if is_layers_are_grouped or (parallel_config is not None and parallel_config.tensor_parallel_size == 1):

weights_iterator = safetensors_weights_iterator(files_list)
else:
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
Expand Down Expand Up @@ -532,7 +539,10 @@ def get_all_weights_file(model_path: str):
with index_file.open("r") as f:
weight_map = json.load(f)["weight_map"]
keys = list(weight_map.keys())
is_layers_are_grouped = layers_are_grouped(keys)
values = list(weight_map.values())
is_keys_orders = layers_are_grouped(keys)

This comment was marked as outdated.

is_values_naturally_ordered = values_are_naturally_ordered(values)
is_layers_are_grouped = is_keys_orders and is_values_naturally_ordered
ordered_weight_map = {
key: str(model_path / weight_map[key]) for key in sorted(weight_map.keys(), key=natural_key)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def clean_memory_fragments(self) -> None:
@measure_time()
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
model_path = get_model_path(fd_config)
weights_iterator = get_weight_iterator(model_path, fd_config.load_config)
weights_iterator = get_weight_iterator(model_path, fd_config)
if enable_cache:
load_weights_from_cache(model, weights_iterator)
else:
Expand Down
Loading