1use crate::error::{DatasetsError, Result};
8
9struct Lcg64 {
15 state: u64,
16}
17
18impl Lcg64 {
19 fn new(seed: u64) -> Self {
20 Self {
21 state: seed.wrapping_add(1),
22 }
23 }
24
25 fn next_u64(&mut self) -> u64 {
26 self.state = self
27 .state
28 .wrapping_mul(6_364_136_223_846_793_005)
29 .wrapping_add(1_442_695_040_888_963_407);
30 self.state
31 }
32
33 fn next_usize(&mut self, n: usize) -> usize {
34 if n == 0 {
35 return 0;
36 }
37 (self.next_u64() % n as u64) as usize
38 }
39
40 fn next_f64(&mut self) -> f64 {
41 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
42 }
43}
44
45#[non_exhaustive]
51#[derive(Debug, Clone, PartialEq, Default)]
52pub enum SamplerStrategy {
53 #[default]
55 Sequential,
56 Random,
58 Stratified,
60 WeightedRandom {
62 weights: Vec<f64>,
64 },
65}
66
67#[derive(Debug, Clone)]
69pub struct SamplerConfig {
70 pub batch_size: usize,
72 pub shuffle: bool,
74 pub drop_last: bool,
76 pub seed: u64,
78 pub strategy: SamplerStrategy,
80}
81
82impl Default for SamplerConfig {
83 fn default() -> Self {
84 Self {
85 batch_size: 32,
86 shuffle: true,
87 drop_last: false,
88 seed: 42,
89 strategy: SamplerStrategy::default(),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct MiniBatch {
97 pub data: Vec<Vec<f64>>,
99 pub labels: Vec<usize>,
101 pub indices: Vec<usize>,
103}
104
105#[derive(Debug, Clone)]
109pub struct MiniBatchSampler {
110 config: SamplerConfig,
111}
112
113impl MiniBatchSampler {
114 pub fn new(config: SamplerConfig) -> Self {
116 Self { config }
117 }
118
119 pub fn config(&self) -> &SamplerConfig {
121 &self.config
122 }
123
124 pub fn iter_batches(&self, data: &[Vec<f64>], labels: &[usize]) -> Result<Vec<MiniBatch>> {
134 iter_batches(data, labels, &self.config)
135 }
136}
137
138pub fn iter_batches(
146 data: &[Vec<f64>],
147 labels: &[usize],
148 config: &SamplerConfig,
149) -> Result<Vec<MiniBatch>> {
150 let n = data.len();
151 if n != labels.len() {
152 return Err(DatasetsError::InvalidFormat(format!(
153 "data length ({}) != labels length ({})",
154 n,
155 labels.len()
156 )));
157 }
158 if config.batch_size == 0 {
159 return Err(DatasetsError::InvalidFormat(
160 "batch_size must be >= 1".into(),
161 ));
162 }
163 if n == 0 {
164 return Ok(Vec::new());
165 }
166
167 let indices = build_index_order(n, labels, config);
168 let mut batches = Vec::new();
169 let mut offset = 0;
170
171 while offset < indices.len() {
172 let end = (offset + config.batch_size).min(indices.len());
173 let batch_indices: Vec<usize> = indices[offset..end].to_vec();
174
175 if config.drop_last && batch_indices.len() < config.batch_size {
176 break;
177 }
178
179 let batch_data: Vec<Vec<f64>> = batch_indices.iter().map(|&i| data[i].clone()).collect();
180 let batch_labels: Vec<usize> = batch_indices.iter().map(|&i| labels[i]).collect();
181
182 batches.push(MiniBatch {
183 data: batch_data,
184 labels: batch_labels,
185 indices: batch_indices,
186 });
187
188 offset = end;
189 }
190
191 Ok(batches)
192}
193
194fn build_index_order(n: usize, labels: &[usize], config: &SamplerConfig) -> Vec<usize> {
200 match &config.strategy {
201 SamplerStrategy::Sequential => {
202 let mut indices: Vec<usize> = (0..n).collect();
203 if config.shuffle {
204 fisher_yates_shuffle(&mut indices, config.seed);
205 }
206 indices
207 }
208
209 SamplerStrategy::Random => {
210 let mut indices: Vec<usize> = (0..n).collect();
211 fisher_yates_shuffle(&mut indices, config.seed);
212 indices
213 }
214
215 SamplerStrategy::Stratified => build_stratified_order(n, labels, config),
216
217 SamplerStrategy::WeightedRandom { weights } => build_weighted_order(n, weights, config),
218 }
219}
220
221fn fisher_yates_shuffle(indices: &mut [usize], seed: u64) {
223 let n = indices.len();
224 if n <= 1 {
225 return;
226 }
227 let mut rng = Lcg64::new(seed);
228 for i in (1..n).rev() {
229 let j = rng.next_usize(i + 1);
230 indices.swap(i, j);
231 }
232}
233
234fn build_stratified_order(n: usize, labels: &[usize], config: &SamplerConfig) -> Vec<usize> {
237 if n == 0 {
238 return Vec::new();
239 }
240
241 let max_class = labels.iter().copied().max().unwrap_or(0);
243 let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); max_class + 1];
244 for (i, &label) in labels.iter().enumerate() {
245 class_indices[label].push(i);
246 }
247
248 if config.shuffle {
250 for (cls, indices) in class_indices.iter_mut().enumerate() {
251 let class_seed = config.seed.wrapping_add(cls as u64 * 0x9e37_79b9_7f4a_7c15);
252 fisher_yates_shuffle(indices, class_seed);
253 }
254 }
255
256 let mut result = Vec::with_capacity(n);
259 let mut cursors: Vec<usize> = vec![0; class_indices.len()];
260 let mut remaining = n;
261
262 while remaining > 0 {
263 let mut added = false;
264 for (cls, indices) in class_indices.iter().enumerate() {
265 if cursors[cls] < indices.len() {
266 result.push(indices[cursors[cls]]);
267 cursors[cls] += 1;
268 remaining -= 1;
269 added = true;
270 if remaining == 0 {
271 break;
272 }
273 }
274 }
275 if !added {
276 break;
277 }
278 }
279
280 result
281}
282
283fn build_weighted_order(n: usize, weights: &[f64], config: &SamplerConfig) -> Vec<usize> {
285 if n == 0 || weights.is_empty() {
286 return Vec::new();
287 }
288
289 let mut rng = Lcg64::new(config.seed);
290 let actual_weights: Vec<f64> = if weights.len() >= n {
291 weights[..n].to_vec()
292 } else {
293 let mut w = weights.to_vec();
295 w.resize(n, 1.0);
296 w
297 };
298
299 let total: f64 = actual_weights.iter().sum();
301 if total <= 0.0 {
302 return (0..n).collect();
304 }
305 let cumulative: Vec<f64> = actual_weights
306 .iter()
307 .scan(0.0, |acc, &w| {
308 *acc += w / total;
309 Some(*acc)
310 })
311 .collect();
312
313 (0..n)
315 .map(|_| {
316 let u = rng.next_f64();
317 match cumulative.binary_search_by(|probe| {
319 probe.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal)
320 }) {
321 Ok(idx) => idx.min(n - 1),
322 Err(idx) => idx.min(n - 1),
323 }
324 })
325 .collect()
326}
327
328#[cfg(test)]
333mod tests {
334 use super::*;
335
336 fn make_simple_data(n: usize, n_features: usize) -> Vec<Vec<f64>> {
337 (0..n)
338 .map(|i| {
339 (0..n_features)
340 .map(|j| (i * n_features + j) as f64)
341 .collect()
342 })
343 .collect()
344 }
345
346 #[test]
347 fn test_sequential_batches_correct_size() {
348 let data = make_simple_data(100, 5);
349 let labels: Vec<usize> = (0..100).map(|i| i % 3).collect();
350 let config = SamplerConfig {
351 batch_size: 32,
352 shuffle: false,
353 drop_last: false,
354 seed: 42,
355 strategy: SamplerStrategy::Sequential,
356 };
357 let batches = iter_batches(&data, &labels, &config).expect("should succeed");
358 assert_eq!(batches.len(), 4);
360 assert_eq!(batches[0].data.len(), 32);
361 assert_eq!(batches[3].data.len(), 4); }
363
364 #[test]
365 fn test_drop_last() {
366 let data = make_simple_data(50, 3);
367 let labels: Vec<usize> = vec![0; 50];
368 let config = SamplerConfig {
369 batch_size: 16,
370 shuffle: false,
371 drop_last: true,
372 seed: 0,
373 strategy: SamplerStrategy::Sequential,
374 };
375 let batches = iter_batches(&data, &labels, &config).expect("should succeed");
376 assert_eq!(batches.len(), 3);
378 for b in &batches {
379 assert_eq!(b.data.len(), 16);
380 }
381 }
382
383 #[test]
384 fn test_random_shuffles_indices() {
385 let data = make_simple_data(20, 2);
386 let labels: Vec<usize> = vec![0; 20];
387 let config = SamplerConfig {
388 batch_size: 20,
389 shuffle: true,
390 drop_last: false,
391 seed: 99,
392 strategy: SamplerStrategy::Random,
393 };
394 let batches = iter_batches(&data, &labels, &config).expect("should succeed");
395 assert_eq!(batches.len(), 1);
396 let mut sorted = batches[0].indices.clone();
398 sorted.sort_unstable();
399 assert_eq!(sorted, (0..20).collect::<Vec<_>>());
400 assert_ne!(batches[0].indices, (0..20).collect::<Vec<_>>());
402 }
403
404 #[test]
405 fn test_stratified_label_proportions() {
406 let n = 100;
408 let mut labels: Vec<usize> = vec![0; 60];
409 labels.extend(vec![1; 40]);
410 let data = make_simple_data(n, 2);
411
412 let config = SamplerConfig {
413 batch_size: 20,
414 shuffle: false,
415 drop_last: false,
416 seed: 42,
417 strategy: SamplerStrategy::Stratified,
418 };
419 let batches = iter_batches(&data, &labels, &config).expect("should succeed");
420 assert_eq!(batches.len(), 5); let total_c0: usize = batches
424 .iter()
425 .map(|b| b.labels.iter().filter(|&&l| l == 0).count())
426 .sum();
427 let total_c1: usize = batches
428 .iter()
429 .map(|b| b.labels.iter().filter(|&&l| l == 1).count())
430 .sum();
431 assert_eq!(total_c0, 60);
432 assert_eq!(total_c1, 40);
433
434 let batches_with_both: usize = batches
437 .iter()
438 .filter(|b| {
439 let c0 = b.labels.iter().filter(|&&l| l == 0).count();
440 let c1 = b.labels.iter().filter(|&&l| l == 1).count();
441 c0 > 0 && c1 > 0
442 })
443 .count();
444 assert!(
446 batches_with_both >= 4,
447 "Expected most batches to have both classes, got {batches_with_both}"
448 );
449 }
450
451 #[test]
452 fn test_weighted_sampling() {
453 let n = 50;
454 let data = make_simple_data(n, 2);
455 let labels: Vec<usize> = vec![0; n];
456 let mut weights = vec![0.0; n];
458 weights[0] = 1.0;
459
460 let config = SamplerConfig {
461 batch_size: 10,
462 shuffle: false,
463 drop_last: false,
464 seed: 42,
465 strategy: SamplerStrategy::WeightedRandom { weights },
466 };
467 let batches = iter_batches(&data, &labels, &config).expect("should succeed");
468 for batch in &batches {
470 for &idx in &batch.indices {
471 assert_eq!(idx, 0, "All indices should be 0 with weight=[1,0,0,...]");
472 }
473 }
474 }
475
476 #[test]
477 fn test_reproducibility_same_seed() {
478 let data = make_simple_data(40, 3);
479 let labels: Vec<usize> = (0..40).map(|i| i % 2).collect();
480 let config = SamplerConfig {
481 batch_size: 10,
482 shuffle: true,
483 drop_last: false,
484 seed: 777,
485 strategy: SamplerStrategy::Random,
486 };
487 let b1 = iter_batches(&data, &labels, &config).expect("ok");
488 let b2 = iter_batches(&data, &labels, &config).expect("ok");
489 assert_eq!(b1.len(), b2.len());
490 for (a, b) in b1.iter().zip(b2.iter()) {
491 assert_eq!(a.indices, b.indices);
492 }
493 }
494
495 #[test]
496 fn test_mismatched_lengths_error() {
497 let data = make_simple_data(10, 2);
498 let labels: Vec<usize> = vec![0; 5];
499 let config = SamplerConfig::default();
500 assert!(iter_batches(&data, &labels, &config).is_err());
501 }
502
503 #[test]
504 fn test_zero_batch_size_error() {
505 let data = make_simple_data(10, 2);
506 let labels: Vec<usize> = vec![0; 10];
507 let config = SamplerConfig {
508 batch_size: 0,
509 ..Default::default()
510 };
511 assert!(iter_batches(&data, &labels, &config).is_err());
512 }
513
514 #[test]
515 fn test_empty_dataset() {
516 let data: Vec<Vec<f64>> = Vec::new();
517 let labels: Vec<usize> = Vec::new();
518 let config = SamplerConfig::default();
519 let batches = iter_batches(&data, &labels, &config).expect("ok");
520 assert!(batches.is_empty());
521 }
522
523 #[test]
524 fn test_sampler_struct() {
525 let data = make_simple_data(20, 2);
526 let labels: Vec<usize> = vec![0; 20];
527 let sampler = MiniBatchSampler::new(SamplerConfig {
528 batch_size: 5,
529 shuffle: false,
530 drop_last: false,
531 seed: 0,
532 strategy: SamplerStrategy::Sequential,
533 });
534 let batches = sampler.iter_batches(&data, &labels).expect("ok");
535 assert_eq!(batches.len(), 4);
536 assert_eq!(sampler.config().batch_size, 5);
537 }
538
539 #[test]
540 fn test_all_indices_covered_sequential() {
541 let n = 37;
542 let data = make_simple_data(n, 2);
543 let labels: Vec<usize> = vec![0; n];
544 let config = SamplerConfig {
545 batch_size: 10,
546 shuffle: false,
547 drop_last: false,
548 seed: 0,
549 strategy: SamplerStrategy::Sequential,
550 };
551 let batches = iter_batches(&data, &labels, &config).expect("ok");
552 let mut all_indices: Vec<usize> = batches
553 .iter()
554 .flat_map(|b| b.indices.iter().copied())
555 .collect();
556 all_indices.sort_unstable();
557 assert_eq!(all_indices, (0..n).collect::<Vec<_>>());
558 }
559
560 #[test]
561 fn test_batch_data_matches_original() {
562 let data = make_simple_data(15, 3);
563 let labels: Vec<usize> = (0..15).map(|i| i % 2).collect();
564 let config = SamplerConfig {
565 batch_size: 5,
566 shuffle: false,
567 drop_last: false,
568 seed: 0,
569 strategy: SamplerStrategy::Sequential,
570 };
571 let batches = iter_batches(&data, &labels, &config).expect("ok");
572 for batch in &batches {
573 for (pos, &idx) in batch.indices.iter().enumerate() {
574 assert_eq!(batch.data[pos], data[idx]);
575 assert_eq!(batch.labels[pos], labels[idx]);
576 }
577 }
578 }
579}