Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 275 additions & 0 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import re
import sys
import json
import ctypes
Expand Down Expand Up @@ -4026,3 +4027,277 @@ def chatml_function_calling(
}

raise ValueError("Automatic streaming tool choice is not supported")


# ==========================================================================
# Gemma 4 native tool-call parsing + chat completion handler
# ==========================================================================
#
# Gemma 4 (released 2026-04-02) emits tool calls as native tokens of the form:
#
# <|tool_call>call:FUNCTION_NAME{key:value,key:value,...}<tool_call|>
#
# Argument values are encoded by type:
# - string : key:<|"|>value<|"|>
# - int : key:30
# - float : key:3.5
# - bool : key:true / key:false
# - null : key:null
# - list : key:[v1,v2,...] (element grammar is the same as above)
#
# An optional thinking-mode block precedes the tool call when thinking is on:
#
# <|channel>thought ... <channel|><|tool_call>call:...<tool_call|>
#
# Without a dedicated handler `create_chat_completion()` returns these tokens
# verbatim in `message.content` and `tool_calls` stays `None`, which silently
# breaks every OpenAI-compatible Gemma 4 client (issue #2227). The C++ server
# parses the same tokens via the PEG grammar added in
# https://github.com/ggml-org/llama.cpp/pull/21326 ; this is the Python port.

_GEMMA4_TOOL_CALL_RE = re.compile(
r"<\|tool_call>\s*call:(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*\{(?P<args>.*?)\}\s*<tool_call\|>",
re.DOTALL,
)
_GEMMA4_THOUGHT_RE = re.compile(r"<\|channel>\s*thought.*?<channel\|>", re.DOTALL)
_GEMMA4_STR_DELIM = '<|"|>'


def _gemma4_parse_value(s: str, pos: int) -> Tuple[Any, int]:
"""Parse a single Gemma 4 value starting at ``s[pos]``.

Returns ``(value, new_pos)`` where ``new_pos`` points just past the value.
"""
while pos < len(s) and s[pos].isspace():
pos += 1
# string literal: <|"|>...<|"|>
if s.startswith(_GEMMA4_STR_DELIM, pos):
start = pos + len(_GEMMA4_STR_DELIM)
end = s.find(_GEMMA4_STR_DELIM, start)
if end < 0:
return s[start:], len(s)
return s[start:end], end + len(_GEMMA4_STR_DELIM)
# list literal: [v1,v2,...]
if pos < len(s) and s[pos] == "[":
items: List[Any] = []
pos += 1
while pos < len(s):
while pos < len(s) and s[pos].isspace():
pos += 1
if pos < len(s) and s[pos] == "]":
return items, pos + 1
val, pos = _gemma4_parse_value(s, pos)
items.append(val)
while pos < len(s) and s[pos] in " \t,":
pos += 1
return items, pos
# primitive literal: read until separator
start = pos
while pos < len(s) and s[pos] not in ",}]":
pos += 1
raw = s[start:pos].strip()
if raw == "true":
return True, pos
if raw == "false":
return False, pos
if raw == "null":
return None, pos
try:
if "." in raw or "e" in raw.lower():
return float(raw), pos
return int(raw), pos
except ValueError:
return raw, pos


def _gemma4_parse_args(args_str: str) -> Dict[str, Any]:
"""Parse the inside of a Gemma 4 ``{...}`` arg block into ``{name: value}``."""
out: Dict[str, Any] = {}
pos = 0
while pos < len(args_str):
m = re.match(r"\s*([A-Za-z_][A-Za-z0-9_]*)\s*:", args_str[pos:])
if not m:
break
key = m.group(1)
pos += m.end()
val, pos = _gemma4_parse_value(args_str, pos)
out[key] = val
sep = re.match(r"\s*,\s*", args_str[pos:])
pos += sep.end() if sep else 0
return out


def _parse_gemma4_native_tool_calls(
text: str,
) -> Tuple[Optional[str], Optional[List[Dict[str, Any]]]]:
"""Extract Gemma 4 native tool-call tokens from a completion.

Returns ``(content_remainder, tool_calls)``. When no ``<|tool_call>`` token
is present the original ``text`` is returned with ``tool_calls=None`` so
plain-text replies pass through unchanged.
"""
cleaned = _GEMMA4_THOUGHT_RE.sub("", text)
if "<|tool_call>" not in cleaned:
return text, None
tool_calls: List[Dict[str, Any]] = []
for i, m in enumerate(_GEMMA4_TOOL_CALL_RE.finditer(cleaned)):
name = m.group("name")
args = _gemma4_parse_args(m.group("args"))
suffix = "".join(random.choices(string.hexdigits.lower()[:16], k=8))
tool_calls.append(
{
"id": f"call_{i}_{name}_{suffix}",
"type": "function",
"function": {"name": name, "arguments": json.dumps(args)},
}
)
if not tool_calls:
return text, None
remainder = _GEMMA4_TOOL_CALL_RE.sub("", cleaned).strip()
return (remainder or None), tool_calls


