1use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use trustformers_core::{errors::invalid_input, tensor::Tensor, traits::Model, Result};
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CurriculumConfig {
44 pub strategy: CurriculumStrategy,
46 pub difficulty_measure: DifficultyMeasure,
48 pub pacing_function: PacingFunction,
50 pub initial_data_percentage: f32,
52 pub use_throughout_training: bool,
54 pub curriculum_epochs: usize,
56 pub shuffle_easy_examples: bool,
58 pub adaptive_threshold: bool,
60 pub min_difficulty_threshold: f32,
62 pub max_difficulty_threshold: f32,
64 pub evaluation_frequency: usize,
66}
67
68impl Default for CurriculumConfig {
69 fn default() -> Self {
70 Self {
71 strategy: CurriculumStrategy::SelfPaced {
72 lambda: 0.5,
73 gamma: 1.1,
74 },
75 difficulty_measure: DifficultyMeasure::LossBasedDifficulty,
76 pacing_function: PacingFunction::Linear,
77 initial_data_percentage: 0.1,
78 use_throughout_training: true,
79 curriculum_epochs: 10,
80 shuffle_easy_examples: true,
81 adaptive_threshold: true,
82 min_difficulty_threshold: 0.1,
83 max_difficulty_threshold: 0.9,
84 evaluation_frequency: 1000,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum CurriculumStrategy {
92 SelfPaced { lambda: f32, gamma: f32 },
94 CompetenceBased {
96 competence_threshold: f32,
97 increase_rate: f32,
98 },
99 Predefined {
101 difficulty_levels: Vec<f32>,
102 level_durations: Vec<usize>,
103 },
104 BabySteps { step_size: f32, patience: usize },
106 AntiCurriculum { reverse_pacing: bool },
108 Cyclical {
110 cycle_length: usize,
111 num_cycles: usize,
112 },
113 Minimax {
115 teacher_lambda: f32,
116 student_lambda: f32,
117 },
118 Random,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub enum DifficultyMeasure {
125 LossBasedDifficulty,
127 GradientNormDifficulty,
129 ConfidenceDifficulty,
131 LengthDifficulty,
133 ComplexityDifficulty,
135 MultiCriteria {
137 measures: Vec<DifficultyMeasure>,
138 weights: Vec<f32>,
139 },
140 LearnedDifficulty {
142 difficulty_network: Option<String>, },
144 ManualDifficulty,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub enum PacingFunction {
151 Linear,
153 Exponential { rate: f32 },
155 Logarithmic { base: f32 },
157 Sigmoid { steepness: f32, midpoint: f32 },
159 StepWise { steps: Vec<(usize, f32)> },
161 Polynomial { degree: f32 },
163 Custom { function_name: String },
165}
166
167#[derive(Debug, Clone)]
169pub struct CurriculumExample {
170 pub input: Tensor,
172 pub target: Tensor,
174 pub difficulty: f32,
176 pub metadata: HashMap<String, String>,
178 pub weight: f32,
180}
181
182impl CurriculumExample {
183 pub fn new(input: Tensor, target: Tensor, difficulty: f32) -> Self {
185 Self {
186 input,
187 target,
188 difficulty,
189 metadata: HashMap::new(),
190 weight: 1.0,
191 }
192 }
193
194 pub fn with_metadata(
196 input: Tensor,
197 target: Tensor,
198 difficulty: f32,
199 metadata: HashMap<String, String>,
200 ) -> Self {
201 Self {
202 input,
203 target,
204 difficulty,
205 metadata,
206 weight: 1.0,
207 }
208 }
209
210 pub fn with_weight(mut self, weight: f32) -> Self {
212 self.weight = weight;
213 self
214 }
215}
216
217pub struct CurriculumLearningTrainer<M: Model> {
219 pub model: M,
221 pub config: CurriculumConfig,
223 pub examples: Vec<CurriculumExample>,
225 pub current_threshold: f32,
227 pub current_epoch: usize,
229 pub step_counter: usize,
231 pub performance_history: Vec<f32>,
233 pub difficulty_scorer: Option<DifficultyScorer>,
235}
236
237impl<M: Model<Input = Tensor, Output = Tensor>> CurriculumLearningTrainer<M> {
238 pub fn new(model: M, config: CurriculumConfig) -> Result<Self> {
240 let difficulty_scorer = match &config.difficulty_measure {
241 DifficultyMeasure::LearnedDifficulty { .. } => {
242 Some(DifficultyScorer::new(&config.difficulty_measure)?)
243 },
244 _ => None,
245 };
246
247 let initial_data_percentage = config.initial_data_percentage;
248
249 Ok(Self {
250 model,
251 config,
252 examples: Vec::new(),
253 current_threshold: initial_data_percentage,
254 current_epoch: 0,
255 step_counter: 0,
256 performance_history: Vec::new(),
257 difficulty_scorer,
258 })
259 }
260
261 pub fn add_examples(&mut self, examples: Vec<CurriculumExample>) {
263 self.examples.extend(examples);
264 self.sort_examples_by_difficulty();
265 }
266
267 pub fn add_example(&mut self, example: CurriculumExample) {
269 self.examples.push(example);
270 self.sort_examples_by_difficulty();
271 }
272
273 pub fn estimate_difficulties(&mut self) -> Result<()> {
275 let mut indices_to_update = Vec::new();
276
277 for (i, example) in self.examples.iter().enumerate() {
279 if example.difficulty == 0.0 {
280 indices_to_update.push(i);
282 }
283 }
284
285 for i in indices_to_update {
287 let input = self.examples[i].input.clone();
288 let target = self.examples[i].target.clone();
289 let difficulty = self.compute_difficulty(&input, &target)?;
290 self.examples[i].difficulty = difficulty;
291 }
292
293 self.sort_examples_by_difficulty();
294 Ok(())
295 }
296
297 fn compute_difficulty(&self, input: &Tensor, target: &Tensor) -> Result<f32> {
299 match &self.config.difficulty_measure {
300 DifficultyMeasure::LossBasedDifficulty => {
301 let outputs = self.model.forward(input.clone())?;
302 let loss = self.compute_loss(&outputs, target)?;
303 loss.to_scalar().map_err(|e| {
304 invalid_input(format!("Failed to convert loss tensor to scalar: {}", e))
305 })
306 },
307 DifficultyMeasure::GradientNormDifficulty => {
308 Ok(0.5) },
312 DifficultyMeasure::ConfidenceDifficulty => {
313 let outputs = self.model.forward(input.clone())?;
314 let probs = outputs.softmax(-1)?;
315 let max_prob = self.compute_max_probability(&probs)?;
316 Ok(1.0 - max_prob) },
318 DifficultyMeasure::LengthDifficulty => {
319 let seq_len = input.shape()[1] as f32; Ok(seq_len / 1000.0) },
323 DifficultyMeasure::ComplexityDifficulty => {
324 Ok(0.5) },
328 DifficultyMeasure::MultiCriteria { measures, weights } => {
329 let mut total_difficulty = 0.0;
330 let mut total_weight = 0.0;
331
332 for (measure, &weight) in measures.iter().zip(weights.iter()) {
333 let difficulty = self.compute_individual_difficulty(measure, input, target)?;
335 total_difficulty += difficulty * weight;
336 total_weight += weight;
337 }
338
339 Ok(if total_weight > 0.0 { total_difficulty / total_weight } else { 0.5 })
340 },
341 DifficultyMeasure::LearnedDifficulty { .. } => {
342 if let Some(scorer) = &self.difficulty_scorer {
343 scorer.score_difficulty(input, target)
344 } else {
345 Ok(0.5)
346 }
347 },
348 DifficultyMeasure::ManualDifficulty => {
349 Ok(0.5) },
352 }
353 }
354
355 fn compute_individual_difficulty(
357 &self,
358 measure: &DifficultyMeasure,
359 input: &Tensor,
360 target: &Tensor,
361 ) -> Result<f32> {
362 match measure {
363 DifficultyMeasure::LossBasedDifficulty => {
364 let outputs = self.model.forward(input.clone())?;
365 let loss = self.compute_loss(&outputs, target)?;
366 loss.to_scalar().map_err(|e| {
367 invalid_input(format!("Failed to convert loss tensor to scalar: {}", e))
368 })
369 },
370 DifficultyMeasure::LengthDifficulty => {
371 let seq_len = input.shape()[1] as f32; Ok(seq_len / 1000.0) },
374 DifficultyMeasure::GradientNormDifficulty => {
375 Ok(0.5) },
378 DifficultyMeasure::ConfidenceDifficulty => {
379 let _outputs = self.model.forward(input.clone())?;
381 Ok(0.5) },
384 DifficultyMeasure::ComplexityDifficulty => {
385 Ok(0.5) },
389 DifficultyMeasure::LearnedDifficulty { .. } => {
390 if let Some(scorer) = &self.difficulty_scorer {
391 scorer.score_difficulty(input, target)
392 } else {
393 Ok(0.5)
394 }
395 },
396 DifficultyMeasure::ManualDifficulty => {
397 Ok(0.5) },
400 DifficultyMeasure::MultiCriteria { .. } => {
401 Ok(0.5)
403 },
404 }
405 }
406
407 fn sort_examples_by_difficulty(&mut self) {
409 self.examples.sort_by(|a, b| {
410 a.difficulty.partial_cmp(&b.difficulty).unwrap_or(std::cmp::Ordering::Equal)
411 });
412 }
413
414 pub fn get_current_curriculum(&self) -> Vec<CurriculumExample> {
416 let num_examples = self.examples.len();
417 let threshold_count = (num_examples as f32 * self.current_threshold) as usize;
418
419 match &self.config.strategy {
420 CurriculumStrategy::AntiCurriculum { reverse_pacing } => {
421 if *reverse_pacing {
422 self.examples.iter().rev().take(threshold_count).cloned().collect()
424 } else {
425 self.examples.iter().take(threshold_count).cloned().collect()
426 }
427 },
428 _ => {
429 self.examples.iter().take(threshold_count).cloned().collect()
431 },
432 }
433 }
434
435 pub fn update_curriculum_threshold(&mut self) -> Result<()> {
437 match &self.config.strategy {
438 CurriculumStrategy::SelfPaced { lambda: _, gamma } => {
439 let recent_performance = self.get_recent_performance();
441 if recent_performance > 0.8 {
442 self.current_threshold = (self.current_threshold * gamma).min(1.0);
444 }
445 },
446 CurriculumStrategy::CompetenceBased {
447 competence_threshold,
448 increase_rate,
449 } => {
450 let competence = self.compute_competence()?;
451 if competence > *competence_threshold {
452 self.current_threshold = (self.current_threshold + increase_rate).min(1.0);
453 }
454 },
455 CurriculumStrategy::Predefined {
456 difficulty_levels,
457 level_durations,
458 } => {
459 let total_steps: usize = level_durations.iter().sum();
461 let current_step = self.step_counter % total_steps;
462 let mut cumulative_steps = 0;
463
464 for (i, &duration) in level_durations.iter().enumerate() {
465 cumulative_steps += duration;
466 if current_step < cumulative_steps {
467 if i < difficulty_levels.len() {
468 self.current_threshold = difficulty_levels[i];
469 }
470 break;
471 }
472 }
473 },
474 CurriculumStrategy::BabySteps {
475 step_size,
476 patience,
477 } => {
478 if self.performance_history.len() >= *patience {
480 let recent_avg =
481 self.performance_history.iter().rev().take(*patience).sum::<f32>()
482 / *patience as f32;
483
484 if recent_avg > 0.85 {
485 self.current_threshold = (self.current_threshold + step_size).min(1.0);
487 }
488 }
489 },
490 CurriculumStrategy::Cyclical { cycle_length, .. } => {
491 let cycle_position =
493 (self.step_counter % cycle_length) as f32 / *cycle_length as f32;
494 self.current_threshold = self.apply_pacing_function(cycle_position);
495 },
496 _ => {
497 let progress = self.current_epoch as f32 / self.config.curriculum_epochs as f32;
499 self.current_threshold = self.apply_pacing_function(progress);
500 },
501 }
502
503 self.current_threshold = self
505 .current_threshold
506 .max(self.config.min_difficulty_threshold)
507 .min(self.config.max_difficulty_threshold);
508
509 Ok(())
510 }
511
512 fn apply_pacing_function(&self, progress: f32) -> f32 {
514 let clamped_progress = progress.clamp(0.0, 1.0);
515
516 match &self.config.pacing_function {
517 PacingFunction::Linear => {
518 self.config.initial_data_percentage
519 + (1.0 - self.config.initial_data_percentage) * clamped_progress
520 },
521 PacingFunction::Exponential { rate } => {
522 self.config.initial_data_percentage
523 + (1.0 - self.config.initial_data_percentage)
524 * (1.0 - (-rate * clamped_progress).exp())
525 },
526 PacingFunction::Logarithmic { base } => {
527 self.config.initial_data_percentage
528 + (1.0 - self.config.initial_data_percentage) * (clamped_progress * base).ln()
529 / base.ln()
530 },
531 PacingFunction::Sigmoid {
532 steepness,
533 midpoint,
534 } => {
535 let sigmoid = 1.0 / (1.0 + (-steepness * (clamped_progress - midpoint)).exp());
536 self.config.initial_data_percentage
537 + (1.0 - self.config.initial_data_percentage) * sigmoid
538 },
539 PacingFunction::StepWise { steps } => {
540 let total_steps = self.step_counter;
541 for &(step_threshold, threshold_value) in steps {
542 if total_steps <= step_threshold {
543 return threshold_value;
544 }
545 }
546 1.0 },
548 PacingFunction::Polynomial { degree } => {
549 self.config.initial_data_percentage
550 + (1.0 - self.config.initial_data_percentage) * clamped_progress.powf(*degree)
551 },
552 PacingFunction::Custom { .. } => {
553 self.apply_pacing_function_linear(clamped_progress)
555 },
556 }
557 }
558
559 fn apply_pacing_function_linear(&self, progress: f32) -> f32 {
561 self.config.initial_data_percentage + (1.0 - self.config.initial_data_percentage) * progress
562 }
563
564 fn compute_competence(&self) -> Result<f32> {
566 if self.performance_history.is_empty() {
567 return Ok(0.0);
568 }
569
570 let recent_performance = self.get_recent_performance();
571 Ok(recent_performance)
572 }
573
574 fn get_recent_performance(&self) -> f32 {
576 if self.performance_history.is_empty() {
577 return 0.0;
578 }
579
580 let window_size = 10.min(self.performance_history.len());
581 self.performance_history.iter().rev().take(window_size).sum::<f32>() / window_size as f32
582 }
583
584 pub fn train_step(&mut self) -> Result<CurriculumLearningOutput> {
586 self.update_curriculum_threshold()?;
588
589 let curriculum_examples = self.get_current_curriculum();
591
592 if curriculum_examples.is_empty() {
593 return Err(invalid_input(
594 "No examples available for training".to_string(),
595 ));
596 }
597
598 let example = &curriculum_examples[self.step_counter % curriculum_examples.len()];
600
601 let outputs = self.model.forward(example.input.clone())?;
603 let loss = self.compute_loss(&outputs, &example.target)?;
604
605 let weighted_loss = loss.scalar_mul(example.weight)?;
607
608 let accuracy = self.compute_accuracy(&outputs, &example.target)?;
610 self.performance_history.push(accuracy);
611
612 if self.performance_history.len() > 1000 {
614 self.performance_history = self.performance_history.split_off(500);
615 }
616
617 self.step_counter += 1;
618
619 Ok(CurriculumLearningOutput {
620 loss: weighted_loss,
621 accuracy,
622 difficulty_threshold: self.current_threshold,
623 examples_used: curriculum_examples.len(),
624 current_difficulty: example.difficulty,
625 })
626 }
627
628 pub fn train_epoch(&mut self) -> Result<CurriculumEpochOutput> {
630 let mut total_loss = 0.0;
631 let mut total_accuracy = 0.0;
632 let mut num_steps = 0;
633
634 let curriculum_examples = self.get_current_curriculum();
635
636 for example in &curriculum_examples {
637 let outputs = self.model.forward(example.input.clone())?;
638 let loss = self.compute_loss(&outputs, &example.target)?;
639 let accuracy = self.compute_accuracy(&outputs, &example.target)?;
640
641 let loss_scalar = loss.to_scalar().map_err(|e| {
642 invalid_input(format!("Failed to convert loss tensor to scalar: {}", e))
643 })?;
644 total_loss += loss_scalar * example.weight;
645 total_accuracy += accuracy;
646 num_steps += 1;
647 }
648
649 self.current_epoch += 1;
650
651 Ok(CurriculumEpochOutput {
652 epoch: self.current_epoch,
653 average_loss: total_loss / num_steps as f32,
654 average_accuracy: total_accuracy / num_steps as f32,
655 difficulty_threshold: self.current_threshold,
656 examples_used: curriculum_examples.len(),
657 total_examples: self.examples.len(),
658 })
659 }
660
661 fn compute_loss(&self, outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
663 self.compute_cross_entropy_loss(outputs, targets)
664 }
665
666 fn compute_accuracy(&self, outputs: &Tensor, targets: &Tensor) -> Result<f32> {
668 let predicted = self.compute_argmax(outputs)?;
669 let target_indices = self.compute_argmax(targets)?;
670
671 let total_samples = predicted.len() as f32;
673 if total_samples == 0.0 {
674 return Ok(0.0);
675 }
676
677 let mut correct = 0.0;
678 for (pred, target) in predicted.iter().zip(target_indices.iter()) {
679 if (pred - target).abs() < f32::EPSILON {
680 correct += 1.0;
681 }
682 }
683
684 Ok(correct / total_samples)
685 }
686
687 pub fn get_curriculum_stats(&self) -> CurriculumStats {
689 let curriculum_examples = self.get_current_curriculum();
690 let difficulties: Vec<f32> = curriculum_examples.iter().map(|e| e.difficulty).collect();
691
692 let min_difficulty = difficulties.iter().fold(f32::INFINITY, |a, &b| a.min(b));
693 let max_difficulty = difficulties.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
694 let avg_difficulty = if !difficulties.is_empty() {
695 difficulties.iter().sum::<f32>() / difficulties.len() as f32
696 } else {
697 0.0
698 };
699
700 CurriculumStats {
701 current_threshold: self.current_threshold,
702 examples_in_curriculum: curriculum_examples.len(),
703 total_examples: self.examples.len(),
704 min_difficulty,
705 max_difficulty,
706 avg_difficulty,
707 epoch: self.current_epoch,
708 step: self.step_counter,
709 }
710 }
711
712 fn compute_max_probability(&self, probs: &Tensor) -> Result<f32> {
714 match probs {
715 Tensor::F32(arr) => {
716 let max_val = arr.iter().fold(0.0f32, |acc, &x| acc.max(x));
718 Ok(max_val)
719 },
720 _ => {
721 Ok(0.5) },
723 }
724 }
725
726 fn compute_cross_entropy_loss(&self, outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
728 let probs = outputs.softmax(-1)?;
730
731 let log_probs = probs.log()?;
733
734 match (log_probs, targets) {
736 (Tensor::F32(log_prob_arr), Tensor::F32(target_arr)) => {
737 let batch_size = log_prob_arr.shape()[0];
739 let num_classes = log_prob_arr.shape().get(1).copied().ok_or_else(|| {
740 invalid_input(format!(
741 "Invalid tensor shape: expected at least 2 dimensions, got {}",
742 log_prob_arr.shape().len()
743 ))
744 })?;
745
746 let mut total_loss = 0.0f32;
747
748 for batch_idx in 0..batch_size {
749 if target_arr.shape().len() == 1 {
750 let target_class = target_arr[[batch_idx]] as usize;
752 if target_class < num_classes {
753 total_loss -= log_prob_arr[[batch_idx, target_class]];
754 }
755 } else if target_arr.shape().len() >= 2 && target_arr.shape()[1] == num_classes
756 {
757 for class_idx in 0..num_classes {
759 let target_prob = target_arr[[batch_idx, class_idx]];
760 if target_prob > 0.0 {
761 total_loss -= target_prob * log_prob_arr[[batch_idx, class_idx]];
762 }
763 }
764 }
765 }
766
767 let mean_loss = total_loss / batch_size as f32;
769 Ok(Tensor::scalar(mean_loss)?)
770 },
771 _ => {
772 Ok(Tensor::scalar(1.0f32)?)
774 },
775 }
776 }
777
778 fn compute_argmax(&self, tensor: &Tensor) -> Result<Vec<f32>> {
780 match tensor {
781 Tensor::F32(arr) => {
782 let mut argmax_values = Vec::new();
783
784 if arr.ndim() == 1 {
786 let mut max_idx = 0;
788 let mut max_val = arr[0];
789 for (idx, &val) in arr.iter().enumerate() {
790 if val > max_val {
791 max_val = val;
792 max_idx = idx;
793 }
794 }
795 argmax_values.push(max_idx as f32);
796 } else if arr.ndim() == 2 {
797 let batch_size = arr.shape()[0];
799 let num_classes = arr.shape()[1];
800
801 for batch_idx in 0..batch_size {
802 let mut max_idx = 0;
803 let mut max_val = arr[[batch_idx, 0]];
804
805 for class_idx in 1..num_classes {
806 let val = arr[[batch_idx, class_idx]];
807 if val > max_val {
808 max_val = val;
809 max_idx = class_idx;
810 }
811 }
812 argmax_values.push(max_idx as f32);
813 }
814 } else {
815 let mut max_idx = 0;
817 let mut max_val = arr.iter().next().copied().ok_or_else(|| {
818 invalid_input("Cannot compute argmax on empty tensor".to_string())
819 })?;
820
821 for (idx, &val) in arr.iter().enumerate() {
822 if val > max_val {
823 max_val = val;
824 max_idx = idx;
825 }
826 }
827 argmax_values.push(max_idx as f32);
828 }
829
830 Ok(argmax_values)
831 },
832 _ => {
833 Ok(vec![0.0])
835 },
836 }
837 }
838}
839
840pub struct DifficultyScorer {
842 #[allow(dead_code)]
844 method: DifficultyMeasure,
845}
846
847impl DifficultyScorer {
848 pub fn new(method: &DifficultyMeasure) -> Result<Self> {
849 Ok(Self {
850 method: method.clone(),
851 })
852 }
853
854 pub fn score_difficulty(&self, _input: &Tensor, _target: &Tensor) -> Result<f32> {
855 Ok(0.5) }
859}
860
861#[derive(Debug, Clone)]
863pub struct CurriculumLearningOutput {
864 pub loss: Tensor,
865 pub accuracy: f32,
866 pub difficulty_threshold: f32,
867 pub examples_used: usize,
868 pub current_difficulty: f32,
869}
870
871#[derive(Debug, Clone)]
873pub struct CurriculumEpochOutput {
874 pub epoch: usize,
875 pub average_loss: f32,
876 pub average_accuracy: f32,
877 pub difficulty_threshold: f32,
878 pub examples_used: usize,
879 pub total_examples: usize,
880}
881
882#[derive(Debug, Clone)]
884pub struct CurriculumStats {
885 pub current_threshold: f32,
886 pub examples_in_curriculum: usize,
887 pub total_examples: usize,
888 pub min_difficulty: f32,
889 pub max_difficulty: f32,
890 pub avg_difficulty: f32,
891 pub epoch: usize,
892 pub step: usize,
893}
894
895pub mod utils {
897 use super::*;
898
899 pub fn self_paced_config(lambda: f32, gamma: f32) -> CurriculumConfig {
901 CurriculumConfig {
902 strategy: CurriculumStrategy::SelfPaced { lambda, gamma },
903 ..Default::default()
904 }
905 }
906
907 pub fn competence_based_config(threshold: f32, increase_rate: f32) -> CurriculumConfig {
909 CurriculumConfig {
910 strategy: CurriculumStrategy::CompetenceBased {
911 competence_threshold: threshold,
912 increase_rate,
913 },
914 ..Default::default()
915 }
916 }
917
918 pub fn baby_steps_config(step_size: f32, patience: usize) -> CurriculumConfig {
920 CurriculumConfig {
921 strategy: CurriculumStrategy::BabySteps {
922 step_size,
923 patience,
924 },
925 pacing_function: PacingFunction::Linear,
926 ..Default::default()
927 }
928 }
929
930 pub fn predefined_config(
932 difficulty_levels: Vec<f32>,
933 level_durations: Vec<usize>,
934 ) -> CurriculumConfig {
935 CurriculumConfig {
936 strategy: CurriculumStrategy::Predefined {
937 difficulty_levels,
938 level_durations,
939 },
940 ..Default::default()
941 }
942 }
943
944 pub fn anti_curriculum_config() -> CurriculumConfig {
946 CurriculumConfig {
947 strategy: CurriculumStrategy::AntiCurriculum {
948 reverse_pacing: true,
949 },
950 ..Default::default()
951 }
952 }
953
954 pub fn cyclical_config(cycle_length: usize, num_cycles: usize) -> CurriculumConfig {
956 CurriculumConfig {
957 strategy: CurriculumStrategy::Cyclical {
958 cycle_length,
959 num_cycles,
960 },
961 ..Default::default()
962 }
963 }
964
965 pub fn create_length_based_examples(
967 inputs: Vec<Tensor>,
968 targets: Vec<Tensor>,
969 ) -> Vec<CurriculumExample> {
970 inputs
971 .into_iter()
972 .zip(targets)
973 .map(|(input, target)| {
974 let length = input.shape()[1] as f32; let difficulty = (length / 512.0).min(1.0); CurriculumExample::new(input, target, difficulty)
977 })
978 .collect()
979 }
980
981 pub fn create_loss_based_examples<M: Model<Input = Tensor, Output = Tensor>>(
983 model: &M,
984 inputs: Vec<Tensor>,
985 targets: Vec<Tensor>,
986 ) -> Result<Vec<CurriculumExample>> {
987 let mut examples = Vec::new();
988
989 for (input, target) in inputs.into_iter().zip(targets) {
990 let outputs = model.forward(input.clone())?;
991 let loss = simple_cross_entropy_loss(&outputs, &target)?;
993 let difficulty = loss.to_scalar().map_err(|e| {
994 invalid_input(format!(
995 "Failed to convert loss tensor to scalar for difficulty estimation: {}",
996 e
997 ))
998 })?;
999
1000 examples.push(CurriculumExample::new(input, target, difficulty));
1001 }
1002
1003 Ok(examples)
1004 }
1005
1006 fn simple_cross_entropy_loss(outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
1008 let probs = outputs.softmax(-1)?;
1010
1011 match targets.data() {
1014 Ok(target_data) => {
1015 if let Ok(prob_data) = probs.data() {
1016 let batch_size = targets.shape()[0];
1017 let mut total_loss = 0.0f32;
1018
1019 for i in 0..batch_size {
1020 let target_idx = target_data[i] as usize;
1021 if target_idx < prob_data.len() {
1022 let prob = prob_data[target_idx].max(1e-8); total_loss += -prob.ln();
1024 }
1025 }
1026
1027 let mean_loss = total_loss / batch_size as f32;
1028 Ok(Tensor::scalar(mean_loss)?)
1029 } else {
1030 Ok(Tensor::scalar(1.0f32)?)
1031 }
1032 },
1033 Err(_) => Ok(Tensor::scalar(1.0f32)?),
1034 }
1035 }
1036
1037 pub fn create_manual_examples(
1039 inputs: Vec<Tensor>,
1040 targets: Vec<Tensor>,
1041 difficulties: Vec<f32>,
1042 ) -> Result<Vec<CurriculumExample>> {
1043 if inputs.len() != targets.len() || inputs.len() != difficulties.len() {
1044 return Err(invalid_input("Mismatched array lengths".to_string()));
1045 }
1046
1047 Ok(inputs
1048 .into_iter()
1049 .zip(targets)
1050 .zip(difficulties)
1051 .map(|((input, target), difficulty)| CurriculumExample::new(input, target, difficulty))
1052 .collect())
1053 }
1054
1055 pub fn analyze_curriculum_effectiveness(
1057 baseline_accuracies: &[f32],
1058 curriculum_accuracies: &[f32],
1059 ) -> CurriculumAnalysis {
1060 let baseline_final = baseline_accuracies.last().copied().unwrap_or_else(|| {
1062 eprintln!("Warning: Empty baseline accuracies array, using 0.0");
1063 0.0
1064 });
1065 let curriculum_final = curriculum_accuracies.last().copied().unwrap_or_else(|| {
1066 eprintln!("Warning: Empty curriculum accuracies array, using 0.0");
1067 0.0
1068 });
1069
1070 let improvement = curriculum_final - baseline_final;
1071
1072 let baseline_auc = baseline_accuracies.iter().sum::<f32>();
1074 let curriculum_auc = curriculum_accuracies.iter().sum::<f32>();
1075 let convergence_speedup = curriculum_auc / baseline_auc.max(1e-8);
1076
1077 CurriculumAnalysis {
1078 final_accuracy_improvement: improvement,
1079 convergence_speedup,
1080 baseline_final_accuracy: baseline_final,
1081 curriculum_final_accuracy: curriculum_final,
1082 }
1083 }
1084}
1085
1086#[derive(Debug, Clone)]
1088pub struct CurriculumAnalysis {
1089 pub final_accuracy_improvement: f32,
1090 pub convergence_speedup: f32,
1091 pub baseline_final_accuracy: f32,
1092 pub curriculum_final_accuracy: f32,
1093}
1094
1095#[cfg(test)]
1096mod tests {
1097 use super::*;
1098
1099 #[test]
1100 fn test_curriculum_config_default() {
1101 let config = CurriculumConfig::default();
1102 assert_eq!(config.initial_data_percentage, 0.1);
1103 assert!(config.use_throughout_training);
1104 assert!(config.shuffle_easy_examples);
1105
1106 if let CurriculumStrategy::SelfPaced { lambda, gamma } = config.strategy {
1107 assert_eq!(lambda, 0.5);
1108 assert_eq!(gamma, 1.1);
1109 } else {
1110 panic!("Expected SelfPaced strategy");
1111 }
1112 }
1113
1114 #[test]
1115 fn test_curriculum_example() {
1116 let input = Tensor::zeros(&[1, 10]).expect("operation failed");
1117 let target = Tensor::zeros(&[1]).expect("operation failed");
1118 let example = CurriculumExample::new(input, target, 0.5);
1119
1120 assert_eq!(example.difficulty, 0.5);
1121 assert_eq!(example.weight, 1.0);
1122 assert!(example.metadata.is_empty());
1123 }
1124
1125 #[test]
1126 fn test_curriculum_example_with_metadata() {
1127 let input = Tensor::zeros(&[1, 10]).expect("operation failed");
1128 let target = Tensor::zeros(&[1]).expect("operation failed");
1129 let mut metadata = HashMap::new();
1130 metadata.insert("source".to_string(), "test".to_string());
1131
1132 let example = CurriculumExample::with_metadata(input, target, 0.7, metadata);
1133 assert_eq!(example.difficulty, 0.7);
1134 assert_eq!(
1135 example.metadata.get("source").expect("operation failed"),
1136 "test"
1137 );
1138 }
1139
1140 #[test]
1141 fn test_curriculum_example_with_weight() {
1142 let input = Tensor::zeros(&[1, 10]).expect("operation failed");
1143 let target = Tensor::zeros(&[1]).expect("operation failed");
1144 let example = CurriculumExample::new(input, target, 0.3).with_weight(2.0);
1145
1146 assert_eq!(example.difficulty, 0.3);
1147 assert_eq!(example.weight, 2.0);
1148 }
1149
1150 #[test]
1151 fn test_self_paced_config() {
1152 let config = utils::self_paced_config(0.8, 1.2);
1153
1154 if let CurriculumStrategy::SelfPaced { lambda, gamma } = config.strategy {
1155 assert_eq!(lambda, 0.8);
1156 assert_eq!(gamma, 1.2);
1157 } else {
1158 panic!("Expected SelfPaced strategy");
1159 }
1160 }
1161
1162 #[test]
1163 fn test_competence_based_config() {
1164 let config = utils::competence_based_config(0.85, 0.1);
1165
1166 if let CurriculumStrategy::CompetenceBased {
1167 competence_threshold,
1168 increase_rate,
1169 } = config.strategy
1170 {
1171 assert_eq!(competence_threshold, 0.85);
1172 assert_eq!(increase_rate, 0.1);
1173 } else {
1174 panic!("Expected CompetenceBased strategy");
1175 }
1176 }
1177
1178 #[test]
1179 fn test_baby_steps_config() {
1180 let config = utils::baby_steps_config(0.05, 5);
1181
1182 if let CurriculumStrategy::BabySteps {
1183 step_size,
1184 patience,
1185 } = config.strategy
1186 {
1187 assert_eq!(step_size, 0.05);
1188 assert_eq!(patience, 5);
1189 } else {
1190 panic!("Expected BabySteps strategy");
1191 }
1192 }
1193
1194 #[test]
1195 fn test_predefined_config() {
1196 let levels = vec![0.2, 0.5, 0.8, 1.0];
1197 let durations = vec![1000, 1500, 2000, 2500];
1198 let config = utils::predefined_config(levels.clone(), durations.clone());
1199
1200 if let CurriculumStrategy::Predefined {
1201 difficulty_levels,
1202 level_durations,
1203 } = config.strategy
1204 {
1205 assert_eq!(difficulty_levels, levels);
1206 assert_eq!(level_durations, durations);
1207 } else {
1208 panic!("Expected Predefined strategy");
1209 }
1210 }
1211
1212 #[test]
1213 fn test_anti_curriculum_config() {
1214 let config = utils::anti_curriculum_config();
1215
1216 if let CurriculumStrategy::AntiCurriculum { reverse_pacing } = config.strategy {
1217 assert!(reverse_pacing);
1218 } else {
1219 panic!("Expected AntiCurriculum strategy");
1220 }
1221 }
1222
1223 #[test]
1224 fn test_cyclical_config() {
1225 let config = utils::cyclical_config(1000, 3);
1226
1227 if let CurriculumStrategy::Cyclical {
1228 cycle_length,
1229 num_cycles,
1230 } = config.strategy
1231 {
1232 assert_eq!(cycle_length, 1000);
1233 assert_eq!(num_cycles, 3);
1234 } else {
1235 panic!("Expected Cyclical strategy");
1236 }
1237 }
1238
1239 #[test]
1240 fn test_create_manual_examples() {
1241 let inputs = vec![
1242 Tensor::zeros(&[1, 10]).expect("operation failed"),
1243 Tensor::ones(&[1, 10]).expect("operation failed"),
1244 ];
1245 let targets = vec![
1246 Tensor::zeros(&[1]).expect("operation failed"),
1247 Tensor::ones(&[1]).expect("operation failed"),
1248 ];
1249 let difficulties = vec![0.2, 0.8];
1250
1251 let examples =
1252 utils::create_manual_examples(inputs, targets, difficulties).expect("operation failed");
1253 assert_eq!(examples.len(), 2);
1254 assert_eq!(examples[0].difficulty, 0.2);
1255 assert_eq!(examples[1].difficulty, 0.8);
1256 }
1257
1258 #[test]
1259 fn test_create_manual_examples_mismatched_lengths() {
1260 let inputs = vec![Tensor::zeros(&[1, 10]).expect("operation failed")];
1261 let targets = vec![Tensor::zeros(&[1]).expect("operation failed")];
1262 let difficulties = vec![0.2, 0.8]; let result = utils::create_manual_examples(inputs, targets, difficulties);
1265 assert!(result.is_err());
1266 }
1267
1268 #[test]
1269 fn test_curriculum_analysis() {
1270 let baseline = vec![0.6, 0.7, 0.75, 0.8];
1271 let curriculum = vec![0.7, 0.8, 0.85, 0.9];
1272
1273 let analysis = utils::analyze_curriculum_effectiveness(&baseline, &curriculum);
1274 assert!((analysis.final_accuracy_improvement - 0.1).abs() < 1e-6); assert!((analysis.baseline_final_accuracy - 0.8).abs() < 1e-6);
1277 assert!((analysis.curriculum_final_accuracy - 0.9).abs() < 1e-6);
1278 assert!(analysis.convergence_speedup > 1.0);
1279 }
1280}