1#[cfg(not(feature = "std"))]
6use alloc::{boxed::Box, vec::Vec};
7
8use scirs2_core::random::Random;
10
11pub(crate) mod rng_utils {
13 use super::*;
14
15 pub fn create_rng(seed: Option<u64>) -> Random<scirs2_core::rngs::StdRng> {
17 if let Some(seed) = seed {
18 Random::seed(seed)
19 } else {
20 Random::seed(42) }
22 }
23
24 pub fn shuffle_indices<T: Clone>(indices: &mut [T], seed: Option<u64>) {
26 let mut rng = create_rng(seed);
27
28 for i in (1..indices.len()).rev() {
30 let j = rng.gen_range(0..=i);
31 indices.swap(i, j);
32 }
33 }
34
35 pub fn gen_range(
37 rng: &mut Random<scirs2_core::rngs::StdRng>,
38 range: std::ops::Range<usize>,
39 ) -> usize {
40 rng.gen_range(range)
41 }
42}
43
44pub trait Sampler: Send {
49 type Iter: Iterator<Item = usize> + Send;
51
52 fn iter(&self) -> Self::Iter;
54
55 fn len(&self) -> usize;
57
58 fn is_empty(&self) -> bool {
60 self.len() == 0
61 }
62
63 fn into_batch_sampler(
65 self,
66 batch_size: usize,
67 drop_last: bool,
68 ) -> super::batch::BatchingSampler<Self>
69 where
70 Self: Sized,
71 {
72 super::batch::BatchingSampler::new(self, batch_size, drop_last)
73 }
74
75 fn into_distributed(
77 self,
78 num_replicas: usize,
79 rank: usize,
80 ) -> super::distributed::DistributedWrapper<Self>
81 where
82 Self: Sized,
83 {
84 super::distributed::DistributedWrapper::new(self, num_replicas, rank)
85 }
86}
87
88pub trait BatchSampler: Send {
90 type Iter: Iterator<Item = Vec<usize>> + Send;
92
93 fn iter(&self) -> Self::Iter;
95
96 fn num_batches(&self) -> usize;
98
99 fn len(&self) -> usize {
101 self.num_batches()
102 }
103
104 fn is_empty(&self) -> bool {
106 self.num_batches() == 0
107 }
108}
109
110pub struct SamplerIterator {
112 indices: Vec<usize>,
113 position: usize,
114}
115
116impl SamplerIterator {
117 pub fn new(indices: Vec<usize>) -> Self {
119 Self {
120 indices,
121 position: 0,
122 }
123 }
124
125 pub fn from_range(start: usize, end: usize) -> Self {
127 Self::new((start..end).collect())
128 }
129
130 pub fn shuffled(mut indices: Vec<usize>, seed: Option<u64>) -> Self {
132 let mut rng = match seed {
134 Some(s) => Random::seed(s),
135 None => Random::seed(42), };
137
138 for i in (1..indices.len()).rev() {
140 let j = rng.gen_range(0..=i);
141 indices.swap(i, j);
142 }
143
144 Self::new(indices)
145 }
146
147 pub fn remaining(&self) -> usize {
149 self.indices.len() - self.position
150 }
151}
152
153impl Iterator for SamplerIterator {
154 type Item = usize;
155
156 fn next(&mut self) -> Option<Self::Item> {
157 if self.position < self.indices.len() {
158 let item = self.indices[self.position];
159 self.position += 1;
160 Some(item)
161 } else {
162 None
163 }
164 }
165
166 fn size_hint(&self) -> (usize, Option<usize>) {
167 let remaining = self.remaining();
168 (remaining, Some(remaining))
169 }
170}
171
172impl ExactSizeIterator for SamplerIterator {
173 fn len(&self) -> usize {
174 self.remaining()
175 }
176}
177
178pub mod utils {
180 use super::*;
181
182 pub fn random_indices(n: usize, k: usize, seed: Option<u64>) -> Vec<usize> {
184 assert!(k <= n, "Cannot sample more items than available");
185
186 let mut rng = match seed {
188 Some(s) => Random::seed(s),
189 None => Random::seed(42),
190 };
191
192 if k == n {
193 let mut indices: Vec<usize> = (0..n).collect();
195 for i in (1..indices.len()).rev() {
196 let j = rng.gen_range(0..=i);
197 indices.swap(i, j);
198 }
199 indices
200 } else if k <= n / 2 {
201 let mut selected = std::collections::HashSet::new();
203 while selected.len() < k {
204 let idx = rng.gen_range(0..n);
205 selected.insert(idx);
206 }
207 let mut result: Vec<usize> = selected.into_iter().collect();
208 result.sort_unstable(); result
210 } else {
211 let mut excluded = std::collections::HashSet::new();
213 while excluded.len() < n - k {
214 let idx = rng.gen_range(0..n);
215 excluded.insert(idx);
216 }
217 let mut result: Vec<usize> = (0..n).filter(|&i| !excluded.contains(&i)).collect();
218 result.sort_unstable(); result
220 }
221 }
222
223 pub fn stratified_split(
225 indices: &[usize],
226 labels: &[usize],
227 test_ratio: f32,
228 seed: Option<u64>,
229 ) -> (Vec<usize>, Vec<usize>) {
230 use std::collections::HashMap;
231
232 let mut label_groups: HashMap<usize, Vec<usize>> = HashMap::new();
234 for &idx in indices {
235 if idx < labels.len() {
236 label_groups
237 .entry(labels[idx])
238 .or_insert_with(Vec::new)
239 .push(idx);
240 }
241 }
242
243 let mut rng = match seed {
245 Some(s) => Random::seed(s),
246 None => Random::seed(42),
247 };
248
249 let mut train_indices = Vec::new();
250 let mut test_indices = Vec::new();
251
252 for (_, mut group_indices) in label_groups {
254 for i in (1..group_indices.len()).rev() {
256 let j = rng.gen_range(0..=i);
257 group_indices.swap(i, j);
258 }
259
260 let test_size = ((group_indices.len() as f32) * test_ratio).round() as usize;
261 let test_size = test_size.min(group_indices.len());
262
263 test_indices.extend(group_indices.iter().take(test_size));
264 train_indices.extend(group_indices.iter().skip(test_size));
265 }
266
267 (train_indices, test_indices)
268 }
269
270 pub fn calculate_class_weights(labels: &[usize], num_classes: usize) -> Vec<f32> {
272 let mut class_counts = vec![0usize; num_classes];
273
274 for &label in labels {
276 if label < num_classes {
277 class_counts[label] += 1;
278 }
279 }
280
281 let total_samples = labels.len() as f32;
283 class_counts
284 .iter()
285 .map(|&count| {
286 if count > 0 {
287 total_samples / (num_classes as f32 * count as f32)
288 } else {
289 0.0
290 }
291 })
292 .collect()
293 }
294
295 pub fn validate_sampling_params(
297 dataset_size: usize,
298 num_samples: Option<usize>,
299 replacement: bool,
300 ) -> Result<usize, String> {
301 let actual_num_samples = num_samples.unwrap_or(dataset_size);
302
303 if dataset_size == 0 {
305 if actual_num_samples == 0 {
306 return Ok(0);
307 } else {
308 return Err("Cannot sample from empty dataset".to_string());
309 }
310 }
311
312 if !replacement && actual_num_samples > dataset_size {
313 return Err(format!(
314 "Cannot sample {} items without replacement from dataset of size {}",
315 actual_num_samples, dataset_size
316 ));
317 }
318
319 if actual_num_samples == 0 && !replacement {
320 return Err(
321 "Number of samples cannot be zero for non-empty dataset without replacement"
322 .to_string(),
323 );
324 }
325
326 Ok(actual_num_samples)
327 }
328
329 pub fn train_val_split(
331 dataset_size: usize,
332 val_ratio: f32,
333 seed: Option<u64>,
334 ) -> (Vec<usize>, Vec<usize>) {
335 let val_size = (dataset_size as f32 * val_ratio).round() as usize;
336 let indices = random_indices(dataset_size, dataset_size, seed);
337
338 let (val_indices, train_indices) = indices.split_at(val_size);
339 (train_indices.to_vec(), val_indices.to_vec())
340 }
341
342 pub fn kfold_splits(
344 dataset_size: usize,
345 k: usize,
346 seed: Option<u64>,
347 ) -> Vec<(Vec<usize>, Vec<usize>)> {
348 assert!(k > 1, "K must be greater than 1");
349 assert!(k <= dataset_size, "K cannot be larger than dataset size");
350
351 let indices = random_indices(dataset_size, dataset_size, seed);
352 let fold_size = dataset_size / k;
353 let mut splits = Vec::new();
354
355 for i in 0..k {
356 let start = i * fold_size;
357 let end = if i == k - 1 {
358 dataset_size } else {
360 (i + 1) * fold_size
361 };
362
363 let val_indices = indices[start..end].to_vec();
364 let train_indices = [&indices[..start], &indices[end..]].concat();
365 splits.push((train_indices, val_indices));
366 }
367
368 splits
369 }
370
371 pub fn train_val_test_split(
373 dataset_size: usize,
374 train_ratio: f32,
375 val_ratio: f32,
376 seed: Option<u64>,
377 ) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
378 assert!(
379 train_ratio + val_ratio < 1.0,
380 "Train and val ratios must sum to less than 1.0"
381 );
382 assert!(
383 train_ratio > 0.0 && val_ratio > 0.0,
384 "Ratios must be positive"
385 );
386
387 let train_size = (dataset_size as f32 * train_ratio).round() as usize;
388 let val_size = (dataset_size as f32 * val_ratio).round() as usize;
389 let _test_size = dataset_size - train_size - val_size;
390
391 let indices = random_indices(dataset_size, dataset_size, seed);
392
393 let train_indices = indices[..train_size].to_vec();
394 let val_indices = indices[train_size..train_size + val_size].to_vec();
395 let test_indices = indices[train_size + val_size..].to_vec();
396
397 (train_indices, val_indices, test_indices)
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn test_sampler_iterator_basic() {
407 let indices = vec![0, 1, 2, 3, 4];
408 let iter = SamplerIterator::new(indices.clone());
409
410 assert_eq!(iter.len(), 5);
411 assert_eq!(iter.remaining(), 5);
412
413 let collected: Vec<usize> = iter.collect();
414 assert_eq!(collected, indices);
415 }
416
417 #[test]
418 fn test_sampler_iterator_from_range() {
419 let iter = SamplerIterator::from_range(0, 5);
420 let collected: Vec<usize> = iter.collect();
421 assert_eq!(collected, vec![0, 1, 2, 3, 4]);
422 }
423
424 #[test]
425 fn test_sampler_iterator_shuffled() {
426 let indices = vec![0, 1, 2, 3, 4];
427 let iter = SamplerIterator::shuffled(indices.clone(), Some(42));
428 let collected: Vec<usize> = iter.collect();
429
430 assert_eq!(collected.len(), indices.len());
432 for &idx in &indices {
433 assert!(collected.contains(&idx));
434 }
435 }
436
437 #[test]
438 fn test_utils_random_indices() {
439 let indices = utils::random_indices(10, 5, Some(42));
440 assert_eq!(indices.len(), 5);
441
442 let mut sorted_indices = indices.clone();
444 sorted_indices.sort();
445 sorted_indices.dedup();
446 assert_eq!(sorted_indices.len(), 5);
447
448 for &idx in &indices {
449 assert!(idx < 10);
450 }
451 }
452
453 #[test]
454 fn test_utils_random_indices_all() {
455 let indices = utils::random_indices(5, 5, Some(42));
456 assert_eq!(indices.len(), 5);
457
458 let mut sorted_indices = indices.clone();
459 sorted_indices.sort();
460 assert_eq!(sorted_indices, vec![0, 1, 2, 3, 4]);
461 }
462
463 #[test]
464 fn test_utils_stratified_split() {
465 let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
466 let labels = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2];
467
468 let (train, test) = utils::stratified_split(&indices, &labels, 0.3, Some(42));
469
470 assert!(train.len() + test.len() == indices.len());
472 assert!(test.len() >= 2); let mut all_indices = train.clone();
476 all_indices.extend(test.clone());
477 all_indices.sort();
478 assert_eq!(all_indices, indices);
479 }
480
481 #[test]
482 fn test_utils_calculate_class_weights() {
483 let labels = vec![0, 0, 1, 1, 1, 2]; let weights = utils::calculate_class_weights(&labels, 3);
485
486 assert_eq!(weights.len(), 3);
487
488 assert!(weights[2] > weights[1]);
491 assert!(weights[0] > weights[1]);
492 }
493
494 #[test]
495 fn test_utils_validate_sampling_params() {
496 assert!(utils::validate_sampling_params(10, Some(5), false).is_ok());
498 assert!(utils::validate_sampling_params(10, Some(15), true).is_ok());
499 assert!(utils::validate_sampling_params(10, None, false).is_ok());
500
501 assert!(utils::validate_sampling_params(0, Some(0), false).is_ok());
503 assert!(utils::validate_sampling_params(0, None, false).is_ok());
504
505 assert!(utils::validate_sampling_params(10, Some(0), true).is_ok());
507
508 assert!(utils::validate_sampling_params(0, Some(5), false).is_err()); assert!(utils::validate_sampling_params(10, Some(0), false).is_err()); assert!(utils::validate_sampling_params(10, Some(15), false).is_err()); }
513
514 #[test]
515 fn test_size_hints() {
516 let iter = SamplerIterator::new(vec![0, 1, 2]);
517 assert_eq!(iter.size_hint(), (3, Some(3)));
518
519 let mut iter = SamplerIterator::new(vec![0, 1, 2]);
520 iter.next();
521 assert_eq!(iter.size_hint(), (2, Some(2)));
522 }
523}