1use anyhow::{anyhow, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use trustformers_core::tensor::Tensor;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EWCConfig {
23 pub learning_rate: f32,
25 pub lambda: f32,
27 pub fisher_method: FisherMethod,
29 pub fisher_samples: usize,
31 pub online: bool,
33 pub decay_factor: f32,
35}
36
37impl Default for EWCConfig {
38 fn default() -> Self {
39 Self {
40 learning_rate: 1e-3,
41 lambda: 1000.0,
42 fisher_method: FisherMethod::Empirical,
43 fisher_samples: 1000,
44 online: false,
45 decay_factor: 0.9,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub enum FisherMethod {
53 Empirical,
55 True,
57 Diagonal,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct PackNetConfig {
64 pub learning_rate: f32,
66 pub sparsity_level: f32,
68 pub num_tasks: usize,
70 pub allocation_strategy: AllocationStrategy,
72}
73
74impl Default for PackNetConfig {
75 fn default() -> Self {
76 Self {
77 learning_rate: 1e-3,
78 sparsity_level: 0.5,
79 num_tasks: 10,
80 allocation_strategy: AllocationStrategy::Sequential,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub enum AllocationStrategy {
88 Sequential,
90 Random,
92 ImportanceBased,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct L2RegularizationConfig {
99 pub learning_rate: f32,
101 pub reg_strength: f32,
103 pub update_strategy: UpdateStrategy,
105}
106
107impl Default for L2RegularizationConfig {
108 fn default() -> Self {
109 Self {
110 learning_rate: 1e-3,
111 reg_strength: 0.1,
112 update_strategy: UpdateStrategy::EMA,
113 }
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub enum UpdateStrategy {
120 Fixed,
122 EMA,
124 TaskBoundary,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct MemoryReplayConfig {
131 pub learning_rate: f32,
133 pub memory_size: usize,
135 pub replay_frequency: usize,
137 pub replay_batch_size: usize,
139 pub selection_strategy: MemorySelectionStrategy,
141}
142
143impl Default for MemoryReplayConfig {
144 fn default() -> Self {
145 Self {
146 learning_rate: 1e-3,
147 memory_size: 1000,
148 replay_frequency: 10,
149 replay_batch_size: 32,
150 selection_strategy: MemorySelectionStrategy::Random,
151 }
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub enum MemorySelectionStrategy {
158 Random,
160 GradientBased,
162 UncertaintyBased,
164}
165
166pub struct EWC {
168 config: EWCConfig,
169 parameters: Vec<Tensor>,
170 importance_weights: Vec<Tensor>,
171 anchor_parameters: Vec<Tensor>,
172 current_task: usize,
173 accumulated_importance: Vec<Tensor>,
174}
175
176impl EWC {
177 pub fn new(config: EWCConfig, initial_parameters: Vec<Tensor>) -> Result<Self> {
179 let param_count = initial_parameters.len();
180
181 Ok(Self {
182 config,
183 parameters: initial_parameters.clone(),
184 importance_weights: (0..param_count)
185 .map(|i| Tensor::zeros(&initial_parameters[i].shape()).unwrap())
186 .collect(),
187 anchor_parameters: initial_parameters.clone(),
188 current_task: 0,
189 accumulated_importance: (0..param_count)
190 .map(|i| Tensor::zeros(&initial_parameters[i].shape()).unwrap())
191 .collect(),
192 })
193 }
194
195 pub fn compute_fisher_information(&mut self, gradients_samples: &[Vec<Tensor>]) -> Result<()> {
197 let num_samples = gradients_samples.len();
198 if num_samples == 0 {
199 return Err(anyhow!("No gradient samples provided"));
200 }
201
202 for importance in self.importance_weights.iter_mut() {
204 *importance = Tensor::zeros(&importance.shape())?;
205 }
206
207 for gradient_sample in gradients_samples {
209 for (i, gradient) in gradient_sample.iter().enumerate() {
210 if i < self.importance_weights.len() {
211 let squared_grad = gradient.mul(gradient)?;
212 self.importance_weights[i] = self.importance_weights[i].add(&squared_grad)?;
213 }
214 }
215 }
216
217 for importance in self.importance_weights.iter_mut() {
219 *importance = importance.div_scalar(num_samples as f32)?;
220 }
221
222 if self.config.online {
224 for i in 0..self.accumulated_importance.len() {
225 let decayed =
226 self.accumulated_importance[i].mul_scalar(self.config.decay_factor)?;
227 self.accumulated_importance[i] = decayed.add(&self.importance_weights[i])?;
228 }
229 }
230
231 Ok(())
232 }
233
234 pub fn finish_task(&mut self) -> Result<()> {
236 self.anchor_parameters = self.parameters.clone();
238 self.current_task += 1;
239 Ok(())
240 }
241
242 pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
244 for (i, gradient) in gradients.iter().enumerate() {
245 if i < self.parameters.len() {
246 let param_diff = self.parameters[i].sub(&self.anchor_parameters[i])?;
248 let importance = if self.config.online {
249 &self.accumulated_importance[i]
250 } else {
251 &self.importance_weights[i]
252 };
253 let ewc_grad = param_diff.mul(importance)?.mul_scalar(self.config.lambda)?;
254
255 let total_grad = gradient.add(&ewc_grad)?;
257
258 let update = total_grad.mul_scalar(self.config.learning_rate)?;
260 self.parameters[i] = self.parameters[i].sub(&update)?;
261 }
262 }
263 Ok(())
264 }
265
266 pub fn get_parameters(&self) -> &[Tensor] {
268 &self.parameters
269 }
270
271 pub fn get_importance_weights(&self) -> &[Tensor] {
273 &self.importance_weights
274 }
275}
276
277pub struct PackNet {
279 config: PackNetConfig,
280 parameters: Vec<Tensor>,
281 #[allow(dead_code)]
282 parameter_masks: Vec<Tensor>,
283 task_allocations: HashMap<usize, Vec<Tensor>>,
284 current_task: usize,
285 available_capacity: Vec<f32>,
286}
287
288impl PackNet {
289 pub fn new(config: PackNetConfig, initial_parameters: Vec<Tensor>) -> Result<Self> {
291 let param_count = initial_parameters.len();
292
293 Ok(Self {
294 config,
295 parameters: initial_parameters.clone(),
296 parameter_masks: (0..param_count)
297 .map(|i| Tensor::ones(&initial_parameters[i].shape()).unwrap())
298 .collect(),
299 task_allocations: HashMap::new(),
300 current_task: 0,
301 available_capacity: vec![1.0; param_count],
302 })
303 }
304
305 pub fn allocate_task(&mut self, task_id: usize) -> Result<()> {
307 if self.available_capacity.iter().any(|&cap| cap < self.config.sparsity_level) {
308 return Err(anyhow!("Insufficient parameter capacity for new task"));
309 }
310
311 let mut task_masks = Vec::new();
312
313 for (i, param) in self.parameters.iter().enumerate() {
314 let shape = param.shape();
315 let total_params = shape.iter().product::<usize>();
316 let allocated_params = (total_params as f32 * self.config.sparsity_level) as usize;
317
318 let mut mask_data = vec![0.0; total_params];
320
321 match self.config.allocation_strategy {
322 AllocationStrategy::Sequential => {
323 let start_idx =
324 ((1.0 - self.available_capacity[i]) * total_params as f32) as usize;
325 let end_idx = (start_idx + allocated_params).min(total_params);
326 for idx in start_idx..end_idx {
327 mask_data[idx] = 1.0;
328 }
329 },
330 AllocationStrategy::Random => {
331 use scirs2_core::random::*; let mut indices: Vec<usize> = (0..total_params).collect();
333 let mut rng = thread_rng();
334 indices.shuffle(rng.rng_mut());
335 for &idx in indices.iter().take(allocated_params) {
336 mask_data[idx] = 1.0;
337 }
338 },
339 AllocationStrategy::ImportanceBased => {
340 for idx in 0..allocated_params.min(total_params) {
343 mask_data[idx] = 1.0;
344 }
345 },
346 }
347
348 let task_mask = Tensor::new(mask_data)?;
349 task_masks.push(task_mask);
350
351 self.available_capacity[i] -= self.config.sparsity_level;
353 }
354
355 self.task_allocations.insert(task_id, task_masks);
356 self.current_task = task_id;
357 Ok(())
358 }
359
360 pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
362 let task_masks = self
363 .task_allocations
364 .get(&self.current_task)
365 .ok_or_else(|| anyhow!("No allocation for current task"))?;
366
367 for (i, gradient) in gradients.iter().enumerate() {
368 if i < self.parameters.len() && i < task_masks.len() {
369 let masked_grad = gradient.mul(&task_masks[i])?;
371
372 let update = masked_grad.mul_scalar(self.config.learning_rate)?;
374 self.parameters[i] = self.parameters[i].sub(&update)?;
375 }
376 }
377 Ok(())
378 }
379
380 pub fn get_parameters(&self) -> &[Tensor] {
382 &self.parameters
383 }
384
385 pub fn get_available_capacity(&self) -> &[f32] {
387 &self.available_capacity
388 }
389}
390
391pub struct L2Regularization {
393 config: L2RegularizationConfig,
394 parameters: Vec<Tensor>,
395 anchor_parameters: Vec<Tensor>,
396 ema_decay: f32,
397}
398
399impl L2Regularization {
400 pub fn new(config: L2RegularizationConfig, initial_parameters: Vec<Tensor>) -> Self {
402 Self {
403 config,
404 parameters: initial_parameters.clone(),
405 anchor_parameters: initial_parameters,
406 ema_decay: 0.999,
407 }
408 }
409
410 pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
412 for (i, gradient) in gradients.iter().enumerate() {
413 if i < self.parameters.len() {
414 let param_diff = self.parameters[i].sub(&self.anchor_parameters[i])?;
416 let reg_grad = param_diff.mul_scalar(self.config.reg_strength)?;
417
418 let total_grad = gradient.add(®_grad)?;
420
421 let update = total_grad.mul_scalar(self.config.learning_rate)?;
423 self.parameters[i] = self.parameters[i].sub(&update)?;
424
425 match self.config.update_strategy {
427 UpdateStrategy::Fixed => {
428 },
430 UpdateStrategy::EMA => {
431 let anchor_update = self.parameters[i].mul_scalar(1.0 - self.ema_decay)?;
433 let anchor_keep = self.anchor_parameters[i].mul_scalar(self.ema_decay)?;
434 self.anchor_parameters[i] = anchor_update.add(&anchor_keep)?;
435 },
436 UpdateStrategy::TaskBoundary => {
437 },
439 }
440 }
441 }
442 Ok(())
443 }
444
445 pub fn finish_task(&mut self) -> Result<()> {
447 if matches!(self.config.update_strategy, UpdateStrategy::TaskBoundary) {
448 self.anchor_parameters = self.parameters.clone();
449 }
450 Ok(())
451 }
452
453 pub fn get_parameters(&self) -> &[Tensor] {
455 &self.parameters
456 }
457}
458
459pub struct MemoryReplay {
461 config: MemoryReplayConfig,
462 parameters: Vec<Tensor>,
463 memory_buffer: Vec<Vec<Tensor>>, step_count: usize,
465}
466
467impl MemoryReplay {
468 pub fn new(config: MemoryReplayConfig, initial_parameters: Vec<Tensor>) -> Self {
470 Self {
471 config,
472 parameters: initial_parameters,
473 memory_buffer: Vec::new(),
474 step_count: 0,
475 }
476 }
477
478 pub fn store_gradient(&mut self, gradients: &[Tensor]) -> Result<()> {
480 if self.memory_buffer.len() >= self.config.memory_size {
481 match self.config.selection_strategy {
483 MemorySelectionStrategy::Random => {
484 use scirs2_core::random::*; let idx = thread_rng().gen_range(0..self.memory_buffer.len());
486 self.memory_buffer.remove(idx);
487 },
488 _ => {
489 self.memory_buffer.remove(0); },
491 }
492 }
493
494 self.memory_buffer.push(gradients.to_vec());
495 Ok(())
496 }
497
498 pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
500 for (i, gradient) in gradients.iter().enumerate() {
502 if i < self.parameters.len() {
503 let update = gradient.mul_scalar(self.config.learning_rate)?;
504 self.parameters[i] = self.parameters[i].sub(&update)?;
505 }
506 }
507
508 self.store_gradient(gradients)?;
510
511 if self.step_count % self.config.replay_frequency == 0 && !self.memory_buffer.is_empty() {
513 self.replay_step()?;
514 }
515
516 self.step_count += 1;
517 Ok(())
518 }
519
520 fn replay_step(&mut self) -> Result<()> {
521 let batch_size = self.config.replay_batch_size.min(self.memory_buffer.len());
522
523 use scirs2_core::random::*; let mut indices: Vec<usize> = (0..self.memory_buffer.len()).collect();
526 let mut rng = thread_rng();
527 indices.shuffle(rng.rng_mut());
528
529 for &idx in indices.iter().take(batch_size) {
530 let replay_gradients = &self.memory_buffer[idx];
531
532 let replay_lr = self.config.learning_rate * 0.5;
534 for (i, gradient) in replay_gradients.iter().enumerate() {
535 if i < self.parameters.len() {
536 let update = gradient.mul_scalar(replay_lr)?;
537 self.parameters[i] = self.parameters[i].sub(&update)?;
538 }
539 }
540 }
541
542 Ok(())
543 }
544
545 pub fn get_parameters(&self) -> &[Tensor] {
547 &self.parameters
548 }
549
550 pub fn memory_size(&self) -> usize {
552 self.memory_buffer.len()
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 #[test]
561 fn test_ewc_config() {
562 let config = EWCConfig::default();
563 assert_eq!(config.learning_rate, 1e-3);
564 assert_eq!(config.lambda, 1000.0);
565 assert!(!config.online);
566 }
567
568 #[test]
569 fn test_packnet_config() {
570 let config = PackNetConfig::default();
571 assert_eq!(config.sparsity_level, 0.5);
572 assert_eq!(config.num_tasks, 10);
573 }
574
575 #[test]
576 fn test_l2_regularization_config() {
577 let config = L2RegularizationConfig::default();
578 assert_eq!(config.reg_strength, 0.1);
579 assert!(matches!(config.update_strategy, UpdateStrategy::EMA));
580 }
581
582 #[test]
583 fn test_memory_replay_config() {
584 let config = MemoryReplayConfig::default();
585 assert_eq!(config.memory_size, 1000);
586 assert_eq!(config.replay_frequency, 10);
587 assert!(matches!(
588 config.selection_strategy,
589 MemorySelectionStrategy::Random
590 ));
591 }
592
593 #[test]
594 fn test_fisher_methods() {
595 assert!(matches!(FisherMethod::Empirical, FisherMethod::Empirical));
596 assert!(matches!(FisherMethod::True, FisherMethod::True));
597 assert!(matches!(FisherMethod::Diagonal, FisherMethod::Diagonal));
598 }
599}