Skip to main content

rusty_genius_cortex/
backend.rs

1use crate::Engine;
2use anyhow::{anyhow, Result};
3use async_std::task::{self, sleep};
4use async_trait::async_trait;
5use futures::channel::mpsc;
6use futures::sink::SinkExt;
7use rusty_genius_core::protocol::{InferenceEvent, ThoughtEvent};
8use std::time::Duration;
9
10#[cfg(feature = "real-engine")]
11use llama_cpp_2::context::params::LlamaContextParams;
12#[cfg(feature = "real-engine")]
13use llama_cpp_2::llama_backend::LlamaBackend;
14#[cfg(feature = "real-engine")]
15use llama_cpp_2::llama_batch::LlamaBatch;
16#[cfg(feature = "real-engine")]
17use llama_cpp_2::model::params::LlamaModelParams;
18#[cfg(feature = "real-engine")]
19use llama_cpp_2::model::{AddBos, LlamaModel};
20#[cfg(feature = "real-engine")]
21use std::num::NonZeroU32;
22#[cfg(feature = "real-engine")]
23use std::sync::Arc;
24
25// --- Pinky (Stub) ---
26
27pub struct Pinky {
28    model_loaded: bool,
29}
30
31impl Pinky {
32    pub fn new() -> Self {
33        Self {
34            model_loaded: false,
35        }
36    }
37}
38
39#[async_trait]
40impl Engine for Pinky {
41    async fn load_model(&mut self, _model_path: &str) -> Result<()> {
42        // Simulate loading time
43        sleep(Duration::from_millis(100)).await;
44        self.model_loaded = true;
45        Ok(())
46    }
47
48    async fn unload_model(&mut self) -> Result<()> {
49        self.model_loaded = false;
50        Ok(())
51    }
52
53    async fn infer(&mut self, prompt: &str) -> Result<mpsc::Receiver<Result<InferenceEvent>>> {
54        if !self.model_loaded {
55            return Err(anyhow!("Pinky Error: No model loaded!"));
56        }
57
58        let (mut tx, rx) = mpsc::channel(100);
59        let prompt = prompt.to_string();
60
61        task::spawn(async move {
62            let _ = tx.send(Ok(InferenceEvent::ProcessStart)).await;
63            task::sleep(Duration::from_millis(50)).await;
64
65            // Emit a "thought"
66            let _ = tx
67                .send(Ok(InferenceEvent::Thought(ThoughtEvent::Start)))
68                .await;
69            let _ = tx
70                .send(Ok(InferenceEvent::Thought(ThoughtEvent::Delta(
71                    "Narf!".to_string(),
72                ))))
73                .await;
74            task::sleep(Duration::from_millis(50)).await;
75            let _ = tx
76                .send(Ok(InferenceEvent::Thought(ThoughtEvent::Stop)))
77                .await;
78
79            // Emit content (echo prompt mostly)
80            let _ = tx
81                .send(Ok(InferenceEvent::Content(format!(
82                    "Pinky says: {}",
83                    prompt
84                ))))
85                .await;
86
87            let _ = tx.send(Ok(InferenceEvent::Complete)).await;
88        });
89
90        Ok(rx)
91    }
92}
93
94// --- Brain (Real) ---
95
96#[cfg(feature = "real-engine")]
97pub struct Brain {
98    model: Option<Arc<LlamaModel>>,
99    backend: Arc<LlamaBackend>,
100}
101
102#[cfg(feature = "real-engine")]
103impl Brain {
104    pub fn new() -> Self {
105        Self {
106            model: None,
107            backend: Arc::new(LlamaBackend::init().expect("Failed to init llama backend")),
108        }
109    }
110}
111
112#[cfg(feature = "real-engine")]
113#[async_trait]
114impl Engine for Brain {
115    async fn load_model(&mut self, model_path: &str) -> Result<()> {
116        // Load model
117        let params = LlamaModelParams::default();
118        let model = LlamaModel::load_from_file(&self.backend, model_path, &params)
119            .map_err(|e| anyhow!("Failed to load model from {}: {}", model_path, e))?;
120        self.model = Some(Arc::new(model));
121        Ok(())
122    }
123
124    async fn unload_model(&mut self) -> Result<()> {
125        self.model = None;
126        Ok(())
127    }
128
129    async fn infer(&mut self, prompt: &str) -> Result<mpsc::Receiver<Result<InferenceEvent>>> {
130        let model = self
131            .model
132            .as_ref()
133            .ok_or_else(|| anyhow!("No model loaded"))?
134            .clone();
135
136        // Share the backend reference
137        let backend = self.backend.clone();
138
139        let prompt_str = prompt.to_string();
140        let (mut tx, rx) = mpsc::channel(100);
141
142        task::spawn_blocking(move || {
143            // Send ProcessStart
144            let _ = futures::executor::block_on(tx.send(Ok(InferenceEvent::ProcessStart)));
145
146            // Use the shared backend (no re-init)
147            let backend_ref = &backend;
148
149            // Create context
150            let ctx_params =
151                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
152
153            let mut ctx = match model.new_context(backend_ref, ctx_params) {
154                Ok(c) => c,
155                Err(e) => {
156                    let _ = futures::executor::block_on(
157                        tx.send(Err(anyhow!("Context creation failed: {}", e))),
158                    );
159                    return;
160                }
161            };
162
163            // Tokenize
164            let tokens_list = match model.str_to_token(&prompt_str, AddBos::Always) {
165                Ok(t) => t,
166                Err(e) => {
167                    let _ = futures::executor::block_on(
168                        tx.send(Err(anyhow!("Tokenize failed: {}", e))),
169                    );
170                    return;
171                }
172            };
173
174            // Prepare Batch for Prompt
175            let n_tokens = tokens_list.len();
176            let mut batch = LlamaBatch::new(2048, 1); // Ensure batch size can handle context
177
178            // Load prompt into batch
179            let last_index = n_tokens as i32 - 1;
180            for (i, token) in tokens_list.iter().enumerate() {
181                // add(token, pos, &[seq_id], logits)
182                // We only need logits for the very last token to predict the next one
183                let _ = batch.add(*token, i as i32, &[0], i as i32 == last_index);
184            }
185
186            // Decode Prompt
187            if let Err(e) = ctx.decode(&mut batch) {
188                let _ = futures::executor::block_on(
189                    tx.send(Err(anyhow!("Decode prompt failed: {}", e))),
190                );
191                return;
192            }
193
194            // Generation Loop
195            let mut n_cur = n_tokens as i32;
196            let n_decode = 0; // generated tokens count
197            let max_tokens = 512; // Hard limit for safety
198
199            let mut think_buffer = String::new();
200            let mut in_think_block = false;
201            let mut token_str_buffer = String::new();
202
203            loop {
204                // Sample next token
205                let next_token = match ctx.sample_token_greedy(batch.n_tokens() - 1) {
206                    Some(t) => t,
207                    None => {
208                        let _ = futures::executor::block_on(
209                            tx.send(Err(anyhow!("Failed to sample token"))),
210                        );
211                        break;
212                    }
213                };
214
215                // Decode token to string
216                let token_str = match model.token_to_str(next_token) {
217                    Ok(s) => s.to_string(),
218                    Err(_) => "??".to_string(),
219                };
220
221                // Check for EOS
222                if next_token == model.token_eos() || n_decode >= max_tokens {
223                    break;
224                }
225
226                // Parse Logic for <think> tags
227                // Simple stream parsing
228                token_str_buffer.push_str(&token_str);
229
230                // If we are NOT in a think block, check if one is starting
231                if !in_think_block {
232                    if token_str_buffer.contains("<think>") {
233                        in_think_block = true;
234                        // Emit Start Thought event
235                        let _ = futures::executor::block_on(
236                            tx.send(Ok(InferenceEvent::Thought(ThoughtEvent::Start))),
237                        );
238
239                        // If there was content before <think>, we should emit it?
240                        // For simplicity assuming distinct blocks or just consuming tag.
241                        // Remove <think> from buffer to find remainder
242                        token_str_buffer = token_str_buffer.replace("<think>", "");
243                    }
244                }
245
246                // If we ARE in a think block
247                if in_think_block {
248                    if token_str_buffer.contains("</think>") {
249                        in_think_block = false;
250                        // Emit Stop Thought event
251                        let parts: Vec<&str> = token_str_buffer.split("</think>").collect();
252                        if let Some(think_content) = parts.first() {
253                            if !think_content.is_empty() {
254                                let _ = futures::executor::block_on(tx.send(Ok(
255                                    InferenceEvent::Thought(ThoughtEvent::Delta(
256                                        think_content.to_string(),
257                                    )),
258                                )));
259                            }
260                        }
261
262                        let _ = futures::executor::block_on(
263                            tx.send(Ok(InferenceEvent::Thought(ThoughtEvent::Stop))),
264                        );
265
266                        // Remainder after </think> should be content?
267                        if parts.len() > 1 {
268                            token_str_buffer = parts[1].to_string();
269                            // Fallthrough to emit content
270                        } else {
271                            token_str_buffer.clear();
272                        }
273                    } else {
274                        // Stream delta
275                        if !token_str_buffer.is_empty() {
276                            let _ =
277                                futures::executor::block_on(tx.send(Ok(InferenceEvent::Thought(
278                                    ThoughtEvent::Delta(token_str_buffer.clone()),
279                                ))));
280                            token_str_buffer.clear();
281                        }
282                    }
283                }
284
285                // If NOT in think block (anymore), emit as content
286                if !in_think_block && !token_str_buffer.is_empty() {
287                    let _ = futures::executor::block_on(
288                        tx.send(Ok(InferenceEvent::Content(token_str_buffer.clone()))),
289                    );
290                    token_str_buffer.clear();
291                }
292
293                // Prepare next batch
294                batch.clear();
295                let _ = batch.add(next_token, n_cur, &[0], true);
296                n_cur += 1;
297
298                if let Err(e) = ctx.decode(&mut batch) {
299                    let _ =
300                        futures::executor::block_on(tx.send(Err(anyhow!("Decode failed: {}", e))));
301                    break;
302                }
303            }
304
305            let _ = futures::executor::block_on(tx.send(Ok(InferenceEvent::Complete)));
306        });
307
308        Ok(rx)
309    }
310}