1use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct EwcConfig {
15 pub param_count: usize,
17 pub max_tasks: usize,
19 pub initial_lambda: f32,
21 pub min_lambda: f32,
23 pub max_lambda: f32,
25 pub fisher_ema_decay: f32,
27 pub boundary_threshold: f32,
29 pub gradient_history_size: usize,
31}
32
33impl Default for EwcConfig {
34 fn default() -> Self {
35 Self {
39 param_count: 1000,
40 max_tasks: 10,
41 initial_lambda: 2000.0, min_lambda: 100.0,
43 max_lambda: 15000.0, fisher_ema_decay: 0.999,
45 boundary_threshold: 2.0,
46 gradient_history_size: 100,
47 }
48 }
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct TaskFisher {
54 pub task_id: usize,
56 pub fisher: Vec<f32>,
58 pub optimal_weights: Vec<f32>,
60 pub importance: f32,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct EwcPlusPlus {
67 config: EwcConfig,
69 current_fisher: Vec<f32>,
71 current_weights: Vec<f32>,
73 task_memory: VecDeque<TaskFisher>,
75 current_task_id: usize,
77 lambda: f32,
79 gradient_history: VecDeque<Vec<f32>>,
81 gradient_mean: Vec<f32>,
83 gradient_var: Vec<f32>,
85 samples_seen: u64,
87}
88
89impl EwcPlusPlus {
90 pub fn new(config: EwcConfig) -> Self {
92 let param_count = config.param_count;
93 let initial_lambda = config.initial_lambda;
94
95 Self {
96 config: config.clone(),
97 current_fisher: vec![0.0; param_count],
98 current_weights: vec![0.0; param_count],
99 task_memory: VecDeque::with_capacity(config.max_tasks),
100 current_task_id: 0,
101 lambda: initial_lambda,
102 gradient_history: VecDeque::with_capacity(config.gradient_history_size),
103 gradient_mean: vec![0.0; param_count],
104 gradient_var: vec![1.0; param_count],
105 samples_seen: 0,
106 }
107 }
108
109 pub fn update_fisher(&mut self, gradients: &[f32]) {
111 if gradients.len() != self.config.param_count {
112 return;
113 }
114
115 let decay = self.config.fisher_ema_decay;
116
117 for (i, &g) in gradients.iter().enumerate() {
119 self.current_fisher[i] = decay * self.current_fisher[i] + (1.0 - decay) * g * g;
120 }
121
122 self.update_gradient_stats(gradients);
124 self.samples_seen += 1;
125 }
126
127 fn update_gradient_stats(&mut self, gradients: &[f32]) {
129 if self.gradient_history.len() >= self.config.gradient_history_size {
131 self.gradient_history.pop_front();
132 }
133 self.gradient_history.push_back(gradients.to_vec());
134
135 let n = self.samples_seen as f32 + 1.0;
137
138 for (i, &g) in gradients.iter().enumerate() {
139 let delta = g - self.gradient_mean[i];
140 self.gradient_mean[i] += delta / n;
141 let delta2 = g - self.gradient_mean[i];
142 self.gradient_var[i] += delta * delta2;
143 }
144 }
145
146 pub fn detect_task_boundary(&self, gradients: &[f32]) -> bool {
148 if self.samples_seen < 50 || gradients.len() != self.config.param_count {
149 return false;
150 }
151
152 let mut z_score_sum = 0.0f32;
154 let mut count = 0;
155
156 for (i, &g) in gradients.iter().enumerate() {
157 let var = self.gradient_var[i] / self.samples_seen as f32;
158 if var > 1e-8 {
159 let std = var.sqrt();
160 let z = (g - self.gradient_mean[i]).abs() / std;
161 z_score_sum += z;
162 count += 1;
163 }
164 }
165
166 if count == 0 {
167 return false;
168 }
169
170 let avg_z = z_score_sum / count as f32;
171 avg_z > self.config.boundary_threshold
172 }
173
174 pub fn start_new_task(&mut self) {
176 let task_fisher = TaskFisher {
178 task_id: self.current_task_id,
179 fisher: self.current_fisher.clone(),
180 optimal_weights: self.current_weights.clone(),
181 importance: 1.0,
182 };
183
184 if self.task_memory.len() >= self.config.max_tasks {
186 self.task_memory.pop_front();
187 }
188 self.task_memory.push_back(task_fisher);
189
190 self.current_task_id += 1;
192 self.current_fisher.fill(0.0);
193 self.gradient_history.clear();
194 self.gradient_mean.fill(0.0);
195 self.gradient_var.fill(1.0);
196 self.samples_seen = 0;
197
198 self.adapt_lambda();
200 }
201
202 fn adapt_lambda(&mut self) {
204 let task_count = self.task_memory.len();
205 if task_count == 0 {
206 return;
207 }
208
209 let scale = 1.0 + 0.1 * task_count as f32;
211 self.lambda = (self.config.initial_lambda * scale)
212 .clamp(self.config.min_lambda, self.config.max_lambda);
213 }
214
215 pub fn apply_constraints(&self, gradients: &[f32]) -> Vec<f32> {
217 if gradients.len() != self.config.param_count {
218 return gradients.to_vec();
219 }
220
221 let mut constrained = gradients.to_vec();
222
223 for task in &self.task_memory {
225 for (i, g) in constrained.iter_mut().enumerate() {
226 let importance = task.fisher[i] * task.importance;
230 if importance > 1e-8 {
231 let penalty_grad = self.lambda * importance;
232 *g *= 1.0 / (1.0 + penalty_grad);
234 }
235 }
236 }
237
238 for (i, g) in constrained.iter_mut().enumerate() {
240 if self.current_fisher[i] > 1e-8 {
241 let penalty_grad = self.lambda * self.current_fisher[i] * 0.1; *g *= 1.0 / (1.0 + penalty_grad);
243 }
244 }
245
246 constrained
247 }
248
249 pub fn regularization_loss(&self, current_weights: &[f32]) -> f32 {
251 if current_weights.len() != self.config.param_count {
252 return 0.0;
253 }
254
255 let mut loss = 0.0f32;
256
257 for task in &self.task_memory {
258 for ((&cw, &ow), &fi) in current_weights
259 .iter()
260 .zip(task.optimal_weights.iter())
261 .zip(task.fisher.iter())
262 .take(self.config.param_count)
263 {
264 let diff = cw - ow;
265 loss += fi * diff * diff * task.importance;
266 }
267 }
268
269 self.lambda * loss / 2.0
270 }
271
272 pub fn set_optimal_weights(&mut self, weights: &[f32]) {
274 if weights.len() == self.config.param_count {
275 self.current_weights.copy_from_slice(weights);
276 }
277 }
278
279 pub fn consolidate_all_tasks(&mut self) {
281 if self.task_memory.is_empty() {
282 return;
283 }
284
285 let mut consolidated_fisher = vec![0.0f32; self.config.param_count];
287 let mut total_importance = 0.0f32;
288
289 for task in &self.task_memory {
290 for (i, &f) in task.fisher.iter().enumerate() {
291 consolidated_fisher[i] += f * task.importance;
292 }
293 total_importance += task.importance;
294 }
295
296 if total_importance > 0.0 {
297 for f in &mut consolidated_fisher {
298 *f /= total_importance;
299 }
300 }
301
302 let consolidated = TaskFisher {
304 task_id: 0,
305 fisher: consolidated_fisher,
306 optimal_weights: self.current_weights.clone(),
307 importance: total_importance,
308 };
309
310 self.task_memory.clear();
311 self.task_memory.push_back(consolidated);
312 }
313
314 pub fn lambda(&self) -> f32 {
316 self.lambda
317 }
318
319 pub fn set_lambda(&mut self, lambda: f32) {
321 self.lambda = lambda.clamp(self.config.min_lambda, self.config.max_lambda);
322 }
323
324 pub fn task_count(&self) -> usize {
326 self.task_memory.len()
327 }
328
329 pub fn current_task_id(&self) -> usize {
331 self.current_task_id
332 }
333
334 pub fn samples_seen(&self) -> u64 {
336 self.samples_seen
337 }
338
339 pub fn importance_scores(&self) -> Vec<f32> {
341 let mut scores = self.current_fisher.clone();
342
343 for task in &self.task_memory {
344 for (i, &f) in task.fisher.iter().enumerate() {
345 scores[i] += f * task.importance;
346 }
347 }
348
349 scores
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_ewc_creation() {
359 let config = EwcConfig {
360 param_count: 100,
361 ..Default::default()
362 };
363 let ewc = EwcPlusPlus::new(config);
364
365 assert_eq!(ewc.task_count(), 0);
366 assert_eq!(ewc.current_task_id(), 0);
367 }
368
369 #[test]
370 fn test_fisher_update() {
371 let config = EwcConfig {
372 param_count: 10,
373 ..Default::default()
374 };
375 let mut ewc = EwcPlusPlus::new(config);
376
377 let gradients = vec![0.5; 10];
378 ewc.update_fisher(&gradients);
379
380 assert!(ewc.samples_seen() > 0);
381 assert!(ewc.current_fisher.iter().any(|&f| f > 0.0));
382 }
383
384 #[test]
385 fn test_task_boundary() {
386 let config = EwcConfig {
387 param_count: 10,
388 gradient_history_size: 10,
389 boundary_threshold: 2.0,
390 ..Default::default()
391 };
392 let mut ewc = EwcPlusPlus::new(config);
393
394 for _ in 0..60 {
396 let gradients = vec![0.1; 10];
397 ewc.update_fisher(&gradients);
398 }
399
400 let normal = vec![0.1; 10];
402 assert!(!ewc.detect_task_boundary(&normal));
403
404 let different = vec![10.0; 10];
406 }
408
409 #[test]
410 fn test_constraint_application() {
411 let config = EwcConfig {
412 param_count: 5,
413 ..Default::default()
414 };
415 let mut ewc = EwcPlusPlus::new(config);
416
417 for _ in 0..10 {
419 ewc.update_fisher(&vec![1.0; 5]);
420 }
421 ewc.start_new_task();
422
423 let gradients = vec![1.0; 5];
425 let constrained = ewc.apply_constraints(&gradients);
426
427 let orig_mag: f32 = gradients.iter().map(|x| x.abs()).sum();
429 let const_mag: f32 = constrained.iter().map(|x| x.abs()).sum();
430 assert!(const_mag <= orig_mag);
431 }
432
433 #[test]
434 fn test_regularization_loss() {
435 let config = EwcConfig {
436 param_count: 5,
437 initial_lambda: 100.0,
438 ..Default::default()
439 };
440 let mut ewc = EwcPlusPlus::new(config);
441
442 ewc.set_optimal_weights(&vec![0.0; 5]);
444 for _ in 0..10 {
445 ewc.update_fisher(&vec![1.0; 5]);
446 }
447 ewc.start_new_task();
448
449 let at_optimal = ewc.regularization_loss(&vec![0.0; 5]);
451
452 let deviated = ewc.regularization_loss(&vec![1.0; 5]);
454 assert!(deviated > at_optimal);
455 }
456
457 #[test]
458 fn test_task_consolidation() {
459 let config = EwcConfig {
460 param_count: 5,
461 max_tasks: 5,
462 ..Default::default()
463 };
464 let mut ewc = EwcPlusPlus::new(config);
465
466 for _ in 0..3 {
468 for _ in 0..10 {
469 ewc.update_fisher(&vec![1.0; 5]);
470 }
471 ewc.start_new_task();
472 }
473
474 assert_eq!(ewc.task_count(), 3);
475
476 ewc.consolidate_all_tasks();
477 assert_eq!(ewc.task_count(), 1);
478 }
479
480 #[test]
481 fn test_lambda_adaptation() {
482 let config = EwcConfig {
483 param_count: 5,
484 initial_lambda: 1000.0,
485 ..Default::default()
486 };
487 let mut ewc = EwcPlusPlus::new(config);
488
489 let initial_lambda = ewc.lambda();
490
491 for _ in 0..5 {
493 ewc.start_new_task();
494 }
495
496 assert!(ewc.lambda() >= initial_lambda);
498 }
499}