1use scirs2_core::ndarray::Array1;
35use serde::{Deserialize, Serialize};
36use std::collections::HashMap;
37
38use crate::error::{TrainError, TrainResult};
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub enum MiningStrategy {
43 TopK(usize),
45 Threshold(f64),
47 TopPercentage(f64),
49 Focal { gamma: f64, num_samples: usize },
51}
52
53#[derive(Debug, Clone)]
60pub struct HardNegativeMiner {
61 pub strategy: MiningStrategy,
63 pub pos_neg_ratio: f64,
65}
66
67impl HardNegativeMiner {
68 pub fn new(strategy: MiningStrategy, pos_neg_ratio: f64) -> Self {
70 Self {
71 strategy,
72 pos_neg_ratio,
73 }
74 }
75
76 pub fn select_samples(
85 &self,
86 losses: &Array1<f64>,
87 labels: &Array1<f64>,
88 ) -> TrainResult<Vec<usize>> {
89 if losses.len() != labels.len() {
90 return Err(TrainError::InvalidParameter(
91 "Losses and labels must have same length".to_string(),
92 ));
93 }
94
95 let mut pos_indices = Vec::new();
97 let mut neg_indices = Vec::new();
98
99 for (idx, &label) in labels.iter().enumerate() {
100 if label > 0.5 {
101 pos_indices.push(idx);
102 } else {
103 neg_indices.push(idx);
104 }
105 }
106
107 let mut selected = pos_indices.clone();
109
110 let num_negatives = if self.pos_neg_ratio > 0.0 {
112 (pos_indices.len() as f64 * self.pos_neg_ratio) as usize
113 } else {
114 match &self.strategy {
115 MiningStrategy::TopK(k) => *k,
116 MiningStrategy::TopPercentage(p) => (neg_indices.len() as f64 * p) as usize,
117 MiningStrategy::Focal { num_samples, .. } => *num_samples,
118 MiningStrategy::Threshold(_) => neg_indices.len(),
119 }
120 };
121
122 let hard_negatives = self.select_hard_negatives(losses, &neg_indices, num_negatives)?;
123 selected.extend(hard_negatives);
124
125 Ok(selected)
126 }
127
128 fn select_hard_negatives(
130 &self,
131 losses: &Array1<f64>,
132 neg_indices: &[usize],
133 num_samples: usize,
134 ) -> TrainResult<Vec<usize>> {
135 if neg_indices.is_empty() {
136 return Ok(Vec::new());
137 }
138
139 match &self.strategy {
140 MiningStrategy::TopK(_) | MiningStrategy::TopPercentage(_) => {
141 let mut neg_with_loss: Vec<(usize, f64)> =
143 neg_indices.iter().map(|&idx| (idx, losses[idx])).collect();
144 neg_with_loss.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
145
146 let k = num_samples.min(neg_with_loss.len());
147 Ok(neg_with_loss.iter().take(k).map(|(idx, _)| *idx).collect())
148 }
149 MiningStrategy::Threshold(threshold) => {
150 Ok(neg_indices
152 .iter()
153 .filter(|&&idx| losses[idx] > *threshold)
154 .copied()
155 .collect())
156 }
157 MiningStrategy::Focal { gamma, .. } => {
158 let mut neg_with_weight: Vec<(usize, f64)> = neg_indices
160 .iter()
161 .map(|&idx| {
162 let loss = losses[idx];
163 let p = (-loss).exp(); let weight = (1.0 - p).powf(*gamma);
165 (idx, weight)
166 })
167 .collect();
168 neg_with_weight.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
169
170 let k = num_samples.min(neg_with_weight.len());
171 Ok(neg_with_weight
172 .iter()
173 .take(k)
174 .map(|(idx, _)| *idx)
175 .collect())
176 }
177 }
178 }
179}
180
181#[derive(Debug, Clone)]
186pub struct ImportanceSampler {
187 pub num_samples: usize,
189 pub seed: u64,
191}
192
193impl ImportanceSampler {
194 pub fn new(num_samples: usize, seed: u64) -> Self {
196 Self { num_samples, seed }
197 }
198
199 pub fn sample(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
207 if scores.is_empty() {
208 return Ok(Vec::new());
209 }
210
211 let total: f64 = scores.iter().sum();
213 if total <= 0.0 {
214 return Err(TrainError::InvalidParameter(
215 "Importance scores must be positive".to_string(),
216 ));
217 }
218
219 let probabilities: Vec<f64> = scores.iter().map(|&s| s / total).collect();
220
221 let mut cumulative = Vec::with_capacity(probabilities.len());
223 let mut sum = 0.0;
224 for &p in &probabilities {
225 sum += p;
226 cumulative.push(sum);
227 }
228
229 let mut selected = Vec::new();
231 let mut rng_state = self.seed;
232
233 for _ in 0..self.num_samples {
234 rng_state = (rng_state.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
236 let rand = (rng_state as f64) / (0x7fffffff as f64);
237
238 match cumulative.binary_search_by(|&p| {
240 if p < rand {
241 std::cmp::Ordering::Less
242 } else {
243 std::cmp::Ordering::Greater
244 }
245 }) {
246 Ok(idx) => selected.push(idx),
247 Err(idx) => selected.push(idx.min(cumulative.len() - 1)),
248 }
249 }
250
251 Ok(selected)
252 }
253
254 pub fn sample_with_replacement(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
256 self.sample(scores)
257 }
258
259 pub fn sample_without_replacement(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
261 let mut samples = self.sample(scores)?;
262 samples.sort_unstable();
263 samples.dedup();
264 Ok(samples)
265 }
266}
267
268#[derive(Debug, Clone)]
275pub struct FocalSampler {
276 pub gamma: f64,
278 pub num_samples: usize,
280}
281
282impl FocalSampler {
283 pub fn new(gamma: f64, num_samples: usize) -> Self {
285 Self { gamma, num_samples }
286 }
287
288 pub fn select_samples(
297 &self,
298 predictions: &Array1<f64>,
299 labels: &Array1<f64>,
300 ) -> TrainResult<Vec<usize>> {
301 if predictions.len() != labels.len() {
302 return Err(TrainError::InvalidParameter(
303 "Predictions and labels must have same length".to_string(),
304 ));
305 }
306
307 let mut weights = Vec::with_capacity(predictions.len());
309 for (&pred, &label) in predictions.iter().zip(labels.iter()) {
310 let p_t = if label > 0.5 { pred } else { 1.0 - pred };
311 let weight = (1.0 - p_t).powf(self.gamma);
312 weights.push(weight);
313 }
314
315 let mut indexed_weights: Vec<(usize, f64)> = weights.into_iter().enumerate().collect();
317 indexed_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
318
319 let k = self.num_samples.min(indexed_weights.len());
320 Ok(indexed_weights
321 .iter()
322 .take(k)
323 .map(|(idx, _)| *idx)
324 .collect())
325 }
326}
327
328#[derive(Debug, Clone)]
332pub struct ClassBalancedSampler {
333 pub samples_per_class: usize,
335 pub seed: u64,
337}
338
339impl ClassBalancedSampler {
340 pub fn new(samples_per_class: usize, seed: u64) -> Self {
342 Self {
343 samples_per_class,
344 seed,
345 }
346 }
347
348 pub fn sample(&self, labels: &Array1<f64>) -> TrainResult<Vec<usize>> {
356 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
358
359 for (idx, &label) in labels.iter().enumerate() {
360 let class = label.round() as i32;
361 class_indices.entry(class).or_default().push(idx);
362 }
363
364 if class_indices.is_empty() {
365 return Ok(Vec::new());
366 }
367
368 let mut selected = Vec::new();
370 let mut rng_state = self.seed;
371
372 for (_, indices) in class_indices.iter() {
373 let num_to_sample = self.samples_per_class.min(indices.len());
374
375 let mut shuffled = indices.clone();
377 for i in 0..num_to_sample {
378 rng_state = (rng_state.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
379 let j = i + ((rng_state as usize) % (shuffled.len() - i));
380 shuffled.swap(i, j);
381 }
382
383 selected.extend_from_slice(&shuffled[..num_to_sample]);
384 }
385
386 Ok(selected)
387 }
388
389 pub fn compute_class_weights(&self, labels: &Array1<f64>) -> TrainResult<HashMap<i32, f64>> {
391 let mut class_counts: HashMap<i32, usize> = HashMap::new();
392
393 for &label in labels.iter() {
394 let class = label.round() as i32;
395 *class_counts.entry(class).or_insert(0) += 1;
396 }
397
398 let total = labels.len() as f64;
399 let num_classes = class_counts.len() as f64;
400
401 let weights: HashMap<i32, f64> = class_counts
403 .into_iter()
404 .map(|(class, count)| {
405 let weight = total / (num_classes * count as f64);
406 (class, weight)
407 })
408 .collect();
409
410 Ok(weights)
411 }
412}
413
414#[derive(Debug, Clone)]
418pub struct CurriculumSampler {
419 pub progress: f64,
421 pub difficulty_scores: Array1<f64>,
423 pub num_samples: usize,
425}
426
427impl CurriculumSampler {
428 pub fn new(difficulty_scores: Array1<f64>, num_samples: usize) -> Self {
430 Self {
431 progress: 0.0,
432 difficulty_scores,
433 num_samples,
434 }
435 }
436
437 pub fn update_progress(&mut self, progress: f64) {
439 self.progress = progress.clamp(0.0, 1.0);
440 }
441
442 pub fn select_samples(&self) -> TrainResult<Vec<usize>> {
447 let max_difficulty = self.progress;
449
450 let mut candidates: Vec<usize> = self
452 .difficulty_scores
453 .iter()
454 .enumerate()
455 .filter(|(_, &score)| score <= max_difficulty)
456 .map(|(idx, _)| idx)
457 .collect();
458
459 if candidates.len() < self.num_samples {
461 let mut all_sorted: Vec<(usize, f64)> = self
462 .difficulty_scores
463 .iter()
464 .enumerate()
465 .map(|(idx, &score)| (idx, score))
466 .collect();
467 all_sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
468
469 candidates = all_sorted
470 .iter()
471 .take(self.num_samples)
472 .map(|(idx, _)| *idx)
473 .collect();
474 }
475
476 if candidates.len() > self.num_samples {
478 candidates.truncate(self.num_samples);
479 }
480
481 Ok(candidates)
482 }
483}
484
485#[derive(Debug, Clone)]
489pub struct OnlineHardExampleMiner {
490 pub strategy: MiningStrategy,
492 pub keep_easy_ratio: f64,
494}
495
496impl OnlineHardExampleMiner {
497 pub fn new(strategy: MiningStrategy, keep_easy_ratio: f64) -> Self {
499 Self {
500 strategy,
501 keep_easy_ratio,
502 }
503 }
504
505 pub fn mine_batch(&self, losses: &Array1<f64>) -> TrainResult<Vec<usize>> {
513 if losses.is_empty() {
514 return Ok(Vec::new());
515 }
516
517 let mut indexed_losses: Vec<(usize, f64)> = losses.iter().copied().enumerate().collect();
519 indexed_losses.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
520
521 let total_samples = losses.len();
522 let num_hard = match &self.strategy {
523 MiningStrategy::TopK(k) => (*k).min(total_samples),
524 MiningStrategy::TopPercentage(p) => (total_samples as f64 * p) as usize,
525 MiningStrategy::Threshold(t) => {
526 indexed_losses.iter().filter(|(_, loss)| *loss > *t).count()
527 }
528 MiningStrategy::Focal { num_samples, .. } => (*num_samples).min(total_samples),
529 };
530
531 let num_easy = (total_samples as f64 * self.keep_easy_ratio) as usize;
533
534 let mut selected = Vec::new();
536 selected.extend(indexed_losses.iter().take(num_hard).map(|(idx, _)| *idx));
537 if num_easy > 0 {
538 selected.extend(
539 indexed_losses
540 .iter()
541 .skip(total_samples - num_easy)
542 .map(|(idx, _)| *idx),
543 );
544 }
545
546 Ok(selected)
547 }
548}
549
550#[derive(Debug, Clone)]
554pub struct BatchReweighter {
555 pub strategy: ReweightingStrategy,
557}
558
559#[derive(Debug, Clone, Serialize, Deserialize)]
561pub enum ReweightingStrategy {
562 Uniform,
564 InverseLoss { epsilon: f64 },
566 Focal { gamma: f64 },
568 GradientNorm { epsilon: f64 },
570}
571
572impl BatchReweighter {
573 pub fn new(strategy: ReweightingStrategy) -> Self {
575 Self { strategy }
576 }
577
578 pub fn compute_weights(&self, losses: &Array1<f64>) -> TrainResult<Array1<f64>> {
586 match &self.strategy {
587 ReweightingStrategy::Uniform => Ok(Array1::ones(losses.len())),
588 ReweightingStrategy::InverseLoss { epsilon } => {
589 let weights = losses.mapv(|loss| 1.0 / (loss + epsilon));
590 let sum: f64 = weights.sum();
592 Ok(weights * (losses.len() as f64 / sum))
593 }
594 ReweightingStrategy::Focal { gamma } => {
595 let weights = losses.mapv(|loss| {
597 let p = (-loss).exp().min(0.9999);
598 (1.0 - p).powf(*gamma)
599 });
600 let sum: f64 = weights.sum();
602 Ok(weights * (losses.len() as f64 / sum))
603 }
604 ReweightingStrategy::GradientNorm { epsilon } => {
605 let weights = losses.mapv(|loss| loss.sqrt() + epsilon);
607 let sum: f64 = weights.sum();
608 Ok(weights * (losses.len() as f64 / sum))
609 }
610 }
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617
618 #[test]
619 fn test_hard_negative_miner_topk() {
620 let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2, 0.7]);
621 let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]);
622
623 let miner = HardNegativeMiner::new(MiningStrategy::TopK(2), 0.0);
624 let selected = miner.select_samples(&losses, &labels).unwrap();
625
626 assert!(selected.contains(&0));
628 assert!(selected.contains(&2));
629 assert!(selected.contains(&4));
630 assert!(selected.contains(&1)); assert!(selected.contains(&3)); }
633
634 #[test]
635 fn test_hard_negative_miner_threshold() {
636 let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
637 let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 0.0]);
638
639 let miner = HardNegativeMiner::new(MiningStrategy::Threshold(0.5), 0.0);
640 let selected = miner.select_samples(&losses, &labels).unwrap();
641
642 assert!(selected.contains(&0)); assert!(selected.contains(&2)); assert!(selected.contains(&1)); assert!(selected.contains(&3)); assert!(!selected.contains(&4)); }
649
650 #[test]
651 fn test_importance_sampler() {
652 let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
653 let sampler = ImportanceSampler::new(3, 42);
654
655 let selected = sampler.sample(&scores).unwrap();
656 assert_eq!(selected.len(), 3);
657
658 assert!(selected.len() <= 4);
661 }
662
663 #[test]
664 fn test_importance_sampler_without_replacement() {
665 let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
666 let sampler = ImportanceSampler::new(5, 42);
667
668 let selected = sampler.sample_without_replacement(&scores).unwrap();
669
670 let mut sorted = selected.clone();
672 sorted.sort_unstable();
673 sorted.dedup();
674 assert_eq!(sorted.len(), selected.len());
675 }
676
677 #[test]
678 fn test_focal_sampler() {
679 let predictions = Array1::from_vec(vec![0.9, 0.1, 0.5, 0.8, 0.3]);
680 let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 1.0, 0.0]);
681
682 let sampler = FocalSampler::new(2.0, 3);
683 let selected = sampler.select_samples(&predictions, &labels).unwrap();
684
685 assert_eq!(selected.len(), 3);
686 assert!(selected.contains(&2)); }
689
690 #[test]
691 fn test_class_balanced_sampler() {
692 let labels = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 2.0]);
693 let sampler = ClassBalancedSampler::new(2, 42);
694
695 let selected = sampler.sample(&labels).unwrap();
696
697 assert_eq!(selected.len(), 5);
700
701 let selected_labels: Vec<f64> = selected.iter().map(|&idx| labels[idx]).collect();
703 assert!(selected_labels.contains(&0.0));
704 assert!(selected_labels.contains(&1.0));
705 assert!(selected_labels.contains(&2.0));
706 }
707
708 #[test]
709 fn test_class_balanced_weights() {
710 let labels = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 2.0]);
711 let sampler = ClassBalancedSampler::new(2, 42);
712
713 let weights = sampler.compute_class_weights(&labels).unwrap();
714
715 assert!((weights[&0] - 0.667).abs() < 0.01);
719 assert!((weights[&1] - 1.0).abs() < 0.01);
720 assert!((weights[&2] - 2.0).abs() < 0.01);
721 }
722
723 #[test]
724 fn test_curriculum_sampler() {
725 let difficulty = Array1::from_vec(vec![0.1, 0.3, 0.5, 0.7, 0.9]);
726 let mut sampler = CurriculumSampler::new(difficulty, 3);
727
728 sampler.update_progress(0.0);
730 let selected = sampler.select_samples().unwrap();
731 assert!(!selected.is_empty());
732
733 sampler.update_progress(0.5);
735 let selected = sampler.select_samples().unwrap();
736 assert!(selected.len() >= 3);
737
738 sampler.update_progress(1.0);
740 let selected = sampler.select_samples().unwrap();
741 assert_eq!(selected.len(), 3);
742 }
743
744 #[test]
745 fn test_online_hard_example_miner() {
746 let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
747 let miner = OnlineHardExampleMiner::new(MiningStrategy::TopK(2), 0.2);
748
749 let selected = miner.mine_batch(&losses).unwrap();
750
751 assert!(selected.len() >= 2);
753 assert!(selected.contains(&1)); assert!(selected.contains(&3)); }
756
757 #[test]
758 fn test_batch_reweighter_uniform() {
759 let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
760 let reweighter = BatchReweighter::new(ReweightingStrategy::Uniform);
761
762 let weights = reweighter.compute_weights(&losses).unwrap();
763
764 assert_eq!(weights.len(), 3);
765 assert!((weights[0] - 1.0).abs() < 1e-10);
766 assert!((weights[1] - 1.0).abs() < 1e-10);
767 assert!((weights[2] - 1.0).abs() < 1e-10);
768 }
769
770 #[test]
771 fn test_batch_reweighter_inverse_loss() {
772 let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
773 let reweighter = BatchReweighter::new(ReweightingStrategy::InverseLoss { epsilon: 0.01 });
774
775 let weights = reweighter.compute_weights(&losses).unwrap();
776
777 assert!(weights[0] > weights[1]);
779 assert!(weights[1] > weights[2]);
780
781 let sum: f64 = weights.sum();
783 assert!((sum - 3.0).abs() < 0.01);
784 }
785
786 #[test]
787 fn test_batch_reweighter_focal() {
788 let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
789 let reweighter = BatchReweighter::new(ReweightingStrategy::Focal { gamma: 2.0 });
790
791 let weights = reweighter.compute_weights(&losses).unwrap();
792
793 assert!(weights[2] > weights[1]);
795 assert!(weights[1] > weights[0]);
796
797 let sum: f64 = weights.sum();
799 assert!((sum - 3.0).abs() < 0.01);
800 }
801
802 #[test]
803 fn test_hard_negative_miner_pos_neg_ratio() {
804 let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2, 0.7]);
805 let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]);
806
807 let miner = HardNegativeMiner::new(MiningStrategy::TopK(10), 1.0);
809 let selected = miner.select_samples(&losses, &labels).unwrap();
810
811 let num_pos = selected.iter().filter(|&&idx| labels[idx] > 0.5).count();
812 let num_neg = selected.iter().filter(|&&idx| labels[idx] < 0.5).count();
813
814 assert_eq!(num_pos, 3);
815 assert_eq!(num_neg, 3); }
817
818 #[test]
819 fn test_curriculum_sampler_progress_bounds() {
820 let difficulty = Array1::from_vec(vec![0.1, 0.5, 0.9]);
821 let mut sampler = CurriculumSampler::new(difficulty, 2);
822
823 sampler.update_progress(-0.5);
825 assert_eq!(sampler.progress, 0.0);
826
827 sampler.update_progress(1.5);
828 assert_eq!(sampler.progress, 1.0);
829 }
830}