ruvector_sona/
ewc.rs

1//! EWC++ (Enhanced Elastic Weight Consolidation) for SONA
2//!
3//! Prevents catastrophic forgetting with:
4//! - Online Fisher information estimation
5//! - Multi-task memory with circular buffer
6//! - Automatic task boundary detection
7//! - Adaptive lambda scheduling
8
9use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11
12/// EWC++ configuration
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct EwcConfig {
15    /// Number of parameters
16    pub param_count: usize,
17    /// Maximum tasks to remember
18    pub max_tasks: usize,
19    /// Initial lambda
20    pub initial_lambda: f32,
21    /// Minimum lambda
22    pub min_lambda: f32,
23    /// Maximum lambda
24    pub max_lambda: f32,
25    /// Fisher EMA decay factor
26    pub fisher_ema_decay: f32,
27    /// Task boundary detection threshold
28    pub boundary_threshold: f32,
29    /// Gradient history for boundary detection
30    pub gradient_history_size: usize,
31}
32
33impl Default for EwcConfig {
34    fn default() -> Self {
35        // OPTIMIZED DEFAULTS based on @ruvector/sona v0.1.1 benchmarks:
36        // - Lambda 2000 optimal for catastrophic forgetting prevention
37        // - Higher max_lambda (15000) for aggressive protection when needed
38        Self {
39            param_count: 1000,
40            max_tasks: 10,
41            initial_lambda: 2000.0,   // OPTIMIZED: Better forgetting prevention
42            min_lambda: 100.0,
43            max_lambda: 15000.0,      // OPTIMIZED: Higher ceiling for multi-task
44            fisher_ema_decay: 0.999,
45            boundary_threshold: 2.0,
46            gradient_history_size: 100,
47        }
48    }
49}
50
51/// Task-specific Fisher information
52#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct TaskFisher {
54    /// Task ID
55    pub task_id: usize,
56    /// Fisher diagonal
57    pub fisher: Vec<f32>,
58    /// Optimal weights for this task
59    pub optimal_weights: Vec<f32>,
60    /// Task importance (for weighted consolidation)
61    pub importance: f32,
62}
63
64/// EWC++ implementation
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct EwcPlusPlus {
67    /// Configuration
68    config: EwcConfig,
69    /// Current Fisher information (online estimate)
70    current_fisher: Vec<f32>,
71    /// Current optimal weights
72    current_weights: Vec<f32>,
73    /// Task memory (circular buffer)
74    task_memory: VecDeque<TaskFisher>,
75    /// Current task ID
76    current_task_id: usize,
77    /// Current lambda
78    lambda: f32,
79    /// Gradient history for boundary detection
80    gradient_history: VecDeque<Vec<f32>>,
81    /// Running gradient mean
82    gradient_mean: Vec<f32>,
83    /// Running gradient variance
84    gradient_var: Vec<f32>,
85    /// Samples seen for current task
86    samples_seen: u64,
87}
88
89impl EwcPlusPlus {
90    /// Create new EWC++
91    pub fn new(config: EwcConfig) -> Self {
92        let param_count = config.param_count;
93        let initial_lambda = config.initial_lambda;
94
95        Self {
96            config: config.clone(),
97            current_fisher: vec![0.0; param_count],
98            current_weights: vec![0.0; param_count],
99            task_memory: VecDeque::with_capacity(config.max_tasks),
100            current_task_id: 0,
101            lambda: initial_lambda,
102            gradient_history: VecDeque::with_capacity(config.gradient_history_size),
103            gradient_mean: vec![0.0; param_count],
104            gradient_var: vec![1.0; param_count],
105            samples_seen: 0,
106        }
107    }
108
109    /// Update Fisher information online using EMA
110    pub fn update_fisher(&mut self, gradients: &[f32]) {
111        if gradients.len() != self.config.param_count {
112            return;
113        }
114
115        let decay = self.config.fisher_ema_decay;
116
117        // Online Fisher update: F_t = decay * F_{t-1} + (1 - decay) * g^2
118        for (i, &g) in gradients.iter().enumerate() {
119            self.current_fisher[i] = decay * self.current_fisher[i] + (1.0 - decay) * g * g;
120        }
121
122        // Update gradient statistics for boundary detection
123        self.update_gradient_stats(gradients);
124        self.samples_seen += 1;
125    }
126
127    /// Update gradient statistics for boundary detection
128    fn update_gradient_stats(&mut self, gradients: &[f32]) {
129        // Store in history
130        if self.gradient_history.len() >= self.config.gradient_history_size {
131            self.gradient_history.pop_front();
132        }
133        self.gradient_history.push_back(gradients.to_vec());
134
135        // Update running mean and variance (Welford's algorithm)
136        let n = self.samples_seen as f32 + 1.0;
137
138        for (i, &g) in gradients.iter().enumerate() {
139            let delta = g - self.gradient_mean[i];
140            self.gradient_mean[i] += delta / n;
141            let delta2 = g - self.gradient_mean[i];
142            self.gradient_var[i] += delta * delta2;
143        }
144    }
145
146    /// Detect task boundary using distribution shift
147    pub fn detect_task_boundary(&self, gradients: &[f32]) -> bool {
148        if self.samples_seen < 50 || gradients.len() != self.config.param_count {
149            return false;
150        }
151
152        // Compute z-score of current gradients vs running stats
153        let mut z_score_sum = 0.0f32;
154        let mut count = 0;
155
156        for (i, &g) in gradients.iter().enumerate() {
157            let var = self.gradient_var[i] / self.samples_seen as f32;
158            if var > 1e-8 {
159                let std = var.sqrt();
160                let z = (g - self.gradient_mean[i]).abs() / std;
161                z_score_sum += z;
162                count += 1;
163            }
164        }
165
166        if count == 0 {
167            return false;
168        }
169
170        let avg_z = z_score_sum / count as f32;
171        avg_z > self.config.boundary_threshold
172    }
173
174    /// Start new task - saves current Fisher to memory
175    pub fn start_new_task(&mut self) {
176        // Save current task's Fisher
177        let task_fisher = TaskFisher {
178            task_id: self.current_task_id,
179            fisher: self.current_fisher.clone(),
180            optimal_weights: self.current_weights.clone(),
181            importance: 1.0,
182        };
183
184        // Add to circular buffer
185        if self.task_memory.len() >= self.config.max_tasks {
186            self.task_memory.pop_front();
187        }
188        self.task_memory.push_back(task_fisher);
189
190        // Reset for new task
191        self.current_task_id += 1;
192        self.current_fisher.fill(0.0);
193        self.gradient_history.clear();
194        self.gradient_mean.fill(0.0);
195        self.gradient_var.fill(1.0);
196        self.samples_seen = 0;
197
198        // Adapt lambda based on task count
199        self.adapt_lambda();
200    }
201
202    /// Adapt lambda based on accumulated tasks
203    fn adapt_lambda(&mut self) {
204        let task_count = self.task_memory.len();
205        if task_count == 0 {
206            return;
207        }
208
209        // Increase lambda as more tasks accumulate (more to protect)
210        let scale = 1.0 + 0.1 * task_count as f32;
211        self.lambda = (self.config.initial_lambda * scale)
212            .clamp(self.config.min_lambda, self.config.max_lambda);
213    }
214
215    /// Apply EWC++ constraints to gradients
216    pub fn apply_constraints(&self, gradients: &[f32]) -> Vec<f32> {
217        if gradients.len() != self.config.param_count {
218            return gradients.to_vec();
219        }
220
221        let mut constrained = gradients.to_vec();
222
223        // Apply constraint from each remembered task
224        for task in &self.task_memory {
225            for (i, g) in constrained.iter_mut().enumerate() {
226                // Penalty: lambda * F_i * (w_i - w*_i)
227                // Gradient of penalty: lambda * F_i
228                // Project gradient to preserve important weights
229                let importance = task.fisher[i] * task.importance;
230                if importance > 1e-8 {
231                    let penalty_grad = self.lambda * importance;
232                    // Reduce gradient magnitude for important parameters
233                    *g *= 1.0 / (1.0 + penalty_grad);
234                }
235            }
236        }
237
238        // Also apply current task's Fisher (online)
239        for (i, g) in constrained.iter_mut().enumerate() {
240            if self.current_fisher[i] > 1e-8 {
241                let penalty_grad = self.lambda * self.current_fisher[i] * 0.1; // Lower weight for current
242                *g *= 1.0 / (1.0 + penalty_grad);
243            }
244        }
245
246        constrained
247    }
248
249    /// Compute EWC regularization loss
250    pub fn regularization_loss(&self, current_weights: &[f32]) -> f32 {
251        if current_weights.len() != self.config.param_count {
252            return 0.0;
253        }
254
255        let mut loss = 0.0f32;
256
257        for task in &self.task_memory {
258            for i in 0..self.config.param_count {
259                let diff = current_weights[i] - task.optimal_weights[i];
260                loss += task.fisher[i] * diff * diff * task.importance;
261            }
262        }
263
264        self.lambda * loss / 2.0
265    }
266
267    /// Update optimal weights reference
268    pub fn set_optimal_weights(&mut self, weights: &[f32]) {
269        if weights.len() == self.config.param_count {
270            self.current_weights.copy_from_slice(weights);
271        }
272    }
273
274    /// Consolidate all tasks (merge Fisher information)
275    pub fn consolidate_all_tasks(&mut self) {
276        if self.task_memory.is_empty() {
277            return;
278        }
279
280        // Compute weighted average of Fisher matrices
281        let mut consolidated_fisher = vec![0.0f32; self.config.param_count];
282        let mut total_importance = 0.0f32;
283
284        for task in &self.task_memory {
285            for (i, &f) in task.fisher.iter().enumerate() {
286                consolidated_fisher[i] += f * task.importance;
287            }
288            total_importance += task.importance;
289        }
290
291        if total_importance > 0.0 {
292            for f in &mut consolidated_fisher {
293                *f /= total_importance;
294            }
295        }
296
297        // Store as single consolidated task
298        let consolidated = TaskFisher {
299            task_id: 0,
300            fisher: consolidated_fisher,
301            optimal_weights: self.current_weights.clone(),
302            importance: total_importance,
303        };
304
305        self.task_memory.clear();
306        self.task_memory.push_back(consolidated);
307    }
308
309    /// Get current lambda
310    pub fn lambda(&self) -> f32 {
311        self.lambda
312    }
313
314    /// Set lambda manually
315    pub fn set_lambda(&mut self, lambda: f32) {
316        self.lambda = lambda.clamp(self.config.min_lambda, self.config.max_lambda);
317    }
318
319    /// Get task count
320    pub fn task_count(&self) -> usize {
321        self.task_memory.len()
322    }
323
324    /// Get current task ID
325    pub fn current_task_id(&self) -> usize {
326        self.current_task_id
327    }
328
329    /// Get samples seen for current task
330    pub fn samples_seen(&self) -> u64 {
331        self.samples_seen
332    }
333
334    /// Get parameter importance scores
335    pub fn importance_scores(&self) -> Vec<f32> {
336        let mut scores = self.current_fisher.clone();
337
338        for task in &self.task_memory {
339            for (i, &f) in task.fisher.iter().enumerate() {
340                scores[i] += f * task.importance;
341            }
342        }
343
344        scores
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_ewc_creation() {
354        let config = EwcConfig {
355            param_count: 100,
356            ..Default::default()
357        };
358        let ewc = EwcPlusPlus::new(config);
359
360        assert_eq!(ewc.task_count(), 0);
361        assert_eq!(ewc.current_task_id(), 0);
362    }
363
364    #[test]
365    fn test_fisher_update() {
366        let config = EwcConfig {
367            param_count: 10,
368            ..Default::default()
369        };
370        let mut ewc = EwcPlusPlus::new(config);
371
372        let gradients = vec![0.5; 10];
373        ewc.update_fisher(&gradients);
374
375        assert!(ewc.samples_seen() > 0);
376        assert!(ewc.current_fisher.iter().any(|&f| f > 0.0));
377    }
378
379    #[test]
380    fn test_task_boundary() {
381        let config = EwcConfig {
382            param_count: 10,
383            gradient_history_size: 10,
384            boundary_threshold: 2.0,
385            ..Default::default()
386        };
387        let mut ewc = EwcPlusPlus::new(config);
388
389        // Train on consistent gradients
390        for _ in 0..60 {
391            let gradients = vec![0.1; 10];
392            ewc.update_fisher(&gradients);
393        }
394
395        // Normal gradient should not trigger boundary
396        let normal = vec![0.1; 10];
397        assert!(!ewc.detect_task_boundary(&normal));
398
399        // Very different gradient might trigger boundary
400        let different = vec![10.0; 10];
401        // May or may not trigger depending on variance
402    }
403
404    #[test]
405    fn test_constraint_application() {
406        let config = EwcConfig {
407            param_count: 5,
408            ..Default::default()
409        };
410        let mut ewc = EwcPlusPlus::new(config);
411
412        // Build up some Fisher information
413        for _ in 0..10 {
414            ewc.update_fisher(&vec![1.0; 5]);
415        }
416        ewc.start_new_task();
417
418        // Apply constraints
419        let gradients = vec![1.0; 5];
420        let constrained = ewc.apply_constraints(&gradients);
421
422        // Constrained gradients should be smaller
423        let orig_mag: f32 = gradients.iter().map(|x| x.abs()).sum();
424        let const_mag: f32 = constrained.iter().map(|x| x.abs()).sum();
425        assert!(const_mag <= orig_mag);
426    }
427
428    #[test]
429    fn test_regularization_loss() {
430        let config = EwcConfig {
431            param_count: 5,
432            initial_lambda: 100.0,
433            ..Default::default()
434        };
435        let mut ewc = EwcPlusPlus::new(config);
436
437        // Set up optimal weights and Fisher
438        ewc.set_optimal_weights(&vec![0.0; 5]);
439        for _ in 0..10 {
440            ewc.update_fisher(&vec![1.0; 5]);
441        }
442        ewc.start_new_task();
443
444        // Loss should be zero when at optimal
445        let at_optimal = ewc.regularization_loss(&vec![0.0; 5]);
446
447        // Loss should be positive when deviated
448        let deviated = ewc.regularization_loss(&vec![1.0; 5]);
449        assert!(deviated > at_optimal);
450    }
451
452    #[test]
453    fn test_task_consolidation() {
454        let config = EwcConfig {
455            param_count: 5,
456            max_tasks: 5,
457            ..Default::default()
458        };
459        let mut ewc = EwcPlusPlus::new(config);
460
461        // Create multiple tasks
462        for _ in 0..3 {
463            for _ in 0..10 {
464                ewc.update_fisher(&vec![1.0; 5]);
465            }
466            ewc.start_new_task();
467        }
468
469        assert_eq!(ewc.task_count(), 3);
470
471        ewc.consolidate_all_tasks();
472        assert_eq!(ewc.task_count(), 1);
473    }
474
475    #[test]
476    fn test_lambda_adaptation() {
477        let config = EwcConfig {
478            param_count: 5,
479            initial_lambda: 1000.0,
480            ..Default::default()
481        };
482        let mut ewc = EwcPlusPlus::new(config);
483
484        let initial_lambda = ewc.lambda();
485
486        // Add tasks
487        for _ in 0..5 {
488            ewc.start_new_task();
489        }
490
491        // Lambda should have increased
492        assert!(ewc.lambda() >= initial_lambda);
493    }
494}