1use crate::loops::coordinator::{CoordinatorStats, LoopCoordinator};
4use crate::trajectory::TrajectoryBuilder;
5use crate::types::{QueryTrajectory, SonaConfig};
6
7pub struct SonaEngine {
9 coordinator: LoopCoordinator,
11 config: SonaConfig,
13 enabled: bool,
15}
16
17impl std::fmt::Debug for SonaEngine {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("SonaEngine")
20 .field("config", &self.config)
21 .field("enabled", &self.enabled)
22 .finish_non_exhaustive()
23 }
24}
25
26impl SonaEngine {
27 pub fn new(hidden_dim: usize) -> Self {
29 Self::with_config(SonaConfig {
30 hidden_dim,
31 embedding_dim: hidden_dim,
32 ..Default::default()
33 })
34 }
35
36 pub fn with_config(config: SonaConfig) -> Self {
38 Self {
39 coordinator: LoopCoordinator::with_config(config.clone()),
40 config,
41 enabled: true,
42 }
43 }
44
45 pub fn begin_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryBuilder {
47 let id = self.coordinator.next_trajectory_id();
48 TrajectoryBuilder::new(id, query_embedding)
49 }
50
51 pub fn end_trajectory(&self, builder: TrajectoryBuilder, quality: f32) {
53 if !self.enabled {
54 return;
55 }
56
57 let trajectory = builder.build(quality);
58 self.coordinator.on_inference(trajectory);
59 }
60
61 pub fn submit_trajectory(&self, trajectory: QueryTrajectory) {
63 if self.enabled {
64 self.coordinator.on_inference(trajectory);
65 }
66 }
67
68 pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
70 if !self.enabled {
71 return;
72 }
73
74 if let Some(lora) = self.coordinator.micro_lora().try_read() {
75 lora.forward(input, output);
76 }
77 }
78
79 pub fn apply_base_lora(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
81 if !self.enabled {
82 return;
83 }
84
85 if let Some(lora) = self.coordinator.base_lora().try_read() {
86 lora.forward_layer(layer_idx, input, output);
87 }
88 }
89
90 pub fn tick(&self) -> Option<String> {
92 if !self.enabled {
93 return None;
94 }
95
96 if let Some(result) = self.coordinator.maybe_run_background() {
97 Some(format!(
98 "Background cycle: {} trajectories -> {} patterns in {:?}",
99 result.trajectories_processed, result.patterns_extracted, result.elapsed
100 ))
101 } else {
102 None
103 }
104 }
105
106 pub fn force_learn(&self) -> String {
108 let result = self.coordinator.force_background();
109 format!(
110 "Forced learning: {} trajectories -> {} patterns, status: {}",
111 result.trajectories_processed, result.patterns_extracted, result.status
112 )
113 }
114
115 pub fn flush(&self) {
117 self.coordinator.flush_instant();
118 }
119
120 pub fn find_patterns(&self, query_embedding: &[f32], k: usize) -> Vec<crate::LearnedPattern> {
122 self.coordinator
123 .reasoning_bank()
124 .read()
125 .find_similar(query_embedding, k)
126 .into_iter()
127 .cloned()
128 .collect()
129 }
130
131 pub fn stats(&self) -> CoordinatorStats {
133 self.coordinator.stats()
134 }
135
136 pub fn coordinator(&self) -> &LoopCoordinator {
138 &self.coordinator
139 }
140
141 pub fn set_enabled(&mut self, enabled: bool) {
143 self.enabled = enabled;
144 }
145
146 pub fn is_enabled(&self) -> bool {
148 self.enabled
149 }
150
151 pub fn config(&self) -> &SonaConfig {
153 &self.config
154 }
155
156 #[cfg(feature = "serde-support")]
158 pub fn get_all_patterns(&self) -> Vec<crate::LearnedPattern> {
159 self.coordinator.reasoning_bank().read().get_all_patterns()
160 }
161
162 #[cfg(feature = "serde-support")]
164 pub fn export_lora_state(&self) -> crate::export::safetensors::LoRAState {
165 use crate::export::safetensors::{LoRALayerState, LoRAState};
166
167 let mut state = LoRAState::default();
168
169 if let Some(lora) = self.coordinator.micro_lora().try_read() {
171 let (down, up) = lora.get_weights();
172 state.micro_lora_layers.push(LoRALayerState {
173 lora_a: down.clone(),
174 lora_b: up.clone(),
175 rank: self.config.micro_lora_rank,
176 input_dim: self.config.hidden_dim,
177 output_dim: self.config.hidden_dim,
178 });
179 }
180
181 if let Some(lora) = self.coordinator.base_lora().try_read() {
183 for idx in 0..lora.num_layers() {
184 if let Some((down, up)) = lora.get_layer_weights(idx) {
185 state.base_lora_layers.push(LoRALayerState {
186 lora_a: down.clone(),
187 lora_b: up.clone(),
188 rank: lora.rank,
189 input_dim: lora.hidden_dim,
190 output_dim: lora.hidden_dim,
191 });
192 }
193 }
194 }
195
196 state
197 }
198
199 #[cfg(feature = "serde-support")]
201 pub fn get_quality_trajectories(&self) -> Vec<crate::export::dataset::QualityTrajectory> {
202 use crate::export::dataset::QualityTrajectory;
203
204 let trajectories = self.coordinator.reasoning_bank().read().get_all_patterns();
206
207 trajectories
208 .iter()
209 .map(|p| {
210 QualityTrajectory {
211 query_embedding: p.centroid.clone(),
212 response_embedding: p.centroid.clone(), route: p.pattern_type.to_string(),
214 quality: p.avg_quality,
215 context_ids: vec![],
216 }
217 })
218 .collect()
219 }
220
221 #[cfg(feature = "serde-support")]
223 pub fn get_routing_decisions(&self) -> Vec<crate::export::dataset::RoutingDecision> {
224 use crate::export::dataset::RoutingDecision;
225
226 let patterns = self.coordinator.reasoning_bank().read().get_all_patterns();
227
228 patterns
229 .iter()
230 .map(|p| {
231 RoutingDecision {
232 query_embedding: p.centroid.clone(),
233 routing_logits: vec![p.avg_quality], selected_route: p.pattern_type.to_string(),
235 confidence: p.avg_quality,
236 quality: p.avg_quality,
237 }
238 })
239 .collect()
240 }
241}
242
243pub struct SonaEngineBuilder {
245 config: SonaConfig,
246}
247
248impl SonaEngineBuilder {
249 pub fn new() -> Self {
251 Self {
252 config: SonaConfig::default(),
253 }
254 }
255
256 pub fn hidden_dim(mut self, dim: usize) -> Self {
258 self.config.hidden_dim = dim;
259 self.config.embedding_dim = dim;
260 self
261 }
262
263 pub fn micro_lora_rank(mut self, rank: usize) -> Self {
265 self.config.micro_lora_rank = rank.clamp(1, 2);
266 self
267 }
268
269 pub fn base_lora_rank(mut self, rank: usize) -> Self {
271 self.config.base_lora_rank = rank;
272 self
273 }
274
275 pub fn micro_lr(mut self, lr: f32) -> Self {
277 self.config.micro_lora_lr = lr;
278 self
279 }
280
281 pub fn base_lr(mut self, lr: f32) -> Self {
283 self.config.base_lora_lr = lr;
284 self
285 }
286
287 pub fn ewc_lambda(mut self, lambda: f32) -> Self {
289 self.config.ewc_lambda = lambda;
290 self
291 }
292
293 pub fn pattern_clusters(mut self, k: usize) -> Self {
295 self.config.pattern_clusters = k;
296 self
297 }
298
299 pub fn buffer_capacity(mut self, capacity: usize) -> Self {
301 self.config.trajectory_capacity = capacity;
302 self
303 }
304
305 pub fn quality_threshold(mut self, threshold: f32) -> Self {
307 self.config.quality_threshold = threshold;
308 self
309 }
310
311 pub fn build(self) -> SonaEngine {
313 SonaEngine::with_config(self.config)
314 }
315}
316
317impl Default for SonaEngineBuilder {
318 fn default() -> Self {
319 Self::new()
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::types::TrajectoryStep;
327
328 #[test]
329 fn test_engine_creation() {
330 let engine = SonaEngine::new(256);
331 assert!(engine.is_enabled());
332 }
333
334 #[test]
335 fn test_builder() {
336 let engine = SonaEngineBuilder::new()
337 .hidden_dim(512)
338 .micro_lora_rank(2)
339 .base_lora_rank(16)
340 .micro_lr(0.002)
341 .ewc_lambda(500.0)
342 .build();
343
344 assert_eq!(engine.config().hidden_dim, 512);
345 assert_eq!(engine.config().micro_lora_rank, 2);
346 }
347
348 #[test]
349 fn test_trajectory_workflow() {
350 let engine = SonaEngine::new(64);
351
352 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
354 builder.add_step(vec![0.5; 64], vec![], 0.8);
355 builder.add_step(vec![0.6; 64], vec![], 0.9);
356
357 engine.end_trajectory(builder, 0.85);
359
360 let stats = engine.stats();
361 assert_eq!(stats.trajectories_buffered, 1);
362 }
363
364 #[test]
365 fn test_micro_lora_application() {
366 let engine = SonaEngine::new(64);
367
368 for i in 0..10 {
370 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
371 builder.add_step(vec![0.5; 64], vec![], 0.8);
372 engine.end_trajectory(builder, 0.8);
373 }
374 engine.flush();
375
376 let input = vec![1.0; 64];
378 let mut output = vec![0.0; 64];
379 engine.apply_micro_lora(&input, &mut output);
380
381 }
383
384 #[test]
385 fn test_force_learn() {
386 let engine = SonaEngine::new(256);
387
388 for _i in 0..150 {
389 let mut builder = engine.begin_trajectory(vec![0.1; 256]);
390 builder.add_step(vec![0.5; 256], vec![], 0.8);
391 engine.end_trajectory(builder, 0.8);
392 }
393
394 let result = engine.force_learn();
395 assert!(result.contains("150 trajectories"));
396 }
397
398 #[test]
399 fn test_force_learn_with_few_trajectories() {
400 let engine = SonaEngine::new(64);
402
403 for _i in 0..10 {
405 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
406 builder.add_step(vec![0.5; 64], vec![], 0.8);
407 engine.end_trajectory(builder, 0.8);
408 }
409
410 let result = engine.force_learn();
411 assert!(
413 result.contains("10 trajectories"),
414 "Expected '10 trajectories' but got: {}",
415 result
416 );
417 }
418
419 #[test]
420 fn test_disabled_engine() {
421 let mut engine = SonaEngine::new(64);
422 engine.set_enabled(false);
423
424 let builder = engine.begin_trajectory(vec![0.1; 64]);
425 engine.end_trajectory(builder, 0.8);
426
427 let stats = engine.stats();
429 assert_eq!(stats.trajectories_buffered, 0);
430 }
431}