1use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
12use std::collections::HashMap;
13
14pub trait CurriculumStrategy {
16 fn select_samples(
26 &self,
27 epoch: usize,
28 total_epochs: usize,
29 difficulties: &ArrayView1<f64>,
30 ) -> TrainResult<Vec<usize>>;
31
32 fn compute_difficulty(
42 &self,
43 data: &Array2<f64>,
44 labels: &Array2<f64>,
45 predictions: Option<&Array2<f64>>,
46 ) -> TrainResult<Array1<f64>>;
47}
48
49#[derive(Debug, Clone)]
54pub struct LinearCurriculum {
55 pub start_percentage: f64,
57 pub sort_by_difficulty: bool,
59}
60
61impl LinearCurriculum {
62 pub fn new(start_percentage: f64) -> TrainResult<Self> {
67 if !(0.0..=1.0).contains(&start_percentage) {
68 return Err(TrainError::InvalidParameter(
69 "start_percentage must be in [0, 1]".to_string(),
70 ));
71 }
72 Ok(Self {
73 start_percentage,
74 sort_by_difficulty: true,
75 })
76 }
77
78 pub fn without_sorting(mut self) -> Self {
80 self.sort_by_difficulty = false;
81 self
82 }
83}
84
85impl Default for LinearCurriculum {
86 fn default() -> Self {
87 Self {
88 start_percentage: 0.2,
89 sort_by_difficulty: true,
90 }
91 }
92}
93
94impl CurriculumStrategy for LinearCurriculum {
95 fn select_samples(
96 &self,
97 epoch: usize,
98 total_epochs: usize,
99 difficulties: &ArrayView1<f64>,
100 ) -> TrainResult<Vec<usize>> {
101 let n = difficulties.len();
102 if n == 0 {
103 return Ok(Vec::new());
104 }
105
106 let progress = if total_epochs > 1 {
108 epoch as f64 / (total_epochs - 1) as f64
109 } else {
110 1.0
111 };
112 let current_percentage = self.start_percentage + (1.0 - self.start_percentage) * progress;
113 let num_samples = ((n as f64 * current_percentage).ceil() as usize).min(n);
114
115 if !self.sort_by_difficulty {
116 return Ok((0..num_samples).collect());
118 }
119
120 let mut indices: Vec<usize> = (0..n).collect();
122 indices.sort_by(|&a, &b| {
123 difficulties[a]
124 .partial_cmp(&difficulties[b])
125 .unwrap_or(std::cmp::Ordering::Equal)
126 });
127
128 Ok(indices.into_iter().take(num_samples).collect())
129 }
130
131 fn compute_difficulty(
132 &self,
133 _data: &Array2<f64>,
134 _labels: &Array2<f64>,
135 predictions: Option<&Array2<f64>>,
136 ) -> TrainResult<Array1<f64>> {
137 if let Some(preds) = predictions {
140 let n = preds.nrows();
141 let mut difficulties = Array1::zeros(n);
142
143 for i in 0..n {
144 let pred = preds.row(i);
145 let mut entropy = 0.0;
147 for &p in pred.iter() {
148 if p > 1e-10 {
149 entropy -= p * p.ln();
150 }
151 }
152 difficulties[i] = entropy;
153 }
154
155 Ok(difficulties)
156 } else {
157 Ok(Array1::zeros(_labels.nrows()))
159 }
160 }
161}
162
163#[derive(Debug, Clone)]
167pub struct ExponentialCurriculum {
168 pub start_percentage: f64,
170 pub growth_rate: f64,
172}
173
174impl ExponentialCurriculum {
175 pub fn new(start_percentage: f64, growth_rate: f64) -> TrainResult<Self> {
181 if !(0.0..=1.0).contains(&start_percentage) {
182 return Err(TrainError::InvalidParameter(
183 "start_percentage must be in [0, 1]".to_string(),
184 ));
185 }
186 if growth_rate <= 0.0 {
187 return Err(TrainError::InvalidParameter(
188 "growth_rate must be positive".to_string(),
189 ));
190 }
191 Ok(Self {
192 start_percentage,
193 growth_rate,
194 })
195 }
196}
197
198impl Default for ExponentialCurriculum {
199 fn default() -> Self {
200 Self {
201 start_percentage: 0.1,
202 growth_rate: 2.0,
203 }
204 }
205}
206
207impl CurriculumStrategy for ExponentialCurriculum {
208 fn select_samples(
209 &self,
210 epoch: usize,
211 total_epochs: usize,
212 difficulties: &ArrayView1<f64>,
213 ) -> TrainResult<Vec<usize>> {
214 let n = difficulties.len();
215 if n == 0 {
216 return Ok(Vec::new());
217 }
218
219 let progress = if total_epochs > 1 {
221 epoch as f64 / (total_epochs - 1) as f64
222 } else {
223 1.0
224 };
225 let current_percentage =
226 (self.start_percentage * (self.growth_rate * progress).exp()).min(1.0);
227 let num_samples = ((n as f64 * current_percentage).ceil() as usize).min(n);
228
229 let mut indices: Vec<usize> = (0..n).collect();
231 indices.sort_by(|&a, &b| {
232 difficulties[a]
233 .partial_cmp(&difficulties[b])
234 .unwrap_or(std::cmp::Ordering::Equal)
235 });
236
237 Ok(indices.into_iter().take(num_samples).collect())
238 }
239
240 fn compute_difficulty(
241 &self,
242 _data: &Array2<f64>,
243 _labels: &Array2<f64>,
244 predictions: Option<&Array2<f64>>,
245 ) -> TrainResult<Array1<f64>> {
246 if let Some(preds) = predictions {
248 let n = preds.nrows();
249 let mut difficulties = Array1::zeros(n);
250
251 for i in 0..n {
252 let pred = preds.row(i);
253 let mut entropy = 0.0;
254 for &p in pred.iter() {
255 if p > 1e-10 {
256 entropy -= p * p.ln();
257 }
258 }
259 difficulties[i] = entropy;
260 }
261
262 Ok(difficulties)
263 } else {
264 Ok(Array1::zeros(_labels.nrows()))
265 }
266 }
267}
268
269#[derive(Debug, Clone)]
274pub struct SelfPacedCurriculum {
275 pub lambda: f64,
277 pub threshold: f64,
279}
280
281impl SelfPacedCurriculum {
282 pub fn new(lambda: f64, threshold: f64) -> TrainResult<Self> {
288 if lambda <= 0.0 {
289 return Err(TrainError::InvalidParameter(
290 "lambda must be positive".to_string(),
291 ));
292 }
293 Ok(Self { lambda, threshold })
294 }
295}
296
297impl Default for SelfPacedCurriculum {
298 fn default() -> Self {
299 Self {
300 lambda: 1.0,
301 threshold: 0.5,
302 }
303 }
304}
305
306impl CurriculumStrategy for SelfPacedCurriculum {
307 fn select_samples(
308 &self,
309 _epoch: usize,
310 _total_epochs: usize,
311 difficulties: &ArrayView1<f64>,
312 ) -> TrainResult<Vec<usize>> {
313 let indices: Vec<usize> = difficulties
315 .iter()
316 .enumerate()
317 .filter(|(_, &d)| d < self.threshold)
318 .map(|(i, _)| i)
319 .collect();
320
321 Ok(indices)
322 }
323
324 fn compute_difficulty(
325 &self,
326 _data: &Array2<f64>,
327 labels: &Array2<f64>,
328 predictions: Option<&Array2<f64>>,
329 ) -> TrainResult<Array1<f64>> {
330 if let Some(preds) = predictions {
331 let n = preds.nrows();
332 let mut difficulties = Array1::zeros(n);
333
334 for i in 0..n {
335 let pred = preds.row(i);
337 let label = labels.row(i);
338
339 let mut loss = 0.0;
340 for j in 0..pred.len() {
341 let p = pred[j].clamp(1e-10, 1.0 - 1e-10);
342 loss -= label[j] * p.ln();
343 }
344
345 difficulties[i] = loss * self.lambda;
347 }
348
349 Ok(difficulties)
350 } else {
351 Err(TrainError::InvalidParameter(
352 "SelfPacedCurriculum requires predictions for difficulty computation".to_string(),
353 ))
354 }
355 }
356}
357
358#[derive(Debug, Clone)]
362pub struct CompetenceCurriculum {
363 pub initial_competence: f64,
365 pub growth_rate: f64,
367 pub max_competence: f64,
369}
370
371impl CompetenceCurriculum {
372 pub fn new(initial_competence: f64, growth_rate: f64) -> TrainResult<Self> {
378 if !(0.0..=1.0).contains(&initial_competence) {
379 return Err(TrainError::InvalidParameter(
380 "initial_competence must be in [0, 1]".to_string(),
381 ));
382 }
383 Ok(Self {
384 initial_competence,
385 growth_rate,
386 max_competence: 1.0,
387 })
388 }
389}
390
391impl Default for CompetenceCurriculum {
392 fn default() -> Self {
393 Self {
394 initial_competence: 0.3,
395 growth_rate: 0.05,
396 max_competence: 1.0,
397 }
398 }
399}
400
401impl CurriculumStrategy for CompetenceCurriculum {
402 fn select_samples(
403 &self,
404 epoch: usize,
405 _total_epochs: usize,
406 difficulties: &ArrayView1<f64>,
407 ) -> TrainResult<Vec<usize>> {
408 let competence =
410 (self.initial_competence + self.growth_rate * epoch as f64).min(self.max_competence);
411
412 let indices: Vec<usize> = difficulties
414 .iter()
415 .enumerate()
416 .filter(|(_, &d)| d <= competence)
417 .map(|(i, _)| i)
418 .collect();
419
420 Ok(indices)
421 }
422
423 fn compute_difficulty(
424 &self,
425 _data: &Array2<f64>,
426 _labels: &Array2<f64>,
427 predictions: Option<&Array2<f64>>,
428 ) -> TrainResult<Array1<f64>> {
429 if let Some(preds) = predictions {
431 let n = preds.nrows();
432 let mut difficulties = Array1::zeros(n);
433
434 for i in 0..n {
435 let pred = preds.row(i);
436 let mut entropy = 0.0;
437 for &p in pred.iter() {
438 if p > 1e-10 {
439 entropy -= p * p.ln();
440 }
441 }
442 difficulties[i] = entropy;
443 }
444
445 let max_difficulty = difficulties.iter().cloned().fold(0.0f64, f64::max);
447 if max_difficulty > 0.0 {
448 difficulties.mapv_inplace(|d| d / max_difficulty);
449 }
450
451 Ok(difficulties)
452 } else {
453 Ok(Array1::zeros(_labels.nrows()))
454 }
455 }
456}
457
458#[derive(Debug, Clone)]
462pub struct TaskCurriculum {
463 task_schedule: Vec<(usize, usize)>,
465}
466
467impl TaskCurriculum {
468 pub fn new(schedule: Vec<(usize, usize)>) -> Self {
473 let mut sorted_schedule = schedule;
474 sorted_schedule.sort_by_key(|(epoch, _)| *epoch);
475 Self {
476 task_schedule: sorted_schedule,
477 }
478 }
479
480 pub fn get_active_tasks(&self, epoch: usize) -> Vec<usize> {
488 self.task_schedule
489 .iter()
490 .filter(|(start_epoch, _)| *start_epoch <= epoch)
491 .map(|(_, task_id)| *task_id)
492 .collect()
493 }
494}
495
496impl Default for TaskCurriculum {
497 fn default() -> Self {
498 Self {
500 task_schedule: vec![(0, 0)],
501 }
502 }
503}
504
505pub struct CurriculumManager {
507 strategy: Box<dyn CurriculumStrategyClone>,
508 difficulty_cache: HashMap<String, Array1<f64>>,
509 current_epoch: usize,
510}
511
512impl std::fmt::Debug for CurriculumManager {
513 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514 f.debug_struct("CurriculumManager")
515 .field("current_epoch", &self.current_epoch)
516 .field("num_cached_difficulties", &self.difficulty_cache.len())
517 .finish()
518 }
519}
520
521trait CurriculumStrategyClone: CurriculumStrategy {
523 fn clone_box(&self) -> Box<dyn CurriculumStrategyClone>;
524}
525
526impl<T: CurriculumStrategy + Clone + 'static> CurriculumStrategyClone for T {
527 fn clone_box(&self) -> Box<dyn CurriculumStrategyClone> {
528 Box::new(self.clone())
529 }
530}
531
532impl Clone for Box<dyn CurriculumStrategyClone> {
533 fn clone(&self) -> Self {
534 self.clone_box()
535 }
536}
537
538impl CurriculumStrategy for Box<dyn CurriculumStrategyClone> {
539 fn select_samples(
540 &self,
541 epoch: usize,
542 total_epochs: usize,
543 difficulties: &ArrayView1<f64>,
544 ) -> TrainResult<Vec<usize>> {
545 (**self).select_samples(epoch, total_epochs, difficulties)
546 }
547
548 fn compute_difficulty(
549 &self,
550 data: &Array2<f64>,
551 labels: &Array2<f64>,
552 predictions: Option<&Array2<f64>>,
553 ) -> TrainResult<Array1<f64>> {
554 (**self).compute_difficulty(data, labels, predictions)
555 }
556}
557
558impl CurriculumManager {
559 pub fn new<S: CurriculumStrategy + Clone + 'static>(strategy: S) -> Self {
564 Self {
565 strategy: Box::new(strategy),
566 difficulty_cache: HashMap::new(),
567 current_epoch: 0,
568 }
569 }
570
571 pub fn set_epoch(&mut self, epoch: usize) {
573 self.current_epoch = epoch;
574 }
575
576 pub fn compute_difficulty(
584 &mut self,
585 key: &str,
586 data: &Array2<f64>,
587 labels: &Array2<f64>,
588 predictions: Option<&Array2<f64>>,
589 ) -> TrainResult<()> {
590 let difficulties = self
591 .strategy
592 .compute_difficulty(data, labels, predictions)?;
593 self.difficulty_cache.insert(key.to_string(), difficulties);
594 Ok(())
595 }
596
597 pub fn get_selected_samples(&self, key: &str, total_epochs: usize) -> TrainResult<Vec<usize>> {
606 let difficulties = self.difficulty_cache.get(key).ok_or_else(|| {
607 TrainError::InvalidParameter(format!("No difficulty scores cached for key: {}", key))
608 })?;
609
610 self.strategy
611 .select_samples(self.current_epoch, total_epochs, &difficulties.view())
612 }
613
614 pub fn clear_cache(&mut self) {
616 self.difficulty_cache.clear();
617 }
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623 use scirs2_core::ndarray::array;
624
625 #[test]
626 fn test_linear_curriculum() {
627 let curriculum = LinearCurriculum::new(0.2).unwrap();
628 let difficulties = array![0.1, 0.5, 0.3, 0.9, 0.2];
629
630 let selected = curriculum
632 .select_samples(0, 10, &difficulties.view())
633 .unwrap();
634 assert_eq!(selected.len(), 1);
635
636 let selected = curriculum
638 .select_samples(9, 10, &difficulties.view())
639 .unwrap();
640 assert_eq!(selected.len(), 5);
641 }
642
643 #[test]
644 fn test_linear_curriculum_invalid() {
645 assert!(LinearCurriculum::new(-0.1).is_err());
646 assert!(LinearCurriculum::new(1.5).is_err());
647 }
648
649 #[test]
650 fn test_exponential_curriculum() {
651 let curriculum = ExponentialCurriculum::new(0.1, 2.0).unwrap();
652 let difficulties = array![0.1, 0.5, 0.3, 0.9, 0.2];
653
654 let selected = curriculum
655 .select_samples(0, 10, &difficulties.view())
656 .unwrap();
657 assert!(!selected.is_empty());
658
659 let selected = curriculum
660 .select_samples(9, 10, &difficulties.view())
661 .unwrap();
662 assert!(selected.len() >= 4);
664 }
665
666 #[test]
667 fn test_self_paced_curriculum() {
668 let curriculum = SelfPacedCurriculum::new(1.0, 0.5).unwrap();
669 let difficulties = array![0.1, 0.6, 0.3, 0.9, 0.2];
670
671 let selected = curriculum
673 .select_samples(0, 10, &difficulties.view())
674 .unwrap();
675 assert_eq!(selected.len(), 3); }
677
678 #[test]
679 fn test_competence_curriculum() {
680 let curriculum = CompetenceCurriculum::new(0.3, 0.1).unwrap();
681 let difficulties = array![0.1, 0.5, 0.3, 0.9, 0.2];
682
683 let selected = curriculum
685 .select_samples(0, 10, &difficulties.view())
686 .unwrap();
687 assert_eq!(selected.len(), 3); let selected = curriculum
691 .select_samples(5, 10, &difficulties.view())
692 .unwrap();
693 assert!(selected.len() >= 3);
694 }
695
696 #[test]
697 fn test_task_curriculum() {
698 let curriculum = TaskCurriculum::new(vec![(0, 0), (5, 1), (10, 2)]);
699
700 let tasks = curriculum.get_active_tasks(0);
701 assert_eq!(tasks.len(), 1);
702 assert_eq!(tasks[0], 0);
703
704 let tasks = curriculum.get_active_tasks(7);
705 assert_eq!(tasks.len(), 2);
706 assert!(tasks.contains(&0));
707 assert!(tasks.contains(&1));
708
709 let tasks = curriculum.get_active_tasks(15);
710 assert_eq!(tasks.len(), 3);
711 }
712
713 #[test]
714 fn test_difficulty_computation() {
715 let curriculum = LinearCurriculum::default();
716
717 let data = array![[1.0, 2.0], [3.0, 4.0]];
719 let labels = array![[1.0, 0.0], [0.0, 1.0]];
720 let predictions = array![[0.8, 0.2], [0.3, 0.7]];
721
722 let difficulties = curriculum
723 .compute_difficulty(&data, &labels, Some(&predictions))
724 .unwrap();
725 assert_eq!(difficulties.len(), 2);
726 assert!(difficulties.iter().all(|&d| d >= 0.0));
727
728 let difficulties = curriculum.compute_difficulty(&data, &labels, None).unwrap();
730 assert_eq!(difficulties.len(), 2);
731 assert!(difficulties.iter().all(|&d| d == 0.0));
732 }
733
734 #[test]
735 fn test_curriculum_manager() {
736 let mut manager = CurriculumManager::new(LinearCurriculum::default());
737
738 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
739 let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
740 let predictions = array![[0.8, 0.2], [0.3, 0.7], [0.6, 0.4]];
741
742 manager
744 .compute_difficulty("train", &data, &labels, Some(&predictions))
745 .unwrap();
746
747 manager.set_epoch(0);
749 let selected = manager.get_selected_samples("train", 10).unwrap();
750 assert!(!selected.is_empty());
751
752 manager.clear_cache();
754 }
755
756 #[test]
757 fn test_curriculum_manager_missing_key() {
758 let manager = CurriculumManager::new(LinearCurriculum::default());
759 let result = manager.get_selected_samples("nonexistent", 10);
760 assert!(result.is_err());
761 }
762
763 #[test]
764 fn test_linear_curriculum_without_sorting() {
765 let curriculum = LinearCurriculum::new(0.5).unwrap().without_sorting();
766 let difficulties = array![0.9, 0.1, 0.5, 0.3, 0.7];
767
768 let selected = curriculum
770 .select_samples(0, 10, &difficulties.view())
771 .unwrap();
772 assert_eq!(selected.len(), 3); }
774
775 #[test]
776 fn test_empty_difficulties() {
777 let curriculum = LinearCurriculum::default();
778 let difficulties = Array1::<f64>::zeros(0);
779
780 let selected = curriculum
781 .select_samples(0, 10, &difficulties.view())
782 .unwrap();
783 assert_eq!(selected.len(), 0);
784 }
785}