Skip to main content

ruvector_sona/loops/
background.rs

1//! Loop B - Background Learning
2//!
3//! Hourly pattern extraction and base LoRA updates.
4
5use crate::ewc::EwcPlusPlus;
6use crate::lora::BaseLoRA;
7use crate::reasoning_bank::ReasoningBank;
8use crate::time_compat::Instant;
9use crate::types::{LearnedPattern, QueryTrajectory, SonaConfig};
10use parking_lot::RwLock;
11use std::sync::Arc;
12use std::time::Duration;
13
14/// Background loop configuration
15#[derive(Clone, Debug)]
16pub struct BackgroundLoopConfig {
17    /// Minimum trajectories to process
18    pub min_trajectories: usize,
19    /// Base LoRA learning rate
20    pub base_lora_lr: f32,
21    /// EWC lambda
22    pub ewc_lambda: f32,
23    /// Pattern extraction interval
24    pub extraction_interval: Duration,
25}
26
27impl Default for BackgroundLoopConfig {
28    fn default() -> Self {
29        Self {
30            min_trajectories: 10, // Was 100; lowered so patterns crystallize from fewer trajectories
31            base_lora_lr: 0.0001,
32            ewc_lambda: 1000.0,
33            extraction_interval: Duration::from_secs(3600),
34        }
35    }
36}
37
38impl From<&SonaConfig> for BackgroundLoopConfig {
39    fn from(config: &SonaConfig) -> Self {
40        Self {
41            min_trajectories: 10, // Was 100; lowered so patterns crystallize from fewer trajectories
42            base_lora_lr: config.base_lora_lr,
43            ewc_lambda: config.ewc_lambda,
44            extraction_interval: Duration::from_millis(config.background_interval_ms),
45        }
46    }
47}
48
49/// Background cycle result
50#[derive(Debug)]
51pub struct BackgroundResult {
52    pub trajectories_processed: usize,
53    pub patterns_extracted: usize,
54    pub ewc_updated: bool,
55    pub elapsed: Duration,
56    pub status: String,
57}
58
59impl BackgroundResult {
60    fn skipped(reason: &str) -> Self {
61        Self {
62            trajectories_processed: 0,
63            patterns_extracted: 0,
64            ewc_updated: false,
65            elapsed: Duration::ZERO,
66            status: format!("skipped: {}", reason),
67        }
68    }
69}
70
71/// Background learning loop (Loop B)
72pub struct BackgroundLoop {
73    /// Configuration
74    config: BackgroundLoopConfig,
75    /// ReasoningBank for pattern storage
76    reasoning_bank: Arc<RwLock<ReasoningBank>>,
77    /// EWC++ for forgetting prevention
78    ewc: Arc<RwLock<EwcPlusPlus>>,
79    /// Base LoRA
80    base_lora: Arc<RwLock<BaseLoRA>>,
81    /// Last extraction time
82    last_extraction: RwLock<Instant>,
83}
84
85impl BackgroundLoop {
86    /// Create new background loop
87    pub fn new(
88        config: BackgroundLoopConfig,
89        reasoning_bank: Arc<RwLock<ReasoningBank>>,
90        ewc: Arc<RwLock<EwcPlusPlus>>,
91        base_lora: Arc<RwLock<BaseLoRA>>,
92    ) -> Self {
93        Self {
94            config,
95            reasoning_bank,
96            ewc,
97            base_lora,
98            last_extraction: RwLock::new(Instant::now()),
99        }
100    }
101
102    /// Check if it's time for background cycle
103    pub fn should_run(&self) -> bool {
104        self.last_extraction.read().elapsed() >= self.config.extraction_interval
105    }
106
107    /// Run background learning cycle
108    ///
109    /// If `force` is true, bypasses the minimum trajectory check (for forceLearn API)
110    pub fn run_cycle(&self, trajectories: Vec<QueryTrajectory>, force: bool) -> BackgroundResult {
111        if !force && trajectories.len() < self.config.min_trajectories {
112            return BackgroundResult::skipped(&format!(
113                "insufficient trajectories ({} < {} minimum, use forceLearn to bypass)",
114                trajectories.len(),
115                self.config.min_trajectories
116            ));
117        }
118
119        if trajectories.is_empty() {
120            return BackgroundResult::skipped("no trajectories to process");
121        }
122
123        let start = Instant::now();
124
125        // 1. Add trajectories to reasoning bank
126        {
127            let mut bank = self.reasoning_bank.write();
128            for trajectory in &trajectories {
129                bank.add_trajectory(trajectory);
130            }
131        }
132
133        // 2. Extract patterns
134        let patterns = {
135            let mut bank = self.reasoning_bank.write();
136            bank.extract_patterns()
137        };
138
139        // 3. Compute gradients from patterns
140        let gradients = self.compute_pattern_gradients(&patterns);
141
142        // 4. Apply EWC++ constraints
143        let constrained_gradients = {
144            let ewc = self.ewc.read();
145            ewc.apply_constraints(&gradients)
146        };
147
148        // 5. Check for task boundary
149        let task_boundary = {
150            let ewc = self.ewc.read();
151            ewc.detect_task_boundary(&gradients)
152        };
153
154        if task_boundary {
155            let mut ewc = self.ewc.write();
156            ewc.start_new_task();
157        }
158
159        // 6. Update EWC++ Fisher
160        {
161            let mut ewc = self.ewc.write();
162            ewc.update_fisher(&constrained_gradients);
163        }
164
165        // 7. Update base LoRA
166        self.update_base_lora(&constrained_gradients);
167
168        // Update last extraction time
169        *self.last_extraction.write() = Instant::now();
170
171        BackgroundResult {
172            trajectories_processed: trajectories.len(),
173            patterns_extracted: patterns.len(),
174            ewc_updated: true,
175            elapsed: start.elapsed(),
176            status: "completed".to_string(),
177        }
178    }
179
180    fn compute_pattern_gradients(&self, patterns: &[LearnedPattern]) -> Vec<f32> {
181        if patterns.is_empty() {
182            return Vec::new();
183        }
184
185        let dim = patterns[0].centroid.len();
186        let mut gradient = vec![0.0f32; dim];
187        let mut total_weight = 0.0f32;
188
189        for pattern in patterns {
190            let weight = pattern.avg_quality * pattern.cluster_size as f32;
191            for (i, &v) in pattern.centroid.iter().enumerate() {
192                if i < dim {
193                    gradient[i] += v * weight;
194                }
195            }
196            total_weight += weight;
197        }
198
199        if total_weight > 0.0 {
200            for g in &mut gradient {
201                *g /= total_weight;
202            }
203        }
204
205        gradient
206    }
207
208    fn update_base_lora(&self, gradients: &[f32]) {
209        let mut lora = self.base_lora.write();
210        let num_layers = lora.num_layers();
211
212        if num_layers == 0 || gradients.is_empty() {
213            return;
214        }
215
216        let per_layer = gradients.len() / num_layers;
217
218        for (layer_idx, layer) in lora.layers.iter_mut().enumerate() {
219            let start = layer_idx * per_layer;
220            let end = (start + per_layer).min(gradients.len());
221
222            for (i, &grad) in gradients[start..end].iter().enumerate() {
223                if i < layer.up_proj.len() {
224                    layer.up_proj[i] += grad * self.config.base_lora_lr;
225                }
226            }
227        }
228    }
229
230    /// Get reasoning bank reference
231    pub fn reasoning_bank(&self) -> &Arc<RwLock<ReasoningBank>> {
232        &self.reasoning_bank
233    }
234
235    /// Get EWC reference
236    pub fn ewc(&self) -> &Arc<RwLock<EwcPlusPlus>> {
237        &self.ewc
238    }
239
240    /// Get base LoRA reference
241    pub fn base_lora(&self) -> &Arc<RwLock<BaseLoRA>> {
242        &self.base_lora
243    }
244}