Skip to main content

ruvector_sona/
ewc.rs

1//! EWC++ (Enhanced Elastic Weight Consolidation) for SONA
2//!
3//! Prevents catastrophic forgetting with:
4//! - Online Fisher information estimation
5//! - Multi-task memory with circular buffer
6//! - Automatic task boundary detection
7//! - Adaptive lambda scheduling
8
9use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11
12/// EWC++ configuration
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct EwcConfig {
15    /// Number of parameters
16    pub param_count: usize,
17    /// Maximum tasks to remember
18    pub max_tasks: usize,
19    /// Initial lambda
20    pub initial_lambda: f32,
21    /// Minimum lambda
22    pub min_lambda: f32,
23    /// Maximum lambda
24    pub max_lambda: f32,
25    /// Fisher EMA decay factor
26    pub fisher_ema_decay: f32,
27    /// Task boundary detection threshold
28    pub boundary_threshold: f32,
29    /// Gradient history for boundary detection
30    pub gradient_history_size: usize,
31}
32
33impl Default for EwcConfig {
34    fn default() -> Self {
35        // OPTIMIZED DEFAULTS based on @ruvector/sona v0.1.1 benchmarks:
36        // - Lambda 2000 optimal for catastrophic forgetting prevention
37        // - Higher max_lambda (15000) for aggressive protection when needed
38        Self {
39            param_count: 1000,
40            max_tasks: 10,
41            initial_lambda: 2000.0, // OPTIMIZED: Better forgetting prevention
42            min_lambda: 100.0,
43            max_lambda: 15000.0, // OPTIMIZED: Higher ceiling for multi-task
44            fisher_ema_decay: 0.999,
45            boundary_threshold: 2.0,
46            gradient_history_size: 100,
47        }
48    }
49}
50
51/// Task-specific Fisher information
52#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct TaskFisher {
54    /// Task ID
55    pub task_id: usize,
56    /// Fisher diagonal
57    pub fisher: Vec<f32>,
58    /// Optimal weights for this task
59    pub optimal_weights: Vec<f32>,
60    /// Task importance (for weighted consolidation)
61    pub importance: f32,
62}
63
64/// EWC++ implementation
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct EwcPlusPlus {
67    /// Configuration
68    config: EwcConfig,
69    /// Current Fisher information (online estimate)
70    current_fisher: Vec<f32>,
71    /// Current optimal weights
72    current_weights: Vec<f32>,
73    /// Task memory (circular buffer)
74    task_memory: VecDeque<TaskFisher>,
75    /// Current task ID
76    current_task_id: usize,
77    /// Current lambda
78    lambda: f32,
79    /// Gradient history for boundary detection
80    gradient_history: VecDeque<Vec<f32>>,
81    /// Running gradient mean
82    gradient_mean: Vec<f32>,
83    /// Running gradient variance
84    gradient_var: Vec<f32>,
85    /// Samples seen for current task
86    samples_seen: u64,
87}
88
89impl EwcPlusPlus {
90    /// Create new EWC++
91    pub fn new(config: EwcConfig) -> Self {
92        let param_count = config.param_count;
93        let initial_lambda = config.initial_lambda;
94
95        Self {
96            config: config.clone(),
97            current_fisher: vec![0.0; param_count],
98            current_weights: vec![0.0; param_count],
99            task_memory: VecDeque::with_capacity(config.max_tasks),
100            current_task_id: 0,
101            lambda: initial_lambda,
102            gradient_history: VecDeque::with_capacity(config.gradient_history_size),
103            gradient_mean: vec![0.0; param_count],
104            gradient_var: vec![1.0; param_count],
105            samples_seen: 0,
106        }
107    }
108
109    /// Update Fisher information online using EMA
110    pub fn update_fisher(&mut self, gradients: &[f32]) {
111        if gradients.len() != self.config.param_count {
112            return;
113        }
114
115        let decay = self.config.fisher_ema_decay;
116
117        // Online Fisher update: F_t = decay * F_{t-1} + (1 - decay) * g^2
118        for (i, &g) in gradients.iter().enumerate() {
119            self.current_fisher[i] = decay * self.current_fisher[i] + (1.0 - decay) * g * g;
120        }
121
122        // Update gradient statistics for boundary detection
123        self.update_gradient_stats(gradients);
124        self.samples_seen += 1;
125    }
126
127    /// Update gradient statistics for boundary detection
128    fn update_gradient_stats(&mut self, gradients: &[f32]) {
129        // Store in history
130        if self.gradient_history.len() >= self.config.gradient_history_size {
131            self.gradient_history.pop_front();
132        }
133        self.gradient_history.push_back(gradients.to_vec());
134
135        // Update running mean and variance (Welford's algorithm)
136        let n = self.samples_seen as f32 + 1.0;
137
138        for (i, &g) in gradients.iter().enumerate() {
139            let delta = g - self.gradient_mean[i];
140            self.gradient_mean[i] += delta / n;
141            let delta2 = g - self.gradient_mean[i];
142            self.gradient_var[i] += delta * delta2;
143        }
144    }
145
146    /// Detect task boundary using distribution shift
147    pub fn detect_task_boundary(&self, gradients: &[f32]) -> bool {
148        if self.samples_seen < 50 || gradients.len() != self.config.param_count {
149            return false;
150        }
151
152        // Compute z-score of current gradients vs running stats
153        let mut z_score_sum = 0.0f32;
154        let mut count = 0;
155
156        for (i, &g) in gradients.iter().enumerate() {
157            let var = self.gradient_var[i] / self.samples_seen as f32;
158            if var > 1e-8 {
159                let std = var.sqrt();
160                let z = (g - self.gradient_mean[i]).abs() / std;
161                z_score_sum += z;
162                count += 1;
163            }
164        }
165
166        if count == 0 {
167            return false;
168        }
169
170        let avg_z = z_score_sum / count as f32;
171        avg_z > self.config.boundary_threshold
172    }
173
174    /// Start new task - saves current Fisher to memory
175    pub fn start_new_task(&mut self) {
176        // Save current task's Fisher
177        let task_fisher = TaskFisher {
178            task_id: self.current_task_id,
179            fisher: self.current_fisher.clone(),
180            optimal_weights: self.current_weights.clone(),
181            importance: 1.0,
182        };
183
184        // Add to circular buffer
185        if self.task_memory.len() >= self.config.max_tasks {
186            self.task_memory.pop_front();
187        }
188        self.task_memory.push_back(task_fisher);
189
190        // Reset for new task
191        self.current_task_id += 1;
192        self.current_fisher.fill(0.0);
193        self.gradient_history.clear();
194        self.gradient_mean.fill(0.0);
195        self.gradient_var.fill(1.0);
196        self.samples_seen = 0;
197
198        // Adapt lambda based on task count
199        self.adapt_lambda();
200    }
201
202    /// Adapt lambda based on accumulated tasks
203    fn adapt_lambda(&mut self) {
204        let task_count = self.task_memory.len();
205        if task_count == 0 {
206            return;
207        }
208
209        // Increase lambda as more tasks accumulate (more to protect)
210        let scale = 1.0 + 0.1 * task_count as f32;
211        self.lambda = (self.config.initial_lambda * scale)
212            .clamp(self.config.min_lambda, self.config.max_lambda);
213    }
214
215    /// Apply EWC++ constraints to gradients
216    pub fn apply_constraints(&self, gradients: &[f32]) -> Vec<f32> {
217        if gradients.len() != self.config.param_count {
218            return gradients.to_vec();
219        }
220
221        let mut constrained = gradients.to_vec();
222
223        // Apply constraint from each remembered task
224        for task in &self.task_memory {
225            for (i, g) in constrained.iter_mut().enumerate() {
226                // Penalty: lambda * F_i * (w_i - w*_i)
227                // Gradient of penalty: lambda * F_i
228                // Project gradient to preserve important weights
229                let importance = task.fisher[i] * task.importance;
230                if importance > 1e-8 {
231                    let penalty_grad = self.lambda * importance;
232                    // Reduce gradient magnitude for important parameters
233                    *g *= 1.0 / (1.0 + penalty_grad);
234                }
235            }
236        }
237
238        // Also apply current task's Fisher (online)
239        for (i, g) in constrained.iter_mut().enumerate() {
240            if self.current_fisher[i] > 1e-8 {
241                let penalty_grad = self.lambda * self.current_fisher[i] * 0.1; // Lower weight for current
242                *g *= 1.0 / (1.0 + penalty_grad);
243            }
244        }
245
246        constrained
247    }
248
249    /// Compute EWC regularization loss
250    pub fn regularization_loss(&self, current_weights: &[f32]) -> f32 {
251        if current_weights.len() != self.config.param_count {
252            return 0.0;
253        }
254
255        let mut loss = 0.0f32;
256
257        for task in &self.task_memory {
258            for ((&cw, &ow), &fi) in current_weights
259                .iter()
260                .zip(task.optimal_weights.iter())
261                .zip(task.fisher.iter())
262                .take(self.config.param_count)
263            {
264                let diff = cw - ow;
265                loss += fi * diff * diff * task.importance;
266            }
267        }
268
269        self.lambda * loss / 2.0
270    }
271
272    /// Update optimal weights reference
273    pub fn set_optimal_weights(&mut self, weights: &[f32]) {
274        if weights.len() == self.config.param_count {
275            self.current_weights.copy_from_slice(weights);
276        }
277    }
278
279    /// Consolidate all tasks (merge Fisher information)
280    pub fn consolidate_all_tasks(&mut self) {
281        if self.task_memory.is_empty() {
282            return;
283        }
284
285        // Compute weighted average of Fisher matrices
286        let mut consolidated_fisher = vec![0.0f32; self.config.param_count];
287        let mut total_importance = 0.0f32;
288
289        for task in &self.task_memory {
290            for (i, &f) in task.fisher.iter().enumerate() {
291                consolidated_fisher[i] += f * task.importance;
292            }
293            total_importance += task.importance;
294        }
295
296        if total_importance > 0.0 {
297            for f in &mut consolidated_fisher {
298                *f /= total_importance;
299            }
300        }
301
302        // Store as single consolidated task
303        let consolidated = TaskFisher {
304            task_id: 0,
305            fisher: consolidated_fisher,
306            optimal_weights: self.current_weights.clone(),
307            importance: total_importance,
308        };
309
310        self.task_memory.clear();
311        self.task_memory.push_back(consolidated);
312    }
313
314    /// Get current lambda
315    pub fn lambda(&self) -> f32 {
316        self.lambda
317    }
318
319    /// Set lambda manually
320    pub fn set_lambda(&mut self, lambda: f32) {
321        self.lambda = lambda.clamp(self.config.min_lambda, self.config.max_lambda);
322    }
323
324    /// Get task count
325    pub fn task_count(&self) -> usize {
326        self.task_memory.len()
327    }
328
329    /// Get current task ID
330    pub fn current_task_id(&self) -> usize {
331        self.current_task_id
332    }
333
334    /// Get samples seen for current task
335    pub fn samples_seen(&self) -> u64 {
336        self.samples_seen
337    }
338
339    /// Get parameter importance scores
340    pub fn importance_scores(&self) -> Vec<f32> {
341        let mut scores = self.current_fisher.clone();
342
343        for task in &self.task_memory {
344            for (i, &f) in task.fisher.iter().enumerate() {
345                scores[i] += f * task.importance;
346            }
347        }
348
349        scores
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn test_ewc_creation() {
359        let config = EwcConfig {
360            param_count: 100,
361            ..Default::default()
362        };
363        let ewc = EwcPlusPlus::new(config);
364
365        assert_eq!(ewc.task_count(), 0);
366        assert_eq!(ewc.current_task_id(), 0);
367    }
368
369    #[test]
370    fn test_fisher_update() {
371        let config = EwcConfig {
372            param_count: 10,
373            ..Default::default()
374        };
375        let mut ewc = EwcPlusPlus::new(config);
376
377        let gradients = vec![0.5; 10];
378        ewc.update_fisher(&gradients);
379
380        assert!(ewc.samples_seen() > 0);
381        assert!(ewc.current_fisher.iter().any(|&f| f > 0.0));
382    }
383
384    #[test]
385    fn test_task_boundary() {
386        let config = EwcConfig {
387            param_count: 10,
388            gradient_history_size: 10,
389            boundary_threshold: 2.0,
390            ..Default::default()
391        };
392        let mut ewc = EwcPlusPlus::new(config);
393
394        // Train on consistent gradients
395        for _ in 0..60 {
396            let gradients = vec![0.1; 10];
397            ewc.update_fisher(&gradients);
398        }
399
400        // Normal gradient should not trigger boundary
401        let normal = vec![0.1; 10];
402        assert!(!ewc.detect_task_boundary(&normal));
403
404        // Very different gradient might trigger boundary
405        let different = vec![10.0; 10];
406        // May or may not trigger depending on variance
407    }
408
409    #[test]
410    fn test_constraint_application() {
411        let config = EwcConfig {
412            param_count: 5,
413            ..Default::default()
414        };
415        let mut ewc = EwcPlusPlus::new(config);
416
417        // Build up some Fisher information
418        for _ in 0..10 {
419            ewc.update_fisher(&vec![1.0; 5]);
420        }
421        ewc.start_new_task();
422
423        // Apply constraints
424        let gradients = vec![1.0; 5];
425        let constrained = ewc.apply_constraints(&gradients);
426
427        // Constrained gradients should be smaller
428        let orig_mag: f32 = gradients.iter().map(|x| x.abs()).sum();
429        let const_mag: f32 = constrained.iter().map(|x| x.abs()).sum();
430        assert!(const_mag <= orig_mag);
431    }
432
433    #[test]
434    fn test_regularization_loss() {
435        let config = EwcConfig {
436            param_count: 5,
437            initial_lambda: 100.0,
438            ..Default::default()
439        };
440        let mut ewc = EwcPlusPlus::new(config);
441
442        // Set up optimal weights and Fisher
443        ewc.set_optimal_weights(&vec![0.0; 5]);
444        for _ in 0..10 {
445            ewc.update_fisher(&vec![1.0; 5]);
446        }
447        ewc.start_new_task();
448
449        // Loss should be zero when at optimal
450        let at_optimal = ewc.regularization_loss(&vec![0.0; 5]);
451
452        // Loss should be positive when deviated
453        let deviated = ewc.regularization_loss(&vec![1.0; 5]);
454        assert!(deviated > at_optimal);
455    }
456
457    #[test]
458    fn test_task_consolidation() {
459        let config = EwcConfig {
460            param_count: 5,
461            max_tasks: 5,
462            ..Default::default()
463        };
464        let mut ewc = EwcPlusPlus::new(config);
465
466        // Create multiple tasks
467        for _ in 0..3 {
468            for _ in 0..10 {
469                ewc.update_fisher(&vec![1.0; 5]);
470            }
471            ewc.start_new_task();
472        }
473
474        assert_eq!(ewc.task_count(), 3);
475
476        ewc.consolidate_all_tasks();
477        assert_eq!(ewc.task_count(), 1);
478    }
479
480    #[test]
481    fn test_lambda_adaptation() {
482        let config = EwcConfig {
483            param_count: 5,
484            initial_lambda: 1000.0,
485            ..Default::default()
486        };
487        let mut ewc = EwcPlusPlus::new(config);
488
489        let initial_lambda = ewc.lambda();
490
491        // Add tasks
492        for _ in 0..5 {
493            ewc.start_new_task();
494        }
495
496        // Lambda should have increased
497        assert!(ewc.lambda() >= initial_lambda);
498    }
499}