Skip to main content

torsh_data/sampler/
adaptive.rs

1//! Adaptive sampling functionality
2//!
3//! This module provides adaptive sampling strategies that dynamically adjust sampling
4//! behavior based on training progress, model performance, and sample characteristics.
5
6#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9// ✅ SciRS2 Policy Compliant - Using scirs2_core for all random operations
10use scirs2_core::RngExt;
11
12use super::core::{rng_utils, Sampler, SamplerIterator};
13
14/// Adaptive sampling strategies
15///
16/// These strategies define different approaches for selecting samples based on
17/// their characteristics such as difficulty, frequency, or uncertainty.
18#[derive(Clone, Debug, PartialEq)]
19pub enum AdaptiveStrategy {
20    /// Focus on hard samples (high loss)
21    ///
22    /// Prioritizes samples with high training loss, which are typically
23    /// harder for the model to learn.
24    ///
25    /// # Arguments
26    ///
27    /// * `intensity` - Controls how strongly to weight hard samples (typically 0.5-2.0)
28    HardSampling { intensity: f64 },
29
30    /// Focus on easy samples (low loss)
31    ///
32    /// Prioritizes samples with low training loss, which can help with
33    /// stable training or curriculum learning.
34    ///
35    /// # Arguments
36    ///
37    /// * `intensity` - Controls how strongly to weight easy samples (typically 0.5-2.0)
38    EasySampling { intensity: f64 },
39
40    /// Uniform sampling for exploration
41    ///
42    /// Samples all data points with equal probability, providing a baseline
43    /// for comparison and ensuring exploration.
44    Uniform,
45
46    /// Uncertainty-based sampling
47    ///
48    /// Focuses on samples where the model predictions are most uncertain.
49    ///
50    /// # Arguments
51    ///
52    /// * `temperature` - Controls the sharpness of uncertainty weighting
53    Uncertainty { temperature: f64 },
54
55    /// Frequency-based inverse sampling
56    ///
57    /// Prioritizes samples that have been seen less frequently during training,
58    /// helping to balance the training distribution.
59    ///
60    /// # Arguments
61    ///
62    /// * `power` - Controls the strength of inverse frequency weighting
63    InverseFrequency { power: f64 },
64
65    /// Gradient-based importance
66    ///
67    /// Focuses on samples that produce gradients above a certain threshold,
68    /// indicating they contribute significantly to learning.
69    ///
70    /// # Arguments
71    ///
72    /// * `threshold` - Minimum gradient magnitude for inclusion
73    GradientMagnitude { threshold: f64 },
74}
75
76impl Default for AdaptiveStrategy {
77    fn default() -> Self {
78        AdaptiveStrategy::Uniform
79    }
80}
81
82/// Adaptive sampler that dynamically adjusts sampling strategy based on training progress
83///
84/// This sampler combines multiple sampling strategies and adapts the sampling distribution
85/// based on model performance, loss patterns, and sample difficulty over time.
86///
87/// # Examples
88///
89/// ```rust,ignore
90/// use torsh_data::sampler::{AdaptiveSampler, AdaptiveStrategy, Sampler};
91///
92/// let mut sampler = AdaptiveSampler::new(1000, 64)
93///     .with_adaptation_rate(0.1)
94///     .with_warmup_epochs(5)
95///     .with_generator(42);
96///
97/// // Add custom strategy
98/// sampler = sampler.add_strategy(
99///     AdaptiveStrategy::InverseFrequency { power: 1.0 },
100///     0.2
101/// );
102///
103/// // During training, update with sample losses
104/// let sample_indices = vec![0, 1, 2, 3, 4];
105/// let losses = vec![0.5, 0.8, 0.3, 0.9, 0.2];
106/// sampler.update_sample_losses(&sample_indices, &losses);
107///
108/// // Set current epoch for adaptation
109/// sampler.set_epoch(10);
110///
111/// // Get adaptive samples
112/// let indices: Vec<usize> = sampler.iter().collect();
113/// assert_eq!(indices.len(), 64);
114/// ```
115#[derive(Clone)]
116pub struct AdaptiveSampler {
117    dataset_size: usize,
118    num_samples: usize,
119    strategies: Vec<AdaptiveStrategy>,
120    strategy_weights: Vec<f64>,
121    sample_losses: Vec<f64>,
122    sample_difficulties: Vec<f64>,
123    sample_frequencies: Vec<usize>,
124    adaptation_rate: f64,
125    smoothing_factor: f64,
126    current_epoch: usize,
127    warmup_epochs: usize,
128    generator: Option<u64>,
129}
130
131impl AdaptiveSampler {
132    /// Create a new adaptive sampler
133    ///
134    /// Creates a sampler with default strategies: hard sampling, uniform, and uncertainty.
135    ///
136    /// # Arguments
137    ///
138    /// * `dataset_size` - Total size of the dataset
139    /// * `num_samples` - Number of samples to select per iteration
140    ///
141    /// # Examples
142    ///
143    /// ```rust,ignore
144    /// use torsh_data::sampler::AdaptiveSampler;
145    ///
146    /// let sampler = AdaptiveSampler::new(1000, 32);
147    /// assert_eq!(sampler.len(), 32);
148    /// ```
149    pub fn new(dataset_size: usize, num_samples: usize) -> Self {
150        let strategies = vec![
151            AdaptiveStrategy::HardSampling { intensity: 1.0 },
152            AdaptiveStrategy::Uniform,
153            AdaptiveStrategy::Uncertainty { temperature: 1.0 },
154        ];
155
156        let strategy_weights = vec![0.4, 0.3, 0.3];
157
158        Self {
159            dataset_size,
160            num_samples,
161            strategies,
162            strategy_weights,
163            sample_losses: vec![0.0; dataset_size],
164            sample_difficulties: vec![0.0; dataset_size],
165            sample_frequencies: vec![0; dataset_size],
166            adaptation_rate: 0.1,
167            smoothing_factor: 0.9,
168            current_epoch: 0,
169            warmup_epochs: 5,
170            generator: None,
171        }
172    }
173
174    /// Add a custom sampling strategy
175    ///
176    /// # Arguments
177    ///
178    /// * `strategy` - The adaptive strategy to add
179    /// * `weight` - Initial weight for this strategy (will be normalized)
180    pub fn add_strategy(mut self, strategy: AdaptiveStrategy, weight: f64) -> Self {
181        self.strategies.push(strategy);
182        self.strategy_weights.push(weight);
183        self.normalize_strategy_weights();
184        self
185    }
186
187    /// Set adaptation rate for strategy weight updates
188    ///
189    /// # Arguments
190    ///
191    /// * `rate` - Adaptation rate (typically 0.01-0.2)
192    pub fn with_adaptation_rate(mut self, rate: f64) -> Self {
193        self.adaptation_rate = rate.clamp(0.0, 1.0);
194        self
195    }
196
197    /// Set smoothing factor for exponential moving average of losses
198    ///
199    /// # Arguments
200    ///
201    /// * `factor` - Smoothing factor (0.0-1.0, higher values = more smoothing)
202    pub fn with_smoothing_factor(mut self, factor: f64) -> Self {
203        self.smoothing_factor = factor.clamp(0.0, 1.0);
204        self
205    }
206
207    /// Set number of warmup epochs before adaptation begins
208    ///
209    /// # Arguments
210    ///
211    /// * `epochs` - Number of warmup epochs
212    pub fn with_warmup_epochs(mut self, epochs: usize) -> Self {
213        self.warmup_epochs = epochs;
214        self
215    }
216
217    /// Set random generator seed
218    ///
219    /// # Arguments
220    ///
221    /// * `seed` - Random seed for reproducible sampling
222    pub fn with_generator(mut self, seed: u64) -> Self {
223        self.generator = Some(seed);
224        self
225    }
226
227    /// Get the current epoch
228    pub fn current_epoch(&self) -> usize {
229        self.current_epoch
230    }
231
232    /// Get the warmup epochs
233    pub fn warmup_epochs(&self) -> usize {
234        self.warmup_epochs
235    }
236
237    /// Get the adaptation rate
238    pub fn adaptation_rate(&self) -> f64 {
239        self.adaptation_rate
240    }
241
242    /// Get the smoothing factor
243    pub fn smoothing_factor(&self) -> f64 {
244        self.smoothing_factor
245    }
246
247    /// Get the current strategy weights
248    pub fn strategy_weights(&self) -> &[f64] {
249        &self.strategy_weights
250    }
251
252    /// Get the current strategies
253    pub fn strategies(&self) -> &[AdaptiveStrategy] {
254        &self.strategies
255    }
256
257    /// Get sample losses
258    pub fn sample_losses(&self) -> &[f64] {
259        &self.sample_losses
260    }
261
262    /// Get sample difficulties
263    pub fn sample_difficulties(&self) -> &[f64] {
264        &self.sample_difficulties
265    }
266
267    /// Get sample frequencies
268    pub fn sample_frequencies(&self) -> &[usize] {
269        &self.sample_frequencies
270    }
271
272    /// Check if the sampler is in warmup phase
273    pub fn is_warming_up(&self) -> bool {
274        self.current_epoch < self.warmup_epochs
275    }
276
277    /// Update sample losses from training
278    ///
279    /// This method should be called after each training batch to update
280    /// the sampler's understanding of sample difficulty.
281    ///
282    /// # Arguments
283    ///
284    /// * `sample_indices` - Indices of samples in the batch
285    /// * `losses` - Corresponding loss values for each sample
286    ///
287    /// # Panics
288    ///
289    /// Panics if the lengths of `sample_indices` and `losses` don't match.
290    pub fn update_sample_losses(&mut self, sample_indices: &[usize], losses: &[f64]) {
291        assert_eq!(sample_indices.len(), losses.len());
292
293        for (&idx, &loss) in sample_indices.iter().zip(losses.iter()) {
294            if idx < self.dataset_size {
295                // Exponential moving average for loss smoothing
296                self.sample_losses[idx] = self.smoothing_factor * self.sample_losses[idx]
297                    + (1.0 - self.smoothing_factor) * loss;
298
299                // Track sample frequency
300                self.sample_frequencies[idx] += 1;
301            }
302        }
303
304        self.update_sample_difficulties();
305        self.adapt_strategy_weights();
306    }
307
308    /// Set current epoch for adaptation tracking
309    ///
310    /// # Arguments
311    ///
312    /// * `epoch` - Current training epoch
313    pub fn set_epoch(&mut self, epoch: usize) {
314        self.current_epoch = epoch;
315    }
316
317    /// Reset sampler state
318    pub fn reset(&mut self) {
319        self.sample_losses.fill(0.0);
320        self.sample_difficulties.fill(0.0);
321        self.sample_frequencies.fill(0);
322        self.current_epoch = 0;
323    }
324
325    /// Get statistics about the current adaptive sampling state
326    pub fn adaptive_stats(&self) -> AdaptiveStats {
327        let hard_samples = self
328            .sample_difficulties
329            .iter()
330            .filter(|&&d| d > 0.5)
331            .count();
332        let max_freq = self.sample_frequencies.iter().max().copied().unwrap_or(0);
333        let min_freq = self.sample_frequencies.iter().min().copied().unwrap_or(0);
334        let mean_loss = self.sample_losses.iter().sum::<f64>() / self.sample_losses.len() as f64;
335
336        AdaptiveStats {
337            current_epoch: self.current_epoch,
338            warmup_epochs: self.warmup_epochs,
339            is_warming_up: self.is_warming_up(),
340            hard_samples_count: hard_samples,
341            hard_samples_ratio: hard_samples as f64 / self.dataset_size as f64,
342            frequency_imbalance: if min_freq > 0 {
343                max_freq as f64 / min_freq as f64
344            } else {
345                0.0
346            },
347            mean_loss,
348            adaptation_rate: self.adaptation_rate,
349            num_strategies: self.strategies.len(),
350        }
351    }
352
353    /// Update sample difficulties based on current losses
354    fn update_sample_difficulties(&mut self) {
355        if self.sample_losses.is_empty() {
356            return;
357        }
358
359        let mean_loss = self.sample_losses.iter().sum::<f64>() / self.sample_losses.len() as f64;
360        let variance = self
361            .sample_losses
362            .iter()
363            .map(|&loss| (loss - mean_loss).powi(2))
364            .sum::<f64>()
365            / self.sample_losses.len() as f64;
366        let std_dev = variance.sqrt();
367
368        for (i, &loss) in self.sample_losses.iter().enumerate() {
369            // Normalize difficulty score
370            self.sample_difficulties[i] = if std_dev > 0.0 {
371                (loss - mean_loss) / std_dev
372            } else {
373                0.0
374            };
375        }
376    }
377
378    /// Adapt strategy weights based on current training state
379    fn adapt_strategy_weights(&mut self) {
380        if self.is_warming_up() {
381            return;
382        }
383
384        // Calculate strategy performance metrics
385        let hard_samples_ratio = self
386            .sample_difficulties
387            .iter()
388            .filter(|&&d| d > 0.5)
389            .count() as f64
390            / self.dataset_size as f64;
391
392        let frequency_imbalance = {
393            let max_freq = self.sample_frequencies.iter().max().unwrap_or(&1);
394            let min_freq = self.sample_frequencies.iter().min().unwrap_or(&1);
395            (*max_freq as f64 / (*min_freq as f64).max(1.0)).ln()
396        };
397
398        // Adjust weights based on current training state
399        let mut new_weights = self.strategy_weights.clone();
400
401        // If too many hard samples, reduce hard sampling
402        if hard_samples_ratio > 0.3 {
403            for (i, strategy) in self.strategies.iter().enumerate() {
404                match strategy {
405                    AdaptiveStrategy::HardSampling { .. } => {
406                        new_weights[i] *= 1.0 - self.adaptation_rate;
407                    }
408                    AdaptiveStrategy::EasySampling { .. } => {
409                        new_weights[i] *= 1.0 + self.adaptation_rate;
410                    }
411                    _ => {}
412                }
413            }
414        }
415
416        // If high frequency imbalance, increase inverse frequency sampling
417        if frequency_imbalance > 1.0 {
418            for (i, strategy) in self.strategies.iter().enumerate() {
419                if let AdaptiveStrategy::InverseFrequency { .. } = strategy {
420                    new_weights[i] *= 1.0 + self.adaptation_rate;
421                }
422            }
423        }
424
425        self.strategy_weights = new_weights;
426        self.normalize_strategy_weights();
427    }
428
429    /// Normalize strategy weights to sum to 1.0
430    fn normalize_strategy_weights(&mut self) {
431        let sum: f64 = self.strategy_weights.iter().sum();
432        if sum > 0.0 {
433            for weight in &mut self.strategy_weights {
434                *weight /= sum;
435            }
436        } else {
437            // If all weights are zero, reset to uniform
438            let uniform_weight = 1.0 / self.strategy_weights.len() as f64;
439            self.strategy_weights.fill(uniform_weight);
440        }
441    }
442
443    /// Get sampling weights for a specific strategy
444    fn get_strategy_weights(&self, strategy: &AdaptiveStrategy) -> Vec<f64> {
445        match strategy {
446            AdaptiveStrategy::HardSampling { intensity } => self
447                .sample_difficulties
448                .iter()
449                .map(|&d| (d * intensity).exp())
450                .collect(),
451            AdaptiveStrategy::EasySampling { intensity } => self
452                .sample_difficulties
453                .iter()
454                .map(|&d| (-d * intensity).exp())
455                .collect(),
456            AdaptiveStrategy::Uniform => {
457                vec![1.0; self.dataset_size]
458            }
459            AdaptiveStrategy::Uncertainty { temperature } => self
460                .sample_losses
461                .iter()
462                .map(|&loss| (loss / temperature).exp())
463                .collect(),
464            AdaptiveStrategy::InverseFrequency { power } => self
465                .sample_frequencies
466                .iter()
467                .map(|&freq| 1.0 / (freq as f64 + 1.0).powf(*power))
468                .collect(),
469            AdaptiveStrategy::GradientMagnitude { threshold } => self
470                .sample_losses
471                .iter()
472                .map(|&loss| if loss > *threshold { loss } else { 0.1 })
473                .collect(),
474        }
475    }
476
477    /// Combine weights from all strategies
478    fn get_combined_weights(&self) -> Vec<f64> {
479        let mut combined = vec![0.0; self.dataset_size];
480
481        for (strategy, &weight) in self.strategies.iter().zip(self.strategy_weights.iter()) {
482            let strategy_weights = self.get_strategy_weights(strategy);
483            for (i, &w) in strategy_weights.iter().enumerate() {
484                combined[i] += weight * w;
485            }
486        }
487
488        // Ensure all weights are positive
489        for w in &mut combined {
490            *w = w.max(1e-6);
491        }
492
493        combined
494    }
495
496    /// Sample indices using weighted sampling with replacement
497    fn sample_with_replacement(&self, weights: &[f64]) -> Vec<usize> {
498        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
499        let mut rng = rng_utils::create_rng(self.generator);
500
501        // Create cumulative distribution
502        let weight_sum: f64 = weights.iter().sum();
503        if weight_sum <= 0.0 {
504            // Fallback to uniform sampling
505            return (0..self.num_samples)
506                .map(|_| rng_utils::gen_range(&mut rng, 0..self.dataset_size))
507                .collect();
508        }
509
510        let mut cumulative_weights = Vec::with_capacity(weights.len());
511        let mut cumsum = 0.0;
512
513        for &weight in weights {
514            cumsum += weight / weight_sum;
515            cumulative_weights.push(cumsum);
516        }
517
518        // Ensure the last value is exactly 1.0
519        if let Some(last) = cumulative_weights.last_mut() {
520            *last = 1.0;
521        }
522
523        // Sample using inverse transform sampling
524        (0..self.num_samples)
525            .map(|_| {
526                let rand_val: f64 = rng.random();
527                cumulative_weights
528                    .binary_search_by(|&x| {
529                        x.partial_cmp(&rand_val)
530                            .unwrap_or(std::cmp::Ordering::Equal)
531                    })
532                    .unwrap_or_else(|i| i)
533                    .min(self.dataset_size - 1)
534            })
535            .collect()
536    }
537}
538
539impl Sampler for AdaptiveSampler {
540    type Iter = SamplerIterator;
541
542    fn iter(&self) -> Self::Iter {
543        let weights = self.get_combined_weights();
544        let indices = self.sample_with_replacement(&weights);
545        SamplerIterator::new(indices)
546    }
547
548    fn len(&self) -> usize {
549        self.num_samples
550    }
551}
552
553/// Statistics about the current adaptive sampling state
554#[derive(Debug, Clone, PartialEq)]
555pub struct AdaptiveStats {
556    /// Current training epoch
557    pub current_epoch: usize,
558    /// Number of warmup epochs
559    pub warmup_epochs: usize,
560    /// Whether currently in warmup phase
561    pub is_warming_up: bool,
562    /// Number of samples classified as hard
563    pub hard_samples_count: usize,
564    /// Ratio of hard samples to total samples
565    pub hard_samples_ratio: f64,
566    /// Imbalance in sample frequencies (max/min)
567    pub frequency_imbalance: f64,
568    /// Mean loss across all samples
569    pub mean_loss: f64,
570    /// Current adaptation rate
571    pub adaptation_rate: f64,
572    /// Number of active strategies
573    pub num_strategies: usize,
574}
575
576/// Create an adaptive sampler with hard sampling focus
577///
578/// Convenience function for creating an adaptive sampler that emphasizes hard samples.
579///
580/// # Arguments
581///
582/// * `dataset_size` - Total size of the dataset
583/// * `num_samples` - Number of samples to select per iteration
584/// * `intensity` - Intensity of hard sampling focus
585/// * `seed` - Optional random seed for reproducible sampling
586pub fn hard_adaptive_sampler(
587    dataset_size: usize,
588    num_samples: usize,
589    intensity: f64,
590    seed: Option<u64>,
591) -> AdaptiveSampler {
592    let mut sampler = AdaptiveSampler::new(dataset_size, num_samples)
593        .add_strategy(AdaptiveStrategy::HardSampling { intensity }, 0.7);
594    if let Some(s) = seed {
595        sampler = sampler.with_generator(s);
596    }
597    sampler
598}
599
600/// Create an adaptive sampler with frequency balancing
601///
602/// Convenience function for creating an adaptive sampler that balances sample frequencies.
603///
604/// # Arguments
605///
606/// * `dataset_size` - Total size of the dataset
607/// * `num_samples` - Number of samples to select per iteration
608/// * `power` - Power for inverse frequency weighting
609/// * `seed` - Optional random seed for reproducible sampling
610pub fn frequency_balanced_sampler(
611    dataset_size: usize,
612    num_samples: usize,
613    power: f64,
614    seed: Option<u64>,
615) -> AdaptiveSampler {
616    let mut sampler = AdaptiveSampler::new(dataset_size, num_samples)
617        .add_strategy(AdaptiveStrategy::InverseFrequency { power }, 0.6);
618    if let Some(s) = seed {
619        sampler = sampler.with_generator(s);
620    }
621    sampler
622}
623
624/// Create an adaptive sampler with uncertainty focus
625///
626/// Convenience function for creating an adaptive sampler that emphasizes uncertain samples.
627///
628/// # Arguments
629///
630/// * `dataset_size` - Total size of the dataset
631/// * `num_samples` - Number of samples to select per iteration
632/// * `temperature` - Temperature for uncertainty weighting
633/// * `seed` - Optional random seed for reproducible sampling
634pub fn uncertainty_adaptive_sampler(
635    dataset_size: usize,
636    num_samples: usize,
637    temperature: f64,
638    seed: Option<u64>,
639) -> AdaptiveSampler {
640    let mut sampler = AdaptiveSampler::new(dataset_size, num_samples)
641        .add_strategy(AdaptiveStrategy::Uncertainty { temperature }, 0.8);
642    if let Some(s) = seed {
643        sampler = sampler.with_generator(s);
644    }
645    sampler
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651
652    #[test]
653    fn test_adaptive_sampler_basic() {
654        let dataset_size = 100;
655        let num_samples = 50;
656        let sampler = AdaptiveSampler::new(dataset_size, num_samples).with_generator(42);
657
658        assert_eq!(sampler.len(), num_samples);
659        assert_eq!(sampler.current_epoch(), 0);
660        assert_eq!(sampler.warmup_epochs(), 5);
661        assert!(sampler.is_warming_up());
662        assert_eq!(sampler.strategies().len(), 3); // Default strategies
663        assert_eq!(sampler.strategy_weights().len(), 3);
664
665        let indices: Vec<usize> = sampler.iter().collect();
666        assert_eq!(indices.len(), num_samples);
667
668        // All indices should be valid
669        for &idx in &indices {
670            assert!(idx < dataset_size);
671        }
672    }
673
674    #[test]
675    fn test_adaptive_sampler_with_losses() {
676        let dataset_size = 10;
677        let num_samples = 5;
678        let mut sampler = AdaptiveSampler::new(dataset_size, num_samples).with_generator(42);
679
680        // Initially all difficulties should be zero
681        assert!(sampler.sample_difficulties().iter().all(|&d| d == 0.0));
682
683        // Simulate training with some sample losses
684        let sample_indices = vec![0, 2, 4, 6, 8];
685        let losses = vec![0.1, 0.8, 0.2, 0.9, 0.3]; // Indices 2 and 6 have high losses
686
687        sampler.update_sample_losses(&sample_indices, &losses);
688
689        // Check that losses were updated (with smoothing factor 0.9: new_loss = 0.9 * 0.0 + 0.1 * input)
690        assert!((sampler.sample_losses()[0] - 0.01).abs() < 1e-10); // 0.1 * 0.1 = 0.01
691        assert!((sampler.sample_losses()[2] - 0.08).abs() < 1e-10); // 0.1 * 0.8 = 0.08
692        assert!((sampler.sample_losses()[6] - 0.09).abs() < 1e-10); // 0.1 * 0.9 = 0.09
693
694        // Check that frequencies were updated
695        assert_eq!(sampler.sample_frequencies()[0], 1);
696        assert_eq!(sampler.sample_frequencies()[2], 1);
697        assert_eq!(sampler.sample_frequencies()[1], 0); // Not sampled
698
699        // Sample should still work
700        let indices: Vec<usize> = sampler.iter().collect();
701        assert_eq!(indices.len(), num_samples);
702    }
703
704    #[test]
705    fn test_adaptive_sampler_strategy_adaptation() {
706        let dataset_size = 20;
707        let num_samples = 10;
708        let mut sampler = AdaptiveSampler::new(dataset_size, num_samples)
709            .with_warmup_epochs(2)
710            .with_generator(42);
711
712        // Add custom strategy
713        sampler = sampler.add_strategy(AdaptiveStrategy::InverseFrequency { power: 1.0 }, 0.2);
714        assert_eq!(sampler.strategies().len(), 4);
715
716        // Initial weights should be normalized
717        let initial_sum: f64 = sampler.strategy_weights().iter().sum();
718        assert!((initial_sum - 1.0).abs() < f64::EPSILON);
719
720        // During warmup, weights shouldn't change
721        sampler.set_epoch(1);
722        assert!(sampler.is_warming_up());
723
724        let sample_indices: Vec<usize> = (0..10).collect();
725        let losses = vec![0.5; 10];
726        sampler.update_sample_losses(&sample_indices, &losses);
727
728        let _weights_during_warmup = sampler.strategy_weights().to_vec();
729
730        // After warmup, weights can adapt
731        sampler.set_epoch(3);
732        assert!(!sampler.is_warming_up());
733
734        sampler.update_sample_losses(&sample_indices, &losses);
735        // Weights might have changed (depending on adaptation logic)
736
737        let indices: Vec<usize> = sampler.iter().collect();
738        assert_eq!(indices.len(), num_samples);
739    }
740
741    #[test]
742    fn test_adaptive_strategies() {
743        let dataset_size = 10;
744        let num_samples = 5;
745
746        let strategies = vec![
747            AdaptiveStrategy::HardSampling { intensity: 1.0 },
748            AdaptiveStrategy::EasySampling { intensity: 1.0 },
749            AdaptiveStrategy::Uniform,
750            AdaptiveStrategy::Uncertainty { temperature: 1.0 },
751            AdaptiveStrategy::InverseFrequency { power: 1.0 },
752            AdaptiveStrategy::GradientMagnitude { threshold: 0.5 },
753        ];
754
755        for strategy in strategies {
756            let sampler = AdaptiveSampler::new(dataset_size, num_samples)
757                .add_strategy(strategy, 0.5)
758                .with_generator(42);
759
760            let indices: Vec<usize> = sampler.iter().collect();
761            assert_eq!(indices.len(), num_samples);
762
763            // All indices should be valid
764            for &idx in &indices {
765                assert!(idx < dataset_size);
766            }
767        }
768    }
769
770    #[test]
771    fn test_adaptive_sampler_difficulty_calculation() {
772        let dataset_size = 5;
773        let num_samples = 3;
774        let mut sampler = AdaptiveSampler::new(dataset_size, num_samples).with_generator(42);
775
776        // Update with specific losses to test difficulty calculation
777        let sample_indices = vec![0, 1, 2, 3, 4];
778        let losses = vec![0.1, 0.2, 0.8, 0.9, 0.3]; // Clear difficulty pattern
779
780        sampler.update_sample_losses(&sample_indices, &losses);
781
782        let difficulties = sampler.sample_difficulties();
783
784        // Indices 2 and 3 should have higher difficulty (positive scores)
785        assert!(difficulties[2] > difficulties[0]);
786        assert!(difficulties[3] > difficulties[1]);
787        assert!(difficulties[2] > 0.0);
788        assert!(difficulties[3] > 0.0);
789    }
790
791    #[test]
792    fn test_adaptive_sampler_methods() {
793        let mut sampler = AdaptiveSampler::new(20, 10)
794            .with_adaptation_rate(0.2)
795            .with_smoothing_factor(0.8)
796            .with_warmup_epochs(3)
797            .with_generator(42);
798
799        assert_eq!(sampler.adaptation_rate(), 0.2);
800        assert_eq!(sampler.smoothing_factor(), 0.8);
801        assert_eq!(sampler.warmup_epochs(), 3);
802
803        // Test epoch setting
804        sampler.set_epoch(5);
805        assert_eq!(sampler.current_epoch(), 5);
806        assert!(!sampler.is_warming_up());
807
808        // Test reset
809        sampler.update_sample_losses(&[0, 1, 2], &[0.5, 0.6, 0.7]);
810        assert!(sampler.sample_losses().iter().any(|&l| l > 0.0));
811        assert!(sampler.sample_frequencies().iter().any(|&f| f > 0));
812
813        sampler.reset();
814        assert!(sampler.sample_losses().iter().all(|&l| l == 0.0));
815        assert!(sampler.sample_frequencies().iter().all(|&f| f == 0));
816        assert_eq!(sampler.current_epoch(), 0);
817    }
818
819    #[test]
820    fn test_adaptive_stats() {
821        let mut sampler = AdaptiveSampler::new(100, 32);
822
823        let stats = sampler.adaptive_stats();
824        assert_eq!(stats.current_epoch, 0);
825        assert_eq!(stats.warmup_epochs, 5);
826        assert!(stats.is_warming_up);
827        assert_eq!(stats.hard_samples_count, 0);
828        assert_eq!(stats.hard_samples_ratio, 0.0);
829        assert_eq!(stats.mean_loss, 0.0);
830        assert_eq!(stats.num_strategies, 3);
831
832        // Update with losses and check stats
833        let sample_indices: Vec<usize> = (0..20).collect();
834        let losses: Vec<f64> = (0..20).map(|i| if i > 15 { 0.8 } else { 0.2 }).collect();
835        sampler.update_sample_losses(&sample_indices, &losses);
836
837        let stats = sampler.adaptive_stats();
838        assert!(stats.mean_loss > 0.0);
839        assert!(stats.hard_samples_count > 0);
840        assert!(stats.hard_samples_ratio > 0.0);
841    }
842
843    #[test]
844    fn test_convenience_functions() {
845        // Test hard_adaptive_sampler
846        let hard_sampler = hard_adaptive_sampler(100, 32, 1.5, Some(42));
847        assert_eq!(hard_sampler.len(), 32);
848        assert!(hard_sampler.strategies().len() > 3); // Should have added hard sampling
849
850        // Test frequency_balanced_sampler
851        let freq_sampler = frequency_balanced_sampler(100, 32, 1.0, Some(42));
852        assert_eq!(freq_sampler.len(), 32);
853
854        // Test uncertainty_adaptive_sampler
855        let uncertainty_sampler = uncertainty_adaptive_sampler(100, 32, 0.8, Some(42));
856        assert_eq!(uncertainty_sampler.len(), 32);
857    }
858
859    #[test]
860    fn test_weight_normalization() {
861        let mut sampler = AdaptiveSampler::new(10, 5);
862
863        // Add strategies with arbitrary weights
864        sampler = sampler
865            .add_strategy(AdaptiveStrategy::HardSampling { intensity: 1.0 }, 2.0)
866            .add_strategy(AdaptiveStrategy::EasySampling { intensity: 1.0 }, 3.0);
867
868        // Weights should be normalized
869        let sum: f64 = sampler.strategy_weights().iter().sum();
870        assert!((sum - 1.0).abs() < f64::EPSILON);
871    }
872
873    #[test]
874    fn test_strategy_weights() {
875        let sampler = AdaptiveSampler::new(5, 3);
876
877        // Test uniform strategy
878        let uniform_weights = sampler.get_strategy_weights(&AdaptiveStrategy::Uniform);
879        assert!(uniform_weights.iter().all(|&w| w == 1.0));
880
881        // Create sampler with some loss data
882        let mut sampler_with_data = AdaptiveSampler::new(5, 3);
883        sampler_with_data.update_sample_losses(&[0, 1, 2], &[0.1, 0.5, 0.9]);
884
885        // Test hard sampling strategy
886        let hard_weights = sampler_with_data
887            .get_strategy_weights(&AdaptiveStrategy::HardSampling { intensity: 1.0 });
888        assert_eq!(hard_weights.len(), 5);
889
890        // Test easy sampling strategy
891        let easy_weights = sampler_with_data
892            .get_strategy_weights(&AdaptiveStrategy::EasySampling { intensity: 1.0 });
893        assert_eq!(easy_weights.len(), 5);
894    }
895
896    #[test]
897    fn test_edge_cases() {
898        // Empty num_samples
899        let empty_sampler = AdaptiveSampler::new(10, 0);
900        assert_eq!(empty_sampler.len(), 0);
901        let indices: Vec<usize> = empty_sampler.iter().collect();
902        assert!(indices.is_empty());
903
904        // Single sample
905        let single_sampler = AdaptiveSampler::new(10, 1);
906        let indices: Vec<usize> = single_sampler.iter().collect();
907        assert_eq!(indices.len(), 1);
908
909        // Large dataset
910        let large_sampler = AdaptiveSampler::new(10000, 64);
911        assert_eq!(large_sampler.len(), 64);
912
913        // Invalid sample indices (should be ignored)
914        let mut sampler = AdaptiveSampler::new(5, 3);
915        sampler.update_sample_losses(&[0, 10, 2], &[0.1, 0.5, 0.3]); // Index 10 is out of bounds
916                                                                     // Should not panic and should ignore invalid index
917
918        // Zero weights should fallback to uniform
919        let mut zero_weight_sampler = AdaptiveSampler::new(5, 3);
920        zero_weight_sampler.strategy_weights = vec![0.0, 0.0, 0.0];
921        zero_weight_sampler.normalize_strategy_weights();
922        let sum: f64 = zero_weight_sampler.strategy_weights().iter().sum();
923        assert!((sum - 1.0).abs() < f64::EPSILON);
924    }
925
926    #[test]
927    fn test_adaptive_strategy_equality() {
928        assert_eq!(
929            AdaptiveStrategy::HardSampling { intensity: 1.0 },
930            AdaptiveStrategy::HardSampling { intensity: 1.0 }
931        );
932        assert_eq!(AdaptiveStrategy::Uniform, AdaptiveStrategy::Uniform);
933        assert_ne!(
934            AdaptiveStrategy::HardSampling { intensity: 1.0 },
935            AdaptiveStrategy::EasySampling { intensity: 1.0 }
936        );
937    }
938
939    #[test]
940    fn test_adaptive_strategy_default() {
941        assert_eq!(AdaptiveStrategy::default(), AdaptiveStrategy::Uniform);
942    }
943
944    #[test]
945    fn test_parameter_clamping() {
946        let sampler = AdaptiveSampler::new(10, 5)
947            .with_adaptation_rate(1.5) // Should be clamped to 1.0
948            .with_smoothing_factor(-0.1); // Should be clamped to 0.0
949
950        assert_eq!(sampler.adaptation_rate(), 1.0);
951        assert_eq!(sampler.smoothing_factor(), 0.0);
952    }
953
954    #[test]
955    fn test_reproducibility() {
956        let mut sampler1 = AdaptiveSampler::new(20, 10).with_generator(123);
957        let mut sampler2 = AdaptiveSampler::new(20, 10).with_generator(123);
958
959        // Update both with same data
960        let sample_indices = vec![0, 1, 2, 3, 4];
961        let losses = vec![0.1, 0.2, 0.3, 0.4, 0.5];
962
963        sampler1.update_sample_losses(&sample_indices, &losses);
964        sampler2.update_sample_losses(&sample_indices, &losses);
965
966        let indices1: Vec<usize> = sampler1.iter().collect();
967        let indices2: Vec<usize> = sampler2.iter().collect();
968
969        assert_eq!(indices1, indices2);
970    }
971}