ruvector_sona/loops/
background.rs1use crate::ewc::EwcPlusPlus;
6use crate::lora::BaseLoRA;
7use crate::reasoning_bank::ReasoningBank;
8use crate::types::{QueryTrajectory, SonaConfig, LearnedPattern};
9use parking_lot::RwLock;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13#[derive(Clone, Debug)]
15pub struct BackgroundLoopConfig {
16 pub min_trajectories: usize,
18 pub base_lora_lr: f32,
20 pub ewc_lambda: f32,
22 pub extraction_interval: Duration,
24}
25
26impl Default for BackgroundLoopConfig {
27 fn default() -> Self {
28 Self {
29 min_trajectories: 100,
30 base_lora_lr: 0.0001,
31 ewc_lambda: 1000.0,
32 extraction_interval: Duration::from_secs(3600),
33 }
34 }
35}
36
37impl From<&SonaConfig> for BackgroundLoopConfig {
38 fn from(config: &SonaConfig) -> Self {
39 Self {
40 min_trajectories: 100,
41 base_lora_lr: config.base_lora_lr,
42 ewc_lambda: config.ewc_lambda,
43 extraction_interval: Duration::from_millis(config.background_interval_ms),
44 }
45 }
46}
47
48#[derive(Debug)]
50pub struct BackgroundResult {
51 pub trajectories_processed: usize,
52 pub patterns_extracted: usize,
53 pub ewc_updated: bool,
54 pub elapsed: Duration,
55 pub status: String,
56}
57
58impl BackgroundResult {
59 fn skipped(reason: &str) -> Self {
60 Self {
61 trajectories_processed: 0,
62 patterns_extracted: 0,
63 ewc_updated: false,
64 elapsed: Duration::ZERO,
65 status: format!("skipped: {}", reason),
66 }
67 }
68}
69
70pub struct BackgroundLoop {
72 config: BackgroundLoopConfig,
74 reasoning_bank: Arc<RwLock<ReasoningBank>>,
76 ewc: Arc<RwLock<EwcPlusPlus>>,
78 base_lora: Arc<RwLock<BaseLoRA>>,
80 last_extraction: RwLock<Instant>,
82}
83
84impl BackgroundLoop {
85 pub fn new(
87 config: BackgroundLoopConfig,
88 reasoning_bank: Arc<RwLock<ReasoningBank>>,
89 ewc: Arc<RwLock<EwcPlusPlus>>,
90 base_lora: Arc<RwLock<BaseLoRA>>,
91 ) -> Self {
92 Self {
93 config,
94 reasoning_bank,
95 ewc,
96 base_lora,
97 last_extraction: RwLock::new(Instant::now()),
98 }
99 }
100
101 pub fn should_run(&self) -> bool {
103 self.last_extraction.read().elapsed() >= self.config.extraction_interval
104 }
105
106 pub fn run_cycle(&self, trajectories: Vec<QueryTrajectory>) -> BackgroundResult {
108 if trajectories.len() < self.config.min_trajectories {
109 return BackgroundResult::skipped("insufficient trajectories");
110 }
111
112 let start = Instant::now();
113
114 {
116 let mut bank = self.reasoning_bank.write();
117 for trajectory in &trajectories {
118 bank.add_trajectory(trajectory);
119 }
120 }
121
122 let patterns = {
124 let mut bank = self.reasoning_bank.write();
125 bank.extract_patterns()
126 };
127
128 let gradients = self.compute_pattern_gradients(&patterns);
130
131 let constrained_gradients = {
133 let ewc = self.ewc.read();
134 ewc.apply_constraints(&gradients)
135 };
136
137 let task_boundary = {
139 let ewc = self.ewc.read();
140 ewc.detect_task_boundary(&gradients)
141 };
142
143 if task_boundary {
144 let mut ewc = self.ewc.write();
145 ewc.start_new_task();
146 }
147
148 {
150 let mut ewc = self.ewc.write();
151 ewc.update_fisher(&constrained_gradients);
152 }
153
154 self.update_base_lora(&constrained_gradients);
156
157 *self.last_extraction.write() = Instant::now();
159
160 BackgroundResult {
161 trajectories_processed: trajectories.len(),
162 patterns_extracted: patterns.len(),
163 ewc_updated: true,
164 elapsed: start.elapsed(),
165 status: "completed".to_string(),
166 }
167 }
168
169 fn compute_pattern_gradients(&self, patterns: &[LearnedPattern]) -> Vec<f32> {
170 if patterns.is_empty() {
171 return Vec::new();
172 }
173
174 let dim = patterns[0].centroid.len();
175 let mut gradient = vec![0.0f32; dim];
176 let mut total_weight = 0.0f32;
177
178 for pattern in patterns {
179 let weight = pattern.avg_quality * pattern.cluster_size as f32;
180 for (i, &v) in pattern.centroid.iter().enumerate() {
181 if i < dim {
182 gradient[i] += v * weight;
183 }
184 }
185 total_weight += weight;
186 }
187
188 if total_weight > 0.0 {
189 for g in &mut gradient {
190 *g /= total_weight;
191 }
192 }
193
194 gradient
195 }
196
197 fn update_base_lora(&self, gradients: &[f32]) {
198 let mut lora = self.base_lora.write();
199 let num_layers = lora.num_layers();
200
201 if num_layers == 0 || gradients.is_empty() {
202 return;
203 }
204
205 let per_layer = gradients.len() / num_layers;
206
207 for (layer_idx, layer) in lora.layers.iter_mut().enumerate() {
208 let start = layer_idx * per_layer;
209 let end = (start + per_layer).min(gradients.len());
210
211 for (i, &grad) in gradients[start..end].iter().enumerate() {
212 if i < layer.up_proj.len() {
213 layer.up_proj[i] += grad * self.config.base_lora_lr;
214 }
215 }
216 }
217 }
218
219 pub fn reasoning_bank(&self) -> &Arc<RwLock<ReasoningBank>> {
221 &self.reasoning_bank
222 }
223
224 pub fn ewc(&self) -> &Arc<RwLock<EwcPlusPlus>> {
226 &self.ewc
227 }
228
229 pub fn base_lora(&self) -> &Arc<RwLock<BaseLoRA>> {
231 &self.base_lora
232 }
233}