Skip to main content

ruvector_sona/
engine.rs

1//! SONA Engine - Main interface for self-optimizing neural architecture
2
3use crate::loops::coordinator::{CoordinatorStats, LoopCoordinator};
4use crate::trajectory::TrajectoryBuilder;
5use crate::types::{QueryTrajectory, SonaConfig};
6
7/// Main SONA engine integrating all components
8pub struct SonaEngine {
9    /// Loop coordinator
10    coordinator: LoopCoordinator,
11    /// Configuration
12    config: SonaConfig,
13    /// Whether engine is enabled
14    enabled: bool,
15}
16
17impl std::fmt::Debug for SonaEngine {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("SonaEngine")
20            .field("config", &self.config)
21            .field("enabled", &self.enabled)
22            .finish_non_exhaustive()
23    }
24}
25
26impl SonaEngine {
27    /// Create new SONA engine with default config
28    pub fn new(hidden_dim: usize) -> Self {
29        Self::with_config(SonaConfig {
30            hidden_dim,
31            embedding_dim: hidden_dim,
32            ..Default::default()
33        })
34    }
35
36    /// Create with custom config
37    pub fn with_config(config: SonaConfig) -> Self {
38        Self {
39            coordinator: LoopCoordinator::with_config(config.clone()),
40            config,
41            enabled: true,
42        }
43    }
44
45    /// Start trajectory recording for a query
46    pub fn begin_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryBuilder {
47        let id = self.coordinator.next_trajectory_id();
48        TrajectoryBuilder::new(id, query_embedding)
49    }
50
51    /// Complete trajectory and submit for learning
52    pub fn end_trajectory(&self, builder: TrajectoryBuilder, quality: f32) {
53        if !self.enabled {
54            return;
55        }
56
57        let trajectory = builder.build(quality);
58        self.coordinator.on_inference(trajectory);
59    }
60
61    /// Submit pre-built trajectory
62    pub fn submit_trajectory(&self, trajectory: QueryTrajectory) {
63        if self.enabled {
64            self.coordinator.on_inference(trajectory);
65        }
66    }
67
68    /// Apply micro-LoRA to hidden states
69    pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
70        if !self.enabled {
71            return;
72        }
73
74        if let Some(lora) = self.coordinator.micro_lora().try_read() {
75            lora.forward(input, output);
76        }
77    }
78
79    /// Apply base-LoRA to layer output
80    pub fn apply_base_lora(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
81        if !self.enabled {
82            return;
83        }
84
85        if let Some(lora) = self.coordinator.base_lora().try_read() {
86            lora.forward_layer(layer_idx, input, output);
87        }
88    }
89
90    /// Run background learning cycle if due
91    pub fn tick(&self) -> Option<String> {
92        if !self.enabled {
93            return None;
94        }
95
96        if let Some(result) = self.coordinator.maybe_run_background() {
97            Some(format!(
98                "Background cycle: {} trajectories -> {} patterns in {:?}",
99                result.trajectories_processed, result.patterns_extracted, result.elapsed
100            ))
101        } else {
102            None
103        }
104    }
105
106    /// Force background learning cycle
107    pub fn force_learn(&self) -> String {
108        let result = self.coordinator.force_background();
109        format!(
110            "Forced learning: {} trajectories -> {} patterns, status: {}",
111            result.trajectories_processed, result.patterns_extracted, result.status
112        )
113    }
114
115    /// Flush instant loop updates
116    pub fn flush(&self) {
117        self.coordinator.flush_instant();
118    }
119
120    /// Find similar patterns to query
121    pub fn find_patterns(&self, query_embedding: &[f32], k: usize) -> Vec<crate::LearnedPattern> {
122        self.coordinator
123            .reasoning_bank()
124            .read()
125            .find_similar(query_embedding, k)
126            .into_iter()
127            .cloned()
128            .collect()
129    }
130
131    /// Get engine statistics
132    pub fn stats(&self) -> CoordinatorStats {
133        self.coordinator.stats()
134    }
135
136    /// Get coordinator reference for state serialization (fixes #274)
137    pub fn coordinator(&self) -> &LoopCoordinator {
138        &self.coordinator
139    }
140
141    /// Enable/disable engine
142    pub fn set_enabled(&mut self, enabled: bool) {
143        self.enabled = enabled;
144    }
145
146    /// Check if enabled
147    pub fn is_enabled(&self) -> bool {
148        self.enabled
149    }
150
151    /// Get config
152    pub fn config(&self) -> &SonaConfig {
153        &self.config
154    }
155
156    /// Get all learned patterns from the reasoning bank
157    #[cfg(feature = "serde-support")]
158    pub fn get_all_patterns(&self) -> Vec<crate::LearnedPattern> {
159        self.coordinator.reasoning_bank().read().get_all_patterns()
160    }
161
162    /// Export LoRA state for serialization
163    #[cfg(feature = "serde-support")]
164    pub fn export_lora_state(&self) -> crate::export::safetensors::LoRAState {
165        use crate::export::safetensors::{LoRALayerState, LoRAState};
166
167        let mut state = LoRAState::default();
168
169        // Export MicroLoRA (single layer)
170        if let Some(lora) = self.coordinator.micro_lora().try_read() {
171            let (down, up) = lora.get_weights();
172            state.micro_lora_layers.push(LoRALayerState {
173                lora_a: down.clone(),
174                lora_b: up.clone(),
175                rank: self.config.micro_lora_rank,
176                input_dim: self.config.hidden_dim,
177                output_dim: self.config.hidden_dim,
178            });
179        }
180
181        // Export BaseLoRA (multi-layer)
182        if let Some(lora) = self.coordinator.base_lora().try_read() {
183            for idx in 0..lora.num_layers() {
184                if let Some((down, up)) = lora.get_layer_weights(idx) {
185                    state.base_lora_layers.push(LoRALayerState {
186                        lora_a: down.clone(),
187                        lora_b: up.clone(),
188                        rank: lora.rank,
189                        input_dim: lora.hidden_dim,
190                        output_dim: lora.hidden_dim,
191                    });
192                }
193            }
194        }
195
196        state
197    }
198
199    /// Get quality trajectories for preference learning export
200    #[cfg(feature = "serde-support")]
201    pub fn get_quality_trajectories(&self) -> Vec<crate::export::dataset::QualityTrajectory> {
202        use crate::export::dataset::QualityTrajectory;
203
204        // Get buffered trajectories from the instant loop via coordinator
205        let trajectories = self.coordinator.reasoning_bank().read().get_all_patterns();
206
207        trajectories
208            .iter()
209            .map(|p| {
210                QualityTrajectory {
211                    query_embedding: p.centroid.clone(),
212                    response_embedding: p.centroid.clone(), // Use centroid as proxy
213                    route: p.pattern_type.to_string(),
214                    quality: p.avg_quality,
215                    context_ids: vec![],
216                }
217            })
218            .collect()
219    }
220
221    /// Get routing decisions for distillation export
222    #[cfg(feature = "serde-support")]
223    pub fn get_routing_decisions(&self) -> Vec<crate::export::dataset::RoutingDecision> {
224        use crate::export::dataset::RoutingDecision;
225
226        let patterns = self.coordinator.reasoning_bank().read().get_all_patterns();
227
228        patterns
229            .iter()
230            .map(|p| {
231                RoutingDecision {
232                    query_embedding: p.centroid.clone(),
233                    routing_logits: vec![p.avg_quality], // Simplified
234                    selected_route: p.pattern_type.to_string(),
235                    confidence: p.avg_quality,
236                    quality: p.avg_quality,
237                }
238            })
239            .collect()
240    }
241}
242
243/// Builder for SonaEngine
244pub struct SonaEngineBuilder {
245    config: SonaConfig,
246}
247
248impl SonaEngineBuilder {
249    /// Create new builder
250    pub fn new() -> Self {
251        Self {
252            config: SonaConfig::default(),
253        }
254    }
255
256    /// Set hidden dimension
257    pub fn hidden_dim(mut self, dim: usize) -> Self {
258        self.config.hidden_dim = dim;
259        self.config.embedding_dim = dim;
260        self
261    }
262
263    /// Set micro-LoRA rank
264    pub fn micro_lora_rank(mut self, rank: usize) -> Self {
265        self.config.micro_lora_rank = rank.clamp(1, 2);
266        self
267    }
268
269    /// Set base-LoRA rank
270    pub fn base_lora_rank(mut self, rank: usize) -> Self {
271        self.config.base_lora_rank = rank;
272        self
273    }
274
275    /// Set micro-LoRA learning rate
276    pub fn micro_lr(mut self, lr: f32) -> Self {
277        self.config.micro_lora_lr = lr;
278        self
279    }
280
281    /// Set base-LoRA learning rate
282    pub fn base_lr(mut self, lr: f32) -> Self {
283        self.config.base_lora_lr = lr;
284        self
285    }
286
287    /// Set EWC lambda
288    pub fn ewc_lambda(mut self, lambda: f32) -> Self {
289        self.config.ewc_lambda = lambda;
290        self
291    }
292
293    /// Set pattern clusters
294    pub fn pattern_clusters(mut self, k: usize) -> Self {
295        self.config.pattern_clusters = k;
296        self
297    }
298
299    /// Set trajectory buffer capacity
300    pub fn buffer_capacity(mut self, capacity: usize) -> Self {
301        self.config.trajectory_capacity = capacity;
302        self
303    }
304
305    /// Set quality threshold
306    pub fn quality_threshold(mut self, threshold: f32) -> Self {
307        self.config.quality_threshold = threshold;
308        self
309    }
310
311    /// Build the engine
312    pub fn build(self) -> SonaEngine {
313        SonaEngine::with_config(self.config)
314    }
315}
316
317impl Default for SonaEngineBuilder {
318    fn default() -> Self {
319        Self::new()
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::types::TrajectoryStep;
327
328    #[test]
329    fn test_engine_creation() {
330        let engine = SonaEngine::new(256);
331        assert!(engine.is_enabled());
332    }
333
334    #[test]
335    fn test_builder() {
336        let engine = SonaEngineBuilder::new()
337            .hidden_dim(512)
338            .micro_lora_rank(2)
339            .base_lora_rank(16)
340            .micro_lr(0.002)
341            .ewc_lambda(500.0)
342            .build();
343
344        assert_eq!(engine.config().hidden_dim, 512);
345        assert_eq!(engine.config().micro_lora_rank, 2);
346    }
347
348    #[test]
349    fn test_trajectory_workflow() {
350        let engine = SonaEngine::new(64);
351
352        // Begin trajectory
353        let mut builder = engine.begin_trajectory(vec![0.1; 64]);
354        builder.add_step(vec![0.5; 64], vec![], 0.8);
355        builder.add_step(vec![0.6; 64], vec![], 0.9);
356
357        // End trajectory
358        engine.end_trajectory(builder, 0.85);
359
360        let stats = engine.stats();
361        assert_eq!(stats.trajectories_buffered, 1);
362    }
363
364    #[test]
365    fn test_micro_lora_application() {
366        let engine = SonaEngine::new(64);
367
368        // Train a bit first
369        for i in 0..10 {
370            let mut builder = engine.begin_trajectory(vec![0.1; 64]);
371            builder.add_step(vec![0.5; 64], vec![], 0.8);
372            engine.end_trajectory(builder, 0.8);
373        }
374        engine.flush();
375
376        // Apply LoRA
377        let input = vec![1.0; 64];
378        let mut output = vec![0.0; 64];
379        engine.apply_micro_lora(&input, &mut output);
380
381        // Output may or may not be modified depending on accumulated gradients
382    }
383
384    #[test]
385    fn test_force_learn() {
386        let engine = SonaEngine::new(256);
387
388        for _i in 0..150 {
389            let mut builder = engine.begin_trajectory(vec![0.1; 256]);
390            builder.add_step(vec![0.5; 256], vec![], 0.8);
391            engine.end_trajectory(builder, 0.8);
392        }
393
394        let result = engine.force_learn();
395        assert!(result.contains("150 trajectories"));
396    }
397
398    #[test]
399    fn test_force_learn_with_few_trajectories() {
400        // Test that forceLearn works even with fewer than min_trajectories (100)
401        let engine = SonaEngine::new(64);
402
403        // Only record 10 trajectories (below the 100 minimum)
404        for _i in 0..10 {
405            let mut builder = engine.begin_trajectory(vec![0.1; 64]);
406            builder.add_step(vec![0.5; 64], vec![], 0.8);
407            engine.end_trajectory(builder, 0.8);
408        }
409
410        let result = engine.force_learn();
411        // Should process 10 trajectories (not "insufficient trajectories")
412        assert!(
413            result.contains("10 trajectories"),
414            "Expected '10 trajectories' but got: {}",
415            result
416        );
417    }
418
419    #[test]
420    fn test_disabled_engine() {
421        let mut engine = SonaEngine::new(64);
422        engine.set_enabled(false);
423
424        let builder = engine.begin_trajectory(vec![0.1; 64]);
425        engine.end_trajectory(builder, 0.8);
426
427        // Should not record when disabled
428        let stats = engine.stats();
429        assert_eq!(stats.trajectories_buffered, 0);
430    }
431}