Skip to main content

tensorlogic_train/
sampling.rs

1//! Advanced sampling strategies for training.
2//!
3//! This module provides sophisticated sampling techniques to improve training efficiency:
4//! - **Importance sampling**: Weight samples by their importance for learning
5//! - **Hard negative mining**: Focus on difficult negative examples
6//! - **Focal sampling**: Emphasize hard-to-classify examples
7//! - **Class-balanced sampling**: Handle imbalanced datasets
8//! - **Curriculum sampling**: Gradually increase sample difficulty
9//!
10//! # Examples
11//!
12//! ## Hard Negative Mining
13//! ```rust
14//! use tensorlogic_train::{HardNegativeMiner, MiningStrategy};
15//! use scirs2_core::ndarray::Array1;
16//!
17//! let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
18//! let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0]);
19//!
20//! let miner = HardNegativeMiner::new(MiningStrategy::TopK(2), 0.0);
21//! let selected = miner.select_samples(&losses, &labels).unwrap();
22//! ```
23//!
24//! ## Importance Sampling
25//! ```rust
26//! use tensorlogic_train::ImportanceSampler;
27//! use scirs2_core::ndarray::Array1;
28//!
29//! let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
30//! let sampler = ImportanceSampler::new(2, 42);
31//! let selected = sampler.sample(&scores).unwrap();
32//! ```
33
34use scirs2_core::ndarray::Array1;
35use serde::{Deserialize, Serialize};
36use std::collections::HashMap;
37
38use crate::error::{TrainError, TrainResult};
39
40/// Strategy for mining hard examples.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub enum MiningStrategy {
43    /// Select top-K samples with highest loss
44    TopK(usize),
45    /// Select samples above a loss threshold
46    Threshold(f64),
47    /// Select top percentage of samples
48    TopPercentage(f64),
49    /// Select samples using focal weighting (emphasize hard examples)
50    Focal { gamma: f64, num_samples: usize },
51}
52
53/// Hard negative mining for handling imbalanced datasets.
54///
55/// Focuses training on difficult negative examples to improve classifier discrimination.
56///
57/// # References
58/// - Shrivastava et al. (2016): "Training Region-based Object Detectors with Online Hard Example Mining"
59#[derive(Debug, Clone)]
60pub struct HardNegativeMiner {
61    /// Mining strategy to use
62    pub strategy: MiningStrategy,
63    /// Ratio of positives to negatives to maintain
64    pub pos_neg_ratio: f64,
65}
66
67impl HardNegativeMiner {
68    /// Create a new hard negative miner.
69    pub fn new(strategy: MiningStrategy, pos_neg_ratio: f64) -> Self {
70        Self {
71            strategy,
72            pos_neg_ratio,
73        }
74    }
75
76    /// Select hard negative samples based on loss values.
77    ///
78    /// # Arguments
79    /// * `losses` - Per-sample loss values
80    /// * `labels` - True labels (1.0 for positive, 0.0 for negative)
81    ///
82    /// # Returns
83    /// Indices of selected samples
84    pub fn select_samples(
85        &self,
86        losses: &Array1<f64>,
87        labels: &Array1<f64>,
88    ) -> TrainResult<Vec<usize>> {
89        if losses.len() != labels.len() {
90            return Err(TrainError::InvalidParameter(
91                "Losses and labels must have same length".to_string(),
92            ));
93        }
94
95        // Separate positive and negative indices
96        let mut pos_indices = Vec::new();
97        let mut neg_indices = Vec::new();
98
99        for (idx, &label) in labels.iter().enumerate() {
100            if label > 0.5 {
101                pos_indices.push(idx);
102            } else {
103                neg_indices.push(idx);
104            }
105        }
106
107        // Select all positives
108        let mut selected = pos_indices.clone();
109
110        // Select hard negatives based on strategy
111        let num_negatives = if self.pos_neg_ratio > 0.0 {
112            (pos_indices.len() as f64 * self.pos_neg_ratio) as usize
113        } else {
114            match &self.strategy {
115                MiningStrategy::TopK(k) => *k,
116                MiningStrategy::TopPercentage(p) => (neg_indices.len() as f64 * p) as usize,
117                MiningStrategy::Focal { num_samples, .. } => *num_samples,
118                MiningStrategy::Threshold(_) => neg_indices.len(),
119            }
120        };
121
122        let hard_negatives = self.select_hard_negatives(losses, &neg_indices, num_negatives)?;
123        selected.extend(hard_negatives);
124
125        Ok(selected)
126    }
127
128    /// Select hard negative examples.
129    fn select_hard_negatives(
130        &self,
131        losses: &Array1<f64>,
132        neg_indices: &[usize],
133        num_samples: usize,
134    ) -> TrainResult<Vec<usize>> {
135        if neg_indices.is_empty() {
136            return Ok(Vec::new());
137        }
138
139        match &self.strategy {
140            MiningStrategy::TopK(_) | MiningStrategy::TopPercentage(_) => {
141                // Sort negatives by loss (descending)
142                let mut neg_with_loss: Vec<(usize, f64)> =
143                    neg_indices.iter().map(|&idx| (idx, losses[idx])).collect();
144                neg_with_loss.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
145
146                let k = num_samples.min(neg_with_loss.len());
147                Ok(neg_with_loss.iter().take(k).map(|(idx, _)| *idx).collect())
148            }
149            MiningStrategy::Threshold(threshold) => {
150                // Select all negatives above threshold
151                Ok(neg_indices
152                    .iter()
153                    .filter(|&&idx| losses[idx] > *threshold)
154                    .copied()
155                    .collect())
156            }
157            MiningStrategy::Focal { gamma, .. } => {
158                // Use focal weighting: (1 - p)^gamma
159                let mut neg_with_weight: Vec<(usize, f64)> = neg_indices
160                    .iter()
161                    .map(|&idx| {
162                        let loss = losses[idx];
163                        let p = (-loss).exp(); // Approximate probability
164                        let weight = (1.0 - p).powf(*gamma);
165                        (idx, weight)
166                    })
167                    .collect();
168                neg_with_weight.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
169
170                let k = num_samples.min(neg_with_weight.len());
171                Ok(neg_with_weight
172                    .iter()
173                    .take(k)
174                    .map(|(idx, _)| *idx)
175                    .collect())
176            }
177        }
178    }
179}
180
181/// Importance sampling based on sample scores.
182///
183/// Samples examples with probability proportional to their importance scores.
184/// Useful for focusing on informative examples.
185#[derive(Debug, Clone)]
186pub struct ImportanceSampler {
187    /// Number of samples to draw
188    pub num_samples: usize,
189    /// Random seed for reproducibility
190    pub seed: u64,
191}
192
193impl ImportanceSampler {
194    /// Create a new importance sampler.
195    pub fn new(num_samples: usize, seed: u64) -> Self {
196        Self { num_samples, seed }
197    }
198
199    /// Sample indices based on importance scores.
200    ///
201    /// # Arguments
202    /// * `scores` - Importance scores for each sample (higher = more important)
203    ///
204    /// # Returns
205    /// Sampled indices
206    pub fn sample(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
207        if scores.is_empty() {
208            return Ok(Vec::new());
209        }
210
211        // Normalize scores to probabilities
212        let total: f64 = scores.iter().sum();
213        if total <= 0.0 {
214            return Err(TrainError::InvalidParameter(
215                "Importance scores must be positive".to_string(),
216            ));
217        }
218
219        let probabilities: Vec<f64> = scores.iter().map(|&s| s / total).collect();
220
221        // Compute cumulative probabilities
222        let mut cumulative = Vec::with_capacity(probabilities.len());
223        let mut sum = 0.0;
224        for &p in &probabilities {
225            sum += p;
226            cumulative.push(sum);
227        }
228
229        // Sample using linear congruential generator (simple, deterministic)
230        let mut selected = Vec::new();
231        let mut rng_state = self.seed;
232
233        for _ in 0..self.num_samples {
234            // Generate random number in [0, 1)
235            rng_state = (rng_state.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
236            let rand = (rng_state as f64) / (0x7fffffff as f64);
237
238            // Binary search for the sample
239            match cumulative.binary_search_by(|&p| {
240                if p < rand {
241                    std::cmp::Ordering::Less
242                } else {
243                    std::cmp::Ordering::Greater
244                }
245            }) {
246                Ok(idx) => selected.push(idx),
247                Err(idx) => selected.push(idx.min(cumulative.len() - 1)),
248            }
249        }
250
251        Ok(selected)
252    }
253
254    /// Sample with replacement allowed.
255    pub fn sample_with_replacement(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
256        self.sample(scores)
257    }
258
259    /// Sample without replacement (unique indices).
260    pub fn sample_without_replacement(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
261        let mut samples = self.sample(scores)?;
262        samples.sort_unstable();
263        samples.dedup();
264        Ok(samples)
265    }
266}
267
268/// Focal sampling strategy.
269///
270/// Emphasizes hard-to-classify examples using focal loss weighting.
271///
272/// # References
273/// - Lin et al. (2017): "Focal Loss for Dense Object Detection"
274#[derive(Debug, Clone)]
275pub struct FocalSampler {
276    /// Focusing parameter (higher = more focus on hard examples)
277    pub gamma: f64,
278    /// Number of samples to select
279    pub num_samples: usize,
280}
281
282impl FocalSampler {
283    /// Create a new focal sampler.
284    pub fn new(gamma: f64, num_samples: usize) -> Self {
285        Self { gamma, num_samples }
286    }
287
288    /// Select samples using focal weighting.
289    ///
290    /// # Arguments
291    /// * `predictions` - Model predictions (probabilities)
292    /// * `labels` - True labels
293    ///
294    /// # Returns
295    /// Indices of selected samples
296    pub fn select_samples(
297        &self,
298        predictions: &Array1<f64>,
299        labels: &Array1<f64>,
300    ) -> TrainResult<Vec<usize>> {
301        if predictions.len() != labels.len() {
302            return Err(TrainError::InvalidParameter(
303                "Predictions and labels must have same length".to_string(),
304            ));
305        }
306
307        // Compute focal weights: (1 - p_t)^gamma
308        let mut weights = Vec::with_capacity(predictions.len());
309        for (&pred, &label) in predictions.iter().zip(labels.iter()) {
310            let p_t = if label > 0.5 { pred } else { 1.0 - pred };
311            let weight = (1.0 - p_t).powf(self.gamma);
312            weights.push(weight);
313        }
314
315        // Select top samples by weight
316        let mut indexed_weights: Vec<(usize, f64)> = weights.into_iter().enumerate().collect();
317        indexed_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
318
319        let k = self.num_samples.min(indexed_weights.len());
320        Ok(indexed_weights
321            .iter()
322            .take(k)
323            .map(|(idx, _)| *idx)
324            .collect())
325    }
326}
327
328/// Class-balanced sampling for imbalanced datasets.
329///
330/// Ensures equal representation of all classes during training.
331#[derive(Debug, Clone)]
332pub struct ClassBalancedSampler {
333    /// Number of samples per class
334    pub samples_per_class: usize,
335    /// Random seed
336    pub seed: u64,
337}
338
339impl ClassBalancedSampler {
340    /// Create a new class-balanced sampler.
341    pub fn new(samples_per_class: usize, seed: u64) -> Self {
342        Self {
343            samples_per_class,
344            seed,
345        }
346    }
347
348    /// Sample balanced batches from each class.
349    ///
350    /// # Arguments
351    /// * `labels` - Class labels
352    ///
353    /// # Returns
354    /// Sampled indices
355    pub fn sample(&self, labels: &Array1<f64>) -> TrainResult<Vec<usize>> {
356        // Group indices by class
357        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
358
359        for (idx, &label) in labels.iter().enumerate() {
360            let class = label.round() as i32;
361            class_indices.entry(class).or_default().push(idx);
362        }
363
364        if class_indices.is_empty() {
365            return Ok(Vec::new());
366        }
367
368        // Sample from each class
369        let mut selected = Vec::new();
370        let mut rng_state = self.seed;
371
372        for (_, indices) in class_indices.iter() {
373            let num_to_sample = self.samples_per_class.min(indices.len());
374
375            // Fisher-Yates shuffle and take first k
376            let mut shuffled = indices.clone();
377            for i in 0..num_to_sample {
378                rng_state = (rng_state.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
379                let j = i + ((rng_state as usize) % (shuffled.len() - i));
380                shuffled.swap(i, j);
381            }
382
383            selected.extend_from_slice(&shuffled[..num_to_sample]);
384        }
385
386        Ok(selected)
387    }
388
389    /// Compute class weights for weighted sampling.
390    pub fn compute_class_weights(&self, labels: &Array1<f64>) -> TrainResult<HashMap<i32, f64>> {
391        let mut class_counts: HashMap<i32, usize> = HashMap::new();
392
393        for &label in labels.iter() {
394            let class = label.round() as i32;
395            *class_counts.entry(class).or_insert(0) += 1;
396        }
397
398        let total = labels.len() as f64;
399        let num_classes = class_counts.len() as f64;
400
401        // Inverse frequency weighting
402        let weights: HashMap<i32, f64> = class_counts
403            .into_iter()
404            .map(|(class, count)| {
405                let weight = total / (num_classes * count as f64);
406                (class, weight)
407            })
408            .collect();
409
410        Ok(weights)
411    }
412}
413
414/// Curriculum sampling for progressive difficulty.
415///
416/// Gradually introduces harder examples as training progresses.
417#[derive(Debug, Clone)]
418pub struct CurriculumSampler {
419    /// Current training progress (0.0 to 1.0)
420    pub progress: f64,
421    /// Difficulty scores for each sample
422    pub difficulty_scores: Array1<f64>,
423    /// Number of samples to select
424    pub num_samples: usize,
425}
426
427impl CurriculumSampler {
428    /// Create a new curriculum sampler.
429    pub fn new(difficulty_scores: Array1<f64>, num_samples: usize) -> Self {
430        Self {
431            progress: 0.0,
432            difficulty_scores,
433            num_samples,
434        }
435    }
436
437    /// Update training progress.
438    pub fn update_progress(&mut self, progress: f64) {
439        self.progress = progress.clamp(0.0, 1.0);
440    }
441
442    /// Select samples based on current curriculum stage.
443    ///
444    /// # Returns
445    /// Indices of samples appropriate for current training stage
446    pub fn select_samples(&self) -> TrainResult<Vec<usize>> {
447        // Difficulty threshold increases with progress
448        let max_difficulty = self.progress;
449
450        // Select samples below difficulty threshold
451        let mut candidates: Vec<usize> = self
452            .difficulty_scores
453            .iter()
454            .enumerate()
455            .filter(|(_, &score)| score <= max_difficulty)
456            .map(|(idx, _)| idx)
457            .collect();
458
459        // If not enough samples, gradually include harder ones
460        if candidates.len() < self.num_samples {
461            let mut all_sorted: Vec<(usize, f64)> = self
462                .difficulty_scores
463                .iter()
464                .enumerate()
465                .map(|(idx, &score)| (idx, score))
466                .collect();
467            all_sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
468
469            candidates = all_sorted
470                .iter()
471                .take(self.num_samples)
472                .map(|(idx, _)| *idx)
473                .collect();
474        }
475
476        // Randomly sample if we have too many
477        if candidates.len() > self.num_samples {
478            candidates.truncate(self.num_samples);
479        }
480
481        Ok(candidates)
482    }
483}
484
485/// Online hard example mining during training.
486///
487/// Dynamically identifies and focuses on hard examples within each batch.
488#[derive(Debug, Clone)]
489pub struct OnlineHardExampleMiner {
490    /// Mining strategy
491    pub strategy: MiningStrategy,
492    /// Keep easy examples ratio (for stability)
493    pub keep_easy_ratio: f64,
494}
495
496impl OnlineHardExampleMiner {
497    /// Create a new online hard example miner.
498    pub fn new(strategy: MiningStrategy, keep_easy_ratio: f64) -> Self {
499        Self {
500            strategy,
501            keep_easy_ratio,
502        }
503    }
504
505    /// Mine hard examples from a batch.
506    ///
507    /// # Arguments
508    /// * `losses` - Per-sample losses in the batch
509    ///
510    /// # Returns
511    /// Indices of samples to keep for gradient update
512    pub fn mine_batch(&self, losses: &Array1<f64>) -> TrainResult<Vec<usize>> {
513        if losses.is_empty() {
514            return Ok(Vec::new());
515        }
516
517        // Sort by loss
518        let mut indexed_losses: Vec<(usize, f64)> = losses.iter().copied().enumerate().collect();
519        indexed_losses.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
520
521        let total_samples = losses.len();
522        let num_hard = match &self.strategy {
523            MiningStrategy::TopK(k) => (*k).min(total_samples),
524            MiningStrategy::TopPercentage(p) => (total_samples as f64 * p) as usize,
525            MiningStrategy::Threshold(t) => {
526                indexed_losses.iter().filter(|(_, loss)| *loss > *t).count()
527            }
528            MiningStrategy::Focal { num_samples, .. } => (*num_samples).min(total_samples),
529        };
530
531        // Keep some easy examples for stability
532        let num_easy = (total_samples as f64 * self.keep_easy_ratio) as usize;
533
534        // Take hard examples from the front, easy examples from the back
535        let mut selected = Vec::new();
536        selected.extend(indexed_losses.iter().take(num_hard).map(|(idx, _)| *idx));
537        if num_easy > 0 {
538            selected.extend(
539                indexed_losses
540                    .iter()
541                    .skip(total_samples - num_easy)
542                    .map(|(idx, _)| *idx),
543            );
544        }
545
546        Ok(selected)
547    }
548}
549
550/// Batch reweighting based on sample importance.
551///
552/// Computes weights for each sample in a batch to emphasize important examples.
553#[derive(Debug, Clone)]
554pub struct BatchReweighter {
555    /// Reweighting strategy
556    pub strategy: ReweightingStrategy,
557}
558
559/// Strategy for reweighting samples.
560#[derive(Debug, Clone, Serialize, Deserialize)]
561pub enum ReweightingStrategy {
562    /// Uniform weights (no reweighting)
563    Uniform,
564    /// Inverse loss weighting
565    InverseLoss { epsilon: f64 },
566    /// Focal loss weighting
567    Focal { gamma: f64 },
568    /// Gradient norm based
569    GradientNorm { epsilon: f64 },
570}
571
572impl BatchReweighter {
573    /// Create a new batch reweighter.
574    pub fn new(strategy: ReweightingStrategy) -> Self {
575        Self { strategy }
576    }
577
578    /// Compute sample weights for a batch.
579    ///
580    /// # Arguments
581    /// * `losses` - Per-sample losses
582    ///
583    /// # Returns
584    /// Weight for each sample
585    pub fn compute_weights(&self, losses: &Array1<f64>) -> TrainResult<Array1<f64>> {
586        match &self.strategy {
587            ReweightingStrategy::Uniform => Ok(Array1::ones(losses.len())),
588            ReweightingStrategy::InverseLoss { epsilon } => {
589                let weights = losses.mapv(|loss| 1.0 / (loss + epsilon));
590                // Normalize
591                let sum: f64 = weights.sum();
592                Ok(weights * (losses.len() as f64 / sum))
593            }
594            ReweightingStrategy::Focal { gamma } => {
595                // Weight = (1 - p)^gamma where p = exp(-loss)
596                let weights = losses.mapv(|loss| {
597                    let p = (-loss).exp().min(0.9999);
598                    (1.0 - p).powf(*gamma)
599                });
600                // Normalize
601                let sum: f64 = weights.sum();
602                Ok(weights * (losses.len() as f64 / sum))
603            }
604            ReweightingStrategy::GradientNorm { epsilon } => {
605                // Approximate gradient norm from loss
606                let weights = losses.mapv(|loss| loss.sqrt() + epsilon);
607                let sum: f64 = weights.sum();
608                Ok(weights * (losses.len() as f64 / sum))
609            }
610        }
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    #[test]
619    fn test_hard_negative_miner_topk() {
620        let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2, 0.7]);
621        let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]);
622
623        let miner = HardNegativeMiner::new(MiningStrategy::TopK(2), 0.0);
624        let selected = miner.select_samples(&losses, &labels).unwrap();
625
626        // Should include all positives (0, 2, 4) and top 2 negatives (1, 3)
627        assert!(selected.contains(&0));
628        assert!(selected.contains(&2));
629        assert!(selected.contains(&4));
630        assert!(selected.contains(&1)); // Loss 0.9
631        assert!(selected.contains(&3)); // Loss 0.8
632    }
633
634    #[test]
635    fn test_hard_negative_miner_threshold() {
636        let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
637        let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 0.0]);
638
639        let miner = HardNegativeMiner::new(MiningStrategy::Threshold(0.5), 0.0);
640        let selected = miner.select_samples(&losses, &labels).unwrap();
641
642        // Should include all positives and negatives with loss > 0.5
643        assert!(selected.contains(&0)); // Positive
644        assert!(selected.contains(&2)); // Positive
645        assert!(selected.contains(&1)); // Negative, loss 0.9 > 0.5
646        assert!(selected.contains(&3)); // Negative, loss 0.8 > 0.5
647        assert!(!selected.contains(&4)); // Negative, loss 0.2 < 0.5
648    }
649
650    #[test]
651    fn test_importance_sampler() {
652        let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
653        let sampler = ImportanceSampler::new(3, 42);
654
655        let selected = sampler.sample(&scores).unwrap();
656        assert_eq!(selected.len(), 3);
657
658        // Higher scored items should be more likely
659        // With seed 42, we should get deterministic results
660        assert!(selected.len() <= 4);
661    }
662
663    #[test]
664    fn test_importance_sampler_without_replacement() {
665        let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
666        let sampler = ImportanceSampler::new(5, 42);
667
668        let selected = sampler.sample_without_replacement(&scores).unwrap();
669
670        // Should have unique indices
671        let mut sorted = selected.clone();
672        sorted.sort_unstable();
673        sorted.dedup();
674        assert_eq!(sorted.len(), selected.len());
675    }
676
677    #[test]
678    fn test_focal_sampler() {
679        let predictions = Array1::from_vec(vec![0.9, 0.1, 0.5, 0.8, 0.3]);
680        let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 1.0, 0.0]);
681
682        let sampler = FocalSampler::new(2.0, 3);
683        let selected = sampler.select_samples(&predictions, &labels).unwrap();
684
685        assert_eq!(selected.len(), 3);
686        // Should select hard examples (where prediction is far from label)
687        assert!(selected.contains(&2)); // pred=0.5, label=1.0 (hard)
688    }
689
690    #[test]
691    fn test_class_balanced_sampler() {
692        let labels = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 2.0]);
693        let sampler = ClassBalancedSampler::new(2, 42);
694
695        let selected = sampler.sample(&labels).unwrap();
696
697        // Should sample up to 2 from each class
698        // Class 0: 2, Class 1: 2, Class 2: 1 (only 1 available) = 5 total
699        assert_eq!(selected.len(), 5);
700
701        // Verify we got samples from each class
702        let selected_labels: Vec<f64> = selected.iter().map(|&idx| labels[idx]).collect();
703        assert!(selected_labels.contains(&0.0));
704        assert!(selected_labels.contains(&1.0));
705        assert!(selected_labels.contains(&2.0));
706    }
707
708    #[test]
709    fn test_class_balanced_weights() {
710        let labels = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 2.0]);
711        let sampler = ClassBalancedSampler::new(2, 42);
712
713        let weights = sampler.compute_class_weights(&labels).unwrap();
714
715        // Class 0: 3 samples, weight = 6/(3*3) = 0.667
716        // Class 1: 2 samples, weight = 6/(3*2) = 1.0
717        // Class 2: 1 sample, weight = 6/(3*1) = 2.0
718        assert!((weights[&0] - 0.667).abs() < 0.01);
719        assert!((weights[&1] - 1.0).abs() < 0.01);
720        assert!((weights[&2] - 2.0).abs() < 0.01);
721    }
722
723    #[test]
724    fn test_curriculum_sampler() {
725        let difficulty = Array1::from_vec(vec![0.1, 0.3, 0.5, 0.7, 0.9]);
726        let mut sampler = CurriculumSampler::new(difficulty, 3);
727
728        // At 0% progress, should only select easiest samples
729        sampler.update_progress(0.0);
730        let selected = sampler.select_samples().unwrap();
731        assert!(!selected.is_empty());
732
733        // At 50% progress, should include medium difficulty
734        sampler.update_progress(0.5);
735        let selected = sampler.select_samples().unwrap();
736        assert!(selected.len() >= 3);
737
738        // At 100% progress, should include all samples
739        sampler.update_progress(1.0);
740        let selected = sampler.select_samples().unwrap();
741        assert_eq!(selected.len(), 3);
742    }
743
744    #[test]
745    fn test_online_hard_example_miner() {
746        let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
747        let miner = OnlineHardExampleMiner::new(MiningStrategy::TopK(2), 0.2);
748
749        let selected = miner.mine_batch(&losses).unwrap();
750
751        // Should keep top 2 hard (1, 3) and some easy
752        assert!(selected.len() >= 2);
753        assert!(selected.contains(&1)); // Highest loss
754        assert!(selected.contains(&3)); // Second highest
755    }
756
757    #[test]
758    fn test_batch_reweighter_uniform() {
759        let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
760        let reweighter = BatchReweighter::new(ReweightingStrategy::Uniform);
761
762        let weights = reweighter.compute_weights(&losses).unwrap();
763
764        assert_eq!(weights.len(), 3);
765        assert!((weights[0] - 1.0).abs() < 1e-10);
766        assert!((weights[1] - 1.0).abs() < 1e-10);
767        assert!((weights[2] - 1.0).abs() < 1e-10);
768    }
769
770    #[test]
771    fn test_batch_reweighter_inverse_loss() {
772        let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
773        let reweighter = BatchReweighter::new(ReweightingStrategy::InverseLoss { epsilon: 0.01 });
774
775        let weights = reweighter.compute_weights(&losses).unwrap();
776
777        // Lower loss should have higher weight (inverse)
778        assert!(weights[0] > weights[1]);
779        assert!(weights[1] > weights[2]);
780
781        // Weights should sum to number of samples (normalized)
782        let sum: f64 = weights.sum();
783        assert!((sum - 3.0).abs() < 0.01);
784    }
785
786    #[test]
787    fn test_batch_reweighter_focal() {
788        let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
789        let reweighter = BatchReweighter::new(ReweightingStrategy::Focal { gamma: 2.0 });
790
791        let weights = reweighter.compute_weights(&losses).unwrap();
792
793        // Higher loss should have higher weight (focal emphasizes hard)
794        assert!(weights[2] > weights[1]);
795        assert!(weights[1] > weights[0]);
796
797        // Weights should sum to number of samples
798        let sum: f64 = weights.sum();
799        assert!((sum - 3.0).abs() < 0.01);
800    }
801
802    #[test]
803    fn test_hard_negative_miner_pos_neg_ratio() {
804        let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2, 0.7]);
805        let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]);
806
807        // 3 positives, ratio 1.0 means 3 negatives
808        let miner = HardNegativeMiner::new(MiningStrategy::TopK(10), 1.0);
809        let selected = miner.select_samples(&losses, &labels).unwrap();
810
811        let num_pos = selected.iter().filter(|&&idx| labels[idx] > 0.5).count();
812        let num_neg = selected.iter().filter(|&&idx| labels[idx] < 0.5).count();
813
814        assert_eq!(num_pos, 3);
815        assert_eq!(num_neg, 3); // Should select 3 negatives (ratio 1:1)
816    }
817
818    #[test]
819    fn test_curriculum_sampler_progress_bounds() {
820        let difficulty = Array1::from_vec(vec![0.1, 0.5, 0.9]);
821        let mut sampler = CurriculumSampler::new(difficulty, 2);
822
823        // Test progress clamping
824        sampler.update_progress(-0.5);
825        assert_eq!(sampler.progress, 0.0);
826
827        sampler.update_progress(1.5);
828        assert_eq!(sampler.progress, 1.0);
829    }
830}