1use crate::loops::coordinator::{CoordinatorStats, LoopCoordinator};
4use crate::lora::MicroLoRA;
5use crate::trajectory::TrajectoryBuilder;
6use crate::types::{QueryTrajectory, SonaConfig};
7use parking_lot::RwLock;
8use std::sync::Arc;
9
10pub struct SonaEngine {
12 coordinator: LoopCoordinator,
14 config: SonaConfig,
16 enabled: bool,
18}
19
20impl SonaEngine {
21 pub fn new(hidden_dim: usize) -> Self {
23 Self::with_config(SonaConfig {
24 hidden_dim,
25 embedding_dim: hidden_dim,
26 ..Default::default()
27 })
28 }
29
30 pub fn with_config(config: SonaConfig) -> Self {
32 Self {
33 coordinator: LoopCoordinator::with_config(config.clone()),
34 config,
35 enabled: true,
36 }
37 }
38
39 pub fn begin_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryBuilder {
41 let id = self.coordinator.next_trajectory_id();
42 TrajectoryBuilder::new(id, query_embedding)
43 }
44
45 pub fn end_trajectory(&self, builder: TrajectoryBuilder, quality: f32) {
47 if !self.enabled {
48 return;
49 }
50
51 let trajectory = builder.build(quality);
52 self.coordinator.on_inference(trajectory);
53 }
54
55 pub fn submit_trajectory(&self, trajectory: QueryTrajectory) {
57 if self.enabled {
58 self.coordinator.on_inference(trajectory);
59 }
60 }
61
62 pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
64 if !self.enabled {
65 return;
66 }
67
68 if let Some(lora) = self.coordinator.micro_lora().try_read() {
69 lora.forward(input, output);
70 }
71 }
72
73 pub fn apply_base_lora(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
75 if !self.enabled {
76 return;
77 }
78
79 if let Some(lora) = self.coordinator.base_lora().try_read() {
80 lora.forward_layer(layer_idx, input, output);
81 }
82 }
83
84 pub fn tick(&self) -> Option<String> {
86 if !self.enabled {
87 return None;
88 }
89
90 if let Some(result) = self.coordinator.maybe_run_background() {
91 Some(format!(
92 "Background cycle: {} trajectories -> {} patterns in {:?}",
93 result.trajectories_processed,
94 result.patterns_extracted,
95 result.elapsed
96 ))
97 } else {
98 None
99 }
100 }
101
102 pub fn force_learn(&self) -> String {
104 let result = self.coordinator.force_background();
105 format!(
106 "Forced learning: {} trajectories -> {} patterns, status: {}",
107 result.trajectories_processed,
108 result.patterns_extracted,
109 result.status
110 )
111 }
112
113 pub fn flush(&self) {
115 self.coordinator.flush_instant();
116 }
117
118 pub fn find_patterns(&self, query_embedding: &[f32], k: usize) -> Vec<crate::LearnedPattern> {
120 self.coordinator
121 .reasoning_bank()
122 .read()
123 .find_similar(query_embedding, k)
124 .into_iter()
125 .cloned()
126 .collect()
127 }
128
129 pub fn stats(&self) -> CoordinatorStats {
131 self.coordinator.stats()
132 }
133
134 pub fn set_enabled(&mut self, enabled: bool) {
136 self.enabled = enabled;
137 }
138
139 pub fn is_enabled(&self) -> bool {
141 self.enabled
142 }
143
144 pub fn config(&self) -> &SonaConfig {
146 &self.config
147 }
148}
149
150pub struct SonaEngineBuilder {
152 config: SonaConfig,
153}
154
155impl SonaEngineBuilder {
156 pub fn new() -> Self {
158 Self {
159 config: SonaConfig::default(),
160 }
161 }
162
163 pub fn hidden_dim(mut self, dim: usize) -> Self {
165 self.config.hidden_dim = dim;
166 self.config.embedding_dim = dim;
167 self
168 }
169
170 pub fn micro_lora_rank(mut self, rank: usize) -> Self {
172 self.config.micro_lora_rank = rank.clamp(1, 2);
173 self
174 }
175
176 pub fn base_lora_rank(mut self, rank: usize) -> Self {
178 self.config.base_lora_rank = rank;
179 self
180 }
181
182 pub fn micro_lr(mut self, lr: f32) -> Self {
184 self.config.micro_lora_lr = lr;
185 self
186 }
187
188 pub fn base_lr(mut self, lr: f32) -> Self {
190 self.config.base_lora_lr = lr;
191 self
192 }
193
194 pub fn ewc_lambda(mut self, lambda: f32) -> Self {
196 self.config.ewc_lambda = lambda;
197 self
198 }
199
200 pub fn pattern_clusters(mut self, k: usize) -> Self {
202 self.config.pattern_clusters = k;
203 self
204 }
205
206 pub fn buffer_capacity(mut self, capacity: usize) -> Self {
208 self.config.trajectory_capacity = capacity;
209 self
210 }
211
212 pub fn quality_threshold(mut self, threshold: f32) -> Self {
214 self.config.quality_threshold = threshold;
215 self
216 }
217
218 pub fn build(self) -> SonaEngine {
220 SonaEngine::with_config(self.config)
221 }
222}
223
224impl Default for SonaEngineBuilder {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::types::TrajectoryStep;
234
235 #[test]
236 fn test_engine_creation() {
237 let engine = SonaEngine::new(256);
238 assert!(engine.is_enabled());
239 }
240
241 #[test]
242 fn test_builder() {
243 let engine = SonaEngineBuilder::new()
244 .hidden_dim(512)
245 .micro_lora_rank(2)
246 .base_lora_rank(16)
247 .micro_lr(0.002)
248 .ewc_lambda(500.0)
249 .build();
250
251 assert_eq!(engine.config().hidden_dim, 512);
252 assert_eq!(engine.config().micro_lora_rank, 2);
253 }
254
255 #[test]
256 fn test_trajectory_workflow() {
257 let engine = SonaEngine::new(64);
258
259 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
261 builder.add_step(vec![0.5; 64], vec![], 0.8);
262 builder.add_step(vec![0.6; 64], vec![], 0.9);
263
264 engine.end_trajectory(builder, 0.85);
266
267 let stats = engine.stats();
268 assert_eq!(stats.trajectories_buffered, 1);
269 }
270
271 #[test]
272 fn test_micro_lora_application() {
273 let engine = SonaEngine::new(64);
274
275 for i in 0..10 {
277 let mut builder = engine.begin_trajectory(vec![0.1; 64]);
278 builder.add_step(vec![0.5; 64], vec![], 0.8);
279 engine.end_trajectory(builder, 0.8);
280 }
281 engine.flush();
282
283 let input = vec![1.0; 64];
285 let mut output = vec![0.0; 64];
286 engine.apply_micro_lora(&input, &mut output);
287
288 }
290
291 #[test]
292 fn test_force_learn() {
293 let engine = SonaEngine::new(256);
294
295 for i in 0..150 {
296 let mut builder = engine.begin_trajectory(vec![0.1; 256]);
297 builder.add_step(vec![0.5; 256], vec![], 0.8);
298 engine.end_trajectory(builder, 0.8);
299 }
300
301 let result = engine.force_learn();
302 assert!(result.contains("150 trajectories"));
303 }
304
305 #[test]
306 fn test_disabled_engine() {
307 let mut engine = SonaEngine::new(64);
308 engine.set_enabled(false);
309
310 let builder = engine.begin_trajectory(vec![0.1; 64]);
311 engine.end_trajectory(builder, 0.8);
312
313 let stats = engine.stats();
315 assert_eq!(stats.trajectories_buffered, 0);
316 }
317}