Skip to main content

rusty_genius_cortex/
backend.rs

1use crate::Engine;
2use anyhow::{anyhow, Result};
3use async_std::task::{self, sleep};
4use async_trait::async_trait;
5use futures::channel::mpsc;
6use futures::sink::SinkExt;
7use rusty_genius_core::protocol::{InferenceEvent, ThoughtEvent};
8use std::time::Duration;
9
10#[cfg(feature = "real-engine")]
11use llama_cpp_2::context::params::LlamaContextParams;
12#[cfg(feature = "real-engine")]
13use llama_cpp_2::llama_backend::LlamaBackend;
14#[cfg(feature = "real-engine")]
15use llama_cpp_2::llama_batch::LlamaBatch;
16#[cfg(feature = "real-engine")]
17use llama_cpp_2::model::params::LlamaModelParams;
18#[cfg(feature = "real-engine")]
19use llama_cpp_2::model::{AddBos, LlamaModel, Special};
20#[cfg(feature = "real-engine")]
21use llama_cpp_2::sampling::LlamaSampler;
22#[cfg(feature = "real-engine")]
23use std::num::NonZeroU32;
24#[cfg(feature = "real-engine")]
25use std::sync::Arc;
26
27// --- Pinky (Stub) ---
28
29pub struct Pinky {
30    model_loaded: bool,
31}
32
33impl Pinky {
34    pub fn new() -> Self {
35        Self {
36            model_loaded: false,
37        }
38    }
39}
40
41#[async_trait]
42impl Engine for Pinky {
43    async fn load_model(&mut self, _model_path: &str) -> Result<()> {
44        // Simulate loading time
45        sleep(Duration::from_millis(100)).await;
46        self.model_loaded = true;
47        Ok(())
48    }
49
50    async fn unload_model(&mut self) -> Result<()> {
51        self.model_loaded = false;
52        Ok(())
53    }
54
55    async fn infer(&mut self, prompt: &str) -> Result<mpsc::Receiver<Result<InferenceEvent>>> {
56        if !self.model_loaded {
57            return Err(anyhow!("Pinky Error: No model loaded!"));
58        }
59
60        let (mut tx, rx) = mpsc::channel(100);
61        let prompt = prompt.to_string();
62
63        task::spawn(async move {
64            let _ = tx.send(Ok(InferenceEvent::ProcessStart)).await;
65            task::sleep(Duration::from_millis(50)).await;
66
67            // Emit a "thought"
68            let _ = tx
69                .send(Ok(InferenceEvent::Thought(ThoughtEvent::Start)))
70                .await;
71            let _ = tx
72                .send(Ok(InferenceEvent::Thought(ThoughtEvent::Delta(
73                    "Narf!".to_string(),
74                ))))
75                .await;
76            task::sleep(Duration::from_millis(50)).await;
77            let _ = tx
78                .send(Ok(InferenceEvent::Thought(ThoughtEvent::Stop)))
79                .await;
80
81            // Emit content (echo prompt mostly)
82            let _ = tx
83                .send(Ok(InferenceEvent::Content(format!(
84                    "Pinky says: {}",
85                    prompt
86                ))))
87                .await;
88
89            let _ = tx.send(Ok(InferenceEvent::Complete)).await;
90        });
91
92        Ok(rx)
93    }
94}
95
96// --- Brain (Real) ---
97
98#[cfg(feature = "real-engine")]
99pub struct Brain {
100    model: Option<Arc<LlamaModel>>,
101    backend: Arc<LlamaBackend>,
102}
103
104#[cfg(feature = "real-engine")]
105impl Brain {
106    pub fn new() -> Self {
107        Self {
108            model: None,
109            backend: Arc::new(LlamaBackend::init().expect("Failed to init llama backend")),
110        }
111    }
112}
113
114#[cfg(feature = "real-engine")]
115#[async_trait]
116impl Engine for Brain {
117    async fn load_model(&mut self, model_path: &str) -> Result<()> {
118        // Load model
119        let params = LlamaModelParams::default();
120        let model = LlamaModel::load_from_file(&self.backend, model_path, &params)
121            .map_err(|e| anyhow!("Failed to load model from {}: {}", model_path, e))?;
122        self.model = Some(Arc::new(model));
123        Ok(())
124    }
125
126    async fn unload_model(&mut self) -> Result<()> {
127        self.model = None;
128        Ok(())
129    }
130
131    async fn infer(&mut self, prompt: &str) -> Result<mpsc::Receiver<Result<InferenceEvent>>> {
132        let model = self
133            .model
134            .as_ref()
135            .ok_or_else(|| anyhow!("No model loaded"))?
136            .clone();
137
138        // Share the backend reference
139        let backend = self.backend.clone();
140
141        let prompt_str = prompt.to_string();
142        let (mut tx, rx) = mpsc::channel(100);
143
144        task::spawn_blocking(move || {
145            // Send ProcessStart
146            let _ = futures::executor::block_on(tx.send(Ok(InferenceEvent::ProcessStart)));
147
148            // Use the shared backend (no re-init)
149            let backend_ref = &backend;
150
151            // Create context
152            let ctx_params =
153                LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
154
155            let mut ctx = match model.new_context(backend_ref, ctx_params) {
156                Ok(c) => c,
157                Err(e) => {
158                    let _ = futures::executor::block_on(
159                        tx.send(Err(anyhow!("Context creation failed: {}", e))),
160                    );
161                    return;
162                }
163            };
164
165            // Tokenize
166            let tokens_list = match model.str_to_token(&prompt_str, AddBos::Always) {
167                Ok(t) => t,
168                Err(e) => {
169                    let _ = futures::executor::block_on(
170                        tx.send(Err(anyhow!("Tokenize failed: {}", e))),
171                    );
172                    return;
173                }
174            };
175
176            // Prepare Batch for Prompt
177            let n_tokens = tokens_list.len();
178            let mut batch = LlamaBatch::new(2048, 1); // Ensure batch size can handle context
179
180            // Load prompt into batch
181            let last_index = n_tokens as i32 - 1;
182            for (i, token) in tokens_list.iter().enumerate() {
183                // add(token, pos, &[seq_id], logits)
184                // We only need logits for the very last token to predict the next one
185                let _ = batch.add(*token, i as i32, &[0], i as i32 == last_index);
186            }
187
188            // Decode Prompt
189            if let Err(e) = ctx.decode(&mut batch) {
190                let _ = futures::executor::block_on(
191                    tx.send(Err(anyhow!("Decode prompt failed: {}", e))),
192                );
193                return;
194            }
195
196            // Generation Loop
197            let mut n_cur = n_tokens as i32;
198            let n_decode = 0; // generated tokens count
199            let max_tokens = 512; // Hard limit for safety
200
201            let mut think_buffer = String::new();
202            let mut in_think_block = false;
203            let mut token_str_buffer = String::new();
204
205            loop {
206                // Sample next token
207                let mut sampler = LlamaSampler::greedy();
208                let next_token = sampler.sample(&ctx, batch.n_tokens() - 1);
209
210                // Decode token to string
211                let token_str = match model.token_to_str(next_token, Special::Plaintext) {
212                    Ok(s) => s.to_string(),
213                    Err(_) => "??".to_string(),
214                };
215
216                // Check for EOS
217                if next_token == model.token_eos() || n_decode >= max_tokens {
218                    break;
219                }
220
221                // Parse Logic for <think> tags
222                // Simple stream parsing
223                token_str_buffer.push_str(&token_str);
224
225                // If we are NOT in a think block, check if one is starting
226                if !in_think_block {
227                    if token_str_buffer.contains("<think>") {
228                        in_think_block = true;
229                        // Emit Start Thought event
230                        let _ = futures::executor::block_on(
231                            tx.send(Ok(InferenceEvent::Thought(ThoughtEvent::Start))),
232                        );
233
234                        // If there was content before <think>, we should emit it?
235                        // For simplicity assuming distinct blocks or just consuming tag.
236                        // Remove <think> from buffer to find remainder
237                        token_str_buffer = token_str_buffer.replace("<think>", "");
238                    }
239                }
240
241                // If we ARE in a think block
242                if in_think_block {
243                    if token_str_buffer.contains("</think>") {
244                        in_think_block = false;
245                        // Emit Stop Thought event
246                        let parts: Vec<&str> = token_str_buffer.split("</think>").collect();
247                        if let Some(think_content) = parts.first() {
248                            if !think_content.is_empty() {
249                                let _ = futures::executor::block_on(tx.send(Ok(
250                                    InferenceEvent::Thought(ThoughtEvent::Delta(
251                                        think_content.to_string(),
252                                    )),
253                                )));
254                            }
255                        }
256
257                        let _ = futures::executor::block_on(
258                            tx.send(Ok(InferenceEvent::Thought(ThoughtEvent::Stop))),
259                        );
260
261                        // Remainder after </think> should be content?
262                        if parts.len() > 1 {
263                            token_str_buffer = parts[1].to_string();
264                            // Fallthrough to emit content
265                        } else {
266                            token_str_buffer.clear();
267                        }
268                    } else {
269                        // Stream delta
270                        if !token_str_buffer.is_empty() {
271                            let _ =
272                                futures::executor::block_on(tx.send(Ok(InferenceEvent::Thought(
273                                    ThoughtEvent::Delta(token_str_buffer.clone()),
274                                ))));
275                            token_str_buffer.clear();
276                        }
277                    }
278                }
279
280                // If NOT in think block (anymore), emit as content
281                if !in_think_block && !token_str_buffer.is_empty() {
282                    let _ = futures::executor::block_on(
283                        tx.send(Ok(InferenceEvent::Content(token_str_buffer.clone()))),
284                    );
285                    token_str_buffer.clear();
286                }
287
288                // Prepare next batch
289                batch.clear();
290                let _ = batch.add(next_token, n_cur, &[0], true);
291                n_cur += 1;
292
293                if let Err(e) = ctx.decode(&mut batch) {
294                    let _ =
295                        futures::executor::block_on(tx.send(Err(anyhow!("Decode failed: {}", e))));
296                    break;
297                }
298            }
299
300            let _ = futures::executor::block_on(tx.send(Ok(InferenceEvent::Complete)));
301        });
302
303        Ok(rx)
304    }
305}