66// Copyright(C) 2026 Chris Warren-Smith
77
88#include < chrono>
9- #include < vector>
109#include " llama.h"
1110#include " llama-sb.h"
1211
12+ constexpr int MAX_REPEAT = 5 ;
13+
1314LlamaIter::LlamaIter () :
1415 _llama(nullptr ),
1516 _tokens_sec(0 ),
17+ _repetition_count(0 ),
1618 _has_next(false ) {
1719}
1820
@@ -21,20 +23,21 @@ Llama::Llama() :
2123 _ctx(nullptr ),
2224 _sampler(nullptr ),
2325 _vocab(nullptr ),
24- _penalty_last_n(64 ),
25- _penalty_repeat(1 . 1f ),
26+ _penalty_last_n(0 ),
27+ _penalty_repeat(0 ),
2628 _temperature(0 ),
2729 _top_k(0 ),
28- _top_p(1 . 0f ),
29- _min_p(0 . 0f ),
30- _max_tokens(150 ),
30+ _top_p(0 ),
31+ _min_p(0 ),
32+ _max_tokens(0 ),
3133 _log_level(GGML_LOG_LEVEL_CONT) {
3234 llama_log_set ([](enum ggml_log_level level, const char * text, void *user_data) {
3335 Llama *llama = (Llama *)user_data;
3436 if (level > llama->_log_level ) {
3537 fprintf (stderr, " LLAMA: %s" , text);
3638 }
3739 }, this );
40+ reset ();
3841}
3942
4043Llama::~Llama () {
@@ -49,6 +52,18 @@ Llama::~Llama() {
4952 }
5053}
5154
55+ void Llama::reset () {
56+ _stop_sequences.clear ();
57+ _last_error = " " ;
58+ _penalty_last_n = 64 ;
59+ _penalty_repeat = 1 .1f ;
60+ _temperature = 0 ;
61+ _top_k = 0 ;
62+ _top_p = 1 .0f ;
63+ _min_p = 0 .0f ;
64+ _max_tokens = 150 ;
65+ }
66+
5267bool Llama::construct (string model_path, int n_ctx, int n_batch) {
5368 ggml_backend_load_all ();
5469
@@ -100,31 +115,31 @@ void Llama::configure_sampler() {
100115 }
101116}
102117
103- void Llama::reset () {
104- llama_sampler_reset (_sampler);
105- _chat_prompt.clear ();
106- }
107-
108- bool Llama::generate (LlamaIter &iter, const string &prompt) {
109- int n_prompt = -llama_tokenize (_vocab, prompt.c_str (), prompt.size (),
110- nullptr , 0 , true , true );
118+ vector<llama_token> Llama::tokenize (const string &prompt) {
119+ vector<llama_token> result;
111120
121+ int n_prompt = -llama_tokenize (_vocab, prompt.c_str (), prompt.size (), nullptr , 0 , true , true );
112122 if (n_prompt <= 0 ) {
113123 _last_error = " failed to tokenize prompt" ;
114- return false ;
124+ } else {
125+ result.reserve (n_prompt);
126+ result.resize (n_prompt);
127+ if (llama_tokenize (_vocab, prompt.c_str (), prompt.size (),
128+ result.data (), n_prompt, true , true ) < 0 ) {
129+ _last_error = " failed to tokenize prompt" ;
130+ }
115131 }
132+ return result;
133+ }
116134
117- std::vector<llama_token> prompt_tokens (n_prompt);
118- if (llama_tokenize (_vocab, prompt.c_str (), prompt.size (),
119- prompt_tokens.data (), n_prompt, true , true ) < 0 ) {
120- _last_error = " failed to tokenize prompt" ;
135+ bool Llama::generate (LlamaIter &iter, const string &prompt) {
136+ vector<llama_token> prompt_tokens = tokenize (prompt);
137+ if (prompt_tokens.size () == 0 ) {
121138 return false ;
122139 }
123140
124- configure_sampler ();
125-
126141 // decode prompt
127- llama_batch batch = llama_batch_get_one (prompt_tokens.data (), n_prompt );
142+ llama_batch batch = llama_batch_get_one (prompt_tokens.data (), prompt_tokens. size () );
128143 if (llama_decode (_ctx, batch)) {
129144 _last_error = " failed to eval prompt" ;
130145 return false ;
@@ -143,8 +158,9 @@ bool Llama::generate(LlamaIter &iter, const string &prompt) {
143158 }
144159 }
145160
161+ configure_sampler ();
162+
146163 iter._llama = this ;
147- iter._batch = batch;
148164 iter._has_next = true ;
149165 iter._tokens_sec = 0 ;
150166 return true ;
@@ -153,7 +169,7 @@ bool Llama::generate(LlamaIter &iter, const string &prompt) {
153169string Llama::next (LlamaIter &iter) {
154170 string out;
155171
156- std:: vector<llama_token> decoded;
172+ vector<llama_token> decoded;
157173 decoded.reserve (_max_tokens);
158174
159175 int generated = 0 ;
@@ -183,14 +199,32 @@ string Llama::next(LlamaIter &iter) {
183199
184200 // detokenize sequentially
185201 if (!decoded.empty ()) {
186- char buf[512 ];
187202 for (llama_token tok : decoded) {
188- if (llama_vocab_is_control (_vocab, tok)) {
189- continue ;
190- }
191- int n = llama_token_to_piece (_vocab, tok, buf, sizeof (buf), 0 , false );
192- if (n > 0 ) {
193- out.append (buf, n);
203+ if (!llama_vocab_is_control (_vocab, tok)) {
204+ char buf[512 ];
205+ int n = llama_token_to_piece (_vocab, tok, buf, sizeof (buf), 0 , false );
206+ if (n > 0 ) {
207+ if (iter._last_word == buf) {
208+ if (++iter._repetition_count == MAX_REPEAT) {
209+ iter._has_next = false ;
210+ break ;
211+ }
212+ } else {
213+ iter._repetition_count = 0 ;
214+ iter._last_word = buf;
215+ }
216+ out.append (buf, n);
217+
218+ for (const auto &stop : _stop_sequences) {
219+ size_t pos = out.find (stop);
220+ if (pos != std::string::npos) {
221+ // found stop sequence - truncate and signal end
222+ out = out.substr (0 , pos);
223+ iter._has_next = false ;
224+ break ;
225+ }
226+ }
227+ }
194228 }
195229 }
196230 }
0 commit comments