1010#include " llama.h"
1111#include " llama-sb.h"
1212
13+ LlamaIter::LlamaIter () :
14+ _llama(nullptr ),
15+ _tokens_sec(0 ),
16+ _has_next(false ) {
17+ }
18+
1319Llama::Llama () :
1420 _model(nullptr ),
1521 _ctx(nullptr ),
@@ -43,18 +49,6 @@ Llama::~Llama() {
4349 }
4450}
4551
46- void Llama::append_response (const string &response) {
47- _chat_prompt += response;
48- _chat_prompt += " \n " ;
49- }
50-
51- const string Llama::build_chat_prompt (const string &user_msg) {
52- _chat_prompt += " User: " ;
53- _chat_prompt += user_msg;
54- _chat_prompt += " \n Assistant: " ;
55- return _chat_prompt;
56- }
57-
5852bool Llama::construct (string model_path, int n_ctx, int n_batch) {
5953 ggml_backend_load_all ();
6054
@@ -107,40 +101,36 @@ void Llama::configure_sampler() {
107101}
108102
109103void Llama::reset () {
110- // llama_kv_cache_clear(it->second->ctx );
104+ llama_sampler_reset (_sampler );
111105 _chat_prompt.clear ();
112106}
113107
114- string Llama::generate (const string &prompt) {
115- string out = prompt;
116-
117- // ---- tokenize prompt ----
108+ bool Llama::generate (LlamaIter &iter, const string &prompt) {
118109 int n_prompt = -llama_tokenize (_vocab, prompt.c_str (), prompt.size (),
119110 nullptr , 0 , true , true );
120111
121112 if (n_prompt <= 0 ) {
122113 _last_error = " failed to tokenize prompt" ;
123- return out ;
114+ return false ;
124115 }
125116
126117 std::vector<llama_token> prompt_tokens (n_prompt);
127118 if (llama_tokenize (_vocab, prompt.c_str (), prompt.size (),
128119 prompt_tokens.data (), n_prompt, true , true ) < 0 ) {
129120 _last_error = " failed to tokenize prompt" ;
130- return out ;
121+ return false ;
131122 }
132123
133- // ---- sampler ----
134124 configure_sampler ();
135125
136- // ---- decode prompt ----
126+ // decode prompt
137127 llama_batch batch = llama_batch_get_one (prompt_tokens.data (), n_prompt);
138128 if (llama_decode (_ctx, batch)) {
139129 _last_error = " failed to eval prompt" ;
140- return out ;
130+ return false ;
141131 }
142132
143- // ---- handle encoder models ----
133+ // handle encoder models
144134 if (llama_model_has_encoder (_model)) {
145135 llama_token decoder_start_token_id = llama_model_decoder_start_token (_model);
146136 if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
@@ -149,11 +139,20 @@ string Llama::generate(const string &prompt) {
149139 batch = llama_batch_get_one (&decoder_start_token_id, 1 );
150140 if (llama_decode (_ctx, batch)) {
151141 _last_error = " failed to eval decoder start token" ;
152- return out ;
142+ return false ;
153143 }
154144 }
155145
156- // ---- generation loop ----
146+ iter._llama = this ;
147+ iter._batch = batch;
148+ iter._has_next = true ;
149+ iter._tokens_sec = 0 ;
150+ return true ;
151+ }
152+
153+ string Llama::next (LlamaIter &iter) {
154+ string out;
155+
157156 std::vector<llama_token> decoded;
158157 decoded.reserve (_max_tokens);
159158
@@ -166,22 +165,23 @@ string Llama::generate(const string &prompt) {
166165
167166 // end-of-generation check
168167 if (llama_vocab_is_eog (_vocab, tok)) {
168+ iter._has_next = false ;
169169 break ;
170170 }
171171
172172 // append token to decoded list
173173 decoded.push_back (tok);
174174 ++generated;
175175
176- // ---- decode the token immediately ----
176+ // decode the token
177177 llama_batch batch = llama_batch_get_one (&tok, 1 );
178178 if (llama_decode (_ctx, batch)) {
179179 _last_error = " failed to eval token during generation" ;
180180 break ;
181181 }
182182 }
183183
184- // ---- detokenize sequentially ----
184+ // detokenize sequentially
185185 if (!decoded.empty ()) {
186186 char buf[512 ];
187187 for (llama_token tok : decoded) {
@@ -195,14 +195,10 @@ string Llama::generate(const string &prompt) {
195195 }
196196 }
197197
198- // ---- timing ----
198+ // timing
199199 auto t_end = std::chrono::high_resolution_clock::now ();
200200 double secs = std::chrono::duration<double >(t_end - t_start).count ();
201- double tokps = secs > 0 ? generated / secs : 0 ;
202-
203- fprintf (stderr,
204- " [tok/s=%.2f] generated=%d time=%.3fs\n " ,
205- tokps, generated, secs);
201+ iter._tokens_sec = secs > 0 ? generated / secs : 0 ;
206202
207203 return out;
208204}
0 commit comments