From 3c6a0f2f911f8b0adfcaf312d6089e59a8034d6a Mon Sep 17 00:00:00 2001 From: rubik Date: Thu, 14 May 2026 11:16:13 +0800 Subject: [PATCH] =?UTF-8?q?issue/350=20-=20Support=20ChatGLM=20model=201?= =?UTF-8?q?=E3=80=81attention=E5=B1=82=E4=B8=8EGLM4=E4=B8=80=E8=87=B4?= =?UTF-8?q?=EF=BC=8C=E5=A4=8D=E7=94=A8=E3=80=82=202=E3=80=81decoder?= =?UTF-8?q?=E4=B8=8E=E6=A0=87=E5=87=86llama=E4=B8=80=E6=A0=B7=203=E3=80=81?= =?UTF-8?q?=E5=90=84=E4=B8=AAlayer=E5=B1=82=E5=90=8D=E5=AD=97=E4=B8=8Ellam?= =?UTF-8?q?a=E8=BF=9B=E8=A1=8C=E6=98=A0=E5=B0=84=204=E3=80=81=E4=BF=AE?= =?UTF-8?q?=E5=A4=8Dexamples/test=5Finfer.py=E4=B8=AD=E7=9A=84batch?= =?UTF-8?q?=E5=A4=84=E7=90=86=205=E3=80=81=E4=BF=AE=E5=A4=8Dbench=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E6=97=B6=E7=9A=84key=E5=90=8D=E5=AD=97=E4=B8=8D?= =?UTF-8?q?=E4=B8=80=E8=87=B4=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- csrc/models/chatglm/chatglm_for_causal_lm.cpp | 51 +++++ csrc/models/chatglm/chatglm_for_causal_lm.hpp | 23 +++ examples/bench.py | 20 ++ examples/test_infer.py | 2 +- python/infinilm/modeling_utils.py | 176 ++++++++++++++---- python/infinilm/processors/__init__.py | 3 + .../infinilm/processors/chatglm_processor.py | 46 +++++ 7 files changed, 285 insertions(+), 36 deletions(-) create mode 100644 csrc/models/chatglm/chatglm_for_causal_lm.cpp create mode 100644 csrc/models/chatglm/chatglm_for_causal_lm.hpp create mode 100644 python/infinilm/processors/chatglm_processor.py diff --git a/csrc/models/chatglm/chatglm_for_causal_lm.cpp b/csrc/models/chatglm/chatglm_for_causal_lm.cpp new file mode 100644 index 00000000..f366528c --- /dev/null +++ b/csrc/models/chatglm/chatglm_for_causal_lm.cpp @@ -0,0 +1,51 @@ +#include "chatglm_for_causal_lm.hpp" +#include "../models_registry.hpp" + +namespace infinilm::models::chatglm { + +std::shared_ptr create_chatglm_model_config( + std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + if ("chatglm" != model_type) { + throw std::runtime_error( + "infinilm::models::chatglm::create_chatglm_model_config: model_type is not chatglm"); + } + + nlohmann::json &config_json = model_config->get_config_json(); + auto rename_key = [&config_json](const std::string &old_key, const std::string &new_key) { + if (config_json.contains(old_key) && !config_json.contains(new_key)) { + config_json[new_key] = config_json[old_key]; + } + }; + + rename_key("num_layers", "num_hidden_layers"); + rename_key("multi_query_group_num", "num_key_value_heads"); + rename_key("kv_channels", "head_dim"); + rename_key("layernorm_epsilon", "rms_norm_eps"); + rename_key("seq_length", "max_position_embeddings"); + rename_key("ffn_hidden_size", "intermediate_size"); + + if (!config_json.contains("vocab_size") && config_json.contains("padded_vocab_size")) { + config_json["vocab_size"] = config_json["padded_vocab_size"]; + } + + if (!config_json.contains("attention_bias")) { + config_json["attention_bias"] = true; + } + + if (!config_json.contains("rope_theta")) { + config_json["rope_theta"] = 10000.0; + } + + return model_config; +} + +} // namespace infinilm::models::chatglm + +namespace { + +INFINILM_REGISTER_CAUSAL_LM_MODEL( + chatglm, + infinilm::models::chatglm::ChatglmForCausalLM, + infinilm::models::chatglm::create_chatglm_model_config); +} // namespace diff --git a/csrc/models/chatglm/chatglm_for_causal_lm.hpp b/csrc/models/chatglm/chatglm_for_causal_lm.hpp new file mode 100644 index 00000000..0d7a4c57 --- /dev/null +++ b/csrc/models/chatglm/chatglm_for_causal_lm.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "../../layers/common_modules.hpp" +#include "../glm4/glm4_attention.hpp" +#include + +namespace infinilm::models::chatglm { + +using ChatglmMLP = infinilm::layers::MLP; + +// Reuse Glm4Attention as ChatGLM and GLM4 share the identical attention layer +using ChatglmAttention = infinilm::models::glm4::Glm4Attention; + +using ChatglmDecoderLayer = infinilm::layers::causal_lm_templates::TextDecoderLayer; + +using ChatglmModel = infinilm::layers::causal_lm_templates::TextModel; + +using ChatglmForCausalLM = infinilm::layers::causal_lm_templates::TextCausalLM; + +std::shared_ptr create_chatglm_model_config( + std::shared_ptr model_config); + +} // namespace infinilm::models::chatglm diff --git a/examples/bench.py b/examples/bench.py index 6c5e68ef..1c8837a0 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -24,6 +24,24 @@ _PAGED_KV_BLOCK_SIZE = 256 +_CONFIG_KEY_MAP = { + "chatglm": { + "num_key_value_heads": "multi_query_group_num", + "num_hidden_layers": "num_layers", + "head_dim": "kv_channels", + }, +} + +def _normalize_config(config, model_type): + key_map = _CONFIG_KEY_MAP.get(model_type) + if not key_map: + return config + normalized = dict(config) + for std_key, model_key in key_map.items(): + if model_key in normalized: + normalized.setdefault(std_key, normalized[model_key]) + return normalized + # BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128] # INPUT_LENS = [32, 256, 1024, 4096] # OUTPUT_LENS = [256, 1024, 4096] @@ -46,6 +64,8 @@ def get_test_cases( """Generate cases ordered by ascending KV cache memory usage.""" # Load model config to derive attention dimensions config = read_json_file(os.path.join(model_path, "config.json")) + model_type = config.get("model_type", "") + config = _normalize_config(config, model_type) head_dim = config.get( "head_dim", config.get("hidden_size") // config.get("num_attention_heads") ) diff --git a/examples/test_infer.py b/examples/test_infer.py index 97bbdd3a..a3ce5e6d 100644 --- a/examples/test_infer.py +++ b/examples/test_infer.py @@ -42,7 +42,7 @@ def test( ) conversations = [ - {"role": "user", "content": [{"type": "text", "text": prompt}]} + [{"role": "user", "content": [{"type": "text", "text": prompt}]}] for prompt in prompts ] if image_path is not None: diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 93b9bf45..36bcc5f0 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -197,7 +197,7 @@ def load_model_state_dict_by_file( # Apply model-specific weight remapping remapper = _WEIGHT_REMAPPER.get(model_type) if remapper is not None: - model_param = remapper(model_param) + model_param = remapper(model_param, config=model.hf_config) already_loaded_keys.extend(model_param.keys()) @@ -228,7 +228,7 @@ def load_model_state_dict_by_file( # Apply model-specific weight remapping remapper = _WEIGHT_REMAPPER.get(model_type) if remapper is not None: - model_params = remapper(model_params) + model_params = remapper(model_params, config=model.hf_config) # Scale embed_tokens on torch side before converting if "model.embed_tokens.weight" in model_params: @@ -332,77 +332,143 @@ def load_model_state_dict_by_tensor( t2 = time.time() print(f" load weights over! {(t2 - t1) * 1000} ms \n") - # ============================================================================ # Common weight transformation utilities # ============================================================================ +def drop_keys( + state_dict: Dict[str, torch.Tensor], + substrings: List[str], +) -> Dict[str, torch.Tensor]: + """Drop keys containing any of the given substrings.""" + return { + k: v for k, v in state_dict.items() + if not any(sub in k for sub in substrings) + } + + +def rename_keys( + state_dict: Dict[str, torch.Tensor], + mapping: Dict[str, str], +) -> Dict[str, torch.Tensor]: + """Rename weight keys according to a substring mapping.""" + result = {} + for key, tensor in state_dict.items(): + new_key = key + for old_str, new_str in mapping.items(): + new_key = new_key.replace(old_str, new_str) + result[new_key] = tensor + return result + def split_fused_weight( state_dict: Dict[str, torch.Tensor], fused_key: str, output_names: List[str], split_dim: int = 0, - split_ratios: Optional[List[float]] = None, + split_sizes: Optional[List[int]] = None, ) -> Dict[str, torch.Tensor]: """Split fused weight tensors into separate weights. Args: - state_dict: Original state dict from HuggingFace safetensors. - fused_key: Substring to match in key names (e.g. "gate_up_proj"). - output_names: Names of the split outputs (e.g. ["gate_proj", "up_proj"]). + state_dict: Original state dict. + fused_key: Substring to match in key names (e.g. "query_key_value"). + output_names: Names of the split outputs (e.g. ["q_proj", "k_proj", "v_proj"]). split_dim: Dimension along which to split. Default 0. - split_ratios: Optional ratios. If None, split equally. + split_sizes: Optional explicit sizes for each split. Supports -1 to mean + "the remaining size". If None, split equally. Returns: New state dict with fused keys replaced by split keys. + + Examples: + # Equal 2-way split (e.g. gate_up_proj.weight) + split_fused_weight(sd, "gate_up_proj", ["gate_proj", "up_proj"]) + + # Dynamic 3-way split with bias (e.g. query_key_value.weight + bias) + split_fused_weight(sd, "query_key_value", ["q_proj", "k_proj", "v_proj"], + split_sizes=[q_dim, k_dim, -1]) """ result = {} + marker = f".{fused_key}." + for key, tensor in state_dict.items(): - if fused_key not in key: + if marker not in key: result[key] = tensor continue - base_key = key.replace(f".{fused_key}.weight", "") + # Extract base_key and suffix (handles both .weight and .bias) + base_key, suffix = key.split(marker, 1) dim_size = tensor.shape[split_dim] - num_splits = len(output_names) - if split_ratios is not None: - total_ratio = sum(split_ratios) - sizes = [int(dim_size * r / total_ratio) for r in split_ratios[:-1]] - sizes.append(dim_size - sum(sizes)) + # Calculate split sizes + if split_sizes is not None: + sizes = [] + remainder = dim_size + for s in split_sizes: + if s == -1: + sizes.append(0) # placeholder + else: + sizes.append(s) + remainder -= s + # Fill -1 placeholders with remainder + sizes = [remainder if s == 0 else s for s in sizes] else: + num_splits = len(output_names) chunk = dim_size // num_splits sizes = [chunk] * (num_splits - 1) sizes.append(dim_size - chunk * (num_splits - 1)) splits = torch.split(tensor, sizes, dim=split_dim) for name, split_tensor in zip(output_names, splits): - result[f"{base_key}.{name}.weight"] = split_tensor + result[f"{base_key}.{name}.{suffix}"] = split_tensor return result - -def rename_keys( +def split_fused_weight_with_sizes( state_dict: Dict[str, torch.Tensor], - mapping: Dict[str, str], + fused_key: str, + output_names: List[str], + split_sizes: List[int], + split_dim: int = 0, ) -> Dict[str, torch.Tensor]: - """Rename weight keys according to a substring mapping.""" + """Split fused weight tensors into separate weights with explicit sizes. + Supports -1 in split_sizes to mean "the remaining size". + Handles both .weight and .bias suffixes (unlike split_fused_weight + which only handles .weight). + """ result = {} + marker = f".{fused_key}." + for key, tensor in state_dict.items(): - new_key = key - for old_str, new_str in mapping.items(): - new_key = new_key.replace(old_str, new_str) - result[new_key] = tensor - return result + if marker not in key: + result[key] = tensor + continue + + base_key, suffix = key.split(marker, 1) + dim_size = tensor.shape[split_dim] + # Resolve -1 (remainder) + sizes = [] + remainder = dim_size + for s in split_sizes: + if s == -1: + sizes.append(0) # placeholder + else: + sizes.append(s) + remainder -= s + sizes = [remainder if s == 0 else s for s in sizes] + + splits = torch.split(tensor, sizes, dim=split_dim) + for name, split_tensor in zip(output_names, splits): + result[f"{base_key}.{name}.{suffix}"] = split_tensor + + return result # ============================================================================ # Model-specific remap functions # ============================================================================ - - -def _remap_glm4(state_dict): +def _remap_glm4(state_dict, config=None): """Split GLM-4 fused gate_up_proj into gate_proj + up_proj.""" return split_fused_weight( state_dict, @@ -411,15 +477,55 @@ def _remap_glm4(state_dict): ) -# Add more model remap functions here as needed: -# -# def _remap_qwen3(state_dict): -# state_dict = split_fused_weight(state_dict, "gate_up_proj", ["gate_proj", "up_proj"]) -# state_dict = rename_keys(state_dict, {"model.layers": "decoder.layers"}) -# return state_dict +def _remap_chatglm(state_dict, config=None): + """Remap ChatGLM weights to InfiniLM format. + + Faithfully ported from the original working _remap_chatglm_weights. + """ + hf_config = config or {} + num_heads = hf_config.get("num_attention_heads", 32) + num_kv = hf_config.get("multi_query_group_num", 2) + head_dim = hf_config.get("kv_channels", 128) + ffn_hidden = hf_config.get("ffn_hidden_size", 13696) + + q_dim = num_heads * head_dim + k_dim = num_kv * head_dim + + # 1. Drop unused keys + state_dict = drop_keys(state_dict, ["rotary_pos_emb"]) + + # 2. Split QKV + state_dict = split_fused_weight_with_sizes( + state_dict, + fused_key="query_key_value", + output_names=["q_proj", "k_proj", "v_proj"], + split_sizes=[q_dim, k_dim, -1], + ) + + # 3. Split gate_up + state_dict = split_fused_weight_with_sizes( + state_dict, + fused_key="dense_h_to_4h", + output_names=["gate_proj", "up_proj"], + split_sizes=[ffn_hidden, -1], + ) + + # 4. Rename keys + state_dict = rename_keys(state_dict, { + "transformer.encoder.layers.": "model.layers.", + "transformer.embedding.word_embeddings": "model.embed_tokens", + "transformer.encoder.final_layernorm": "model.norm", + "transformer.output_layer": "lm_head", + "self_attention.": "self_attn.", + "self_attn.dense": "self_attn.o_proj", + "mlp.dense_4h_to_h": "mlp.down_proj", + }) + + return state_dict + # Model type → remap function mapping _WEIGHT_REMAPPER = { "glm4": _remap_glm4, - # "qwen3": _remap_qwen3, + "chatglm": _remap_chatglm, } diff --git a/python/infinilm/processors/__init__.py b/python/infinilm/processors/__init__.py index 61adff6d..0d9ba5fe 100644 --- a/python/infinilm/processors/__init__.py +++ b/python/infinilm/processors/__init__.py @@ -1,6 +1,7 @@ from .processor import InfinilmProcessor from .basic_llm_processor import BasicLLMProcessor from .llama_processor import LlamaProcessor +from .chatglm_processor import ChatGLMProcessor from transformers import AutoConfig @@ -14,5 +15,7 @@ def from_pretrained(cls, model_dir_path: str, **kwargs) -> InfinilmProcessor: if model_type in ["llama"]: return LlamaProcessor(model_dir_path) + elif model_type in ["chatglm"]: + return ChatGLMProcessor(model_dir_path) else: return BasicLLMProcessor(model_dir_path) diff --git a/python/infinilm/processors/chatglm_processor.py b/python/infinilm/processors/chatglm_processor.py new file mode 100644 index 00000000..ba2e3432 --- /dev/null +++ b/python/infinilm/processors/chatglm_processor.py @@ -0,0 +1,46 @@ +# python/infinilm/processors/chatglm_processor.py + +import re +import types +from .basic_llm_processor import BasicLLMProcessor + + +class ChatGLMProcessor(BasicLLMProcessor): + def __init__(self, model_dir_path: str): + super().__init__(model_dir_path) + self._fix_tokenizer_decode(self.tokenizer) + + @staticmethod + def _fix_tokenizer_decode(tokenizer): + """Fix ChatGLM tokenizer: patch convert_tokens_to_string. + + ChatGLM uses SentencePiece which encodes spaces as ▁ (U+2581). + Its convert_tokens_to_string calls self.tokenizer.decode_tokens(tokens), + which strips ▁ when decoding tokens incrementally, losing inter-word spaces. + + Fix: replace convert_tokens_to_string to: + 1. Join tokens + replace ▁ → space (the only thing decode_tokens gets wrong) + 2. Handle byte fallback: consecutive <0xHH> sequences → UTF-8 chars + + This keeps decode()'s other logic (skip_special_tokens, + clean_up_tokenization_spaces, etc.) intact. + + ▁ (U+2581) and _ (U+005F) are different characters in SentencePiece, + so this replacement will NOT affect real underscores. + """ + def patched_convert_tokens_to_string(self_tok, tokens): + # 1. Join tokens + replace ▁ (U+2581) with space + text = "".join(tokens).replace("\u2581", " ") + + # 2. Handle SentencePiece byte fallback: consecutive <0xHH> → UTF-8 + def byte_fallback_replace(match): + hex_strs = re.findall(r"<0x([0-9A-Fa-f]{2})>", match.group(0)) + byte_values = bytes([int(h, 16) for h in hex_strs]) + return byte_values.decode("utf-8", errors="replace") + + text = re.sub(r"(<0x[0-9A-Fa-f]{2}>)+", byte_fallback_replace, text) + return text + + tokenizer.convert_tokens_to_string = types.MethodType( + patched_convert_tokens_to_string, tokenizer + )