1#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9use scirs2_core::RngExt;
11
12use super::core::{rng_utils, Sampler, SamplerIterator};
13
14#[derive(Clone, Debug, PartialEq)]
19pub enum AdaptiveStrategy {
20 HardSampling { intensity: f64 },
29
30 EasySampling { intensity: f64 },
39
40 Uniform,
45
46 Uncertainty { temperature: f64 },
54
55 InverseFrequency { power: f64 },
64
65 GradientMagnitude { threshold: f64 },
74}
75
76impl Default for AdaptiveStrategy {
77 fn default() -> Self {
78 AdaptiveStrategy::Uniform
79 }
80}
81
82#[derive(Clone)]
116pub struct AdaptiveSampler {
117 dataset_size: usize,
118 num_samples: usize,
119 strategies: Vec<AdaptiveStrategy>,
120 strategy_weights: Vec<f64>,
121 sample_losses: Vec<f64>,
122 sample_difficulties: Vec<f64>,
123 sample_frequencies: Vec<usize>,
124 adaptation_rate: f64,
125 smoothing_factor: f64,
126 current_epoch: usize,
127 warmup_epochs: usize,
128 generator: Option<u64>,
129}
130
131impl AdaptiveSampler {
132 pub fn new(dataset_size: usize, num_samples: usize) -> Self {
150 let strategies = vec![
151 AdaptiveStrategy::HardSampling { intensity: 1.0 },
152 AdaptiveStrategy::Uniform,
153 AdaptiveStrategy::Uncertainty { temperature: 1.0 },
154 ];
155
156 let strategy_weights = vec![0.4, 0.3, 0.3];
157
158 Self {
159 dataset_size,
160 num_samples,
161 strategies,
162 strategy_weights,
163 sample_losses: vec![0.0; dataset_size],
164 sample_difficulties: vec![0.0; dataset_size],
165 sample_frequencies: vec![0; dataset_size],
166 adaptation_rate: 0.1,
167 smoothing_factor: 0.9,
168 current_epoch: 0,
169 warmup_epochs: 5,
170 generator: None,
171 }
172 }
173
174 pub fn add_strategy(mut self, strategy: AdaptiveStrategy, weight: f64) -> Self {
181 self.strategies.push(strategy);
182 self.strategy_weights.push(weight);
183 self.normalize_strategy_weights();
184 self
185 }
186
187 pub fn with_adaptation_rate(mut self, rate: f64) -> Self {
193 self.adaptation_rate = rate.clamp(0.0, 1.0);
194 self
195 }
196
197 pub fn with_smoothing_factor(mut self, factor: f64) -> Self {
203 self.smoothing_factor = factor.clamp(0.0, 1.0);
204 self
205 }
206
207 pub fn with_warmup_epochs(mut self, epochs: usize) -> Self {
213 self.warmup_epochs = epochs;
214 self
215 }
216
217 pub fn with_generator(mut self, seed: u64) -> Self {
223 self.generator = Some(seed);
224 self
225 }
226
227 pub fn current_epoch(&self) -> usize {
229 self.current_epoch
230 }
231
232 pub fn warmup_epochs(&self) -> usize {
234 self.warmup_epochs
235 }
236
237 pub fn adaptation_rate(&self) -> f64 {
239 self.adaptation_rate
240 }
241
242 pub fn smoothing_factor(&self) -> f64 {
244 self.smoothing_factor
245 }
246
247 pub fn strategy_weights(&self) -> &[f64] {
249 &self.strategy_weights
250 }
251
252 pub fn strategies(&self) -> &[AdaptiveStrategy] {
254 &self.strategies
255 }
256
257 pub fn sample_losses(&self) -> &[f64] {
259 &self.sample_losses
260 }
261
262 pub fn sample_difficulties(&self) -> &[f64] {
264 &self.sample_difficulties
265 }
266
267 pub fn sample_frequencies(&self) -> &[usize] {
269 &self.sample_frequencies
270 }
271
272 pub fn is_warming_up(&self) -> bool {
274 self.current_epoch < self.warmup_epochs
275 }
276
277 pub fn update_sample_losses(&mut self, sample_indices: &[usize], losses: &[f64]) {
291 assert_eq!(sample_indices.len(), losses.len());
292
293 for (&idx, &loss) in sample_indices.iter().zip(losses.iter()) {
294 if idx < self.dataset_size {
295 self.sample_losses[idx] = self.smoothing_factor * self.sample_losses[idx]
297 + (1.0 - self.smoothing_factor) * loss;
298
299 self.sample_frequencies[idx] += 1;
301 }
302 }
303
304 self.update_sample_difficulties();
305 self.adapt_strategy_weights();
306 }
307
308 pub fn set_epoch(&mut self, epoch: usize) {
314 self.current_epoch = epoch;
315 }
316
317 pub fn reset(&mut self) {
319 self.sample_losses.fill(0.0);
320 self.sample_difficulties.fill(0.0);
321 self.sample_frequencies.fill(0);
322 self.current_epoch = 0;
323 }
324
325 pub fn adaptive_stats(&self) -> AdaptiveStats {
327 let hard_samples = self
328 .sample_difficulties
329 .iter()
330 .filter(|&&d| d > 0.5)
331 .count();
332 let max_freq = self.sample_frequencies.iter().max().copied().unwrap_or(0);
333 let min_freq = self.sample_frequencies.iter().min().copied().unwrap_or(0);
334 let mean_loss = self.sample_losses.iter().sum::<f64>() / self.sample_losses.len() as f64;
335
336 AdaptiveStats {
337 current_epoch: self.current_epoch,
338 warmup_epochs: self.warmup_epochs,
339 is_warming_up: self.is_warming_up(),
340 hard_samples_count: hard_samples,
341 hard_samples_ratio: hard_samples as f64 / self.dataset_size as f64,
342 frequency_imbalance: if min_freq > 0 {
343 max_freq as f64 / min_freq as f64
344 } else {
345 0.0
346 },
347 mean_loss,
348 adaptation_rate: self.adaptation_rate,
349 num_strategies: self.strategies.len(),
350 }
351 }
352
353 fn update_sample_difficulties(&mut self) {
355 if self.sample_losses.is_empty() {
356 return;
357 }
358
359 let mean_loss = self.sample_losses.iter().sum::<f64>() / self.sample_losses.len() as f64;
360 let variance = self
361 .sample_losses
362 .iter()
363 .map(|&loss| (loss - mean_loss).powi(2))
364 .sum::<f64>()
365 / self.sample_losses.len() as f64;
366 let std_dev = variance.sqrt();
367
368 for (i, &loss) in self.sample_losses.iter().enumerate() {
369 self.sample_difficulties[i] = if std_dev > 0.0 {
371 (loss - mean_loss) / std_dev
372 } else {
373 0.0
374 };
375 }
376 }
377
378 fn adapt_strategy_weights(&mut self) {
380 if self.is_warming_up() {
381 return;
382 }
383
384 let hard_samples_ratio = self
386 .sample_difficulties
387 .iter()
388 .filter(|&&d| d > 0.5)
389 .count() as f64
390 / self.dataset_size as f64;
391
392 let frequency_imbalance = {
393 let max_freq = self.sample_frequencies.iter().max().unwrap_or(&1);
394 let min_freq = self.sample_frequencies.iter().min().unwrap_or(&1);
395 (*max_freq as f64 / (*min_freq as f64).max(1.0)).ln()
396 };
397
398 let mut new_weights = self.strategy_weights.clone();
400
401 if hard_samples_ratio > 0.3 {
403 for (i, strategy) in self.strategies.iter().enumerate() {
404 match strategy {
405 AdaptiveStrategy::HardSampling { .. } => {
406 new_weights[i] *= 1.0 - self.adaptation_rate;
407 }
408 AdaptiveStrategy::EasySampling { .. } => {
409 new_weights[i] *= 1.0 + self.adaptation_rate;
410 }
411 _ => {}
412 }
413 }
414 }
415
416 if frequency_imbalance > 1.0 {
418 for (i, strategy) in self.strategies.iter().enumerate() {
419 if let AdaptiveStrategy::InverseFrequency { .. } = strategy {
420 new_weights[i] *= 1.0 + self.adaptation_rate;
421 }
422 }
423 }
424
425 self.strategy_weights = new_weights;
426 self.normalize_strategy_weights();
427 }
428
429 fn normalize_strategy_weights(&mut self) {
431 let sum: f64 = self.strategy_weights.iter().sum();
432 if sum > 0.0 {
433 for weight in &mut self.strategy_weights {
434 *weight /= sum;
435 }
436 } else {
437 let uniform_weight = 1.0 / self.strategy_weights.len() as f64;
439 self.strategy_weights.fill(uniform_weight);
440 }
441 }
442
443 fn get_strategy_weights(&self, strategy: &AdaptiveStrategy) -> Vec<f64> {
445 match strategy {
446 AdaptiveStrategy::HardSampling { intensity } => self
447 .sample_difficulties
448 .iter()
449 .map(|&d| (d * intensity).exp())
450 .collect(),
451 AdaptiveStrategy::EasySampling { intensity } => self
452 .sample_difficulties
453 .iter()
454 .map(|&d| (-d * intensity).exp())
455 .collect(),
456 AdaptiveStrategy::Uniform => {
457 vec![1.0; self.dataset_size]
458 }
459 AdaptiveStrategy::Uncertainty { temperature } => self
460 .sample_losses
461 .iter()
462 .map(|&loss| (loss / temperature).exp())
463 .collect(),
464 AdaptiveStrategy::InverseFrequency { power } => self
465 .sample_frequencies
466 .iter()
467 .map(|&freq| 1.0 / (freq as f64 + 1.0).powf(*power))
468 .collect(),
469 AdaptiveStrategy::GradientMagnitude { threshold } => self
470 .sample_losses
471 .iter()
472 .map(|&loss| if loss > *threshold { loss } else { 0.1 })
473 .collect(),
474 }
475 }
476
477 fn get_combined_weights(&self) -> Vec<f64> {
479 let mut combined = vec![0.0; self.dataset_size];
480
481 for (strategy, &weight) in self.strategies.iter().zip(self.strategy_weights.iter()) {
482 let strategy_weights = self.get_strategy_weights(strategy);
483 for (i, &w) in strategy_weights.iter().enumerate() {
484 combined[i] += weight * w;
485 }
486 }
487
488 for w in &mut combined {
490 *w = w.max(1e-6);
491 }
492
493 combined
494 }
495
496 fn sample_with_replacement(&self, weights: &[f64]) -> Vec<usize> {
498 let mut rng = rng_utils::create_rng(self.generator);
500
501 let weight_sum: f64 = weights.iter().sum();
503 if weight_sum <= 0.0 {
504 return (0..self.num_samples)
506 .map(|_| rng_utils::gen_range(&mut rng, 0..self.dataset_size))
507 .collect();
508 }
509
510 let mut cumulative_weights = Vec::with_capacity(weights.len());
511 let mut cumsum = 0.0;
512
513 for &weight in weights {
514 cumsum += weight / weight_sum;
515 cumulative_weights.push(cumsum);
516 }
517
518 if let Some(last) = cumulative_weights.last_mut() {
520 *last = 1.0;
521 }
522
523 (0..self.num_samples)
525 .map(|_| {
526 let rand_val: f64 = rng.random();
527 cumulative_weights
528 .binary_search_by(|&x| {
529 x.partial_cmp(&rand_val)
530 .unwrap_or(std::cmp::Ordering::Equal)
531 })
532 .unwrap_or_else(|i| i)
533 .min(self.dataset_size - 1)
534 })
535 .collect()
536 }
537}
538
539impl Sampler for AdaptiveSampler {
540 type Iter = SamplerIterator;
541
542 fn iter(&self) -> Self::Iter {
543 let weights = self.get_combined_weights();
544 let indices = self.sample_with_replacement(&weights);
545 SamplerIterator::new(indices)
546 }
547
548 fn len(&self) -> usize {
549 self.num_samples
550 }
551}
552
553#[derive(Debug, Clone, PartialEq)]
555pub struct AdaptiveStats {
556 pub current_epoch: usize,
558 pub warmup_epochs: usize,
560 pub is_warming_up: bool,
562 pub hard_samples_count: usize,
564 pub hard_samples_ratio: f64,
566 pub frequency_imbalance: f64,
568 pub mean_loss: f64,
570 pub adaptation_rate: f64,
572 pub num_strategies: usize,
574}
575
576pub fn hard_adaptive_sampler(
587 dataset_size: usize,
588 num_samples: usize,
589 intensity: f64,
590 seed: Option<u64>,
591) -> AdaptiveSampler {
592 let mut sampler = AdaptiveSampler::new(dataset_size, num_samples)
593 .add_strategy(AdaptiveStrategy::HardSampling { intensity }, 0.7);
594 if let Some(s) = seed {
595 sampler = sampler.with_generator(s);
596 }
597 sampler
598}
599
600pub fn frequency_balanced_sampler(
611 dataset_size: usize,
612 num_samples: usize,
613 power: f64,
614 seed: Option<u64>,
615) -> AdaptiveSampler {
616 let mut sampler = AdaptiveSampler::new(dataset_size, num_samples)
617 .add_strategy(AdaptiveStrategy::InverseFrequency { power }, 0.6);
618 if let Some(s) = seed {
619 sampler = sampler.with_generator(s);
620 }
621 sampler
622}
623
624pub fn uncertainty_adaptive_sampler(
635 dataset_size: usize,
636 num_samples: usize,
637 temperature: f64,
638 seed: Option<u64>,
639) -> AdaptiveSampler {
640 let mut sampler = AdaptiveSampler::new(dataset_size, num_samples)
641 .add_strategy(AdaptiveStrategy::Uncertainty { temperature }, 0.8);
642 if let Some(s) = seed {
643 sampler = sampler.with_generator(s);
644 }
645 sampler
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651
652 #[test]
653 fn test_adaptive_sampler_basic() {
654 let dataset_size = 100;
655 let num_samples = 50;
656 let sampler = AdaptiveSampler::new(dataset_size, num_samples).with_generator(42);
657
658 assert_eq!(sampler.len(), num_samples);
659 assert_eq!(sampler.current_epoch(), 0);
660 assert_eq!(sampler.warmup_epochs(), 5);
661 assert!(sampler.is_warming_up());
662 assert_eq!(sampler.strategies().len(), 3); assert_eq!(sampler.strategy_weights().len(), 3);
664
665 let indices: Vec<usize> = sampler.iter().collect();
666 assert_eq!(indices.len(), num_samples);
667
668 for &idx in &indices {
670 assert!(idx < dataset_size);
671 }
672 }
673
674 #[test]
675 fn test_adaptive_sampler_with_losses() {
676 let dataset_size = 10;
677 let num_samples = 5;
678 let mut sampler = AdaptiveSampler::new(dataset_size, num_samples).with_generator(42);
679
680 assert!(sampler.sample_difficulties().iter().all(|&d| d == 0.0));
682
683 let sample_indices = vec![0, 2, 4, 6, 8];
685 let losses = vec![0.1, 0.8, 0.2, 0.9, 0.3]; sampler.update_sample_losses(&sample_indices, &losses);
688
689 assert!((sampler.sample_losses()[0] - 0.01).abs() < 1e-10); assert!((sampler.sample_losses()[2] - 0.08).abs() < 1e-10); assert!((sampler.sample_losses()[6] - 0.09).abs() < 1e-10); assert_eq!(sampler.sample_frequencies()[0], 1);
696 assert_eq!(sampler.sample_frequencies()[2], 1);
697 assert_eq!(sampler.sample_frequencies()[1], 0); let indices: Vec<usize> = sampler.iter().collect();
701 assert_eq!(indices.len(), num_samples);
702 }
703
704 #[test]
705 fn test_adaptive_sampler_strategy_adaptation() {
706 let dataset_size = 20;
707 let num_samples = 10;
708 let mut sampler = AdaptiveSampler::new(dataset_size, num_samples)
709 .with_warmup_epochs(2)
710 .with_generator(42);
711
712 sampler = sampler.add_strategy(AdaptiveStrategy::InverseFrequency { power: 1.0 }, 0.2);
714 assert_eq!(sampler.strategies().len(), 4);
715
716 let initial_sum: f64 = sampler.strategy_weights().iter().sum();
718 assert!((initial_sum - 1.0).abs() < f64::EPSILON);
719
720 sampler.set_epoch(1);
722 assert!(sampler.is_warming_up());
723
724 let sample_indices: Vec<usize> = (0..10).collect();
725 let losses = vec![0.5; 10];
726 sampler.update_sample_losses(&sample_indices, &losses);
727
728 let _weights_during_warmup = sampler.strategy_weights().to_vec();
729
730 sampler.set_epoch(3);
732 assert!(!sampler.is_warming_up());
733
734 sampler.update_sample_losses(&sample_indices, &losses);
735 let indices: Vec<usize> = sampler.iter().collect();
738 assert_eq!(indices.len(), num_samples);
739 }
740
741 #[test]
742 fn test_adaptive_strategies() {
743 let dataset_size = 10;
744 let num_samples = 5;
745
746 let strategies = vec![
747 AdaptiveStrategy::HardSampling { intensity: 1.0 },
748 AdaptiveStrategy::EasySampling { intensity: 1.0 },
749 AdaptiveStrategy::Uniform,
750 AdaptiveStrategy::Uncertainty { temperature: 1.0 },
751 AdaptiveStrategy::InverseFrequency { power: 1.0 },
752 AdaptiveStrategy::GradientMagnitude { threshold: 0.5 },
753 ];
754
755 for strategy in strategies {
756 let sampler = AdaptiveSampler::new(dataset_size, num_samples)
757 .add_strategy(strategy, 0.5)
758 .with_generator(42);
759
760 let indices: Vec<usize> = sampler.iter().collect();
761 assert_eq!(indices.len(), num_samples);
762
763 for &idx in &indices {
765 assert!(idx < dataset_size);
766 }
767 }
768 }
769
770 #[test]
771 fn test_adaptive_sampler_difficulty_calculation() {
772 let dataset_size = 5;
773 let num_samples = 3;
774 let mut sampler = AdaptiveSampler::new(dataset_size, num_samples).with_generator(42);
775
776 let sample_indices = vec![0, 1, 2, 3, 4];
778 let losses = vec![0.1, 0.2, 0.8, 0.9, 0.3]; sampler.update_sample_losses(&sample_indices, &losses);
781
782 let difficulties = sampler.sample_difficulties();
783
784 assert!(difficulties[2] > difficulties[0]);
786 assert!(difficulties[3] > difficulties[1]);
787 assert!(difficulties[2] > 0.0);
788 assert!(difficulties[3] > 0.0);
789 }
790
791 #[test]
792 fn test_adaptive_sampler_methods() {
793 let mut sampler = AdaptiveSampler::new(20, 10)
794 .with_adaptation_rate(0.2)
795 .with_smoothing_factor(0.8)
796 .with_warmup_epochs(3)
797 .with_generator(42);
798
799 assert_eq!(sampler.adaptation_rate(), 0.2);
800 assert_eq!(sampler.smoothing_factor(), 0.8);
801 assert_eq!(sampler.warmup_epochs(), 3);
802
803 sampler.set_epoch(5);
805 assert_eq!(sampler.current_epoch(), 5);
806 assert!(!sampler.is_warming_up());
807
808 sampler.update_sample_losses(&[0, 1, 2], &[0.5, 0.6, 0.7]);
810 assert!(sampler.sample_losses().iter().any(|&l| l > 0.0));
811 assert!(sampler.sample_frequencies().iter().any(|&f| f > 0));
812
813 sampler.reset();
814 assert!(sampler.sample_losses().iter().all(|&l| l == 0.0));
815 assert!(sampler.sample_frequencies().iter().all(|&f| f == 0));
816 assert_eq!(sampler.current_epoch(), 0);
817 }
818
819 #[test]
820 fn test_adaptive_stats() {
821 let mut sampler = AdaptiveSampler::new(100, 32);
822
823 let stats = sampler.adaptive_stats();
824 assert_eq!(stats.current_epoch, 0);
825 assert_eq!(stats.warmup_epochs, 5);
826 assert!(stats.is_warming_up);
827 assert_eq!(stats.hard_samples_count, 0);
828 assert_eq!(stats.hard_samples_ratio, 0.0);
829 assert_eq!(stats.mean_loss, 0.0);
830 assert_eq!(stats.num_strategies, 3);
831
832 let sample_indices: Vec<usize> = (0..20).collect();
834 let losses: Vec<f64> = (0..20).map(|i| if i > 15 { 0.8 } else { 0.2 }).collect();
835 sampler.update_sample_losses(&sample_indices, &losses);
836
837 let stats = sampler.adaptive_stats();
838 assert!(stats.mean_loss > 0.0);
839 assert!(stats.hard_samples_count > 0);
840 assert!(stats.hard_samples_ratio > 0.0);
841 }
842
843 #[test]
844 fn test_convenience_functions() {
845 let hard_sampler = hard_adaptive_sampler(100, 32, 1.5, Some(42));
847 assert_eq!(hard_sampler.len(), 32);
848 assert!(hard_sampler.strategies().len() > 3); let freq_sampler = frequency_balanced_sampler(100, 32, 1.0, Some(42));
852 assert_eq!(freq_sampler.len(), 32);
853
854 let uncertainty_sampler = uncertainty_adaptive_sampler(100, 32, 0.8, Some(42));
856 assert_eq!(uncertainty_sampler.len(), 32);
857 }
858
859 #[test]
860 fn test_weight_normalization() {
861 let mut sampler = AdaptiveSampler::new(10, 5);
862
863 sampler = sampler
865 .add_strategy(AdaptiveStrategy::HardSampling { intensity: 1.0 }, 2.0)
866 .add_strategy(AdaptiveStrategy::EasySampling { intensity: 1.0 }, 3.0);
867
868 let sum: f64 = sampler.strategy_weights().iter().sum();
870 assert!((sum - 1.0).abs() < f64::EPSILON);
871 }
872
873 #[test]
874 fn test_strategy_weights() {
875 let sampler = AdaptiveSampler::new(5, 3);
876
877 let uniform_weights = sampler.get_strategy_weights(&AdaptiveStrategy::Uniform);
879 assert!(uniform_weights.iter().all(|&w| w == 1.0));
880
881 let mut sampler_with_data = AdaptiveSampler::new(5, 3);
883 sampler_with_data.update_sample_losses(&[0, 1, 2], &[0.1, 0.5, 0.9]);
884
885 let hard_weights = sampler_with_data
887 .get_strategy_weights(&AdaptiveStrategy::HardSampling { intensity: 1.0 });
888 assert_eq!(hard_weights.len(), 5);
889
890 let easy_weights = sampler_with_data
892 .get_strategy_weights(&AdaptiveStrategy::EasySampling { intensity: 1.0 });
893 assert_eq!(easy_weights.len(), 5);
894 }
895
896 #[test]
897 fn test_edge_cases() {
898 let empty_sampler = AdaptiveSampler::new(10, 0);
900 assert_eq!(empty_sampler.len(), 0);
901 let indices: Vec<usize> = empty_sampler.iter().collect();
902 assert!(indices.is_empty());
903
904 let single_sampler = AdaptiveSampler::new(10, 1);
906 let indices: Vec<usize> = single_sampler.iter().collect();
907 assert_eq!(indices.len(), 1);
908
909 let large_sampler = AdaptiveSampler::new(10000, 64);
911 assert_eq!(large_sampler.len(), 64);
912
913 let mut sampler = AdaptiveSampler::new(5, 3);
915 sampler.update_sample_losses(&[0, 10, 2], &[0.1, 0.5, 0.3]); let mut zero_weight_sampler = AdaptiveSampler::new(5, 3);
920 zero_weight_sampler.strategy_weights = vec![0.0, 0.0, 0.0];
921 zero_weight_sampler.normalize_strategy_weights();
922 let sum: f64 = zero_weight_sampler.strategy_weights().iter().sum();
923 assert!((sum - 1.0).abs() < f64::EPSILON);
924 }
925
926 #[test]
927 fn test_adaptive_strategy_equality() {
928 assert_eq!(
929 AdaptiveStrategy::HardSampling { intensity: 1.0 },
930 AdaptiveStrategy::HardSampling { intensity: 1.0 }
931 );
932 assert_eq!(AdaptiveStrategy::Uniform, AdaptiveStrategy::Uniform);
933 assert_ne!(
934 AdaptiveStrategy::HardSampling { intensity: 1.0 },
935 AdaptiveStrategy::EasySampling { intensity: 1.0 }
936 );
937 }
938
939 #[test]
940 fn test_adaptive_strategy_default() {
941 assert_eq!(AdaptiveStrategy::default(), AdaptiveStrategy::Uniform);
942 }
943
944 #[test]
945 fn test_parameter_clamping() {
946 let sampler = AdaptiveSampler::new(10, 5)
947 .with_adaptation_rate(1.5) .with_smoothing_factor(-0.1); assert_eq!(sampler.adaptation_rate(), 1.0);
951 assert_eq!(sampler.smoothing_factor(), 0.0);
952 }
953
954 #[test]
955 fn test_reproducibility() {
956 let mut sampler1 = AdaptiveSampler::new(20, 10).with_generator(123);
957 let mut sampler2 = AdaptiveSampler::new(20, 10).with_generator(123);
958
959 let sample_indices = vec![0, 1, 2, 3, 4];
961 let losses = vec![0.1, 0.2, 0.3, 0.4, 0.5];
962
963 sampler1.update_sample_losses(&sample_indices, &losses);
964 sampler2.update_sample_losses(&sample_indices, &losses);
965
966 let indices1: Vec<usize> = sampler1.iter().collect();
967 let indices2: Vec<usize> = sampler2.iter().collect();
968
969 assert_eq!(indices1, indices2);
970 }
971}