Skip to main content

ruvector_sona/
engine.rs

1//! SONA Engine - Main interface for self-optimizing neural architecture
2
3use crate::loops::coordinator::{CoordinatorStats, LoopCoordinator};
4use crate::lora::MicroLoRA;
5use crate::trajectory::TrajectoryBuilder;
6use crate::types::{QueryTrajectory, SonaConfig};
7use parking_lot::RwLock;
8use std::sync::Arc;
9
10/// Main SONA engine integrating all components
11pub struct SonaEngine {
12    /// Loop coordinator
13    coordinator: LoopCoordinator,
14    /// Configuration
15    config: SonaConfig,
16    /// Whether engine is enabled
17    enabled: bool,
18}
19
20impl std::fmt::Debug for SonaEngine {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        f.debug_struct("SonaEngine")
23            .field("config", &self.config)
24            .field("enabled", &self.enabled)
25            .finish_non_exhaustive()
26    }
27}
28
29impl SonaEngine {
30    /// Create new SONA engine with default config
31    pub fn new(hidden_dim: usize) -> Self {
32        Self::with_config(SonaConfig {
33            hidden_dim,
34            embedding_dim: hidden_dim,
35            ..Default::default()
36        })
37    }
38
39    /// Create with custom config
40    pub fn with_config(config: SonaConfig) -> Self {
41        Self {
42            coordinator: LoopCoordinator::with_config(config.clone()),
43            config,
44            enabled: true,
45        }
46    }
47
48    /// Start trajectory recording for a query
49    pub fn begin_trajectory(&self, query_embedding: Vec<f32>) -> TrajectoryBuilder {
50        let id = self.coordinator.next_trajectory_id();
51        TrajectoryBuilder::new(id, query_embedding)
52    }
53
54    /// Complete trajectory and submit for learning
55    pub fn end_trajectory(&self, builder: TrajectoryBuilder, quality: f32) {
56        if !self.enabled {
57            return;
58        }
59
60        let trajectory = builder.build(quality);
61        self.coordinator.on_inference(trajectory);
62    }
63
64    /// Submit pre-built trajectory
65    pub fn submit_trajectory(&self, trajectory: QueryTrajectory) {
66        if self.enabled {
67            self.coordinator.on_inference(trajectory);
68        }
69    }
70
71    /// Apply micro-LoRA to hidden states
72    pub fn apply_micro_lora(&self, input: &[f32], output: &mut [f32]) {
73        if !self.enabled {
74            return;
75        }
76
77        if let Some(lora) = self.coordinator.micro_lora().try_read() {
78            lora.forward(input, output);
79        }
80    }
81
82    /// Apply base-LoRA to layer output
83    pub fn apply_base_lora(&self, layer_idx: usize, input: &[f32], output: &mut [f32]) {
84        if !self.enabled {
85            return;
86        }
87
88        if let Some(lora) = self.coordinator.base_lora().try_read() {
89            lora.forward_layer(layer_idx, input, output);
90        }
91    }
92
93    /// Run background learning cycle if due
94    pub fn tick(&self) -> Option<String> {
95        if !self.enabled {
96            return None;
97        }
98
99        if let Some(result) = self.coordinator.maybe_run_background() {
100            Some(format!(
101                "Background cycle: {} trajectories -> {} patterns in {:?}",
102                result.trajectories_processed, result.patterns_extracted, result.elapsed
103            ))
104        } else {
105            None
106        }
107    }
108
109    /// Force background learning cycle
110    pub fn force_learn(&self) -> String {
111        let result = self.coordinator.force_background();
112        format!(
113            "Forced learning: {} trajectories -> {} patterns, status: {}",
114            result.trajectories_processed, result.patterns_extracted, result.status
115        )
116    }
117
118    /// Flush instant loop updates
119    pub fn flush(&self) {
120        self.coordinator.flush_instant();
121    }
122
123    /// Find similar patterns to query
124    pub fn find_patterns(&self, query_embedding: &[f32], k: usize) -> Vec<crate::LearnedPattern> {
125        self.coordinator
126            .reasoning_bank()
127            .read()
128            .find_similar(query_embedding, k)
129            .into_iter()
130            .cloned()
131            .collect()
132    }
133
134    /// Get engine statistics
135    pub fn stats(&self) -> CoordinatorStats {
136        self.coordinator.stats()
137    }
138
139    /// Enable/disable engine
140    pub fn set_enabled(&mut self, enabled: bool) {
141        self.enabled = enabled;
142    }
143
144    /// Check if enabled
145    pub fn is_enabled(&self) -> bool {
146        self.enabled
147    }
148
149    /// Get config
150    pub fn config(&self) -> &SonaConfig {
151        &self.config
152    }
153
154    /// Get all learned patterns from the reasoning bank
155    #[cfg(feature = "serde-support")]
156    pub fn get_all_patterns(&self) -> Vec<crate::LearnedPattern> {
157        self.coordinator.reasoning_bank().read().get_all_patterns()
158    }
159
160    /// Export LoRA state for serialization
161    #[cfg(feature = "serde-support")]
162    pub fn export_lora_state(&self) -> crate::export::safetensors::LoRAState {
163        use crate::export::safetensors::{LoRALayerState, LoRAState};
164
165        let mut state = LoRAState::default();
166
167        // Export MicroLoRA (single layer)
168        if let Some(lora) = self.coordinator.micro_lora().try_read() {
169            let (down, up) = lora.get_weights();
170            state.micro_lora_layers.push(LoRALayerState {
171                lora_a: down.clone(),
172                lora_b: up.clone(),
173                rank: self.config.micro_lora_rank,
174                input_dim: self.config.hidden_dim,
175                output_dim: self.config.hidden_dim,
176            });
177        }
178
179        // Export BaseLoRA (multi-layer)
180        if let Some(lora) = self.coordinator.base_lora().try_read() {
181            for idx in 0..lora.num_layers() {
182                if let Some((down, up)) = lora.get_layer_weights(idx) {
183                    state.base_lora_layers.push(LoRALayerState {
184                        lora_a: down.clone(),
185                        lora_b: up.clone(),
186                        rank: lora.rank,
187                        input_dim: lora.hidden_dim,
188                        output_dim: lora.hidden_dim,
189                    });
190                }
191            }
192        }
193
194        state
195    }
196
197    /// Get quality trajectories for preference learning export
198    #[cfg(feature = "serde-support")]
199    pub fn get_quality_trajectories(&self) -> Vec<crate::export::dataset::QualityTrajectory> {
200        use crate::export::dataset::QualityTrajectory;
201
202        // Get buffered trajectories from the instant loop via coordinator
203        let trajectories = self.coordinator.reasoning_bank().read().get_all_patterns();
204
205        trajectories
206            .iter()
207            .map(|p| {
208                QualityTrajectory {
209                    query_embedding: p.centroid.clone(),
210                    response_embedding: p.centroid.clone(), // Use centroid as proxy
211                    route: p.pattern_type.to_string(),
212                    quality: p.avg_quality,
213                    context_ids: vec![],
214                }
215            })
216            .collect()
217    }
218
219    /// Get routing decisions for distillation export
220    #[cfg(feature = "serde-support")]
221    pub fn get_routing_decisions(&self) -> Vec<crate::export::dataset::RoutingDecision> {
222        use crate::export::dataset::RoutingDecision;
223
224        let patterns = self.coordinator.reasoning_bank().read().get_all_patterns();
225
226        patterns
227            .iter()
228            .map(|p| {
229                RoutingDecision {
230                    query_embedding: p.centroid.clone(),
231                    routing_logits: vec![p.avg_quality], // Simplified
232                    selected_route: p.pattern_type.to_string(),
233                    confidence: p.avg_quality,
234                    quality: p.avg_quality,
235                }
236            })
237            .collect()
238    }
239}
240
241/// Builder for SonaEngine
242pub struct SonaEngineBuilder {
243    config: SonaConfig,
244}
245
246impl SonaEngineBuilder {
247    /// Create new builder
248    pub fn new() -> Self {
249        Self {
250            config: SonaConfig::default(),
251        }
252    }
253
254    /// Set hidden dimension
255    pub fn hidden_dim(mut self, dim: usize) -> Self {
256        self.config.hidden_dim = dim;
257        self.config.embedding_dim = dim;
258        self
259    }
260
261    /// Set micro-LoRA rank
262    pub fn micro_lora_rank(mut self, rank: usize) -> Self {
263        self.config.micro_lora_rank = rank.clamp(1, 2);
264        self
265    }
266
267    /// Set base-LoRA rank
268    pub fn base_lora_rank(mut self, rank: usize) -> Self {
269        self.config.base_lora_rank = rank;
270        self
271    }
272
273    /// Set micro-LoRA learning rate
274    pub fn micro_lr(mut self, lr: f32) -> Self {
275        self.config.micro_lora_lr = lr;
276        self
277    }
278
279    /// Set base-LoRA learning rate
280    pub fn base_lr(mut self, lr: f32) -> Self {
281        self.config.base_lora_lr = lr;
282        self
283    }
284
285    /// Set EWC lambda
286    pub fn ewc_lambda(mut self, lambda: f32) -> Self {
287        self.config.ewc_lambda = lambda;
288        self
289    }
290
291    /// Set pattern clusters
292    pub fn pattern_clusters(mut self, k: usize) -> Self {
293        self.config.pattern_clusters = k;
294        self
295    }
296
297    /// Set trajectory buffer capacity
298    pub fn buffer_capacity(mut self, capacity: usize) -> Self {
299        self.config.trajectory_capacity = capacity;
300        self
301    }
302
303    /// Set quality threshold
304    pub fn quality_threshold(mut self, threshold: f32) -> Self {
305        self.config.quality_threshold = threshold;
306        self
307    }
308
309    /// Build the engine
310    pub fn build(self) -> SonaEngine {
311        SonaEngine::with_config(self.config)
312    }
313}
314
315impl Default for SonaEngineBuilder {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::types::TrajectoryStep;
325
326    #[test]
327    fn test_engine_creation() {
328        let engine = SonaEngine::new(256);
329        assert!(engine.is_enabled());
330    }
331
332    #[test]
333    fn test_builder() {
334        let engine = SonaEngineBuilder::new()
335            .hidden_dim(512)
336            .micro_lora_rank(2)
337            .base_lora_rank(16)
338            .micro_lr(0.002)
339            .ewc_lambda(500.0)
340            .build();
341
342        assert_eq!(engine.config().hidden_dim, 512);
343        assert_eq!(engine.config().micro_lora_rank, 2);
344    }
345
346    #[test]
347    fn test_trajectory_workflow() {
348        let engine = SonaEngine::new(64);
349
350        // Begin trajectory
351        let mut builder = engine.begin_trajectory(vec![0.1; 64]);
352        builder.add_step(vec![0.5; 64], vec![], 0.8);
353        builder.add_step(vec![0.6; 64], vec![], 0.9);
354
355        // End trajectory
356        engine.end_trajectory(builder, 0.85);
357
358        let stats = engine.stats();
359        assert_eq!(stats.trajectories_buffered, 1);
360    }
361
362    #[test]
363    fn test_micro_lora_application() {
364        let engine = SonaEngine::new(64);
365
366        // Train a bit first
367        for i in 0..10 {
368            let mut builder = engine.begin_trajectory(vec![0.1; 64]);
369            builder.add_step(vec![0.5; 64], vec![], 0.8);
370            engine.end_trajectory(builder, 0.8);
371        }
372        engine.flush();
373
374        // Apply LoRA
375        let input = vec![1.0; 64];
376        let mut output = vec![0.0; 64];
377        engine.apply_micro_lora(&input, &mut output);
378
379        // Output may or may not be modified depending on accumulated gradients
380    }
381
382    #[test]
383    fn test_force_learn() {
384        let engine = SonaEngine::new(256);
385
386        for i in 0..150 {
387            let mut builder = engine.begin_trajectory(vec![0.1; 256]);
388            builder.add_step(vec![0.5; 256], vec![], 0.8);
389            engine.end_trajectory(builder, 0.8);
390        }
391
392        let result = engine.force_learn();
393        assert!(result.contains("150 trajectories"));
394    }
395
396    #[test]
397    fn test_disabled_engine() {
398        let mut engine = SonaEngine::new(64);
399        engine.set_enabled(false);
400
401        let builder = engine.begin_trajectory(vec![0.1; 64]);
402        engine.end_trajectory(builder, 0.8);
403
404        // Should not record when disabled
405        let stats = engine.stats();
406        assert_eq!(stats.trajectories_buffered, 0);
407    }
408}