Skip to main content

tensorlogic_train/
sampling.rs

1//! Advanced sampling strategies for training.
2//!
3//! This module provides sophisticated sampling techniques to improve training efficiency:
4//! - **Importance sampling**: Weight samples by their importance for learning
5//! - **Hard negative mining**: Focus on difficult negative examples
6//! - **Focal sampling**: Emphasize hard-to-classify examples
7//! - **Class-balanced sampling**: Handle imbalanced datasets
8//! - **Curriculum sampling**: Gradually increase sample difficulty
9//!
10//! # Examples
11//!
12//! ## Hard Negative Mining
13//! ```rust
14//! use tensorlogic_train::{HardNegativeMiner, MiningStrategy};
15//! use scirs2_core::ndarray::Array1;
16//!
17//! let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
18//! let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0]);
19//!
20//! let miner = HardNegativeMiner::new(MiningStrategy::TopK(2), 0.0);
21//! let selected = miner.select_samples(&losses, &labels).expect("unwrap");
22//! ```
23//!
24//! ## Importance Sampling
25//! ```rust
26//! use tensorlogic_train::ImportanceSampler;
27//! use scirs2_core::ndarray::Array1;
28//!
29//! let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
30//! let sampler = ImportanceSampler::new(2, 42);
31//! let selected = sampler.sample(&scores).expect("unwrap");
32//! ```
33
34use scirs2_core::ndarray::Array1;
35use serde::{Deserialize, Serialize};
36use std::collections::HashMap;
37
38use crate::error::{TrainError, TrainResult};
39
40/// Strategy for mining hard examples.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub enum MiningStrategy {
43    /// Select top-K samples with highest loss
44    TopK(usize),
45    /// Select samples above a loss threshold
46    Threshold(f64),
47    /// Select top percentage of samples
48    TopPercentage(f64),
49    /// Select samples using focal weighting (emphasize hard examples)
50    Focal { gamma: f64, num_samples: usize },
51}
52
53/// Hard negative mining for handling imbalanced datasets.
54///
55/// Focuses training on difficult negative examples to improve classifier discrimination.
56///
57/// # References
58/// - Shrivastava et al. (2016): "Training Region-based Object Detectors with Online Hard Example Mining"
59#[derive(Debug, Clone)]
60pub struct HardNegativeMiner {
61    /// Mining strategy to use
62    pub strategy: MiningStrategy,
63    /// Ratio of positives to negatives to maintain
64    pub pos_neg_ratio: f64,
65}
66
67impl HardNegativeMiner {
68    /// Create a new hard negative miner.
69    pub fn new(strategy: MiningStrategy, pos_neg_ratio: f64) -> Self {
70        Self {
71            strategy,
72            pos_neg_ratio,
73        }
74    }
75
76    /// Select hard negative samples based on loss values.
77    ///
78    /// # Arguments
79    /// * `losses` - Per-sample loss values
80    /// * `labels` - True labels (1.0 for positive, 0.0 for negative)
81    ///
82    /// # Returns
83    /// Indices of selected samples
84    pub fn select_samples(
85        &self,
86        losses: &Array1<f64>,
87        labels: &Array1<f64>,
88    ) -> TrainResult<Vec<usize>> {
89        if losses.len() != labels.len() {
90            return Err(TrainError::InvalidParameter(
91                "Losses and labels must have same length".to_string(),
92            ));
93        }
94
95        // Separate positive and negative indices
96        let mut pos_indices = Vec::new();
97        let mut neg_indices = Vec::new();
98
99        for (idx, &label) in labels.iter().enumerate() {
100            if label > 0.5 {
101                pos_indices.push(idx);
102            } else {
103                neg_indices.push(idx);
104            }
105        }
106
107        // Select all positives
108        let mut selected = pos_indices.clone();
109
110        // Select hard negatives based on strategy
111        let num_negatives = if self.pos_neg_ratio > 0.0 {
112            (pos_indices.len() as f64 * self.pos_neg_ratio) as usize
113        } else {
114            match &self.strategy {
115                MiningStrategy::TopK(k) => *k,
116                MiningStrategy::TopPercentage(p) => (neg_indices.len() as f64 * p) as usize,
117                MiningStrategy::Focal { num_samples, .. } => *num_samples,
118                MiningStrategy::Threshold(_) => neg_indices.len(),
119            }
120        };
121
122        let hard_negatives = self.select_hard_negatives(losses, &neg_indices, num_negatives)?;
123        selected.extend(hard_negatives);
124
125        Ok(selected)
126    }
127
128    /// Select hard negative examples.
129    fn select_hard_negatives(
130        &self,
131        losses: &Array1<f64>,
132        neg_indices: &[usize],
133        num_samples: usize,
134    ) -> TrainResult<Vec<usize>> {
135        if neg_indices.is_empty() {
136            return Ok(Vec::new());
137        }
138
139        match &self.strategy {
140            MiningStrategy::TopK(_) | MiningStrategy::TopPercentage(_) => {
141                // Sort negatives by loss (descending)
142                let mut neg_with_loss: Vec<(usize, f64)> =
143                    neg_indices.iter().map(|&idx| (idx, losses[idx])).collect();
144                neg_with_loss
145                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
146
147                let k = num_samples.min(neg_with_loss.len());
148                Ok(neg_with_loss.iter().take(k).map(|(idx, _)| *idx).collect())
149            }
150            MiningStrategy::Threshold(threshold) => {
151                // Select all negatives above threshold
152                Ok(neg_indices
153                    .iter()
154                    .filter(|&&idx| losses[idx] > *threshold)
155                    .copied()
156                    .collect())
157            }
158            MiningStrategy::Focal { gamma, .. } => {
159                // Use focal weighting: (1 - p)^gamma
160                let mut neg_with_weight: Vec<(usize, f64)> = neg_indices
161                    .iter()
162                    .map(|&idx| {
163                        let loss = losses[idx];
164                        let p = (-loss).exp(); // Approximate probability
165                        let weight = (1.0 - p).powf(*gamma);
166                        (idx, weight)
167                    })
168                    .collect();
169                neg_with_weight
170                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
171
172                let k = num_samples.min(neg_with_weight.len());
173                Ok(neg_with_weight
174                    .iter()
175                    .take(k)
176                    .map(|(idx, _)| *idx)
177                    .collect())
178            }
179        }
180    }
181}
182
183/// Importance sampling based on sample scores.
184///
185/// Samples examples with probability proportional to their importance scores.
186/// Useful for focusing on informative examples.
187#[derive(Debug, Clone)]
188pub struct ImportanceSampler {
189    /// Number of samples to draw
190    pub num_samples: usize,
191    /// Random seed for reproducibility
192    pub seed: u64,
193}
194
195impl ImportanceSampler {
196    /// Create a new importance sampler.
197    pub fn new(num_samples: usize, seed: u64) -> Self {
198        Self { num_samples, seed }
199    }
200
201    /// Sample indices based on importance scores.
202    ///
203    /// # Arguments
204    /// * `scores` - Importance scores for each sample (higher = more important)
205    ///
206    /// # Returns
207    /// Sampled indices
208    pub fn sample(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
209        if scores.is_empty() {
210            return Ok(Vec::new());
211        }
212
213        // Normalize scores to probabilities
214        let total: f64 = scores.iter().sum();
215        if total <= 0.0 {
216            return Err(TrainError::InvalidParameter(
217                "Importance scores must be positive".to_string(),
218            ));
219        }
220
221        let probabilities: Vec<f64> = scores.iter().map(|&s| s / total).collect();
222
223        // Compute cumulative probabilities
224        let mut cumulative = Vec::with_capacity(probabilities.len());
225        let mut sum = 0.0;
226        for &p in &probabilities {
227            sum += p;
228            cumulative.push(sum);
229        }
230
231        // Sample using linear congruential generator (simple, deterministic)
232        let mut selected = Vec::new();
233        let mut rng_state = self.seed;
234
235        for _ in 0..self.num_samples {
236            // Generate random number in [0, 1)
237            rng_state = (rng_state.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
238            let rand = (rng_state as f64) / (0x7fffffff as f64);
239
240            // Binary search for the sample
241            match cumulative.binary_search_by(|&p| {
242                if p < rand {
243                    std::cmp::Ordering::Less
244                } else {
245                    std::cmp::Ordering::Greater
246                }
247            }) {
248                Ok(idx) => selected.push(idx),
249                Err(idx) => selected.push(idx.min(cumulative.len() - 1)),
250            }
251        }
252
253        Ok(selected)
254    }
255
256    /// Sample with replacement allowed.
257    pub fn sample_with_replacement(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
258        self.sample(scores)
259    }
260
261    /// Sample without replacement (unique indices).
262    pub fn sample_without_replacement(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
263        let mut samples = self.sample(scores)?;
264        samples.sort_unstable();
265        samples.dedup();
266        Ok(samples)
267    }
268}
269
270/// Focal sampling strategy.
271///
272/// Emphasizes hard-to-classify examples using focal loss weighting.
273///
274/// # References
275/// - Lin et al. (2017): "Focal Loss for Dense Object Detection"
276#[derive(Debug, Clone)]
277pub struct FocalSampler {
278    /// Focusing parameter (higher = more focus on hard examples)
279    pub gamma: f64,
280    /// Number of samples to select
281    pub num_samples: usize,
282}
283
284impl FocalSampler {
285    /// Create a new focal sampler.
286    pub fn new(gamma: f64, num_samples: usize) -> Self {
287        Self { gamma, num_samples }
288    }
289
290    /// Select samples using focal weighting.
291    ///
292    /// # Arguments
293    /// * `predictions` - Model predictions (probabilities)
294    /// * `labels` - True labels
295    ///
296    /// # Returns
297    /// Indices of selected samples
298    pub fn select_samples(
299        &self,
300        predictions: &Array1<f64>,
301        labels: &Array1<f64>,
302    ) -> TrainResult<Vec<usize>> {
303        if predictions.len() != labels.len() {
304            return Err(TrainError::InvalidParameter(
305                "Predictions and labels must have same length".to_string(),
306            ));
307        }
308
309        // Compute focal weights: (1 - p_t)^gamma
310        let mut weights = Vec::with_capacity(predictions.len());
311        for (&pred, &label) in predictions.iter().zip(labels.iter()) {
312            let p_t = if label > 0.5 { pred } else { 1.0 - pred };
313            let weight = (1.0 - p_t).powf(self.gamma);
314            weights.push(weight);
315        }
316
317        // Select top samples by weight
318        let mut indexed_weights: Vec<(usize, f64)> = weights.into_iter().enumerate().collect();
319        indexed_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
320
321        let k = self.num_samples.min(indexed_weights.len());
322        Ok(indexed_weights
323            .iter()
324            .take(k)
325            .map(|(idx, _)| *idx)
326            .collect())
327    }
328}
329
330/// Class-balanced sampling for imbalanced datasets.
331///
332/// Ensures equal representation of all classes during training.
333#[derive(Debug, Clone)]
334pub struct ClassBalancedSampler {
335    /// Number of samples per class
336    pub samples_per_class: usize,
337    /// Random seed
338    pub seed: u64,
339}
340
341impl ClassBalancedSampler {
342    /// Create a new class-balanced sampler.
343    pub fn new(samples_per_class: usize, seed: u64) -> Self {
344        Self {
345            samples_per_class,
346            seed,
347        }
348    }
349
350    /// Sample balanced batches from each class.
351    ///
352    /// # Arguments
353    /// * `labels` - Class labels
354    ///
355    /// # Returns
356    /// Sampled indices
357    pub fn sample(&self, labels: &Array1<f64>) -> TrainResult<Vec<usize>> {
358        // Group indices by class
359        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
360
361        for (idx, &label) in labels.iter().enumerate() {
362            let class = label.round() as i32;
363            class_indices.entry(class).or_default().push(idx);
364        }
365
366        if class_indices.is_empty() {
367            return Ok(Vec::new());
368        }
369
370        // Sample from each class
371        let mut selected = Vec::new();
372        let mut rng_state = self.seed;
373
374        for (_, indices) in class_indices.iter() {
375            let num_to_sample = self.samples_per_class.min(indices.len());
376
377            // Fisher-Yates shuffle and take first k
378            let mut shuffled = indices.clone();
379            for i in 0..num_to_sample {
380                rng_state = (rng_state.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
381                let j = i + ((rng_state as usize) % (shuffled.len() - i));
382                shuffled.swap(i, j);
383            }
384
385            selected.extend_from_slice(&shuffled[..num_to_sample]);
386        }
387
388        Ok(selected)
389    }
390
391    /// Compute class weights for weighted sampling.
392    pub fn compute_class_weights(&self, labels: &Array1<f64>) -> TrainResult<HashMap<i32, f64>> {
393        let mut class_counts: HashMap<i32, usize> = HashMap::new();
394
395        for &label in labels.iter() {
396            let class = label.round() as i32;
397            *class_counts.entry(class).or_insert(0) += 1;
398        }
399
400        let total = labels.len() as f64;
401        let num_classes = class_counts.len() as f64;
402
403        // Inverse frequency weighting
404        let weights: HashMap<i32, f64> = class_counts
405            .into_iter()
406            .map(|(class, count)| {
407                let weight = total / (num_classes * count as f64);
408                (class, weight)
409            })
410            .collect();
411
412        Ok(weights)
413    }
414}
415
416/// Curriculum sampling for progressive difficulty.
417///
418/// Gradually introduces harder examples as training progresses.
419#[derive(Debug, Clone)]
420pub struct CurriculumSampler {
421    /// Current training progress (0.0 to 1.0)
422    pub progress: f64,
423    /// Difficulty scores for each sample
424    pub difficulty_scores: Array1<f64>,
425    /// Number of samples to select
426    pub num_samples: usize,
427}
428
429impl CurriculumSampler {
430    /// Create a new curriculum sampler.
431    pub fn new(difficulty_scores: Array1<f64>, num_samples: usize) -> Self {
432        Self {
433            progress: 0.0,
434            difficulty_scores,
435            num_samples,
436        }
437    }
438
439    /// Update training progress.
440    pub fn update_progress(&mut self, progress: f64) {
441        self.progress = progress.clamp(0.0, 1.0);
442    }
443
444    /// Select samples based on current curriculum stage.
445    ///
446    /// # Returns
447    /// Indices of samples appropriate for current training stage
448    pub fn select_samples(&self) -> TrainResult<Vec<usize>> {
449        // Difficulty threshold increases with progress
450        let max_difficulty = self.progress;
451
452        // Select samples below difficulty threshold
453        let mut candidates: Vec<usize> = self
454            .difficulty_scores
455            .iter()
456            .enumerate()
457            .filter(|(_, &score)| score <= max_difficulty)
458            .map(|(idx, _)| idx)
459            .collect();
460
461        // If not enough samples, gradually include harder ones
462        if candidates.len() < self.num_samples {
463            let mut all_sorted: Vec<(usize, f64)> = self
464                .difficulty_scores
465                .iter()
466                .enumerate()
467                .map(|(idx, &score)| (idx, score))
468                .collect();
469            all_sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
470
471            candidates = all_sorted
472                .iter()
473                .take(self.num_samples)
474                .map(|(idx, _)| *idx)
475                .collect();
476        }
477
478        // Randomly sample if we have too many
479        if candidates.len() > self.num_samples {
480            candidates.truncate(self.num_samples);
481        }
482
483        Ok(candidates)
484    }
485}
486
487/// Online hard example mining during training.
488///
489/// Dynamically identifies and focuses on hard examples within each batch.
490#[derive(Debug, Clone)]
491pub struct OnlineHardExampleMiner {
492    /// Mining strategy
493    pub strategy: MiningStrategy,
494    /// Keep easy examples ratio (for stability)
495    pub keep_easy_ratio: f64,
496}
497
498impl OnlineHardExampleMiner {
499    /// Create a new online hard example miner.
500    pub fn new(strategy: MiningStrategy, keep_easy_ratio: f64) -> Self {
501        Self {
502            strategy,
503            keep_easy_ratio,
504        }
505    }
506
507    /// Mine hard examples from a batch.
508    ///
509    /// # Arguments
510    /// * `losses` - Per-sample losses in the batch
511    ///
512    /// # Returns
513    /// Indices of samples to keep for gradient update
514    pub fn mine_batch(&self, losses: &Array1<f64>) -> TrainResult<Vec<usize>> {
515        if losses.is_empty() {
516            return Ok(Vec::new());
517        }
518
519        // Sort by loss
520        let mut indexed_losses: Vec<(usize, f64)> = losses.iter().copied().enumerate().collect();
521        indexed_losses.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
522
523        let total_samples = losses.len();
524        let num_hard = match &self.strategy {
525            MiningStrategy::TopK(k) => (*k).min(total_samples),
526            MiningStrategy::TopPercentage(p) => (total_samples as f64 * p) as usize,
527            MiningStrategy::Threshold(t) => {
528                indexed_losses.iter().filter(|(_, loss)| *loss > *t).count()
529            }
530            MiningStrategy::Focal { num_samples, .. } => (*num_samples).min(total_samples),
531        };
532
533        // Keep some easy examples for stability
534        let num_easy = (total_samples as f64 * self.keep_easy_ratio) as usize;
535
536        // Take hard examples from the front, easy examples from the back
537        let mut selected = Vec::new();
538        selected.extend(indexed_losses.iter().take(num_hard).map(|(idx, _)| *idx));
539        if num_easy > 0 {
540            selected.extend(
541                indexed_losses
542                    .iter()
543                    .skip(total_samples - num_easy)
544                    .map(|(idx, _)| *idx),
545            );
546        }
547
548        Ok(selected)
549    }
550}
551
552/// Batch reweighting based on sample importance.
553///
554/// Computes weights for each sample in a batch to emphasize important examples.
555#[derive(Debug, Clone)]
556pub struct BatchReweighter {
557    /// Reweighting strategy
558    pub strategy: ReweightingStrategy,
559}
560
561/// Strategy for reweighting samples.
562#[derive(Debug, Clone, Serialize, Deserialize)]
563pub enum ReweightingStrategy {
564    /// Uniform weights (no reweighting)
565    Uniform,
566    /// Inverse loss weighting
567    InverseLoss { epsilon: f64 },
568    /// Focal loss weighting
569    Focal { gamma: f64 },
570    /// Gradient norm based
571    GradientNorm { epsilon: f64 },
572}
573
574impl BatchReweighter {
575    /// Create a new batch reweighter.
576    pub fn new(strategy: ReweightingStrategy) -> Self {
577        Self { strategy }
578    }
579
580    /// Compute sample weights for a batch.
581    ///
582    /// # Arguments
583    /// * `losses` - Per-sample losses
584    ///
585    /// # Returns
586    /// Weight for each sample
587    pub fn compute_weights(&self, losses: &Array1<f64>) -> TrainResult<Array1<f64>> {
588        match &self.strategy {
589            ReweightingStrategy::Uniform => Ok(Array1::ones(losses.len())),
590            ReweightingStrategy::InverseLoss { epsilon } => {
591                let weights = losses.mapv(|loss| 1.0 / (loss + epsilon));
592                // Normalize
593                let sum: f64 = weights.sum();
594                Ok(weights * (losses.len() as f64 / sum))
595            }
596            ReweightingStrategy::Focal { gamma } => {
597                // Weight = (1 - p)^gamma where p = exp(-loss)
598                let weights = losses.mapv(|loss| {
599                    let p = (-loss).exp().min(0.9999);
600                    (1.0 - p).powf(*gamma)
601                });
602                // Normalize
603                let sum: f64 = weights.sum();
604                Ok(weights * (losses.len() as f64 / sum))
605            }
606            ReweightingStrategy::GradientNorm { epsilon } => {
607                // Approximate gradient norm from loss
608                let weights = losses.mapv(|loss| loss.sqrt() + epsilon);
609                let sum: f64 = weights.sum();
610                Ok(weights * (losses.len() as f64 / sum))
611            }
612        }
613    }
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619
620    #[test]
621    fn test_hard_negative_miner_topk() {
622        let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2, 0.7]);
623        let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]);
624
625        let miner = HardNegativeMiner::new(MiningStrategy::TopK(2), 0.0);
626        let selected = miner.select_samples(&losses, &labels).expect("unwrap");
627
628        // Should include all positives (0, 2, 4) and top 2 negatives (1, 3)
629        assert!(selected.contains(&0));
630        assert!(selected.contains(&2));
631        assert!(selected.contains(&4));
632        assert!(selected.contains(&1)); // Loss 0.9
633        assert!(selected.contains(&3)); // Loss 0.8
634    }
635
636    #[test]
637    fn test_hard_negative_miner_threshold() {
638        let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
639        let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 0.0]);
640
641        let miner = HardNegativeMiner::new(MiningStrategy::Threshold(0.5), 0.0);
642        let selected = miner.select_samples(&losses, &labels).expect("unwrap");
643
644        // Should include all positives and negatives with loss > 0.5
645        assert!(selected.contains(&0)); // Positive
646        assert!(selected.contains(&2)); // Positive
647        assert!(selected.contains(&1)); // Negative, loss 0.9 > 0.5
648        assert!(selected.contains(&3)); // Negative, loss 0.8 > 0.5
649        assert!(!selected.contains(&4)); // Negative, loss 0.2 < 0.5
650    }
651
652    #[test]
653    fn test_importance_sampler() {
654        let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
655        let sampler = ImportanceSampler::new(3, 42);
656
657        let selected = sampler.sample(&scores).expect("unwrap");
658        assert_eq!(selected.len(), 3);
659
660        // Higher scored items should be more likely
661        // With seed 42, we should get deterministic results
662        assert!(selected.len() <= 4);
663    }
664
665    #[test]
666    fn test_importance_sampler_without_replacement() {
667        let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
668        let sampler = ImportanceSampler::new(5, 42);
669
670        let selected = sampler.sample_without_replacement(&scores).expect("unwrap");
671
672        // Should have unique indices
673        let mut sorted = selected.clone();
674        sorted.sort_unstable();
675        sorted.dedup();
676        assert_eq!(sorted.len(), selected.len());
677    }
678
679    #[test]
680    fn test_focal_sampler() {
681        let predictions = Array1::from_vec(vec![0.9, 0.1, 0.5, 0.8, 0.3]);
682        let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 1.0, 0.0]);
683
684        let sampler = FocalSampler::new(2.0, 3);
685        let selected = sampler
686            .select_samples(&predictions, &labels)
687            .expect("unwrap");
688
689        assert_eq!(selected.len(), 3);
690        // Should select hard examples (where prediction is far from label)
691        assert!(selected.contains(&2)); // pred=0.5, label=1.0 (hard)
692    }
693
694    #[test]
695    fn test_class_balanced_sampler() {
696        let labels = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 2.0]);
697        let sampler = ClassBalancedSampler::new(2, 42);
698
699        let selected = sampler.sample(&labels).expect("unwrap");
700
701        // Should sample up to 2 from each class
702        // Class 0: 2, Class 1: 2, Class 2: 1 (only 1 available) = 5 total
703        assert_eq!(selected.len(), 5);
704
705        // Verify we got samples from each class
706        let selected_labels: Vec<f64> = selected.iter().map(|&idx| labels[idx]).collect();
707        assert!(selected_labels.contains(&0.0));
708        assert!(selected_labels.contains(&1.0));
709        assert!(selected_labels.contains(&2.0));
710    }
711
712    #[test]
713    fn test_class_balanced_weights() {
714        let labels = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 2.0]);
715        let sampler = ClassBalancedSampler::new(2, 42);
716
717        let weights = sampler.compute_class_weights(&labels).expect("unwrap");
718
719        // Class 0: 3 samples, weight = 6/(3*3) = 0.667
720        // Class 1: 2 samples, weight = 6/(3*2) = 1.0
721        // Class 2: 1 sample, weight = 6/(3*1) = 2.0
722        assert!((weights[&0] - 0.667).abs() < 0.01);
723        assert!((weights[&1] - 1.0).abs() < 0.01);
724        assert!((weights[&2] - 2.0).abs() < 0.01);
725    }
726
727    #[test]
728    fn test_curriculum_sampler() {
729        let difficulty = Array1::from_vec(vec![0.1, 0.3, 0.5, 0.7, 0.9]);
730        let mut sampler = CurriculumSampler::new(difficulty, 3);
731
732        // At 0% progress, should only select easiest samples
733        sampler.update_progress(0.0);
734        let selected = sampler.select_samples().expect("unwrap");
735        assert!(!selected.is_empty());
736
737        // At 50% progress, should include medium difficulty
738        sampler.update_progress(0.5);
739        let selected = sampler.select_samples().expect("unwrap");
740        assert!(selected.len() >= 3);
741
742        // At 100% progress, should include all samples
743        sampler.update_progress(1.0);
744        let selected = sampler.select_samples().expect("unwrap");
745        assert_eq!(selected.len(), 3);
746    }
747
748    #[test]
749    fn test_online_hard_example_miner() {
750        let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
751        let miner = OnlineHardExampleMiner::new(MiningStrategy::TopK(2), 0.2);
752
753        let selected = miner.mine_batch(&losses).expect("unwrap");
754
755        // Should keep top 2 hard (1, 3) and some easy
756        assert!(selected.len() >= 2);
757        assert!(selected.contains(&1)); // Highest loss
758        assert!(selected.contains(&3)); // Second highest
759    }
760
761    #[test]
762    fn test_batch_reweighter_uniform() {
763        let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
764        let reweighter = BatchReweighter::new(ReweightingStrategy::Uniform);
765
766        let weights = reweighter.compute_weights(&losses).expect("unwrap");
767
768        assert_eq!(weights.len(), 3);
769        assert!((weights[0] - 1.0).abs() < 1e-10);
770        assert!((weights[1] - 1.0).abs() < 1e-10);
771        assert!((weights[2] - 1.0).abs() < 1e-10);
772    }
773
774    #[test]
775    fn test_batch_reweighter_inverse_loss() {
776        let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
777        let reweighter = BatchReweighter::new(ReweightingStrategy::InverseLoss { epsilon: 0.01 });
778
779        let weights = reweighter.compute_weights(&losses).expect("unwrap");
780
781        // Lower loss should have higher weight (inverse)
782        assert!(weights[0] > weights[1]);
783        assert!(weights[1] > weights[2]);
784
785        // Weights should sum to number of samples (normalized)
786        let sum: f64 = weights.sum();
787        assert!((sum - 3.0).abs() < 0.01);
788    }
789
790    #[test]
791    fn test_batch_reweighter_focal() {
792        let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
793        let reweighter = BatchReweighter::new(ReweightingStrategy::Focal { gamma: 2.0 });
794
795        let weights = reweighter.compute_weights(&losses).expect("unwrap");
796
797        // Higher loss should have higher weight (focal emphasizes hard)
798        assert!(weights[2] > weights[1]);
799        assert!(weights[1] > weights[0]);
800
801        // Weights should sum to number of samples
802        let sum: f64 = weights.sum();
803        assert!((sum - 3.0).abs() < 0.01);
804    }
805
806    #[test]
807    fn test_hard_negative_miner_pos_neg_ratio() {
808        let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2, 0.7]);
809        let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]);
810
811        // 3 positives, ratio 1.0 means 3 negatives
812        let miner = HardNegativeMiner::new(MiningStrategy::TopK(10), 1.0);
813        let selected = miner.select_samples(&losses, &labels).expect("unwrap");
814
815        let num_pos = selected.iter().filter(|&&idx| labels[idx] > 0.5).count();
816        let num_neg = selected.iter().filter(|&&idx| labels[idx] < 0.5).count();
817
818        assert_eq!(num_pos, 3);
819        assert_eq!(num_neg, 3); // Should select 3 negatives (ratio 1:1)
820    }
821
822    #[test]
823    fn test_curriculum_sampler_progress_bounds() {
824        let difficulty = Array1::from_vec(vec![0.1, 0.5, 0.9]);
825        let mut sampler = CurriculumSampler::new(difficulty, 2);
826
827        // Test progress clamping
828        sampler.update_progress(-0.5);
829        assert_eq!(sampler.progress, 0.0);
830
831        sampler.update_progress(1.5);
832        assert_eq!(sampler.progress, 1.0);
833    }
834}