1use ndarray::Array1;
4
5#[derive(Debug, Clone)]
6pub struct EwcConfig {
7 pub lambda: f32, pub decay: f32, pub online: bool, }
11
12impl Default for EwcConfig {
13 fn default() -> Self {
14 Self {
15 lambda: 5000.0,
16 decay: 0.99,
17 online: true,
18 }
19 }
20}
21
22pub struct EwcPlusPlus {
23 config: EwcConfig,
24 fisher_diag: Option<Array1<f32>>,
25 optimal_params: Option<Array1<f32>>,
26 task_count: usize,
27}
28
29impl EwcPlusPlus {
30 pub fn new(config: EwcConfig) -> Self {
31 Self {
32 config,
33 fisher_diag: None,
34 optimal_params: None,
35 task_count: 0,
36 }
37 }
38
39 pub fn consolidate(&mut self, params: &Array1<f32>, fisher: &Array1<f32>) {
41 if self.config.online && self.fisher_diag.is_some() {
42 let current_fisher = self.fisher_diag.as_ref().unwrap();
44 self.fisher_diag =
45 Some(current_fisher * self.config.decay + fisher * (1.0 - self.config.decay));
46 } else {
47 self.fisher_diag = Some(fisher.clone());
48 }
49
50 self.optimal_params = Some(params.clone());
51 self.task_count += 1;
52 }
53
54 pub fn penalty(&self, params: &Array1<f32>) -> f32 {
56 match (&self.fisher_diag, &self.optimal_params) {
57 (Some(fisher), Some(optimal)) => {
58 let diff = params - optimal;
59 let weighted = &diff * &diff * fisher;
60 0.5 * self.config.lambda * weighted.sum()
61 }
62 _ => 0.0,
63 }
64 }
65
66 pub fn penalty_gradient(&self, params: &Array1<f32>) -> Option<Array1<f32>> {
68 match (&self.fisher_diag, &self.optimal_params) {
69 (Some(fisher), Some(optimal)) => {
70 let diff = params - optimal;
71 Some(self.config.lambda * fisher * &diff)
72 }
73 _ => None,
74 }
75 }
76
77 pub fn compute_fisher(gradients: &[Array1<f32>]) -> Array1<f32> {
79 if gradients.is_empty() {
80 return Array1::zeros(0);
81 }
82
83 let dim = gradients[0].len();
84 let mut fisher = Array1::zeros(dim);
85
86 for grad in gradients {
87 fisher = fisher + grad.mapv(|x| x * x);
88 }
89
90 fisher / gradients.len() as f32
91 }
92
93 pub fn has_prior(&self) -> bool {
94 self.fisher_diag.is_some()
95 }
96
97 pub fn task_count(&self) -> usize {
98 self.task_count
99 }
100}