Skip to content

Commit fc00a83

Browse files
author
Chris Warren-Smith
committed
LLM: plugin module - added repeat penalty handling
1 parent 4753316 commit fc00a83

File tree

4 files changed

+212
-50
lines changed

4 files changed

+212
-50
lines changed

llama/llama-sb.cpp

Lines changed: 78 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//
66
// Copyright(C) 2026 Chris Warren-Smith
77

8+
#include <chrono>
89
#include <vector>
910
#include "llama.h"
1011
#include "llama-sb.h"
@@ -14,12 +15,14 @@ Llama::Llama() :
1415
_ctx(nullptr),
1516
_sampler(nullptr),
1617
_vocab(nullptr),
18+
_penalty_last_n(64),
19+
_penalty_repeat(1.1f),
1720
_temperature(0),
1821
_top_k(0),
1922
_top_p(1.0f),
2023
_min_p(0.0f),
2124
_max_tokens(150),
22-
_log_level(GGML_LOG_LEVEL_NONE) {
25+
_log_level(GGML_LOG_LEVEL_CONT) {
2326
llama_log_set([](enum ggml_log_level level, const char * text, void *user_data) {
2427
Llama *llama = (Llama *)user_data;
2528
if (level > llama->_log_level) {
@@ -82,6 +85,10 @@ bool Llama::construct(string model_path, int n_ctx, int n_batch) {
8285

8386
void Llama::configure_sampler() {
8487
llama_sampler_reset(_sampler);
88+
if (_penalty_last_n != 0 && _penalty_repeat != 1.0f) {
89+
auto penalties = llama_sampler_init_penalties(_penalty_last_n, _penalty_repeat, 0.0f, 0.0f);
90+
llama_sampler_chain_add(_sampler, penalties);
91+
}
8592
if (_temperature <= 0.0f) {
8693
llama_sampler_chain_add(_sampler, llama_sampler_init_greedy());
8794
} else {
@@ -99,72 +106,104 @@ void Llama::configure_sampler() {
99106
}
100107
}
101108

109+
void Llama::reset() {
110+
// llama_kv_cache_clear(it->second->ctx);
111+
_chat_prompt.clear();
112+
}
113+
102114
string Llama::generate(const string &prompt) {
103-
string out;
115+
string out = prompt;
104116

105-
// find the number of tokens in the prompt
106-
int n_prompt = -llama_tokenize(_vocab, prompt.c_str(), prompt.size(), nullptr, 0, true, true);
117+
// ---- tokenize prompt ----
118+
int n_prompt = -llama_tokenize(_vocab, prompt.c_str(), prompt.size(),
119+
nullptr, 0, true, true);
120+
121+
if (n_prompt <= 0) {
122+
_last_error = "failed to tokenize prompt";
123+
return out;
124+
}
107125

108-
// allocate space for the tokens and tokenize the prompt
109126
std::vector<llama_token> prompt_tokens(n_prompt);
110-
if (llama_tokenize(_vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
111-
_last_error = "failed tokenize the prompt";
127+
if (llama_tokenize(_vocab, prompt.c_str(), prompt.size(),
128+
prompt_tokens.data(), n_prompt, true, true) < 0) {
129+
_last_error = "failed to tokenize prompt";
112130
return out;
113131
}
114132

115-
// initialize the sampler
133+
// ---- sampler ----
116134
configure_sampler();
117135

118-
// prepare a batch for the prompt
119-
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
120-
if (llama_model_has_encoder(_model)) {
121-
if (llama_encode(_ctx, batch)) {
122-
_last_error = "failed to eval";
123-
return out;
124-
}
136+
// ---- decode prompt ----
137+
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), n_prompt);
138+
if (llama_decode(_ctx, batch)) {
139+
_last_error = "failed to eval prompt";
140+
return out;
141+
}
125142

143+
// ---- handle encoder models ----
144+
if (llama_model_has_encoder(_model)) {
126145
llama_token decoder_start_token_id = llama_model_decoder_start_token(_model);
127146
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
128147
decoder_start_token_id = llama_vocab_bos(_vocab);
129148
}
130-
131149
batch = llama_batch_get_one(&decoder_start_token_id, 1);
132-
}
133-
134-
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + _max_tokens;) {
135-
// evaluate the current batch with the transformer model
136150
if (llama_decode(_ctx, batch)) {
137-
_last_error = "failed to eval";
138-
break;
151+
_last_error = "failed to eval decoder start token";
152+
return out;
139153
}
154+
}
155+
156+
// ---- generation loop ----
157+
std::vector<llama_token> decoded;
158+
decoded.reserve(_max_tokens);
140159

141-
n_pos += batch.n_tokens;
160+
int generated = 0;
161+
auto t_start = std::chrono::high_resolution_clock::now();
142162

143-
// sample the next token
144-
llama_token new_token_id = llama_sampler_sample(_sampler, _ctx, -1);
163+
while (generated < _max_tokens) {
164+
// sample one token from the current logits
165+
llama_token tok = llama_sampler_sample(_sampler, _ctx, -1);
145166

146-
// is it an end of generation?
147-
if (llama_vocab_is_eog(_vocab, new_token_id)) {
167+
// end-of-generation check
168+
if (llama_vocab_is_eog(_vocab, tok)) {
148169
break;
149170
}
150171

151-
char buf[128];
152-
int n = llama_token_to_piece(_vocab, new_token_id, buf, sizeof(buf), 0, true);
153-
if (n < 0) {
154-
_last_error = "failed to convert token to piece";
172+
// append token to decoded list
173+
decoded.push_back(tok);
174+
++generated;
175+
176+
// ---- decode the token immediately ----
177+
llama_batch batch = llama_batch_get_one(&tok, 1);
178+
if (llama_decode(_ctx, batch)) {
179+
_last_error = "failed to eval token during generation";
155180
break;
156-
} else if (n > 0) {
157-
out.append(buf, n);
158181
}
182+
}
159183

160-
// prepare the next batch with the sampled token
161-
batch = llama_batch_get_one(&new_token_id, 1);
184+
// ---- detokenize sequentially ----
185+
if (!decoded.empty()) {
186+
char buf[512];
187+
for (llama_token tok : decoded) {
188+
if (llama_vocab_is_control(_vocab, tok)) {
189+
continue;
190+
}
191+
int n = llama_token_to_piece(_vocab, tok, buf, sizeof(buf), 0, false);
192+
if (n > 0) {
193+
out.append(buf, n);
194+
}
195+
}
162196
}
163197

198+
// ---- timing ----
199+
auto t_end = std::chrono::high_resolution_clock::now();
200+
double secs = std::chrono::duration<double>(t_end - t_start).count();
201+
double tokps = secs > 0 ? generated / secs : 0;
202+
203+
fprintf(stderr,
204+
"[tok/s=%.2f] generated=%d time=%.3fs\n",
205+
tokps, generated, secs);
206+
164207
return out;
165208
}
166209

167-
void Llama::reset() {
168-
// llama_kv_cache_clear(it->second->ctx);
169-
_chat_prompt.clear();
170-
}

llama/llama-sb.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ struct Llama {
2323
string generate(const string &prompt);
2424

2525
// generation parameters
26+
27+
void set_penalty_last_n(int32_t penalty_last_n) { _penalty_last_n = penalty_last_n; }
28+
void set_penalty_repeat(float penalty_repeat) { _penalty_repeat = penalty_repeat; }
2629
void set_max_tokens(int max_tokens) { _max_tokens = max_tokens; }
2730
void set_min_p(float min_p) { _min_p = min_p; }
2831
void set_temperature(float temperature) { _temperature = temperature; }
@@ -49,6 +52,8 @@ struct Llama {
4952
const llama_vocab *_vocab;
5053
string _chat_prompt;
5154
string _last_error;
55+
int32_t _penalty_last_n;
56+
float _penalty_repeat;
5257
float _temperature;
5358
float _top_p;
5459
float _min_p;

llama/main.cpp

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,44 @@ static string expand_path(const char *path) {
5050
return result;
5151
}
5252

53+
54+
//
55+
// llama.set_penalty_repeat(0.8)
56+
//
57+
static int cmd_llama_set_penalty_repeat(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
58+
int result = 0;
59+
if (argc != 1) {
60+
error(retval, "llama.set_penalty_repeat", 1, 1);
61+
} else {
62+
int id = get_class_id(self, retval);
63+
if (id != -1) {
64+
Llama &llama = g_map.at(id);
65+
llama.set_penalty_repeat(get_param_num(argc, arg, 0, 0));
66+
result = 1;
67+
}
68+
}
69+
return result;
70+
}
71+
72+
//
73+
// llama.set_penalty_last_n(0.8)
74+
//
75+
static int cmd_llama_set_penalty_last_n(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
76+
int result = 0;
77+
if (argc != 1) {
78+
error(retval, "llama.set_penalty_last_n", 1, 1);
79+
} else {
80+
int id = get_class_id(self, retval);
81+
if (id != -1) {
82+
Llama &llama = g_map.at(id);
83+
llama.set_penalty_last_n(get_param_num(argc, arg, 0, 0));
84+
result = 1;
85+
}
86+
}
87+
return result;
88+
}
89+
90+
5391
//
5492
// llama.set_max_tokens(50)
5593
//
@@ -105,7 +143,7 @@ static int cmd_llama_set_temperature(var_s *self, int argc, slib_par_t *arg, var
105143
}
106144

107145
//
108-
// llama.set_set_top_k(10.0)
146+
// llama.set_top_k(10.0)
109147
//
110148
static int cmd_llama_set_top_k(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
111149
int result = 0;
@@ -122,6 +160,9 @@ static int cmd_llama_set_top_k(var_s *self, int argc, slib_par_t *arg, var_s *re
122160
return result;
123161
}
124162

163+
//
164+
// llama.set_top_p(0)
165+
//
125166
static int cmd_llama_set_top_p(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
126167
int result = 0;
127168
if (argc != 1) {
@@ -207,14 +248,16 @@ static int cmd_llama_generate(var_s *self, int argc, slib_par_t *arg, var_s *ret
207248
static int cmd_create_llama(int argc, slib_par_t *params, var_t *retval) {
208249
int result;
209250
auto model = expand_path(get_param_str(argc, params, 0, ""));
210-
auto n_ctx = get_param_int(argc, params, 0, 2048);
211-
auto n_batch = get_param_int(argc, params, 1, 1024);
212-
auto temperature = get_param_num(argc, params, 2, 0.25);
251+
auto n_ctx = get_param_int(argc, params, 1, 2048);
252+
auto n_batch = get_param_int(argc, params, 2, 1024);
253+
auto temperature = get_param_num(argc, params, 3, 0.25);
213254
int id = ++g_nextId;
214255
Llama &llama = g_map[id];
215256
if (llama.construct(model, n_ctx, n_batch)) {
216257
llama.set_temperature(temperature);
217258
map_init_id(retval, id, CLASS_ID);
259+
v_create_callback(retval, "set_penalty_repeat", cmd_llama_set_penalty_repeat);
260+
v_create_callback(retval, "set_penalty_last_n", cmd_llama_set_penalty_last_n);
218261
v_create_callback(retval, "set_max_tokens", cmd_llama_set_max_tokens);
219262
v_create_callback(retval, "set_min_p", cmd_llama_set_min_p);
220263
v_create_callback(retval, "set_temperature", cmd_llama_set_temperature);

llama/samples/chat.bas

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,87 @@
1+
import llm
12

2-
const llama = llm.llama("qwen.gguf", 1024)
3+
const model = "models/Qwen_Qwen2.5-1.5B-Instruct-GGUF-Q4/qwen2.5-1.5b-instruct-q4_k_m.gguf"
4+
const llama = llm.llama(model, 4096, 512)
35

4-
print llama.generate("Write a BASIC program", 256, 0.2)
6+
llama.set_max_tokens(150)
7+
llama.set_min_p(0.5)
8+
llama.set_temperature(.8)
9+
llama.set_top_k(1)
10+
llama.set_top_p(0)
511

6-
print llama.chat("Hello")
7-
print llama.chat("Write a BASIC program to draw a cat")
8-
print llama.chat("Now add color")
912

10-
llama.reset()
13+
rem factual answers, tools, summaries
14+
' llama.set_max_tokens(150)
15+
' llama.set_temperature(0.0)
16+
' llama.set_top_k(1)
17+
' llama.set_top_p(0.0)
18+
' llama.set_min_p(0.0)
1119

12-
print llama.chat("Who are you?")
20+
rem assistant, Q+A, explanations, chat
21+
' llama.set_max_tokens(150)
22+
' llama.set_temperature(0.8)
23+
' llama.set_top_k(40)
24+
' llama.set_top_p(0.0)
25+
' llama.set_min_p(0.05)
26+
27+
rem creative, storytelling
28+
' llama.set_max_tokens(200)
29+
' llama.set_temperature(1.0)
30+
' llama.set_top_k(80)
31+
' llama.set_top_p(0.0)
32+
' llama.set_min_p(0.1)
33+
34+
rem surprises/loko
35+
' llama.set_max_tokens(200)
36+
' llama.set_temperature(1.2)
37+
' llama.set_top_k(120)
38+
' llama.set_top_p(0.0)
39+
' llama.set_min_p(0.15)
40+
41+
rem technical, conservative
42+
' llama.set_max_tokens(150)
43+
' llama.set_temperature(0.6)
44+
' llama.set_top_k(30)
45+
' llama.set_top_p(0.0)
46+
' llama.set_min_p(0.02)
47+
48+
rem speed optimised on CPU
49+
llama.set_max_tokens(150)
50+
llama.set_temperature(0.7)
51+
llama.set_top_k(20)
52+
llama.set_top_p(0.0)
53+
llama.set_min_p(0.05)
54+
55+
' // Conservative - minimal repetition control
56+
' _penalty_last_n = 64;
57+
' _penalty_repeat = 1.05f;
58+
59+
' // Balanced - good default
60+
' _penalty_last_n = 64;
61+
' _penalty_repeat = 1.1f;
62+
63+
' // Aggressive - strong anti-repetition
64+
' _penalty_last_n = 128;
65+
' _penalty_repeat = 1.2f;
66+
67+
' // Disabled
68+
' _penalty_last_n = 0;
69+
' _penalty_repeat = 1.0f;
70+
71+
llama.set_penalty_repeat(1.15)
72+
llama.set_penalty_last_n(64)
73+
74+
75+
76+
llm_prompt = """\
77+
you are a helpful assistant\
78+
\nQuestion: when is dinner?\
79+
"""
80+
81+
print llm_prompt
82+
print llama.generate(llm_prompt)
83+
84+
' iter = llama.generate(llm_prompt)
85+
' while iter != 0
86+
' print iter.next()
87+
' wend

0 commit comments

Comments
 (0)