1#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9use scirs2_core::random::{Random, Rng};
11
12use super::core::{Sampler, SamplerIterator};
13
14#[derive(Debug, Clone)]
29pub struct SequentialSampler {
30 dataset_size: usize,
31}
32
33impl SequentialSampler {
34 pub fn new(dataset_size: usize) -> Self {
40 Self { dataset_size }
41 }
42
43 pub fn dataset_size(&self) -> usize {
45 self.dataset_size
46 }
47}
48
49impl Sampler for SequentialSampler {
50 type Iter = SamplerIterator;
51
52 fn iter(&self) -> Self::Iter {
53 SamplerIterator::from_range(0, self.dataset_size)
54 }
55
56 fn len(&self) -> usize {
57 self.dataset_size
58 }
59}
60
61#[derive(Debug, Clone)]
83pub struct RandomSampler {
84 dataset_size: usize,
85 num_samples: Option<usize>,
86 replacement: bool,
87 generator: Option<u64>,
88}
89
90impl RandomSampler {
91 pub fn new(dataset_size: usize, num_samples: Option<usize>, replacement: bool) -> Self {
104 let actual_num_samples = num_samples.unwrap_or(dataset_size);
105
106 super::core::utils::validate_sampling_params(
107 dataset_size,
108 Some(actual_num_samples),
109 replacement,
110 )
111 .expect("Invalid sampling parameters");
112
113 Self {
114 dataset_size,
115 num_samples,
116 replacement,
117 generator: None,
118 }
119 }
120
121 pub fn simple(dataset_size: usize) -> Self {
131 Self::new(dataset_size, None, false)
132 }
133
134 pub fn with_replacement(
146 dataset_size: usize,
147 replacement: bool,
148 num_samples: Option<usize>,
149 ) -> Self {
150 Self::new(dataset_size, num_samples, replacement)
151 }
152
153 pub fn with_generator(mut self, seed: u64) -> Self {
159 self.generator = Some(seed);
160 self
161 }
162
163 pub fn dataset_size(&self) -> usize {
165 self.dataset_size
166 }
167
168 pub fn num_samples(&self) -> usize {
170 self.num_samples.unwrap_or(self.dataset_size)
171 }
172
173 pub fn replacement(&self) -> bool {
175 self.replacement
176 }
177
178 pub fn generator(&self) -> Option<u64> {
180 self.generator
181 }
182}
183
184impl Sampler for RandomSampler {
185 type Iter = SamplerIterator;
186
187 fn iter(&self) -> Self::Iter {
188 let num_samples = self.num_samples();
189
190 if self.replacement {
191 self.iter_with_replacement(num_samples)
192 } else {
193 self.iter_without_replacement(num_samples)
194 }
195 }
196
197 fn len(&self) -> usize {
198 self.num_samples()
199 }
200}
201
202impl RandomSampler {
203 fn iter_with_replacement(&self, num_samples: usize) -> SamplerIterator {
205 let mut rng = match self.generator {
207 Some(seed) => Random::seed(seed),
208 None => Random::seed(42),
209 };
210
211 let indices: Vec<usize> = (0..num_samples)
212 .map(|_| rng.gen_range(0..self.dataset_size))
213 .collect();
214
215 SamplerIterator::new(indices)
216 }
217
218 fn iter_without_replacement(&self, num_samples: usize) -> SamplerIterator {
220 if num_samples == self.dataset_size {
221 let indices: Vec<usize> = (0..self.dataset_size).collect();
223 SamplerIterator::shuffled(indices, self.generator)
224 } else {
225 let indices =
227 super::core::utils::random_indices(self.dataset_size, num_samples, self.generator);
228 SamplerIterator::new(indices)
229 }
230 }
231}
232
233pub fn sequential(dataset_size: usize) -> SequentialSampler {
241 SequentialSampler::new(dataset_size)
242}
243
244pub fn random(dataset_size: usize, seed: Option<u64>) -> RandomSampler {
254 let mut sampler = RandomSampler::new(dataset_size, None, false);
255 if let Some(s) = seed {
256 sampler = sampler.with_generator(s);
257 }
258 sampler
259}
260
261pub fn random_with_replacement(
272 dataset_size: usize,
273 num_samples: usize,
274 seed: Option<u64>,
275) -> RandomSampler {
276 let mut sampler = RandomSampler::new(dataset_size, Some(num_samples), true);
277 if let Some(s) = seed {
278 sampler = sampler.with_generator(s);
279 }
280 sampler
281}
282
283pub fn random_subset(dataset_size: usize, num_samples: usize, seed: Option<u64>) -> RandomSampler {
294 let mut sampler = RandomSampler::new(dataset_size, Some(num_samples), false);
295 if let Some(s) = seed {
296 sampler = sampler.with_generator(s);
297 }
298 sampler
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_sequential_sampler() {
307 let sampler = SequentialSampler::new(5);
308 assert_eq!(sampler.len(), 5);
309 assert_eq!(sampler.dataset_size(), 5);
310 assert!(!sampler.is_empty());
311
312 let indices: Vec<usize> = sampler.iter().collect();
313 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
314 }
315
316 #[test]
317 fn test_sequential_sampler_zero_size() {
318 let sampler = SequentialSampler::new(0);
320 assert_eq!(sampler.dataset_size(), 0);
321
322 let indices: Vec<usize> = sampler.iter().collect();
324 assert_eq!(indices.len(), 0);
325 }
326
327 #[test]
328 fn test_random_sampler_all_indices() {
329 let sampler = RandomSampler::new(5, None, false).with_generator(42);
330 assert_eq!(sampler.len(), 5);
331 assert_eq!(sampler.dataset_size(), 5);
332 assert_eq!(sampler.num_samples(), 5);
333 assert!(!sampler.replacement());
334 assert_eq!(sampler.generator(), Some(42));
335
336 let indices: Vec<usize> = sampler.iter().collect();
337 assert_eq!(indices.len(), 5);
338
339 let mut sorted_indices = indices.clone();
341 sorted_indices.sort();
342 assert_eq!(sorted_indices, vec![0, 1, 2, 3, 4]);
343 }
344
345 #[test]
346 fn test_random_sampler_subset() {
347 let sampler = RandomSampler::new(10, Some(3), false).with_generator(42);
348 assert_eq!(sampler.len(), 3);
349 assert_eq!(sampler.num_samples(), 3);
350
351 let indices: Vec<usize> = sampler.iter().collect();
352 assert_eq!(indices.len(), 3);
353
354 let mut unique_indices = indices.clone();
356 unique_indices.sort();
357 unique_indices.dedup();
358 assert_eq!(unique_indices.len(), 3);
359
360 for &idx in &indices {
361 assert!(idx < 10);
362 }
363 }
364
365 #[test]
366 fn test_random_sampler_with_replacement() {
367 let sampler = RandomSampler::new(3, Some(10), true).with_generator(42);
368 assert_eq!(sampler.len(), 10);
369 assert_eq!(sampler.num_samples(), 10);
370 assert!(sampler.replacement());
371
372 let indices: Vec<usize> = sampler.iter().collect();
373 assert_eq!(indices.len(), 10);
374
375 for &idx in &indices {
377 assert!(idx < 3);
378 }
379 }
380
381 #[test]
382 #[should_panic(expected = "Invalid sampling parameters")]
383 fn test_random_sampler_invalid_no_replacement() {
384 RandomSampler::new(5, Some(10), false);
385 }
386
387 #[test]
388 fn test_random_sampler_reproducible() {
389 let sampler1 = RandomSampler::new(10, Some(5), false).with_generator(42);
390 let sampler2 = RandomSampler::new(10, Some(5), false).with_generator(42);
391
392 let indices1: Vec<usize> = sampler1.iter().collect();
393 let indices2: Vec<usize> = sampler2.iter().collect();
394
395 assert_eq!(indices1, indices2);
396 }
397
398 #[test]
399 fn test_convenience_functions() {
400 let seq = sequential(5);
401 assert_eq!(seq.len(), 5);
402
403 let rand = random(5, Some(42));
404 assert_eq!(rand.len(), 5);
405 assert!(!rand.replacement());
406
407 let rand_repl = random_with_replacement(3, 10, Some(42));
408 assert_eq!(rand_repl.len(), 10);
409 assert!(rand_repl.replacement());
410
411 let subset = random_subset(10, 3, Some(42));
412 assert_eq!(subset.len(), 3);
413 assert!(!subset.replacement());
414 }
415
416 #[test]
417 fn test_random_sampler_clone() {
418 let sampler = RandomSampler::new(5, Some(3), false).with_generator(42);
419 let cloned = sampler.clone();
420
421 assert_eq!(sampler.len(), cloned.len());
422 assert_eq!(sampler.dataset_size(), cloned.dataset_size());
423 assert_eq!(sampler.replacement(), cloned.replacement());
424 assert_eq!(sampler.generator(), cloned.generator());
425 }
426
427 #[test]
428 fn test_edge_cases() {
429 let seq = SequentialSampler::new(1);
431 let indices: Vec<usize> = seq.iter().collect();
432 assert_eq!(indices, vec![0]);
433
434 let rand = RandomSampler::new(1, None, false);
435 let indices: Vec<usize> = rand.iter().collect();
436 assert_eq!(indices, vec![0]);
437
438 let rand_zero = RandomSampler::new(5, Some(0), true);
440 assert_eq!(rand_zero.len(), 0);
441 assert!(rand_zero.is_empty());
442
443 let indices: Vec<usize> = rand_zero.iter().collect();
444 assert_eq!(indices.len(), 0);
445 }
446
447 #[test]
448 fn test_iterator_properties() {
449 let sampler = RandomSampler::new(5, Some(3), false).with_generator(42);
450 let mut iter = sampler.iter();
451
452 assert_eq!(iter.size_hint(), (3, Some(3)));
454
455 assert_eq!(iter.len(), 3);
457
458 iter.next();
460 assert_eq!(iter.len(), 2);
461 }
462}