rusty_genius_cortex/
backend.rs1use crate::Engine;
2use anyhow::{anyhow, Result};
3use async_std::task::{self, sleep};
4use async_trait::async_trait;
5use futures::channel::mpsc;
6use futures::sink::SinkExt;
7use rusty_genius_core::protocol::{InferenceEvent, ThoughtEvent};
8use std::time::Duration;
9
10#[cfg(feature = "real-engine")]
11use llama_cpp_2::context::params::LlamaContextParams;
12#[cfg(feature = "real-engine")]
13use llama_cpp_2::llama_backend::LlamaBackend;
14#[cfg(feature = "real-engine")]
15use llama_cpp_2::llama_batch::LlamaBatch;
16#[cfg(feature = "real-engine")]
17use llama_cpp_2::model::params::LlamaModelParams;
18#[cfg(feature = "real-engine")]
19use llama_cpp_2::model::{AddBos, LlamaModel, Special};
20#[cfg(feature = "real-engine")]
21use llama_cpp_2::sampling::LlamaSampler;
22#[cfg(feature = "real-engine")]
23use std::num::NonZeroU32;
24#[cfg(feature = "real-engine")]
25use std::sync::Arc;
26
27pub struct Pinky {
30 model_loaded: bool,
31}
32
33impl Pinky {
34 pub fn new() -> Self {
35 Self {
36 model_loaded: false,
37 }
38 }
39}
40
41#[async_trait]
42impl Engine for Pinky {
43 async fn load_model(&mut self, _model_path: &str) -> Result<()> {
44 sleep(Duration::from_millis(100)).await;
46 self.model_loaded = true;
47 Ok(())
48 }
49
50 async fn unload_model(&mut self) -> Result<()> {
51 self.model_loaded = false;
52 Ok(())
53 }
54
55 async fn infer(&mut self, prompt: &str) -> Result<mpsc::Receiver<Result<InferenceEvent>>> {
56 if !self.model_loaded {
57 return Err(anyhow!("Pinky Error: No model loaded!"));
58 }
59
60 let (mut tx, rx) = mpsc::channel(100);
61 let prompt = prompt.to_string();
62
63 task::spawn(async move {
64 let _ = tx.send(Ok(InferenceEvent::ProcessStart)).await;
65 task::sleep(Duration::from_millis(50)).await;
66
67 let _ = tx
69 .send(Ok(InferenceEvent::Thought(ThoughtEvent::Start)))
70 .await;
71 let _ = tx
72 .send(Ok(InferenceEvent::Thought(ThoughtEvent::Delta(
73 "Narf!".to_string(),
74 ))))
75 .await;
76 task::sleep(Duration::from_millis(50)).await;
77 let _ = tx
78 .send(Ok(InferenceEvent::Thought(ThoughtEvent::Stop)))
79 .await;
80
81 let _ = tx
83 .send(Ok(InferenceEvent::Content(format!(
84 "Pinky says: {}",
85 prompt
86 ))))
87 .await;
88
89 let _ = tx.send(Ok(InferenceEvent::Complete)).await;
90 });
91
92 Ok(rx)
93 }
94}
95
96#[cfg(feature = "real-engine")]
99pub struct Brain {
100 model: Option<Arc<LlamaModel>>,
101 backend: Arc<LlamaBackend>,
102}
103
104#[cfg(feature = "real-engine")]
105impl Brain {
106 pub fn new() -> Self {
107 Self {
108 model: None,
109 backend: Arc::new(LlamaBackend::init().expect("Failed to init llama backend")),
110 }
111 }
112}
113
114#[cfg(feature = "real-engine")]
115#[async_trait]
116impl Engine for Brain {
117 async fn load_model(&mut self, model_path: &str) -> Result<()> {
118 let params = LlamaModelParams::default();
120 let model = LlamaModel::load_from_file(&self.backend, model_path, ¶ms)
121 .map_err(|e| anyhow!("Failed to load model from {}: {}", model_path, e))?;
122 self.model = Some(Arc::new(model));
123 Ok(())
124 }
125
126 async fn unload_model(&mut self) -> Result<()> {
127 self.model = None;
128 Ok(())
129 }
130
131 async fn infer(&mut self, prompt: &str) -> Result<mpsc::Receiver<Result<InferenceEvent>>> {
132 let model = self
133 .model
134 .as_ref()
135 .ok_or_else(|| anyhow!("No model loaded"))?
136 .clone();
137
138 let backend = self.backend.clone();
140
141 let prompt_str = prompt.to_string();
142 let (mut tx, rx) = mpsc::channel(100);
143
144 task::spawn_blocking(move || {
145 let _ = futures::executor::block_on(tx.send(Ok(InferenceEvent::ProcessStart)));
147
148 let backend_ref = &backend;
150
151 let ctx_params =
153 LlamaContextParams::default().with_n_ctx(Some(NonZeroU32::new(2048).unwrap()));
154
155 let mut ctx = match model.new_context(backend_ref, ctx_params) {
156 Ok(c) => c,
157 Err(e) => {
158 let _ = futures::executor::block_on(
159 tx.send(Err(anyhow!("Context creation failed: {}", e))),
160 );
161 return;
162 }
163 };
164
165 let tokens_list = match model.str_to_token(&prompt_str, AddBos::Always) {
167 Ok(t) => t,
168 Err(e) => {
169 let _ = futures::executor::block_on(
170 tx.send(Err(anyhow!("Tokenize failed: {}", e))),
171 );
172 return;
173 }
174 };
175
176 let n_tokens = tokens_list.len();
178 let mut batch = LlamaBatch::new(2048, 1); let last_index = n_tokens as i32 - 1;
182 for (i, token) in tokens_list.iter().enumerate() {
183 let _ = batch.add(*token, i as i32, &[0], i as i32 == last_index);
186 }
187
188 if let Err(e) = ctx.decode(&mut batch) {
190 let _ = futures::executor::block_on(
191 tx.send(Err(anyhow!("Decode prompt failed: {}", e))),
192 );
193 return;
194 }
195
196 let mut n_cur = n_tokens as i32;
198 let n_decode = 0; let max_tokens = 512; let mut think_buffer = String::new();
202 let mut in_think_block = false;
203 let mut token_str_buffer = String::new();
204
205 loop {
206 let mut sampler = LlamaSampler::greedy();
208 let next_token = sampler.sample(&ctx, batch.n_tokens() - 1);
209
210 let token_str = match model.token_to_str(next_token, Special::Plaintext) {
212 Ok(s) => s.to_string(),
213 Err(_) => "??".to_string(),
214 };
215
216 if next_token == model.token_eos() || n_decode >= max_tokens {
218 break;
219 }
220
221 token_str_buffer.push_str(&token_str);
224
225 if !in_think_block {
227 if token_str_buffer.contains("<think>") {
228 in_think_block = true;
229 let _ = futures::executor::block_on(
231 tx.send(Ok(InferenceEvent::Thought(ThoughtEvent::Start))),
232 );
233
234 token_str_buffer = token_str_buffer.replace("<think>", "");
238 }
239 }
240
241 if in_think_block {
243 if token_str_buffer.contains("</think>") {
244 in_think_block = false;
245 let parts: Vec<&str> = token_str_buffer.split("</think>").collect();
247 if let Some(think_content) = parts.first() {
248 if !think_content.is_empty() {
249 let _ = futures::executor::block_on(tx.send(Ok(
250 InferenceEvent::Thought(ThoughtEvent::Delta(
251 think_content.to_string(),
252 )),
253 )));
254 }
255 }
256
257 let _ = futures::executor::block_on(
258 tx.send(Ok(InferenceEvent::Thought(ThoughtEvent::Stop))),
259 );
260
261 if parts.len() > 1 {
263 token_str_buffer = parts[1].to_string();
264 } else {
266 token_str_buffer.clear();
267 }
268 } else {
269 if !token_str_buffer.is_empty() {
271 let _ =
272 futures::executor::block_on(tx.send(Ok(InferenceEvent::Thought(
273 ThoughtEvent::Delta(token_str_buffer.clone()),
274 ))));
275 token_str_buffer.clear();
276 }
277 }
278 }
279
280 if !in_think_block && !token_str_buffer.is_empty() {
282 let _ = futures::executor::block_on(
283 tx.send(Ok(InferenceEvent::Content(token_str_buffer.clone()))),
284 );
285 token_str_buffer.clear();
286 }
287
288 batch.clear();
290 let _ = batch.add(next_token, n_cur, &[0], true);
291 n_cur += 1;
292
293 if let Err(e) = ctx.decode(&mut batch) {
294 let _ =
295 futures::executor::block_on(tx.send(Err(anyhow!("Decode failed: {}", e))));
296 break;
297 }
298 }
299
300 let _ = futures::executor::block_on(tx.send(Ok(InferenceEvent::Complete)));
301 });
302
303 Ok(rx)
304 }
305}