rusty_genius_cortex/
backend.rs1use crate::Engine;
2use anyhow::{anyhow, Result};
3use async_std::task::{self, sleep};
4use async_trait::async_trait;
5use futures::channel::mpsc;
6use futures::sink::SinkExt;
7use rusty_genius_core::protocol::{InferenceEvent, ThoughtEvent};
8use std::time::Duration;
9
10#[cfg(feature = "real-engine")]
11use llama_cpp_2::context::params::LlamaContextParams;
12#[cfg(feature = "real-engine")]
13use llama_cpp_2::llama_backend::LlamaBackend;
14#[cfg(feature = "real-engine")]
15use llama_cpp_2::llama_batch::LlamaBatch;
16#[cfg(feature = "real-engine")]
17use llama_cpp_2::model::params::LlamaModelParams;
18#[cfg(feature = "real-engine")]
19use llama_cpp_2::model::{AddBos, LlamaModel};
20#[cfg(feature = "real-engine")]
21use std::num::NonZeroU32;
22#[cfg(feature = "real-engine")]
23use std::sync::Arc;
24
25pub struct Pinky {
28 model_loaded: bool,
29}
30
31impl Pinky {
32 pub fn new() -> Self {
33 Self {
34 model_loaded: false,
35 }
36 }
37}
38
39#[async_trait]
40impl Engine for Pinky {
41 async fn load_model(&mut self, _model_path: &str) -> Result<()> {
42 sleep(Duration::from_millis(100)).await;
44 self.model_loaded = true;
45 Ok(())
46 }
47
48 async fn unload_model(&mut self) -> Result<()> {
49 self.model_loaded = false;
50 Ok(())
51 }
52
53 async fn infer(&mut self, prompt: &str) -> Result<mpsc::Receiver<Result<InferenceEvent>>> {
54 if !self.model_loaded {
55 return Err(anyhow!("Pinky Error: No model loaded!"));
56 }
57
58 let (mut tx, rx) = mpsc::channel(100);
59 let prompt = prompt.to_string();
60
61 task::spawn(async move {
62 let _ = tx.send(Ok(InferenceEvent::ProcessStart)).await;
63 task::sleep(Duration::from_millis(50)).await;
64
65 let _ = tx
67 .send(Ok(InferenceEvent::Thought(ThoughtEvent::Start)))
68 .await;
69 let _ = tx
70 .send(Ok(InferenceEvent::Thought(ThoughtEvent::Delta(
71 "Narf!".to_string(),
72 ))))
73 .await;
74 task::sleep(Duration::from_millis(50)).await;
75 let _ = tx
76 .send(Ok(InferenceEvent::Thought(ThoughtEvent::Stop)))
77 .await;
78
79 let _ = tx
81 .send(Ok(InferenceEvent::Content(format!(
82 "Pinky says: {}",
83 prompt
84 ))))
85 .await;
86
87 let _ = tx.send(Ok(InferenceEvent::Complete)).await;
88 });
89
90 Ok(rx)
91 }
92}
93
94#[cfg(feature = "real-engine")]
97pub struct Brain {
98 model: Option<Arc<LlamaModel>>,
99 backend: Arc<LlamaBackend>,
100}
101
102#[cfg(feature = "real-engine")]
103impl Brain {
104 pub fn new() -> Self {
105 Self {
106 model: None,
107 backend: Arc::new(LlamaBackend::init().expect("Failed to init llama backend")),
108 }
109 }
110}
111
112#[cfg(feature = "real-engine")]
113#[async_trait]
114impl Engine for Brain {
115 async fn load_model(&mut self, model_path: &str) -> Result<()> {
116 let params = LlamaModelParams::default();
118 let model = LlamaModel::load_from_file(&self.backend, model_path, ¶ms)
119 .map_err(|e| anyhow!("Failed to load model from {}: {}", model_path, e))?;
120 self.model = Some(Arc::new(model));
121 Ok(())
122 }
123
124 async fn unload_model(&mut self) -> Result<()> {
125 self.model = None;
126 Ok(())
127 }
128
129 async fn infer(&mut self, prompt: &str) -> Result<mpsc::Receiver<Result<InferenceEvent>>> {
130 let model = self
131 .model
132 .as_ref()
133 .ok_or_else(|| anyhow!("No model loaded"))?
134 .clone();
135
136 let backend = self.backend.clone();
138
139 let prompt_str = prompt.to_string();
140 let (mut tx, rx) = mpsc::channel(100);
141
142 task::spawn_blocking(move || {
143 let _ = futures::executor::block_on(tx.send(Ok(InferenceEvent::ProcessStart)));
145
146 let backend_ref = &backend;
148
149 let ctx_params =
151 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
152
153 let mut ctx = match model.new_context(backend_ref, ctx_params) {
154 Ok(c) => c,
155 Err(e) => {
156 let _ = futures::executor::block_on(
157 tx.send(Err(anyhow!("Context creation failed: {}", e))),
158 );
159 return;
160 }
161 };
162
163 let tokens_list = match model.str_to_token(&prompt_str, AddBos::Always) {
165 Ok(t) => t,
166 Err(e) => {
167 let _ = futures::executor::block_on(
168 tx.send(Err(anyhow!("Tokenize failed: {}", e))),
169 );
170 return;
171 }
172 };
173
174 let n_tokens = tokens_list.len();
176 let mut batch = LlamaBatch::new(2048, 1); let last_index = n_tokens as i32 - 1;
180 for (i, token) in tokens_list.iter().enumerate() {
181 let _ = batch.add(*token, i as i32, &[0], i as i32 == last_index);
184 }
185
186 if let Err(e) = ctx.decode(&mut batch) {
188 let _ = futures::executor::block_on(
189 tx.send(Err(anyhow!("Decode prompt failed: {}", e))),
190 );
191 return;
192 }
193
194 let mut n_cur = n_tokens as i32;
196 let n_decode = 0; let max_tokens = 512; let mut think_buffer = String::new();
200 let mut in_think_block = false;
201 let mut token_str_buffer = String::new();
202
203 loop {
204 let next_token = match ctx.sample_token_greedy(batch.n_tokens() - 1) {
206 Some(t) => t,
207 None => {
208 let _ = futures::executor::block_on(
209 tx.send(Err(anyhow!("Failed to sample token"))),
210 );
211 break;
212 }
213 };
214
215 let token_str = match model.token_to_str(next_token) {
217 Ok(s) => s.to_string(),
218 Err(_) => "??".to_string(),
219 };
220
221 if next_token == model.token_eos() || n_decode >= max_tokens {
223 break;
224 }
225
226 token_str_buffer.push_str(&token_str);
229
230 if !in_think_block {
232 if token_str_buffer.contains("<think>") {
233 in_think_block = true;
234 let _ = futures::executor::block_on(
236 tx.send(Ok(InferenceEvent::Thought(ThoughtEvent::Start))),
237 );
238
239 token_str_buffer = token_str_buffer.replace("<think>", "");
243 }
244 }
245
246 if in_think_block {
248 if token_str_buffer.contains("</think>") {
249 in_think_block = false;
250 let parts: Vec<&str> = token_str_buffer.split("</think>").collect();
252 if let Some(think_content) = parts.first() {
253 if !think_content.is_empty() {
254 let _ = futures::executor::block_on(tx.send(Ok(
255 InferenceEvent::Thought(ThoughtEvent::Delta(
256 think_content.to_string(),
257 )),
258 )));
259 }
260 }
261
262 let _ = futures::executor::block_on(
263 tx.send(Ok(InferenceEvent::Thought(ThoughtEvent::Stop))),
264 );
265
266 if parts.len() > 1 {
268 token_str_buffer = parts[1].to_string();
269 } else {
271 token_str_buffer.clear();
272 }
273 } else {
274 if !token_str_buffer.is_empty() {
276 let _ =
277 futures::executor::block_on(tx.send(Ok(InferenceEvent::Thought(
278 ThoughtEvent::Delta(token_str_buffer.clone()),
279 ))));
280 token_str_buffer.clear();
281 }
282 }
283 }
284
285 if !in_think_block && !token_str_buffer.is_empty() {
287 let _ = futures::executor::block_on(
288 tx.send(Ok(InferenceEvent::Content(token_str_buffer.clone()))),
289 );
290 token_str_buffer.clear();
291 }
292
293 batch.clear();
295 let _ = batch.add(next_token, n_cur, &[0], true);
296 n_cur += 1;
297
298 if let Err(e) = ctx.decode(&mut batch) {
299 let _ =
300 futures::executor::block_on(tx.send(Err(anyhow!("Decode failed: {}", e))));
301 break;
302 }
303 }
304
305 let _ = futures::executor::block_on(tx.send(Ok(InferenceEvent::Complete)));
306 });
307
308 Ok(rx)
309 }
310}