Skip to content

Commit 2f6015b

Browse files
committed
Conversation & Streaming support + some refactoring + updated test program
1 parent 9341ffb commit 2f6015b

6 files changed

Lines changed: 497 additions & 61 deletions

File tree

coresdk/src/backend/backend_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ namespace splashkit_lib
6464
ADC_PTR= 0x41444350, //'ADCP';
6565
MOTOR_DRIVER_PTR = 0x4d444950, //'MDIP';
6666
SERVO_DRIVER_PTR = 0x53455256, //'SERV';
67+
CONVERSATION_PTR = 0x434f4e56, //'CONV';
6768
NONE_PTR = 0x4e4f4e45 //'NONE';
6869
};
6970

coresdk/src/backend/genai_driver.cpp

Lines changed: 80 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ namespace splashkit_lib
6060
return {false};
6161
}
6262

63+
if (llama_model_has_encoder(model))
64+
{
65+
llama_model_free(model);
66+
CLOG(ERROR, "GenAI") << "Unsupported model, requires encoder-decoder support.";
67+
return {false};
68+
}
69+
6370
const llama_vocab * vocab = llama_model_get_vocab(model);
6471
const char* tmpl = llama_model_chat_template(model, /* name */ nullptr);
6572

@@ -82,7 +89,7 @@ namespace splashkit_lib
8289
llama_model_free(mdl.model);
8390
}
8491

85-
std::string format_chat(model& mdl, const std::vector<message>& messages)
92+
std::string format_chat(model& mdl, const std::vector<message>& messages, bool add_assistant)
8693
{
8794
std::vector<llama_chat_message> llama_formatted;
8895
std::vector<char> formatted(0);
@@ -94,27 +101,27 @@ namespace splashkit_lib
94101
llama_formatted.push_back({msg.role.c_str(), msg.content.c_str()});
95102
}
96103

97-
int new_len = llama_chat_apply_template(mdl.tmpl, llama_formatted.data(), llama_formatted.size(), true, formatted.data(), formatted.size());
104+
int new_len = llama_chat_apply_template(mdl.tmpl, llama_formatted.data(), llama_formatted.size(), add_assistant, formatted.data(), formatted.size());
98105
if (new_len > (int)formatted.size())
99106
{
100107
formatted.resize(new_len);
101-
new_len = llama_chat_apply_template(mdl.tmpl, llama_formatted.data(), llama_formatted.size(), true, formatted.data(), formatted.size());
108+
new_len = llama_chat_apply_template(mdl.tmpl, llama_formatted.data(), llama_formatted.size(), add_assistant, formatted.data(), formatted.size());
102109
}
103110

104111
return std::string(formatted.begin(), formatted.end());
105112
}
106113

