Skip to main content

ruvector_dag/sona/
ewc.rs

1//! EWC++: Elastic Weight Consolidation to prevent forgetting
2
3use ndarray::Array1;
4
5#[derive(Debug, Clone)]
6pub struct EwcConfig {
7    pub lambda: f32,  // Importance weight (2000-15000)
8    pub decay: f32,   // Fisher decay rate
9    pub online: bool, // Use online EWC
10}
11
12impl Default for EwcConfig {
13    fn default() -> Self {
14        Self {
15            lambda: 5000.0,
16            decay: 0.99,
17            online: true,
18        }
19    }
20}
21
22pub struct EwcPlusPlus {
23    config: EwcConfig,
24    fisher_diag: Option<Array1<f32>>,
25    optimal_params: Option<Array1<f32>>,
26    task_count: usize,
27}
28
29impl EwcPlusPlus {
30    pub fn new(config: EwcConfig) -> Self {
31        Self {
32            config,
33            fisher_diag: None,
34            optimal_params: None,
35            task_count: 0,
36        }
37    }
38
39    /// Consolidate current parameters after training
40    pub fn consolidate(&mut self, params: &Array1<f32>, fisher: &Array1<f32>) {
41        if self.config.online && self.fisher_diag.is_some() {
42            // Online EWC: accumulate Fisher information
43            let current_fisher = self.fisher_diag.as_ref().unwrap();
44            self.fisher_diag =
45                Some(current_fisher * self.config.decay + fisher * (1.0 - self.config.decay));
46        } else {
47            self.fisher_diag = Some(fisher.clone());
48        }
49
50        self.optimal_params = Some(params.clone());
51        self.task_count += 1;
52    }
53
54    /// Compute EWC penalty for given parameters
55    pub fn penalty(&self, params: &Array1<f32>) -> f32 {
56        match (&self.fisher_diag, &self.optimal_params) {
57            (Some(fisher), Some(optimal)) => {
58                let diff = params - optimal;
59                let weighted = &diff * &diff * fisher;
60                0.5 * self.config.lambda * weighted.sum()
61            }
62            _ => 0.0,
63        }
64    }
65
66    /// Compute gradient of EWC penalty
67    pub fn penalty_gradient(&self, params: &Array1<f32>) -> Option<Array1<f32>> {
68        match (&self.fisher_diag, &self.optimal_params) {
69            (Some(fisher), Some(optimal)) => {
70                let diff = params - optimal;
71                Some(self.config.lambda * fisher * &diff)
72            }
73            _ => None,
74        }
75    }
76
77    /// Compute Fisher information from gradients
78    pub fn compute_fisher(gradients: &[Array1<f32>]) -> Array1<f32> {
79        if gradients.is_empty() {
80            return Array1::zeros(0);
81        }
82
83        let dim = gradients[0].len();
84        let mut fisher = Array1::zeros(dim);
85
86        for grad in gradients {
87            fisher = fisher + grad.mapv(|x| x * x);
88        }
89
90        fisher / gradients.len() as f32
91    }
92
93    pub fn has_prior(&self) -> bool {
94        self.fisher_diag.is_some()
95    }
96
97    pub fn task_count(&self) -> usize {
98        self.task_count
99    }
100}