ruvector_sona/
engine.rs

1//! SONA Engine - Main interface for self-optimizing neural architecture
2
3use crate::loops::coordinator::{CoordinatorStats, LoopCoordinator};
4use crate::lora::MicroLoRA;
5use crate::trajectory::TrajectoryBuilder;
6use crate::types::{QueryTrajectory, SonaConfig};
7use parking_lot::RwLock;
8use std::sync::Arc;
9
10/// Main SONA engine integrating all components
11pub struct SonaEngine {
12    /// Loop coordinator
13    coordinator: LoopCoordinator,
14    /// Configuration
15    config: SonaConfig,
16    /// Whether engine is enabled
17    enabled: bool,
18}
19
20impl SonaEngine {
21    /// Create new SONA engine with default config
22    pub fn new(hidden_dim: usize) -> Self {
23        Self::with_config(SonaConfig {
24            hidden_dim,
25            embedding_dim: hidden_dim,
26            ..Default::default()
27        })
28    }
29
30    /// Create with custom config
31    pub fn with_config(config: SonaConfig) -> Self {
32        Self {
33            coordinator: LoopCoordinator::with_config(config.clone()),
34            config,
35            enabled: true,
36        }
37    }
38
39    /// Start trajectory recording for a query
40    pub fn begin_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryBuilder {
41        let id = self.coordinator.next_trajectory_id();
42        TrajectoryBuilder::new(id, query_embedding)
43    }
44
45    /// Complete trajectory and submit for learning
46    pub fn end_trajectory(&self, builder: TrajectoryBuilder, quality: f32) {
47        if !self.enabled {
48            return;
49        }
50
51        let trajectory = builder.build(quality);
52        self.coordinator.on_inference(trajectory);
53    }
54
55    /// Submit pre-built trajectory
56    pub fn submit_trajectory(&self, trajectory: QueryTrajectory) {
57        if self.enabled {
58            self.coordinator.on_inference(trajectory);
59        }
60    }
61
62    /// Apply micro-LoRA to hidden states
63    pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
64        if !self.enabled {
65            return;
66        }
67
68        if let Some(lora) = self.coordinator.micro_lora().try_read() {
69            lora.forward(input, output);
70        }
71    }
72
73    /// Apply base-LoRA to layer output
74    pub fn apply_base_lora(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
75        if !self.enabled {
76            return;
77        }
78
79        if let Some(lora) = self.coordinator.base_lora().try_read() {
80            lora.forward_layer(layer_idx, input, output);
81        }
82    }
83
84    /// Run background learning cycle if due
85    pub fn tick(&self) -> Option<String> {
86        if !self.enabled {
87            return None;
88        }
89
90        if let Some(result) = self.coordinator.maybe_run_background() {
91            Some(format!(
92                "Background cycle: {} trajectories -> {} patterns in {:?}",
93                result.trajectories_processed,
94                result.patterns_extracted,
95                result.elapsed
96            ))
97        } else {
98            None
99        }
100    }
101
102    /// Force background learning cycle
103    pub fn force_learn(&self) -> String {
104        let result = self.coordinator.force_background();
105        format!(
106            "Forced learning: {} trajectories -> {} patterns, status: {}",
107            result.trajectories_processed,
108            result.patterns_extracted,
109            result.status
110        )
111    }
112
113    /// Flush instant loop updates
114    pub fn flush(&self) {
115        self.coordinator.flush_instant();
116    }
117
118    /// Find similar patterns to query
119    pub fn find_patterns(&self, query_embedding: &[f32], k: usize) -> Vec<crate::LearnedPattern> {
120        self.coordinator
121            .reasoning_bank()
122            .read()
123            .find_similar(query_embedding, k)
124            .into_iter()
125            .cloned()
126            .collect()
127    }
128
129    /// Get engine statistics
130    pub fn stats(&self) -> CoordinatorStats {
131        self.coordinator.stats()
132    }
133
134    /// Enable/disable engine
135    pub fn set_enabled(&mut self, enabled: bool) {
136        self.enabled = enabled;
137    }
138
139    /// Check if enabled
140    pub fn is_enabled(&self) -> bool {
141        self.enabled
142    }
143
144    /// Get config
145    pub fn config(&self) -> &SonaConfig {
146        &self.config
147    }
148
149    /// Get all learned patterns from the reasoning bank
150    #[cfg(feature = "serde-support")]
151    pub fn get_all_patterns(&self) -> Vec<crate::LearnedPattern> {
152        self.coordinator.reasoning_bank().read().get_all_patterns()
153    }
154
155    /// Export LoRA state for serialization
156    #[cfg(feature = "serde-support")]
157    pub fn export_lora_state(&self) -> crate::export::safetensors::LoRAState {
158        use crate::export::safetensors::{LoRAState, LoRALayerState};
159
160        let mut state = LoRAState::default();
161
162        // Export MicroLoRA (single layer)
163        if let Some(lora) = self.coordinator.micro_lora().try_read() {
164            let (down, up) = lora.get_weights();
165            state.micro_lora_layers.push(LoRALayerState {
166                lora_a: down.clone(),
167                lora_b: up.clone(),
168                rank: self.config.micro_lora_rank,
169                input_dim: self.config.hidden_dim,
170                output_dim: self.config.hidden_dim,
171            });
172        }
173
174        // Export BaseLoRA (multi-layer)
175        if let Some(lora) = self.coordinator.base_lora().try_read() {
176            for idx in 0..lora.num_layers() {
177                if let Some((down, up)) = lora.get_layer_weights(idx) {
178                    state.base_lora_layers.push(LoRALayerState {
179                        lora_a: down.clone(),
180                        lora_b: up.clone(),
181                        rank: lora.rank,
182                        input_dim: lora.hidden_dim,
183                        output_dim: lora.hidden_dim,
184                    });
185                }
186            }
187        }
188
189        state
190    }
191
192    /// Get quality trajectories for preference learning export
193    #[cfg(feature = "serde-support")]
194    pub fn get_quality_trajectories(&self) -> Vec<crate::export::dataset::QualityTrajectory> {
195        use crate::export::dataset::QualityTrajectory;
196
197        // Get buffered trajectories from the instant loop via coordinator
198        let trajectories = self.coordinator.reasoning_bank().read().get_all_patterns();
199
200        trajectories.iter().map(|p| {
201            QualityTrajectory {
202                query_embedding: p.centroid.clone(),
203                response_embedding: p.centroid.clone(), // Use centroid as proxy
204                route: p.pattern_type.to_string(),
205                quality: p.avg_quality,
206                context_ids: vec![],
207            }
208        }).collect()
209    }
210
211    /// Get routing decisions for distillation export
212    #[cfg(feature = "serde-support")]
213    pub fn get_routing_decisions(&self) -> Vec<crate::export::dataset::RoutingDecision> {
214        use crate::export::dataset::RoutingDecision;
215
216        let patterns = self.coordinator.reasoning_bank().read().get_all_patterns();
217
218        patterns.iter().map(|p| {
219            RoutingDecision {
220                query_embedding: p.centroid.clone(),
221                routing_logits: vec![p.avg_quality], // Simplified
222                selected_route: p.pattern_type.to_string(),
223                confidence: p.avg_quality,
224                quality: p.avg_quality,
225            }
226        }).collect()
227    }
228}
229
230/// Builder for SonaEngine
231pub struct SonaEngineBuilder {
232    config: SonaConfig,
233}
234
235impl SonaEngineBuilder {
236    /// Create new builder
237    pub fn new() -> Self {
238        Self {
239            config: SonaConfig::default(),
240        }
241    }
242
243    /// Set hidden dimension
244    pub fn hidden_dim(mut self, dim: usize) -> Self {
245        self.config.hidden_dim = dim;
246        self.config.embedding_dim = dim;
247        self
248    }
249
250    /// Set micro-LoRA rank
251    pub fn micro_lora_rank(mut self, rank: usize) -> Self {
252        self.config.micro_lora_rank = rank.clamp(1, 2);
253        self
254    }
255
256    /// Set base-LoRA rank
257    pub fn base_lora_rank(mut self, rank: usize) -> Self {
258        self.config.base_lora_rank = rank;
259        self
260    }
261
262    /// Set micro-LoRA learning rate
263    pub fn micro_lr(mut self, lr: f32) -> Self {
264        self.config.micro_lora_lr = lr;
265        self
266    }
267
268    /// Set base-LoRA learning rate
269    pub fn base_lr(mut self, lr: f32) -> Self {
270        self.config.base_lora_lr = lr;
271        self
272    }
273
274    /// Set EWC lambda
275    pub fn ewc_lambda(mut self, lambda: f32) -> Self {
276        self.config.ewc_lambda = lambda;
277        self
278    }
279
280    /// Set pattern clusters
281    pub fn pattern_clusters(mut self, k: usize) -> Self {
282        self.config.pattern_clusters = k;
283        self
284    }
285
286    /// Set trajectory buffer capacity
287    pub fn buffer_capacity(mut self, capacity: usize) -> Self {
288        self.config.trajectory_capacity = capacity;
289        self
290    }
291
292    /// Set quality threshold
293    pub fn quality_threshold(mut self, threshold: f32) -> Self {
294        self.config.quality_threshold = threshold;
295        self
296    }
297
298    /// Build the engine
299    pub fn build(self) -> SonaEngine {
300        SonaEngine::with_config(self.config)
301    }
302}
303
304impl Default for SonaEngineBuilder {
305    fn default() -> Self {
306        Self::new()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::types::TrajectoryStep;
314
315    #[test]
316    fn test_engine_creation() {
317        let engine = SonaEngine::new(256);
318        assert!(engine.is_enabled());
319    }
320
321    #[test]
322    fn test_builder() {
323        let engine = SonaEngineBuilder::new()
324            .hidden_dim(512)
325            .micro_lora_rank(2)
326            .base_lora_rank(16)
327            .micro_lr(0.002)
328            .ewc_lambda(500.0)
329            .build();
330
331        assert_eq!(engine.config().hidden_dim, 512);
332        assert_eq!(engine.config().micro_lora_rank, 2);
333    }
334
335    #[test]
336    fn test_trajectory_workflow() {
337        let engine = SonaEngine::new(64);
338
339        // Begin trajectory
340        let mut builder = engine.begin_trajectory(vec![0.1; 64]);
341        builder.add_step(vec![0.5; 64], vec![], 0.8);
342        builder.add_step(vec![0.6; 64], vec![], 0.9);
343
344        // End trajectory
345        engine.end_trajectory(builder, 0.85);
346
347        let stats = engine.stats();
348        assert_eq!(stats.trajectories_buffered, 1);
349    }
350
351    #[test]
352    fn test_micro_lora_application() {
353        let engine = SonaEngine::new(64);
354
355        // Train a bit first
356        for i in 0..10 {
357            let mut builder = engine.begin_trajectory(vec![0.1; 64]);
358            builder.add_step(vec![0.5; 64], vec![], 0.8);
359            engine.end_trajectory(builder, 0.8);
360        }
361        engine.flush();
362
363        // Apply LoRA
364        let input = vec![1.0; 64];
365        let mut output = vec![0.0; 64];
366        engine.apply_micro_lora(&input, &mut output);
367
368        // Output may or may not be modified depending on accumulated gradients
369    }
370
371    #[test]
372    fn test_force_learn() {
373        let engine = SonaEngine::new(256);
374
375        for i in 0..150 {
376            let mut builder = engine.begin_trajectory(vec![0.1; 256]);
377            builder.add_step(vec![0.5; 256], vec![], 0.8);
378            engine.end_trajectory(builder, 0.8);
379        }
380
381        let result = engine.force_learn();
382        assert!(result.contains("150 trajectories"));
383    }
384
385    #[test]
386    fn test_disabled_engine() {
387        let mut engine = SonaEngine::new(64);
388        engine.set_enabled(false);
389
390        let builder = engine.begin_trajectory(vec![0.1; 64]);
391        engine.end_trajectory(builder, 0.8);
392
393        // Should not record when disabled
394        let stats = engine.stats();
395        assert_eq!(stats.trajectories_buffered, 0);
396    }
397}