Skip to content

Commit 059de75

Browse files
author
Chris Warren-Smith
committed
LLM: plugin module - generation iterator
1 parent fc00a83 commit 059de75

File tree

7 files changed

+185
-126
lines changed

7 files changed

+185
-126
lines changed

include/param.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,14 +589,14 @@ void v_create_func(var_p_t map, const char *name, method cb) {
589589
var_p_t v_func = map_add_var(map, name, 0);
590590
v_func->type = V_FUNC;
591591
v_func->v.fn.cb = cb;
592-
v_func->v.fn.mcb = NULL;
592+
v_func->v.fn.mcb = nullptr;
593593
v_func->v.fn.id = 0;
594594
}
595595

596596
void v_create_callback(var_p_t map, const char *name, callback cb) {
597597
var_p_t v_func = map_add_var(map, name, 0);
598598
v_func->type = V_FUNC;
599-
v_func->v.fn.cb = NULL;
599+
v_func->v.fn.cb = nullptr;
600600
v_func->v.fn.mcb = cb;
601601
v_func->v.fn.id = 0;
602602
}

include/var.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ typedef struct var_s {
8080

8181
// associative array/map
8282
struct {
83-
// pointer the map structure
83+
// pointer to the map structure
8484
void *map;
8585

8686
uint32_t count;
@@ -132,7 +132,7 @@ typedef struct var_s {
132132
// non-zero if constant
133133
uint8_t const_flag;
134134

135-
// whether help in pooled memory
135+
// whether held in pooled memory
136136
uint8_t pooled;
137137
} var_t;
138138

@@ -154,7 +154,7 @@ var_t *v_new(void);
154154
*
155155
* @return a newly created var_t array of the given size
156156
*/
157-
void v_new_array(var_t *var, unsigned size);
157+
void v_new_array(var_t *var, uint32_t size);
158158

159159
/**
160160
* @ingroup var

llama/llama-sb.cpp

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
#include "llama.h"
1111
#include "llama-sb.h"
1212

13+
LlamaIter::LlamaIter() :
14+
_llama(nullptr),
15+
_tokens_sec(0),
16+
_has_next(false) {
17+
}
18+
1319
Llama::Llama() :
1420
_model(nullptr),
1521
_ctx(nullptr),
@@ -43,18 +49,6 @@ Llama::~Llama() {
4349
}
4450
}
4551

46-
void Llama::append_response(const string &response) {
47-
_chat_prompt += response;
48-
_chat_prompt += "\n";
49-
}
50-
51-
const string Llama::build_chat_prompt(const string &user_msg) {
52-
_chat_prompt += "User: ";
53-
_chat_prompt += user_msg;
54-
_chat_prompt += "\nAssistant: ";
55-
return _chat_prompt;
56-
}
57-
5852
bool Llama::construct(string model_path, int n_ctx, int n_batch) {
5953
ggml_backend_load_all();
6054

@@ -107,40 +101,36 @@ void Llama::configure_sampler() {
107101
}
108102

109103
void Llama::reset() {
110-
// llama_kv_cache_clear(it->second->ctx);
104+
llama_sampler_reset(_sampler);
111105
_chat_prompt.clear();
112106
}
113107

114-
string Llama::generate(const string &prompt) {
115-
string out = prompt;
116-
117-
// ---- tokenize prompt ----
108+
bool Llama::generate(LlamaIter &iter, const string &prompt) {
118109
int n_prompt = -llama_tokenize(_vocab, prompt.c_str(), prompt.size(),
119110
nullptr, 0, true, true);
120111

121112
if (n_prompt <= 0) {
122113
_last_error = "failed to tokenize prompt";
123-
return out;
114+
return false;
124115
}
125116

126117
std::vector<llama_token> prompt_tokens(n_prompt);
127118
if (llama_tokenize(_vocab, prompt.c_str(), prompt.size(),
128119
prompt_tokens.data(), n_prompt, true, true) < 0) {
129120
_last_error = "failed to tokenize prompt";
130-
return out;
121+
return false;
131122
}
132123

133-
// ---- sampler ----
134124
configure_sampler();
135125

136-
// ---- decode prompt ----
126+
// decode prompt
137127
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), n_prompt);
138128
if (llama_decode(_ctx, batch)) {
139129
_last_error = "failed to eval prompt";
140-
return out;
130+
return false;
141131
}
142132

143-
// ---- handle encoder models ----
133+
// handle encoder models
144134
if (llama_model_has_encoder(_model)) {
145135
llama_token decoder_start_token_id = llama_model_decoder_start_token(_model);
146136
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
@@ -149,11 +139,20 @@ string Llama::generate(const string &prompt) {
149139
batch = llama_batch_get_one(&decoder_start_token_id, 1);
150140
if (llama_decode(_ctx, batch)) {
151141
_last_error = "failed to eval decoder start token";
152-
return out;
142+
return false;
153143
}
154144
}
155145

156-
// ---- generation loop ----
146+
iter._llama = this;
147+
iter._batch = batch;
148+
iter._has_next = true;
149+
iter._tokens_sec = 0;
150+
return true;
151+
}
152+
153+
string Llama::next(LlamaIter &iter) {
154+
string out;
155+
157156
std::vector<llama_token> decoded;
158157
decoded.reserve(_max_tokens);
159158

@@ -166,22 +165,23 @@ string Llama::generate(const string &prompt) {
166165

167166
// end-of-generation check
168167
if (llama_vocab_is_eog(_vocab, tok)) {
168+
iter._has_next = false;
169169
break;
170170
}
171171

172172
// append token to decoded list
173173
decoded.push_back(tok);
174174
++generated;
175175

176-
// ---- decode the token immediately ----
176+
// decode the token
177177
llama_batch batch = llama_batch_get_one(&tok, 1);
178178
if (llama_decode(_ctx, batch)) {
179179
_last_error = "failed to eval token during generation";
180180
break;
181181
}
182182
}
183183

184-
// ---- detokenize sequentially ----
184+
// detokenize sequentially
185185
if (!decoded.empty()) {
186186
char buf[512];
187187
for (llama_token tok : decoded) {
@@ -195,14 +195,10 @@ string Llama::generate(const string &prompt) {
195195
}
196196
}
197197

198-
// ---- timing ----
198+
// timing
199199
auto t_end = std::chrono::high_resolution_clock::now();
200200
double secs = std::chrono::duration<double>(t_end - t_start).count();
201-
double tokps = secs > 0 ? generated / secs : 0;
202-
203-
fprintf(stderr,
204-
"[tok/s=%.2f] generated=%d time=%.3fs\n",
205-
tokps, generated, secs);
201+
iter._tokens_sec = secs > 0 ? generated / secs : 0;
206202

207203
return out;
208204
}

llama/llama-sb.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212

1313
using namespace std;
1414

15+
struct Llama;
16+
17+
struct LlamaIter {
18+
explicit LlamaIter();
19+
~LlamaIter() {}
20+
21+
Llama *_llama;
22+
llama_batch _batch;
23+
float _tokens_sec;
24+
bool _has_next;
25+
};
26+
1527
struct Llama {
1628
explicit Llama();
1729
~Llama();
@@ -20,10 +32,10 @@ struct Llama {
2032
bool construct(string model_path, int n_ctx, int n_batch);
2133

2234
// generation
23-
string generate(const string &prompt);
35+
bool generate(LlamaIter &iter, const string &prompt);
36+
string next(LlamaIter &iter);
2437

2538
// generation parameters
26-
2739
void set_penalty_last_n(int32_t penalty_last_n) { _penalty_last_n = penalty_last_n; }
2840
void set_penalty_repeat(float penalty_repeat) { _penalty_repeat = penalty_repeat; }
2941
void set_max_tokens(int max_tokens) { _max_tokens = max_tokens; }
@@ -32,12 +44,6 @@ struct Llama {
3244
void set_top_k(int top_k) { _top_k = top_k; }
3345
void set_top_p(float top_p) { _top_p = top_p; }
3446

35-
// messages
36-
void append_response(const string &response);
37-
void append_user_message(const string &user_msg);
38-
const string& get_chat_history() const;
39-
const string build_chat_prompt(const string &user_msg);
40-
4147
// error handling
4248
const char *last_error() { return _last_error.c_str(); }
4349
void set_log_level(int level) { _log_level = level; }

0 commit comments

Comments
 (0)