Skip to content

Commit fad622e

Browse files
author
Chris Warren-Smith
committed
LLM: plugin module - impl add_stop func
1 parent 059de75 commit fad622e

File tree

3 files changed

+92
-35
lines changed

3 files changed

+92
-35
lines changed

llama/llama-sb.cpp

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
// Copyright(C) 2026 Chris Warren-Smith
77

88
#include <chrono>
9-
#include <vector>
109
#include "llama.h"
1110
#include "llama-sb.h"
1211

12+
constexpr int MAX_REPEAT = 5;
13+
1314
LlamaIter::LlamaIter() :
1415
_llama(nullptr),
1516
_tokens_sec(0),
17+
_repetition_count(0),
1618
_has_next(false) {
1719
}
1820

@@ -21,20 +23,21 @@ Llama::Llama() :
2123
_ctx(nullptr),
2224
_sampler(nullptr),
2325
_vocab(nullptr),
24-
_penalty_last_n(64),
25-
_penalty_repeat(1.1f),
26+
_penalty_last_n(0),
27+
_penalty_repeat(0),
2628
_temperature(0),
2729
_top_k(0),
28-
_top_p(1.0f),
29-
_min_p(0.0f),
30-
_max_tokens(150),
30+
_top_p(0),
31+
_min_p(0),
32+
_max_tokens(0),
3133
_log_level(GGML_LOG_LEVEL_CONT) {
3234
llama_log_set([](enum ggml_log_level level, const char * text, void *user_data) {
3335
Llama *llama = (Llama *)user_data;
3436
if (level > llama->_log_level) {
3537
fprintf(stderr, "LLAMA: %s", text);
3638
}
3739
}, this);
40+
reset();
3841
}
3942

4043
Llama::~Llama() {
@@ -49,6 +52,18 @@ Llama::~Llama() {
4952
}
5053
}
5154

55+
void Llama::reset() {
56+
_stop_sequences.clear();
57+
_last_error = "";
58+
_penalty_last_n = 64;
59+
_penalty_repeat = 1.1f;
60+
_temperature = 0;
61+
_top_k = 0;
62+
_top_p = 1.0f;
63+
_min_p = 0.0f;
64+
_max_tokens = 150;
65+
}
66+
5267
bool Llama::construct(string model_path, int n_ctx, int n_batch) {
5368
ggml_backend_load_all();
5469

@@ -100,31 +115,31 @@ void Llama::configure_sampler() {
100115
}
101116
}
102117

