Skip to main content

ruvector_sona/loops/
background.rs

1//! Loop B - Background Learning
2//!
3//! Hourly pattern extraction and base LoRA updates.
4
5use crate::ewc::EwcPlusPlus;
6use crate::lora::BaseLoRA;
7use crate::reasoning_bank::ReasoningBank;
8use crate::time_compat::Instant;
9use crate::types::{LearnedPattern, QueryTrajectory, SonaConfig};
10use parking_lot::RwLock;
11use std::sync::Arc;
12use std::time::Duration;
13
14/// Background loop configuration
15#[derive(Clone, Debug)]
16pub struct BackgroundLoopConfig {
17    /// Minimum trajectories to process
18    pub min_trajectories: usize,
19    /// Base LoRA learning rate
20    pub base_lora_lr: f32,
21    /// EWC lambda
22    pub ewc_lambda: f32,
23    /// Pattern extraction interval
24    pub extraction_interval: Duration,
25}
26
27impl Default for BackgroundLoopConfig {
28    fn default() -> Self {
29        Self {
30            min_trajectories: 100,
31            base_lora_lr: 0.0001,
32            ewc_lambda: 1000.0,
33            extraction_interval: Duration::from_secs(3600),
34        }
35    }
36}
37
38impl From<&SonaConfig> for BackgroundLoopConfig {
39    fn from(config: &SonaConfig) -> Self {
40        Self {
41            min_trajectories: 100,
42            base_lora_lr: config.base_lora_lr,
43            ewc_lambda: config.ewc_lambda,
44            extraction_interval: Duration::from_millis(config.background_interval_ms),
45        }
46    }
47}
48
49/// Background cycle result
50#[derive(Debug)]
51pub struct BackgroundResult {
52    pub trajectories_processed: usize,
53    pub patterns_extracted: usize,
54    pub ewc_updated: bool,
55    pub elapsed: Duration,
56    pub status: String,
57}
58
59impl BackgroundResult {
60    fn skipped(reason: &str) -> Self {
61        Self {
62            trajectories_processed: 0,
63            patterns_extracted: 0,
64            ewc_updated: false,
65            elapsed: Duration::ZERO,
66            status: format!("skipped: {}", reason),
67        }
68    }
69}
70
71/// Background learning loop (Loop B)
72pub struct BackgroundLoop {
73    /// Configuration
74    config: BackgroundLoopConfig,
75    /// ReasoningBank for pattern storage
76    reasoning_bank: Arc<RwLock<ReasoningBank>>,
77    /// EWC++ for forgetting prevention
78    ewc: Arc<RwLock<EwcPlusPlus>>,
79    /// Base LoRA
80    base_lora: Arc<RwLock<BaseLoRA>>,
81    /// Last extraction time
82    last_extraction: RwLock<Instant>,
83}
84
85impl BackgroundLoop {
86    /// Create new background loop
87    pub fn new(
88        config: BackgroundLoopConfig,
89        reasoning_bank: Arc<RwLock<ReasoningBank>>,
90        ewc: Arc<RwLock<EwcPlusPlus>>,
91        base_lora: Arc<RwLock<BaseLoRA>>,
92    ) -> Self {
93        Self {
94            config,
95            reasoning_bank,
96            ewc,
97            base_lora,
98            last_extraction: RwLock::new(Instant::now()),
99        }
100    }
101
102    /// Check if it's time for background cycle
103    pub fn should_run(&self) -> bool {
104        self.last_extraction.read().elapsed() >= self.config.extraction_interval
105    }
106
107    /// Run background learning cycle
108    pub fn run_cycle(&self, trajectories: Vec<QueryTrajectory>) -> BackgroundResult {
109        if trajectories.len() < self.config.min_trajectories {
110            return BackgroundResult::skipped("insufficient trajectories");
111        }
112
113        let start = Instant::now();
114
115        // 1. Add trajectories to reasoning bank
116        {
117            let mut bank = self.reasoning_bank.write();
118            for trajectory in &trajectories {
119                bank.add_trajectory(trajectory);
120            }
121        }
122
123        // 2. Extract patterns
124        let patterns = {
125            let mut bank = self.reasoning_bank.write();
126            bank.extract_patterns()
127        };
128
129        // 3. Compute gradients from patterns
130        let gradients = self.compute_pattern_gradients(&patterns);
131
132        // 4. Apply EWC++ constraints
133        let constrained_gradients = {
134            let ewc = self.ewc.read();
135            ewc.apply_constraints(&gradients)
136        };
137
138        // 5. Check for task boundary
139        let task_boundary = {
140            let ewc = self.ewc.read();
141            ewc.detect_task_boundary(&gradients)
142        };
143
144        if task_boundary {
145            let mut ewc = self.ewc.write();
146            ewc.start_new_task();
147        }
148
149        // 6. Update EWC++ Fisher
150        {
151            let mut ewc = self.ewc.write();
152            ewc.update_fisher(&constrained_gradients);
153        }
154
155        // 7. Update base LoRA
156        self.update_base_lora(&constrained_gradients);
157
158        // Update last extraction time
159        *self.last_extraction.write() = Instant::now();
160
161        BackgroundResult {
162            trajectories_processed: trajectories.len(),
163            patterns_extracted: patterns.len(),
164            ewc_updated: true,
165            elapsed: start.elapsed(),
166            status: "completed".to_string(),
167        }
168    }
169
170    fn compute_pattern_gradients(&self, patterns: &[LearnedPattern]) -> Vec<f32> {
171        if patterns.is_empty() {
172            return Vec::new();
173        }
174
175        let dim = patterns[0].centroid.len();
176        let mut gradient = vec![0.0f32; dim];
177        let mut total_weight = 0.0f32;
178
179        for pattern in patterns {
180            let weight = pattern.avg_quality * pattern.cluster_size as f32;
181            for (i, &v) in pattern.centroid.iter().enumerate() {
182                if i < dim {
183                    gradient[i] += v * weight;
184                }
185            }
186            total_weight += weight;
187        }
188
189        if total_weight > 0.0 {
190            for g in &mut gradient {
191                *g /= total_weight;
192            }
193        }
194
195        gradient
196    }
197
198    fn update_base_lora(&self, gradients: &[f32]) {
199        let mut lora = self.base_lora.write();
200        let num_layers = lora.num_layers();
201
202        if num_layers == 0 || gradients.is_empty() {
203            return;
204        }
205
206        let per_layer = gradients.len() / num_layers;
207
208        for (layer_idx, layer) in lora.layers.iter_mut().enumerate() {
209            let start = layer_idx * per_layer;
210            let end = (start + per_layer).min(gradients.len());
211
212            for (i, &grad) in gradients[start..end].iter().enumerate() {
213                if i < layer.up_proj.len() {
214                    layer.up_proj[i] += grad * self.config.base_lora_lr;
215                }
216            }
217        }
218    }
219
220    /// Get reasoning bank reference
221    pub fn reasoning_bank(&self) -> &Arc<RwLock<ReasoningBank>> {
222        &self.reasoning_bank
223    }
224
225    /// Get EWC reference
226    pub fn ewc(&self) -> &Arc<RwLock<EwcPlusPlus>> {
227        &self.ewc
228    }
229
230    /// Get base LoRA reference
231    pub fn base_lora(&self) -> &Arc<RwLock<BaseLoRA>> {
232        &self.base_lora
233    }
234}