torsh_data/sampler/
basic.rs

1//! Basic sampling strategies
2//!
3//! This module provides fundamental sampling implementations including
4//! sequential and random sampling patterns.
5
6#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9// ✅ SciRS2 Policy Compliant - Using scirs2_core for all random operations
10use scirs2_core::random::{Random, Rng};
11
12use super::core::{Sampler, SamplerIterator};
13
14/// Sequential sampler that yields indices in order
15///
16/// This sampler produces indices from 0 to dataset_size-1 in sequential order.
17/// Useful for deterministic iteration over datasets.
18///
19/// # Examples
20///
21/// ```rust
22/// use torsh_data::sampler::{SequentialSampler, Sampler};
23///
24/// let sampler = SequentialSampler::new(5);
25/// let indices: Vec<usize> = sampler.iter().collect();
26/// assert_eq!(indices, vec![0, 1, 2, 3, 4]);
27/// ```
28#[derive(Debug, Clone)]
29pub struct SequentialSampler {
30    dataset_size: usize,
31}
32
33impl SequentialSampler {
34    /// Create a new sequential sampler
35    ///
36    /// # Arguments
37    ///
38    /// * `dataset_size` - Number of samples in the dataset (can be 0 for empty datasets)
39    pub fn new(dataset_size: usize) -> Self {
40        Self { dataset_size }
41    }
42
43    /// Get the dataset size
44    pub fn dataset_size(&self) -> usize {
45        self.dataset_size
46    }
47}
48
49impl Sampler for SequentialSampler {
50    type Iter = SamplerIterator;
51
52    fn iter(&self) -> Self::Iter {
53        SamplerIterator::from_range(0, self.dataset_size)
54    }
55
56    fn len(&self) -> usize {
57        self.dataset_size
58    }
59}
60
61/// Random sampler that yields indices in random order
62///
63/// This sampler shuffles the indices and yields them in random order.
64/// Can optionally sample with or without replacement and control the
65/// number of samples returned.
66///
67/// # Examples
68///
69/// ```rust
70/// use torsh_data::sampler::{RandomSampler, Sampler};
71///
72/// // Sample all indices in random order
73/// let sampler = RandomSampler::new(5, None, false).with_generator(42);
74/// let indices: Vec<usize> = sampler.iter().collect();
75/// assert_eq!(indices.len(), 5);
76///
77/// // Sample 3 indices without replacement
78/// let sampler = RandomSampler::new(10, Some(3), false).with_generator(42);
79/// let indices: Vec<usize> = sampler.iter().collect();
80/// assert_eq!(indices.len(), 3);
81/// ```
82#[derive(Debug, Clone)]
83pub struct RandomSampler {
84    dataset_size: usize,
85    num_samples: Option<usize>,
86    replacement: bool,
87    generator: Option<u64>,
88}
89
90impl RandomSampler {
91    /// Create a new random sampler
92    ///
93    /// # Arguments
94    ///
95    /// * `dataset_size` - Number of samples in the dataset
96    /// * `num_samples` - Number of samples to yield (None for all)
97    /// * `replacement` - Whether to sample with replacement
98    ///
99    /// # Panics
100    ///
101    /// Panics if `dataset_size` is 0 or if sampling without replacement
102    /// but `num_samples` > `dataset_size`
103    pub fn new(dataset_size: usize, num_samples: Option<usize>, replacement: bool) -> Self {
104        let actual_num_samples = num_samples.unwrap_or(dataset_size);
105
106        super::core::utils::validate_sampling_params(
107            dataset_size,
108            Some(actual_num_samples),
109            replacement,
110        )
111        .expect("Invalid sampling parameters");
112
113        Self {
114            dataset_size,
115            num_samples,
116            replacement,
117            generator: None,
118        }
119    }
120
121    /// Create a simple random sampler with default settings (no replacement, all samples)
122    ///
123    /// # Arguments
124    ///
125    /// * `dataset_size` - Number of samples in the dataset
126    ///
127    /// # Panics
128    ///
129    /// Panics if `dataset_size` is 0
130    pub fn simple(dataset_size: usize) -> Self {
131        Self::new(dataset_size, None, false)
132    }
133
134    /// Create a random sampler with specific replacement setting
135    ///
136    /// # Arguments
137    ///
138    /// * `dataset_size` - Number of samples in the dataset
139    /// * `replacement` - Whether to sample with replacement
140    /// * `num_samples` - Number of samples to yield (None for all)
141    ///
142    /// # Panics
143    ///
144    /// Panics if `dataset_size` is 0
145    pub fn with_replacement(
146        dataset_size: usize,
147        replacement: bool,
148        num_samples: Option<usize>,
149    ) -> Self {
150        Self::new(dataset_size, num_samples, replacement)
151    }
152
153    /// Set the random number generator seed
154    ///
155    /// # Arguments
156    ///
157    /// * `seed` - Random seed for reproducible sampling
158    pub fn with_generator(mut self, seed: u64) -> Self {
159        self.generator = Some(seed);
160        self
161    }
162
163    /// Get the dataset size
164    pub fn dataset_size(&self) -> usize {
165        self.dataset_size
166    }
167
168    /// Get the number of samples that will be yielded
169    pub fn num_samples(&self) -> usize {
170        self.num_samples.unwrap_or(self.dataset_size)
171    }
172
173    /// Check if sampling with replacement
174    pub fn replacement(&self) -> bool {
175        self.replacement
176    }
177
178    /// Get the generator seed if set
179    pub fn generator(&self) -> Option<u64> {
180        self.generator
181    }
182}
183
184impl Sampler for RandomSampler {
185    type Iter = SamplerIterator;
186
187    fn iter(&self) -> Self::Iter {
188        let num_samples = self.num_samples();
189
190        if self.replacement {
191            self.iter_with_replacement(num_samples)
192        } else {
193            self.iter_without_replacement(num_samples)
194        }
195    }
196
197    fn len(&self) -> usize {
198        self.num_samples()
199    }
200}
201
202impl RandomSampler {
203    /// Generate iterator for sampling with replacement
204    fn iter_with_replacement(&self, num_samples: usize) -> SamplerIterator {
205        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
206        let mut rng = match self.generator {
207            Some(seed) => Random::seed(seed),
208            None => Random::seed(42),
209        };
210
211        let indices: Vec<usize> = (0..num_samples)
212            .map(|_| rng.gen_range(0..self.dataset_size))
213            .collect();
214
215        SamplerIterator::new(indices)
216    }
217
218    /// Generate iterator for sampling without replacement
219    fn iter_without_replacement(&self, num_samples: usize) -> SamplerIterator {
220        if num_samples == self.dataset_size {
221            // Return all indices shuffled
222            let indices: Vec<usize> = (0..self.dataset_size).collect();
223            SamplerIterator::shuffled(indices, self.generator)
224        } else {
225            // Use utility function for efficient sampling
226            let indices =
227                super::core::utils::random_indices(self.dataset_size, num_samples, self.generator);
228            SamplerIterator::new(indices)
229        }
230    }
231}
232
233/// Create a sequential sampler
234///
235/// Convenience function for creating a sequential sampler.
236///
237/// # Arguments
238///
239/// * `dataset_size` - Number of samples in the dataset
240pub fn sequential(dataset_size: usize) -> SequentialSampler {
241    SequentialSampler::new(dataset_size)
242}
243
244/// Create a random sampler
245///
246/// Convenience function for creating a random sampler that yields all
247/// indices in random order without replacement.
248///
249/// # Arguments
250///
251/// * `dataset_size` - Number of samples in the dataset
252/// * `seed` - Optional random seed for reproducible sampling
253pub fn random(dataset_size: usize, seed: Option<u64>) -> RandomSampler {
254    let mut sampler = RandomSampler::new(dataset_size, None, false);
255    if let Some(s) = seed {
256        sampler = sampler.with_generator(s);
257    }
258    sampler
259}
260
261/// Create a random sampler with replacement
262///
263/// Convenience function for creating a random sampler that samples
264/// with replacement.
265///
266/// # Arguments
267///
268/// * `dataset_size` - Number of samples in the dataset
269/// * `num_samples` - Number of samples to yield
270/// * `seed` - Optional random seed for reproducible sampling
271pub fn random_with_replacement(
272    dataset_size: usize,
273    num_samples: usize,
274    seed: Option<u64>,
275) -> RandomSampler {
276    let mut sampler = RandomSampler::new(dataset_size, Some(num_samples), true);
277    if let Some(s) = seed {
278        sampler = sampler.with_generator(s);
279    }
280    sampler
281}
282
283/// Create a random subset sampler
284///
285/// Convenience function for creating a random sampler that yields
286/// a subset of indices without replacement.
287///
288/// # Arguments
289///
290/// * `dataset_size` - Number of samples in the dataset
291/// * `num_samples` - Number of samples to yield
292/// * `seed` - Optional random seed for reproducible sampling
293pub fn random_subset(dataset_size: usize, num_samples: usize, seed: Option<u64>) -> RandomSampler {
294    let mut sampler = RandomSampler::new(dataset_size, Some(num_samples), false);
295    if let Some(s) = seed {
296        sampler = sampler.with_generator(s);
297    }
298    sampler
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_sequential_sampler() {
307        let sampler = SequentialSampler::new(5);
308        assert_eq!(sampler.len(), 5);
309        assert_eq!(sampler.dataset_size(), 5);
310        assert!(!sampler.is_empty());
311
312        let indices: Vec<usize> = sampler.iter().collect();
313        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
314    }
315
316    #[test]
317    fn test_sequential_sampler_zero_size() {
318        // Zero-size datasets are now allowed for empty datasets
319        let sampler = SequentialSampler::new(0);
320        assert_eq!(sampler.dataset_size(), 0);
321
322        // Empty sampler should produce no indices
323        let indices: Vec<usize> = sampler.iter().collect();
324        assert_eq!(indices.len(), 0);
325    }
326
327    #[test]
328    fn test_random_sampler_all_indices() {
329        let sampler = RandomSampler::new(5, None, false).with_generator(42);
330        assert_eq!(sampler.len(), 5);
331        assert_eq!(sampler.dataset_size(), 5);
332        assert_eq!(sampler.num_samples(), 5);
333        assert!(!sampler.replacement());
334        assert_eq!(sampler.generator(), Some(42));
335
336        let indices: Vec<usize> = sampler.iter().collect();
337        assert_eq!(indices.len(), 5);
338
339        // All indices 0-4 should be present
340        let mut sorted_indices = indices.clone();
341        sorted_indices.sort();
342        assert_eq!(sorted_indices, vec![0, 1, 2, 3, 4]);
343    }
344
345    #[test]
346    fn test_random_sampler_subset() {
347        let sampler = RandomSampler::new(10, Some(3), false).with_generator(42);
348        assert_eq!(sampler.len(), 3);
349        assert_eq!(sampler.num_samples(), 3);
350
351        let indices: Vec<usize> = sampler.iter().collect();
352        assert_eq!(indices.len(), 3);
353
354        // All indices should be unique and in range
355        let mut unique_indices = indices.clone();
356        unique_indices.sort();
357        unique_indices.dedup();
358        assert_eq!(unique_indices.len(), 3);
359
360        for &idx in &indices {
361            assert!(idx < 10);
362        }
363    }
364
365    #[test]
366    fn test_random_sampler_with_replacement() {
367        let sampler = RandomSampler::new(3, Some(10), true).with_generator(42);
368        assert_eq!(sampler.len(), 10);
369        assert_eq!(sampler.num_samples(), 10);
370        assert!(sampler.replacement());
371
372        let indices: Vec<usize> = sampler.iter().collect();
373        assert_eq!(indices.len(), 10);
374
375        // All indices should be in range (but may be duplicated)
376        for &idx in &indices {
377            assert!(idx < 3);
378        }
379    }
380
381    #[test]
382    #[should_panic(expected = "Invalid sampling parameters")]
383    fn test_random_sampler_invalid_no_replacement() {
384        RandomSampler::new(5, Some(10), false);
385    }
386
387    #[test]
388    fn test_random_sampler_reproducible() {
389        let sampler1 = RandomSampler::new(10, Some(5), false).with_generator(42);
390        let sampler2 = RandomSampler::new(10, Some(5), false).with_generator(42);
391
392        let indices1: Vec<usize> = sampler1.iter().collect();
393        let indices2: Vec<usize> = sampler2.iter().collect();
394
395        assert_eq!(indices1, indices2);
396    }
397
398    #[test]
399    fn test_convenience_functions() {
400        let seq = sequential(5);
401        assert_eq!(seq.len(), 5);
402
403        let rand = random(5, Some(42));
404        assert_eq!(rand.len(), 5);
405        assert!(!rand.replacement());
406
407        let rand_repl = random_with_replacement(3, 10, Some(42));
408        assert_eq!(rand_repl.len(), 10);
409        assert!(rand_repl.replacement());
410
411        let subset = random_subset(10, 3, Some(42));
412        assert_eq!(subset.len(), 3);
413        assert!(!subset.replacement());
414    }
415
416    #[test]
417    fn test_random_sampler_clone() {
418        let sampler = RandomSampler::new(5, Some(3), false).with_generator(42);
419        let cloned = sampler.clone();
420
421        assert_eq!(sampler.len(), cloned.len());
422        assert_eq!(sampler.dataset_size(), cloned.dataset_size());
423        assert_eq!(sampler.replacement(), cloned.replacement());
424        assert_eq!(sampler.generator(), cloned.generator());
425    }
426
427    #[test]
428    fn test_edge_cases() {
429        // Single element dataset
430        let seq = SequentialSampler::new(1);
431        let indices: Vec<usize> = seq.iter().collect();
432        assert_eq!(indices, vec![0]);
433
434        let rand = RandomSampler::new(1, None, false);
435        let indices: Vec<usize> = rand.iter().collect();
436        assert_eq!(indices, vec![0]);
437
438        // Sample 0 items (should be allowed for with replacement)
439        let rand_zero = RandomSampler::new(5, Some(0), true);
440        assert_eq!(rand_zero.len(), 0);
441        assert!(rand_zero.is_empty());
442
443        let indices: Vec<usize> = rand_zero.iter().collect();
444        assert_eq!(indices.len(), 0);
445    }
446
447    #[test]
448    fn test_iterator_properties() {
449        let sampler = RandomSampler::new(5, Some(3), false).with_generator(42);
450        let mut iter = sampler.iter();
451
452        // Test size hint
453        assert_eq!(iter.size_hint(), (3, Some(3)));
454
455        // Test exact size
456        assert_eq!(iter.len(), 3);
457
458        // Consume one item
459        iter.next();
460        assert_eq!(iter.len(), 2);
461    }
462}