1use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct EwcConfig {
15 pub param_count: usize,
17 pub max_tasks: usize,
19 pub initial_lambda: f32,
21 pub min_lambda: f32,
23 pub max_lambda: f32,
25 pub fisher_ema_decay: f32,
27 pub boundary_threshold: f32,
29 pub gradient_history_size: usize,
31}
32
33impl Default for EwcConfig {
34 fn default() -> Self {
35 Self {
36 param_count: 1000,
37 max_tasks: 10,
38 initial_lambda: 1000.0,
39 min_lambda: 100.0,
40 max_lambda: 10000.0,
41 fisher_ema_decay: 0.999,
42 boundary_threshold: 2.0,
43 gradient_history_size: 100,
44 }
45 }
46}
47
48#[derive(Clone, Debug, Serialize, Deserialize)]
50pub struct TaskFisher {
51 pub task_id: usize,
53 pub fisher: Vec<f32>,
55 pub optimal_weights: Vec<f32>,
57 pub importance: f32,
59}
60
61#[derive(Clone, Debug, Serialize, Deserialize)]
63pub struct EwcPlusPlus {
64 config: EwcConfig,
66 current_fisher: Vec<f32>,
68 current_weights: Vec<f32>,
70 task_memory: VecDeque<TaskFisher>,
72 current_task_id: usize,
74 lambda: f32,
76 gradient_history: VecDeque<Vec<f32>>,
78 gradient_mean: Vec<f32>,
80 gradient_var: Vec<f32>,
82 samples_seen: u64,
84}
85
86impl EwcPlusPlus {
87 pub fn new(config: EwcConfig) -> Self {
89 let param_count = config.param_count;
90 let initial_lambda = config.initial_lambda;
91
92 Self {
93 config: config.clone(),
94 current_fisher: vec![0.0; param_count],
95 current_weights: vec![0.0; param_count],
96 task_memory: VecDeque::with_capacity(config.max_tasks),
97 current_task_id: 0,
98 lambda: initial_lambda,
99 gradient_history: VecDeque::with_capacity(config.gradient_history_size),
100 gradient_mean: vec![0.0; param_count],
101 gradient_var: vec![1.0; param_count],
102 samples_seen: 0,
103 }
104 }
105
106 pub fn update_fisher(&mut self, gradients: &[f32]) {
108 if gradients.len() != self.config.param_count {
109 return;
110 }
111
112 let decay = self.config.fisher_ema_decay;
113
114 for (i, &g) in gradients.iter().enumerate() {
116 self.current_fisher[i] = decay * self.current_fisher[i] + (1.0 - decay) * g * g;
117 }
118
119 self.update_gradient_stats(gradients);
121 self.samples_seen += 1;
122 }
123
124 fn update_gradient_stats(&mut self, gradients: &[f32]) {
126 if self.gradient_history.len() >= self.config.gradient_history_size {
128 self.gradient_history.pop_front();
129 }
130 self.gradient_history.push_back(gradients.to_vec());
131
132 let n = self.samples_seen as f32 + 1.0;
134
135 for (i, &g) in gradients.iter().enumerate() {
136 let delta = g - self.gradient_mean[i];
137 self.gradient_mean[i] += delta / n;
138 let delta2 = g - self.gradient_mean[i];
139 self.gradient_var[i] += delta * delta2;
140 }
141 }
142
143 pub fn detect_task_boundary(&self, gradients: &[f32]) -> bool {
145 if self.samples_seen < 50 || gradients.len() != self.config.param_count {
146 return false;
147 }
148
149 let mut z_score_sum = 0.0f32;
151 let mut count = 0;
152
153 for (i, &g) in gradients.iter().enumerate() {
154 let var = self.gradient_var[i] / self.samples_seen as f32;
155 if var > 1e-8 {
156 let std = var.sqrt();
157 let z = (g - self.gradient_mean[i]).abs() / std;
158 z_score_sum += z;
159 count += 1;
160 }
161 }
162
163 if count == 0 {
164 return false;
165 }
166
167 let avg_z = z_score_sum / count as f32;
168 avg_z > self.config.boundary_threshold
169 }
170
171 pub fn start_new_task(&mut self) {
173 let task_fisher = TaskFisher {
175 task_id: self.current_task_id,
176 fisher: self.current_fisher.clone(),
177 optimal_weights: self.current_weights.clone(),
178 importance: 1.0,
179 };
180
181 if self.task_memory.len() >= self.config.max_tasks {
183 self.task_memory.pop_front();
184 }
185 self.task_memory.push_back(task_fisher);
186
187 self.current_task_id += 1;
189 self.current_fisher.fill(0.0);
190 self.gradient_history.clear();
191 self.gradient_mean.fill(0.0);
192 self.gradient_var.fill(1.0);
193 self.samples_seen = 0;
194
195 self.adapt_lambda();
197 }
198
199 fn adapt_lambda(&mut self) {
201 let task_count = self.task_memory.len();
202 if task_count == 0 {
203 return;
204 }
205
206 let scale = 1.0 + 0.1 * task_count as f32;
208 self.lambda = (self.config.initial_lambda * scale)
209 .clamp(self.config.min_lambda, self.config.max_lambda);
210 }
211
212 pub fn apply_constraints(&self, gradients: &[f32]) -> Vec<f32> {
214 if gradients.len() != self.config.param_count {
215 return gradients.to_vec();
216 }
217
218 let mut constrained = gradients.to_vec();
219
220 for task in &self.task_memory {
222 for (i, g) in constrained.iter_mut().enumerate() {
223 let importance = task.fisher[i] * task.importance;
227 if importance > 1e-8 {
228 let penalty_grad = self.lambda * importance;
229 *g *= 1.0 / (1.0 + penalty_grad);
231 }
232 }
233 }
234
235 for (i, g) in constrained.iter_mut().enumerate() {
237 if self.current_fisher[i] > 1e-8 {
238 let penalty_grad = self.lambda * self.current_fisher[i] * 0.1; *g *= 1.0 / (1.0 + penalty_grad);
240 }
241 }
242
243 constrained
244 }
245
246 pub fn regularization_loss(&self, current_weights: &[f32]) -> f32 {
248 if current_weights.len() != self.config.param_count {
249 return 0.0;
250 }
251
252 let mut loss = 0.0f32;
253
254 for task in &self.task_memory {
255 for i in 0..self.config.param_count {
256 let diff = current_weights[i] - task.optimal_weights[i];
257 loss += task.fisher[i] * diff * diff * task.importance;
258 }
259 }
260
261 self.lambda * loss / 2.0
262 }
263
264 pub fn set_optimal_weights(&mut self, weights: &[f32]) {
266 if weights.len() == self.config.param_count {
267 self.current_weights.copy_from_slice(weights);
268 }
269 }
270
271 pub fn consolidate_all_tasks(&mut self) {
273 if self.task_memory.is_empty() {
274 return;
275 }
276
277 let mut consolidated_fisher = vec![0.0f32; self.config.param_count];
279 let mut total_importance = 0.0f32;
280
281 for task in &self.task_memory {
282 for (i, &f) in task.fisher.iter().enumerate() {
283 consolidated_fisher[i] += f * task.importance;
284 }
285 total_importance += task.importance;
286 }
287
288 if total_importance > 0.0 {
289 for f in &mut consolidated_fisher {
290 *f /= total_importance;
291 }
292 }
293
294 let consolidated = TaskFisher {
296 task_id: 0,
297 fisher: consolidated_fisher,
298 optimal_weights: self.current_weights.clone(),
299 importance: total_importance,
300 };
301
302 self.task_memory.clear();
303 self.task_memory.push_back(consolidated);
304 }
305
306 pub fn lambda(&self) -> f32 {
308 self.lambda
309 }
310
311 pub fn set_lambda(&mut self, lambda: f32) {
313 self.lambda = lambda.clamp(self.config.min_lambda, self.config.max_lambda);
314 }
315
316 pub fn task_count(&self) -> usize {
318 self.task_memory.len()
319 }
320
321 pub fn current_task_id(&self) -> usize {
323 self.current_task_id
324 }
325
326 pub fn samples_seen(&self) -> u64 {
328 self.samples_seen
329 }
330
331 pub fn importance_scores(&self) -> Vec<f32> {
333 let mut scores = self.current_fisher.clone();
334
335 for task in &self.task_memory {
336 for (i, &f) in task.fisher.iter().enumerate() {
337 scores[i] += f * task.importance;
338 }
339 }
340
341 scores
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_ewc_creation() {
351 let config = EwcConfig {
352 param_count: 100,
353 ..Default::default()
354 };
355 let ewc = EwcPlusPlus::new(config);
356
357 assert_eq!(ewc.task_count(), 0);
358 assert_eq!(ewc.current_task_id(), 0);
359 }
360
361 #[test]
362 fn test_fisher_update() {
363 let config = EwcConfig {
364 param_count: 10,
365 ..Default::default()
366 };
367 let mut ewc = EwcPlusPlus::new(config);
368
369 let gradients = vec![0.5; 10];
370 ewc.update_fisher(&gradients);
371
372 assert!(ewc.samples_seen() > 0);
373 assert!(ewc.current_fisher.iter().any(|&f| f > 0.0));
374 }
375
376 #[test]
377 fn test_task_boundary() {
378 let config = EwcConfig {
379 param_count: 10,
380 gradient_history_size: 10,
381 boundary_threshold: 2.0,
382 ..Default::default()
383 };
384 let mut ewc = EwcPlusPlus::new(config);
385
386 for _ in 0..60 {
388 let gradients = vec![0.1; 10];
389 ewc.update_fisher(&gradients);
390 }
391
392 let normal = vec![0.1; 10];
394 assert!(!ewc.detect_task_boundary(&normal));
395
396 let different = vec![10.0; 10];
398 }
400
401 #[test]
402 fn test_constraint_application() {
403 let config = EwcConfig {
404 param_count: 5,
405 ..Default::default()
406 };
407 let mut ewc = EwcPlusPlus::new(config);
408
409 for _ in 0..10 {
411 ewc.update_fisher(&vec![1.0; 5]);
412 }
413 ewc.start_new_task();
414
415 let gradients = vec![1.0; 5];
417 let constrained = ewc.apply_constraints(&gradients);
418
419 let orig_mag: f32 = gradients.iter().map(|x| x.abs()).sum();
421 let const_mag: f32 = constrained.iter().map(|x| x.abs()).sum();
422 assert!(const_mag <= orig_mag);
423 }
424
425 #[test]
426 fn test_regularization_loss() {
427 let config = EwcConfig {
428 param_count: 5,
429 initial_lambda: 100.0,
430 ..Default::default()
431 };
432 let mut ewc = EwcPlusPlus::new(config);
433
434 ewc.set_optimal_weights(&vec![0.0; 5]);
436 for _ in 0..10 {
437 ewc.update_fisher(&vec![1.0; 5]);
438 }
439 ewc.start_new_task();
440
441 let at_optimal = ewc.regularization_loss(&vec![0.0; 5]);
443
444 let deviated = ewc.regularization_loss(&vec![1.0; 5]);
446 assert!(deviated > at_optimal);
447 }
448
449 #[test]
450 fn test_task_consolidation() {
451 let config = EwcConfig {
452 param_count: 5,
453 max_tasks: 5,
454 ..Default::default()
455 };
456 let mut ewc = EwcPlusPlus::new(config);
457
458 for _ in 0..3 {
460 for _ in 0..10 {
461 ewc.update_fisher(&vec![1.0; 5]);
462 }
463 ewc.start_new_task();
464 }
465
466 assert_eq!(ewc.task_count(), 3);
467
468 ewc.consolidate_all_tasks();
469 assert_eq!(ewc.task_count(), 1);
470 }
471
472 #[test]
473 fn test_lambda_adaptation() {
474 let config = EwcConfig {
475 param_count: 5,
476 initial_lambda: 1000.0,
477 ..Default::default()
478 };
479 let mut ewc = EwcPlusPlus::new(config);
480
481 let initial_lambda = ewc.lambda();
482
483 for _ in 0..5 {
485 ewc.start_new_task();
486 }
487
488 assert!(ewc.lambda() >= initial_lambda);
490 }
491}