ruvector_sona/
ewc.rs

1//! EWC++ (Enhanced Elastic Weight Consolidation) for SONA
2//!
3//! Prevents catastrophic forgetting with:
4//! - Online Fisher information estimation
5//! - Multi-task memory with circular buffer
6//! - Automatic task boundary detection
7//! - Adaptive lambda scheduling
8
9use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11
12/// EWC++ configuration
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct EwcConfig {
15    /// Number of parameters
16    pub param_count: usize,
17    /// Maximum tasks to remember
18    pub max_tasks: usize,
19    /// Initial lambda
20    pub initial_lambda: f32,
21    /// Minimum lambda
22    pub min_lambda: f32,
23    /// Maximum lambda
24    pub max_lambda: f32,
25    /// Fisher EMA decay factor
26    pub fisher_ema_decay: f32,
27    /// Task boundary detection threshold
28    pub boundary_threshold: f32,
29    /// Gradient history for boundary detection
30    pub gradient_history_size: usize,
31}
32
33impl Default for EwcConfig {
34    fn default() -> Self {
35        Self {
36            param_count: 1000,
37            max_tasks: 10,
38            initial_lambda: 1000.0,
39            min_lambda: 100.0,
40            max_lambda: 10000.0,
41            fisher_ema_decay: 0.999,
42            boundary_threshold: 2.0,
43            gradient_history_size: 100,
44        }
45    }
46}
47
48/// Task-specific Fisher information
49#[derive(Clone, Debug, Serialize, Deserialize)]
50pub struct TaskFisher {
51    /// Task ID
52    pub task_id: usize,
53    /// Fisher diagonal
54    pub fisher: Vec<f32>,
55    /// Optimal weights for this task
56    pub optimal_weights: Vec<f32>,
57    /// Task importance (for weighted consolidation)
58    pub importance: f32,
59}
60
61/// EWC++ implementation
62#[derive(Clone, Debug, Serialize, Deserialize)]
63pub struct EwcPlusPlus {
64    /// Configuration
65    config: EwcConfig,
66    /// Current Fisher information (online estimate)
67    current_fisher: Vec<f32>,
68    /// Current optimal weights
69    current_weights: Vec<f32>,
70    /// Task memory (circular buffer)
71    task_memory: VecDeque<TaskFisher>,
72    /// Current task ID
73    current_task_id: usize,
74    /// Current lambda
75    lambda: f32,
76    /// Gradient history for boundary detection
77    gradient_history: VecDeque<Vec<f32>>,
78    /// Running gradient mean
79    gradient_mean: Vec<f32>,
80    /// Running gradient variance
81    gradient_var: Vec<f32>,
82    /// Samples seen for current task
83    samples_seen: u64,
84}
85
86impl EwcPlusPlus {
87    /// Create new EWC++
88    pub fn new(config: EwcConfig) -> Self {
89        let param_count = config.param_count;
90        let initial_lambda = config.initial_lambda;
91
92        Self {
93            config: config.clone(),
94            current_fisher: vec![0.0; param_count],
95            current_weights: vec![0.0; param_count],
96            task_memory: VecDeque::with_capacity(config.max_tasks),
97            current_task_id: 0,
98            lambda: initial_lambda,
99            gradient_history: VecDeque::with_capacity(config.gradient_history_size),
100            gradient_mean: vec![0.0; param_count],
101            gradient_var: vec![1.0; param_count],
102            samples_seen: 0,
103        }
104    }
105
106    /// Update Fisher information online using EMA
107    pub fn update_fisher(&mut self, gradients: &[f32]) {
108        if gradients.len() != self.config.param_count {
109            return;
110        }
111
112        let decay = self.config.fisher_ema_decay;
113
114        // Online Fisher update: F_t = decay * F_{t-1} + (1 - decay) * g^2
115        for (i, &g) in gradients.iter().enumerate() {
116            self.current_fisher[i] = decay * self.current_fisher[i] + (1.0 - decay) * g * g;
117        }
118
119        // Update gradient statistics for boundary detection
120        self.update_gradient_stats(gradients);
121        self.samples_seen += 1;
122    }
123
124    /// Update gradient statistics for boundary detection
125    fn update_gradient_stats(&mut self, gradients: &[f32]) {
126        // Store in history
127        if self.gradient_history.len() >= self.config.gradient_history_size {
128            self.gradient_history.pop_front();
129        }
130        self.gradient_history.push_back(gradients.to_vec());
131
132        // Update running mean and variance (Welford's algorithm)
133        let n = self.samples_seen as f32 + 1.0;
134
135        for (i, &g) in gradients.iter().enumerate() {
136            let delta = g - self.gradient_mean[i];
137            self.gradient_mean[i] += delta / n;
138            let delta2 = g - self.gradient_mean[i];
139            self.gradient_var[i] += delta * delta2;
140        }
141    }
142
143    /// Detect task boundary using distribution shift
144    pub fn detect_task_boundary(&self, gradients: &[f32]) -> bool {
145        if self.samples_seen < 50 || gradients.len() != self.config.param_count {
146            return false;
147        }
148
149        // Compute z-score of current gradients vs running stats
150        let mut z_score_sum = 0.0f32;
151        let mut count = 0;
152
153        for (i, &g) in gradients.iter().enumerate() {
154            let var = self.gradient_var[i] / self.samples_seen as f32;
155            if var > 1e-8 {
156                let std = var.sqrt();
157                let z = (g - self.gradient_mean[i]).abs() / std;
158                z_score_sum += z;
159                count += 1;
160            }
161        }
162
163        if count == 0 {
164            return false;
165        }
166
167        let avg_z = z_score_sum / count as f32;
168        avg_z > self.config.boundary_threshold
169    }
170
171    /// Start new task - saves current Fisher to memory
172    pub fn start_new_task(&mut self) {
173        // Save current task's Fisher
174        let task_fisher = TaskFisher {
175            task_id: self.current_task_id,
176            fisher: self.current_fisher.clone(),
177            optimal_weights: self.current_weights.clone(),
178            importance: 1.0,
179        };
180
181        // Add to circular buffer
182        if self.task_memory.len() >= self.config.max_tasks {
183            self.task_memory.pop_front();
184        }
185        self.task_memory.push_back(task_fisher);
186
187        // Reset for new task
188        self.current_task_id += 1;
189        self.current_fisher.fill(0.0);
190        self.gradient_history.clear();
191        self.gradient_mean.fill(0.0);
192        self.gradient_var.fill(1.0);
193        self.samples_seen = 0;
194
195        // Adapt lambda based on task count
196        self.adapt_lambda();
197    }
198
199    /// Adapt lambda based on accumulated tasks
200    fn adapt_lambda(&mut self) {
201        let task_count = self.task_memory.len();
202        if task_count == 0 {
203            return;
204        }
205
206        // Increase lambda as more tasks accumulate (more to protect)
207        let scale = 1.0 + 0.1 * task_count as f32;
208        self.lambda = (self.config.initial_lambda * scale)
209            .clamp(self.config.min_lambda, self.config.max_lambda);
210    }
211
212    /// Apply EWC++ constraints to gradients
213    pub fn apply_constraints(&self, gradients: &[f32]) -> Vec<f32> {
214        if gradients.len() != self.config.param_count {
215            return gradients.to_vec();
216        }
217
218        let mut constrained = gradients.to_vec();
219
220        // Apply constraint from each remembered task
221        for task in &self.task_memory {
222            for (i, g) in constrained.iter_mut().enumerate() {
223                // Penalty: lambda * F_i * (w_i - w*_i)
224                // Gradient of penalty: lambda * F_i
225                // Project gradient to preserve important weights
226                let importance = task.fisher[i] * task.importance;
227                if importance > 1e-8 {
228                    let penalty_grad = self.lambda * importance;
229                    // Reduce gradient magnitude for important parameters
230                    *g *= 1.0 / (1.0 + penalty_grad);
231                }
232            }
233        }
234
235        // Also apply current task's Fisher (online)
236        for (i, g) in constrained.iter_mut().enumerate() {
237            if self.current_fisher[i] > 1e-8 {
238                let penalty_grad = self.lambda * self.current_fisher[i] * 0.1; // Lower weight for current
239                *g *= 1.0 / (1.0 + penalty_grad);
240            }
241        }
242
243        constrained
244    }
245
246    /// Compute EWC regularization loss
247    pub fn regularization_loss(&self, current_weights: &[f32]) -> f32 {
248        if current_weights.len() != self.config.param_count {
249            return 0.0;
250        }
251
252        let mut loss = 0.0f32;
253
254        for task in &self.task_memory {
255            for i in 0..self.config.param_count {
256                let diff = current_weights[i] - task.optimal_weights[i];
257                loss += task.fisher[i] * diff * diff * task.importance;
258            }
259        }
260
261        self.lambda * loss / 2.0
262    }
263
264    /// Update optimal weights reference
265    pub fn set_optimal_weights(&mut self, weights: &[f32]) {
266        if weights.len() == self.config.param_count {
267            self.current_weights.copy_from_slice(weights);
268        }
269    }
270
271    /// Consolidate all tasks (merge Fisher information)
272    pub fn consolidate_all_tasks(&mut self) {
273        if self.task_memory.is_empty() {
274            return;
275        }
276
277        // Compute weighted average of Fisher matrices
278        let mut consolidated_fisher = vec![0.0f32; self.config.param_count];
279        let mut total_importance = 0.0f32;
280
281        for task in &self.task_memory {
282            for (i, &f) in task.fisher.iter().enumerate() {
283                consolidated_fisher[i] += f * task.importance;
284            }
285            total_importance += task.importance;
286        }
287
288        if total_importance > 0.0 {
289            for f in &mut consolidated_fisher {
290                *f /= total_importance;
291            }
292        }
293
294        // Store as single consolidated task
295        let consolidated = TaskFisher {
296            task_id: 0,
297            fisher: consolidated_fisher,
298            optimal_weights: self.current_weights.clone(),
299            importance: total_importance,
300        };
301
302        self.task_memory.clear();
303        self.task_memory.push_back(consolidated);
304    }
305
306    /// Get current lambda
307    pub fn lambda(&self) -> f32 {
308        self.lambda
309    }
310
311    /// Set lambda manually
312    pub fn set_lambda(&mut self, lambda: f32) {
313        self.lambda = lambda.clamp(self.config.min_lambda, self.config.max_lambda);
314    }
315
316    /// Get task count
317    pub fn task_count(&self) -> usize {
318        self.task_memory.len()
319    }
320
321    /// Get current task ID
322    pub fn current_task_id(&self) -> usize {
323        self.current_task_id
324    }
325
326    /// Get samples seen for current task
327    pub fn samples_seen(&self) -> u64 {
328        self.samples_seen
329    }
330
331    /// Get parameter importance scores
332    pub fn importance_scores(&self) -> Vec<f32> {
333        let mut scores = self.current_fisher.clone();
334
335        for task in &self.task_memory {
336            for (i, &f) in task.fisher.iter().enumerate() {
337                scores[i] += f * task.importance;
338            }
339        }
340
341        scores
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_ewc_creation() {
351        let config = EwcConfig {
352            param_count: 100,
353            ..Default::default()
354        };
355        let ewc = EwcPlusPlus::new(config);
356
357        assert_eq!(ewc.task_count(), 0);
358        assert_eq!(ewc.current_task_id(), 0);
359    }
360
361    #[test]
362    fn test_fisher_update() {
363        let config = EwcConfig {
364            param_count: 10,
365            ..Default::default()
366        };
367        let mut ewc = EwcPlusPlus::new(config);
368
369        let gradients = vec![0.5; 10];
370        ewc.update_fisher(&gradients);
371
372        assert!(ewc.samples_seen() > 0);
373        assert!(ewc.current_fisher.iter().any(|&f| f > 0.0));
374    }
375
376    #[test]
377    fn test_task_boundary() {
378        let config = EwcConfig {
379            param_count: 10,
380            gradient_history_size: 10,
381            boundary_threshold: 2.0,
382            ..Default::default()
383        };
384        let mut ewc = EwcPlusPlus::new(config);
385
386        // Train on consistent gradients
387        for _ in 0..60 {
388            let gradients = vec![0.1; 10];
389            ewc.update_fisher(&gradients);
390        }
391
392        // Normal gradient should not trigger boundary
393        let normal = vec![0.1; 10];
394        assert!(!ewc.detect_task_boundary(&normal));
395
396        // Very different gradient might trigger boundary
397        let different = vec![10.0; 10];
398        // May or may not trigger depending on variance
399    }
400
401    #[test]
402    fn test_constraint_application() {
403        let config = EwcConfig {
404            param_count: 5,
405            ..Default::default()
406        };
407        let mut ewc = EwcPlusPlus::new(config);
408
409        // Build up some Fisher information
410        for _ in 0..10 {
411            ewc.update_fisher(&vec![1.0; 5]);
412        }
413        ewc.start_new_task();
414
415        // Apply constraints
416        let gradients = vec![1.0; 5];
417        let constrained = ewc.apply_constraints(&gradients);
418
419        // Constrained gradients should be smaller
420        let orig_mag: f32 = gradients.iter().map(|x| x.abs()).sum();
421        let const_mag: f32 = constrained.iter().map(|x| x.abs()).sum();
422        assert!(const_mag <= orig_mag);
423    }
424
425    #[test]
426    fn test_regularization_loss() {
427        let config = EwcConfig {
428            param_count: 5,
429            initial_lambda: 100.0,
430            ..Default::default()
431        };
432        let mut ewc = EwcPlusPlus::new(config);
433
434        // Set up optimal weights and Fisher
435        ewc.set_optimal_weights(&vec![0.0; 5]);
436        for _ in 0..10 {
437            ewc.update_fisher(&vec![1.0; 5]);
438        }
439        ewc.start_new_task();
440
441        // Loss should be zero when at optimal
442        let at_optimal = ewc.regularization_loss(&vec![0.0; 5]);
443
444        // Loss should be positive when deviated
445        let deviated = ewc.regularization_loss(&vec![1.0; 5]);
446        assert!(deviated > at_optimal);
447    }
448
449    #[test]
450    fn test_task_consolidation() {
451        let config = EwcConfig {
452            param_count: 5,
453            max_tasks: 5,
454            ..Default::default()
455        };
456        let mut ewc = EwcPlusPlus::new(config);
457
458        // Create multiple tasks
459        for _ in 0..3 {
460            for _ in 0..10 {
461                ewc.update_fisher(&vec![1.0; 5]);
462            }
463            ewc.start_new_task();
464        }
465
466        assert_eq!(ewc.task_count(), 3);
467
468        ewc.consolidate_all_tasks();
469        assert_eq!(ewc.task_count(), 1);
470    }
471
472    #[test]
473    fn test_lambda_adaptation() {
474        let config = EwcConfig {
475            param_count: 5,
476            initial_lambda: 1000.0,
477            ..Default::default()
478        };
479        let mut ewc = EwcPlusPlus::new(config);
480
481        let initial_lambda = ewc.lambda();
482
483        // Add tasks
484        for _ in 0..5 {
485            ewc.start_new_task();
486        }
487
488        // Lambda should have increased
489        assert!(ewc.lambda() >= initial_lambda);
490    }
491}