Skip to main content

tensorlogic_train/
curriculum.rs

1//! Curriculum learning strategies for progressive training.
2//!
3//! This module provides various curriculum learning strategies that gradually increase
4//! training difficulty:
5//! - Sample-level curriculum (difficulty scoring)
6//! - Competence-based pacing (adaptive difficulty)
7//! - Self-paced learning
8//! - Task-level curriculum (multi-task progressive training)
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
12use std::collections::HashMap;
13
14/// Trait for curriculum learning strategies.
15pub trait CurriculumStrategy {
16    /// Get the subset of samples to use for the current training step.
17    ///
18    /// # Arguments
19    /// * `epoch` - Current training epoch
20    /// * `total_epochs` - Total number of training epochs
21    /// * `difficulties` - Difficulty scores for each sample `[N]`
22    ///
23    /// # Returns
24    /// Indices of samples to include in training at this stage
25    fn select_samples(
26        &self,
27        epoch: usize,
28        total_epochs: usize,
29        difficulties: &ArrayView1<f64>,
30    ) -> TrainResult<Vec<usize>>;
31
32    /// Compute difficulty scores for training samples.
33    ///
34    /// # Arguments
35    /// * `data` - Training data `[N, features]`
36    /// * `labels` - Training labels `[N, classes]`
37    /// * `predictions` - Model predictions `[N, classes]` (optional, for adaptive strategies)
38    ///
39    /// # Returns
40    /// Difficulty score for each sample `[N]` (higher = more difficult)
41    fn compute_difficulty(
42        &self,
43        data: &Array2<f64>,
44        labels: &Array2<f64>,
45        predictions: Option<&Array2<f64>>,
46    ) -> TrainResult<Array1<f64>>;
47}
48
49/// Linear curriculum: gradually increase the percentage of samples used.
50///
51/// Starts with a small percentage of easiest samples and linearly increases
52/// to use all samples by the end of training.
53#[derive(Debug, Clone)]
54pub struct LinearCurriculum {
55    /// Initial percentage of samples to use (0.0 to 1.0).
56    pub start_percentage: f64,
57    /// Whether to sort by difficulty (true) or use all samples (false).
58    pub sort_by_difficulty: bool,
59}
60
61impl LinearCurriculum {
62    /// Create a new linear curriculum.
63    ///
64    /// # Arguments
65    /// * `start_percentage` - Initial percentage of samples (e.g., 0.1 for 10%)
66    pub fn new(start_percentage: f64) -> TrainResult<Self> {
67        if !(0.0..=1.0).contains(&start_percentage) {
68            return Err(TrainError::InvalidParameter(
69                "start_percentage must be in [0, 1]".to_string(),
70            ));
71        }
72        Ok(Self {
73            start_percentage,
74            sort_by_difficulty: true,
75        })
76    }
77
78    /// Disable sorting by difficulty (use random subset instead).
79    pub fn without_sorting(mut self) -> Self {
80        self.sort_by_difficulty = false;
81        self
82    }
83}
84
85impl Default for LinearCurriculum {
86    fn default() -> Self {
87        Self {
88            start_percentage: 0.2,
89            sort_by_difficulty: true,
90        }
91    }
92}
93
94impl CurriculumStrategy for LinearCurriculum {
95    fn select_samples(
96        &self,
97        epoch: usize,
98        total_epochs: usize,
99        difficulties: &ArrayView1<f64>,
100    ) -> TrainResult<Vec<usize>> {
101        let n = difficulties.len();
102        if n == 0 {
103            return Ok(Vec::new());
104        }
105
106        // Compute current percentage (linear interpolation)
107        let progress = if total_epochs > 1 {
108            epoch as f64 / (total_epochs - 1) as f64
109        } else {
110            1.0
111        };
112        let current_percentage = self.start_percentage + (1.0 - self.start_percentage) * progress;
113        let num_samples = ((n as f64 * current_percentage).ceil() as usize).min(n);
114
115        if !self.sort_by_difficulty {
116            // Return first num_samples indices
117            return Ok((0..num_samples).collect());
118        }
119
120        // Sort by difficulty (ascending) and select easiest samples
121        let mut indices: Vec<usize> = (0..n).collect();
122        indices.sort_by(|&a, &b| {
123            difficulties[a]
124                .partial_cmp(&difficulties[b])
125                .unwrap_or(std::cmp::Ordering::Equal)
126        });
127
128        Ok(indices.into_iter().take(num_samples).collect())
129    }
130
131    fn compute_difficulty(
132        &self,
133        _data: &Array2<f64>,
134        _labels: &Array2<f64>,
135        predictions: Option<&Array2<f64>>,
136    ) -> TrainResult<Array1<f64>> {
137        // Default: use prediction entropy as difficulty
138        // If no predictions provided, use zeros (all equal difficulty)
139        if let Some(preds) = predictions {
140            let n = preds.nrows();
141            let mut difficulties = Array1::zeros(n);
142
143            for i in 0..n {
144                let pred = preds.row(i);
145                // Compute entropy: -Σ p_i log(p_i)
146                let mut entropy = 0.0;
147                for &p in pred.iter() {
148                    if p > 1e-10 {
149                        entropy -= p * p.ln();
150                    }
151                }
152                difficulties[i] = entropy;
153            }
154
155            Ok(difficulties)
156        } else {
157            // No predictions provided, assume all equal difficulty
158            Ok(Array1::zeros(_labels.nrows()))
159        }
160    }
161}
162
163/// Exponential curriculum: exponentially increase sample percentage.
164///
165/// Uses an exponential schedule to quickly ramp up the number of training samples.
166#[derive(Debug, Clone)]
167pub struct ExponentialCurriculum {
168    /// Initial percentage of samples.
169    pub start_percentage: f64,
170    /// Exponential growth rate (higher = faster growth).
171    pub growth_rate: f64,
172}
173
174impl ExponentialCurriculum {
175    /// Create a new exponential curriculum.
176    ///
177    /// # Arguments
178    /// * `start_percentage` - Initial percentage of samples
179    /// * `growth_rate` - Growth rate (e.g., 2.0 for doubling)
180    pub fn new(start_percentage: f64, growth_rate: f64) -> TrainResult<Self> {
181        if !(0.0..=1.0).contains(&start_percentage) {
182            return Err(TrainError::InvalidParameter(
183                "start_percentage must be in [0, 1]".to_string(),
184            ));
185        }
186        if growth_rate <= 0.0 {
187            return Err(TrainError::InvalidParameter(
188                "growth_rate must be positive".to_string(),
189            ));
190        }
191        Ok(Self {
192            start_percentage,
193            growth_rate,
194        })
195    }
196}
197
198impl Default for ExponentialCurriculum {
199    fn default() -> Self {
200        Self {
201            start_percentage: 0.1,
202            growth_rate: 2.0,
203        }
204    }
205}
206
207impl CurriculumStrategy for ExponentialCurriculum {
208    fn select_samples(
209        &self,
210        epoch: usize,
211        total_epochs: usize,
212        difficulties: &ArrayView1<f64>,
213    ) -> TrainResult<Vec<usize>> {
214        let n = difficulties.len();
215        if n == 0 {
216            return Ok(Vec::new());
217        }
218
219        // Exponential growth: p(t) = start * exp(growth * t)
220        let progress = if total_epochs > 1 {
221            epoch as f64 / (total_epochs - 1) as f64
222        } else {
223            1.0
224        };
225        let current_percentage =
226            (self.start_percentage * (self.growth_rate * progress).exp()).min(1.0);
227        let num_samples = ((n as f64 * current_percentage).ceil() as usize).min(n);
228
229        // Sort by difficulty and select easiest samples
230        let mut indices: Vec<usize> = (0..n).collect();
231        indices.sort_by(|&a, &b| {
232            difficulties[a]
233                .partial_cmp(&difficulties[b])
234                .unwrap_or(std::cmp::Ordering::Equal)
235        });
236
237        Ok(indices.into_iter().take(num_samples).collect())
238    }
239
240    fn compute_difficulty(
241        &self,
242        _data: &Array2<f64>,
243        _labels: &Array2<f64>,
244        predictions: Option<&Array2<f64>>,
245    ) -> TrainResult<Array1<f64>> {
246        // Same as LinearCurriculum
247        if let Some(preds) = predictions {
248            let n = preds.nrows();
249            let mut difficulties = Array1::zeros(n);
250
251            for i in 0..n {
252                let pred = preds.row(i);
253                let mut entropy = 0.0;
254                for &p in pred.iter() {
255                    if p > 1e-10 {
256                        entropy -= p * p.ln();
257                    }
258                }
259                difficulties[i] = entropy;
260            }
261
262            Ok(difficulties)
263        } else {
264            Ok(Array1::zeros(_labels.nrows()))
265        }
266    }
267}
268
269/// Self-paced learning: model determines its own learning pace.
270///
271/// Adaptively selects samples based on current model performance,
272/// prioritizing samples the model is ready to learn from.
273#[derive(Debug, Clone)]
274pub struct SelfPacedCurriculum {
275    /// Age parameter controlling pace (higher = more aggressive).
276    pub lambda: f64,
277    /// Threshold for sample selection.
278    pub threshold: f64,
279}
280
281impl SelfPacedCurriculum {
282    /// Create a new self-paced curriculum.
283    ///
284    /// # Arguments
285    /// * `lambda` - Age parameter (controls learning pace)
286    /// * `threshold` - Selection threshold
287    pub fn new(lambda: f64, threshold: f64) -> TrainResult<Self> {
288        if lambda <= 0.0 {
289            return Err(TrainError::InvalidParameter(
290                "lambda must be positive".to_string(),
291            ));
292        }
293        Ok(Self { lambda, threshold })
294    }
295}
296
297impl Default for SelfPacedCurriculum {
298    fn default() -> Self {
299        Self {
300            lambda: 1.0,
301            threshold: 0.5,
302        }
303    }
304}
305
306impl CurriculumStrategy for SelfPacedCurriculum {
307    fn select_samples(
308        &self,
309        _epoch: usize,
310        _total_epochs: usize,
311        difficulties: &ArrayView1<f64>,
312    ) -> TrainResult<Vec<usize>> {
313        // Select samples with difficulty below threshold
314        let indices: Vec<usize> = difficulties
315            .iter()
316            .enumerate()
317            .filter(|(_, &d)| d < self.threshold)
318            .map(|(i, _)| i)
319            .collect();
320
321        Ok(indices)
322    }
323
324    fn compute_difficulty(
325        &self,
326        _data: &Array2<f64>,
327        labels: &Array2<f64>,
328        predictions: Option<&Array2<f64>>,
329    ) -> TrainResult<Array1<f64>> {
330        if let Some(preds) = predictions {
331            let n = preds.nrows();
332            let mut difficulties = Array1::zeros(n);
333
334            for i in 0..n {
335                // Compute loss for each sample (cross-entropy)
336                let pred = preds.row(i);
337                let label = labels.row(i);
338
339                let mut loss = 0.0;
340                for j in 0..pred.len() {
341                    let p = pred[j].clamp(1e-10, 1.0 - 1e-10);
342                    loss -= label[j] * p.ln();
343                }
344
345                // Weight by self-pacing parameter
346                difficulties[i] = loss * self.lambda;
347            }
348
349            Ok(difficulties)
350        } else {
351            Err(TrainError::InvalidParameter(
352                "SelfPacedCurriculum requires predictions for difficulty computation".to_string(),
353            ))
354        }
355    }
356}
357
358/// Competence-based curriculum: adapts to model's current competence level.
359///
360/// Gradually increases difficulty based on model's mastery of easier samples.
361#[derive(Debug, Clone)]
362pub struct CompetenceCurriculum {
363    /// Initial competence level (0.0 to 1.0).
364    pub initial_competence: f64,
365    /// Competence growth rate per epoch.
366    pub growth_rate: f64,
367    /// Maximum competence level.
368    pub max_competence: f64,
369}
370
371impl CompetenceCurriculum {
372    /// Create a new competence-based curriculum.
373    ///
374    /// # Arguments
375    /// * `initial_competence` - Starting competence level
376    /// * `growth_rate` - How fast competence grows per epoch
377    pub fn new(initial_competence: f64, growth_rate: f64) -> TrainResult<Self> {
378        if !(0.0..=1.0).contains(&initial_competence) {
379            return Err(TrainError::InvalidParameter(
380                "initial_competence must be in [0, 1]".to_string(),
381            ));
382        }
383        Ok(Self {
384            initial_competence,
385            growth_rate,
386            max_competence: 1.0,
387        })
388    }
389}
390
391impl Default for CompetenceCurriculum {
392    fn default() -> Self {
393        Self {
394            initial_competence: 0.3,
395            growth_rate: 0.05,
396            max_competence: 1.0,
397        }
398    }
399}
400
401impl CurriculumStrategy for CompetenceCurriculum {
402    fn select_samples(
403        &self,
404        epoch: usize,
405        _total_epochs: usize,
406        difficulties: &ArrayView1<f64>,
407    ) -> TrainResult<Vec<usize>> {
408        // Current competence level
409        let competence =
410            (self.initial_competence + self.growth_rate * epoch as f64).min(self.max_competence);
411
412        // Select samples with difficulty <= competence
413        let indices: Vec<usize> = difficulties
414            .iter()
415            .enumerate()
416            .filter(|(_, &d)| d <= competence)
417            .map(|(i, _)| i)
418            .collect();
419
420        Ok(indices)
421    }
422
423    fn compute_difficulty(
424        &self,
425        _data: &Array2<f64>,
426        _labels: &Array2<f64>,
427        predictions: Option<&Array2<f64>>,
428    ) -> TrainResult<Array1<f64>> {
429        // Normalize difficulties to [0, 1]
430        if let Some(preds) = predictions {
431            let n = preds.nrows();
432            let mut difficulties = Array1::zeros(n);
433
434            for i in 0..n {
435                let pred = preds.row(i);
436                let mut entropy = 0.0;
437                for &p in pred.iter() {
438                    if p > 1e-10 {
439                        entropy -= p * p.ln();
440                    }
441                }
442                difficulties[i] = entropy;
443            }
444
445            // Normalize to [0, 1]
446            let max_difficulty = difficulties.iter().cloned().fold(0.0f64, f64::max);
447            if max_difficulty > 0.0 {
448                difficulties.mapv_inplace(|d| d / max_difficulty);
449            }
450
451            Ok(difficulties)
452        } else {
453            Ok(Array1::zeros(_labels.nrows()))
454        }
455    }
456}
457
458/// Task-level curriculum for multi-task learning.
459///
460/// Progressively introduces different tasks during training.
461#[derive(Debug, Clone)]
462pub struct TaskCurriculum {
463    /// Task schedule: (start_epoch, task_id) pairs.
464    task_schedule: Vec<(usize, usize)>,
465}
466
467impl TaskCurriculum {
468    /// Create a new task curriculum.
469    ///
470    /// # Arguments
471    /// * `schedule` - Task introduction schedule [(epoch, task_id)]
472    pub fn new(schedule: Vec<(usize, usize)>) -> Self {
473        let mut sorted_schedule = schedule;
474        sorted_schedule.sort_by_key(|(epoch, _)| *epoch);
475        Self {
476            task_schedule: sorted_schedule,
477        }
478    }
479
480    /// Get active tasks for the current epoch.
481    ///
482    /// # Arguments
483    /// * `epoch` - Current training epoch
484    ///
485    /// # Returns
486    /// Set of active task IDs
487    pub fn get_active_tasks(&self, epoch: usize) -> Vec<usize> {
488        self.task_schedule
489            .iter()
490            .filter(|(start_epoch, _)| *start_epoch <= epoch)
491            .map(|(_, task_id)| *task_id)
492            .collect()
493    }
494}
495
496impl Default for TaskCurriculum {
497    fn default() -> Self {
498        // Default: single task from epoch 0
499        Self {
500            task_schedule: vec![(0, 0)],
501        }
502    }
503}
504
505/// Manager for curriculum learning that tracks training progress.
506pub struct CurriculumManager {
507    strategy: Box<dyn CurriculumStrategyClone>,
508    difficulty_cache: HashMap<String, Array1<f64>>,
509    current_epoch: usize,
510}
511
512impl std::fmt::Debug for CurriculumManager {
513    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514        f.debug_struct("CurriculumManager")
515            .field("current_epoch", &self.current_epoch)
516            .field("num_cached_difficulties", &self.difficulty_cache.len())
517            .finish()
518    }
519}
520
521/// Helper trait for cloning curriculum strategies.
522trait CurriculumStrategyClone: CurriculumStrategy {
523    fn clone_box(&self) -> Box<dyn CurriculumStrategyClone>;
524}
525
526impl<T: CurriculumStrategy + Clone + 'static> CurriculumStrategyClone for T {
527    fn clone_box(&self) -> Box<dyn CurriculumStrategyClone> {
528        Box::new(self.clone())
529    }
530}
531
532impl Clone for Box<dyn CurriculumStrategyClone> {
533    fn clone(&self) -> Self {
534        self.clone_box()
535    }
536}
537
538impl CurriculumStrategy for Box<dyn CurriculumStrategyClone> {
539    fn select_samples(
540        &self,
541        epoch: usize,
542        total_epochs: usize,
543        difficulties: &ArrayView1<f64>,
544    ) -> TrainResult<Vec<usize>> {
545        (**self).select_samples(epoch, total_epochs, difficulties)
546    }
547
548    fn compute_difficulty(
549        &self,
550        data: &Array2<f64>,
551        labels: &Array2<f64>,
552        predictions: Option<&Array2<f64>>,
553    ) -> TrainResult<Array1<f64>> {
554        (**self).compute_difficulty(data, labels, predictions)
555    }
556}
557
558impl CurriculumManager {
559    /// Create a new curriculum manager.
560    ///
561    /// # Arguments
562    /// * `strategy` - Curriculum learning strategy
563    pub fn new<S: CurriculumStrategy + Clone + 'static>(strategy: S) -> Self {
564        Self {
565            strategy: Box::new(strategy),
566            difficulty_cache: HashMap::new(),
567            current_epoch: 0,
568        }
569    }
570
571    /// Update the current epoch.
572    pub fn set_epoch(&mut self, epoch: usize) {
573        self.current_epoch = epoch;
574    }
575
576    /// Compute and cache difficulty scores.
577    ///
578    /// # Arguments
579    /// * `key` - Cache key (e.g., "train", "val")
580    /// * `data` - Training data
581    /// * `labels` - Training labels
582    /// * `predictions` - Optional model predictions
583    pub fn compute_difficulty(
584        &mut self,
585        key: &str,
586        data: &Array2<f64>,
587        labels: &Array2<f64>,
588        predictions: Option<&Array2<f64>>,
589    ) -> TrainResult<()> {
590        let difficulties = self
591            .strategy
592            .compute_difficulty(data, labels, predictions)?;
593        self.difficulty_cache.insert(key.to_string(), difficulties);
594        Ok(())
595    }
596
597    /// Get selected sample indices for training.
598    ///
599    /// # Arguments
600    /// * `key` - Cache key for difficulty scores
601    /// * `total_epochs` - Total number of training epochs
602    ///
603    /// # Returns
604    /// Indices of samples to use for current epoch
605    pub fn get_selected_samples(&self, key: &str, total_epochs: usize) -> TrainResult<Vec<usize>> {
606        let difficulties = self.difficulty_cache.get(key).ok_or_else(|| {
607            TrainError::InvalidParameter(format!("No difficulty scores cached for key: {}", key))
608        })?;
609
610        self.strategy
611            .select_samples(self.current_epoch, total_epochs, &difficulties.view())
612    }
613
614    /// Clear the difficulty cache.
615    pub fn clear_cache(&mut self) {
616        self.difficulty_cache.clear();
617    }
618}
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623    use scirs2_core::ndarray::array;
624
625    #[test]
626    fn test_linear_curriculum() {
627        let curriculum = LinearCurriculum::new(0.2).unwrap();
628        let difficulties = array![0.1, 0.5, 0.3, 0.9, 0.2];
629
630        // At epoch 0, should select 20% of samples (1 sample)
631        let selected = curriculum
632            .select_samples(0, 10, &difficulties.view())
633            .unwrap();
634        assert_eq!(selected.len(), 1);
635
636        // At epoch 9 (last epoch), should select all samples
637        let selected = curriculum
638            .select_samples(9, 10, &difficulties.view())
639            .unwrap();
640        assert_eq!(selected.len(), 5);
641    }
642
643    #[test]
644    fn test_linear_curriculum_invalid() {
645        assert!(LinearCurriculum::new(-0.1).is_err());
646        assert!(LinearCurriculum::new(1.5).is_err());
647    }
648
649    #[test]
650    fn test_exponential_curriculum() {
651        let curriculum = ExponentialCurriculum::new(0.1, 2.0).unwrap();
652        let difficulties = array![0.1, 0.5, 0.3, 0.9, 0.2];
653
654        let selected = curriculum
655            .select_samples(0, 10, &difficulties.view())
656            .unwrap();
657        assert!(!selected.is_empty());
658
659        let selected = curriculum
660            .select_samples(9, 10, &difficulties.view())
661            .unwrap();
662        // Should select most/all samples at the end (exponential growth may round differently)
663        assert!(selected.len() >= 4);
664    }
665
666    #[test]
667    fn test_self_paced_curriculum() {
668        let curriculum = SelfPacedCurriculum::new(1.0, 0.5).unwrap();
669        let difficulties = array![0.1, 0.6, 0.3, 0.9, 0.2];
670
671        // Should select samples with difficulty < 0.5
672        let selected = curriculum
673            .select_samples(0, 10, &difficulties.view())
674            .unwrap();
675        assert_eq!(selected.len(), 3); // indices 0, 2, 4
676    }
677
678    #[test]
679    fn test_competence_curriculum() {
680        let curriculum = CompetenceCurriculum::new(0.3, 0.1).unwrap();
681        let difficulties = array![0.1, 0.5, 0.3, 0.9, 0.2];
682
683        // At epoch 0, competence = 0.3, should select difficulties <= 0.3
684        let selected = curriculum
685            .select_samples(0, 10, &difficulties.view())
686            .unwrap();
687        assert_eq!(selected.len(), 3); // indices 0, 2, 4
688
689        // At epoch 5, competence = 0.8, should select more samples
690        let selected = curriculum
691            .select_samples(5, 10, &difficulties.view())
692            .unwrap();
693        assert!(selected.len() >= 3);
694    }
695
696    #[test]
697    fn test_task_curriculum() {
698        let curriculum = TaskCurriculum::new(vec![(0, 0), (5, 1), (10, 2)]);
699
700        let tasks = curriculum.get_active_tasks(0);
701        assert_eq!(tasks.len(), 1);
702        assert_eq!(tasks[0], 0);
703
704        let tasks = curriculum.get_active_tasks(7);
705        assert_eq!(tasks.len(), 2);
706        assert!(tasks.contains(&0));
707        assert!(tasks.contains(&1));
708
709        let tasks = curriculum.get_active_tasks(15);
710        assert_eq!(tasks.len(), 3);
711    }
712
713    #[test]
714    fn test_difficulty_computation() {
715        let curriculum = LinearCurriculum::default();
716
717        // Test with predictions
718        let data = array![[1.0, 2.0], [3.0, 4.0]];
719        let labels = array![[1.0, 0.0], [0.0, 1.0]];
720        let predictions = array![[0.8, 0.2], [0.3, 0.7]];
721
722        let difficulties = curriculum
723            .compute_difficulty(&data, &labels, Some(&predictions))
724            .unwrap();
725        assert_eq!(difficulties.len(), 2);
726        assert!(difficulties.iter().all(|&d| d >= 0.0));
727
728        // Test without predictions
729        let difficulties = curriculum.compute_difficulty(&data, &labels, None).unwrap();
730        assert_eq!(difficulties.len(), 2);
731        assert!(difficulties.iter().all(|&d| d == 0.0));
732    }
733
734    #[test]
735    fn test_curriculum_manager() {
736        let mut manager = CurriculumManager::new(LinearCurriculum::default());
737
738        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
739        let labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
740        let predictions = array![[0.8, 0.2], [0.3, 0.7], [0.6, 0.4]];
741
742        // Compute difficulties
743        manager
744            .compute_difficulty("train", &data, &labels, Some(&predictions))
745            .unwrap();
746
747        // Get selected samples
748        manager.set_epoch(0);
749        let selected = manager.get_selected_samples("train", 10).unwrap();
750        assert!(!selected.is_empty());
751
752        // Clear cache
753        manager.clear_cache();
754    }
755
756    #[test]
757    fn test_curriculum_manager_missing_key() {
758        let manager = CurriculumManager::new(LinearCurriculum::default());
759        let result = manager.get_selected_samples("nonexistent", 10);
760        assert!(result.is_err());
761    }
762
763    #[test]
764    fn test_linear_curriculum_without_sorting() {
765        let curriculum = LinearCurriculum::new(0.5).unwrap().without_sorting();
766        let difficulties = array![0.9, 0.1, 0.5, 0.3, 0.7];
767
768        // Should not sort by difficulty
769        let selected = curriculum
770            .select_samples(0, 10, &difficulties.view())
771            .unwrap();
772        assert_eq!(selected.len(), 3); // 50% of 5 samples, rounded up
773    }
774
775    #[test]
776    fn test_empty_difficulties() {
777        let curriculum = LinearCurriculum::default();
778        let difficulties = Array1::<f64>::zeros(0);
779
780        let selected = curriculum
781            .select_samples(0, 10, &difficulties.view())
782            .unwrap();
783        assert_eq!(selected.len(), 0);
784    }
785}