1use scirs2_core::ndarray::Array1;
35use serde::{Deserialize, Serialize};
36use std::collections::HashMap;
37
38use crate::error::{TrainError, TrainResult};
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub enum MiningStrategy {
43 TopK(usize),
45 Threshold(f64),
47 TopPercentage(f64),
49 Focal { gamma: f64, num_samples: usize },
51}
52
53#[derive(Debug, Clone)]
60pub struct HardNegativeMiner {
61 pub strategy: MiningStrategy,
63 pub pos_neg_ratio: f64,
65}
66
67impl HardNegativeMiner {
68 pub fn new(strategy: MiningStrategy, pos_neg_ratio: f64) -> Self {
70 Self {
71 strategy,
72 pos_neg_ratio,
73 }
74 }
75
76 pub fn select_samples(
85 &self,
86 losses: &Array1<f64>,
87 labels: &Array1<f64>,
88 ) -> TrainResult<Vec<usize>> {
89 if losses.len() != labels.len() {
90 return Err(TrainError::InvalidParameter(
91 "Losses and labels must have same length".to_string(),
92 ));
93 }
94
95 let mut pos_indices = Vec::new();
97 let mut neg_indices = Vec::new();
98
99 for (idx, &label) in labels.iter().enumerate() {
100 if label > 0.5 {
101 pos_indices.push(idx);
102 } else {
103 neg_indices.push(idx);
104 }
105 }
106
107 let mut selected = pos_indices.clone();
109
110 let num_negatives = if self.pos_neg_ratio > 0.0 {
112 (pos_indices.len() as f64 * self.pos_neg_ratio) as usize
113 } else {
114 match &self.strategy {
115 MiningStrategy::TopK(k) => *k,
116 MiningStrategy::TopPercentage(p) => (neg_indices.len() as f64 * p) as usize,
117 MiningStrategy::Focal { num_samples, .. } => *num_samples,
118 MiningStrategy::Threshold(_) => neg_indices.len(),
119 }
120 };
121
122 let hard_negatives = self.select_hard_negatives(losses, &neg_indices, num_negatives)?;
123 selected.extend(hard_negatives);
124
125 Ok(selected)
126 }
127
128 fn select_hard_negatives(
130 &self,
131 losses: &Array1<f64>,
132 neg_indices: &[usize],
133 num_samples: usize,
134 ) -> TrainResult<Vec<usize>> {
135 if neg_indices.is_empty() {
136 return Ok(Vec::new());
137 }
138
139 match &self.strategy {
140 MiningStrategy::TopK(_) | MiningStrategy::TopPercentage(_) => {
141 let mut neg_with_loss: Vec<(usize, f64)> =
143 neg_indices.iter().map(|&idx| (idx, losses[idx])).collect();
144 neg_with_loss
145 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
146
147 let k = num_samples.min(neg_with_loss.len());
148 Ok(neg_with_loss.iter().take(k).map(|(idx, _)| *idx).collect())
149 }
150 MiningStrategy::Threshold(threshold) => {
151 Ok(neg_indices
153 .iter()
154 .filter(|&&idx| losses[idx] > *threshold)
155 .copied()
156 .collect())
157 }
158 MiningStrategy::Focal { gamma, .. } => {
159 let mut neg_with_weight: Vec<(usize, f64)> = neg_indices
161 .iter()
162 .map(|&idx| {
163 let loss = losses[idx];
164 let p = (-loss).exp(); let weight = (1.0 - p).powf(*gamma);
166 (idx, weight)
167 })
168 .collect();
169 neg_with_weight
170 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
171
172 let k = num_samples.min(neg_with_weight.len());
173 Ok(neg_with_weight
174 .iter()
175 .take(k)
176 .map(|(idx, _)| *idx)
177 .collect())
178 }
179 }
180 }
181}
182
183#[derive(Debug, Clone)]
188pub struct ImportanceSampler {
189 pub num_samples: usize,
191 pub seed: u64,
193}
194
195impl ImportanceSampler {
196 pub fn new(num_samples: usize, seed: u64) -> Self {
198 Self { num_samples, seed }
199 }
200
201 pub fn sample(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
209 if scores.is_empty() {
210 return Ok(Vec::new());
211 }
212
213 let total: f64 = scores.iter().sum();
215 if total <= 0.0 {
216 return Err(TrainError::InvalidParameter(
217 "Importance scores must be positive".to_string(),
218 ));
219 }
220
221 let probabilities: Vec<f64> = scores.iter().map(|&s| s / total).collect();
222
223 let mut cumulative = Vec::with_capacity(probabilities.len());
225 let mut sum = 0.0;
226 for &p in &probabilities {
227 sum += p;
228 cumulative.push(sum);
229 }
230
231 let mut selected = Vec::new();
233 let mut rng_state = self.seed;
234
235 for _ in 0..self.num_samples {
236 rng_state = (rng_state.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
238 let rand = (rng_state as f64) / (0x7fffffff as f64);
239
240 match cumulative.binary_search_by(|&p| {
242 if p < rand {
243 std::cmp::Ordering::Less
244 } else {
245 std::cmp::Ordering::Greater
246 }
247 }) {
248 Ok(idx) => selected.push(idx),
249 Err(idx) => selected.push(idx.min(cumulative.len() - 1)),
250 }
251 }
252
253 Ok(selected)
254 }
255
256 pub fn sample_with_replacement(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
258 self.sample(scores)
259 }
260
261 pub fn sample_without_replacement(&self, scores: &Array1<f64>) -> TrainResult<Vec<usize>> {
263 let mut samples = self.sample(scores)?;
264 samples.sort_unstable();
265 samples.dedup();
266 Ok(samples)
267 }
268}
269
270#[derive(Debug, Clone)]
277pub struct FocalSampler {
278 pub gamma: f64,
280 pub num_samples: usize,
282}
283
284impl FocalSampler {
285 pub fn new(gamma: f64, num_samples: usize) -> Self {
287 Self { gamma, num_samples }
288 }
289
290 pub fn select_samples(
299 &self,
300 predictions: &Array1<f64>,
301 labels: &Array1<f64>,
302 ) -> TrainResult<Vec<usize>> {
303 if predictions.len() != labels.len() {
304 return Err(TrainError::InvalidParameter(
305 "Predictions and labels must have same length".to_string(),
306 ));
307 }
308
309 let mut weights = Vec::with_capacity(predictions.len());
311 for (&pred, &label) in predictions.iter().zip(labels.iter()) {
312 let p_t = if label > 0.5 { pred } else { 1.0 - pred };
313 let weight = (1.0 - p_t).powf(self.gamma);
314 weights.push(weight);
315 }
316
317 let mut indexed_weights: Vec<(usize, f64)> = weights.into_iter().enumerate().collect();
319 indexed_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
320
321 let k = self.num_samples.min(indexed_weights.len());
322 Ok(indexed_weights
323 .iter()
324 .take(k)
325 .map(|(idx, _)| *idx)
326 .collect())
327 }
328}
329
330#[derive(Debug, Clone)]
334pub struct ClassBalancedSampler {
335 pub samples_per_class: usize,
337 pub seed: u64,
339}
340
341impl ClassBalancedSampler {
342 pub fn new(samples_per_class: usize, seed: u64) -> Self {
344 Self {
345 samples_per_class,
346 seed,
347 }
348 }
349
350 pub fn sample(&self, labels: &Array1<f64>) -> TrainResult<Vec<usize>> {
358 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
360
361 for (idx, &label) in labels.iter().enumerate() {
362 let class = label.round() as i32;
363 class_indices.entry(class).or_default().push(idx);
364 }
365
366 if class_indices.is_empty() {
367 return Ok(Vec::new());
368 }
369
370 let mut selected = Vec::new();
372 let mut rng_state = self.seed;
373
374 for (_, indices) in class_indices.iter() {
375 let num_to_sample = self.samples_per_class.min(indices.len());
376
377 let mut shuffled = indices.clone();
379 for i in 0..num_to_sample {
380 rng_state = (rng_state.wrapping_mul(1103515245).wrapping_add(12345)) & 0x7fffffff;
381 let j = i + ((rng_state as usize) % (shuffled.len() - i));
382 shuffled.swap(i, j);
383 }
384
385 selected.extend_from_slice(&shuffled[..num_to_sample]);
386 }
387
388 Ok(selected)
389 }
390
391 pub fn compute_class_weights(&self, labels: &Array1<f64>) -> TrainResult<HashMap<i32, f64>> {
393 let mut class_counts: HashMap<i32, usize> = HashMap::new();
394
395 for &label in labels.iter() {
396 let class = label.round() as i32;
397 *class_counts.entry(class).or_insert(0) += 1;
398 }
399
400 let total = labels.len() as f64;
401 let num_classes = class_counts.len() as f64;
402
403 let weights: HashMap<i32, f64> = class_counts
405 .into_iter()
406 .map(|(class, count)| {
407 let weight = total / (num_classes * count as f64);
408 (class, weight)
409 })
410 .collect();
411
412 Ok(weights)
413 }
414}
415
416#[derive(Debug, Clone)]
420pub struct CurriculumSampler {
421 pub progress: f64,
423 pub difficulty_scores: Array1<f64>,
425 pub num_samples: usize,
427}
428
429impl CurriculumSampler {
430 pub fn new(difficulty_scores: Array1<f64>, num_samples: usize) -> Self {
432 Self {
433 progress: 0.0,
434 difficulty_scores,
435 num_samples,
436 }
437 }
438
439 pub fn update_progress(&mut self, progress: f64) {
441 self.progress = progress.clamp(0.0, 1.0);
442 }
443
444 pub fn select_samples(&self) -> TrainResult<Vec<usize>> {
449 let max_difficulty = self.progress;
451
452 let mut candidates: Vec<usize> = self
454 .difficulty_scores
455 .iter()
456 .enumerate()
457 .filter(|(_, &score)| score <= max_difficulty)
458 .map(|(idx, _)| idx)
459 .collect();
460
461 if candidates.len() < self.num_samples {
463 let mut all_sorted: Vec<(usize, f64)> = self
464 .difficulty_scores
465 .iter()
466 .enumerate()
467 .map(|(idx, &score)| (idx, score))
468 .collect();
469 all_sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
470
471 candidates = all_sorted
472 .iter()
473 .take(self.num_samples)
474 .map(|(idx, _)| *idx)
475 .collect();
476 }
477
478 if candidates.len() > self.num_samples {
480 candidates.truncate(self.num_samples);
481 }
482
483 Ok(candidates)
484 }
485}
486
487#[derive(Debug, Clone)]
491pub struct OnlineHardExampleMiner {
492 pub strategy: MiningStrategy,
494 pub keep_easy_ratio: f64,
496}
497
498impl OnlineHardExampleMiner {
499 pub fn new(strategy: MiningStrategy, keep_easy_ratio: f64) -> Self {
501 Self {
502 strategy,
503 keep_easy_ratio,
504 }
505 }
506
507 pub fn mine_batch(&self, losses: &Array1<f64>) -> TrainResult<Vec<usize>> {
515 if losses.is_empty() {
516 return Ok(Vec::new());
517 }
518
519 let mut indexed_losses: Vec<(usize, f64)> = losses.iter().copied().enumerate().collect();
521 indexed_losses.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
522
523 let total_samples = losses.len();
524 let num_hard = match &self.strategy {
525 MiningStrategy::TopK(k) => (*k).min(total_samples),
526 MiningStrategy::TopPercentage(p) => (total_samples as f64 * p) as usize,
527 MiningStrategy::Threshold(t) => {
528 indexed_losses.iter().filter(|(_, loss)| *loss > *t).count()
529 }
530 MiningStrategy::Focal { num_samples, .. } => (*num_samples).min(total_samples),
531 };
532
533 let num_easy = (total_samples as f64 * self.keep_easy_ratio) as usize;
535
536 let mut selected = Vec::new();
538 selected.extend(indexed_losses.iter().take(num_hard).map(|(idx, _)| *idx));
539 if num_easy > 0 {
540 selected.extend(
541 indexed_losses
542 .iter()
543 .skip(total_samples - num_easy)
544 .map(|(idx, _)| *idx),
545 );
546 }
547
548 Ok(selected)
549 }
550}
551
552#[derive(Debug, Clone)]
556pub struct BatchReweighter {
557 pub strategy: ReweightingStrategy,
559}
560
561#[derive(Debug, Clone, Serialize, Deserialize)]
563pub enum ReweightingStrategy {
564 Uniform,
566 InverseLoss { epsilon: f64 },
568 Focal { gamma: f64 },
570 GradientNorm { epsilon: f64 },
572}
573
574impl BatchReweighter {
575 pub fn new(strategy: ReweightingStrategy) -> Self {
577 Self { strategy }
578 }
579
580 pub fn compute_weights(&self, losses: &Array1<f64>) -> TrainResult<Array1<f64>> {
588 match &self.strategy {
589 ReweightingStrategy::Uniform => Ok(Array1::ones(losses.len())),
590 ReweightingStrategy::InverseLoss { epsilon } => {
591 let weights = losses.mapv(|loss| 1.0 / (loss + epsilon));
592 let sum: f64 = weights.sum();
594 Ok(weights * (losses.len() as f64 / sum))
595 }
596 ReweightingStrategy::Focal { gamma } => {
597 let weights = losses.mapv(|loss| {
599 let p = (-loss).exp().min(0.9999);
600 (1.0 - p).powf(*gamma)
601 });
602 let sum: f64 = weights.sum();
604 Ok(weights * (losses.len() as f64 / sum))
605 }
606 ReweightingStrategy::GradientNorm { epsilon } => {
607 let weights = losses.mapv(|loss| loss.sqrt() + epsilon);
609 let sum: f64 = weights.sum();
610 Ok(weights * (losses.len() as f64 / sum))
611 }
612 }
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn test_hard_negative_miner_topk() {
622 let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2, 0.7]);
623 let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]);
624
625 let miner = HardNegativeMiner::new(MiningStrategy::TopK(2), 0.0);
626 let selected = miner.select_samples(&losses, &labels).expect("unwrap");
627
628 assert!(selected.contains(&0));
630 assert!(selected.contains(&2));
631 assert!(selected.contains(&4));
632 assert!(selected.contains(&1)); assert!(selected.contains(&3)); }
635
636 #[test]
637 fn test_hard_negative_miner_threshold() {
638 let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
639 let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 0.0]);
640
641 let miner = HardNegativeMiner::new(MiningStrategy::Threshold(0.5), 0.0);
642 let selected = miner.select_samples(&losses, &labels).expect("unwrap");
643
644 assert!(selected.contains(&0)); assert!(selected.contains(&2)); assert!(selected.contains(&1)); assert!(selected.contains(&3)); assert!(!selected.contains(&4)); }
651
652 #[test]
653 fn test_importance_sampler() {
654 let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
655 let sampler = ImportanceSampler::new(3, 42);
656
657 let selected = sampler.sample(&scores).expect("unwrap");
658 assert_eq!(selected.len(), 3);
659
660 assert!(selected.len() <= 4);
663 }
664
665 #[test]
666 fn test_importance_sampler_without_replacement() {
667 let scores = Array1::from_vec(vec![0.1, 0.5, 0.9, 0.3]);
668 let sampler = ImportanceSampler::new(5, 42);
669
670 let selected = sampler.sample_without_replacement(&scores).expect("unwrap");
671
672 let mut sorted = selected.clone();
674 sorted.sort_unstable();
675 sorted.dedup();
676 assert_eq!(sorted.len(), selected.len());
677 }
678
679 #[test]
680 fn test_focal_sampler() {
681 let predictions = Array1::from_vec(vec![0.9, 0.1, 0.5, 0.8, 0.3]);
682 let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 1.0, 0.0]);
683
684 let sampler = FocalSampler::new(2.0, 3);
685 let selected = sampler
686 .select_samples(&predictions, &labels)
687 .expect("unwrap");
688
689 assert_eq!(selected.len(), 3);
690 assert!(selected.contains(&2)); }
693
694 #[test]
695 fn test_class_balanced_sampler() {
696 let labels = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 2.0]);
697 let sampler = ClassBalancedSampler::new(2, 42);
698
699 let selected = sampler.sample(&labels).expect("unwrap");
700
701 assert_eq!(selected.len(), 5);
704
705 let selected_labels: Vec<f64> = selected.iter().map(|&idx| labels[idx]).collect();
707 assert!(selected_labels.contains(&0.0));
708 assert!(selected_labels.contains(&1.0));
709 assert!(selected_labels.contains(&2.0));
710 }
711
712 #[test]
713 fn test_class_balanced_weights() {
714 let labels = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 2.0]);
715 let sampler = ClassBalancedSampler::new(2, 42);
716
717 let weights = sampler.compute_class_weights(&labels).expect("unwrap");
718
719 assert!((weights[&0] - 0.667).abs() < 0.01);
723 assert!((weights[&1] - 1.0).abs() < 0.01);
724 assert!((weights[&2] - 2.0).abs() < 0.01);
725 }
726
727 #[test]
728 fn test_curriculum_sampler() {
729 let difficulty = Array1::from_vec(vec![0.1, 0.3, 0.5, 0.7, 0.9]);
730 let mut sampler = CurriculumSampler::new(difficulty, 3);
731
732 sampler.update_progress(0.0);
734 let selected = sampler.select_samples().expect("unwrap");
735 assert!(!selected.is_empty());
736
737 sampler.update_progress(0.5);
739 let selected = sampler.select_samples().expect("unwrap");
740 assert!(selected.len() >= 3);
741
742 sampler.update_progress(1.0);
744 let selected = sampler.select_samples().expect("unwrap");
745 assert_eq!(selected.len(), 3);
746 }
747
748 #[test]
749 fn test_online_hard_example_miner() {
750 let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2]);
751 let miner = OnlineHardExampleMiner::new(MiningStrategy::TopK(2), 0.2);
752
753 let selected = miner.mine_batch(&losses).expect("unwrap");
754
755 assert!(selected.len() >= 2);
757 assert!(selected.contains(&1)); assert!(selected.contains(&3)); }
760
761 #[test]
762 fn test_batch_reweighter_uniform() {
763 let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
764 let reweighter = BatchReweighter::new(ReweightingStrategy::Uniform);
765
766 let weights = reweighter.compute_weights(&losses).expect("unwrap");
767
768 assert_eq!(weights.len(), 3);
769 assert!((weights[0] - 1.0).abs() < 1e-10);
770 assert!((weights[1] - 1.0).abs() < 1e-10);
771 assert!((weights[2] - 1.0).abs() < 1e-10);
772 }
773
774 #[test]
775 fn test_batch_reweighter_inverse_loss() {
776 let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
777 let reweighter = BatchReweighter::new(ReweightingStrategy::InverseLoss { epsilon: 0.01 });
778
779 let weights = reweighter.compute_weights(&losses).expect("unwrap");
780
781 assert!(weights[0] > weights[1]);
783 assert!(weights[1] > weights[2]);
784
785 let sum: f64 = weights.sum();
787 assert!((sum - 3.0).abs() < 0.01);
788 }
789
790 #[test]
791 fn test_batch_reweighter_focal() {
792 let losses = Array1::from_vec(vec![0.1, 0.5, 0.9]);
793 let reweighter = BatchReweighter::new(ReweightingStrategy::Focal { gamma: 2.0 });
794
795 let weights = reweighter.compute_weights(&losses).expect("unwrap");
796
797 assert!(weights[2] > weights[1]);
799 assert!(weights[1] > weights[0]);
800
801 let sum: f64 = weights.sum();
803 assert!((sum - 3.0).abs() < 0.01);
804 }
805
806 #[test]
807 fn test_hard_negative_miner_pos_neg_ratio() {
808 let losses = Array1::from_vec(vec![0.1, 0.9, 0.3, 0.8, 0.2, 0.7]);
809 let labels = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]);
810
811 let miner = HardNegativeMiner::new(MiningStrategy::TopK(10), 1.0);
813 let selected = miner.select_samples(&losses, &labels).expect("unwrap");
814
815 let num_pos = selected.iter().filter(|&&idx| labels[idx] > 0.5).count();
816 let num_neg = selected.iter().filter(|&&idx| labels[idx] < 0.5).count();
817
818 assert_eq!(num_pos, 3);
819 assert_eq!(num_neg, 3); }
821
822 #[test]
823 fn test_curriculum_sampler_progress_bounds() {
824 let difficulty = Array1::from_vec(vec![0.1, 0.5, 0.9]);
825 let mut sampler = CurriculumSampler::new(difficulty, 2);
826
827 sampler.update_progress(-0.5);
829 assert_eq!(sampler.progress, 0.0);
830
831 sampler.update_progress(1.5);
832 assert_eq!(sampler.progress, 1.0);
833 }
834}