ruvector_sona/loops/
background.rs1use crate::ewc::EwcPlusPlus;
6use crate::lora::BaseLoRA;
7use crate::reasoning_bank::ReasoningBank;
8use crate::time_compat::Instant;
9use crate::types::{LearnedPattern, QueryTrajectory, SonaConfig};
10use parking_lot::RwLock;
11use std::sync::Arc;
12use std::time::Duration;
13
14#[derive(Clone, Debug)]
16pub struct BackgroundLoopConfig {
17 pub min_trajectories: usize,
19 pub base_lora_lr: f32,
21 pub ewc_lambda: f32,
23 pub extraction_interval: Duration,
25}
26
27impl Default for BackgroundLoopConfig {
28 fn default() -> Self {
29 Self {
30 min_trajectories: 100,
31 base_lora_lr: 0.0001,
32 ewc_lambda: 1000.0,
33 extraction_interval: Duration::from_secs(3600),
34 }
35 }
36}
37
38impl From<&SonaConfig> for BackgroundLoopConfig {
39 fn from(config: &SonaConfig) -> Self {
40 Self {
41 min_trajectories: 100,
42 base_lora_lr: config.base_lora_lr,
43 ewc_lambda: config.ewc_lambda,
44 extraction_interval: Duration::from_millis(config.background_interval_ms),
45 }
46 }
47}
48
49#[derive(Debug)]
51pub struct BackgroundResult {
52 pub trajectories_processed: usize,
53 pub patterns_extracted: usize,
54 pub ewc_updated: bool,
55 pub elapsed: Duration,
56 pub status: String,
57}
58
59impl BackgroundResult {
60 fn skipped(reason: &str) -> Self {
61 Self {
62 trajectories_processed: 0,
63 patterns_extracted: 0,
64 ewc_updated: false,
65 elapsed: Duration::ZERO,
66 status: format!("skipped: {}", reason),
67 }
68 }
69}
70
71pub struct BackgroundLoop {
73 config: BackgroundLoopConfig,
75 reasoning_bank: Arc<RwLock<ReasoningBank>>,
77 ewc: Arc<RwLock<EwcPlusPlus>>,
79 base_lora: Arc<RwLock<BaseLoRA>>,
81 last_extraction: RwLock<Instant>,
83}
84
85impl BackgroundLoop {
86 pub fn new(
88 config: BackgroundLoopConfig,
89 reasoning_bank: Arc<RwLock<ReasoningBank>>,
90 ewc: Arc<RwLock<EwcPlusPlus>>,
91 base_lora: Arc<RwLock<BaseLoRA>>,
92 ) -> Self {
93 Self {
94 config,
95 reasoning_bank,
96 ewc,
97 base_lora,
98 last_extraction: RwLock::new(Instant::now()),
99 }
100 }
101
102 pub fn should_run(&self) -> bool {
104 self.last_extraction.read().elapsed() >= self.config.extraction_interval
105 }
106
107 pub fn run_cycle(&self, trajectories: Vec<QueryTrajectory>, force: bool) -> BackgroundResult {
111 if !force && trajectories.len() < self.config.min_trajectories {
112 return BackgroundResult::skipped(&format!(
113 "insufficient trajectories ({} < {} minimum, use forceLearn to bypass)",
114 trajectories.len(),
115 self.config.min_trajectories
116 ));
117 }
118
119 if trajectories.is_empty() {
120 return BackgroundResult::skipped("no trajectories to process");
121 }
122
123 let start = Instant::now();
124
125 {
127 let mut bank = self.reasoning_bank.write();
128 for trajectory in &trajectories {
129 bank.add_trajectory(trajectory);
130 }
131 }
132
133 let patterns = {
135 let mut bank = self.reasoning_bank.write();
136 bank.extract_patterns()
137 };
138
139 let gradients = self.compute_pattern_gradients(&patterns);
141
142 let constrained_gradients = {
144 let ewc = self.ewc.read();
145 ewc.apply_constraints(&gradients)
146 };
147
148 let task_boundary = {
150 let ewc = self.ewc.read();
151 ewc.detect_task_boundary(&gradients)
152 };
153
154 if task_boundary {
155 let mut ewc = self.ewc.write();
156 ewc.start_new_task();
157 }
158
159 {
161 let mut ewc = self.ewc.write();
162 ewc.update_fisher(&constrained_gradients);
163 }
164
165 self.update_base_lora(&constrained_gradients);
167
168 *self.last_extraction.write() = Instant::now();
170
171 BackgroundResult {
172 trajectories_processed: trajectories.len(),
173 patterns_extracted: patterns.len(),
174 ewc_updated: true,
175 elapsed: start.elapsed(),
176 status: "completed".to_string(),
177 }
178 }
179
180 fn compute_pattern_gradients(&self, patterns: &[LearnedPattern]) -> Vec<f32> {
181 if patterns.is_empty() {
182 return Vec::new();
183 }
184
185 let dim = patterns[0].centroid.len();
186 let mut gradient = vec![0.0f32; dim];
187 let mut total_weight = 0.0f32;
188
189 for pattern in patterns {
190 let weight = pattern.avg_quality * pattern.cluster_size as f32;
191 for (i, &v) in pattern.centroid.iter().enumerate() {
192 if i < dim {
193 gradient[i] += v * weight;
194 }
195 }
196 total_weight += weight;
197 }
198
199 if total_weight > 0.0 {
200 for g in &mut gradient {
201 *g /= total_weight;
202 }
203 }
204
205 gradient
206 }
207
208 fn update_base_lora(&self, gradients: &[f32]) {
209 let mut lora = self.base_lora.write();
210 let num_layers = lora.num_layers();
211
212 if num_layers == 0 || gradients.is_empty() {
213 return;
214 }
215
216 let per_layer = gradients.len() / num_layers;
217
218 for (layer_idx, layer) in lora.layers.iter_mut().enumerate() {
219 let start = layer_idx * per_layer;
220 let end = (start + per_layer).min(gradients.len());
221
222 for (i, &grad) in gradients[start..end].iter().enumerate() {
223 if i < layer.up_proj.len() {
224 layer.up_proj[i] += grad * self.config.base_lora_lr;
225 }
226 }
227 }
228 }
229
230 pub fn reasoning_bank(&self) -> &Arc<RwLock<ReasoningBank>> {
232 &self.reasoning_bank
233 }
234
235 pub fn ewc(&self) -> &Arc<RwLock<EwcPlusPlus>> {
237 &self.ewc
238 }
239
240 pub fn base_lora(&self) -> &Arc<RwLock<BaseLoRA>> {
242 &self.base_lora
243 }
244}