torsh_data/sampler/advanced.rs
1//! Advanced sampling strategies for specialized machine learning scenarios.
2//!
3//! This module provides sophisticated sampling techniques that go beyond basic
4//! random and sequential sampling. These strategies are particularly useful for
5//! handling imbalanced datasets, implementing importance sampling, and creating
6//! structured sampling patterns for specific machine learning applications.
7//!
8//! # Key Components
9//!
10//! ## Weighted Sampling
11//! - [`WeightedRandomSampler`] - Probability-weighted random sampling
12//! - Support for unnormalized weights and automatic normalization
13//! - Efficient implementation using alias method for O(1) sampling
14//!
15//! ## Grouped Sampling
16//! - [`GroupedSampler`] - Groups samples by user-defined criteria
17//! - Configurable shuffling within and between groups
18//! - Useful for batch sampling with specific constraints
19//!
20//! ## Stratified Sampling
21//! - [`StratifiedSampler`] - Maintains proportional representation across strata
22//! - Automatic balancing and class-aware sampling
23//! - Essential for classification with imbalanced datasets
24//!
25//! ## Importance Sampling
26//! - [`ImportanceSampler`] - Samples based on importance scores
27//! - Adaptive importance weight calculation
28//! - Critical for active learning and hard negative mining
29//!
30//! # Examples
31//!
32//! ## Weighted Random Sampling
33//! ```rust,ignore
34//! use torsh_data::sampler::{Sampler, WeightedRandomSampler};
35//!
36//! // Sample with higher probability for larger weights
37//! let weights = vec![0.1, 0.3, 0.2, 0.4];
38//! let sampler = WeightedRandomSampler::new(weights, true)
39//! .with_generator(42);
40//!
41//! let indices: Vec<usize> = sampler.iter().take(10).collect();
42//! // Index 3 (weight 0.4) will appear more frequently
43//! ```
44//!
45//! ## Grouped Sampling
46//! ```rust,ignore
47//! use torsh_data::sampler::{Sampler, GroupedSampler};
48//!
49//! // Group samples by some criterion (e.g., class label)
50//! let group_fn = |idx: usize| idx % 3; // 3 groups
51//! # struct DummyDataset { len: usize }
52//! # impl crate::dataset::Dataset for DummyDataset {
53//! # type Item = usize;
54//! # fn get(&self, index: usize) -> Option<Self::Item> { Some(index) }
55//! # fn len(&self) -> usize { self.len }
56//! # }
57//! let dataset = DummyDataset { len: 12 };
58//!
59//! let sampler = GroupedSampler::new(&dataset, group_fn)
60//! .with_shuffle_groups(true)
61//! .with_shuffle_within_groups(true);
62//!
63//! // Samples will be grouped together but in random order
64//! ```
65//!
66//! ## Stratified Sampling
67//! ```rust,ignore
68//! use torsh_data::sampler::{Sampler, StratifiedSampler};
69//!
70//! // Ensure balanced representation across classes
71//! let class_labels = vec![0, 0, 1, 1, 1, 2, 2, 2, 2];
72//! let sampler = StratifiedSampler::new(class_labels)
73//! .with_proportional(true)
74//! .with_generator(123);
75//!
76//! // Each class will be represented proportionally
77//! ```
78//!
79//! ## Importance Sampling
80//! ```rust,ignore
81//! use torsh_data::sampler::{Sampler, ImportanceSampler};
82//!
83//! // Sample based on importance scores (e.g., loss values)
84//! let importance_scores = vec![0.1, 0.8, 0.3, 0.9, 0.2];
85//! let sampler = ImportanceSampler::new(importance_scores)
86//! .with_temperature(2.0) // Higher temp = more uniform
87//! .with_generator(456);
88//!
89//! // High-importance samples will be selected more frequently
90//! ```
91
92#[cfg(not(feature = "std"))]
93use alloc::{collections::HashMap, vec, vec::Vec};
94#[cfg(feature = "std")]
95use std::collections::HashMap;
96
97use super::core::{rng_utils, Sampler, SamplerIterator};
98use scirs2_core::rand_prelude::SliceRandom;
99use scirs2_core::random::Random;
100use scirs2_core::RngExt;
101
102/// Weighted random sampler for probability-based sampling.
103///
104/// This sampler allows you to specify different probabilities for each sample
105/// in the dataset. Samples with higher weights are more likely to be selected.
106/// This is essential for handling imbalanced datasets or implementing custom
107/// sampling distributions.
108///
109/// # Implementation Details
110///
111/// The sampler uses the alias method for efficient O(1) sampling after O(n)
112/// preprocessing. This makes it suitable for large datasets where you need
113/// to draw many samples.
114///
115/// # Performance Characteristics
116///
117/// - **Preprocessing**: O(n) time and space to build alias table
118/// - **Sampling**: O(1) per sample after preprocessing
119/// - **Memory**: O(n) for alias table storage
120/// - **Numerical Stability**: Handles unnormalized weights robustly
121#[derive(Debug, Clone)]
122pub struct WeightedRandomSampler {
123 weights: Vec<f32>,
124 replacement: bool,
125 generator: Option<u64>,
126 alias_table: Option<AliasTable>,
127}
128
129impl WeightedRandomSampler {
130 /// Create a new weighted random sampler.
131 ///
132 /// # Arguments
133 ///
134 /// * `weights` - Vector of weights for each sample (will be normalized)
135 /// * `replacement` - Whether to sample with replacement
136 ///
137 /// # Panics
138 ///
139 /// Panics if weights vector is empty or contains only zeros.
140 ///
141 /// # Examples
142 ///
143 /// ```rust,ignore
144 /// use torsh_data::sampler::{Sampler, WeightedRandomSampler};
145 ///
146 /// let weights = vec![1.0, 2.0, 3.0]; // Unnormalized weights
147 /// let sampler = WeightedRandomSampler::new(weights, true);
148 ///
149 /// // Sample probabilities will be [1/6, 2/6, 3/6]
150 /// let indices: Vec<usize> = sampler.iter().take(100).collect();
151 /// // Index 2 should appear most frequently
152 /// ```
153 pub fn new(weights: Vec<f32>, replacement: bool) -> Self {
154 assert!(!weights.is_empty(), "Weights vector cannot be empty");
155 assert!(
156 weights.iter().any(|&w| w > 0.0),
157 "At least one weight must be positive"
158 );
159
160 Self {
161 weights,
162 replacement,
163 generator: None,
164 alias_table: None,
165 }
166 }
167
168 /// Set random generator seed.
169 ///
170 /// # Arguments
171 ///
172 /// * `seed` - Seed for deterministic sampling
173 ///
174 /// # Examples
175 ///
176 /// ```rust,ignore
177 /// use torsh_data::sampler::{Sampler, WeightedRandomSampler};
178 ///
179 /// let weights = vec![1.0, 2.0, 3.0];
180 /// let sampler = WeightedRandomSampler::new(weights, true)
181 /// .with_generator(42);
182 ///
183 /// // Sampling will be deterministic
184 /// ```
185 pub fn with_generator(mut self, seed: u64) -> Self {
186 self.generator = Some(seed);
187 self
188 }
189
190 /// Get the weights used by this sampler.
191 pub fn weights(&self) -> &[f32] {
192 &self.weights
193 }
194
195 /// Check if sampling is done with replacement.
196 pub fn uses_replacement(&self) -> bool {
197 self.replacement
198 }
199
200 /// Get the generator seed if set.
201 pub fn generator_seed(&self) -> Option<u64> {
202 self.generator
203 }
204
205 /// Build the alias table for efficient sampling.
206 fn build_alias_table(&mut self) {
207 if self.alias_table.is_none() {
208 self.alias_table = Some(AliasTable::new(&self.weights));
209 }
210 }
211
212 /// Generate weighted random indices.
213 fn generate_indices(&mut self, count: usize) -> Vec<usize> {
214 self.build_alias_table();
215 let alias_table = self
216 .alias_table
217 .as_ref()
218 .expect("alias table should be built");
219
220 let mut rng = rng_utils::create_rng(self.generator);
221 let mut indices = Vec::with_capacity(count);
222
223 for _ in 0..count {
224 let idx = alias_table.sample(&mut rng);
225 indices.push(idx);
226 }
227
228 indices
229 }
230}
231
232impl Sampler for WeightedRandomSampler {
233 type Iter = SamplerIterator;
234
235 fn iter(&self) -> Self::Iter {
236 let count = if self.replacement {
237 self.weights.len() // With replacement, sample as many as we have
238 } else {
239 self.weights.len() // Without replacement, sample each once
240 };
241
242 let mut sampler = self.clone();
243 let indices = if self.replacement {
244 sampler.generate_indices(count)
245 } else {
246 // Without replacement: weighted shuffle
247 let mut weighted_indices: Vec<(usize, f32)> = self
248 .weights
249 .iter()
250 .enumerate()
251 .map(|(i, &w)| (i, w))
252 .collect();
253
254 let mut rng = rng_utils::create_rng(self.generator);
255
256 // Fisher-Yates shuffle with weights
257 for i in (1..weighted_indices.len()).rev() {
258 let total_weight: f32 = weighted_indices[..=i].iter().map(|(_, w)| w).sum();
259 let mut target_weight = rng.random::<f32>() * total_weight;
260
261 let mut selected_idx = 0;
262 for (j, (_, weight)) in weighted_indices[..=i].iter().enumerate() {
263 target_weight -= weight;
264 if target_weight <= 0.0 {
265 selected_idx = j;
266 break;
267 }
268 }
269
270 weighted_indices.swap(i, selected_idx);
271 }
272
273 weighted_indices.into_iter().map(|(idx, _)| idx).collect()
274 };
275
276 SamplerIterator::new(indices)
277 }
278
279 fn len(&self) -> usize {
280 self.weights.len()
281 }
282}
283
284/// Efficient alias table implementation for O(1) weighted sampling.
285///
286/// The alias method allows constant-time sampling from a discrete probability
287/// distribution by preprocessing the weights into a lookup table.
288#[derive(Debug, Clone)]
289struct AliasTable {
290 prob: Vec<f32>,
291 alias: Vec<usize>,
292}
293
294impl AliasTable {
295 /// Build an alias table from unnormalized weights.
296 fn new(weights: &[f32]) -> Self {
297 let n = weights.len();
298 let sum: f32 = weights.iter().sum();
299
300 assert!(sum > 0.0, "Total weight must be positive");
301
302 let mut prob = vec![0.0; n];
303 let mut alias = vec![0; n];
304
305 // Normalize weights to probabilities
306 let normalized: Vec<f32> = weights.iter().map(|&w| w * n as f32 / sum).collect();
307
308 // Separate into small and large probability buckets
309 let mut small = Vec::new();
310 let mut large = Vec::new();
311
312 for (i, &p) in normalized.iter().enumerate() {
313 if p < 1.0 {
314 small.push(i);
315 } else {
316 large.push(i);
317 }
318 }
319
320 prob.copy_from_slice(&normalized);
321
322 // Build alias table
323 while let (Some(l), Some(g)) = (small.pop(), large.pop()) {
324 alias[l] = g;
325 prob[g] = prob[g] + prob[l] - 1.0;
326
327 if prob[g] < 1.0 {
328 small.push(g);
329 } else {
330 large.push(g);
331 }
332 }
333
334 // Handle remaining large probabilities
335 while let Some(g) = large.pop() {
336 prob[g] = 1.0;
337 }
338
339 // Handle remaining small probabilities
340 while let Some(l) = small.pop() {
341 prob[l] = 1.0;
342 }
343
344 Self { prob, alias }
345 }
346
347 /// Sample an index using the alias table.
348 fn sample(&self, rng: &mut Random<scirs2_core::rngs::StdRng>) -> usize {
349 let i = rng.gen_range(0..self.prob.len());
350 let coin_flip = rng.random::<f32>();
351
352 if coin_flip < self.prob[i] {
353 i
354 } else {
355 self.alias[i]
356 }
357 }
358}
359
360/// Sampler that groups indices by a key function and samples groups together.
361///
362/// This sampler allows you to define custom grouping criteria and control
363/// how samples within and between groups are ordered. This is useful for
364/// scenarios where you want to process related samples together.
365///
366/// # Use Cases
367///
368/// - **Sequence Data**: Group by sequence ID to process complete sequences
369/// - **Hierarchical Data**: Group by category for structured processing
370/// - **Batch Constraints**: Ensure certain samples appear in the same batch
371/// - **Memory Efficiency**: Group similar samples for better cache locality
372#[derive(Debug)]
373pub struct GroupedSampler<F> {
374 groups: Vec<Vec<usize>>,
375 shuffle_groups: bool,
376 shuffle_within_groups: bool,
377 generator: Option<u64>,
378 _phantom: std::marker::PhantomData<F>,
379}
380
381impl<F> GroupedSampler<F>
382where
383 F: Fn(usize) -> usize + Send,
384{
385 /// Create a new grouped sampler.
386 ///
387 /// # Arguments
388 ///
389 /// * `dataset` - Dataset to sample from
390 /// * `group_fn` - Function that maps sample index to group ID
391 ///
392 /// # Examples
393 ///
394 /// ```rust,ignore
395 /// use torsh_data::sampler::{Sampler, GroupedSampler};
396 ///
397 /// # struct DummyDataset { len: usize }
398 /// # impl crate::dataset::Dataset for DummyDataset {
399 /// # type Item = usize;
400 /// # fn get(&self, index: usize) -> Option<Self::Item> { Some(index) }
401 /// # fn len(&self) -> usize { self.len }
402 /// # }
403 /// let dataset = DummyDataset { len: 10 };
404 ///
405 /// // Group by class (assuming 3 classes)
406 /// let group_by_class = |idx: usize| idx % 3;
407 /// let sampler = GroupedSampler::new(&dataset, group_by_class);
408 /// ```
409 pub fn new<D>(dataset: &D, group_fn: F) -> Self
410 where
411 D: crate::dataset::Dataset,
412 {
413 let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
414
415 // Group indices by the group function
416 for idx in 0..dataset.len() {
417 let group_key = group_fn(idx);
418 groups.entry(group_key).or_default().push(idx);
419 }
420
421 // Convert to sorted vector of groups
422 let mut group_list: Vec<(usize, Vec<usize>)> = groups.into_iter().collect();
423 group_list.sort_by_key(|(key, _)| *key);
424 let groups: Vec<Vec<usize>> = group_list.into_iter().map(|(_, indices)| indices).collect();
425
426 Self {
427 groups,
428 shuffle_groups: false,
429 shuffle_within_groups: false,
430 generator: None,
431 _phantom: std::marker::PhantomData,
432 }
433 }
434
435 /// Set whether to shuffle the order of groups.
436 ///
437 /// # Arguments
438 ///
439 /// * `shuffle` - Whether to randomize group order
440 ///
441 /// # Examples
442 ///
443 /// ```rust,ignore
444 /// # use torsh_data::sampler::GroupedSampler;
445 /// # struct DummyDataset { len: usize }
446 /// # impl crate::dataset::Dataset for DummyDataset {
447 /// # type Item = usize;
448 /// # fn get(&self, index: usize) -> Option<Self::Item> { Some(index) }
449 /// # fn len(&self) -> usize { self.len }
450 /// # }
451 /// let dataset = DummyDataset { len: 10 };
452 /// let sampler = GroupedSampler::new(&dataset, |idx| idx % 3)
453 /// .with_shuffle_groups(true);
454 /// ```
455 pub fn with_shuffle_groups(mut self, shuffle: bool) -> Self {
456 self.shuffle_groups = shuffle;
457 self
458 }
459
460 /// Set whether to shuffle within each group.
461 ///
462 /// # Arguments
463 ///
464 /// * `shuffle` - Whether to randomize order within groups
465 ///
466 /// # Examples
467 ///
468 /// ```rust,ignore
469 /// # use torsh_data::sampler::GroupedSampler;
470 /// # struct DummyDataset { len: usize }
471 /// # impl crate::dataset::Dataset for DummyDataset {
472 /// # type Item = usize;
473 /// # fn get(&self, index: usize) -> Option<Self::Item> { Some(index) }
474 /// # fn len(&self) -> usize { self.len }
475 /// # }
476 /// let dataset = DummyDataset { len: 10 };
477 /// let sampler = GroupedSampler::new(&dataset, |idx| idx % 3)
478 /// .with_shuffle_within_groups(true);
479 /// ```
480 pub fn with_shuffle_within_groups(mut self, shuffle: bool) -> Self {
481 self.shuffle_within_groups = shuffle;
482 self
483 }
484
485 /// Set random generator seed.
486 ///
487 /// # Arguments
488 ///
489 /// * `seed` - Seed for deterministic shuffling
490 pub fn with_generator(mut self, seed: u64) -> Self {
491 self.generator = Some(seed);
492 self
493 }
494
495 /// Get the number of groups.
496 pub fn num_groups(&self) -> usize {
497 self.groups.len()
498 }
499
500 /// Get the sizes of all groups.
501 pub fn group_sizes(&self) -> Vec<usize> {
502 self.groups.iter().map(|group| group.len()).collect()
503 }
504
505 /// Check if groups will be shuffled.
506 pub fn shuffles_groups(&self) -> bool {
507 self.shuffle_groups
508 }
509
510 /// Check if samples within groups will be shuffled.
511 pub fn shuffles_within_groups(&self) -> bool {
512 self.shuffle_within_groups
513 }
514}
515
516impl<F: Send> Sampler for GroupedSampler<F> {
517 type Iter = SamplerIterator;
518
519 fn iter(&self) -> Self::Iter {
520 let mut rng = rng_utils::create_rng(self.generator);
521 let mut groups = self.groups.clone();
522
523 // Shuffle within groups if requested
524 if self.shuffle_within_groups {
525 for group in &mut groups {
526 group.shuffle(&mut rng);
527 }
528 }
529
530 // Shuffle the order of groups if requested
531 if self.shuffle_groups {
532 groups.shuffle(&mut rng);
533 }
534
535 // Flatten all groups into a single list of indices
536 let indices: Vec<usize> = groups.into_iter().flatten().collect();
537
538 SamplerIterator::new(indices)
539 }
540
541 fn len(&self) -> usize {
542 self.groups.iter().map(|group| group.len()).sum()
543 }
544}
545
546/// Stratified sampler for balanced representation across strata.
547///
548/// This sampler ensures that each stratum (class/category) is represented
549/// proportionally in the sample. This is essential for classification tasks
550/// with imbalanced datasets where you want to maintain class balance.
551///
552/// # Key Features
553///
554/// - **Proportional Sampling**: Maintains original class proportions
555/// - **Balanced Sampling**: Equal samples per class (when specified)
556/// - **Minimum Guarantees**: Ensures each class gets at least one sample
557/// - **Reproducible**: Deterministic when seeded
558#[derive(Debug, Clone)]
559pub struct StratifiedSampler {
560 strata: HashMap<usize, Vec<usize>>,
561 proportional: bool,
562 min_samples_per_stratum: usize,
563 generator: Option<u64>,
564}
565
566impl StratifiedSampler {
567 /// Create a new stratified sampler.
568 ///
569 /// # Arguments
570 ///
571 /// * `class_labels` - Vector mapping sample index to class/stratum
572 ///
573 /// # Examples
574 ///
575 /// ```rust,ignore
576 /// use torsh_data::sampler::{Sampler, StratifiedSampler};
577 ///
578 /// let labels = vec![0, 0, 1, 1, 1, 2]; // 2 class 0, 3 class 1, 1 class 2
579 /// let sampler = StratifiedSampler::new(labels);
580 ///
581 /// // Will maintain proportional representation
582 /// ```
583 pub fn new(class_labels: Vec<usize>) -> Self {
584 let mut strata: HashMap<usize, Vec<usize>> = HashMap::new();
585
586 // Group indices by class label
587 for (idx, &class) in class_labels.iter().enumerate() {
588 strata.entry(class).or_default().push(idx);
589 }
590
591 Self {
592 strata,
593 proportional: true,
594 min_samples_per_stratum: 1,
595 generator: None,
596 }
597 }
598
599 /// Create stratified sampler from pre-grouped strata.
600 ///
601 /// # Arguments
602 ///
603 /// * `strata` - Map from stratum ID to vector of sample indices
604 ///
605 /// # Examples
606 ///
607 /// ```rust,ignore
608 /// use std::collections::HashMap;
609 /// use torsh_data::sampler::StratifiedSampler;
610 ///
611 /// let mut strata = HashMap::new();
612 /// strata.insert(0, vec![0, 1, 2]); // Stratum 0: indices 0, 1, 2
613 /// strata.insert(1, vec![3, 4, 5, 6]); // Stratum 1: indices 3, 4, 5, 6
614 ///
615 /// let sampler = StratifiedSampler::from_strata(strata);
616 /// ```
617 pub fn from_strata(strata: HashMap<usize, Vec<usize>>) -> Self {
618 Self {
619 strata,
620 proportional: true,
621 min_samples_per_stratum: 1,
622 generator: None,
623 }
624 }
625
626 /// Set whether to maintain proportional representation.
627 ///
628 /// When true (default), the number of samples per stratum is proportional
629 /// to the stratum size. When false, each stratum gets equal samples.
630 ///
631 /// # Arguments
632 ///
633 /// * `proportional` - Whether to use proportional sampling
634 pub fn with_proportional(mut self, proportional: bool) -> Self {
635 self.proportional = proportional;
636 self
637 }
638
639 /// Set minimum samples per stratum.
640 ///
641 /// Ensures each stratum gets at least this many samples, even if
642 /// proportional sampling would give it fewer.
643 ///
644 /// # Arguments
645 ///
646 /// * `min_samples` - Minimum samples per stratum
647 pub fn with_min_samples_per_stratum(mut self, min_samples: usize) -> Self {
648 self.min_samples_per_stratum = min_samples;
649 self
650 }
651
652 /// Set random generator seed.
653 pub fn with_generator(mut self, seed: u64) -> Self {
654 self.generator = Some(seed);
655 self
656 }
657
658 /// Get the number of strata.
659 pub fn num_strata(&self) -> usize {
660 self.strata.len()
661 }
662
663 /// Get the size of each stratum.
664 pub fn stratum_sizes(&self) -> HashMap<usize, usize> {
665 self.strata.iter().map(|(&k, v)| (k, v.len())).collect()
666 }
667
668 /// Check if proportional sampling is enabled.
669 pub fn uses_proportional(&self) -> bool {
670 self.proportional
671 }
672
673 /// Calculate how many samples each stratum should contribute.
674 fn calculate_stratum_samples(&self, total_samples: usize) -> HashMap<usize, usize> {
675 let total_stratum_size: usize = self.strata.values().map(|v| v.len()).sum();
676 let mut stratum_samples = HashMap::new();
677
678 if self.proportional {
679 // Proportional to stratum size
680 for (&stratum_id, indices) in &self.strata {
681 let proportional_samples = (indices.len() * total_samples) / total_stratum_size;
682 let final_samples = proportional_samples.max(self.min_samples_per_stratum);
683 stratum_samples.insert(stratum_id, final_samples);
684 }
685 } else {
686 // Equal samples per stratum
687 let samples_per_stratum = total_samples / self.strata.len();
688 for &stratum_id in self.strata.keys() {
689 stratum_samples.insert(
690 stratum_id,
691 samples_per_stratum.max(self.min_samples_per_stratum),
692 );
693 }
694 }
695
696 stratum_samples
697 }
698}
699
700impl Sampler for StratifiedSampler {
701 type Iter = SamplerIterator;
702
703 fn iter(&self) -> Self::Iter {
704 let total_samples: usize = self.strata.values().map(|v| v.len()).sum();
705 let stratum_samples = self.calculate_stratum_samples(total_samples);
706
707 let mut rng = rng_utils::create_rng(self.generator);
708 let mut all_indices = Vec::new();
709
710 // Sample from each stratum
711 for (&stratum_id, indices) in &self.strata {
712 let target_samples = stratum_samples[&stratum_id];
713 let mut stratum_indices = indices.clone();
714 stratum_indices.shuffle(&mut rng);
715
716 // Take samples with replacement if needed
717 if target_samples <= indices.len() {
718 all_indices.extend(&stratum_indices[..target_samples]);
719 } else {
720 // Need sampling with replacement
721 all_indices.extend(&stratum_indices);
722 for _ in indices.len()..target_samples {
723 let idx = rng.gen_range(0..indices.len());
724 all_indices.push(indices[idx]);
725 }
726 }
727 }
728
729 // Final shuffle to mix strata
730 all_indices.shuffle(&mut rng);
731
732 SamplerIterator::new(all_indices)
733 }
734
735 fn len(&self) -> usize {
736 let total_samples: usize = self.strata.values().map(|v| v.len()).sum();
737 let stratum_samples = self.calculate_stratum_samples(total_samples);
738 stratum_samples.values().sum()
739 }
740}
741
742/// Importance sampler for adaptive sample selection.
743///
744/// This sampler selects samples based on importance scores, which can represent
745/// various metrics like loss values, prediction confidence, or gradient norms.
746/// High-importance samples are selected more frequently, making this ideal for
747/// active learning and hard negative mining.
748///
749/// # Applications
750///
751/// - **Active Learning**: Sample uncertain or informative examples
752/// - **Hard Negative Mining**: Focus on difficult examples
753/// - **Curriculum Learning**: Gradually increase sample difficulty
754/// - **Online Learning**: Adapt to changing data distributions
755#[derive(Debug, Clone)]
756pub struct ImportanceSampler {
757 importance_scores: Vec<f32>,
758 temperature: f32,
759 generator: Option<u64>,
760 adaptive: bool,
761 update_rate: f32,
762}
763
764impl ImportanceSampler {
765 /// Create a new importance sampler.
766 ///
767 /// # Arguments
768 ///
769 /// * `importance_scores` - Vector of importance values for each sample
770 ///
771 /// # Examples
772 ///
773 /// ```rust,ignore
774 /// use torsh_data::sampler::{Sampler, ImportanceSampler};
775 ///
776 /// // Higher scores = more important
777 /// let scores = vec![0.1, 0.8, 0.3, 0.9, 0.2];
778 /// let sampler = ImportanceSampler::new(scores);
779 ///
780 /// // Samples 1 and 3 will be selected more frequently
781 /// ```
782 pub fn new(importance_scores: Vec<f32>) -> Self {
783 assert!(
784 !importance_scores.is_empty(),
785 "Importance scores cannot be empty"
786 );
787
788 Self {
789 importance_scores,
790 temperature: 1.0,
791 generator: None,
792 adaptive: false,
793 update_rate: 0.1,
794 }
795 }
796
797 /// Set the temperature for importance sampling.
798 ///
799 /// Higher temperature makes sampling more uniform, lower temperature
800 /// makes it more focused on high-importance samples.
801 ///
802 /// # Arguments
803 ///
804 /// * `temperature` - Temperature parameter (> 0.0)
805 ///
806 /// # Examples
807 ///
808 /// ```rust,ignore
809 /// use torsh_data::sampler::ImportanceSampler;
810 ///
811 /// let scores = vec![0.1, 0.8, 0.3];
812 /// let sampler = ImportanceSampler::new(scores)
813 /// .with_temperature(2.0); // More uniform sampling
814 /// ```
815 pub fn with_temperature(mut self, temperature: f32) -> Self {
816 assert!(temperature > 0.0, "Temperature must be positive");
817 self.temperature = temperature;
818 self
819 }
820
821 /// Enable adaptive importance updates.
822 ///
823 /// When enabled, importance scores can be updated based on recent
824 /// sampling feedback to adapt to changing data characteristics.
825 ///
826 /// # Arguments
827 ///
828 /// * `adaptive` - Whether to enable adaptive updates
829 /// * `update_rate` - Rate of adaptation (0.0 to 1.0)
830 pub fn with_adaptive(mut self, adaptive: bool, update_rate: f32) -> Self {
831 assert!(
832 update_rate >= 0.0 && update_rate <= 1.0,
833 "Update rate must be in [0, 1]"
834 );
835 self.adaptive = adaptive;
836 self.update_rate = update_rate;
837 self
838 }
839
840 /// Set random generator seed.
841 pub fn with_generator(mut self, seed: u64) -> Self {
842 self.generator = Some(seed);
843 self
844 }
845
846 /// Get the importance scores.
847 pub fn importance_scores(&self) -> &[f32] {
848 &self.importance_scores
849 }
850
851 /// Get the temperature parameter.
852 pub fn temperature(&self) -> f32 {
853 self.temperature
854 }
855
856 /// Check if adaptive updates are enabled.
857 pub fn is_adaptive(&self) -> bool {
858 self.adaptive
859 }
860
861 /// Update importance scores (for adaptive sampling).
862 ///
863 /// # Arguments
864 ///
865 /// * `new_scores` - Updated importance scores
866 ///
867 /// # Examples
868 ///
869 /// ```rust,ignore
870 /// use torsh_data::sampler::ImportanceSampler;
871 ///
872 /// let mut sampler = ImportanceSampler::new(vec![0.1, 0.5, 0.3])
873 /// .with_adaptive(true, 0.1);
874 ///
875 /// // Update based on new loss values
876 /// let new_losses = vec![0.2, 0.8, 0.1];
877 /// sampler.update_importance_scores(new_losses);
878 /// ```
879 pub fn update_importance_scores(&mut self, new_scores: Vec<f32>) {
880 if self.adaptive && new_scores.len() == self.importance_scores.len() {
881 for (old, &new) in self.importance_scores.iter_mut().zip(new_scores.iter()) {
882 *old = (1.0 - self.update_rate) * *old + self.update_rate * new;
883 }
884 }
885 }
886
887 /// Convert importance scores to sampling probabilities.
888 fn compute_probabilities(&self) -> Vec<f32> {
889 // Apply temperature scaling
890 let scaled_scores: Vec<f32> = self
891 .importance_scores
892 .iter()
893 .map(|&score| (score / self.temperature).exp())
894 .collect();
895
896 // Normalize to probabilities
897 let total: f32 = scaled_scores.iter().sum();
898 if total > 0.0 {
899 scaled_scores.iter().map(|&score| score / total).collect()
900 } else {
901 // Fallback to uniform if all scores are zero
902 vec![1.0 / self.importance_scores.len() as f32; self.importance_scores.len()]
903 }
904 }
905}
906
907impl Sampler for ImportanceSampler {
908 type Iter = SamplerIterator;
909
910 fn iter(&self) -> Self::Iter {
911 let probabilities = self.compute_probabilities();
912 let mut weighted_sampler = WeightedRandomSampler::new(probabilities, false);
913
914 if let Some(seed) = self.generator {
915 weighted_sampler = weighted_sampler.with_generator(seed);
916 }
917
918 weighted_sampler.iter()
919 }
920
921 fn len(&self) -> usize {
922 self.importance_scores.len()
923 }
924}
925
926#[cfg(test)]
927mod tests {
928 use super::*;
929
930 // Mock dataset for testing
931 struct MockDataset {
932 size: usize,
933 }
934
935 impl crate::dataset::Dataset for MockDataset {
936 type Item = usize;
937
938 fn get(&self, index: usize) -> torsh_core::error::Result<Self::Item> {
939 if index < self.size {
940 Ok(index)
941 } else {
942 Err(torsh_core::error::TorshError::IndexOutOfBounds {
943 index,
944 size: self.size,
945 })
946 }
947 }
948
949 fn len(&self) -> usize {
950 self.size
951 }
952 }
953
954 #[test]
955 fn test_weighted_random_sampler() {
956 let weights = vec![0.1, 0.3, 0.6]; // Unnormalized weights
957 let sampler = WeightedRandomSampler::new(weights.clone(), true).with_generator(42);
958
959 assert_eq!(sampler.len(), 3);
960 assert_eq!(sampler.weights(), &weights);
961 assert!(sampler.uses_replacement());
962 assert_eq!(sampler.generator_seed(), Some(42));
963
964 let indices: Vec<usize> = sampler.iter().collect();
965 assert_eq!(indices.len(), 3);
966 assert!(indices.iter().all(|&i| i < 3));
967 }
968
969 #[test]
970 fn test_weighted_sampler_deterministic() {
971 let weights = vec![1.0, 2.0, 3.0];
972 let sampler1 = WeightedRandomSampler::new(weights.clone(), true).with_generator(123);
973 let sampler2 = WeightedRandomSampler::new(weights, true).with_generator(123);
974
975 let indices1: Vec<usize> = sampler1.iter().collect();
976 let indices2: Vec<usize> = sampler2.iter().collect();
977
978 assert_eq!(indices1, indices2);
979 }
980
981 #[test]
982 fn test_alias_table() {
983 let weights = vec![1.0, 2.0, 3.0];
984 let table = AliasTable::new(&weights);
985
986 assert_eq!(table.prob.len(), 3);
987 assert_eq!(table.alias.len(), 3);
988
989 let mut rng = rng_utils::create_rng(Some(42));
990
991 // Sample multiple times to check basic functionality
992 let mut counts = vec![0; 3];
993 for _ in 0..1000 {
994 let sample = table.sample(&mut rng);
995 assert!(sample < 3);
996 counts[sample] += 1;
997 }
998
999 // Higher weights should have higher counts (approximately)
1000 assert!(counts[2] > counts[1]); // Weight 3 > Weight 2
1001 assert!(counts[1] > counts[0]); // Weight 2 > Weight 1
1002 }
1003
1004 #[test]
1005 fn test_grouped_sampler() {
1006 let dataset = MockDataset { size: 12 };
1007 let group_fn = |idx: usize| idx % 3; // 3 groups
1008
1009 let sampler = GroupedSampler::new(&dataset, group_fn)
1010 .with_shuffle_groups(false)
1011 .with_shuffle_within_groups(false);
1012
1013 assert_eq!(sampler.len(), 12);
1014 assert_eq!(sampler.num_groups(), 3);
1015 assert_eq!(sampler.group_sizes(), vec![4, 4, 4]); // 12 / 3 = 4 each
1016
1017 let indices: Vec<usize> = sampler.iter().collect();
1018 assert_eq!(indices.len(), 12);
1019
1020 // Without shuffling, should maintain group order
1021 // Group 0: [0, 3, 6, 9], Group 1: [1, 4, 7, 10], Group 2: [2, 5, 8, 11]
1022 }
1023
1024 #[test]
1025 fn test_grouped_sampler_with_shuffling() {
1026 let dataset = MockDataset { size: 9 };
1027 let group_fn = |idx: usize| idx % 3;
1028
1029 let sampler = GroupedSampler::new(&dataset, group_fn)
1030 .with_shuffle_groups(true)
1031 .with_shuffle_within_groups(true)
1032 .with_generator(42);
1033
1034 let indices1: Vec<usize> = sampler.iter().collect();
1035 let indices2: Vec<usize> = sampler.iter().collect();
1036
1037 // Should be deterministic with same seed
1038 assert_eq!(indices1, indices2);
1039 assert_eq!(indices1.len(), 9);
1040
1041 // Should contain all original indices
1042 let mut sorted_indices = indices1;
1043 sorted_indices.sort();
1044 assert_eq!(sorted_indices, (0..9).collect::<Vec<_>>());
1045 }
1046
1047 #[test]
1048 fn test_stratified_sampler() {
1049 let class_labels = vec![0, 0, 1, 1, 1, 2]; // 2 class 0, 3 class 1, 1 class 2
1050 let sampler = StratifiedSampler::new(class_labels)
1051 .with_proportional(true)
1052 .with_generator(42);
1053
1054 assert_eq!(sampler.num_strata(), 3);
1055 assert!(sampler.uses_proportional());
1056
1057 let stratum_sizes = sampler.stratum_sizes();
1058 assert_eq!(stratum_sizes[&0], 2);
1059 assert_eq!(stratum_sizes[&1], 3);
1060 assert_eq!(stratum_sizes[&2], 1);
1061
1062 let indices: Vec<usize> = sampler.iter().collect();
1063 assert!(!indices.is_empty());
1064 }
1065
1066 #[test]
1067 fn test_stratified_sampler_balanced() {
1068 let class_labels = vec![0, 0, 1, 1, 1, 2];
1069 let sampler = StratifiedSampler::new(class_labels)
1070 .with_proportional(false) // Equal samples per stratum
1071 .with_min_samples_per_stratum(2)
1072 .with_generator(42);
1073
1074 assert!(!sampler.uses_proportional());
1075
1076 let indices: Vec<usize> = sampler.iter().collect();
1077 assert!(!indices.is_empty());
1078 }
1079
1080 #[test]
1081 fn test_stratified_sampler_from_strata() {
1082 let mut strata = HashMap::new();
1083 strata.insert(0, vec![0, 1]);
1084 strata.insert(1, vec![2, 3, 4]);
1085 strata.insert(2, vec![5]);
1086
1087 let sampler = StratifiedSampler::from_strata(strata);
1088 assert_eq!(sampler.num_strata(), 3);
1089
1090 let indices: Vec<usize> = sampler.iter().collect();
1091 assert!(!indices.is_empty());
1092 }
1093
1094 #[test]
1095 fn test_importance_sampler() {
1096 let scores = vec![0.1, 0.8, 0.3, 0.9, 0.2];
1097 let sampler = ImportanceSampler::new(scores.clone())
1098 .with_temperature(1.0)
1099 .with_generator(42);
1100
1101 assert_eq!(sampler.len(), 5);
1102 assert_eq!(sampler.importance_scores(), &scores);
1103 assert_eq!(sampler.temperature(), 1.0);
1104 assert!(!sampler.is_adaptive());
1105
1106 let indices: Vec<usize> = sampler.iter().collect();
1107 assert_eq!(indices.len(), 5);
1108 assert!(indices.iter().all(|&i| i < 5));
1109 }
1110
1111 #[test]
1112 fn test_importance_sampler_temperature() {
1113 let scores = vec![0.1, 1.0, 0.1]; // One very high score
1114
1115 // Low temperature - should heavily favor high-importance sample
1116 let low_temp_sampler = ImportanceSampler::new(scores.clone())
1117 .with_temperature(0.1)
1118 .with_generator(42);
1119
1120 // High temperature - should be more uniform
1121 let high_temp_sampler = ImportanceSampler::new(scores)
1122 .with_temperature(10.0)
1123 .with_generator(42);
1124
1125 // Both should work without panicking
1126 let _low_indices: Vec<usize> = low_temp_sampler.iter().collect();
1127 let _high_indices: Vec<usize> = high_temp_sampler.iter().collect();
1128 }
1129
1130 #[test]
1131 fn test_importance_sampler_adaptive() {
1132 let scores = vec![0.1, 0.5, 0.3];
1133 let mut sampler = ImportanceSampler::new(scores)
1134 .with_adaptive(true, 0.2)
1135 .with_generator(42);
1136
1137 assert!(sampler.is_adaptive());
1138
1139 let original_scores = sampler.importance_scores().to_vec();
1140
1141 // Update scores
1142 let new_scores = vec![0.2, 0.8, 0.1];
1143 sampler.update_importance_scores(new_scores);
1144
1145 let updated_scores = sampler.importance_scores().to_vec();
1146 assert_ne!(original_scores, updated_scores);
1147
1148 // The updated scores should be a blend of old and new
1149 for i in 0..3 {
1150 assert!(updated_scores[i] != original_scores[i]);
1151 }
1152 }
1153
1154 #[test]
1155 #[should_panic(expected = "Weights vector cannot be empty")]
1156 fn test_weighted_sampler_empty_weights() {
1157 WeightedRandomSampler::new(vec![], true);
1158 }
1159
1160 #[test]
1161 #[should_panic(expected = "At least one weight must be positive")]
1162 fn test_weighted_sampler_zero_weights() {
1163 WeightedRandomSampler::new(vec![0.0, 0.0, 0.0], true);
1164 }
1165
1166 #[test]
1167 #[should_panic(expected = "Temperature must be positive")]
1168 fn test_importance_sampler_zero_temperature() {
1169 let scores = vec![0.1, 0.2, 0.3];
1170 ImportanceSampler::new(scores).with_temperature(0.0);
1171 }
1172
1173 #[test]
1174 #[should_panic(expected = "Importance scores cannot be empty")]
1175 fn test_importance_sampler_empty_scores() {
1176 ImportanceSampler::new(vec![]);
1177 }
1178
1179 #[test]
1180 fn test_importance_sampler_probabilities() {
1181 let scores = vec![1.0, 2.0, 3.0];
1182 let sampler = ImportanceSampler::new(scores).with_temperature(1.0);
1183
1184 let probabilities = sampler.compute_probabilities();
1185 assert_eq!(probabilities.len(), 3);
1186
1187 // Probabilities should sum to 1 (approximately)
1188 let sum: f32 = probabilities.iter().sum();
1189 assert!((sum - 1.0).abs() < 1e-6);
1190
1191 // Higher scores should have higher probabilities
1192 assert!(probabilities[2] > probabilities[1]);
1193 assert!(probabilities[1] > probabilities[0]);
1194 }
1195}