Skip to main content

torsh_data/sampler/
advanced.rs

1//! Advanced sampling strategies for specialized machine learning scenarios.
2//!
3//! This module provides sophisticated sampling techniques that go beyond basic
4//! random and sequential sampling. These strategies are particularly useful for
5//! handling imbalanced datasets, implementing importance sampling, and creating
6//! structured sampling patterns for specific machine learning applications.
7//!
8//! # Key Components
9//!
10//! ## Weighted Sampling
11//! - [`WeightedRandomSampler`] - Probability-weighted random sampling
12//! - Support for unnormalized weights and automatic normalization
13//! - Efficient implementation using alias method for O(1) sampling
14//!
15//! ## Grouped Sampling
16//! - [`GroupedSampler`] - Groups samples by user-defined criteria
17//! - Configurable shuffling within and between groups
18//! - Useful for batch sampling with specific constraints
19//!
20//! ## Stratified Sampling
21//! - [`StratifiedSampler`] - Maintains proportional representation across strata
22//! - Automatic balancing and class-aware sampling
23//! - Essential for classification with imbalanced datasets
24//!
25//! ## Importance Sampling
26//! - [`ImportanceSampler`] - Samples based on importance scores
27//! - Adaptive importance weight calculation
28//! - Critical for active learning and hard negative mining
29//!
30//! # Examples
31//!
32//! ## Weighted Random Sampling
33//! ```rust,ignore
34//! use torsh_data::sampler::{Sampler, WeightedRandomSampler};
35//!
36//! // Sample with higher probability for larger weights
37//! let weights = vec![0.1, 0.3, 0.2, 0.4];
38//! let sampler = WeightedRandomSampler::new(weights, true)
39//!     .with_generator(42);
40//!
41//! let indices: Vec<usize> = sampler.iter().take(10).collect();
42//! // Index 3 (weight 0.4) will appear more frequently
43//! ```
44//!
45//! ## Grouped Sampling
46//! ```rust,ignore
47//! use torsh_data::sampler::{Sampler, GroupedSampler};
48//!
49//! // Group samples by some criterion (e.g., class label)
50//! let group_fn = |idx: usize| idx % 3; // 3 groups
51//! # struct DummyDataset { len: usize }
52//! # impl crate::dataset::Dataset for DummyDataset {
53//! #     type Item = usize;
54//! #     fn get(&self, index: usize) -> Option<Self::Item> { Some(index) }
55//! #     fn len(&self) -> usize { self.len }
56//! # }
57//! let dataset = DummyDataset { len: 12 };
58//!
59//! let sampler = GroupedSampler::new(&dataset, group_fn)
60//!     .with_shuffle_groups(true)
61//!     .with_shuffle_within_groups(true);
62//!
63//! // Samples will be grouped together but in random order
64//! ```
65//!
66//! ## Stratified Sampling
67//! ```rust,ignore
68//! use torsh_data::sampler::{Sampler, StratifiedSampler};
69//!
70//! // Ensure balanced representation across classes
71//! let class_labels = vec![0, 0, 1, 1, 1, 2, 2, 2, 2];
72//! let sampler = StratifiedSampler::new(class_labels)
73//!     .with_proportional(true)
74//!     .with_generator(123);
75//!
76//! // Each class will be represented proportionally
77//! ```
78//!
79//! ## Importance Sampling
80//! ```rust,ignore
81//! use torsh_data::sampler::{Sampler, ImportanceSampler};
82//!
83//! // Sample based on importance scores (e.g., loss values)
84//! let importance_scores = vec![0.1, 0.8, 0.3, 0.9, 0.2];
85//! let sampler = ImportanceSampler::new(importance_scores)
86//!     .with_temperature(2.0)  // Higher temp = more uniform
87//!     .with_generator(456);
88//!
89//! // High-importance samples will be selected more frequently
90//! ```
91
92#[cfg(not(feature = "std"))]
93use alloc::{collections::HashMap, vec, vec::Vec};
94#[cfg(feature = "std")]
95use std::collections::HashMap;
96
97use super::core::{rng_utils, Sampler, SamplerIterator};
98use scirs2_core::rand_prelude::SliceRandom;
99use scirs2_core::random::Random;
100use scirs2_core::RngExt;
101
102/// Weighted random sampler for probability-based sampling.
103///
104/// This sampler allows you to specify different probabilities for each sample
105/// in the dataset. Samples with higher weights are more likely to be selected.
106/// This is essential for handling imbalanced datasets or implementing custom
107/// sampling distributions.
108///
109/// # Implementation Details
110///
111/// The sampler uses the alias method for efficient O(1) sampling after O(n)
112/// preprocessing. This makes it suitable for large datasets where you need
113/// to draw many samples.
114///
115/// # Performance Characteristics
116///
117/// - **Preprocessing**: O(n) time and space to build alias table
118/// - **Sampling**: O(1) per sample after preprocessing
119/// - **Memory**: O(n) for alias table storage
120/// - **Numerical Stability**: Handles unnormalized weights robustly
121#[derive(Debug, Clone)]
122pub struct WeightedRandomSampler {
123    weights: Vec<f32>,
124    replacement: bool,
125    generator: Option<u64>,
126    alias_table: Option<AliasTable>,
127}
128
129impl WeightedRandomSampler {
130    /// Create a new weighted random sampler.
131    ///
132    /// # Arguments
133    ///
134    /// * `weights` - Vector of weights for each sample (will be normalized)
135    /// * `replacement` - Whether to sample with replacement
136    ///
137    /// # Panics
138    ///
139    /// Panics if weights vector is empty or contains only zeros.
140    ///
141    /// # Examples
142    ///
143    /// ```rust,ignore
144    /// use torsh_data::sampler::{Sampler, WeightedRandomSampler};
145    ///
146    /// let weights = vec![1.0, 2.0, 3.0]; // Unnormalized weights
147    /// let sampler = WeightedRandomSampler::new(weights, true);
148    ///
149    /// // Sample probabilities will be [1/6, 2/6, 3/6]
150    /// let indices: Vec<usize> = sampler.iter().take(100).collect();
151    /// // Index 2 should appear most frequently
152    /// ```
153    pub fn new(weights: Vec<f32>, replacement: bool) -> Self {
154        assert!(!weights.is_empty(), "Weights vector cannot be empty");
155        assert!(
156            weights.iter().any(|&w| w > 0.0),
157            "At least one weight must be positive"
158        );
159
160        Self {
161            weights,
162            replacement,
163            generator: None,
164            alias_table: None,
165        }
166    }
167
168    /// Set random generator seed.
169    ///
170    /// # Arguments
171    ///
172    /// * `seed` - Seed for deterministic sampling
173    ///
174    /// # Examples
175    ///
176    /// ```rust,ignore
177    /// use torsh_data::sampler::{Sampler, WeightedRandomSampler};
178    ///
179    /// let weights = vec![1.0, 2.0, 3.0];
180    /// let sampler = WeightedRandomSampler::new(weights, true)
181    ///     .with_generator(42);
182    ///
183    /// // Sampling will be deterministic
184    /// ```
185    pub fn with_generator(mut self, seed: u64) -> Self {
186        self.generator = Some(seed);
187        self
188    }
189
190    /// Get the weights used by this sampler.
191    pub fn weights(&self) -> &[f32] {
192        &self.weights
193    }
194
195    /// Check if sampling is done with replacement.
196    pub fn uses_replacement(&self) -> bool {
197        self.replacement
198    }
199
200    /// Get the generator seed if set.
201    pub fn generator_seed(&self) -> Option<u64> {
202        self.generator
203    }
204
205    /// Build the alias table for efficient sampling.
206    fn build_alias_table(&mut self) {
207        if self.alias_table.is_none() {
208            self.alias_table = Some(AliasTable::new(&self.weights));
209        }
210    }
211
212    /// Generate weighted random indices.
213    fn generate_indices(&mut self, count: usize) -> Vec<usize> {
214        self.build_alias_table();
215        let alias_table = self
216            .alias_table
217            .as_ref()
218            .expect("alias table should be built");
219
220        let mut rng = rng_utils::create_rng(self.generator);
221        let mut indices = Vec::with_capacity(count);
222
223        for _ in 0..count {
224            let idx = alias_table.sample(&mut rng);
225            indices.push(idx);
226        }
227
228        indices
229    }
230}
231
232impl Sampler for WeightedRandomSampler {
233    type Iter = SamplerIterator;
234
235    fn iter(&self) -> Self::Iter {
236        let count = if self.replacement {
237            self.weights.len() // With replacement, sample as many as we have
238        } else {
239            self.weights.len() // Without replacement, sample each once
240        };
241
242        let mut sampler = self.clone();
243        let indices = if self.replacement {
244            sampler.generate_indices(count)
245        } else {
246            // Without replacement: weighted shuffle
247            let mut weighted_indices: Vec<(usize, f32)> = self
248                .weights
249                .iter()
250                .enumerate()
251                .map(|(i, &w)| (i, w))
252                .collect();
253
254            let mut rng = rng_utils::create_rng(self.generator);
255
256            // Fisher-Yates shuffle with weights
257            for i in (1..weighted_indices.len()).rev() {
258                let total_weight: f32 = weighted_indices[..=i].iter().map(|(_, w)| w).sum();
259                let mut target_weight = rng.random::<f32>() * total_weight;
260
261                let mut selected_idx = 0;
262                for (j, (_, weight)) in weighted_indices[..=i].iter().enumerate() {
263                    target_weight -= weight;
264                    if target_weight <= 0.0 {
265                        selected_idx = j;
266                        break;
267                    }
268                }
269
270                weighted_indices.swap(i, selected_idx);
271            }
272
273            weighted_indices.into_iter().map(|(idx, _)| idx).collect()
274        };
275
276        SamplerIterator::new(indices)
277    }
278
279    fn len(&self) -> usize {
280        self.weights.len()
281    }
282}
283
284/// Efficient alias table implementation for O(1) weighted sampling.
285///
286/// The alias method allows constant-time sampling from a discrete probability
287/// distribution by preprocessing the weights into a lookup table.
288#[derive(Debug, Clone)]
289struct AliasTable {
290    prob: Vec<f32>,
291    alias: Vec<usize>,
292}
293
294impl AliasTable {
295    /// Build an alias table from unnormalized weights.
296    fn new(weights: &[f32]) -> Self {
297        let n = weights.len();
298        let sum: f32 = weights.iter().sum();
299
300        assert!(sum > 0.0, "Total weight must be positive");
301
302        let mut prob = vec![0.0; n];
303        let mut alias = vec![0; n];
304
305        // Normalize weights to probabilities
306        let normalized: Vec<f32> = weights.iter().map(|&w| w * n as f32 / sum).collect();
307
308        // Separate into small and large probability buckets
309        let mut small = Vec::new();
310        let mut large = Vec::new();
311
312        for (i, &p) in normalized.iter().enumerate() {
313            if p < 1.0 {
314                small.push(i);
315            } else {
316                large.push(i);
317            }
318        }
319
320        prob.copy_from_slice(&normalized);
321
322        // Build alias table
323        while let (Some(l), Some(g)) = (small.pop(), large.pop()) {
324            alias[l] = g;
325            prob[g] = prob[g] + prob[l] - 1.0;
326
327            if prob[g] < 1.0 {
328                small.push(g);
329            } else {
330                large.push(g);
331            }
332        }
333
334        // Handle remaining large probabilities
335        while let Some(g) = large.pop() {
336            prob[g] = 1.0;
337        }
338
339        // Handle remaining small probabilities
340        while let Some(l) = small.pop() {
341            prob[l] = 1.0;
342        }
343
344        Self { prob, alias }
345    }
346
347    /// Sample an index using the alias table.
348    fn sample(&self, rng: &mut Random<scirs2_core::rngs::StdRng>) -> usize {
349        let i = rng.gen_range(0..self.prob.len());
350        let coin_flip = rng.random::<f32>();
351
352        if coin_flip < self.prob[i] {
353            i
354        } else {
355            self.alias[i]
356        }
357    }
358}
359
360/// Sampler that groups indices by a key function and samples groups together.
361///
362/// This sampler allows you to define custom grouping criteria and control
363/// how samples within and between groups are ordered. This is useful for
364/// scenarios where you want to process related samples together.
365///
366/// # Use Cases
367///
368/// - **Sequence Data**: Group by sequence ID to process complete sequences
369/// - **Hierarchical Data**: Group by category for structured processing
370/// - **Batch Constraints**: Ensure certain samples appear in the same batch
371/// - **Memory Efficiency**: Group similar samples for better cache locality
372#[derive(Debug)]
373pub struct GroupedSampler<F> {
374    groups: Vec<Vec<usize>>,
375    shuffle_groups: bool,
376    shuffle_within_groups: bool,
377    generator: Option<u64>,
378    _phantom: std::marker::PhantomData<F>,
379}
380
381impl<F> GroupedSampler<F>
382where
383    F: Fn(usize) -> usize + Send,
384{
385    /// Create a new grouped sampler.
386    ///
387    /// # Arguments
388    ///
389    /// * `dataset` - Dataset to sample from
390    /// * `group_fn` - Function that maps sample index to group ID
391    ///
392    /// # Examples
393    ///
394    /// ```rust,ignore
395    /// use torsh_data::sampler::{Sampler, GroupedSampler};
396    ///
397    /// # struct DummyDataset { len: usize }
398    /// # impl crate::dataset::Dataset for DummyDataset {
399    /// #     type Item = usize;
400    /// #     fn get(&self, index: usize) -> Option<Self::Item> { Some(index) }
401    /// #     fn len(&self) -> usize { self.len }
402    /// # }
403    /// let dataset = DummyDataset { len: 10 };
404    ///
405    /// // Group by class (assuming 3 classes)
406    /// let group_by_class = |idx: usize| idx % 3;
407    /// let sampler = GroupedSampler::new(&dataset, group_by_class);
408    /// ```
409    pub fn new<D>(dataset: &D, group_fn: F) -> Self
410    where
411        D: crate::dataset::Dataset,
412    {
413        let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
414
415        // Group indices by the group function
416        for idx in 0..dataset.len() {
417            let group_key = group_fn(idx);
418            groups.entry(group_key).or_default().push(idx);
419        }
420
421        // Convert to sorted vector of groups
422        let mut group_list: Vec<(usize, Vec<usize>)> = groups.into_iter().collect();
423        group_list.sort_by_key(|(key, _)| *key);
424        let groups: Vec<Vec<usize>> = group_list.into_iter().map(|(_, indices)| indices).collect();
425
426        Self {
427            groups,
428            shuffle_groups: false,
429            shuffle_within_groups: false,
430            generator: None,
431            _phantom: std::marker::PhantomData,
432        }
433    }
434
435    /// Set whether to shuffle the order of groups.
436    ///
437    /// # Arguments
438    ///
439    /// * `shuffle` - Whether to randomize group order
440    ///
441    /// # Examples
442    ///
443    /// ```rust,ignore
444    /// # use torsh_data::sampler::GroupedSampler;
445    /// # struct DummyDataset { len: usize }
446    /// # impl crate::dataset::Dataset for DummyDataset {
447    /// #     type Item = usize;
448    /// #     fn get(&self, index: usize) -> Option<Self::Item> { Some(index) }
449    /// #     fn len(&self) -> usize { self.len }
450    /// # }
451    /// let dataset = DummyDataset { len: 10 };
452    /// let sampler = GroupedSampler::new(&dataset, |idx| idx % 3)
453    ///     .with_shuffle_groups(true);
454    /// ```
455    pub fn with_shuffle_groups(mut self, shuffle: bool) -> Self {
456        self.shuffle_groups = shuffle;
457        self
458    }
459
460    /// Set whether to shuffle within each group.
461    ///
462    /// # Arguments
463    ///
464    /// * `shuffle` - Whether to randomize order within groups
465    ///
466    /// # Examples
467    ///
468    /// ```rust,ignore
469    /// # use torsh_data::sampler::GroupedSampler;
470    /// # struct DummyDataset { len: usize }
471    /// # impl crate::dataset::Dataset for DummyDataset {
472    /// #     type Item = usize;
473    /// #     fn get(&self, index: usize) -> Option<Self::Item> { Some(index) }
474    /// #     fn len(&self) -> usize { self.len }
475    /// # }
476    /// let dataset = DummyDataset { len: 10 };
477    /// let sampler = GroupedSampler::new(&dataset, |idx| idx % 3)
478    ///     .with_shuffle_within_groups(true);
479    /// ```
480    pub fn with_shuffle_within_groups(mut self, shuffle: bool) -> Self {
481        self.shuffle_within_groups = shuffle;
482        self
483    }
484
485    /// Set random generator seed.
486    ///
487    /// # Arguments
488    ///
489    /// * `seed` - Seed for deterministic shuffling
490    pub fn with_generator(mut self, seed: u64) -> Self {
491        self.generator = Some(seed);
492        self
493    }
494
495    /// Get the number of groups.
496    pub fn num_groups(&self) -> usize {
497        self.groups.len()
498    }
499
500    /// Get the sizes of all groups.
501    pub fn group_sizes(&self) -> Vec<usize> {
502        self.groups.iter().map(|group| group.len()).collect()
503    }
504
505    /// Check if groups will be shuffled.
506    pub fn shuffles_groups(&self) -> bool {
507        self.shuffle_groups
508    }
509
510    /// Check if samples within groups will be shuffled.
511    pub fn shuffles_within_groups(&self) -> bool {
512        self.shuffle_within_groups
513    }
514}
515
516impl<F: Send> Sampler for GroupedSampler<F> {
517    type Iter = SamplerIterator;
518
519    fn iter(&self) -> Self::Iter {
520        let mut rng = rng_utils::create_rng(self.generator);
521        let mut groups = self.groups.clone();
522
523        // Shuffle within groups if requested
524        if self.shuffle_within_groups {
525            for group in &mut groups {
526                group.shuffle(&mut rng);
527            }
528        }
529
530        // Shuffle the order of groups if requested
531        if self.shuffle_groups {
532            groups.shuffle(&mut rng);
533        }
534
535        // Flatten all groups into a single list of indices
536        let indices: Vec<usize> = groups.into_iter().flatten().collect();
537
538        SamplerIterator::new(indices)
539    }
540
541    fn len(&self) -> usize {
542        self.groups.iter().map(|group| group.len()).sum()
543    }
544}
545
546/// Stratified sampler for balanced representation across strata.
547///
548/// This sampler ensures that each stratum (class/category) is represented
549/// proportionally in the sample. This is essential for classification tasks
550/// with imbalanced datasets where you want to maintain class balance.
551///
552/// # Key Features
553///
554/// - **Proportional Sampling**: Maintains original class proportions
555/// - **Balanced Sampling**: Equal samples per class (when specified)
556/// - **Minimum Guarantees**: Ensures each class gets at least one sample
557/// - **Reproducible**: Deterministic when seeded
558#[derive(Debug, Clone)]
559pub struct StratifiedSampler {
560    strata: HashMap<usize, Vec<usize>>,
561    proportional: bool,
562    min_samples_per_stratum: usize,
563    generator: Option<u64>,
564}
565
566impl StratifiedSampler {
567    /// Create a new stratified sampler.
568    ///
569    /// # Arguments
570    ///
571    /// * `class_labels` - Vector mapping sample index to class/stratum
572    ///
573    /// # Examples
574    ///
575    /// ```rust,ignore
576    /// use torsh_data::sampler::{Sampler, StratifiedSampler};
577    ///
578    /// let labels = vec![0, 0, 1, 1, 1, 2]; // 2 class 0, 3 class 1, 1 class 2
579    /// let sampler = StratifiedSampler::new(labels);
580    ///
581    /// // Will maintain proportional representation
582    /// ```
583    pub fn new(class_labels: Vec<usize>) -> Self {
584        let mut strata: HashMap<usize, Vec<usize>> = HashMap::new();
585
586        // Group indices by class label
587        for (idx, &class) in class_labels.iter().enumerate() {
588            strata.entry(class).or_default().push(idx);
589        }
590
591        Self {
592            strata,
593            proportional: true,
594            min_samples_per_stratum: 1,
595            generator: None,
596        }
597    }
598
599    /// Create stratified sampler from pre-grouped strata.
600    ///
601    /// # Arguments
602    ///
603    /// * `strata` - Map from stratum ID to vector of sample indices
604    ///
605    /// # Examples
606    ///
607    /// ```rust,ignore
608    /// use std::collections::HashMap;
609    /// use torsh_data::sampler::StratifiedSampler;
610    ///
611    /// let mut strata = HashMap::new();
612    /// strata.insert(0, vec![0, 1, 2]);    // Stratum 0: indices 0, 1, 2
613    /// strata.insert(1, vec![3, 4, 5, 6]); // Stratum 1: indices 3, 4, 5, 6
614    ///
615    /// let sampler = StratifiedSampler::from_strata(strata);
616    /// ```
617    pub fn from_strata(strata: HashMap<usize, Vec<usize>>) -> Self {
618        Self {
619            strata,
620            proportional: true,
621            min_samples_per_stratum: 1,
622            generator: None,
623        }
624    }
625
626    /// Set whether to maintain proportional representation.
627    ///
628    /// When true (default), the number of samples per stratum is proportional
629    /// to the stratum size. When false, each stratum gets equal samples.
630    ///
631    /// # Arguments
632    ///
633    /// * `proportional` - Whether to use proportional sampling
634    pub fn with_proportional(mut self, proportional: bool) -> Self {
635        self.proportional = proportional;
636        self
637    }
638
639    /// Set minimum samples per stratum.
640    ///
641    /// Ensures each stratum gets at least this many samples, even if
642    /// proportional sampling would give it fewer.
643    ///
644    /// # Arguments
645    ///
646    /// * `min_samples` - Minimum samples per stratum
647    pub fn with_min_samples_per_stratum(mut self, min_samples: usize) -> Self {
648        self.min_samples_per_stratum = min_samples;
649        self
650    }
651
652    /// Set random generator seed.
653    pub fn with_generator(mut self, seed: u64) -> Self {
654        self.generator = Some(seed);
655        self
656    }
657
658    /// Get the number of strata.
659    pub fn num_strata(&self) -> usize {
660        self.strata.len()
661    }
662
663    /// Get the size of each stratum.
664    pub fn stratum_sizes(&self) -> HashMap<usize, usize> {
665        self.strata.iter().map(|(&k, v)| (k, v.len())).collect()
666    }
667
668    /// Check if proportional sampling is enabled.
669    pub fn uses_proportional(&self) -> bool {
670        self.proportional
671    }
672
673    /// Calculate how many samples each stratum should contribute.
674    fn calculate_stratum_samples(&self, total_samples: usize) -> HashMap<usize, usize> {
675        let total_stratum_size: usize = self.strata.values().map(|v| v.len()).sum();
676        let mut stratum_samples = HashMap::new();
677
678        if self.proportional {
679            // Proportional to stratum size
680            for (&stratum_id, indices) in &self.strata {
681                let proportional_samples = (indices.len() * total_samples) / total_stratum_size;
682                let final_samples = proportional_samples.max(self.min_samples_per_stratum);
683                stratum_samples.insert(stratum_id, final_samples);
684            }
685        } else {
686            // Equal samples per stratum
687            let samples_per_stratum = total_samples / self.strata.len();
688            for &stratum_id in self.strata.keys() {
689                stratum_samples.insert(
690                    stratum_id,
691                    samples_per_stratum.max(self.min_samples_per_stratum),
692                );
693            }
694        }
695
696        stratum_samples
697    }
698}
699
700impl Sampler for StratifiedSampler {
701    type Iter = SamplerIterator;
702
703    fn iter(&self) -> Self::Iter {
704        let total_samples: usize = self.strata.values().map(|v| v.len()).sum();
705        let stratum_samples = self.calculate_stratum_samples(total_samples);
706
707        let mut rng = rng_utils::create_rng(self.generator);
708        let mut all_indices = Vec::new();
709
710        // Sample from each stratum
711        for (&stratum_id, indices) in &self.strata {
712            let target_samples = stratum_samples[&stratum_id];
713            let mut stratum_indices = indices.clone();
714            stratum_indices.shuffle(&mut rng);
715
716            // Take samples with replacement if needed
717            if target_samples <= indices.len() {
718                all_indices.extend(&stratum_indices[..target_samples]);
719            } else {
720                // Need sampling with replacement
721                all_indices.extend(&stratum_indices);
722                for _ in indices.len()..target_samples {
723                    let idx = rng.gen_range(0..indices.len());
724                    all_indices.push(indices[idx]);
725                }
726            }
727        }
728
729        // Final shuffle to mix strata
730        all_indices.shuffle(&mut rng);
731
732        SamplerIterator::new(all_indices)
733    }
734
735    fn len(&self) -> usize {
736        let total_samples: usize = self.strata.values().map(|v| v.len()).sum();
737        let stratum_samples = self.calculate_stratum_samples(total_samples);
738        stratum_samples.values().sum()
739    }
740}
741
742/// Importance sampler for adaptive sample selection.
743///
744/// This sampler selects samples based on importance scores, which can represent
745/// various metrics like loss values, prediction confidence, or gradient norms.
746/// High-importance samples are selected more frequently, making this ideal for
747/// active learning and hard negative mining.
748///
749/// # Applications
750///
751/// - **Active Learning**: Sample uncertain or informative examples
752/// - **Hard Negative Mining**: Focus on difficult examples
753/// - **Curriculum Learning**: Gradually increase sample difficulty
754/// - **Online Learning**: Adapt to changing data distributions
755#[derive(Debug, Clone)]
756pub struct ImportanceSampler {
757    importance_scores: Vec<f32>,
758    temperature: f32,
759    generator: Option<u64>,
760    adaptive: bool,
761    update_rate: f32,
762}
763
764impl ImportanceSampler {
765    /// Create a new importance sampler.
766    ///
767    /// # Arguments
768    ///
769    /// * `importance_scores` - Vector of importance values for each sample
770    ///
771    /// # Examples
772    ///
773    /// ```rust,ignore
774    /// use torsh_data::sampler::{Sampler, ImportanceSampler};
775    ///
776    /// // Higher scores = more important
777    /// let scores = vec![0.1, 0.8, 0.3, 0.9, 0.2];
778    /// let sampler = ImportanceSampler::new(scores);
779    ///
780    /// // Samples 1 and 3 will be selected more frequently
781    /// ```
782    pub fn new(importance_scores: Vec<f32>) -> Self {
783        assert!(
784            !importance_scores.is_empty(),
785            "Importance scores cannot be empty"
786        );
787
788        Self {
789            importance_scores,
790            temperature: 1.0,
791            generator: None,
792            adaptive: false,
793            update_rate: 0.1,
794        }
795    }
796
797    /// Set the temperature for importance sampling.
798    ///
799    /// Higher temperature makes sampling more uniform, lower temperature
800    /// makes it more focused on high-importance samples.
801    ///
802    /// # Arguments
803    ///
804    /// * `temperature` - Temperature parameter (> 0.0)
805    ///
806    /// # Examples
807    ///
808    /// ```rust,ignore
809    /// use torsh_data::sampler::ImportanceSampler;
810    ///
811    /// let scores = vec![0.1, 0.8, 0.3];
812    /// let sampler = ImportanceSampler::new(scores)
813    ///     .with_temperature(2.0); // More uniform sampling
814    /// ```
815    pub fn with_temperature(mut self, temperature: f32) -> Self {
816        assert!(temperature > 0.0, "Temperature must be positive");
817        self.temperature = temperature;
818        self
819    }
820
821    /// Enable adaptive importance updates.
822    ///
823    /// When enabled, importance scores can be updated based on recent
824    /// sampling feedback to adapt to changing data characteristics.
825    ///
826    /// # Arguments
827    ///
828    /// * `adaptive` - Whether to enable adaptive updates
829    /// * `update_rate` - Rate of adaptation (0.0 to 1.0)
830    pub fn with_adaptive(mut self, adaptive: bool, update_rate: f32) -> Self {
831        assert!(
832            update_rate >= 0.0 && update_rate <= 1.0,
833            "Update rate must be in [0, 1]"
834        );
835        self.adaptive = adaptive;
836        self.update_rate = update_rate;
837        self
838    }
839
840    /// Set random generator seed.
841    pub fn with_generator(mut self, seed: u64) -> Self {
842        self.generator = Some(seed);
843        self
844    }
845
846    /// Get the importance scores.
847    pub fn importance_scores(&self) -> &[f32] {
848        &self.importance_scores
849    }
850
851    /// Get the temperature parameter.
852    pub fn temperature(&self) -> f32 {
853        self.temperature
854    }
855
856    /// Check if adaptive updates are enabled.
857    pub fn is_adaptive(&self) -> bool {
858        self.adaptive
859    }
860
861    /// Update importance scores (for adaptive sampling).
862    ///
863    /// # Arguments
864    ///
865    /// * `new_scores` - Updated importance scores
866    ///
867    /// # Examples
868    ///
869    /// ```rust,ignore
870    /// use torsh_data::sampler::ImportanceSampler;
871    ///
872    /// let mut sampler = ImportanceSampler::new(vec![0.1, 0.5, 0.3])
873    ///     .with_adaptive(true, 0.1);
874    ///
875    /// // Update based on new loss values
876    /// let new_losses = vec![0.2, 0.8, 0.1];
877    /// sampler.update_importance_scores(new_losses);
878    /// ```
879    pub fn update_importance_scores(&mut self, new_scores: Vec<f32>) {
880        if self.adaptive && new_scores.len() == self.importance_scores.len() {
881            for (old, &new) in self.importance_scores.iter_mut().zip(new_scores.iter()) {
882                *old = (1.0 - self.update_rate) * *old + self.update_rate * new;
883            }
884        }
885    }
886
887    /// Convert importance scores to sampling probabilities.
888    fn compute_probabilities(&self) -> Vec<f32> {
889        // Apply temperature scaling
890        let scaled_scores: Vec<f32> = self
891            .importance_scores
892            .iter()
893            .map(|&score| (score / self.temperature).exp())
894            .collect();
895
896        // Normalize to probabilities
897        let total: f32 = scaled_scores.iter().sum();
898        if total > 0.0 {
899            scaled_scores.iter().map(|&score| score / total).collect()
900        } else {
901            // Fallback to uniform if all scores are zero
902            vec![1.0 / self.importance_scores.len() as f32; self.importance_scores.len()]
903        }
904    }
905}
906
907impl Sampler for ImportanceSampler {
908    type Iter = SamplerIterator;
909
910    fn iter(&self) -> Self::Iter {
911        let probabilities = self.compute_probabilities();
912        let mut weighted_sampler = WeightedRandomSampler::new(probabilities, false);
913
914        if let Some(seed) = self.generator {
915            weighted_sampler = weighted_sampler.with_generator(seed);
916        }
917
918        weighted_sampler.iter()
919    }
920
921    fn len(&self) -> usize {
922        self.importance_scores.len()
923    }
924}
925
926#[cfg(test)]
927mod tests {
928    use super::*;
929
930    // Mock dataset for testing
931    struct MockDataset {
932        size: usize,
933    }
934
935    impl crate::dataset::Dataset for MockDataset {
936        type Item = usize;
937
938        fn get(&self, index: usize) -> torsh_core::error::Result<Self::Item> {
939            if index < self.size {
940                Ok(index)
941            } else {
942                Err(torsh_core::error::TorshError::IndexOutOfBounds {
943                    index,
944                    size: self.size,
945                })
946            }
947        }
948
949        fn len(&self) -> usize {
950            self.size
951        }
952    }
953
954    #[test]
955    fn test_weighted_random_sampler() {
956        let weights = vec![0.1, 0.3, 0.6]; // Unnormalized weights
957        let sampler = WeightedRandomSampler::new(weights.clone(), true).with_generator(42);
958
959        assert_eq!(sampler.len(), 3);
960        assert_eq!(sampler.weights(), &weights);
961        assert!(sampler.uses_replacement());
962        assert_eq!(sampler.generator_seed(), Some(42));
963
964        let indices: Vec<usize> = sampler.iter().collect();
965        assert_eq!(indices.len(), 3);
966        assert!(indices.iter().all(|&i| i < 3));
967    }
968
969    #[test]
970    fn test_weighted_sampler_deterministic() {
971        let weights = vec![1.0, 2.0, 3.0];
972        let sampler1 = WeightedRandomSampler::new(weights.clone(), true).with_generator(123);
973        let sampler2 = WeightedRandomSampler::new(weights, true).with_generator(123);
974
975        let indices1: Vec<usize> = sampler1.iter().collect();
976        let indices2: Vec<usize> = sampler2.iter().collect();
977
978        assert_eq!(indices1, indices2);
979    }
980
981    #[test]
982    fn test_alias_table() {
983        let weights = vec![1.0, 2.0, 3.0];
984        let table = AliasTable::new(&weights);
985
986        assert_eq!(table.prob.len(), 3);
987        assert_eq!(table.alias.len(), 3);
988
989        let mut rng = rng_utils::create_rng(Some(42));
990
991        // Sample multiple times to check basic functionality
992        let mut counts = vec![0; 3];
993        for _ in 0..1000 {
994            let sample = table.sample(&mut rng);
995            assert!(sample < 3);
996            counts[sample] += 1;
997        }
998
999        // Higher weights should have higher counts (approximately)
1000        assert!(counts[2] > counts[1]); // Weight 3 > Weight 2
1001        assert!(counts[1] > counts[0]); // Weight 2 > Weight 1
1002    }
1003
1004    #[test]
1005    fn test_grouped_sampler() {
1006        let dataset = MockDataset { size: 12 };
1007        let group_fn = |idx: usize| idx % 3; // 3 groups
1008
1009        let sampler = GroupedSampler::new(&dataset, group_fn)
1010            .with_shuffle_groups(false)
1011            .with_shuffle_within_groups(false);
1012
1013        assert_eq!(sampler.len(), 12);
1014        assert_eq!(sampler.num_groups(), 3);
1015        assert_eq!(sampler.group_sizes(), vec![4, 4, 4]); // 12 / 3 = 4 each
1016
1017        let indices: Vec<usize> = sampler.iter().collect();
1018        assert_eq!(indices.len(), 12);
1019
1020        // Without shuffling, should maintain group order
1021        // Group 0: [0, 3, 6, 9], Group 1: [1, 4, 7, 10], Group 2: [2, 5, 8, 11]
1022    }
1023
1024    #[test]
1025    fn test_grouped_sampler_with_shuffling() {
1026        let dataset = MockDataset { size: 9 };
1027        let group_fn = |idx: usize| idx % 3;
1028
1029        let sampler = GroupedSampler::new(&dataset, group_fn)
1030            .with_shuffle_groups(true)
1031            .with_shuffle_within_groups(true)
1032            .with_generator(42);
1033
1034        let indices1: Vec<usize> = sampler.iter().collect();
1035        let indices2: Vec<usize> = sampler.iter().collect();
1036
1037        // Should be deterministic with same seed
1038        assert_eq!(indices1, indices2);
1039        assert_eq!(indices1.len(), 9);
1040
1041        // Should contain all original indices
1042        let mut sorted_indices = indices1;
1043        sorted_indices.sort();
1044        assert_eq!(sorted_indices, (0..9).collect::<Vec<_>>());
1045    }
1046
1047    #[test]
1048    fn test_stratified_sampler() {
1049        let class_labels = vec![0, 0, 1, 1, 1, 2]; // 2 class 0, 3 class 1, 1 class 2
1050        let sampler = StratifiedSampler::new(class_labels)
1051            .with_proportional(true)
1052            .with_generator(42);
1053
1054        assert_eq!(sampler.num_strata(), 3);
1055        assert!(sampler.uses_proportional());
1056
1057        let stratum_sizes = sampler.stratum_sizes();
1058        assert_eq!(stratum_sizes[&0], 2);
1059        assert_eq!(stratum_sizes[&1], 3);
1060        assert_eq!(stratum_sizes[&2], 1);
1061
1062        let indices: Vec<usize> = sampler.iter().collect();
1063        assert!(!indices.is_empty());
1064    }
1065
1066    #[test]
1067    fn test_stratified_sampler_balanced() {
1068        let class_labels = vec![0, 0, 1, 1, 1, 2];
1069        let sampler = StratifiedSampler::new(class_labels)
1070            .with_proportional(false) // Equal samples per stratum
1071            .with_min_samples_per_stratum(2)
1072            .with_generator(42);
1073
1074        assert!(!sampler.uses_proportional());
1075
1076        let indices: Vec<usize> = sampler.iter().collect();
1077        assert!(!indices.is_empty());
1078    }
1079
1080    #[test]
1081    fn test_stratified_sampler_from_strata() {
1082        let mut strata = HashMap::new();
1083        strata.insert(0, vec![0, 1]);
1084        strata.insert(1, vec![2, 3, 4]);
1085        strata.insert(2, vec![5]);
1086
1087        let sampler = StratifiedSampler::from_strata(strata);
1088        assert_eq!(sampler.num_strata(), 3);
1089
1090        let indices: Vec<usize> = sampler.iter().collect();
1091        assert!(!indices.is_empty());
1092    }
1093
1094    #[test]
1095    fn test_importance_sampler() {
1096        let scores = vec![0.1, 0.8, 0.3, 0.9, 0.2];
1097        let sampler = ImportanceSampler::new(scores.clone())
1098            .with_temperature(1.0)
1099            .with_generator(42);
1100
1101        assert_eq!(sampler.len(), 5);
1102        assert_eq!(sampler.importance_scores(), &scores);
1103        assert_eq!(sampler.temperature(), 1.0);
1104        assert!(!sampler.is_adaptive());
1105
1106        let indices: Vec<usize> = sampler.iter().collect();
1107        assert_eq!(indices.len(), 5);
1108        assert!(indices.iter().all(|&i| i < 5));
1109    }
1110
1111    #[test]
1112    fn test_importance_sampler_temperature() {
1113        let scores = vec![0.1, 1.0, 0.1]; // One very high score
1114
1115        // Low temperature - should heavily favor high-importance sample
1116        let low_temp_sampler = ImportanceSampler::new(scores.clone())
1117            .with_temperature(0.1)
1118            .with_generator(42);
1119
1120        // High temperature - should be more uniform
1121        let high_temp_sampler = ImportanceSampler::new(scores)
1122            .with_temperature(10.0)
1123            .with_generator(42);
1124
1125        // Both should work without panicking
1126        let _low_indices: Vec<usize> = low_temp_sampler.iter().collect();
1127        let _high_indices: Vec<usize> = high_temp_sampler.iter().collect();
1128    }
1129
1130    #[test]
1131    fn test_importance_sampler_adaptive() {
1132        let scores = vec![0.1, 0.5, 0.3];
1133        let mut sampler = ImportanceSampler::new(scores)
1134            .with_adaptive(true, 0.2)
1135            .with_generator(42);
1136
1137        assert!(sampler.is_adaptive());
1138
1139        let original_scores = sampler.importance_scores().to_vec();
1140
1141        // Update scores
1142        let new_scores = vec![0.2, 0.8, 0.1];
1143        sampler.update_importance_scores(new_scores);
1144
1145        let updated_scores = sampler.importance_scores().to_vec();
1146        assert_ne!(original_scores, updated_scores);
1147
1148        // The updated scores should be a blend of old and new
1149        for i in 0..3 {
1150            assert!(updated_scores[i] != original_scores[i]);
1151        }
1152    }
1153
1154    #[test]
1155    #[should_panic(expected = "Weights vector cannot be empty")]
1156    fn test_weighted_sampler_empty_weights() {
1157        WeightedRandomSampler::new(vec![], true);
1158    }
1159
1160    #[test]
1161    #[should_panic(expected = "At least one weight must be positive")]
1162    fn test_weighted_sampler_zero_weights() {
1163        WeightedRandomSampler::new(vec![0.0, 0.0, 0.0], true);
1164    }
1165
1166    #[test]
1167    #[should_panic(expected = "Temperature must be positive")]
1168    fn test_importance_sampler_zero_temperature() {
1169        let scores = vec![0.1, 0.2, 0.3];
1170        ImportanceSampler::new(scores).with_temperature(0.0);
1171    }
1172
1173    #[test]
1174    #[should_panic(expected = "Importance scores cannot be empty")]
1175    fn test_importance_sampler_empty_scores() {
1176        ImportanceSampler::new(vec![]);
1177    }
1178
1179    #[test]
1180    fn test_importance_sampler_probabilities() {
1181        let scores = vec![1.0, 2.0, 3.0];
1182        let sampler = ImportanceSampler::new(scores).with_temperature(1.0);
1183
1184        let probabilities = sampler.compute_probabilities();
1185        assert_eq!(probabilities.len(), 3);
1186
1187        // Probabilities should sum to 1 (approximately)
1188        let sum: f32 = probabilities.iter().sum();
1189        assert!((sum - 1.0).abs() < 1e-6);
1190
1191        // Higher scores should have higher probabilities
1192        assert!(probabilities[2] > probabilities[1]);
1193        assert!(probabilities[1] > probabilities[0]);
1194    }
1195}