103-
void Llama::reset() {
104-
llama_sampler_reset(_sampler);
105-
_chat_prompt.clear();
106-
}
107-
108-
bool Llama::generate(LlamaIter &iter, const string &prompt) {
109-
int n_prompt = -llama_tokenize(_vocab, prompt.c_str(), prompt.size(),
110-
nullptr, 0, true, true);
118+
vector<llama_token> Llama::tokenize(const string &prompt) {
119+
vector<llama_token> result;
111120

121+
int n_prompt = -llama_tokenize(_vocab, prompt.c_str(), prompt.size(), nullptr, 0, true, true);
112122
if (n_prompt <= 0) {
113123
_last_error = "failed to tokenize prompt";
114-
return false;
124+
} else {
125+
result.reserve(n_prompt);
126+
result.resize(n_prompt);
127+
if (llama_tokenize(_vocab, prompt.c_str(), prompt.size(),
128+
result.data(), n_prompt, true, true) < 0) {
129+
_last_error = "failed to tokenize prompt";
130+
}
115131
}
132+
return result;
133+
}
116134

117-
std::vector<llama_token> prompt_tokens(n_prompt);
118-
if (llama_tokenize(_vocab, prompt.c_str(), prompt.size(),
119-
prompt_tokens.data(), n_prompt, true, true) < 0) {
120-
_last_error = "failed to tokenize prompt";
135+
bool Llama::generate(LlamaIter &iter, const string &prompt) {
136+
vector<llama_token> prompt_tokens = tokenize(prompt);
137+
if (prompt_tokens.size() == 0) {
121138
return false;
122139
}
123140

124-
configure_sampler();
125-
126141
// decode prompt
127-
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), n_prompt);
142+
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
128143
if (llama_decode(_ctx, batch)) {
129144
_last_error = "failed to eval prompt";
130145
return false;
@@ -143,8 +158,9 @@ bool Llama::generate(LlamaIter &iter, const string &prompt) {
143158
}
144159
}
145160

161+
configure_sampler();
162+
146163
iter._llama = this;
147-
iter._batch = batch;
148164
iter._has_next = true;
149165
iter._tokens_sec = 0;
150166
return true;
@@ -153,7 +169,7 @@ bool Llama::generate(LlamaIter &iter, const string &prompt) {
153169
string Llama::next(LlamaIter &iter) {
154170
string out;
155171

156-
std::vector<llama_token> decoded;
172+
vector<llama_token> decoded;
157173
decoded.reserve(_max_tokens);
158174

159175
int generated = 0;
@@ -183,14 +199,32 @@ string Llama::next(LlamaIter &iter) {
183199

184200
// detokenize sequentially
185201
if (!decoded.empty()) {
186-
char buf[512];
187202
for (llama_token tok : decoded) {
188-
if (llama_vocab_is_control(_vocab, tok)) {
189-
continue;
190-
}
191-
int n = llama_token_to_piece(_vocab, tok, buf, sizeof(buf), 0, false);
192-
if (n > 0) {
193-
out.append(buf, n);
203+
if (!llama_vocab_is_control(_vocab, tok)) {
204+
char buf[512];
205+
int n = llama_token_to_piece(_vocab, tok, buf, sizeof(buf), 0, false);
206+
if (n > 0) {
207+
if (iter._last_word == buf) {
208+
if (++iter._repetition_count == MAX_REPEAT) {
209+
iter._has_next = false;
210+
break;
211+
}
212+
} else {
213+
iter._repetition_count = 0;
214+
iter._last_word = buf;
215+
}
216+
out.append(buf, n);
217+
218+
for (const auto &stop : _stop_sequences) {
219+
size_t pos = out.find(stop);
220+
if (pos != std::string::npos) {
221+
// found stop sequence - truncate and signal end
222+
out = out.substr(0, pos);
223+
iter._has_next = false;
224+
break;
225+
}
226+
}
227+
}
194228
}
195229
}
196230
}

llama/llama-sb.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
#include <string>
11+
#include <vector>
1112
#include "llama.h"
1213

1314
using namespace std;
@@ -19,8 +20,9 @@ struct LlamaIter {
1920
~LlamaIter() {}
2021

2122
Llama *_llama;
22-
llama_batch _batch;
2323
float _tokens_sec;
24+
string _last_word;
25+
int _repetition_count;
2426
bool _has_next;
2527
};
2628

@@ -36,6 +38,8 @@ struct Llama {
3638
string next(LlamaIter &iter);
3739

3840
// generation parameters
41+
void add_stop(const char *stop) { _stop_sequences.push_back(stop); }
42+
void clear_stops() { _stop_sequences.clear(); }
3943
void set_penalty_last_n(int32_t penalty_last_n) { _penalty_last_n = penalty_last_n; }
4044
void set_penalty_repeat(float penalty_repeat) { _penalty_repeat = penalty_repeat; }
4145
void set_max_tokens(int max_tokens) { _max_tokens = max_tokens; }
@@ -51,12 +55,13 @@ struct Llama {
5155

5256
private:
5357
void configure_sampler();
58+
vector<llama_token> tokenize(const string &prompt);
5459

5560
llama_model *_model;
5661
llama_context *_ctx;
5762
llama_sampler *_sampler;
5863
const llama_vocab *_vocab;
59-
string _chat_prompt;
64+
vector<string> _stop_sequences;
6065
string _last_error;
6166
int32_t _penalty_last_n;
6267
float _penalty_repeat;

llama/main.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,23 @@ static string expand_path(const char *path) {
6666
return result;
6767
}
6868

69+
//
70+
// llama.add_stop('xyz')
71+
//
72+
static int cmd_llama_add_stop(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
73+
int result = 0;
74+
if (argc != 1) {
75+
error(retval, "llama.add_stop", 1, 1);
76+
} else {
77+
int id = get_llama_class_id(self, retval);
78+
if (id != -1) {
79+
Llama &llama = g_llama.at(id);
80+
llama.add_stop(get_param_str(argc, arg, 0, "stop"));
81+
result = 1;
82+
}
83+
}
84+
return result;
85+
}
6986

7087
//
7188
// llama.set_penalty_repeat(0.8)
@@ -304,15 +321,16 @@ static int cmd_create_llama(int argc, slib_par_t *params, var_t *retval) {
304321
Llama &llama = g_llama[id];
305322
if (llama.construct(model, n_ctx, n_batch)) {
306323
map_init_id(retval, id, CLASS_ID_LLAMA);
324+
v_create_callback(retval, "add_stop", cmd_llama_add_stop);
325+
v_create_callback(retval, "generate", cmd_llama_generate);
326+
v_create_callback(retval, "reset", cmd_llama_reset);
307327
v_create_callback(retval, "set_penalty_repeat", cmd_llama_set_penalty_repeat);
308328
v_create_callback(retval, "set_penalty_last_n", cmd_llama_set_penalty_last_n);
309329
v_create_callback(retval, "set_max_tokens", cmd_llama_set_max_tokens);
310330
v_create_callback(retval, "set_min_p", cmd_llama_set_min_p);
311331
v_create_callback(retval, "set_temperature", cmd_llama_set_temperature);
312332
v_create_callback(retval, "set_top_k", cmd_llama_set_top_k);
313333
v_create_callback(retval, "set_top_p", cmd_llama_set_top_p);
314-
v_create_callback(retval, "generate", cmd_llama_generate);
315-
v_create_callback(retval, "reset", cmd_llama_reset);
316334
result = 1;
317335
} else {
318336
error(retval, llama.last_error());

0 commit comments

Comments
 (0)