ruvector_sona/loops/
background.rs

1//! Loop B - Background Learning
2//!
3//! Hourly pattern extraction and base LoRA updates.
4
5use crate::ewc::EwcPlusPlus;
6use crate::lora::BaseLoRA;
7use crate::reasoning_bank::ReasoningBank;
8use crate::types::{QueryTrajectory, SonaConfig, LearnedPattern};
9use parking_lot::RwLock;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13/// Background loop configuration
14#[derive(Clone, Debug)]
15pub struct BackgroundLoopConfig {
16    /// Minimum trajectories to process
17    pub min_trajectories: usize,
18    /// Base LoRA learning rate
19    pub base_lora_lr: f32,
20    /// EWC lambda
21    pub ewc_lambda: f32,
22    /// Pattern extraction interval
23    pub extraction_interval: Duration,
24}
25
26impl Default for BackgroundLoopConfig {
27    fn default() -> Self {
28        Self {
29            min_trajectories: 100,
30            base_lora_lr: 0.0001,
31            ewc_lambda: 1000.0,
32            extraction_interval: Duration::from_secs(3600),
33        }
34    }
35}
36
37impl From<&SonaConfig> for BackgroundLoopConfig {
38    fn from(config: &SonaConfig) -> Self {
39        Self {
40            min_trajectories: 100,
41            base_lora_lr: config.base_lora_lr,
42            ewc_lambda: config.ewc_lambda,
43            extraction_interval: Duration::from_millis(config.background_interval_ms),
44        }
45    }
46}
47
48/// Background cycle result
49#[derive(Debug)]
50pub struct BackgroundResult {
51    pub trajectories_processed: usize,
52    pub patterns_extracted: usize,
53    pub ewc_updated: bool,
54    pub elapsed: Duration,
55    pub status: String,
56}
57
58impl BackgroundResult {
59    fn skipped(reason: &str) -> Self {
60        Self {
61            trajectories_processed: 0,
62            patterns_extracted: 0,
63            ewc_updated: false,
64            elapsed: Duration::ZERO,
65            status: format!("skipped: {}", reason),
66        }
67    }
68}
69
70/// Background learning loop (Loop B)
71pub struct BackgroundLoop {
72    /// Configuration
73    config: BackgroundLoopConfig,
74    /// ReasoningBank for pattern storage
75    reasoning_bank: Arc<RwLock<ReasoningBank>>,
76    /// EWC++ for forgetting prevention
77    ewc: Arc<RwLock<EwcPlusPlus>>,
78    /// Base LoRA
79    base_lora: Arc<RwLock<BaseLoRA>>,
80    /// Last extraction time
81    last_extraction: RwLock<Instant>,
82}
83
84impl BackgroundLoop {
85    /// Create new background loop
86    pub fn new(
87        config: BackgroundLoopConfig,
88        reasoning_bank: Arc<RwLock<ReasoningBank>>,
89        ewc: Arc<RwLock<EwcPlusPlus>>,
90        base_lora: Arc<RwLock<BaseLoRA>>,
91    ) -> Self {
92        Self {
93            config,
94            reasoning_bank,
95            ewc,
96            base_lora,
97            last_extraction: RwLock::new(Instant::now()),
98        }
99    }
100
101    /// Check if it's time for background cycle
102    pub fn should_run(&self) -> bool {
103        self.last_extraction.read().elapsed() >= self.config.extraction_interval
104    }
105
106    /// Run background learning cycle
107    pub fn run_cycle(&self, trajectories: Vec<QueryTrajectory>) -> BackgroundResult {
108        if trajectories.len() < self.config.min_trajectories {
109            return BackgroundResult::skipped("insufficient trajectories");
110        }
111
112        let start = Instant::now();
113
114        // 1. Add trajectories to reasoning bank
115        {
116            let mut bank = self.reasoning_bank.write();
117            for trajectory in &trajectories {
118                bank.add_trajectory(trajectory);
119            }
120        }
121
122        // 2. Extract patterns
123        let patterns = {
124            let mut bank = self.reasoning_bank.write();
125            bank.extract_patterns()
126        };
127
128        // 3. Compute gradients from patterns
129        let gradients = self.compute_pattern_gradients(&patterns);
130
131        // 4. Apply EWC++ constraints
132        let constrained_gradients = {
133            let ewc = self.ewc.read();
134            ewc.apply_constraints(&gradients)
135        };
136
137        // 5. Check for task boundary
138        let task_boundary = {
139            let ewc = self.ewc.read();
140            ewc.detect_task_boundary(&gradients)
141        };
142
143        if task_boundary {
144            let mut ewc = self.ewc.write();
145            ewc.start_new_task();
146        }
147
148        // 6. Update EWC++ Fisher
149        {
150            let mut ewc = self.ewc.write();
151            ewc.update_fisher(&constrained_gradients);
152        }
153
154        // 7. Update base LoRA
155        self.update_base_lora(&constrained_gradients);
156
157        // Update last extraction time
158        *self.last_extraction.write() = Instant::now();
159
160        BackgroundResult {
161            trajectories_processed: trajectories.len(),
162            patterns_extracted: patterns.len(),
163            ewc_updated: true,
164            elapsed: start.elapsed(),
165            status: "completed".to_string(),
166        }
167    }
168
169    fn compute_pattern_gradients(&self, patterns: &[LearnedPattern]) -> Vec<f32> {
170        if patterns.is_empty() {
171            return Vec::new();
172        }
173
174        let dim = patterns[0].centroid.len();
175        let mut gradient = vec![0.0f32; dim];
176        let mut total_weight = 0.0f32;
177
178        for pattern in patterns {
179            let weight = pattern.avg_quality * pattern.cluster_size as f32;
180            for (i, &v) in pattern.centroid.iter().enumerate() {
181                if i < dim {
182                    gradient[i] += v * weight;
183                }
184            }
185            total_weight += weight;
186        }
187
188        if total_weight > 0.0 {
189            for g in &mut gradient {
190                *g /= total_weight;
191            }
192        }
193
194        gradient
195    }
196
197    fn update_base_lora(&self, gradients: &[f32]) {
198        let mut lora = self.base_lora.write();
199        let num_layers = lora.num_layers();
200
201        if num_layers == 0 || gradients.is_empty() {
202            return;
203        }
204
205        let per_layer = gradients.len() / num_layers;
206
207        for (layer_idx, layer) in lora.layers.iter_mut().enumerate() {
208            let start = layer_idx * per_layer;
209            let end = (start + per_layer).min(gradients.len());
210
211            for (i, &grad) in gradients[start..end].iter().enumerate() {
212                if i < layer.up_proj.len() {
213                    layer.up_proj[i] += grad * self.config.base_lora_lr;
214                }
215            }
216        }
217    }
218
219    /// Get reasoning bank reference
220    pub fn reasoning_bank(&self) -> &Arc<RwLock<ReasoningBank>> {
221        &self.reasoning_bank
222    }
223
224    /// Get EWC reference
225    pub fn ewc(&self) -> &Arc<RwLock<EwcPlusPlus>> {
226        &self.ewc
227    }
228
229    /// Get base LoRA reference
230    pub fn base_lora(&self) -> &Arc<RwLock<BaseLoRA>> {
231        &self.base_lora
232    }
233}