ruvector_sona/
engine.rs

1//! SONA Engine - Main interface for self-optimizing neural architecture
2
3use crate::loops::coordinator::{CoordinatorStats, LoopCoordinator};
4use crate::lora::MicroLoRA;
5use crate::trajectory::TrajectoryBuilder;
6use crate::types::{QueryTrajectory, SonaConfig, LearnedPattern};
7use parking_lot::RwLock;
8use std::sync::Arc;
9
10#[cfg(feature = "serde-support")]
11use crate::export::safetensors::{LoRAState, LoRALayerState};
12
13#[cfg(feature = "serde-support")]
14use crate::export::dataset::{QualityTrajectory, RoutingDecision};
15
16/// Main SONA engine integrating all components
17pub struct SonaEngine {
18    /// Loop coordinator
19    coordinator: LoopCoordinator,
20    /// Configuration
21    config: SonaConfig,
22    /// Whether engine is enabled
23    enabled: bool,
24}
25
26impl SonaEngine {
27    /// Create new SONA engine with default config
28    pub fn new(hidden_dim: usize) -> Self {
29        Self::with_config(SonaConfig {
30            hidden_dim,
31            embedding_dim: hidden_dim,
32            ..Default::default()
33        })
34    }
35
36    /// Create with custom config
37    pub fn with_config(config: SonaConfig) -> Self {
38        Self {
39            coordinator: LoopCoordinator::with_config(config.clone()),
40            config,
41            enabled: true,
42        }
43    }
44
45    /// Start trajectory recording for a query
46    pub fn begin_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryBuilder {
47        let id = self.coordinator.next_trajectory_id();
48        TrajectoryBuilder::new(id, query_embedding)
49    }
50
51    /// Complete trajectory and submit for learning
52    pub fn end_trajectory(&self, builder: TrajectoryBuilder, quality: f32) {
53        if !self.enabled {
54            return;
55        }
56
57        let trajectory = builder.build(quality);
58        self.coordinator.on_inference(trajectory);
59    }
60
61    /// Submit pre-built trajectory
62    pub fn submit_trajectory(&self, trajectory: QueryTrajectory) {
63        if self.enabled {
64            self.coordinator.on_inference(trajectory);
65        }
66    }
67
68    /// Apply micro-LoRA to hidden states
69    pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
70        if !self.enabled {
71            return;
72        }
73
74        if let Some(lora) = self.coordinator.micro_lora().try_read() {
75            lora.forward(input, output);
76        }
77    }
78
79    /// Apply base-LoRA to layer output
80    pub fn apply_base_lora(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
81        if !self.enabled {
82            return;
83        }
84
85        if let Some(lora) = self.coordinator.base_lora().try_read() {
86            lora.forward_layer(layer_idx, input, output);
87        }
88    }
89
90    /// Run background learning cycle if due
91    pub fn tick(&self) -> Option<String> {
92        if !self.enabled {
93            return None;
94        }
95
96        if let Some(result) = self.coordinator.maybe_run_background() {
97            Some(format!(
98                "Background cycle: {} trajectories -> {} patterns in {:?}",
99                result.trajectories_processed,
100                result.patterns_extracted,
101                result.elapsed
102            ))
103        } else {
104            None
105        }
106    }
107
108    /// Force background learning cycle
109    pub fn force_learn(&self) -> String {
110        let result = self.coordinator.force_background();
111        format!(
112            "Forced learning: {} trajectories -> {} patterns, status: {}",
113            result.trajectories_processed,
114            result.patterns_extracted,
115            result.status
116        )
117    }
118
119    /// Flush instant loop updates
120    pub fn flush(&self) {
121        self.coordinator.flush_instant();
122    }
123
124    /// Find similar patterns to query
125    pub fn find_patterns(&self, query_embedding: &[f32], k: usize) -> Vec<crate::LearnedPattern> {
126        self.coordinator
127            .reasoning_bank()
128            .read()
129            .find_similar(query_embedding, k)
130            .into_iter()
131            .cloned()
132            .collect()
133    }
134
135    /// Get engine statistics
136    pub fn stats(&self) -> CoordinatorStats {
137        self.coordinator.stats()
138    }
139
140    /// Enable/disable engine
141    pub fn set_enabled(&mut self, enabled: bool) {
142        self.enabled = enabled;
143    }
144
145    /// Check if enabled
146    pub fn is_enabled(&self) -> bool {
147        self.enabled
148    }
149
150    /// Get config
151    pub fn config(&self) -> &SonaConfig {
152        &self.config
153    }
154
155    /// Get all learned patterns from ReasoningBank
156    pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
157        self.coordinator
158            .reasoning_bank()
159            .read()
160            .get_all_patterns()
161    }
162
163    /// Export LoRA state for SafeTensors serialization
164    #[cfg(feature = "serde-support")]
165    pub fn export_lora_state(&self) -> LoRAState {
166        let mut state = LoRAState::default();
167
168        // Export MicroLoRA weights
169        if let Some(micro_lora) = self.coordinator.micro_lora().try_read() {
170            let (lora_a, lora_b) = micro_lora.get_weights();
171            state.micro_lora_layers.push(LoRALayerState {
172                lora_a: lora_a.clone(),
173                lora_b: lora_b.clone(),
174                rank: self.config.micro_lora_rank,
175                input_dim: self.config.hidden_dim,
176                output_dim: self.config.hidden_dim,
177            });
178        }
179
180        // Export BaseLoRA weights
181        if let Some(base_lora) = self.coordinator.base_lora().try_read() {
182            for layer_idx in 0..base_lora.num_layers() {
183                if let Some((lora_a, lora_b)) = base_lora.get_layer_weights(layer_idx) {
184                    state.base_lora_layers.push(LoRALayerState {
185                        lora_a: lora_a.clone(),
186                        lora_b: lora_b.clone(),
187                        rank: self.config.base_lora_rank,
188                        input_dim: self.config.hidden_dim,
189                        output_dim: self.config.hidden_dim,
190                    });
191                }
192            }
193        }
194
195        state
196    }
197
198    /// Get quality trajectories for preference learning export
199    #[cfg(feature = "serde-support")]
200    pub fn get_quality_trajectories(&self) -> Vec<QualityTrajectory> {
201        self.coordinator
202            .trajectory_buffer()
203            .get_all()
204            .iter()
205            .map(|t| QualityTrajectory {
206                query_embedding: t.query_embedding.clone(),
207                response_embedding: t.steps.last()
208                    .map(|s| s.activations.clone())
209                    .unwrap_or_default(),
210                route: t.model_route.clone().unwrap_or_default(),
211                quality: t.final_quality,
212                context_ids: t.context_ids.clone(),
213            })
214            .collect()
215    }
216
217    /// Get routing decisions for distillation export
218    #[cfg(feature = "serde-support")]
219    pub fn get_routing_decisions(&self) -> Vec<RoutingDecision> {
220        // Extract routing decisions from learned patterns
221        self.get_all_patterns()
222            .iter()
223            .map(|p| RoutingDecision {
224                query_embedding: p.centroid.clone(),
225                routing_logits: vec![p.avg_quality; 4], // Placeholder logits
226                selected_route: p.pattern_type.to_string(),
227                confidence: p.avg_quality,
228                quality: p.avg_quality,
229            })
230            .collect()
231    }
232}
233
234/// Builder for SonaEngine
235pub struct SonaEngineBuilder {
236    config: SonaConfig,
237}
238
239impl SonaEngineBuilder {
240    /// Create new builder
241    pub fn new() -> Self {
242        Self {
243            config: SonaConfig::default(),
244        }
245    }
246
247    /// Set hidden dimension
248    pub fn hidden_dim(mut self, dim: usize) -> Self {
249        self.config.hidden_dim = dim;
250        self.config.embedding_dim = dim;
251        self
252    }
253
254    /// Set micro-LoRA rank
255    pub fn micro_lora_rank(mut self, rank: usize) -> Self {
256        self.config.micro_lora_rank = rank.clamp(1, 2);
257        self
258    }
259
260    /// Set base-LoRA rank
261    pub fn base_lora_rank(mut self, rank: usize) -> Self {
262        self.config.base_lora_rank = rank;
263        self
264    }
265
266    /// Set micro-LoRA learning rate
267    pub fn micro_lr(mut self, lr: f32) -> Self {
268        self.config.micro_lora_lr = lr;
269        self
270    }
271
272    /// Set base-LoRA learning rate
273    pub fn base_lr(mut self, lr: f32) -> Self {
274        self.config.base_lora_lr = lr;
275        self
276    }
277
278    /// Set EWC lambda
279    pub fn ewc_lambda(mut self, lambda: f32) -> Self {
280        self.config.ewc_lambda = lambda;
281        self
282    }
283
284    /// Set pattern clusters
285    pub fn pattern_clusters(mut self, k: usize) -> Self {
286        self.config.pattern_clusters = k;
287        self
288    }
289
290    /// Set trajectory buffer capacity
291    pub fn buffer_capacity(mut self, capacity: usize) -> Self {
292        self.config.trajectory_capacity = capacity;
293        self
294    }
295
296    /// Set quality threshold
297    pub fn quality_threshold(mut self, threshold: f32) -> Self {
298        self.config.quality_threshold = threshold;
299        self
300    }
301
302    /// Build the engine
303    pub fn build(self) -> SonaEngine {
304        SonaEngine::with_config(self.config)
305    }
306}
307
308impl Default for SonaEngineBuilder {
309    fn default() -> Self {
310        Self::new()
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::types::TrajectoryStep;
318
319    #[test]
320    fn test_engine_creation() {
321        let engine = SonaEngine::new(256);
322        assert!(engine.is_enabled());
323    }
324
325    #[test]
326    fn test_builder() {
327        let engine = SonaEngineBuilder::new()
328            .hidden_dim(512)
329            .micro_lora_rank(2)
330            .base_lora_rank(16)
331            .micro_lr(0.002)
332            .ewc_lambda(500.0)
333            .build();
334
335        assert_eq!(engine.config().hidden_dim, 512);
336        assert_eq!(engine.config().micro_lora_rank, 2);
337    }
338
339    #[test]
340    fn test_trajectory_workflow() {
341        let engine = SonaEngine::new(64);
342
343        // Begin trajectory
344        let mut builder = engine.begin_trajectory(vec![0.1; 64]);
345        builder.add_step(vec![0.5; 64], vec![], 0.8);
346        builder.add_step(vec![0.6; 64], vec![], 0.9);
347
348        // End trajectory
349        engine.end_trajectory(builder, 0.85);
350
351        let stats = engine.stats();
352        assert_eq!(stats.trajectories_buffered, 1);
353    }
354
355    #[test]
356    fn test_micro_lora_application() {
357        let engine = SonaEngine::new(64);
358
359        // Train a bit first
360        for i in 0..10 {
361            let mut builder = engine.begin_trajectory(vec![0.1; 64]);
362            builder.add_step(vec![0.5; 64], vec![], 0.8);
363            engine.end_trajectory(builder, 0.8);
364        }
365        engine.flush();
366
367        // Apply LoRA
368        let input = vec![1.0; 64];
369        let mut output = vec![0.0; 64];
370        engine.apply_micro_lora(&input, &mut output);
371
372        // Output may or may not be modified depending on accumulated gradients
373    }
374
375    #[test]
376    fn test_force_learn() {
377        let engine = SonaEngine::new(256);
378
379        for i in 0..150 {
380            let mut builder = engine.begin_trajectory(vec![0.1; 256]);
381            builder.add_step(vec![0.5; 256], vec![], 0.8);
382            engine.end_trajectory(builder, 0.8);
383        }
384
385        let result = engine.force_learn();
386        assert!(result.contains("150 trajectories"));
387    }
388
389    #[test]
390    fn test_disabled_engine() {
391        let mut engine = SonaEngine::new(64);
392        engine.set_enabled(false);
393
394        let builder = engine.begin_trajectory(vec![0.1; 64]);
395        engine.end_trajectory(builder, 0.8);
396
397        // Should not record when disabled
398        let stats = engine.stats();
399        assert_eq!(stats.trajectories_buffered, 0);
400    }
401}