Skip to main content

trustformers_models/
curriculum_learning.rs

1//! # Curriculum Learning Framework
2//!
3//! This module provides a comprehensive framework for curriculum learning,
4//! enabling models to learn from training data in a structured, progressive manner
5//! from easy to hard examples.
6//!
7//! ## Features
8//!
9//! - **Multiple Curriculum Strategies**: Self-paced, competence-based, and predefined curricula
10//! - **Difficulty Estimation**: Automatic difficulty scoring for training examples
11//! - **Pacing Functions**: Various functions to control learning pace
12//! - **Multi-criteria Curricula**: Combine multiple difficulty measures
13//! - **Dynamic Curriculum**: Adaptive curriculum based on model performance
14//! - **Evaluation Metrics**: Specialized metrics for curriculum learning
15//!
16//! ## Usage
17//!
18//! ```rust,no_run
19//! use trustformers_models::curriculum_learning::{
20//!     CurriculumLearningTrainer, CurriculumConfig, CurriculumStrategy
21//! };
22//!
23//! let config = CurriculumConfig {
24//!     strategy: CurriculumStrategy::SelfPaced {
25//!         lambda: 0.5,
26//!         gamma: 1.1,
27//!     },
28//!     difficulty_measure: DifficultyMeasure::LossBasedDifficulty,
29//!     pacing_function: PacingFunction::Linear,
30//!     ..Default::default()
31//! };
32//!
33//! let mut trainer = CurriculumLearningTrainer::new(model, config)?;
34//! trainer.train_with_curriculum(training_data)?;
35//! ```
36
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use trustformers_core::{errors::invalid_input, tensor::Tensor, traits::Model, Result};
40
41/// Configuration for curriculum learning
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CurriculumConfig {
44    /// Curriculum learning strategy
45    pub strategy: CurriculumStrategy,
46    /// Method for measuring example difficulty
47    pub difficulty_measure: DifficultyMeasure,
48    /// Function controlling the pace of curriculum
49    pub pacing_function: PacingFunction,
50    /// Starting percentage of data to use (0.0-1.0)
51    pub initial_data_percentage: f32,
52    /// Whether to use curriculum during the entire training
53    pub use_throughout_training: bool,
54    /// Number of epochs for curriculum phase
55    pub curriculum_epochs: usize,
56    /// Whether to shuffle easy examples
57    pub shuffle_easy_examples: bool,
58    /// Whether to adaptively adjust difficulty threshold
59    pub adaptive_threshold: bool,
60    /// Minimum difficulty threshold
61    pub min_difficulty_threshold: f32,
62    /// Maximum difficulty threshold
63    pub max_difficulty_threshold: f32,
64    /// Evaluation frequency for adaptive curriculum
65    pub evaluation_frequency: usize,
66}
67
68impl Default for CurriculumConfig {
69    fn default() -> Self {
70        Self {
71            strategy: CurriculumStrategy::SelfPaced {
72                lambda: 0.5,
73                gamma: 1.1,
74            },
75            difficulty_measure: DifficultyMeasure::LossBasedDifficulty,
76            pacing_function: PacingFunction::Linear,
77            initial_data_percentage: 0.1,
78            use_throughout_training: true,
79            curriculum_epochs: 10,
80            shuffle_easy_examples: true,
81            adaptive_threshold: true,
82            min_difficulty_threshold: 0.1,
83            max_difficulty_threshold: 0.9,
84            evaluation_frequency: 1000,
85        }
86    }
87}
88
89/// Different curriculum learning strategies
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum CurriculumStrategy {
92    /// Self-paced learning
93    SelfPaced { lambda: f32, gamma: f32 },
94    /// Competence-based curriculum
95    CompetenceBased {
96        competence_threshold: f32,
97        increase_rate: f32,
98    },
99    /// Predefined curriculum (manually defined difficulty)
100    Predefined {
101        difficulty_levels: Vec<f32>,
102        level_durations: Vec<usize>,
103    },
104    /// Baby steps curriculum
105    BabySteps { step_size: f32, patience: usize },
106    /// Anti-curriculum (hard to easy)
107    AntiCurriculum { reverse_pacing: bool },
108    /// Cyclical curriculum
109    Cyclical {
110        cycle_length: usize,
111        num_cycles: usize,
112    },
113    /// Minimax curriculum
114    Minimax {
115        teacher_lambda: f32,
116        student_lambda: f32,
117    },
118    /// Random curriculum (baseline)
119    Random,
120}
121
122/// Methods for measuring example difficulty
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub enum DifficultyMeasure {
125    /// Loss-based difficulty (higher loss = harder)
126    LossBasedDifficulty,
127    /// Gradient norm-based difficulty
128    GradientNormDifficulty,
129    /// Prediction confidence-based difficulty
130    ConfidenceDifficulty,
131    /// Length-based difficulty (for sequences)
132    LengthDifficulty,
133    /// Complexity-based difficulty (for images/text)
134    ComplexityDifficulty,
135    /// Multi-criteria difficulty
136    MultiCriteria {
137        measures: Vec<DifficultyMeasure>,
138        weights: Vec<f32>,
139    },
140    /// Learned difficulty (using auxiliary network)
141    LearnedDifficulty {
142        difficulty_network: Option<String>, // Path to difficulty network
143    },
144    /// Manual difficulty scores
145    ManualDifficulty,
146}
147
148/// Functions for controlling curriculum pacing
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub enum PacingFunction {
151    /// Linear increase in difficulty
152    Linear,
153    /// Exponential increase
154    Exponential { rate: f32 },
155    /// Logarithmic increase
156    Logarithmic { base: f32 },
157    /// Sigmoid-shaped increase
158    Sigmoid { steepness: f32, midpoint: f32 },
159    /// Step-wise increase
160    StepWise { steps: Vec<(usize, f32)> },
161    /// Polynomial increase
162    Polynomial { degree: f32 },
163    /// Custom pacing function
164    Custom { function_name: String },
165}
166
167/// Training example with difficulty score
168#[derive(Debug, Clone)]
169pub struct CurriculumExample {
170    /// Input data
171    pub input: Tensor,
172    /// Target labels
173    pub target: Tensor,
174    /// Difficulty score (0.0 = easiest, 1.0 = hardest)
175    pub difficulty: f32,
176    /// Optional metadata
177    pub metadata: HashMap<String, String>,
178    /// Example weight for training
179    pub weight: f32,
180}
181
182impl CurriculumExample {
183    /// Create a new curriculum example
184    pub fn new(input: Tensor, target: Tensor, difficulty: f32) -> Self {
185        Self {
186            input,
187            target,
188            difficulty,
189            metadata: HashMap::new(),
190            weight: 1.0,
191        }
192    }
193
194    /// Create with metadata
195    pub fn with_metadata(
196        input: Tensor,
197        target: Tensor,
198        difficulty: f32,
199        metadata: HashMap<String, String>,
200    ) -> Self {
201        Self {
202            input,
203            target,
204            difficulty,
205            metadata,
206            weight: 1.0,
207        }
208    }
209
210    /// Set example weight
211    pub fn with_weight(mut self, weight: f32) -> Self {
212        self.weight = weight;
213        self
214    }
215}
216
217/// Curriculum learning trainer
218pub struct CurriculumLearningTrainer<M: Model> {
219    /// The model being trained
220    pub model: M,
221    /// Configuration
222    pub config: CurriculumConfig,
223    /// All training examples with difficulty scores
224    pub examples: Vec<CurriculumExample>,
225    /// Current difficulty threshold
226    pub current_threshold: f32,
227    /// Current epoch
228    pub current_epoch: usize,
229    /// Training step counter
230    pub step_counter: usize,
231    /// Performance history for adaptive curriculum
232    pub performance_history: Vec<f32>,
233    /// Difficulty scorer for dynamic difficulty estimation
234    pub difficulty_scorer: Option<DifficultyScorer>,
235}
236
237impl<M: Model<Input = Tensor, Output = Tensor>> CurriculumLearningTrainer<M> {
238    /// Create a new curriculum learning trainer
239    pub fn new(model: M, config: CurriculumConfig) -> Result<Self> {
240        let difficulty_scorer = match &config.difficulty_measure {
241            DifficultyMeasure::LearnedDifficulty { .. } => {
242                Some(DifficultyScorer::new(&config.difficulty_measure)?)
243            },
244            _ => None,
245        };
246
247        let initial_data_percentage = config.initial_data_percentage;
248
249        Ok(Self {
250            model,
251            config,
252            examples: Vec::new(),
253            current_threshold: initial_data_percentage,
254            current_epoch: 0,
255            step_counter: 0,
256            performance_history: Vec::new(),
257            difficulty_scorer,
258        })
259    }
260
261    /// Add training examples to the curriculum
262    pub fn add_examples(&mut self, examples: Vec<CurriculumExample>) {
263        self.examples.extend(examples);
264        self.sort_examples_by_difficulty();
265    }
266
267    /// Add a single example
268    pub fn add_example(&mut self, example: CurriculumExample) {
269        self.examples.push(example);
270        self.sort_examples_by_difficulty();
271    }
272
273    /// Estimate difficulty for examples without scores
274    pub fn estimate_difficulties(&mut self) -> Result<()> {
275        let mut indices_to_update = Vec::new();
276
277        // First pass: collect indices that need updating
278        for (i, example) in self.examples.iter().enumerate() {
279            if example.difficulty == 0.0 {
280                // Assume 0.0 means unscored
281                indices_to_update.push(i);
282            }
283        }
284
285        // Second pass: update difficulties without borrowing conflicts
286        for i in indices_to_update {
287            let input = self.examples[i].input.clone();
288            let target = self.examples[i].target.clone();
289            let difficulty = self.compute_difficulty(&input, &target)?;
290            self.examples[i].difficulty = difficulty;
291        }
292
293        self.sort_examples_by_difficulty();
294        Ok(())
295    }
296
297    /// Compute difficulty score for an example
298    fn compute_difficulty(&self, input: &Tensor, target: &Tensor) -> Result<f32> {
299        match &self.config.difficulty_measure {
300            DifficultyMeasure::LossBasedDifficulty => {
301                let outputs = self.model.forward(input.clone())?;
302                let loss = self.compute_loss(&outputs, target)?;
303                loss.to_scalar().map_err(|e| {
304                    invalid_input(format!("Failed to convert loss tensor to scalar: {}", e))
305                })
306            },
307            DifficultyMeasure::GradientNormDifficulty => {
308                // Compute gradient norm as difficulty measure
309                // This is a simplified implementation
310                Ok(0.5) // Placeholder
311            },
312            DifficultyMeasure::ConfidenceDifficulty => {
313                let outputs = self.model.forward(input.clone())?;
314                let probs = outputs.softmax(-1)?;
315                let max_prob = self.compute_max_probability(&probs)?;
316                Ok(1.0 - max_prob) // Lower confidence = higher difficulty
317            },
318            DifficultyMeasure::LengthDifficulty => {
319                // For sequence data, use length as difficulty measure
320                let seq_len = input.shape()[1] as f32; // Assuming [batch, seq_len, ...]
321                Ok(seq_len / 1000.0) // Normalize by typical sequence length
322            },
323            DifficultyMeasure::ComplexityDifficulty => {
324                // Compute complexity-based difficulty
325                // This could be entropy, edge density, etc.
326                Ok(0.5) // Placeholder
327            },
328            DifficultyMeasure::MultiCriteria { measures, weights } => {
329                let mut total_difficulty = 0.0;
330                let mut total_weight = 0.0;
331
332                for (measure, &weight) in measures.iter().zip(weights.iter()) {
333                    // Compute difficulty for each individual measure using dedicated method
334                    let difficulty = self.compute_individual_difficulty(measure, input, target)?;
335                    total_difficulty += difficulty * weight;
336                    total_weight += weight;
337                }
338
339                Ok(if total_weight > 0.0 { total_difficulty / total_weight } else { 0.5 })
340            },
341            DifficultyMeasure::LearnedDifficulty { .. } => {
342                if let Some(scorer) = &self.difficulty_scorer {
343                    scorer.score_difficulty(input, target)
344                } else {
345                    Ok(0.5)
346                }
347            },
348            DifficultyMeasure::ManualDifficulty => {
349                // Manual difficulty should already be set
350                Ok(0.5) // Default if not set
351            },
352        }
353    }
354
355    /// Helper method to compute difficulty for individual measures (avoiding recursion)
356    fn compute_individual_difficulty(
357        &self,
358        measure: &DifficultyMeasure,
359        input: &Tensor,
360        target: &Tensor,
361    ) -> Result<f32> {
362        match measure {
363            DifficultyMeasure::LossBasedDifficulty => {
364                let outputs = self.model.forward(input.clone())?;
365                let loss = self.compute_loss(&outputs, target)?;
366                loss.to_scalar().map_err(|e| {
367                    invalid_input(format!("Failed to convert loss tensor to scalar: {}", e))
368                })
369            },
370            DifficultyMeasure::LengthDifficulty => {
371                let seq_len = input.shape()[1] as f32; // Assuming [batch, seq_len, ...]
372                Ok(seq_len / 1000.0) // Normalize by typical sequence length
373            },
374            DifficultyMeasure::GradientNormDifficulty => {
375                // Compute gradient norm-based difficulty
376                Ok(0.5) // Placeholder - could be enhanced with actual gradient computation
377            },
378            DifficultyMeasure::ConfidenceDifficulty => {
379                // Compute confidence-based difficulty (higher uncertainty = harder)
380                let _outputs = self.model.forward(input.clone())?;
381                // Simple confidence measure based on max probability
382                Ok(0.5) // Placeholder - could compute actual confidence metrics
383            },
384            DifficultyMeasure::ComplexityDifficulty => {
385                // Compute complexity-based difficulty
386                // This could be entropy, edge density, etc.
387                Ok(0.5) // Placeholder - could be enhanced with actual complexity computation
388            },
389            DifficultyMeasure::LearnedDifficulty { .. } => {
390                if let Some(scorer) = &self.difficulty_scorer {
391                    scorer.score_difficulty(input, target)
392                } else {
393                    Ok(0.5)
394                }
395            },
396            DifficultyMeasure::ManualDifficulty => {
397                // Manual difficulty should already be set
398                Ok(0.5) // Default if not set
399            },
400            DifficultyMeasure::MultiCriteria { .. } => {
401                // Prevent infinite recursion by returning a default value
402                Ok(0.5)
403            },
404        }
405    }
406
407    /// Sort examples by difficulty
408    fn sort_examples_by_difficulty(&mut self) {
409        self.examples.sort_by(|a, b| {
410            a.difficulty.partial_cmp(&b.difficulty).unwrap_or(std::cmp::Ordering::Equal)
411        });
412    }
413
414    /// Get current curriculum subset
415    pub fn get_current_curriculum(&self) -> Vec<CurriculumExample> {
416        let num_examples = self.examples.len();
417        let threshold_count = (num_examples as f32 * self.current_threshold) as usize;
418
419        match &self.config.strategy {
420            CurriculumStrategy::AntiCurriculum { reverse_pacing } => {
421                if *reverse_pacing {
422                    // Start with hardest examples
423                    self.examples.iter().rev().take(threshold_count).cloned().collect()
424                } else {
425                    self.examples.iter().take(threshold_count).cloned().collect()
426                }
427            },
428            _ => {
429                // Normal curriculum: start with easiest
430                self.examples.iter().take(threshold_count).cloned().collect()
431            },
432        }
433    }
434
435    /// Update curriculum threshold based on strategy
436    pub fn update_curriculum_threshold(&mut self) -> Result<()> {
437        match &self.config.strategy {
438            CurriculumStrategy::SelfPaced { lambda: _, gamma } => {
439                // Self-paced learning adjusts threshold based on performance
440                let recent_performance = self.get_recent_performance();
441                if recent_performance > 0.8 {
442                    // Good performance
443                    self.current_threshold = (self.current_threshold * gamma).min(1.0);
444                }
445            },
446            CurriculumStrategy::CompetenceBased {
447                competence_threshold,
448                increase_rate,
449            } => {
450                let competence = self.compute_competence()?;
451                if competence > *competence_threshold {
452                    self.current_threshold = (self.current_threshold + increase_rate).min(1.0);
453                }
454            },
455            CurriculumStrategy::Predefined {
456                difficulty_levels,
457                level_durations,
458            } => {
459                // Use predefined schedule
460                let total_steps: usize = level_durations.iter().sum();
461                let current_step = self.step_counter % total_steps;
462                let mut cumulative_steps = 0;
463
464                for (i, &duration) in level_durations.iter().enumerate() {
465                    cumulative_steps += duration;
466                    if current_step < cumulative_steps {
467                        if i < difficulty_levels.len() {
468                            self.current_threshold = difficulty_levels[i];
469                        }
470                        break;
471                    }
472                }
473            },
474            CurriculumStrategy::BabySteps {
475                step_size,
476                patience,
477            } => {
478                // Increase threshold by small steps when performance is good
479                if self.performance_history.len() >= *patience {
480                    let recent_avg =
481                        self.performance_history.iter().rev().take(*patience).sum::<f32>()
482                            / *patience as f32;
483
484                    if recent_avg > 0.85 {
485                        // Good performance
486                        self.current_threshold = (self.current_threshold + step_size).min(1.0);
487                    }
488                }
489            },
490            CurriculumStrategy::Cyclical { cycle_length, .. } => {
491                // Cyclical curriculum
492                let cycle_position =
493                    (self.step_counter % cycle_length) as f32 / *cycle_length as f32;
494                self.current_threshold = self.apply_pacing_function(cycle_position);
495            },
496            _ => {
497                // Default linear progression
498                let progress = self.current_epoch as f32 / self.config.curriculum_epochs as f32;
499                self.current_threshold = self.apply_pacing_function(progress);
500            },
501        }
502
503        // Apply bounds
504        self.current_threshold = self
505            .current_threshold
506            .max(self.config.min_difficulty_threshold)
507            .min(self.config.max_difficulty_threshold);
508
509        Ok(())
510    }
511
512    /// Apply pacing function to progress
513    fn apply_pacing_function(&self, progress: f32) -> f32 {
514        let clamped_progress = progress.clamp(0.0, 1.0);
515
516        match &self.config.pacing_function {
517            PacingFunction::Linear => {
518                self.config.initial_data_percentage
519                    + (1.0 - self.config.initial_data_percentage) * clamped_progress
520            },
521            PacingFunction::Exponential { rate } => {
522                self.config.initial_data_percentage
523                    + (1.0 - self.config.initial_data_percentage)
524                        * (1.0 - (-rate * clamped_progress).exp())
525            },
526            PacingFunction::Logarithmic { base } => {
527                self.config.initial_data_percentage
528                    + (1.0 - self.config.initial_data_percentage) * (clamped_progress * base).ln()
529                        / base.ln()
530            },
531            PacingFunction::Sigmoid {
532                steepness,
533                midpoint,
534            } => {
535                let sigmoid = 1.0 / (1.0 + (-steepness * (clamped_progress - midpoint)).exp());
536                self.config.initial_data_percentage
537                    + (1.0 - self.config.initial_data_percentage) * sigmoid
538            },
539            PacingFunction::StepWise { steps } => {
540                let total_steps = self.step_counter;
541                for &(step_threshold, threshold_value) in steps {
542                    if total_steps <= step_threshold {
543                        return threshold_value;
544                    }
545                }
546                1.0 // If past all steps, use all data
547            },
548            PacingFunction::Polynomial { degree } => {
549                self.config.initial_data_percentage
550                    + (1.0 - self.config.initial_data_percentage) * clamped_progress.powf(*degree)
551            },
552            PacingFunction::Custom { .. } => {
553                // Custom function would be implemented here
554                self.apply_pacing_function_linear(clamped_progress)
555            },
556        }
557    }
558
559    /// Linear pacing function (fallback)
560    fn apply_pacing_function_linear(&self, progress: f32) -> f32 {
561        self.config.initial_data_percentage + (1.0 - self.config.initial_data_percentage) * progress
562    }
563
564    /// Compute model competence
565    fn compute_competence(&self) -> Result<f32> {
566        if self.performance_history.is_empty() {
567            return Ok(0.0);
568        }
569
570        let recent_performance = self.get_recent_performance();
571        Ok(recent_performance)
572    }
573
574    /// Get recent performance average
575    fn get_recent_performance(&self) -> f32 {
576        if self.performance_history.is_empty() {
577            return 0.0;
578        }
579
580        let window_size = 10.min(self.performance_history.len());
581        self.performance_history.iter().rev().take(window_size).sum::<f32>() / window_size as f32
582    }
583
584    /// Train one step with curriculum
585    pub fn train_step(&mut self) -> Result<CurriculumLearningOutput> {
586        // Update curriculum threshold
587        self.update_curriculum_threshold()?;
588
589        // Get current curriculum examples
590        let curriculum_examples = self.get_current_curriculum();
591
592        if curriculum_examples.is_empty() {
593            return Err(invalid_input(
594                "No examples available for training".to_string(),
595            ));
596        }
597
598        // Sample from curriculum
599        let example = &curriculum_examples[self.step_counter % curriculum_examples.len()];
600
601        // Compute forward pass and loss
602        let outputs = self.model.forward(example.input.clone())?;
603        let loss = self.compute_loss(&outputs, &example.target)?;
604
605        // Weight the loss by example weight
606        let weighted_loss = loss.scalar_mul(example.weight)?;
607
608        // Compute accuracy for performance tracking
609        let accuracy = self.compute_accuracy(&outputs, &example.target)?;
610        self.performance_history.push(accuracy);
611
612        // Keep performance history bounded
613        if self.performance_history.len() > 1000 {
614            self.performance_history = self.performance_history.split_off(500);
615        }
616
617        self.step_counter += 1;
618
619        Ok(CurriculumLearningOutput {
620            loss: weighted_loss,
621            accuracy,
622            difficulty_threshold: self.current_threshold,
623            examples_used: curriculum_examples.len(),
624            current_difficulty: example.difficulty,
625        })
626    }
627
628    /// Train for one epoch with curriculum
629    pub fn train_epoch(&mut self) -> Result<CurriculumEpochOutput> {
630        let mut total_loss = 0.0;
631        let mut total_accuracy = 0.0;
632        let mut num_steps = 0;
633
634        let curriculum_examples = self.get_current_curriculum();
635
636        for example in &curriculum_examples {
637            let outputs = self.model.forward(example.input.clone())?;
638            let loss = self.compute_loss(&outputs, &example.target)?;
639            let accuracy = self.compute_accuracy(&outputs, &example.target)?;
640
641            let loss_scalar = loss.to_scalar().map_err(|e| {
642                invalid_input(format!("Failed to convert loss tensor to scalar: {}", e))
643            })?;
644            total_loss += loss_scalar * example.weight;
645            total_accuracy += accuracy;
646            num_steps += 1;
647        }
648
649        self.current_epoch += 1;
650
651        Ok(CurriculumEpochOutput {
652            epoch: self.current_epoch,
653            average_loss: total_loss / num_steps as f32,
654            average_accuracy: total_accuracy / num_steps as f32,
655            difficulty_threshold: self.current_threshold,
656            examples_used: curriculum_examples.len(),
657            total_examples: self.examples.len(),
658        })
659    }
660
661    /// Compute cross-entropy loss
662    fn compute_loss(&self, outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
663        self.compute_cross_entropy_loss(outputs, targets)
664    }
665
666    /// Compute accuracy
667    fn compute_accuracy(&self, outputs: &Tensor, targets: &Tensor) -> Result<f32> {
668        let predicted = self.compute_argmax(outputs)?;
669        let target_indices = self.compute_argmax(targets)?;
670
671        // Compute accuracy as fraction of correct predictions
672        let total_samples = predicted.len() as f32;
673        if total_samples == 0.0 {
674            return Ok(0.0);
675        }
676
677        let mut correct = 0.0;
678        for (pred, target) in predicted.iter().zip(target_indices.iter()) {
679            if (pred - target).abs() < f32::EPSILON {
680                correct += 1.0;
681            }
682        }
683
684        Ok(correct / total_samples)
685    }
686
687    /// Get curriculum statistics
688    pub fn get_curriculum_stats(&self) -> CurriculumStats {
689        let curriculum_examples = self.get_current_curriculum();
690        let difficulties: Vec<f32> = curriculum_examples.iter().map(|e| e.difficulty).collect();
691
692        let min_difficulty = difficulties.iter().fold(f32::INFINITY, |a, &b| a.min(b));
693        let max_difficulty = difficulties.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
694        let avg_difficulty = if !difficulties.is_empty() {
695            difficulties.iter().sum::<f32>() / difficulties.len() as f32
696        } else {
697            0.0
698        };
699
700        CurriculumStats {
701            current_threshold: self.current_threshold,
702            examples_in_curriculum: curriculum_examples.len(),
703            total_examples: self.examples.len(),
704            min_difficulty,
705            max_difficulty,
706            avg_difficulty,
707            epoch: self.current_epoch,
708            step: self.step_counter,
709        }
710    }
711
712    /// Compute maximum probability from softmax output
713    fn compute_max_probability(&self, probs: &Tensor) -> Result<f32> {
714        match probs {
715            Tensor::F32(arr) => {
716                // Find max probability across all dimensions
717                let max_val = arr.iter().fold(0.0f32, |acc, &x| acc.max(x));
718                Ok(max_val)
719            },
720            _ => {
721                Ok(0.5) // Default fallback
722            },
723        }
724    }
725
726    /// Compute cross-entropy loss between outputs and targets
727    fn compute_cross_entropy_loss(&self, outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
728        // Apply softmax to get probabilities
729        let probs = outputs.softmax(-1)?;
730
731        // Compute log probabilities for numerical stability
732        let log_probs = probs.log()?;
733
734        // Compute negative log likelihood based on target format
735        match (log_probs, targets) {
736            (Tensor::F32(log_prob_arr), Tensor::F32(target_arr)) => {
737                // Assuming targets are one-hot encoded or class indices
738                let batch_size = log_prob_arr.shape()[0];
739                let num_classes = log_prob_arr.shape().get(1).copied().ok_or_else(|| {
740                    invalid_input(format!(
741                        "Invalid tensor shape: expected at least 2 dimensions, got {}",
742                        log_prob_arr.shape().len()
743                    ))
744                })?;
745
746                let mut total_loss = 0.0f32;
747
748                for batch_idx in 0..batch_size {
749                    if target_arr.shape().len() == 1 {
750                        // Class indices format
751                        let target_class = target_arr[[batch_idx]] as usize;
752                        if target_class < num_classes {
753                            total_loss -= log_prob_arr[[batch_idx, target_class]];
754                        }
755                    } else if target_arr.shape().len() >= 2 && target_arr.shape()[1] == num_classes
756                    {
757                        // One-hot format
758                        for class_idx in 0..num_classes {
759                            let target_prob = target_arr[[batch_idx, class_idx]];
760                            if target_prob > 0.0 {
761                                total_loss -= target_prob * log_prob_arr[[batch_idx, class_idx]];
762                            }
763                        }
764                    }
765                }
766
767                // Return mean loss
768                let mean_loss = total_loss / batch_size as f32;
769                Ok(Tensor::scalar(mean_loss)?)
770            },
771            _ => {
772                // Fallback for unsupported tensor types
773                Ok(Tensor::scalar(1.0f32)?)
774            },
775        }
776    }
777
778    /// Compute argmax (indices of maximum values)
779    fn compute_argmax(&self, tensor: &Tensor) -> Result<Vec<f32>> {
780        match tensor {
781            Tensor::F32(arr) => {
782                let mut argmax_values = Vec::new();
783
784                // Handle different tensor shapes
785                if arr.ndim() == 1 {
786                    // 1D tensor - find single argmax
787                    let mut max_idx = 0;
788                    let mut max_val = arr[0];
789                    for (idx, &val) in arr.iter().enumerate() {
790                        if val > max_val {
791                            max_val = val;
792                            max_idx = idx;
793                        }
794                    }
795                    argmax_values.push(max_idx as f32);
796                } else if arr.ndim() == 2 {
797                    // 2D tensor - find argmax along last dimension for each batch
798                    let batch_size = arr.shape()[0];
799                    let num_classes = arr.shape()[1];
800
801                    for batch_idx in 0..batch_size {
802                        let mut max_idx = 0;
803                        let mut max_val = arr[[batch_idx, 0]];
804
805                        for class_idx in 1..num_classes {
806                            let val = arr[[batch_idx, class_idx]];
807                            if val > max_val {
808                                max_val = val;
809                                max_idx = class_idx;
810                            }
811                        }
812                        argmax_values.push(max_idx as f32);
813                    }
814                } else {
815                    // Multi-dimensional tensor - flatten and find global argmax
816                    let mut max_idx = 0;
817                    let mut max_val = arr.iter().next().copied().ok_or_else(|| {
818                        invalid_input("Cannot compute argmax on empty tensor".to_string())
819                    })?;
820
821                    for (idx, &val) in arr.iter().enumerate() {
822                        if val > max_val {
823                            max_val = val;
824                            max_idx = idx;
825                        }
826                    }
827                    argmax_values.push(max_idx as f32);
828                }
829
830                Ok(argmax_values)
831            },
832            _ => {
833                // Fallback for unsupported tensor types
834                Ok(vec![0.0])
835            },
836        }
837    }
838}
839
840/// Difficulty scorer for learned difficulty estimation
841pub struct DifficultyScorer {
842    /// Scoring method
843    #[allow(dead_code)]
844    method: DifficultyMeasure,
845}
846
847impl DifficultyScorer {
848    pub fn new(method: &DifficultyMeasure) -> Result<Self> {
849        Ok(Self {
850            method: method.clone(),
851        })
852    }
853
854    pub fn score_difficulty(&self, _input: &Tensor, _target: &Tensor) -> Result<f32> {
855        // Implement learned difficulty scoring
856        // This would typically involve a separate neural network
857        Ok(0.5) // Placeholder
858    }
859}
860
861/// Output from a curriculum learning training step
862#[derive(Debug, Clone)]
863pub struct CurriculumLearningOutput {
864    pub loss: Tensor,
865    pub accuracy: f32,
866    pub difficulty_threshold: f32,
867    pub examples_used: usize,
868    pub current_difficulty: f32,
869}
870
871/// Output from a curriculum learning epoch
872#[derive(Debug, Clone)]
873pub struct CurriculumEpochOutput {
874    pub epoch: usize,
875    pub average_loss: f32,
876    pub average_accuracy: f32,
877    pub difficulty_threshold: f32,
878    pub examples_used: usize,
879    pub total_examples: usize,
880}
881
882/// Curriculum learning statistics
883#[derive(Debug, Clone)]
884pub struct CurriculumStats {
885    pub current_threshold: f32,
886    pub examples_in_curriculum: usize,
887    pub total_examples: usize,
888    pub min_difficulty: f32,
889    pub max_difficulty: f32,
890    pub avg_difficulty: f32,
891    pub epoch: usize,
892    pub step: usize,
893}
894
895/// Utilities for curriculum learning
896pub mod utils {
897    use super::*;
898
899    /// Create a self-paced learning configuration
900    pub fn self_paced_config(lambda: f32, gamma: f32) -> CurriculumConfig {
901        CurriculumConfig {
902            strategy: CurriculumStrategy::SelfPaced { lambda, gamma },
903            ..Default::default()
904        }
905    }
906
907    /// Create a competence-based curriculum configuration
908    pub fn competence_based_config(threshold: f32, increase_rate: f32) -> CurriculumConfig {
909        CurriculumConfig {
910            strategy: CurriculumStrategy::CompetenceBased {
911                competence_threshold: threshold,
912                increase_rate,
913            },
914            ..Default::default()
915        }
916    }
917
918    /// Create a baby steps curriculum configuration
919    pub fn baby_steps_config(step_size: f32, patience: usize) -> CurriculumConfig {
920        CurriculumConfig {
921            strategy: CurriculumStrategy::BabySteps {
922                step_size,
923                patience,
924            },
925            pacing_function: PacingFunction::Linear,
926            ..Default::default()
927        }
928    }
929
930    /// Create a predefined curriculum configuration
931    pub fn predefined_config(
932        difficulty_levels: Vec<f32>,
933        level_durations: Vec<usize>,
934    ) -> CurriculumConfig {
935        CurriculumConfig {
936            strategy: CurriculumStrategy::Predefined {
937                difficulty_levels,
938                level_durations,
939            },
940            ..Default::default()
941        }
942    }
943
944    /// Create an anti-curriculum configuration (hard to easy)
945    pub fn anti_curriculum_config() -> CurriculumConfig {
946        CurriculumConfig {
947            strategy: CurriculumStrategy::AntiCurriculum {
948                reverse_pacing: true,
949            },
950            ..Default::default()
951        }
952    }
953
954    /// Create a cyclical curriculum configuration
955    pub fn cyclical_config(cycle_length: usize, num_cycles: usize) -> CurriculumConfig {
956        CurriculumConfig {
957            strategy: CurriculumStrategy::Cyclical {
958                cycle_length,
959                num_cycles,
960            },
961            ..Default::default()
962        }
963    }
964
965    /// Create examples with length-based difficulty
966    pub fn create_length_based_examples(
967        inputs: Vec<Tensor>,
968        targets: Vec<Tensor>,
969    ) -> Vec<CurriculumExample> {
970        inputs
971            .into_iter()
972            .zip(targets)
973            .map(|(input, target)| {
974                let length = input.shape()[1] as f32; // Assuming [batch, seq_len, ...]
975                let difficulty = (length / 512.0).min(1.0); // Normalize by max length
976                CurriculumExample::new(input, target, difficulty)
977            })
978            .collect()
979    }
980
981    /// Create examples with loss-based difficulty
982    pub fn create_loss_based_examples<M: Model<Input = Tensor, Output = Tensor>>(
983        model: &M,
984        inputs: Vec<Tensor>,
985        targets: Vec<Tensor>,
986    ) -> Result<Vec<CurriculumExample>> {
987        let mut examples = Vec::new();
988
989        for (input, target) in inputs.into_iter().zip(targets) {
990            let outputs = model.forward(input.clone())?;
991            // Use a simple cross-entropy loss calculation without trainer for difficulty estimation
992            let loss = simple_cross_entropy_loss(&outputs, &target)?;
993            let difficulty = loss.to_scalar().map_err(|e| {
994                invalid_input(format!(
995                    "Failed to convert loss tensor to scalar for difficulty estimation: {}",
996                    e
997                ))
998            })?;
999
1000            examples.push(CurriculumExample::new(input, target, difficulty));
1001        }
1002
1003        Ok(examples)
1004    }
1005
1006    /// Simple cross-entropy loss computation for difficulty estimation
1007    fn simple_cross_entropy_loss(outputs: &Tensor, targets: &Tensor) -> Result<Tensor> {
1008        // Apply softmax to get probabilities
1009        let probs = outputs.softmax(-1)?;
1010
1011        // Simple cross-entropy: -log(p_target)
1012        // This is a simplified version for difficulty estimation
1013        match targets.data() {
1014            Ok(target_data) => {
1015                if let Ok(prob_data) = probs.data() {
1016                    let batch_size = targets.shape()[0];
1017                    let mut total_loss = 0.0f32;
1018
1019                    for i in 0..batch_size {
1020                        let target_idx = target_data[i] as usize;
1021                        if target_idx < prob_data.len() {
1022                            let prob = prob_data[target_idx].max(1e-8); // Avoid log(0)
1023                            total_loss += -prob.ln();
1024                        }
1025                    }
1026
1027                    let mean_loss = total_loss / batch_size as f32;
1028                    Ok(Tensor::scalar(mean_loss)?)
1029                } else {
1030                    Ok(Tensor::scalar(1.0f32)?)
1031                }
1032            },
1033            Err(_) => Ok(Tensor::scalar(1.0f32)?),
1034        }
1035    }
1036
1037    /// Create examples with manual difficulty scores
1038    pub fn create_manual_examples(
1039        inputs: Vec<Tensor>,
1040        targets: Vec<Tensor>,
1041        difficulties: Vec<f32>,
1042    ) -> Result<Vec<CurriculumExample>> {
1043        if inputs.len() != targets.len() || inputs.len() != difficulties.len() {
1044            return Err(invalid_input("Mismatched array lengths".to_string()));
1045        }
1046
1047        Ok(inputs
1048            .into_iter()
1049            .zip(targets)
1050            .zip(difficulties)
1051            .map(|((input, target), difficulty)| CurriculumExample::new(input, target, difficulty))
1052            .collect())
1053    }
1054
1055    /// Analyze curriculum effectiveness
1056    pub fn analyze_curriculum_effectiveness(
1057        baseline_accuracies: &[f32],
1058        curriculum_accuracies: &[f32],
1059    ) -> CurriculumAnalysis {
1060        // Use 0.0 as default for empty accuracy arrays (no training = no accuracy)
1061        let baseline_final = baseline_accuracies.last().copied().unwrap_or_else(|| {
1062            eprintln!("Warning: Empty baseline accuracies array, using 0.0");
1063            0.0
1064        });
1065        let curriculum_final = curriculum_accuracies.last().copied().unwrap_or_else(|| {
1066            eprintln!("Warning: Empty curriculum accuracies array, using 0.0");
1067            0.0
1068        });
1069
1070        let improvement = curriculum_final - baseline_final;
1071
1072        // Compute area under the curve for convergence speed
1073        let baseline_auc = baseline_accuracies.iter().sum::<f32>();
1074        let curriculum_auc = curriculum_accuracies.iter().sum::<f32>();
1075        let convergence_speedup = curriculum_auc / baseline_auc.max(1e-8);
1076
1077        CurriculumAnalysis {
1078            final_accuracy_improvement: improvement,
1079            convergence_speedup,
1080            baseline_final_accuracy: baseline_final,
1081            curriculum_final_accuracy: curriculum_final,
1082        }
1083    }
1084}
1085
1086/// Analysis of curriculum learning effectiveness
1087#[derive(Debug, Clone)]
1088pub struct CurriculumAnalysis {
1089    pub final_accuracy_improvement: f32,
1090    pub convergence_speedup: f32,
1091    pub baseline_final_accuracy: f32,
1092    pub curriculum_final_accuracy: f32,
1093}
1094
1095#[cfg(test)]
1096mod tests {
1097    use super::*;
1098
1099    #[test]
1100    fn test_curriculum_config_default() {
1101        let config = CurriculumConfig::default();
1102        assert_eq!(config.initial_data_percentage, 0.1);
1103        assert!(config.use_throughout_training);
1104        assert!(config.shuffle_easy_examples);
1105
1106        if let CurriculumStrategy::SelfPaced { lambda, gamma } = config.strategy {
1107            assert_eq!(lambda, 0.5);
1108            assert_eq!(gamma, 1.1);
1109        } else {
1110            panic!("Expected SelfPaced strategy");
1111        }
1112    }
1113
1114    #[test]
1115    fn test_curriculum_example() {
1116        let input = Tensor::zeros(&[1, 10]).expect("operation failed");
1117        let target = Tensor::zeros(&[1]).expect("operation failed");
1118        let example = CurriculumExample::new(input, target, 0.5);
1119
1120        assert_eq!(example.difficulty, 0.5);
1121        assert_eq!(example.weight, 1.0);
1122        assert!(example.metadata.is_empty());
1123    }
1124
1125    #[test]
1126    fn test_curriculum_example_with_metadata() {
1127        let input = Tensor::zeros(&[1, 10]).expect("operation failed");
1128        let target = Tensor::zeros(&[1]).expect("operation failed");
1129        let mut metadata = HashMap::new();
1130        metadata.insert("source".to_string(), "test".to_string());
1131
1132        let example = CurriculumExample::with_metadata(input, target, 0.7, metadata);
1133        assert_eq!(example.difficulty, 0.7);
1134        assert_eq!(
1135            example.metadata.get("source").expect("operation failed"),
1136            "test"
1137        );
1138    }
1139
1140    #[test]
1141    fn test_curriculum_example_with_weight() {
1142        let input = Tensor::zeros(&[1, 10]).expect("operation failed");
1143        let target = Tensor::zeros(&[1]).expect("operation failed");
1144        let example = CurriculumExample::new(input, target, 0.3).with_weight(2.0);
1145
1146        assert_eq!(example.difficulty, 0.3);
1147        assert_eq!(example.weight, 2.0);
1148    }
1149
1150    #[test]
1151    fn test_self_paced_config() {
1152        let config = utils::self_paced_config(0.8, 1.2);
1153
1154        if let CurriculumStrategy::SelfPaced { lambda, gamma } = config.strategy {
1155            assert_eq!(lambda, 0.8);
1156            assert_eq!(gamma, 1.2);
1157        } else {
1158            panic!("Expected SelfPaced strategy");
1159        }
1160    }
1161
1162    #[test]
1163    fn test_competence_based_config() {
1164        let config = utils::competence_based_config(0.85, 0.1);
1165
1166        if let CurriculumStrategy::CompetenceBased {
1167            competence_threshold,
1168            increase_rate,
1169        } = config.strategy
1170        {
1171            assert_eq!(competence_threshold, 0.85);
1172            assert_eq!(increase_rate, 0.1);
1173        } else {
1174            panic!("Expected CompetenceBased strategy");
1175        }
1176    }
1177
1178    #[test]
1179    fn test_baby_steps_config() {
1180        let config = utils::baby_steps_config(0.05, 5);
1181
1182        if let CurriculumStrategy::BabySteps {
1183            step_size,
1184            patience,
1185        } = config.strategy
1186        {
1187            assert_eq!(step_size, 0.05);
1188            assert_eq!(patience, 5);
1189        } else {
1190            panic!("Expected BabySteps strategy");
1191        }
1192    }
1193
1194    #[test]
1195    fn test_predefined_config() {
1196        let levels = vec![0.2, 0.5, 0.8, 1.0];
1197        let durations = vec![1000, 1500, 2000, 2500];
1198        let config = utils::predefined_config(levels.clone(), durations.clone());
1199
1200        if let CurriculumStrategy::Predefined {
1201            difficulty_levels,
1202            level_durations,
1203        } = config.strategy
1204        {
1205            assert_eq!(difficulty_levels, levels);
1206            assert_eq!(level_durations, durations);
1207        } else {
1208            panic!("Expected Predefined strategy");
1209        }
1210    }
1211
1212    #[test]
1213    fn test_anti_curriculum_config() {
1214        let config = utils::anti_curriculum_config();
1215
1216        if let CurriculumStrategy::AntiCurriculum { reverse_pacing } = config.strategy {
1217            assert!(reverse_pacing);
1218        } else {
1219            panic!("Expected AntiCurriculum strategy");
1220        }
1221    }
1222
1223    #[test]
1224    fn test_cyclical_config() {
1225        let config = utils::cyclical_config(1000, 3);
1226
1227        if let CurriculumStrategy::Cyclical {
1228            cycle_length,
1229            num_cycles,
1230        } = config.strategy
1231        {
1232            assert_eq!(cycle_length, 1000);
1233            assert_eq!(num_cycles, 3);
1234        } else {
1235            panic!("Expected Cyclical strategy");
1236        }
1237    }
1238
1239    #[test]
1240    fn test_create_manual_examples() {
1241        let inputs = vec![
1242            Tensor::zeros(&[1, 10]).expect("operation failed"),
1243            Tensor::ones(&[1, 10]).expect("operation failed"),
1244        ];
1245        let targets = vec![
1246            Tensor::zeros(&[1]).expect("operation failed"),
1247            Tensor::ones(&[1]).expect("operation failed"),
1248        ];
1249        let difficulties = vec![0.2, 0.8];
1250
1251        let examples =
1252            utils::create_manual_examples(inputs, targets, difficulties).expect("operation failed");
1253        assert_eq!(examples.len(), 2);
1254        assert_eq!(examples[0].difficulty, 0.2);
1255        assert_eq!(examples[1].difficulty, 0.8);
1256    }
1257
1258    #[test]
1259    fn test_create_manual_examples_mismatched_lengths() {
1260        let inputs = vec![Tensor::zeros(&[1, 10]).expect("operation failed")];
1261        let targets = vec![Tensor::zeros(&[1]).expect("operation failed")];
1262        let difficulties = vec![0.2, 0.8]; // Different length
1263
1264        let result = utils::create_manual_examples(inputs, targets, difficulties);
1265        assert!(result.is_err());
1266    }
1267
1268    #[test]
1269    fn test_curriculum_analysis() {
1270        let baseline = vec![0.6, 0.7, 0.75, 0.8];
1271        let curriculum = vec![0.7, 0.8, 0.85, 0.9];
1272
1273        let analysis = utils::analyze_curriculum_effectiveness(&baseline, &curriculum);
1274        // Use approximate comparison for floating point values
1275        assert!((analysis.final_accuracy_improvement - 0.1).abs() < 1e-6); // 0.9 - 0.8
1276        assert!((analysis.baseline_final_accuracy - 0.8).abs() < 1e-6);
1277        assert!((analysis.curriculum_final_accuracy - 0.9).abs() < 1e-6);
1278        assert!(analysis.convergence_speedup > 1.0);
1279    }
1280}