diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 1024fb85b..1af34044d 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import re import sys import json import ctypes @@ -4026,3 +4027,277 @@ def chatml_function_calling( } raise ValueError("Automatic streaming tool choice is not supported") + + +# ========================================================================== +# Gemma 4 native tool-call parsing + chat completion handler +# ========================================================================== +# +# Gemma 4 (released 2026-04-02) emits tool calls as native tokens of the form: +# +# <|tool_call>call:FUNCTION_NAME{key:value,key:value,...} +# +# Argument values are encoded by type: +# - string : key:<|"|>value<|"|> +# - int : key:30 +# - float : key:3.5 +# - bool : key:true / key:false +# - null : key:null +# - list : key:[v1,v2,...] (element grammar is the same as above) +# +# An optional thinking-mode block precedes the tool call when thinking is on: +# +# <|channel>thought ... <|tool_call>call:... +# +# Without a dedicated handler `create_chat_completion()` returns these tokens +# verbatim in `message.content` and `tool_calls` stays `None`, which silently +# breaks every OpenAI-compatible Gemma 4 client (issue #2227). The C++ server +# parses the same tokens via the PEG grammar added in +# https://github.com/ggml-org/llama.cpp/pull/21326 ; this is the Python port. + +_GEMMA4_TOOL_CALL_RE = re.compile( + r"<\|tool_call>\s*call:(?P[A-Za-z_][A-Za-z0-9_]*)\s*\{(?P.*?)\}\s*", + re.DOTALL, +) +_GEMMA4_THOUGHT_RE = re.compile(r"<\|channel>\s*thought.*?", re.DOTALL) +_GEMMA4_STR_DELIM = '<|"|>' + + +def _gemma4_parse_value(s: str, pos: int) -> Tuple[Any, int]: + """Parse a single Gemma 4 value starting at ``s[pos]``. + + Returns ``(value, new_pos)`` where ``new_pos`` points just past the value. + """ + while pos < len(s) and s[pos].isspace(): + pos += 1 + # string literal: <|"|>...<|"|> + if s.startswith(_GEMMA4_STR_DELIM, pos): + start = pos + len(_GEMMA4_STR_DELIM) + end = s.find(_GEMMA4_STR_DELIM, start) + if end < 0: + return s[start:], len(s) + return s[start:end], end + len(_GEMMA4_STR_DELIM) + # list literal: [v1,v2,...] + if pos < len(s) and s[pos] == "[": + items: List[Any] = [] + pos += 1 + while pos < len(s): + while pos < len(s) and s[pos].isspace(): + pos += 1 + if pos < len(s) and s[pos] == "]": + return items, pos + 1 + val, pos = _gemma4_parse_value(s, pos) + items.append(val) + while pos < len(s) and s[pos] in " \t,": + pos += 1 + return items, pos + # primitive literal: read until separator + start = pos + while pos < len(s) and s[pos] not in ",}]": + pos += 1 + raw = s[start:pos].strip() + if raw == "true": + return True, pos + if raw == "false": + return False, pos + if raw == "null": + return None, pos + try: + if "." in raw or "e" in raw.lower(): + return float(raw), pos + return int(raw), pos + except ValueError: + return raw, pos + + +def _gemma4_parse_args(args_str: str) -> Dict[str, Any]: + """Parse the inside of a Gemma 4 ``{...}`` arg block into ``{name: value}``.""" + out: Dict[str, Any] = {} + pos = 0 + while pos < len(args_str): + m = re.match(r"\s*([A-Za-z_][A-Za-z0-9_]*)\s*:", args_str[pos:]) + if not m: + break + key = m.group(1) + pos += m.end() + val, pos = _gemma4_parse_value(args_str, pos) + out[key] = val + sep = re.match(r"\s*,\s*", args_str[pos:]) + pos += sep.end() if sep else 0 + return out + + +def _parse_gemma4_native_tool_calls( + text: str, +) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]: + """Extract Gemma 4 native tool-call tokens from a completion. + + Returns ``(content_remainder, tool_calls)``. When no ``<|tool_call>`` token + is present the original ``text`` is returned with ``tool_calls=None`` so + plain-text replies pass through unchanged. + """ + cleaned = _GEMMA4_THOUGHT_RE.sub("", text) + if "<|tool_call>" not in cleaned: + return text, None + tool_calls: List[Dict[str, Any]] = [] + for i, m in enumerate(_GEMMA4_TOOL_CALL_RE.finditer(cleaned)): + name = m.group("name") + args = _gemma4_parse_args(m.group("args")) + suffix = "".join(random.choices(string.hexdigits.lower()[:16], k=8)) + tool_calls.append( + { + "id": f"call_{i}_{name}_{suffix}", + "type": "function", + "function": {"name": name, "arguments": json.dumps(args)}, + } + ) + if not tool_calls: + return text, None + remainder = _GEMMA4_TOOL_CALL_RE.sub("", cleaned).strip() + return (remainder or None), tool_calls + + +@register_chat_completion_handler("gemma4") +def gemma4_chat_completion( + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = None, + seed: Optional[int] = None, + response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, + max_tokens: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + **kwargs, # type: ignore +) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], +]: + """Chat completion handler for Gemma 4 (issue #2227). + + Renders prompts via the model's embedded Jinja2 chat template (the + Gemma 4 GGUFs ship a correct one) and post-parses Gemma 4's native + tool-call tokens into OpenAI-compatible ``tool_calls`` on the assistant + message. Streaming responses are passed through unchanged — callers can + buffer chunks and re-parse via ``_parse_gemma4_native_tool_calls`` until + an incremental PEG parser is ported from ggml-org/llama.cpp#21326. + """ + template = (getattr(llama, "metadata", None) or {}).get("tokenizer.chat_template") + if not template: + raise ValueError( + "chat_format='gemma4' requires a GGUF model with an embedded " + "tokenizer.chat_template (Gemma 4 GGUFs ship one by default)." + ) + + eos_id = llama.token_eos() + bos_id = llama.token_bos() + eos_token = llama._model.token_get_text(eos_id) if eos_id != -1 else "" + bos_token = llama._model.token_get_text(bos_id) if bos_id != -1 else "" + + formatter = Jinja2ChatFormatter( + template=template, + eos_token=eos_token, + bos_token=bos_token, + add_generation_prompt=True, + ) + result = formatter( + messages=messages, + functions=functions, + function_call=function_call, + tools=tools, + tool_choice=tool_choice, + ) + prompt = llama.tokenize( + result.prompt.encode("utf-8"), + add_bos=not result.added_special, + special=True, + ) + + effective_stop: List[str] = [] + if stop: + effective_stop = [stop] if isinstance(stop, str) else list(stop) + if result.stop is not None: + effective_stop += result.stop if isinstance(result.stop, list) else [result.stop] + + if response_format is not None and response_format.get("type") == "json_object": + grammar = _grammar_for_response_format(response_format, verbose=llama.verbose) + + completion_or_chunks = llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + logprobs=top_logprobs if logprobs else None, + stream=stream, + stop=effective_stop, + seed=seed, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + stopping_criteria=result.stopping_criteria, + grammar=grammar, + logit_bias=logit_bias, + ) + + if stream: + return _convert_completion_to_chat(completion_or_chunks, stream=True) + + completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks) + text = completion["choices"][0]["text"] + content, tool_calls = _parse_gemma4_native_tool_calls(text) + + message: Dict[str, Any] = {"role": "assistant", "content": content if not tool_calls else None} + if tool_calls: + message["tool_calls"] = tool_calls + + chat_response: llama_types.CreateChatCompletionResponse = { + "id": "chat" + completion["id"], + "object": "chat.completion", + "created": completion["created"], + "model": completion["model"], + "choices": [ + { + "index": 0, + "finish_reason": ( + "tool_calls" if tool_calls else completion["choices"][0]["finish_reason"] + ), + "logprobs": _convert_text_completion_logprobs_to_chat( + completion["choices"][0]["logprobs"] + ), + "message": message, + } + ], + "usage": completion.get( + "usage", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + ), + } + return chat_response diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py index 18c7279cf..613c385ca 100644 --- a/tests/test_llama_chat_format.py +++ b/tests/test_llama_chat_format.py @@ -92,3 +92,98 @@ def test_hf_tokenizer_config_str_to_chat_formatter(): ) assert chat_formatter_respoonse.prompt == ("[INST] Hello, world! [/INST]") + + +def _parse(text): + return llama_chat_format._parse_gemma4_native_tool_calls(text) + + +def test_gemma4_parse_string_args(): + text = ( + '<|tool_call>call:write_file{' + 'content:<|"|>print("hello")<|"|>,' + 'file_path:<|"|>hello.py<|"|>' + '}' + ) + content, tool_calls = _parse(text) + assert content is None + assert tool_calls is not None and len(tool_calls) == 1 + fn = tool_calls[0]["function"] + assert fn["name"] == "write_file" + assert json.loads(fn["arguments"]) == { + "content": 'print("hello")', + "file_path": "hello.py", + } + + +def test_gemma4_parse_primitive_args(): + text = ( + '<|tool_call>call:do_thing{' + 'timeout:30,temperature:0.5,background:false,note:null' + '}' + ) + _, tool_calls = _parse(text) + assert json.loads(tool_calls[0]["function"]["arguments"]) == { + "timeout": 30, + "temperature": 0.5, + "background": False, + "note": None, + } + + +def test_gemma4_parse_list_of_strings(): + text = ( + '<|tool_call>call:read_files{' + 'files:[<|"|>a.py<|"|>,<|"|>b.py<|"|>]' + '}' + ) + _, tool_calls = _parse(text) + assert json.loads(tool_calls[0]["function"]["arguments"]) == { + "files": ["a.py", "b.py"] + } + + +def test_gemma4_strips_thought_block(): + text = ( + '<|channel>thought\nLet me call the function.\n' + '<|tool_call>call:f{x:1}' + ) + _, tool_calls = _parse(text) + assert tool_calls and json.loads(tool_calls[0]["function"]["arguments"]) == {"x": 1} + + +def test_gemma4_plain_text_passthrough(): + text = "Just a normal reply with no tool call." + content, tool_calls = _parse(text) + assert tool_calls is None + assert content == text + + +def test_gemma4_multiple_tool_calls(): + text = ( + '<|tool_call>call:a{x:1}' + '<|tool_call>call:b{y:<|"|>two<|"|>}' + ) + _, tool_calls = _parse(text) + assert len(tool_calls) == 2 + assert tool_calls[0]["function"]["name"] == "a" + assert tool_calls[1]["function"]["name"] == "b" + assert json.loads(tool_calls[1]["function"]["arguments"]) == {"y": "two"} + # IDs must be unique across calls. + assert tool_calls[0]["id"] != tool_calls[1]["id"] + + +def test_gemma4_surrounding_plain_text(): + text = "Sure, I will help.\n<|tool_call>call:f{x:1}" + content, tool_calls = _parse(text) + assert tool_calls is not None + assert content == "Sure, I will help." + + +def test_gemma4_string_with_embedded_quotes(): + # Delimiter is the 3-char sequence <|"|>, so literal " inside is fine. + text = '<|tool_call>call:say{msg:<|"|>hello, "world"!<|"|>}' + _, tool_calls = _parse(text) + assert json.loads(tool_calls[0]["function"]["arguments"]) == { + "msg": 'hello, "world"!' + }