1use crate::loops::coordinator::{CoordinatorStats, LoopCoordinator};
4use crate::lora::MicroLoRA;
5use crate::trajectory::TrajectoryBuilder;
6use crate::types::{QueryTrajectory, SonaConfig};
7use parking_lot::RwLock;
8use std::sync::Arc;
9
10pub struct SonaEngine {
12 coordinator: LoopCoordinator,
14 config: SonaConfig,
16 enabled: bool,
18}
19
20impl SonaEngine {
21 pub fn new(hidden_dim: usize) -> Self {
23 Self::with_config(SonaConfig {
24 hidden_dim,
25 embedding_dim: hidden_dim,
26 ..Default::default()
27 })
28 }
29
30 pub fn with_config(config: SonaConfig) -> Self {
32 Self {
33 coordinator: LoopCoordinator::with_config(config.clone()),
34 config,
35 enabled: true,
36 }
37 }
38
39 pub fn begin_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryBuilder {
41 let id = self.coordinator.next_trajectory_id();
42 TrajectoryBuilder::new(id, query_embedding)
43 }
44
45 pub fn end_trajectory(&self, builder: TrajectoryBuilder, quality: f32) {
47 if !self.enabled {
48 return;
49 }
50
51 let trajectory = builder.build(quality);
52 self.coordinator.on_inference(trajectory);
53 }
54
55 pub fn submit_trajectory(&self, trajectory: QueryTrajectory) {
57 if self.enabled {
58 self.coordinator.on_inference(trajectory);
59 }
60 }
61
62 pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
64 if !self.enabled {
65 return;
66 }
67
68 if let Some(lora) = self.coordinator.micro_lora().try_read() {
69 lora.forward(input, output);
70 }
71 }
72
73 pub fn apply_base_lora(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
75 if !self.enabled {
76 return;
77 }
78
79 if let Some(lora) = self.coordinator.base_lora().try_read() {
80 lora.forward_layer(layer_idx, input, output);
81 }
82 }
83
84 pub fn tick(&self) -> Option<String> {
86 if !self.enabled {
87 return None;
88 }
89
90 if let Some(result) = self.coordinator.maybe_run_background() {
91 Some(format!(
92 "Background cycle: {} trajectories -> {} patterns in {:?}",
93 result.trajectories_processed,
94 result.patterns_extracted,
95 result.elapsed
96 ))
97 } else {
98 None
99 }
100 }
101
102 pub fn force_learn(&self) -> String {
104 let result = self.coordinator.force_background();
105 format!(
106 "Forced learning: {} trajectories -> {} patterns, status: {}",
107 result.trajectories_processed,
108 result.patterns_extracted,
109 result.status
110 )
111 }
112
113 pub fn flush(&self) {
115 self.coordinator.flush_instant();
116 }
117
118 pub fn find_patterns(&self, query_embedding: &[f32], k: usize) -> Vec<crate::LearnedPattern> {
120 self.coordinator
121 .reasoning_bank()
122 .read()
123 .find_similar(query_embedding, k)
124 .into_iter()
125 .cloned()
126 .collect()
127 }
128
129 pub fn stats(&self) -> CoordinatorStats {
131 self.coordinator.stats()
132 }
133
134 pub fn set_enabled(&mut self, enabled: bool) {
136 self.enabled = enabled;
137 }
138
139 pub fn is_enabled(&self) -> bool {
141 self.enabled
142 }
143
144 pub fn config(&self) -> &SonaConfig {
146 &self.config
147 }
148
149 #[cfg(feature = "serde-support")]
151 pub fn get_all_patterns(&self) -> Vec<crate::LearnedPattern> {
152 self.coordinator.reasoning_bank().read().get_all_patterns()
153 }
154
155 #[cfg(feature = "serde-support")]
157 pub fn export_lora_state(&self) -> crate::export::safetensors::LoRAState {
158 use crate::export::safetensors::{LoRAState, LoRALayerState};
159
160 let mut state = LoRAState::default();
161
162 if let Some(lora) = self.coordinator.micro_lora().try_read() {
164 let (down, up) = lora.get_weights();
165 state.micro_lora_layers.push(LoRALayerState {
166 lora_a: down.clone(),
167 lora_b: up.clone(),
168 rank: self.config.micro_lora_rank,
169 input_dim: self.config.hidden_dim,
170 output_dim: self.config.hidden_dim,
171 });
172 }
173
174 if let Some(lora) = self.coordinator.base_lora().try_read() {
176 for idx in 0..lora.num_layers() {
177 if let Some((down, up)) = lora.get_layer_weights(idx) {
178 state.base_lora_layers.push(LoRALayerState {
179 lora_a: down.clone(),
180 lora_b: up.clone(),
181 rank: lora.rank,
182 input_dim: lora.hidden_dim,
183 output_dim: lora.hidden_dim,
184 });
185 }
186 }
187 }
188
189 state
190 }
191
192 #[cfg(feature = "serde-support")]
194 pub fn get_quality_trajectories(&self) -> Vec<crate::export::dataset::QualityTrajectory> {
195 use crate::export::dataset::QualityTrajectory;
196
197 let trajectories = self.coordinator.reasoning_bank().read().get_all_patterns();
199
200 trajectories.iter().map(|p| {
201 QualityTrajectory {
202 query_embedding: p.centroid.clone(),
203 response_embedding: p.centroid.clone(), route: p.pattern_type.to_string(),
205 quality: p.avg_quality,
206 context_ids: vec![],
207 }
208 }).collect()
209 }
210
211 #[cfg(feature = "serde-support")]
213 pub fn get_routing_decisions(&self) -> Vec<crate::export::dataset::RoutingDecision> {
214 use crate::export::dataset::RoutingDecision;
215
216 let patterns = self.coordinator.reasoning_bank().read().get_all_patterns();
217
218 patterns.iter().map(|p| {
219 RoutingDecision {
220 query_embedding: p.centroid.clone(),
221 routing_logits: vec![p.avg_quality], selected_route: p.pattern_type.to_string(),
223 confidence: p.avg_quality,
224 quality: p.avg_quality,
225 }
226 }).collect()
227 }
228}
229
230pub struct SonaEngineBuilder {
232 config: SonaConfig,
233}
234
235impl SonaEngineBuilder {
236 pub fn new() -> Self {
238 Self {
239 config: SonaConfig::default(),
240 }
241 }
242
243 pub fn hidden_dim(mut self, dim: usize) -> Self {
245 self.config.hidden_dim = dim;
246 self.config.embedding_dim = dim;
247 self
248 }
249
250 pub fn micro_lora_rank(mut self, rank: usize) -> Self {
252 self.config.micro_lora_rank = rank.clamp(1, 2);
253 self
254 }
255
256 pub fn base_lora_rank(mut self, rank: usize) -> Self {
258 self.config.base_lora_rank = rank;
259 self
260 }
261
262 pub fn micro_lr(mut self, lr: f32) -> Self {
264 self.config.micro_lora_lr = lr;
265 self
266 }
267
268 pub fn base_lr(mut self, lr: f32) -> Self {
270 self.config.base_lora_lr = lr;
271 self
272 }
273
274 pub fn ewc_lambda(mut self, lambda: f32) -> Self {
276 self.config.ewc_lambda = lambda;
277 self
278 }
279
280 pub fn pattern_clusters(mut self, k: usize) -> Self {
282 self.config.pattern_clusters = k;
283 self
284 }
285
286 pub fn buffer_capacity(mut self, capacity: usize) -> Self {
288 self.config.trajectory_capacity = capacity;
289 self
290 }
291
292 pub fn quality_threshold(mut self, threshold: f32) -> Self {
294 self.config.quality_threshold = threshold;
295 self
296 }
297
298 pub fn build(self) -> SonaEngine {
300 SonaEngine::with_config(self.config)
301 }
302}
303
304impl Default for SonaEngineBuilder {
305 fn default() -> Self {
306 Self::new()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::types::TrajectoryStep;
314
315 #[test]
316 fn test_engine_creation() {
317 let engine = SonaEngine::new(256);
318 assert!(engine.is_enabled());
319 }
320
321 #[test]
322 fn test_builder() {
323 let engine = SonaEngineBuilder::new()
324 .hidden_dim(512)
325 .micro_lora_rank(2)
326 .base_lora_rank(16)
327 .micro_lr(0.002)
328 .ewc_lambda(500.0)
329 .build();
330
331 assert_eq!(engine.config().hidden_dim, 512);
332 assert_eq!(engine.config().micro_lora_rank, 2);
333 }
334
335 #[test]
336 fn test_trajectory_workflow() {
337 let engine = SonaEngine::new(64);
338
339 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
341 builder.add_step(vec![0.5; 64], vec![], 0.8);
342 builder.add_step(vec![0.6; 64], vec![], 0.9);
343
344 engine.end_trajectory(builder, 0.85);
346
347 let stats = engine.stats();
348 assert_eq!(stats.trajectories_buffered, 1);
349 }
350
351 #[test]
352 fn test_micro_lora_application() {
353 let engine = SonaEngine::new(64);
354
355 for i in 0..10 {
357 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
358 builder.add_step(vec![0.5; 64], vec![], 0.8);
359 engine.end_trajectory(builder, 0.8);
360 }
361 engine.flush();
362
363 let input = vec![1.0; 64];
365 let mut output = vec![0.0; 64];
366 engine.apply_micro_lora(&input, &mut output);
367
368 }
370
371 #[test]
372 fn test_force_learn() {
373 let engine = SonaEngine::new(256);
374
375 for i in 0..150 {
376 let mut builder = engine.begin_trajectory(vec![0.1; 256]);
377 builder.add_step(vec![0.5; 256], vec![], 0.8);
378 engine.end_trajectory(builder, 0.8);
379 }
380
381 let result = engine.force_learn();
382 assert!(result.contains("150 trajectories"));
383 }
384
385 #[test]
386 fn test_disabled_engine() {
387 let mut engine = SonaEngine::new(64);
388 engine.set_enabled(false);
389
390 let builder = engine.begin_trajectory(vec![0.1; 64]);
391 engine.end_trajectory(builder, 0.8);
392
393 let stats = engine.stats();
395 assert_eq!(stats.trajectories_buffered, 0);
396 }
397}