ruvector_sona/loops/
background.rs1use crate::ewc::EwcPlusPlus;
6use crate::lora::BaseLoRA;
7use crate::reasoning_bank::ReasoningBank;
8use crate::time_compat::Instant;
9use crate::types::{LearnedPattern, QueryTrajectory, SonaConfig};
10use parking_lot::RwLock;
11use std::sync::Arc;
12use std::time::Duration;
13
14#[derive(Clone, Debug)]
16pub struct BackgroundLoopConfig {
17 pub min_trajectories: usize,
19 pub base_lora_lr: f32,
21 pub ewc_lambda: f32,
23 pub extraction_interval: Duration,
25}
26
27impl Default for BackgroundLoopConfig {
28 fn default() -> Self {
29 Self {
30 min_trajectories: 100,
31 base_lora_lr: 0.0001,
32 ewc_lambda: 1000.0,
33 extraction_interval: Duration::from_secs(3600),
34 }
35 }
36}
37
38impl From<&SonaConfig> for BackgroundLoopConfig {
39 fn from(config: &SonaConfig) -> Self {
40 Self {
41 min_trajectories: 100,
42 base_lora_lr: config.base_lora_lr,
43 ewc_lambda: config.ewc_lambda,
44 extraction_interval: Duration::from_millis(config.background_interval_ms),
45 }
46 }
47}
48
49#[derive(Debug)]
51pub struct BackgroundResult {
52 pub trajectories_processed: usize,
53 pub patterns_extracted: usize,
54 pub ewc_updated: bool,
55 pub elapsed: Duration,
56 pub status: String,
57}
58
59impl BackgroundResult {
60 fn skipped(reason: &str) -> Self {
61 Self {
62 trajectories_processed: 0,
63 patterns_extracted: 0,
64 ewc_updated: false,
65 elapsed: Duration::ZERO,
66 status: format!("skipped: {}", reason),
67 }
68 }
69}
70
71pub struct BackgroundLoop {
73 config: BackgroundLoopConfig,
75 reasoning_bank: Arc<RwLock<ReasoningBank>>,
77 ewc: Arc<RwLock<EwcPlusPlus>>,
79 base_lora: Arc<RwLock<BaseLoRA>>,
81 last_extraction: RwLock<Instant>,
83}
84
85impl BackgroundLoop {
86 pub fn new(
88 config: BackgroundLoopConfig,
89 reasoning_bank: Arc<RwLock<ReasoningBank>>,
90 ewc: Arc<RwLock<EwcPlusPlus>>,
91 base_lora: Arc<RwLock<BaseLoRA>>,
92 ) -> Self {
93 Self {
94 config,
95 reasoning_bank,
96 ewc,
97 base_lora,
98 last_extraction: RwLock::new(Instant::now()),
99 }
100 }
101
102 pub fn should_run(&self) -> bool {
104 self.last_extraction.read().elapsed() >= self.config.extraction_interval
105 }
106
107 pub fn run_cycle(&self, trajectories: Vec<QueryTrajectory>) -> BackgroundResult {
109 if trajectories.len() < self.config.min_trajectories {
110 return BackgroundResult::skipped("insufficient trajectories");
111 }
112
113 let start = Instant::now();
114
115 {
117 let mut bank = self.reasoning_bank.write();
118 for trajectory in &trajectories {
119 bank.add_trajectory(trajectory);
120 }
121 }
122
123 let patterns = {
125 let mut bank = self.reasoning_bank.write();
126 bank.extract_patterns()
127 };
128
129 let gradients = self.compute_pattern_gradients(&patterns);
131
132 let constrained_gradients = {
134 let ewc = self.ewc.read();
135 ewc.apply_constraints(&gradients)
136 };
137
138 let task_boundary = {
140 let ewc = self.ewc.read();
141 ewc.detect_task_boundary(&gradients)
142 };
143
144 if task_boundary {
145 let mut ewc = self.ewc.write();
146 ewc.start_new_task();
147 }
148
149 {
151 let mut ewc = self.ewc.write();
152 ewc.update_fisher(&constrained_gradients);
153 }
154
155 self.update_base_lora(&constrained_gradients);
157
158 *self.last_extraction.write() = Instant::now();
160
161 BackgroundResult {
162 trajectories_processed: trajectories.len(),
163 patterns_extracted: patterns.len(),
164 ewc_updated: true,
165 elapsed: start.elapsed(),
166 status: "completed".to_string(),
167 }
168 }
169
170 fn compute_pattern_gradients(&self, patterns: &[LearnedPattern]) -> Vec<f32> {
171 if patterns.is_empty() {
172 return Vec::new();
173 }
174
175 let dim = patterns[0].centroid.len();
176 let mut gradient = vec![0.0f32; dim];
177 let mut total_weight = 0.0f32;
178
179 for pattern in patterns {
180 let weight = pattern.avg_quality * pattern.cluster_size as f32;
181 for (i, &v) in pattern.centroid.iter().enumerate() {
182 if i < dim {
183 gradient[i] += v * weight;
184 }
185 }
186 total_weight += weight;
187 }
188
189 if total_weight > 0.0 {
190 for g in &mut gradient {
191 *g /= total_weight;
192 }
193 }
194
195 gradient
196 }
197
198 fn update_base_lora(&self, gradients: &[f32]) {
199 let mut lora = self.base_lora.write();
200 let num_layers = lora.num_layers();
201
202 if num_layers == 0 || gradients.is_empty() {
203 return;
204 }
205
206 let per_layer = gradients.len() / num_layers;
207
208 for (layer_idx, layer) in lora.layers.iter_mut().enumerate() {
209 let start = layer_idx * per_layer;
210 let end = (start + per_layer).min(gradients.len());
211
212 for (i, &grad) in gradients[start..end].iter().enumerate() {
213 if i < layer.up_proj.len() {
214 layer.up_proj[i] += grad * self.config.base_lora_lr;
215 }
216 }
217 }
218 }
219
220 pub fn reasoning_bank(&self) -> &Arc<RwLock<ReasoningBank>> {
222 &self.reasoning_bank
223 }
224
225 pub fn ewc(&self) -> &Arc<RwLock<EwcPlusPlus>> {
227 &self.ewc
228 }
229
230 pub fn base_lora(&self) -> &Arc<RwLock<BaseLoRA>> {
232 &self.base_lora
233 }
234}