1use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct EwcConfig {
15 pub param_count: usize,
17 pub max_tasks: usize,
19 pub initial_lambda: f32,
21 pub min_lambda: f32,
23 pub max_lambda: f32,
25 pub fisher_ema_decay: f32,
27 pub boundary_threshold: f32,
29 pub gradient_history_size: usize,
31}
32
33impl Default for EwcConfig {
34 fn default() -> Self {
35 Self {
39 param_count: 1000,
40 max_tasks: 10,
41 initial_lambda: 2000.0, min_lambda: 100.0,
43 max_lambda: 15000.0, fisher_ema_decay: 0.999,
45 boundary_threshold: 2.0,
46 gradient_history_size: 100,
47 }
48 }
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct TaskFisher {
54 pub task_id: usize,
56 pub fisher: Vec<f32>,
58 pub optimal_weights: Vec<f32>,
60 pub importance: f32,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct EwcPlusPlus {
67 config: EwcConfig,
69 current_fisher: Vec<f32>,
71 current_weights: Vec<f32>,
73 task_memory: VecDeque<TaskFisher>,
75 current_task_id: usize,
77 lambda: f32,
79 gradient_history: VecDeque<Vec<f32>>,
81 gradient_mean: Vec<f32>,
83 gradient_var: Vec<f32>,
85 samples_seen: u64,
87}
88
89impl EwcPlusPlus {
90 pub fn new(config: EwcConfig) -> Self {
92 let param_count = config.param_count;
93 let initial_lambda = config.initial_lambda;
94
95 Self {
96 config: config.clone(),
97 current_fisher: vec![0.0; param_count],
98 current_weights: vec![0.0; param_count],
99 task_memory: VecDeque::with_capacity(config.max_tasks),
100 current_task_id: 0,
101 lambda: initial_lambda,
102 gradient_history: VecDeque::with_capacity(config.gradient_history_size),
103 gradient_mean: vec![0.0; param_count],
104 gradient_var: vec![1.0; param_count],
105 samples_seen: 0,
106 }
107 }
108
109 pub fn update_fisher(&mut self, gradients: &[f32]) {
111 if gradients.len() != self.config.param_count {
112 return;
113 }
114
115 let decay = self.config.fisher_ema_decay;
116
117 for (i, &g) in gradients.iter().enumerate() {
119 self.current_fisher[i] = decay * self.current_fisher[i] + (1.0 - decay) * g * g;
120 }
121
122 self.update_gradient_stats(gradients);
124 self.samples_seen += 1;
125 }
126
127 fn update_gradient_stats(&mut self, gradients: &[f32]) {
129 if self.gradient_history.len() >= self.config.gradient_history_size {
131 self.gradient_history.pop_front();
132 }
133 self.gradient_history.push_back(gradients.to_vec());
134
135 let n = self.samples_seen as f32 + 1.0;
137
138 for (i, &g) in gradients.iter().enumerate() {
139 let delta = g - self.gradient_mean[i];
140 self.gradient_mean[i] += delta / n;
141 let delta2 = g - self.gradient_mean[i];
142 self.gradient_var[i] += delta * delta2;
143 }
144 }
145
146 pub fn detect_task_boundary(&self, gradients: &[f32]) -> bool {
148 if self.samples_seen < 50 || gradients.len() != self.config.param_count {
149 return false;
150 }
151
152 let mut z_score_sum = 0.0f32;
154 let mut count = 0;
155
156 for (i, &g) in gradients.iter().enumerate() {
157 let var = self.gradient_var[i] / self.samples_seen as f32;
158 if var > 1e-8 {
159 let std = var.sqrt();
160 let z = (g - self.gradient_mean[i]).abs() / std;
161 z_score_sum += z;
162 count += 1;
163 }
164 }
165
166 if count == 0 {
167 return false;
168 }
169
170 let avg_z = z_score_sum / count as f32;
171 avg_z > self.config.boundary_threshold
172 }
173
174 pub fn start_new_task(&mut self) {
176 let task_fisher = TaskFisher {
178 task_id: self.current_task_id,
179 fisher: self.current_fisher.clone(),
180 optimal_weights: self.current_weights.clone(),
181 importance: 1.0,
182 };
183
184 if self.task_memory.len() >= self.config.max_tasks {
186 self.task_memory.pop_front();
187 }
188 self.task_memory.push_back(task_fisher);
189
190 self.current_task_id += 1;
192 self.current_fisher.fill(0.0);
193 self.gradient_history.clear();
194 self.gradient_mean.fill(0.0);
195 self.gradient_var.fill(1.0);
196 self.samples_seen = 0;
197
198 self.adapt_lambda();
200 }
201
202 fn adapt_lambda(&mut self) {
204 let task_count = self.task_memory.len();
205 if task_count == 0 {
206 return;
207 }
208
209 let scale = 1.0 + 0.1 * task_count as f32;
211 self.lambda = (self.config.initial_lambda * scale)
212 .clamp(self.config.min_lambda, self.config.max_lambda);
213 }
214
215 pub fn apply_constraints(&self, gradients: &[f32]) -> Vec<f32> {
217 if gradients.len() != self.config.param_count {
218 return gradients.to_vec();
219 }
220
221 let mut constrained = gradients.to_vec();
222
223 for task in &self.task_memory {
225 for (i, g) in constrained.iter_mut().enumerate() {
226 let importance = task.fisher[i] * task.importance;
230 if importance > 1e-8 {
231 let penalty_grad = self.lambda * importance;
232 *g *= 1.0 / (1.0 + penalty_grad);
234 }
235 }
236 }
237
238 for (i, g) in constrained.iter_mut().enumerate() {
240 if self.current_fisher[i] > 1e-8 {
241 let penalty_grad = self.lambda * self.current_fisher[i] * 0.1; *g *= 1.0 / (1.0 + penalty_grad);
243 }
244 }
245
246 constrained
247 }
248
249 pub fn regularization_loss(&self, current_weights: &[f32]) -> f32 {
251 if current_weights.len() != self.config.param_count {
252 return 0.0;
253 }
254
255 let mut loss = 0.0f32;
256
257 for task in &self.task_memory {
258 for i in 0..self.config.param_count {
259 let diff = current_weights[i] - task.optimal_weights[i];
260 loss += task.fisher[i] * diff * diff * task.importance;
261 }
262 }
263
264 self.lambda * loss / 2.0
265 }
266
267 pub fn set_optimal_weights(&mut self, weights: &[f32]) {
269 if weights.len() == self.config.param_count {
270 self.current_weights.copy_from_slice(weights);
271 }
272 }
273
274 pub fn consolidate_all_tasks(&mut self) {
276 if self.task_memory.is_empty() {
277 return;
278 }
279
280 let mut consolidated_fisher = vec![0.0f32; self.config.param_count];
282 let mut total_importance = 0.0f32;
283
284 for task in &self.task_memory {
285 for (i, &f) in task.fisher.iter().enumerate() {
286 consolidated_fisher[i] += f * task.importance;
287 }
288 total_importance += task.importance;
289 }
290
291 if total_importance > 0.0 {
292 for f in &mut consolidated_fisher {
293 *f /= total_importance;
294 }
295 }
296
297 let consolidated = TaskFisher {
299 task_id: 0,
300 fisher: consolidated_fisher,
301 optimal_weights: self.current_weights.clone(),
302 importance: total_importance,
303 };
304
305 self.task_memory.clear();
306 self.task_memory.push_back(consolidated);
307 }
308
309 pub fn lambda(&self) -> f32 {
311 self.lambda
312 }
313
314 pub fn set_lambda(&mut self, lambda: f32) {
316 self.lambda = lambda.clamp(self.config.min_lambda, self.config.max_lambda);
317 }
318
319 pub fn task_count(&self) -> usize {
321 self.task_memory.len()
322 }
323
324 pub fn current_task_id(&self) -> usize {
326 self.current_task_id
327 }
328
329 pub fn samples_seen(&self) -> u64 {
331 self.samples_seen
332 }
333
334 pub fn importance_scores(&self) -> Vec<f32> {
336 let mut scores = self.current_fisher.clone();
337
338 for task in &self.task_memory {
339 for (i, &f) in task.fisher.iter().enumerate() {
340 scores[i] += f * task.importance;
341 }
342 }
343
344 scores
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_ewc_creation() {
354 let config = EwcConfig {
355 param_count: 100,
356 ..Default::default()
357 };
358 let ewc = EwcPlusPlus::new(config);
359
360 assert_eq!(ewc.task_count(), 0);
361 assert_eq!(ewc.current_task_id(), 0);
362 }
363
364 #[test]
365 fn test_fisher_update() {
366 let config = EwcConfig {
367 param_count: 10,
368 ..Default::default()
369 };
370 let mut ewc = EwcPlusPlus::new(config);
371
372 let gradients = vec![0.5; 10];
373 ewc.update_fisher(&gradients);
374
375 assert!(ewc.samples_seen() > 0);
376 assert!(ewc.current_fisher.iter().any(|&f| f > 0.0));
377 }
378
379 #[test]
380 fn test_task_boundary() {
381 let config = EwcConfig {
382 param_count: 10,
383 gradient_history_size: 10,
384 boundary_threshold: 2.0,
385 ..Default::default()
386 };
387 let mut ewc = EwcPlusPlus::new(config);
388
389 for _ in 0..60 {
391 let gradients = vec![0.1; 10];
392 ewc.update_fisher(&gradients);
393 }
394
395 let normal = vec![0.1; 10];
397 assert!(!ewc.detect_task_boundary(&normal));
398
399 let different = vec![10.0; 10];
401 }
403
404 #[test]
405 fn test_constraint_application() {
406 let config = EwcConfig {
407 param_count: 5,
408 ..Default::default()
409 };
410 let mut ewc = EwcPlusPlus::new(config);
411
412 for _ in 0..10 {
414 ewc.update_fisher(&vec![1.0; 5]);
415 }
416 ewc.start_new_task();
417
418 let gradients = vec![1.0; 5];
420 let constrained = ewc.apply_constraints(&gradients);
421
422 let orig_mag: f32 = gradients.iter().map(|x| x.abs()).sum();
424 let const_mag: f32 = constrained.iter().map(|x| x.abs()).sum();
425 assert!(const_mag <= orig_mag);
426 }
427
428 #[test]
429 fn test_regularization_loss() {
430 let config = EwcConfig {
431 param_count: 5,
432 initial_lambda: 100.0,
433 ..Default::default()
434 };
435 let mut ewc = EwcPlusPlus::new(config);
436
437 ewc.set_optimal_weights(&vec![0.0; 5]);
439 for _ in 0..10 {
440 ewc.update_fisher(&vec![1.0; 5]);
441 }
442 ewc.start_new_task();
443
444 let at_optimal = ewc.regularization_loss(&vec![0.0; 5]);
446
447 let deviated = ewc.regularization_loss(&vec![1.0; 5]);
449 assert!(deviated > at_optimal);
450 }
451
452 #[test]
453 fn test_task_consolidation() {
454 let config = EwcConfig {
455 param_count: 5,
456 max_tasks: 5,
457 ..Default::default()
458 };
459 let mut ewc = EwcPlusPlus::new(config);
460
461 for _ in 0..3 {
463 for _ in 0..10 {
464 ewc.update_fisher(&vec![1.0; 5]);
465 }
466 ewc.start_new_task();
467 }
468
469 assert_eq!(ewc.task_count(), 3);
470
471 ewc.consolidate_all_tasks();
472 assert_eq!(ewc.task_count(), 1);
473 }
474
475 #[test]
476 fn test_lambda_adaptation() {
477 let config = EwcConfig {
478 param_count: 5,
479 initial_lambda: 1000.0,
480 ..Default::default()
481 };
482 let mut ewc = EwcPlusPlus::new(config);
483
484 let initial_lambda = ewc.lambda();
485
486 for _ in 0..5 {
488 ewc.start_new_task();
489 }
490
491 assert!(ewc.lambda() >= initial_lambda);
493 }
494}