@register_chat_completion_handler("gemma4")
def gemma4_chat_completion(
llama: llama.Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
max_tokens: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
"""Chat completion handler for Gemma 4 (issue #2227).

Renders prompts via the model's embedded Jinja2 chat template (the
Gemma 4 GGUFs ship a correct one) and post-parses Gemma 4's native
tool-call tokens into OpenAI-compatible ``tool_calls`` on the assistant
message. Streaming responses are passed through unchanged — callers can
buffer chunks and re-parse via ``_parse_gemma4_native_tool_calls`` until
an incremental PEG parser is ported from ggml-org/llama.cpp#21326.
"""
template = (getattr(llama, "metadata", None) or {}).get("tokenizer.chat_template")
if not template:
raise ValueError(
"chat_format='gemma4' requires a GGUF model with an embedded "
"tokenizer.chat_template (Gemma 4 GGUFs ship one by default)."
)

eos_id = llama.token_eos()
bos_id = llama.token_bos()
eos_token = llama._model.token_get_text(eos_id) if eos_id != -1 else ""
bos_token = llama._model.token_get_text(bos_id) if bos_id != -1 else ""

formatter = Jinja2ChatFormatter(
template=template,
eos_token=eos_token,
bos_token=bos_token,
add_generation_prompt=True,
)
result = formatter(
messages=messages,
functions=functions,
function_call=function_call,
tools=tools,
tool_choice=tool_choice,
)
prompt = llama.tokenize(
result.prompt.encode("utf-8"),
add_bos=not result.added_special,
special=True,
)

effective_stop: List[str] = []
if stop:
effective_stop = [stop] if isinstance(stop, str) else list(stop)
if result.stop is not None:
effective_stop += result.stop if isinstance(result.stop, list) else [result.stop]

if response_format is not None and response_format.get("type") == "json_object":
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)

completion_or_chunks = llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
logprobs=top_logprobs if logprobs else None,
stream=stream,
stop=effective_stop,
seed=seed,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
stopping_criteria=result.stopping_criteria,
grammar=grammar,
logit_bias=logit_bias,
)

if stream:
return _convert_completion_to_chat(completion_or_chunks, stream=True)

completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
text = completion["choices"][0]["text"]
content, tool_calls = _parse_gemma4_native_tool_calls(text)

message: Dict[str, Any] = {"role": "assistant", "content": content if not tool_calls else None}
if tool_calls:
message["tool_calls"] = tool_calls

chat_response: llama_types.CreateChatCompletionResponse = {
"id": "chat" + completion["id"],
"object": "chat.completion",
"created": completion["created"],
"model": completion["model"],
"choices": [
{
"index": 0,
"finish_reason": (
"tool_calls" if tool_calls else completion["choices"][0]["finish_reason"]
),
"logprobs": _convert_text_completion_logprobs_to_chat(
completion["choices"][0]["logprobs"]
),
"message": message,
}
],
"usage": completion.get(
"usage", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
),
}
return chat_response
95 changes: 95 additions & 0 deletions tests/test_llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,98 @@ def test_hf_tokenizer_config_str_to_chat_formatter():
)

assert chat_formatter_respoonse.prompt == ("<s>[INST] Hello, world! [/INST]</s>")


def _parse(text):
return llama_chat_format._parse_gemma4_native_tool_calls(text)


def test_gemma4_parse_string_args():
text = (
'<|tool_call>call:write_file{'
'content:<|"|>print("hello")<|"|>,'
'file_path:<|"|>hello.py<|"|>'
'}<tool_call|>'
)
content, tool_calls = _parse(text)
assert content is None
assert tool_calls is not None and len(tool_calls) == 1
fn = tool_calls[0]["function"]
assert fn["name"] == "write_file"
assert json.loads(fn["arguments"]) == {
"content": 'print("hello")',
"file_path": "hello.py",
}


def test_gemma4_parse_primitive_args():
text = (
'<|tool_call>call:do_thing{'
'timeout:30,temperature:0.5,background:false,note:null'
'}<tool_call|>'
)
_, tool_calls = _parse(text)
assert json.loads(tool_calls[0]["function"]["arguments"]) == {
"timeout": 30,
"temperature": 0.5,
"background": False,
"note": None,
}


def test_gemma4_parse_list_of_strings():
text = (
'<|tool_call>call:read_files{'
'files:[<|"|>a.py<|"|>,<|"|>b.py<|"|>]'
'}<tool_call|>'
)
_, tool_calls = _parse(text)
assert json.loads(tool_calls[0]["function"]["arguments"]) == {
"files": ["a.py", "b.py"]
}


def test_gemma4_strips_thought_block():
text = (
'<|channel>thought\nLet me call the function.\n<channel|>'
'<|tool_call>call:f{x:1}<tool_call|>'
)
_, tool_calls = _parse(text)
assert tool_calls and json.loads(tool_calls[0]["function"]["arguments"]) == {"x": 1}


def test_gemma4_plain_text_passthrough():
text = "Just a normal reply with no tool call."
content, tool_calls = _parse(text)
assert tool_calls is None
assert content == text


def test_gemma4_multiple_tool_calls():
text = (
'<|tool_call>call:a{x:1}<tool_call|>'
'<|tool_call>call:b{y:<|"|>two<|"|>}<tool_call|>'
)
_, tool_calls = _parse(text)
assert len(tool_calls) == 2
assert tool_calls[0]["function"]["name"] == "a"
assert tool_calls[1]["function"]["name"] == "b"
assert json.loads(tool_calls[1]["function"]["arguments"]) == {"y": "two"}
# IDs must be unique across calls.
assert tool_calls[0]["id"] != tool_calls[1]["id"]


def test_gemma4_surrounding_plain_text():
text = "Sure, I will help.\n<|tool_call>call:f{x:1}<tool_call|>"
content, tool_calls = _parse(text)
assert tool_calls is not None
assert content == "Sure, I will help."


def test_gemma4_string_with_embedded_quotes():
# Delimiter is the 3-char sequence <|"|>, so literal " inside is fine.
text = '<|tool_call>call:say{msg:<|"|>hello, "world"!<|"|>}<tool_call|>'
_, tool_calls = _parse(text)
assert json.loads(tool_calls[0]["function"]["arguments"]) == {
"msg": 'hello, "world"!'
}
Loading