1use crate::loops::coordinator::{CoordinatorStats, LoopCoordinator};
4use crate::lora::MicroLoRA;
5use crate::trajectory::TrajectoryBuilder;
6use crate::types::{QueryTrajectory, SonaConfig, LearnedPattern};
7use parking_lot::RwLock;
8use std::sync::Arc;
9
10#[cfg(feature = "serde-support")]
11use crate::export::safetensors::{LoRAState, LoRALayerState};
12
13#[cfg(feature = "serde-support")]
14use crate::export::dataset::{QualityTrajectory, RoutingDecision};
15
16pub struct SonaEngine {
18 coordinator: LoopCoordinator,
20 config: SonaConfig,
22 enabled: bool,
24}
25
26impl SonaEngine {
27 pub fn new(hidden_dim: usize) -> Self {
29 Self::with_config(SonaConfig {
30 hidden_dim,
31 embedding_dim: hidden_dim,
32 ..Default::default()
33 })
34 }
35
36 pub fn with_config(config: SonaConfig) -> Self {
38 Self {
39 coordinator: LoopCoordinator::with_config(config.clone()),
40 config,
41 enabled: true,
42 }
43 }
44
45 pub fn begin_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryBuilder {
47 let id = self.coordinator.next_trajectory_id();
48 TrajectoryBuilder::new(id, query_embedding)
49 }
50
51 pub fn end_trajectory(&self, builder: TrajectoryBuilder, quality: f32) {
53 if !self.enabled {
54 return;
55 }
56
57 let trajectory = builder.build(quality);
58 self.coordinator.on_inference(trajectory);
59 }
60
61 pub fn submit_trajectory(&self, trajectory: QueryTrajectory) {
63 if self.enabled {
64 self.coordinator.on_inference(trajectory);
65 }
66 }
67
68 pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
70 if !self.enabled {
71 return;
72 }
73
74 if let Some(lora) = self.coordinator.micro_lora().try_read() {
75 lora.forward(input, output);
76 }
77 }
78
79 pub fn apply_base_lora(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
81 if !self.enabled {
82 return;
83 }
84
85 if let Some(lora) = self.coordinator.base_lora().try_read() {
86 lora.forward_layer(layer_idx, input, output);
87 }
88 }
89
90 pub fn tick(&self) -> Option<String> {
92 if !self.enabled {
93 return None;
94 }
95
96 if let Some(result) = self.coordinator.maybe_run_background() {
97 Some(format!(
98 "Background cycle: {} trajectories -> {} patterns in {:?}",
99 result.trajectories_processed,
100 result.patterns_extracted,
101 result.elapsed
102 ))
103 } else {
104 None
105 }
106 }
107
108 pub fn force_learn(&self) -> String {
110 let result = self.coordinator.force_background();
111 format!(
112 "Forced learning: {} trajectories -> {} patterns, status: {}",
113 result.trajectories_processed,
114 result.patterns_extracted,
115 result.status
116 )
117 }
118
119 pub fn flush(&self) {
121 self.coordinator.flush_instant();
122 }
123
124 pub fn find_patterns(&self, query_embedding: &[f32], k: usize) -> Vec<crate::LearnedPattern> {
126 self.coordinator
127 .reasoning_bank()
128 .read()
129 .find_similar(query_embedding, k)
130 .into_iter()
131 .cloned()
132 .collect()
133 }
134
135 pub fn stats(&self) -> CoordinatorStats {
137 self.coordinator.stats()
138 }
139
140 pub fn set_enabled(&mut self, enabled: bool) {
142 self.enabled = enabled;
143 }
144
145 pub fn is_enabled(&self) -> bool {
147 self.enabled
148 }
149
150 pub fn config(&self) -> &SonaConfig {
152 &self.config
153 }
154
155 pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
157 self.coordinator
158 .reasoning_bank()
159 .read()
160 .get_all_patterns()
161 }
162
163 #[cfg(feature = "serde-support")]
165 pub fn export_lora_state(&self) -> LoRAState {
166 let mut state = LoRAState::default();
167
168 if let Some(micro_lora) = self.coordinator.micro_lora().try_read() {
170 let (lora_a, lora_b) = micro_lora.get_weights();
171 state.micro_lora_layers.push(LoRALayerState {
172 lora_a: lora_a.clone(),
173 lora_b: lora_b.clone(),
174 rank: self.config.micro_lora_rank,
175 input_dim: self.config.hidden_dim,
176 output_dim: self.config.hidden_dim,
177 });
178 }
179
180 if let Some(base_lora) = self.coordinator.base_lora().try_read() {
182 for layer_idx in 0..base_lora.num_layers() {
183 if let Some((lora_a, lora_b)) = base_lora.get_layer_weights(layer_idx) {
184 state.base_lora_layers.push(LoRALayerState {
185 lora_a: lora_a.clone(),
186 lora_b: lora_b.clone(),
187 rank: self.config.base_lora_rank,
188 input_dim: self.config.hidden_dim,
189 output_dim: self.config.hidden_dim,
190 });
191 }
192 }
193 }
194
195 state
196 }
197
198 #[cfg(feature = "serde-support")]
200 pub fn get_quality_trajectories(&self) -> Vec<QualityTrajectory> {
201 self.coordinator
202 .trajectory_buffer()
203 .get_all()
204 .iter()
205 .map(|t| QualityTrajectory {
206 query_embedding: t.query_embedding.clone(),
207 response_embedding: t.steps.last()
208 .map(|s| s.activations.clone())
209 .unwrap_or_default(),
210 route: t.model_route.clone().unwrap_or_default(),
211 quality: t.final_quality,
212 context_ids: t.context_ids.clone(),
213 })
214 .collect()
215 }
216
217 #[cfg(feature = "serde-support")]
219 pub fn get_routing_decisions(&self) -> Vec<RoutingDecision> {
220 self.get_all_patterns()
222 .iter()
223 .map(|p| RoutingDecision {
224 query_embedding: p.centroid.clone(),
225 routing_logits: vec![p.avg_quality; 4], selected_route: p.pattern_type.to_string(),
227 confidence: p.avg_quality,
228 quality: p.avg_quality,
229 })
230 .collect()
231 }
232}
233
234pub struct SonaEngineBuilder {
236 config: SonaConfig,
237}
238
239impl SonaEngineBuilder {
240 pub fn new() -> Self {
242 Self {
243 config: SonaConfig::default(),
244 }
245 }
246
247 pub fn hidden_dim(mut self, dim: usize) -> Self {
249 self.config.hidden_dim = dim;
250 self.config.embedding_dim = dim;
251 self
252 }
253
254 pub fn micro_lora_rank(mut self, rank: usize) -> Self {
256 self.config.micro_lora_rank = rank.clamp(1, 2);
257 self
258 }
259
260 pub fn base_lora_rank(mut self, rank: usize) -> Self {
262 self.config.base_lora_rank = rank;
263 self
264 }
265
266 pub fn micro_lr(mut self, lr: f32) -> Self {
268 self.config.micro_lora_lr = lr;
269 self
270 }
271
272 pub fn base_lr(mut self, lr: f32) -> Self {
274 self.config.base_lora_lr = lr;
275 self
276 }
277
278 pub fn ewc_lambda(mut self, lambda: f32) -> Self {
280 self.config.ewc_lambda = lambda;
281 self
282 }
283
284 pub fn pattern_clusters(mut self, k: usize) -> Self {
286 self.config.pattern_clusters = k;
287 self
288 }
289
290 pub fn buffer_capacity(mut self, capacity: usize) -> Self {
292 self.config.trajectory_capacity = capacity;
293 self
294 }
295
296 pub fn quality_threshold(mut self, threshold: f32) -> Self {
298 self.config.quality_threshold = threshold;
299 self
300 }
301
302 pub fn build(self) -> SonaEngine {
304 SonaEngine::with_config(self.config)
305 }
306}
307
308impl Default for SonaEngineBuilder {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::types::TrajectoryStep;
318
319 #[test]
320 fn test_engine_creation() {
321 let engine = SonaEngine::new(256);
322 assert!(engine.is_enabled());
323 }
324
325 #[test]
326 fn test_builder() {
327 let engine = SonaEngineBuilder::new()
328 .hidden_dim(512)
329 .micro_lora_rank(2)
330 .base_lora_rank(16)
331 .micro_lr(0.002)
332 .ewc_lambda(500.0)
333 .build();
334
335 assert_eq!(engine.config().hidden_dim, 512);
336 assert_eq!(engine.config().micro_lora_rank, 2);
337 }
338
339 #[test]
340 fn test_trajectory_workflow() {
341 let engine = SonaEngine::new(64);
342
343 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
345 builder.add_step(vec![0.5; 64], vec![], 0.8);
346 builder.add_step(vec![0.6; 64], vec![], 0.9);
347
348 engine.end_trajectory(builder, 0.85);
350
351 let stats = engine.stats();
352 assert_eq!(stats.trajectories_buffered, 1);
353 }
354
355 #[test]
356 fn test_micro_lora_application() {
357 let engine = SonaEngine::new(64);
358
359 for i in 0..10 {
361 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
362 builder.add_step(vec![0.5; 64], vec![], 0.8);
363 engine.end_trajectory(builder, 0.8);
364 }
365 engine.flush();
366
367 let input = vec![1.0; 64];
369 let mut output = vec![0.0; 64];
370 engine.apply_micro_lora(&input, &mut output);
371
372 }
374
375 #[test]
376 fn test_force_learn() {
377 let engine = SonaEngine::new(256);
378
379 for i in 0..150 {
380 let mut builder = engine.begin_trajectory(vec![0.1; 256]);
381 builder.add_step(vec![0.5; 256], vec![], 0.8);
382 engine.end_trajectory(builder, 0.8);
383 }
384
385 let result = engine.force_learn();
386 assert!(result.contains("150 trajectories"));
387 }
388
389 #[test]
390 fn test_disabled_engine() {
391 let mut engine = SonaEngine::new(64);
392 engine.set_enabled(false);
393
394 let builder = engine.begin_trajectory(vec![0.1; 64]);
395 engine.end_trajectory(builder, 0.8);
396
397 let stats = engine.stats();
399 assert_eq!(stats.trajectories_buffered, 0);
400 }
401}