1use crate::loops::coordinator::{CoordinatorStats, LoopCoordinator};
4use crate::lora::MicroLoRA;
5use crate::trajectory::TrajectoryBuilder;
6use crate::types::{QueryTrajectory, SonaConfig};
7use parking_lot::RwLock;
8use std::sync::Arc;
9
10pub struct SonaEngine {
12 coordinator: LoopCoordinator,
14 config: SonaConfig,
16 enabled: bool,
18}
19
20impl std::fmt::Debug for SonaEngine {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 f.debug_struct("SonaEngine")
23 .field("config", &self.config)
24 .field("enabled", &self.enabled)
25 .finish_non_exhaustive()
26 }
27}
28
29impl SonaEngine {
30 pub fn new(hidden_dim: usize) -> Self {
32 Self::with_config(SonaConfig {
33 hidden_dim,
34 embedding_dim: hidden_dim,
35 ..Default::default()
36 })
37 }
38
39 pub fn with_config(config: SonaConfig) -> Self {
41 Self {
42 coordinator: LoopCoordinator::with_config(config.clone()),
43 config,
44 enabled: true,
45 }
46 }
47
48 pub fn begin_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryBuilder {
50 let id = self.coordinator.next_trajectory_id();
51 TrajectoryBuilder::new(id, query_embedding)
52 }
53
54 pub fn end_trajectory(&self, builder: TrajectoryBuilder, quality: f32) {
56 if !self.enabled {
57 return;
58 }
59
60 let trajectory = builder.build(quality);
61 self.coordinator.on_inference(trajectory);
62 }
63
64 pub fn submit_trajectory(&self, trajectory: QueryTrajectory) {
66 if self.enabled {
67 self.coordinator.on_inference(trajectory);
68 }
69 }
70
71 pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
73 if !self.enabled {
74 return;
75 }
76
77 if let Some(lora) = self.coordinator.micro_lora().try_read() {
78 lora.forward(input, output);
79 }
80 }
81
82 pub fn apply_base_lora(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
84 if !self.enabled {
85 return;
86 }
87
88 if let Some(lora) = self.coordinator.base_lora().try_read() {
89 lora.forward_layer(layer_idx, input, output);
90 }
91 }
92
93 pub fn tick(&self) -> Option<String> {
95 if !self.enabled {
96 return None;
97 }
98
99 if let Some(result) = self.coordinator.maybe_run_background() {
100 Some(format!(
101 "Background cycle: {} trajectories -> {} patterns in {:?}",
102 result.trajectories_processed, result.patterns_extracted, result.elapsed
103 ))
104 } else {
105 None
106 }
107 }
108
109 pub fn force_learn(&self) -> String {
111 let result = self.coordinator.force_background();
112 format!(
113 "Forced learning: {} trajectories -> {} patterns, status: {}",
114 result.trajectories_processed, result.patterns_extracted, result.status
115 )
116 }
117
118 pub fn flush(&self) {
120 self.coordinator.flush_instant();
121 }
122
123 pub fn find_patterns(&self, query_embedding: &[f32], k: usize) -> Vec<crate::LearnedPattern> {
125 self.coordinator
126 .reasoning_bank()
127 .read()
128 .find_similar(query_embedding, k)
129 .into_iter()
130 .cloned()
131 .collect()
132 }
133
134 pub fn stats(&self) -> CoordinatorStats {
136 self.coordinator.stats()
137 }
138
139 pub fn set_enabled(&mut self, enabled: bool) {
141 self.enabled = enabled;
142 }
143
144 pub fn is_enabled(&self) -> bool {
146 self.enabled
147 }
148
149 pub fn config(&self) -> &SonaConfig {
151 &self.config
152 }
153
154 #[cfg(feature = "serde-support")]
156 pub fn get_all_patterns(&self) -> Vec<crate::LearnedPattern> {
157 self.coordinator.reasoning_bank().read().get_all_patterns()
158 }
159
160 #[cfg(feature = "serde-support")]
162 pub fn export_lora_state(&self) -> crate::export::safetensors::LoRAState {
163 use crate::export::safetensors::{LoRALayerState, LoRAState};
164
165 let mut state = LoRAState::default();
166
167 if let Some(lora) = self.coordinator.micro_lora().try_read() {
169 let (down, up) = lora.get_weights();
170 state.micro_lora_layers.push(LoRALayerState {
171 lora_a: down.clone(),
172 lora_b: up.clone(),
173 rank: self.config.micro_lora_rank,
174 input_dim: self.config.hidden_dim,
175 output_dim: self.config.hidden_dim,
176 });
177 }
178
179 if let Some(lora) = self.coordinator.base_lora().try_read() {
181 for idx in 0..lora.num_layers() {
182 if let Some((down, up)) = lora.get_layer_weights(idx) {
183 state.base_lora_layers.push(LoRALayerState {
184 lora_a: down.clone(),
185 lora_b: up.clone(),
186 rank: lora.rank,
187 input_dim: lora.hidden_dim,
188 output_dim: lora.hidden_dim,
189 });
190 }
191 }
192 }
193
194 state
195 }
196
197 #[cfg(feature = "serde-support")]
199 pub fn get_quality_trajectories(&self) -> Vec<crate::export::dataset::QualityTrajectory> {
200 use crate::export::dataset::QualityTrajectory;
201
202 let trajectories = self.coordinator.reasoning_bank().read().get_all_patterns();
204
205 trajectories
206 .iter()
207 .map(|p| {
208 QualityTrajectory {
209 query_embedding: p.centroid.clone(),
210 response_embedding: p.centroid.clone(), route: p.pattern_type.to_string(),
212 quality: p.avg_quality,
213 context_ids: vec![],
214 }
215 })
216 .collect()
217 }
218
219 #[cfg(feature = "serde-support")]
221 pub fn get_routing_decisions(&self) -> Vec<crate::export::dataset::RoutingDecision> {
222 use crate::export::dataset::RoutingDecision;
223
224 let patterns = self.coordinator.reasoning_bank().read().get_all_patterns();
225
226 patterns
227 .iter()
228 .map(|p| {
229 RoutingDecision {
230 query_embedding: p.centroid.clone(),
231 routing_logits: vec![p.avg_quality], selected_route: p.pattern_type.to_string(),
233 confidence: p.avg_quality,
234 quality: p.avg_quality,
235 }
236 })
237 .collect()
238 }
239}
240
241pub struct SonaEngineBuilder {
243 config: SonaConfig,
244}
245
246impl SonaEngineBuilder {
247 pub fn new() -> Self {
249 Self {
250 config: SonaConfig::default(),
251 }
252 }
253
254 pub fn hidden_dim(mut self, dim: usize) -> Self {
256 self.config.hidden_dim = dim;
257 self.config.embedding_dim = dim;
258 self
259 }
260
261 pub fn micro_lora_rank(mut self, rank: usize) -> Self {
263 self.config.micro_lora_rank = rank.clamp(1, 2);
264 self
265 }
266
267 pub fn base_lora_rank(mut self, rank: usize) -> Self {
269 self.config.base_lora_rank = rank;
270 self
271 }
272
273 pub fn micro_lr(mut self, lr: f32) -> Self {
275 self.config.micro_lora_lr = lr;
276 self
277 }
278
279 pub fn base_lr(mut self, lr: f32) -> Self {
281 self.config.base_lora_lr = lr;
282 self
283 }
284
285 pub fn ewc_lambda(mut self, lambda: f32) -> Self {
287 self.config.ewc_lambda = lambda;
288 self
289 }
290
291 pub fn pattern_clusters(mut self, k: usize) -> Self {
293 self.config.pattern_clusters = k;
294 self
295 }
296
297 pub fn buffer_capacity(mut self, capacity: usize) -> Self {
299 self.config.trajectory_capacity = capacity;
300 self
301 }
302
303 pub fn quality_threshold(mut self, threshold: f32) -> Self {
305 self.config.quality_threshold = threshold;
306 self
307 }
308
309 pub fn build(self) -> SonaEngine {
311 SonaEngine::with_config(self.config)
312 }
313}
314
315impl Default for SonaEngineBuilder {
316 fn default() -> Self {
317 Self::new()
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use crate::types::TrajectoryStep;
325
326 #[test]
327 fn test_engine_creation() {
328 let engine = SonaEngine::new(256);
329 assert!(engine.is_enabled());
330 }
331
332 #[test]
333 fn test_builder() {
334 let engine = SonaEngineBuilder::new()
335 .hidden_dim(512)
336 .micro_lora_rank(2)
337 .base_lora_rank(16)
338 .micro_lr(0.002)
339 .ewc_lambda(500.0)
340 .build();
341
342 assert_eq!(engine.config().hidden_dim, 512);
343 assert_eq!(engine.config().micro_lora_rank, 2);
344 }
345
346 #[test]
347 fn test_trajectory_workflow() {
348 let engine = SonaEngine::new(64);
349
350 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
352 builder.add_step(vec![0.5; 64], vec![], 0.8);
353 builder.add_step(vec![0.6; 64], vec![], 0.9);
354
355 engine.end_trajectory(builder, 0.85);
357
358 let stats = engine.stats();
359 assert_eq!(stats.trajectories_buffered, 1);
360 }
361
362 #[test]
363 fn test_micro_lora_application() {
364 let engine = SonaEngine::new(64);
365
366 for i in 0..10 {
368 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
369 builder.add_step(vec![0.5; 64], vec![], 0.8);
370 engine.end_trajectory(builder, 0.8);
371 }
372 engine.flush();
373
374 let input = vec![1.0; 64];
376 let mut output = vec![0.0; 64];
377 engine.apply_micro_lora(&input, &mut output);
378
379 }
381
382 #[test]
383 fn test_force_learn() {
384 let engine = SonaEngine::new(256);
385
386 for i in 0..150 {
387 let mut builder = engine.begin_trajectory(vec![0.1; 256]);
388 builder.add_step(vec![0.5; 256], vec![], 0.8);
389 engine.end_trajectory(builder, 0.8);
390 }
391
392 let result = engine.force_learn();
393 assert!(result.contains("150 trajectories"));
394 }
395
396 #[test]
397 fn test_disabled_engine() {
398 let mut engine = SonaEngine::new(64);
399 engine.set_enabled(false);
400
401 let builder = engine.begin_trajectory(vec![0.1; 64]);
402 engine.end_trajectory(builder, 0.8);
403
404 let stats = engine.stats();
406 assert_eq!(stats.trajectories_buffered, 0);
407 }
408}