@@ -60,6 +60,13 @@ namespace splashkit_lib
6060 return {false };
6161 }
6262
63+ if (llama_model_has_encoder (model))
64+ {
65+ llama_model_free (model);
66+ CLOG (ERROR, " GenAI" ) << " Unsupported model, requires encoder-decoder support." ;
67+ return {false };
68+ }
69+
6370 const llama_vocab * vocab = llama_model_get_vocab (model);
6471 const char * tmpl = llama_model_chat_template (model, /* name */ nullptr );
6572
@@ -82,7 +89,7 @@ namespace splashkit_lib
8289 llama_model_free (mdl.model );
8390 }
8491
85- std::string format_chat (model& mdl, const std::vector<message>& messages)
92+ std::string format_chat (model& mdl, const std::vector<message>& messages, bool add_assistant )
8693 {
8794 std::vector<llama_chat_message> llama_formatted;
8895 std::vector<char > formatted (0 );
@@ -94,27 +101,27 @@ namespace splashkit_lib
94101 llama_formatted.push_back ({msg.role .c_str (), msg.content .c_str ()});
95102 }
96103
97- int new_len = llama_chat_apply_template (mdl.tmpl , llama_formatted.data (), llama_formatted.size (), true , formatted.data (), formatted.size ());
104+ int new_len = llama_chat_apply_template (mdl.tmpl , llama_formatted.data (), llama_formatted.size (), add_assistant , formatted.data (), formatted.size ());
98105 if (new_len > (int )formatted.size ())
99106 {
100107 formatted.resize (new_len);
101- new_len = llama_chat_apply_template (mdl.tmpl , llama_formatted.data (), llama_formatted.size (), true , formatted.data (), formatted.size ());
108+ new_len = llama_chat_apply_template (mdl.tmpl , llama_formatted.data (), llama_formatted.size (), add_assistant , formatted.data (), formatted.size ());
102109 }
103110
104111 return std::string (formatted.begin (), formatted.end ());
105112 }
106113
107- llama_tokens tokenize_string (model& mdl, const std::string& prompt)
114+ llama_tokens tokenize_string (model& mdl, const std::string& prompt, bool is_first )
108115 {
109116 // get token count
110117 // note: returns a negative number, the count of tokens it would have returned if the buffer was large enough
111- const int n_prompt = -llama_tokenize (mdl.vocab , prompt.data (), prompt.size (), NULL , 0 , true , true );
118+ const int n_prompt = -llama_tokenize (mdl.vocab , prompt.data (), prompt.size (), NULL , 0 , is_first , true );
112119
113120 // create buffer
114121 std::vector<llama_token> prompt_tokens (n_prompt);
115122
116123 // recieve the tokens
117- if (llama_tokenize (mdl.vocab , prompt.data (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) < 0 )
124+ if (llama_tokenize (mdl.vocab , prompt.data (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), is_first , true ) < 0 )
118125 {
119126 CLOG (ERROR, " GenAI" ) << " Failed to tokenize the prompt." ;
120127 return {};
@@ -128,7 +135,7 @@ namespace splashkit_lib
128135 // Create the context
129136 llama_context_params ctx_params = llama_context_default_params ();
130137 ctx_params.n_ctx = starting_context.size () + settings.max_length - 1 ;
131- ctx_params.n_batch = starting_context. size () ;
138+ ctx_params.n_batch = ctx_params. n_ctx ;
132139 ctx_params.no_perf = true ;
133140
134141 llama_context * ctx = llama_init_from_model (mdl.model , ctx_params);
@@ -153,60 +160,58 @@ namespace splashkit_lib
153160 llama_sampler_chain_add (smpl, llama_sampler_init_penalties (64 , 0 , 0 , settings.presence_penalty ));
154161 llama_sampler_chain_add (smpl, llama_sampler_init_dist (settings.seed ));
155162
156- // Prepare batch and encode starting context
157- llama_batch batch = llama_batch_get_one ( starting_context. data (), starting_context. size ()) ;
163+ // Prepare batch for starting context
164+ llama_tokens next_batch = starting_context;
158165
159- if (llama_model_has_encoder (mdl.model ))
160- {
161- if (llama_encode (ctx, batch))
162- {
163- llama_free (ctx);
164- llama_sampler_free (smpl);
165- CLOG (ERROR, " GenAI" ) << " Failed to encode prompt." ;
166- return {nullptr };
167- }
168-
169- llama_token decoder_start_token_id = llama_model_decoder_start_token (mdl.model );
170- if (decoder_start_token_id == LLAMA_TOKEN_NULL)
171- {
172- decoder_start_token_id = llama_vocab_bos (mdl.vocab );
173- }
174-
175- batch = llama_batch_get_one (&decoder_start_token_id, 1 );
176- }
166+ // Cache newline token - we use this manually in some spots
167+ llama_token newline_token;
168+ llama_tokenize (mdl.vocab , " \n " , 1 , &newline_token, 1 , false , true );
177169
178170 return
179171 {
180172 ctx,
181173 smpl,
182- batch ,
174+ next_batch ,
183175 (int )ctx_params.n_ctx ,
184176 mdl.vocab ,
177+ newline_token,
185178 0 ,
186- " "
179+ {},
180+ false
187181 };
188182 }
189183
190- int context_step (context& ctx)
184+ int context_step (context& ctx, token_result* token )
191185 {
186+ const string THINKING_START = " <think>" ;
187+ const string THINKING_END = " </think>" ;
188+
192189 if (!ctx.ctx )
193190 return -1 ;
194191
192+ llama_batch batch = llama_batch_get_one (ctx.next_batch .data (), ctx.next_batch .size ());
195193 // Decode current batch with the model
196- if (llama_decode (ctx.ctx , ctx. batch ))
194+ if (llama_decode (ctx.ctx , batch))
197195 {
198196 CLOG (ERROR, " GenAI" ) << " Failed to process response from language model." ;
197+ if (token)
198+ token->type = token_result::NONE;
199199 return -1 ;
200200 }
201201
202- ctx.n_pos += ctx.batch .n_tokens ;
202+ ctx.total_context .insert (ctx.total_context .end (), ctx.next_batch .begin (), ctx.next_batch .end ());
203+ ctx.n_pos += batch.n_tokens ;
203204
204205 // Sample next token
205206 llama_token new_token_id = llama_sampler_sample (ctx.smpl , ctx.ctx , -1 );
206207
207208 // Has the model finished its response?
208209 if (llama_vocab_is_eog (ctx.vocab , new_token_id))
210+ {
211+ if (token)
212+ token->type = token_result::NONE;
209213 return 1 ;
214+ }
210215
211216 char buf[128 ];
212217 int n = llama_token_to_piece (ctx.vocab , new_token_id, buf, sizeof (buf), 0 , true );
@@ -217,19 +222,46 @@ namespace splashkit_lib
217222 }
218223
219224 std::string s (buf, n);
220- ctx.ctx_string += s;
225+
226+ if (token)
227+ {
228+ bool is_meta = s == THINKING_START || s == THINKING_END;
229+ token->text = s;
230+ if (is_meta)
231+ token->type = token_result::META;
232+ else if (ctx.in_thinking )
233+ token->type = token_result::THINKING;
234+ else
235+ token->type = token_result::CONTENT;
236+ }
237+
238+ if (s == THINKING_START)
239+ ctx.in_thinking = true ;
240+ else if (s == THINKING_END)
241+ ctx.in_thinking = false ;
221242
222243 // prepare the next batch with the sampled token
223- ctx.batch = llama_batch_get_one (& new_token_id, 1 ) ;
244+ ctx.next_batch = { new_token_id} ;
224245
225246 // Have we reached the end of the context?
226247 // If so, stop now.
227- if (ctx.n_pos + ctx.batch . n_tokens >= ctx.ctx_size )
248+ if (ctx.n_pos + ctx.next_batch . size () >= ctx.ctx_size )
228249 return 1 ;
229250
230251 return 0 ;
231252 }
232253
254+ void add_to_context (context& ctx, llama_tokens& message)
255+ {
256+ ctx.next_batch .insert (ctx.next_batch .end (), message.begin (), message.end ());
257+ }
258+
259+ void manual_end_message (context& ctx)
260+ {
261+ ctx.next_batch .push_back (llama_vocab_eot (ctx.vocab ));
262+ ctx.next_batch .push_back (ctx.newline_token );
263+ }
264+
233265 void delete_context (context& ctx)
234266 {
235267 if (ctx.smpl )
@@ -238,5 +270,18 @@ namespace splashkit_lib
238270 if (ctx.ctx )
239271 llama_free (ctx.ctx );
240272 }
273+
274+ void __print_debug_context (context& ctx)
275+ {
276+ for (auto & x : ctx.total_context )
277+ {
278+ char buf[128 ];
279+ int n = llama_token_to_piece (ctx.vocab , x, buf, sizeof (buf), 0 , true );
280+
281+ std::string s (buf, n);
282+ std::cout << " |" << s;
283+ }
284+ std::cout << std::endl;
285+ }
241286 }
242287}
0 commit comments