55//
66// Copyright(C) 2026 Chris Warren-Smith
77
8+ #include < chrono>
89#include < vector>
910#include " llama.h"
1011#include " llama-sb.h"
@@ -14,12 +15,14 @@ Llama::Llama() :
1415 _ctx(nullptr ),
1516 _sampler(nullptr ),
1617 _vocab(nullptr ),
18+ _penalty_last_n(64 ),
19+ _penalty_repeat(1 .1f ),
1720 _temperature(0 ),
1821 _top_k(0 ),
1922 _top_p(1 .0f ),
2023 _min_p(0 .0f ),
2124 _max_tokens(150 ),
22- _log_level(GGML_LOG_LEVEL_NONE ) {
25+ _log_level(GGML_LOG_LEVEL_CONT ) {
2326 llama_log_set ([](enum ggml_log_level level, const char * text, void *user_data) {
2427 Llama *llama = (Llama *)user_data;
2528 if (level > llama->_log_level ) {
@@ -82,6 +85,10 @@ bool Llama::construct(string model_path, int n_ctx, int n_batch) {
8285
8386void Llama::configure_sampler () {
8487 llama_sampler_reset (_sampler);
88+ if (_penalty_last_n != 0 && _penalty_repeat != 1 .0f ) {
89+ auto penalties = llama_sampler_init_penalties (_penalty_last_n, _penalty_repeat, 0 .0f , 0 .0f );
90+ llama_sampler_chain_add (_sampler, penalties);
91+ }
8592 if (_temperature <= 0 .0f ) {
8693 llama_sampler_chain_add (_sampler, llama_sampler_init_greedy ());
8794 } else {
@@ -99,72 +106,104 @@ void Llama::configure_sampler() {
99106 }
100107}
101108
109+ void Llama::reset () {
110+ // llama_kv_cache_clear(it->second->ctx);
111+ _chat_prompt.clear ();
112+ }
113+
102114string Llama::generate (const string &prompt) {
103- string out;
115+ string out = prompt ;
104116
105- // find the number of tokens in the prompt
106- int n_prompt = -llama_tokenize (_vocab, prompt.c_str (), prompt.size (), nullptr , 0 , true , true );
117+ // ---- tokenize prompt ----
118+ int n_prompt = -llama_tokenize (_vocab, prompt.c_str (), prompt.size (),
119+ nullptr , 0 , true , true );
120+
121+ if (n_prompt <= 0 ) {
122+ _last_error = " failed to tokenize prompt" ;
123+ return out;
124+ }
107125
108- // allocate space for the tokens and tokenize the prompt
109126 std::vector<llama_token> prompt_tokens (n_prompt);
110- if (llama_tokenize (_vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) < 0 ) {
111- _last_error = " failed tokenize the prompt" ;
127+ if (llama_tokenize (_vocab, prompt.c_str (), prompt.size (),
128+ prompt_tokens.data (), n_prompt, true , true ) < 0 ) {
129+ _last_error = " failed to tokenize prompt" ;
112130 return out;
113131 }
114132
115- // initialize the sampler
133+ // ---- sampler ----
116134 configure_sampler ();
117135
118- // prepare a batch for the prompt
119- llama_batch batch = llama_batch_get_one (prompt_tokens.data (), prompt_tokens.size ());
120- if (llama_model_has_encoder (_model)) {
121- if (llama_encode (_ctx, batch)) {
122- _last_error = " failed to eval" ;
123- return out;
124- }
136+ // ---- decode prompt ----
137+ llama_batch batch = llama_batch_get_one (prompt_tokens.data (), n_prompt);
138+ if (llama_decode (_ctx, batch)) {
139+ _last_error = " failed to eval prompt" ;
140+ return out;
141+ }
125142
143+ // ---- handle encoder models ----
144+ if (llama_model_has_encoder (_model)) {
126145 llama_token decoder_start_token_id = llama_model_decoder_start_token (_model);
127146 if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
128147 decoder_start_token_id = llama_vocab_bos (_vocab);
129148 }
130-
131149 batch = llama_batch_get_one (&decoder_start_token_id, 1 );
132- }
133-
134- for (int n_pos = 0 ; n_pos + batch.n_tokens < n_prompt + _max_tokens;) {
135- // evaluate the current batch with the transformer model
136150 if (llama_decode (_ctx, batch)) {
137- _last_error = " failed to eval" ;
138- break ;
151+ _last_error = " failed to eval decoder start token " ;
152+ return out ;
139153 }
154+ }
155+
156+ // ---- generation loop ----
157+ std::vector<llama_token> decoded;
158+ decoded.reserve (_max_tokens);
140159
141- n_pos += batch.n_tokens ;
160+ int generated = 0 ;
161+ auto t_start = std::chrono::high_resolution_clock::now ();
142162
143- // sample the next token
144- llama_token new_token_id = llama_sampler_sample (_sampler, _ctx, -1 );
163+ while (generated < _max_tokens) {
164+ // sample one token from the current logits
165+ llama_token tok = llama_sampler_sample (_sampler, _ctx, -1 );
145166
146- // is it an end of generation?
147- if (llama_vocab_is_eog (_vocab, new_token_id )) {
167+ // end-of- generation check
168+ if (llama_vocab_is_eog (_vocab, tok )) {
148169 break ;
149170 }
150171
151- char buf[128 ];
152- int n = llama_token_to_piece (_vocab, new_token_id, buf, sizeof (buf), 0 , true );
153- if (n < 0 ) {
154- _last_error = " failed to convert token to piece" ;
172+ // append token to decoded list
173+ decoded.push_back (tok);
174+ ++generated;
175+
176+ // ---- decode the token immediately ----
177+ llama_batch batch = llama_batch_get_one (&tok, 1 );
178+ if (llama_decode (_ctx, batch)) {
179+ _last_error = " failed to eval token during generation" ;
155180 break ;
156- } else if (n > 0 ) {
157- out.append (buf, n);
158181 }
182+ }
159183
160- // prepare the next batch with the sampled token
161- batch = llama_batch_get_one (&new_token_id, 1 );
184+ // ---- detokenize sequentially ----
185+ if (!decoded.empty ()) {
186+ char buf[512 ];
187+ for (llama_token tok : decoded) {
188+ if (llama_vocab_is_control (_vocab, tok)) {
189+ continue ;
190+ }
191+ int n = llama_token_to_piece (_vocab, tok, buf, sizeof (buf), 0 , false );
192+ if (n > 0 ) {
193+ out.append (buf, n);
194+ }
195+ }
162196 }
163197
198+ // ---- timing ----
199+ auto t_end = std::chrono::high_resolution_clock::now ();
200+ double secs = std::chrono::duration<double >(t_end - t_start).count ();
201+ double tokps = secs > 0 ? generated / secs : 0 ;
202+
203+ fprintf (stderr,
204+ " [tok/s=%.2f] generated=%d time=%.3fs\n " ,
205+ tokps, generated, secs);
206+
164207 return out;
165208}
166209
167- void Llama::reset () {
168- // llama_kv_cache_clear(it->second->ctx);
169- _chat_prompt.clear ();
170- }
0 commit comments