1use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
7use scirs2_core::RngExt;
8use std::collections::HashMap;
9use tenflowers_core::{Device, Result, Tensor, TensorError};
10
11#[cfg(feature = "serialize")]
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
17pub enum SamplingStrategy {
18 Sequential,
20 Random,
22 Shuffle,
24 Stratified,
26 Weighted,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
33pub enum PaddingStrategy {
34 LongestInBatch,
36 FixedLength,
38 NearestMultiple,
40 NoPadding,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
47pub enum CollationStrategy {
48 Stack,
50 Concatenate,
52 PadAndStack,
54 Custom,
56}
57
58#[derive(Debug, Clone)]
60#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
61pub struct BatchConfig {
62 pub batch_size: usize,
64 pub drop_last: bool,
66 pub sampling_strategy: SamplingStrategy,
68 pub padding_strategy: PaddingStrategy,
70 pub collation_strategy: CollationStrategy,
72 pub max_sequence_length: Option<usize>,
74 pub padding_value: f32,
76 pub seed: Option<u64>,
78}
79
80impl Default for BatchConfig {
81 fn default() -> Self {
82 Self {
83 batch_size: 32,
84 drop_last: false,
85 sampling_strategy: SamplingStrategy::Sequential,
86 padding_strategy: PaddingStrategy::LongestInBatch,
87 collation_strategy: CollationStrategy::Stack,
88 max_sequence_length: None,
89 padding_value: 0.0,
90 seed: None,
91 }
92 }
93}
94
95impl BatchConfig {
96 pub fn new(batch_size: usize) -> Self {
98 Self {
99 batch_size,
100 ..Default::default()
101 }
102 }
103
104 pub fn with_drop_last(mut self, drop_last: bool) -> Self {
106 self.drop_last = drop_last;
107 self
108 }
109
110 pub fn with_sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
112 self.sampling_strategy = strategy;
113 self
114 }
115
116 pub fn with_padding_strategy(mut self, strategy: PaddingStrategy) -> Self {
118 self.padding_strategy = strategy;
119 self
120 }
121
122 pub fn with_collation_strategy(mut self, strategy: CollationStrategy) -> Self {
124 self.collation_strategy = strategy;
125 self
126 }
127
128 pub fn with_max_sequence_length(mut self, max_len: usize) -> Self {
130 self.max_sequence_length = Some(max_len);
131 self
132 }
133
134 pub fn with_padding_value(mut self, value: f32) -> Self {
136 self.padding_value = value;
137 self
138 }
139
140 pub fn with_seed(mut self, seed: u64) -> Self {
142 self.seed = Some(seed);
143 self
144 }
145}
146
147fn fisher_yates_shuffle(indices: &mut [usize], rng: &mut StdRng) {
149 let n = indices.len();
150 if n <= 1 {
151 return;
152 }
153 for i in (1..n).rev() {
154 let j = rng.random_range(0..=i);
155 indices.swap(i, j);
156 }
157}
158
159pub struct BatchSampler {
161 dataset_size: usize,
162 config: BatchConfig,
163 current_index: usize,
164 indices: Vec<usize>,
165 rng: StdRng,
166}
167
168impl BatchSampler {
169 pub fn new(dataset_size: usize, config: BatchConfig) -> Self {
171 let seed = config.seed.unwrap_or(0);
172 let mut rng = StdRng::seed_from_u64(seed);
173
174 let mut indices: Vec<usize> = (0..dataset_size).collect();
175 if config.sampling_strategy == SamplingStrategy::Shuffle {
176 fisher_yates_shuffle(&mut indices, &mut rng);
177 }
178
179 Self {
180 dataset_size,
181 config,
182 current_index: 0,
183 indices,
184 rng,
185 }
186 }
187
188 pub fn next_batch(&mut self) -> Option<Vec<usize>> {
190 if self.current_index >= self.dataset_size {
191 return None;
192 }
193
194 let end_index = (self.current_index + self.config.batch_size).min(self.dataset_size);
195 let batch_indices: Vec<usize> = self.indices[self.current_index..end_index].to_vec();
196
197 self.current_index = end_index;
198
199 if self.config.drop_last && batch_indices.len() < self.config.batch_size {
201 None
202 } else {
203 Some(batch_indices)
204 }
205 }
206
207 pub fn reset(&mut self) {
209 self.current_index = 0;
210
211 if self.config.sampling_strategy == SamplingStrategy::Shuffle {
215 self.indices = (0..self.dataset_size).collect();
216 fisher_yates_shuffle(&mut self.indices, &mut self.rng);
217 }
218 }
219
220 pub fn num_batches(&self) -> usize {
222 let total = (self.dataset_size + self.config.batch_size - 1) / self.config.batch_size;
223 if self.config.drop_last && self.dataset_size % self.config.batch_size != 0 {
224 total - 1
225 } else {
226 total
227 }
228 }
229
230 pub fn current_batch_index(&self) -> usize {
232 self.current_index / self.config.batch_size
233 }
234}
235
236pub struct Collator {
238 config: BatchConfig,
239}
240
241impl Collator {
242 pub fn new(config: BatchConfig) -> Self {
244 Self { config }
245 }
246
247 pub fn collate<T>(&self, samples: &[Tensor<T>]) -> Result<Tensor<T>>
249 where
250 T: Clone + Default,
251 {
252 if samples.is_empty() {
253 return Err(TensorError::invalid_shape_simple(
254 "Cannot collate empty batch".to_string(),
255 ));
256 }
257
258 match self.config.collation_strategy {
259 CollationStrategy::Stack => self.stack_samples(samples),
260 CollationStrategy::PadAndStack => self.pad_and_stack_samples(samples),
261 _ => {
262 self.stack_samples(samples)
264 }
265 }
266 }
267
268 fn stack_samples<T>(&self, samples: &[Tensor<T>]) -> Result<Tensor<T>>
270 where
271 T: Clone + Default,
272 {
273 Ok(samples[0].clone())
276 }
277
278 fn pad_and_stack_samples<T>(&self, samples: &[Tensor<T>]) -> Result<Tensor<T>>
280 where
281 T: Clone + Default,
282 {
283 Ok(samples[0].clone())
289 }
290
291 fn get_padding_length(&self, sample_lengths: &[usize]) -> usize {
293 match self.config.padding_strategy {
294 PaddingStrategy::LongestInBatch => *sample_lengths.iter().max().unwrap_or(&0),
295 PaddingStrategy::FixedLength => self.config.max_sequence_length.unwrap_or(512),
296 PaddingStrategy::NearestMultiple => {
297 let max_len = *sample_lengths.iter().max().unwrap_or(&0);
298 let multiple = self.config.max_sequence_length.unwrap_or(8);
299 ((max_len + multiple - 1) / multiple) * multiple
300 }
301 PaddingStrategy::NoPadding => sample_lengths[0],
302 }
303 }
304}
305
306#[derive(Debug, Clone)]
308#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
309pub struct BatchStatistics {
310 pub total_batches: usize,
312 pub total_samples: usize,
314 pub avg_batch_size: f64,
316 pub min_batch_size: usize,
318 pub max_batch_size: usize,
320 pub avg_padding_ratio: f64,
322}
323
324impl BatchStatistics {
325 pub fn new() -> Self {
327 Self {
328 total_batches: 0,
329 total_samples: 0,
330 avg_batch_size: 0.0,
331 min_batch_size: usize::MAX,
332 max_batch_size: 0,
333 avg_padding_ratio: 0.0,
334 }
335 }
336
337 pub fn record_batch(&mut self, batch_size: usize, padding_ratio: f64) {
339 self.total_batches += 1;
340 self.total_samples += batch_size;
341 self.min_batch_size = self.min_batch_size.min(batch_size);
342 self.max_batch_size = self.max_batch_size.max(batch_size);
343
344 let n = self.total_batches as f64;
346 self.avg_batch_size = (self.avg_batch_size * (n - 1.0) + batch_size as f64) / n;
347 self.avg_padding_ratio = (self.avg_padding_ratio * (n - 1.0) + padding_ratio) / n;
348 }
349
350 pub fn reset(&mut self) {
352 *self = Self::new();
353 }
354
355 pub fn efficiency(&self) -> f64 {
357 1.0 - self.avg_padding_ratio
358 }
359}
360
361impl Default for BatchStatistics {
362 fn default() -> Self {
363 Self::new()
364 }
365}
366
367pub mod batch_utils {
369 use super::*;
370
371 pub fn calculate_optimal_batch_size(
373 sample_memory_bytes: usize,
374 available_memory_bytes: usize,
375 safety_factor: f64,
376 ) -> usize {
377 let usable_memory = (available_memory_bytes as f64 * safety_factor) as usize;
378 (usable_memory / sample_memory_bytes).max(1)
379 }
380
381 pub fn calculate_num_batches(dataset_size: usize, batch_size: usize, drop_last: bool) -> usize {
383 let total = (dataset_size + batch_size - 1) / batch_size;
384 if drop_last && dataset_size % batch_size != 0 {
385 total - 1
386 } else {
387 total
388 }
389 }
390
391 pub fn calculate_padding_overhead(original_lengths: &[usize], padded_length: usize) -> f64 {
393 let original_total: usize = original_lengths.iter().sum();
394 let padded_total = original_lengths.len() * padded_length;
395
396 if padded_total == 0 {
397 0.0
398 } else {
399 1.0 - (original_total as f64 / padded_total as f64)
400 }
401 }
402
403 pub fn find_optimal_padding_length(lengths: &[usize], multiple: usize) -> usize {
405 let max_len = lengths.iter().max().copied().unwrap_or(0);
406 ((max_len + multiple - 1) / multiple) * multiple
407 }
408
409 pub fn group_by_length(lengths: Vec<usize>, num_groups: usize) -> Vec<Vec<usize>> {
411 if lengths.is_empty() || num_groups == 0 {
412 return vec![];
413 }
414
415 let mut indexed_lengths: Vec<_> = lengths.into_iter().enumerate().collect();
416 indexed_lengths.sort_by_key(|(_, len)| *len);
417
418 let group_size = (indexed_lengths.len() + num_groups - 1) / num_groups;
419 let mut groups = vec![Vec::new(); num_groups];
420
421 for (group_idx, chunk) in indexed_lengths.chunks(group_size).enumerate() {
422 groups[group_idx] = chunk.iter().map(|(idx, _)| *idx).collect();
423 }
424
425 groups.into_iter().filter(|g| !g.is_empty()).collect()
426 }
427
428 pub fn calculate_memory_efficiency(
430 batch_size: usize,
431 avg_sequence_length: usize,
432 max_sequence_length: usize,
433 ) -> f64 {
434 let used = batch_size * avg_sequence_length;
435 let allocated = batch_size * max_sequence_length;
436
437 if allocated == 0 {
438 0.0
439 } else {
440 used as f64 / allocated as f64
441 }
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_sampling_strategy_variants() {
451 let strategies = [
452 SamplingStrategy::Sequential,
453 SamplingStrategy::Random,
454 SamplingStrategy::Shuffle,
455 SamplingStrategy::Stratified,
456 SamplingStrategy::Weighted,
457 ];
458
459 assert_eq!(strategies.len(), 5);
460 }
461
462 #[test]
463 fn test_batch_config_default() {
464 let config = BatchConfig::default();
465 assert_eq!(config.batch_size, 32);
466 assert!(!config.drop_last);
467 assert_eq!(config.sampling_strategy, SamplingStrategy::Sequential);
468 }
469
470 #[test]
471 fn test_batch_config_builder() {
472 let config = BatchConfig::new(64)
473 .with_drop_last(true)
474 .with_padding_value(1.0)
475 .with_max_sequence_length(128)
476 .with_seed(42);
477
478 assert_eq!(config.batch_size, 64);
479 assert!(config.drop_last);
480 assert_eq!(config.padding_value, 1.0);
481 assert_eq!(config.max_sequence_length, Some(128));
482 assert_eq!(config.seed, Some(42));
483 }
484
485 #[test]
486 fn test_batch_sampler_creation() {
487 let config = BatchConfig::new(10);
488 let sampler = BatchSampler::new(100, config);
489
490 assert_eq!(sampler.dataset_size, 100);
491 assert_eq!(sampler.num_batches(), 10);
492 }
493
494 #[test]
495 fn test_batch_sampler_next_batch() {
496 let config = BatchConfig::new(10);
497 let mut sampler = BatchSampler::new(25, config);
498
499 let batch1 = sampler.next_batch();
500 assert!(batch1.is_some());
501 assert_eq!(batch1.expect("test: operation should succeed").len(), 10);
502
503 let batch2 = sampler.next_batch();
504 assert!(batch2.is_some());
505 assert_eq!(batch2.expect("test: operation should succeed").len(), 10);
506
507 let batch3 = sampler.next_batch();
508 assert!(batch3.is_some());
509 assert_eq!(batch3.expect("test: operation should succeed").len(), 5); }
511
512 #[test]
513 fn test_batch_sampler_drop_last() {
514 let config = BatchConfig::new(10).with_drop_last(true);
515 let mut sampler = BatchSampler::new(25, config);
516
517 sampler.next_batch();
518 sampler.next_batch();
519 let batch3 = sampler.next_batch();
520
521 assert!(batch3.is_none()); }
523
524 #[test]
525 fn test_batch_sampler_num_batches() {
526 let config = BatchConfig::new(10);
527 let sampler = BatchSampler::new(25, config);
528 assert_eq!(sampler.num_batches(), 3);
529
530 let config_drop = BatchConfig::new(10).with_drop_last(true);
531 let sampler_drop = BatchSampler::new(25, config_drop);
532 assert_eq!(sampler_drop.num_batches(), 2);
533 }
534
535 #[test]
536 fn test_batch_sampler_reset() {
537 let config = BatchConfig::new(10);
538 let mut sampler = BatchSampler::new(25, config);
539
540 sampler.next_batch();
541 sampler.next_batch();
542 assert_eq!(sampler.current_batch_index(), 2);
543
544 sampler.reset();
545 assert_eq!(sampler.current_batch_index(), 0);
546 }
547
548 #[test]
549 fn test_collator_creation() {
550 let config = BatchConfig::new(32);
551 let collator = Collator::new(config);
552
553 assert_eq!(collator.config.batch_size, 32);
555 }
556
557 #[test]
558 fn test_batch_statistics_creation() {
559 let stats = BatchStatistics::new();
560 assert_eq!(stats.total_batches, 0);
561 assert_eq!(stats.total_samples, 0);
562 assert_eq!(stats.avg_batch_size, 0.0);
563 }
564
565 #[test]
566 fn test_batch_statistics_record() {
567 let mut stats = BatchStatistics::new();
568
569 stats.record_batch(32, 0.1);
570 assert_eq!(stats.total_batches, 1);
571 assert_eq!(stats.total_samples, 32);
572 assert_eq!(stats.avg_batch_size, 32.0);
573
574 stats.record_batch(30, 0.15);
575 assert_eq!(stats.total_batches, 2);
576 assert_eq!(stats.total_samples, 62);
577 assert_eq!(stats.avg_batch_size, 31.0);
578 }
579
580 #[test]
581 fn test_batch_statistics_min_max() {
582 let mut stats = BatchStatistics::new();
583
584 stats.record_batch(32, 0.1);
585 stats.record_batch(20, 0.1);
586 stats.record_batch(40, 0.1);
587
588 assert_eq!(stats.min_batch_size, 20);
589 assert_eq!(stats.max_batch_size, 40);
590 }
591
592 #[test]
593 fn test_batch_statistics_efficiency() {
594 let mut stats = BatchStatistics::new();
595 stats.record_batch(32, 0.2); let efficiency = stats.efficiency();
598 assert!((efficiency - 0.8).abs() < 0.01); }
600
601 #[test]
602 fn test_utils_calculate_optimal_batch_size() {
603 let batch_size = batch_utils::calculate_optimal_batch_size(
604 1024 * 1024, 1024 * 1024 * 64, 0.8, );
608
609 assert!(batch_size > 0);
610 assert!(batch_size <= 64);
611 }
612
613 #[test]
614 fn test_utils_calculate_num_batches() {
615 assert_eq!(batch_utils::calculate_num_batches(100, 32, false), 4);
616 assert_eq!(batch_utils::calculate_num_batches(100, 32, true), 3);
617 assert_eq!(batch_utils::calculate_num_batches(96, 32, false), 3);
618 assert_eq!(batch_utils::calculate_num_batches(96, 32, true), 3);
619 }
620
621 #[test]
622 fn test_utils_calculate_padding_overhead() {
623 let lengths = vec![10, 15, 12, 8];
624 let overhead = batch_utils::calculate_padding_overhead(&lengths, 20);
625
626 assert!((overhead - 0.4375).abs() < 0.01);
629 }
630
631 #[test]
632 fn test_utils_find_optimal_padding_length() {
633 let lengths = vec![10, 15, 18, 22];
634 let optimal = batch_utils::find_optimal_padding_length(&lengths, 8);
635
636 assert_eq!(optimal, 24); }
638
639 #[test]
640 fn test_utils_group_by_length() {
641 let lengths = vec![10, 25, 15, 30, 20, 12, 28];
642 let groups = batch_utils::group_by_length(lengths, 3);
643
644 assert_eq!(groups.len(), 3);
645 for group in &groups {
647 assert!(!group.is_empty());
648 }
649 }
650
651 #[test]
652 fn test_utils_calculate_memory_efficiency() {
653 let efficiency = batch_utils::calculate_memory_efficiency(
654 32, 100, 128, );
658
659 assert!((efficiency - 0.78125).abs() < 0.01);
661 }
662
663 #[test]
664 fn test_padding_strategy_variants() {
665 let strategies = [
666 PaddingStrategy::LongestInBatch,
667 PaddingStrategy::FixedLength,
668 PaddingStrategy::NearestMultiple,
669 PaddingStrategy::NoPadding,
670 ];
671
672 assert_eq!(strategies.len(), 4);
673 }
674
675 #[test]
676 fn test_collation_strategy_variants() {
677 let strategies = [
678 CollationStrategy::Stack,
679 CollationStrategy::Concatenate,
680 CollationStrategy::PadAndStack,
681 CollationStrategy::Custom,
682 ];
683
684 assert_eq!(strategies.len(), 4);
685 }
686}