1use crate::{TrainError, TrainResult};
50use scirs2_core::ndarray::{Array1, Array2};
51use std::collections::HashMap;
52
53#[derive(Debug, Clone)]
58pub struct MAMLConfig {
59 pub inner_steps: usize,
61 pub inner_lr: f64,
63 pub outer_lr: f64,
65 pub first_order: bool,
67}
68
69impl Default for MAMLConfig {
70 fn default() -> Self {
71 Self {
72 inner_steps: 5,
73 inner_lr: 0.01,
74 outer_lr: 0.001,
75 first_order: false,
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
87pub struct ReptileConfig {
88 pub inner_steps: usize,
90 pub inner_lr: f64,
92 pub outer_lr: f64,
94}
95
96impl Default for ReptileConfig {
97 fn default() -> Self {
98 Self {
99 inner_steps: 10,
100 inner_lr: 0.01,
101 outer_lr: 0.1,
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
111pub struct MetaTask {
112 pub support_x: Array2<f64>,
114 pub support_y: Array2<f64>,
116 pub query_x: Array2<f64>,
118 pub query_y: Array2<f64>,
120}
121
122impl MetaTask {
123 pub fn new(
125 support_x: Array2<f64>,
126 support_y: Array2<f64>,
127 query_x: Array2<f64>,
128 query_y: Array2<f64>,
129 ) -> TrainResult<Self> {
130 if support_x.nrows() != support_y.nrows() {
131 return Err(TrainError::InvalidParameter(format!(
132 "Support X rows ({}) must match support Y rows ({})",
133 support_x.nrows(),
134 support_y.nrows()
135 )));
136 }
137
138 if query_x.nrows() != query_y.nrows() {
139 return Err(TrainError::InvalidParameter(format!(
140 "Query X rows ({}) must match query Y rows ({})",
141 query_x.nrows(),
142 query_y.nrows()
143 )));
144 }
145
146 Ok(Self {
147 support_x,
148 support_y,
149 query_x,
150 query_y,
151 })
152 }
153
154 pub fn support_size(&self) -> usize {
156 self.support_x.nrows()
157 }
158
159 pub fn query_size(&self) -> usize {
161 self.query_x.nrows()
162 }
163}
164
165pub trait MetaLearner {
167 fn meta_step(
176 &self,
177 tasks: &[MetaTask],
178 parameters: &HashMap<String, Array1<f64>>,
179 ) -> TrainResult<(HashMap<String, Array1<f64>>, f64)>;
180
181 fn adapt(
190 &self,
191 task: &MetaTask,
192 parameters: &HashMap<String, Array1<f64>>,
193 ) -> TrainResult<HashMap<String, Array1<f64>>>;
194}
195
196#[derive(Debug, Clone)]
201pub struct MAML {
202 config: MAMLConfig,
203}
204
205impl MAML {
206 pub fn new(config: MAMLConfig) -> Self {
208 Self { config }
209 }
210}
211
212impl Default for MAML {
213 fn default() -> Self {
214 Self::new(MAMLConfig::default())
215 }
216}
217
218impl MetaLearner for MAML {
219 fn meta_step(
220 &self,
221 tasks: &[MetaTask],
222 parameters: &HashMap<String, Array1<f64>>,
223 ) -> TrainResult<(HashMap<String, Array1<f64>>, f64)> {
224 let mut meta_gradients: HashMap<String, Array1<f64>> = HashMap::new();
225 let mut total_loss = 0.0;
226
227 for (name, param) in parameters {
229 meta_gradients.insert(name.clone(), Array1::zeros(param.len()));
230 }
231
232 for task in tasks {
234 let adapted_params = self.adapt(task, parameters)?;
236
237 let query_loss = self.compute_query_loss(task, &adapted_params)?;
240 total_loss += query_loss;
241
242 let task_gradients = if self.config.first_order {
246 self.compute_first_order_gradients(task, &adapted_params)?
247 } else {
248 self.compute_second_order_gradients(task, parameters, &adapted_params)?
249 };
250
251 for (name, grad) in task_gradients {
253 if let Some(meta_grad) = meta_gradients.get_mut(&name) {
254 *meta_grad = meta_grad.clone() + grad;
255 }
256 }
257 }
258
259 let n_tasks = tasks.len() as f64;
261 for grad in meta_gradients.values_mut() {
262 *grad = grad.mapv(|x| x / n_tasks);
263 }
264 total_loss /= n_tasks;
265
266 let mut updated_params = HashMap::new();
268 for (name, param) in parameters {
269 if let Some(grad) = meta_gradients.get(name) {
270 let updated = param - &grad.mapv(|g| g * self.config.outer_lr);
271 updated_params.insert(name.clone(), updated);
272 }
273 }
274
275 Ok((updated_params, total_loss))
276 }
277
278 fn adapt(
279 &self,
280 task: &MetaTask,
281 parameters: &HashMap<String, Array1<f64>>,
282 ) -> TrainResult<HashMap<String, Array1<f64>>> {
283 let mut adapted_params = parameters.clone();
284
285 for _ in 0..self.config.inner_steps {
287 let gradients = self.compute_support_gradients(task, &adapted_params)?;
289
290 for (name, param) in &mut adapted_params {
292 if let Some(grad) = gradients.get(name) {
293 *param = param.clone() - &grad.mapv(|g| g * self.config.inner_lr);
294 }
295 }
296 }
297
298 Ok(adapted_params)
299 }
300}
301
302impl MAML {
303 fn compute_support_gradients(
305 &self,
306 task: &MetaTask,
307 _parameters: &HashMap<String, Array1<f64>>,
308 ) -> TrainResult<HashMap<String, Array1<f64>>> {
309 let mut gradients = HashMap::new();
312 gradients.insert("weights".to_string(), Array1::zeros(task.support_x.ncols()));
313 Ok(gradients)
314 }
315
316 fn compute_query_loss(
318 &self,
319 task: &MetaTask,
320 _parameters: &HashMap<String, Array1<f64>>,
321 ) -> TrainResult<f64> {
322 Ok(task.query_size() as f64 * 0.1)
325 }
326
327 fn compute_first_order_gradients(
329 &self,
330 task: &MetaTask,
331 _parameters: &HashMap<String, Array1<f64>>,
332 ) -> TrainResult<HashMap<String, Array1<f64>>> {
333 let mut gradients = HashMap::new();
335 gradients.insert("weights".to_string(), Array1::zeros(task.query_x.ncols()));
336 Ok(gradients)
337 }
338
339 fn compute_second_order_gradients(
341 &self,
342 task: &MetaTask,
343 _meta_params: &HashMap<String, Array1<f64>>,
344 _adapted_params: &HashMap<String, Array1<f64>>,
345 ) -> TrainResult<HashMap<String, Array1<f64>>> {
346 let mut gradients = HashMap::new();
349 gradients.insert("weights".to_string(), Array1::zeros(task.query_x.ncols()));
350 Ok(gradients)
351 }
352}
353
354#[derive(Debug, Clone)]
361pub struct Reptile {
362 config: ReptileConfig,
363}
364
365impl Reptile {
366 pub fn new(config: ReptileConfig) -> Self {
368 Self { config }
369 }
370}
371
372impl Default for Reptile {
373 fn default() -> Self {
374 Self::new(ReptileConfig::default())
375 }
376}
377
378impl MetaLearner for Reptile {
379 fn meta_step(
380 &self,
381 tasks: &[MetaTask],
382 parameters: &HashMap<String, Array1<f64>>,
383 ) -> TrainResult<(HashMap<String, Array1<f64>>, f64)> {
384 let mut total_loss = 0.0;
385 let mut accumulated_delta: HashMap<String, Array1<f64>> = HashMap::new();
386
387 for (name, param) in parameters {
389 accumulated_delta.insert(name.clone(), Array1::zeros(param.len()));
390 }
391
392 for task in tasks {
394 let task_params = self.adapt(task, parameters)?;
396
397 let task_loss = self.compute_task_loss(task, &task_params)?;
399 total_loss += task_loss;
400
401 for (name, param) in parameters {
403 if let Some(task_param) = task_params.get(name) {
404 let delta = task_param - param;
405 if let Some(acc_delta) = accumulated_delta.get_mut(name) {
406 *acc_delta = acc_delta.clone() + delta;
407 }
408 }
409 }
410 }
411
412 let n_tasks = tasks.len() as f64;
414 for delta in accumulated_delta.values_mut() {
415 *delta = delta.mapv(|x| x / n_tasks);
416 }
417 total_loss /= n_tasks;
418
419 let mut updated_params = HashMap::new();
421 for (name, param) in parameters {
422 if let Some(delta) = accumulated_delta.get(name) {
423 let updated = param + &delta.mapv(|d| d * self.config.outer_lr);
424 updated_params.insert(name.clone(), updated);
425 }
426 }
427
428 Ok((updated_params, total_loss))
429 }
430
431 fn adapt(
432 &self,
433 task: &MetaTask,
434 parameters: &HashMap<String, Array1<f64>>,
435 ) -> TrainResult<HashMap<String, Array1<f64>>> {
436 let mut task_params = parameters.clone();
437
438 for _ in 0..self.config.inner_steps {
440 let gradients = self.compute_support_gradients(task, &task_params)?;
442
443 for (name, param) in &mut task_params {
445 if let Some(grad) = gradients.get(name) {
446 *param = param.clone() - &grad.mapv(|g| g * self.config.inner_lr);
447 }
448 }
449 }
450
451 Ok(task_params)
452 }
453}
454
455impl Reptile {
456 fn compute_support_gradients(
458 &self,
459 task: &MetaTask,
460 _parameters: &HashMap<String, Array1<f64>>,
461 ) -> TrainResult<HashMap<String, Array1<f64>>> {
462 let mut gradients = HashMap::new();
464 gradients.insert("weights".to_string(), Array1::zeros(task.support_x.ncols()));
465 Ok(gradients)
466 }
467
468 fn compute_task_loss(
470 &self,
471 task: &MetaTask,
472 _parameters: &HashMap<String, Array1<f64>>,
473 ) -> TrainResult<f64> {
474 Ok(task.query_size() as f64 * 0.1)
476 }
477}
478
479#[derive(Debug, Clone, Default)]
481pub struct MetaStats {
482 pub meta_losses: Vec<f64>,
484 pub task_losses: Vec<Vec<f64>>,
486 pub iterations: usize,
488}
489
490impl MetaStats {
491 pub fn new() -> Self {
493 Self::default()
494 }
495
496 pub fn record_meta_step(&mut self, meta_loss: f64) {
498 self.meta_losses.push(meta_loss);
499 self.iterations += 1;
500 }
501
502 pub fn record_task_adaptation(&mut self, task_id: usize, losses: Vec<f64>) {
504 while self.task_losses.len() <= task_id {
505 self.task_losses.push(Vec::new());
506 }
507 self.task_losses[task_id] = losses;
508 }
509
510 pub fn avg_meta_loss(&self, last_n: usize) -> f64 {
512 if self.meta_losses.is_empty() {
513 return 0.0;
514 }
515
516 let n = last_n.min(self.meta_losses.len());
517 let start = self.meta_losses.len() - n;
518 self.meta_losses[start..].iter().sum::<f64>() / n as f64
519 }
520
521 pub fn is_improving(&self, window: usize) -> bool {
523 if self.meta_losses.len() < window * 2 {
524 return false;
525 }
526
527 let recent = self.avg_meta_loss(window);
528 let previous = {
529 let start = self.meta_losses.len() - window * 2;
530 let end = self.meta_losses.len() - window;
531 self.meta_losses[start..end].iter().sum::<f64>() / window as f64
532 };
533
534 recent < previous
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn test_maml_config_default() {
544 let config = MAMLConfig::default();
545 assert_eq!(config.inner_steps, 5);
546 assert_eq!(config.inner_lr, 0.01);
547 assert_eq!(config.outer_lr, 0.001);
548 assert!(!config.first_order);
549 }
550
551 #[test]
552 fn test_reptile_config_default() {
553 let config = ReptileConfig::default();
554 assert_eq!(config.inner_steps, 10);
555 assert_eq!(config.inner_lr, 0.01);
556 assert_eq!(config.outer_lr, 0.1);
557 }
558
559 #[test]
560 fn test_meta_task_creation() {
561 let support_x = Array2::zeros((5, 10));
562 let support_y = Array2::zeros((5, 2));
563 let query_x = Array2::zeros((15, 10));
564 let query_y = Array2::zeros((15, 2));
565
566 let task = MetaTask::new(support_x, support_y, query_x, query_y).unwrap();
567 assert_eq!(task.support_size(), 5);
568 assert_eq!(task.query_size(), 15);
569 }
570
571 #[test]
572 fn test_meta_task_validation() {
573 let support_x = Array2::zeros((5, 10));
574 let support_y = Array2::zeros((4, 2)); let query_x = Array2::zeros((15, 10));
576 let query_y = Array2::zeros((15, 2));
577
578 let result = MetaTask::new(support_x, support_y, query_x, query_y);
579 assert!(result.is_err());
580 }
581
582 #[test]
583 fn test_maml_creation() {
584 let config = MAMLConfig::default();
585 let maml = MAML::new(config);
586 assert_eq!(maml.config.inner_steps, 5);
587 }
588
589 #[test]
590 fn test_maml_default() {
591 let maml = MAML::default();
592 assert_eq!(maml.config.inner_steps, 5);
593 }
594
595 #[test]
596 fn test_reptile_creation() {
597 let config = ReptileConfig::default();
598 let reptile = Reptile::new(config);
599 assert_eq!(reptile.config.inner_steps, 10);
600 }
601
602 #[test]
603 fn test_reptile_default() {
604 let reptile = Reptile::default();
605 assert_eq!(reptile.config.inner_steps, 10);
606 }
607
608 #[test]
609 fn test_maml_adapt() {
610 let maml = MAML::default();
611
612 let task = create_dummy_task();
613 let mut params = HashMap::new();
614 params.insert("weights".to_string(), Array1::zeros(10));
615
616 let adapted = maml.adapt(&task, ¶ms).unwrap();
617 assert!(adapted.contains_key("weights"));
618 }
619
620 #[test]
621 fn test_reptile_adapt() {
622 let reptile = Reptile::default();
623
624 let task = create_dummy_task();
625 let mut params = HashMap::new();
626 params.insert("weights".to_string(), Array1::zeros(10));
627
628 let adapted = reptile.adapt(&task, ¶ms).unwrap();
629 assert!(adapted.contains_key("weights"));
630 }
631
632 #[test]
633 fn test_maml_meta_step() {
634 let maml = MAML::default();
635
636 let tasks = vec![create_dummy_task(), create_dummy_task()];
637 let mut params = HashMap::new();
638 params.insert("weights".to_string(), Array1::zeros(10));
639
640 let (updated_params, loss) = maml.meta_step(&tasks, ¶ms).unwrap();
641 assert!(updated_params.contains_key("weights"));
642 assert!(loss >= 0.0);
643 }
644
645 #[test]
646 fn test_reptile_meta_step() {
647 let reptile = Reptile::default();
648
649 let tasks = vec![create_dummy_task(), create_dummy_task()];
650 let mut params = HashMap::new();
651 params.insert("weights".to_string(), Array1::zeros(10));
652
653 let (updated_params, loss) = reptile.meta_step(&tasks, ¶ms).unwrap();
654 assert!(updated_params.contains_key("weights"));
655 assert!(loss >= 0.0);
656 }
657
658 #[test]
659 fn test_meta_stats() {
660 let mut stats = MetaStats::new();
661
662 stats.record_meta_step(1.0);
663 stats.record_meta_step(0.8);
664 stats.record_meta_step(0.6);
665
666 assert_eq!(stats.iterations, 3);
667 assert_eq!(stats.meta_losses.len(), 3);
668 assert_eq!(stats.avg_meta_loss(2), 0.7);
669 }
670
671 #[test]
672 fn test_meta_stats_improvement() {
673 let mut stats = MetaStats::new();
674
675 for i in 0..20 {
677 stats.record_meta_step(1.0 - i as f64 * 0.01);
678 }
679
680 assert!(stats.is_improving(5));
681 }
682
683 #[test]
684 fn test_meta_stats_task_adaptation() {
685 let mut stats = MetaStats::new();
686
687 stats.record_task_adaptation(0, vec![1.0, 0.8, 0.6]);
688 stats.record_task_adaptation(1, vec![1.2, 0.9, 0.7]);
689
690 assert_eq!(stats.task_losses.len(), 2);
691 assert_eq!(stats.task_losses[0].len(), 3);
692 }
693
694 fn create_dummy_task() -> MetaTask {
696 let support_x = Array2::zeros((5, 10));
697 let support_y = Array2::zeros((5, 2));
698 let query_x = Array2::zeros((15, 10));
699 let query_y = Array2::zeros((15, 2));
700 MetaTask::new(support_x, support_y, query_x, query_y).unwrap()
701 }
702}