107-
llama_tokens tokenize_string(model& mdl, const std::string& prompt)
114+
llama_tokens tokenize_string(model& mdl, const std::string& prompt, bool is_first)
108115
{
109116
// get token count
110117
// note: returns a negative number, the count of tokens it would have returned if the buffer was large enough
111-
const int n_prompt = -llama_tokenize(mdl.vocab, prompt.data(), prompt.size(), NULL, 0, true, true);
118+
const int n_prompt = -llama_tokenize(mdl.vocab, prompt.data(), prompt.size(), NULL, 0, is_first, true);
112119

113120
// create buffer
114121
std::vector<llama_token> prompt_tokens(n_prompt);
115122

116123
// recieve the tokens
117-
if (llama_tokenize(mdl.vocab, prompt.data(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0)
124+
if (llama_tokenize(mdl.vocab, prompt.data(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0)
118125
{
119126
CLOG(ERROR, "GenAI") << "Failed to tokenize the prompt.";
120127
return {};
@@ -128,7 +135,7 @@ namespace splashkit_lib
128135
// Create the context
129136
llama_context_params ctx_params = llama_context_default_params();
130137
ctx_params.n_ctx = starting_context.size() + settings.max_length - 1;
131-
ctx_params.n_batch = starting_context.size();
138+
ctx_params.n_batch = ctx_params.n_ctx;
132139
ctx_params.no_perf = true;
133140

134141
llama_context * ctx = llama_init_from_model(mdl.model, ctx_params);
@@ -153,60 +160,58 @@ namespace splashkit_lib
153160
llama_sampler_chain_add(smpl, llama_sampler_init_penalties(64, 0, 0, settings.presence_penalty));
154161
llama_sampler_chain_add(smpl, llama_sampler_init_dist(settings.seed));
155162

156-
// Prepare batch and encode starting context
157-
llama_batch batch = llama_batch_get_one(starting_context.data(), starting_context.size());
163+
// Prepare batch for starting context
164+
llama_tokens next_batch = starting_context;
158165

159-
if (llama_model_has_encoder(mdl.model))
160-
{
161-
if (llama_encode(ctx, batch))
162-
{
163-
llama_free(ctx);
164-
llama_sampler_free(smpl);
165-
CLOG(ERROR, "GenAI") << "Failed to encode prompt.";
166-
return {nullptr};
167-
}
168-
169-
llama_token decoder_start_token_id = llama_model_decoder_start_token(mdl.model);
170-
if (decoder_start_token_id == LLAMA_TOKEN_NULL)
171-
{
172-
decoder_start_token_id = llama_vocab_bos(mdl.vocab);
173-
}
174-
175-
batch = llama_batch_get_one(&decoder_start_token_id, 1);
176-
}
166+
// Cache newline token - we use this manually in some spots
167+
llama_token newline_token;
168+
llama_tokenize(mdl.vocab, "\n", 1, &newline_token, 1, false, true);
177169

178170
return
179171
{
180172
ctx,
181173
smpl,
182-
batch,
174+
next_batch,
183175
(int)ctx_params.n_ctx,
184176
mdl.vocab,
177+
newline_token,
185178
0,
186-
""
179+
{},
180+
false
187181
};
188182
}
189183

190-
int context_step(context& ctx)
184+
int context_step(context& ctx, token_result* token)
191185
{
186+
const string THINKING_START = "<think>";
187+
const string THINKING_END = "</think>";
188+
192189
if (!ctx.ctx)
193190
return -1;
194191

192+
llama_batch batch = llama_batch_get_one(ctx.next_batch.data(), ctx.next_batch.size());
195193
// Decode current batch with the model
196-
if (llama_decode(ctx.ctx, ctx.batch))
194+
if (llama_decode(ctx.ctx, batch))
197195
{
198196
CLOG(ERROR, "GenAI") << "Failed to process response from language model.";
197+
if (token)
198+
token->type = token_result::NONE;
199199
return -1;
200200
}
201201

202-
ctx.n_pos += ctx.batch.n_tokens;
202+
ctx.total_context.insert(ctx.total_context.end(), ctx.next_batch.begin(), ctx.next_batch.end());
203+
ctx.n_pos += batch.n_tokens;
203204

204205
// Sample next token
205206
llama_token new_token_id = llama_sampler_sample(ctx.smpl, ctx.ctx, -1);
206207

207208
// Has the model finished its response?
208209
if (llama_vocab_is_eog(ctx.vocab, new_token_id))
210+
{
211+
if (token)
212+
token->type = token_result::NONE;
209213
return 1;
214+
}
210215

211216
char buf[128];
212217
int n = llama_token_to_piece(ctx.vocab, new_token_id, buf, sizeof(buf), 0, true);
@@ -217,19 +222,46 @@ namespace splashkit_lib
217222
}
218223

219224
std::string s(buf, n);
220-
ctx.ctx_string += s;
225+
226+
if (token)
227+
{
228+
bool is_meta = s == THINKING_START || s == THINKING_END;
229+
token->text = s;
230+
if (is_meta)
231+
token->type = token_result::META;
232+
else if (ctx.in_thinking)
233+
token->type = token_result::THINKING;
234+
else
235+
token->type = token_result::CONTENT;
236+
}
237+
238+
if (s == THINKING_START)
239+
ctx.in_thinking = true;
240+
else if (s == THINKING_END)
241+
ctx.in_thinking = false;
221242

222243
// prepare the next batch with the sampled token
223-
ctx.batch = llama_batch_get_one(&new_token_id, 1);
244+
ctx.next_batch = {new_token_id};
224245

225246
// Have we reached the end of the context?
226247
// If so, stop now.
227-
if (ctx.n_pos + ctx.batch.n_tokens >= ctx.ctx_size)
248+
if (ctx.n_pos + ctx.next_batch.size() >= ctx.ctx_size)
228249
return 1;
229250

230251
return 0;
231252
}
232253

254+
void add_to_context(context& ctx, llama_tokens& message)
255+
{
256+
ctx.next_batch.insert(ctx.next_batch.end(), message.begin(), message.end());
257+
}
258+
259+
void manual_end_message(context& ctx)
260+
{
261+
ctx.next_batch.push_back(llama_vocab_eot(ctx.vocab));
262+
ctx.next_batch.push_back(ctx.newline_token);
263+
}
264+
233265
void delete_context(context& ctx)
234266
{
235267
if (ctx.smpl)
@@ -238,5 +270,18 @@ namespace splashkit_lib
238270
if (ctx.ctx)
239271
llama_free(ctx.ctx);
240272
}
273+
274+
void __print_debug_context(context& ctx)
275+
{
276+
for (auto& x : ctx.total_context)
277+
{
278+
char buf[128];
279+
int n = llama_token_to_piece(ctx.vocab, x, buf, sizeof(buf), 0, true);
280+
281+
std::string s(buf, n);
282+
std::cout << "|" << s;
283+
}
284+
std::cout << std::endl;
285+
}
241286
}
242287
}

coresdk/src/backend/genai_driver.h

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ namespace splashkit_lib
1818

1919
namespace llamacpp
2020
{
21+
typedef std::vector<llama_token> llama_tokens;
22+
2123
struct model
2224
{
2325
bool valid;
@@ -47,29 +49,62 @@ namespace splashkit_lib
4749
{
4850
llama_context* ctx;
4951
llama_sampler* smpl;
50-
llama_batch batch;
52+
llama_tokens next_batch;
5153
int ctx_size = 0;
5254

5355
const llama_vocab* vocab;
56+
llama_token newline_token;
5457

5558
int n_pos;
56-
std::string ctx_string;
59+
llama_tokens total_context;
60+
61+
bool in_thinking = false;
5762
};
5863

59-
typedef std::vector<llama_token> llama_tokens;
64+
struct token_result
65+
{
66+
enum token_type {
67+
NONE,
68+
CONTENT,
69+
THINKING,
70+
META
71+
};
72+
string text;
73+
token_type type;
74+
};
6075

6176
void init();
6277

6378
model create_model(std::string path);
6479
void delete_model(model mdl);
6580

66-
std::string format_chat(model& mdl, const std::vector<message>& messages);
67-
llama_tokens tokenize_string(model& mdl, const std::string& prompt);
81+
std::string format_chat(model& mdl, const std::vector<message>& messages, bool add_assistant);
82+
llama_tokens tokenize_string(model& mdl, const std::string& prompt, bool is_first);
6883

6984
context start_context(model& mdl, llama_tokens& starting_context, inference_settings settings);
70-
int context_step(context& ctx);
7185
void delete_context(context& ctx);
86+
87+
int context_step(context& ctx, token_result* token);
88+
void add_to_context(context& ctx, llama_tokens& message);
89+
void manual_end_message(context& ctx);
90+
91+
void __print_debug_context(context& ctx);
7292
}
93+
94+
struct sk_conversation
95+
{
96+
pointer_identifier id;
97+
98+
llamacpp::model model;
99+
llamacpp::context context;
100+
101+
bool was_generating;
102+
bool is_generating;
103+
104+
string prompt_append;
105+
106+
llamacpp::token_result next_token;
107+
};
73108
}
74109

75110
#endif /* defined(graphics_driver) */

0 commit comments

Comments
 (0)