1use crate::data::Dataset;
10use crate::error::{NeuralError, Result};
11use scirs2_core::chunking::{ChunkConfig, ChunkStrategy, ChunkingUtils};
12use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
13use scirs2_core::numeric::{Float, FromPrimitive};
14use scirs2_core::random::seq::SliceRandom;
15use scirs2_core::NumAssign;
16use std::collections::VecDeque;
17use std::fmt::Debug;
18use std::marker::PhantomData;
19use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
20use std::sync::{Arc, Mutex};
21use std::thread;
22use std::time::{Duration, Instant};
23
24type BatchPair<F> = (Array<F, IxDyn>, Array<F, IxDyn>);
26
27#[derive(Debug, Clone)]
33pub struct OptimizedLoaderConfig {
34 pub batch_size: usize,
36 pub prefetch_size: usize,
38 pub num_workers: usize,
40 pub drop_last: bool,
42 pub shuffle: bool,
44 pub pin_memory: bool,
46 pub cache_batches: bool,
48 pub max_cache_memory: usize,
50}
51
52impl Default for OptimizedLoaderConfig {
53 fn default() -> Self {
54 Self {
55 batch_size: 32,
56 prefetch_size: 2,
57 num_workers: 0,
58 drop_last: false,
59 shuffle: true,
60 pin_memory: false,
61 cache_batches: false,
62 max_cache_memory: 0,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Default)]
69pub struct LoadingStats {
70 pub batches_loaded: usize,
72 pub samples_loaded: usize,
74 pub total_load_time: Duration,
76 pub avg_batch_time: Duration,
78 pub cache_hits: usize,
80 pub cache_misses: usize,
82 pub prefetch_wait_time: Duration,
84}
85
86pub type BatchResult<F> = Result<(Array<F, IxDyn>, Array<F, IxDyn>)>;
92
93struct BatchCache<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> {
99 cache: Vec<Option<BatchPair<F>>>,
101 max_batches: usize,
103 memory_usage: usize,
105}
106
107impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> BatchCache<F> {
108 fn new(max_batches: usize) -> Self {
109 Self {
110 cache: vec![None; max_batches],
111 max_batches,
112 memory_usage: 0,
113 }
114 }
115
116 fn get(&self, index: usize) -> Option<&BatchPair<F>> {
117 if index < self.cache.len() {
118 self.cache[index].as_ref()
119 } else {
120 None
121 }
122 }
123
124 fn insert(&mut self, index: usize, batch: BatchPair<F>) {
125 if index < self.cache.len() {
126 let batch_size = estimate_array_memory(&batch.0) + estimate_array_memory(&batch.1);
127 self.memory_usage += batch_size;
128 self.cache[index] = Some(batch);
129 }
130 }
131
132 fn clear(&mut self) {
133 self.cache.iter_mut().for_each(|b| *b = None);
134 self.memory_usage = 0;
135 }
136}
137
138fn estimate_array_memory<F: Float + NumAssign>(array: &Array<F, IxDyn>) -> usize {
140 array.len() * std::mem::size_of::<F>()
141}
142
143struct PrefetchQueue<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> {
149 queue: Mutex<VecDeque<(usize, BatchResult<F>)>>,
151 max_size: usize,
153 size: AtomicUsize,
155 stop: AtomicBool,
157}
158
159impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> PrefetchQueue<F> {
160 fn new(max_size: usize) -> Self {
161 Self {
162 queue: Mutex::new(VecDeque::with_capacity(max_size)),
163 max_size,
164 size: AtomicUsize::new(0),
165 stop: AtomicBool::new(false),
166 }
167 }
168
169 fn push(&self, index: usize, batch: BatchResult<F>) -> bool {
170 if self.stop.load(Ordering::Relaxed) {
171 return false;
172 }
173
174 while self.size.load(Ordering::Relaxed) >= self.max_size {
176 if self.stop.load(Ordering::Relaxed) {
177 return false;
178 }
179 thread::sleep(Duration::from_micros(100));
180 }
181
182 let mut queue = match self.queue.lock() {
183 Ok(q) => q,
184 Err(_) => return false,
185 };
186 queue.push_back((index, batch));
187 self.size.fetch_add(1, Ordering::Relaxed);
188 true
189 }
190
191 fn pop(&self) -> Option<(usize, BatchResult<F>)> {
192 let mut queue = match self.queue.lock() {
193 Ok(q) => q,
194 Err(_) => return None,
195 };
196 let result = queue.pop_front();
197 if result.is_some() {
198 self.size.fetch_sub(1, Ordering::Relaxed);
199 }
200 result
201 }
202
203 fn stop(&self) {
204 self.stop.store(true, Ordering::Relaxed);
205 }
206
207 fn is_empty(&self) -> bool {
208 self.size.load(Ordering::Relaxed) == 0
209 }
210}
211
212pub struct OptimizedDataLoader<
218 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync,
219 D: Dataset<F> + Send + Sync + Clone + 'static,
220> {
221 dataset: Arc<D>,
223 config: OptimizedLoaderConfig,
225 indices: Vec<usize>,
227 position: AtomicUsize,
229 num_batches: usize,
231 cache: Option<Mutex<BatchCache<F>>>,
233 stats: Mutex<LoadingStats>,
235 _phantom: PhantomData<F>,
237}
238
239impl<
240 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
241 D: Dataset<F> + Send + Sync + Clone + 'static,
242 > OptimizedDataLoader<F, D>
243{
244 pub fn new(dataset: D, config: OptimizedLoaderConfig) -> Self {
246 let dataset_len = dataset.len();
247 let batch_size = config.batch_size;
248 let drop_last = config.drop_last;
249
250 let num_batches = if drop_last {
251 dataset_len / batch_size
252 } else {
253 dataset_len.div_ceil(batch_size)
254 };
255
256 let indices: Vec<usize> = (0..dataset_len).collect();
257
258 let cache = if config.cache_batches {
259 Some(Mutex::new(BatchCache::new(num_batches)))
260 } else {
261 None
262 };
263
264 Self {
265 dataset: Arc::new(dataset),
266 config,
267 indices,
268 position: AtomicUsize::new(0),
269 num_batches,
270 cache,
271 stats: Mutex::new(LoadingStats::default()),
272 _phantom: PhantomData,
273 }
274 }
275
276 pub fn reset(&mut self) {
278 if self.config.shuffle {
279 let mut rng = scirs2_core::random::rng();
280 self.indices.shuffle(&mut rng);
281 }
282 self.position.store(0, Ordering::Relaxed);
283 }
284
285 pub fn num_batches(&self) -> usize {
287 self.num_batches
288 }
289
290 pub fn len(&self) -> usize {
292 self.dataset.len()
293 }
294
295 pub fn is_empty(&self) -> bool {
297 self.len() == 0
298 }
299
300 pub fn stats(&self) -> LoadingStats {
302 self.stats
303 .lock()
304 .map_or_else(|_| LoadingStats::default(), |s| s.clone())
305 }
306
307 fn load_batch(&self, batch_idx: usize) -> BatchResult<F> {
309 let start = batch_idx * self.config.batch_size;
310 let end = (start + self.config.batch_size).min(self.indices.len());
311
312 if start >= self.indices.len() {
313 return Err(NeuralError::TrainingError(
314 "Batch index out of range".to_string(),
315 ));
316 }
317
318 let batch_indices: Vec<usize> = self.indices[start..end].to_vec();
319
320 if batch_indices.is_empty() {
321 return Err(NeuralError::TrainingError("Empty batch".to_string()));
322 }
323
324 let (first_x, first_y) = self.dataset.get(batch_indices[0])?;
326
327 let batch_x_shape: Vec<usize> = std::iter::once(batch_indices.len())
329 .chain(first_x.shape().iter().copied())
330 .collect();
331 let batch_y_shape: Vec<usize> = std::iter::once(batch_indices.len())
332 .chain(first_y.shape().iter().copied())
333 .collect();
334
335 let mut batch_x = Array::zeros(IxDyn(&batch_x_shape));
336 let mut batch_y = Array::zeros(IxDyn(&batch_y_shape));
337
338 for (i, &idx) in batch_indices.iter().enumerate() {
340 let (x, y) = self.dataset.get(idx)?;
341
342 let mut batch_x_slice = batch_x.slice_mut(scirs2_core::ndarray::s![i, ..]);
344 batch_x_slice.assign(&x);
345
346 let mut batch_y_slice = batch_y.slice_mut(scirs2_core::ndarray::s![i, ..]);
347 batch_y_slice.assign(&y);
348 }
349
350 Ok((batch_x, batch_y))
351 }
352
353 pub fn next_batch(&self) -> Option<BatchResult<F>> {
355 let batch_idx = self.position.fetch_add(1, Ordering::Relaxed);
356
357 if batch_idx >= self.num_batches {
358 return None;
359 }
360
361 if let Some(ref cache) = self.cache {
363 if let Ok(cache_guard) = cache.lock() {
364 if let Some(batch) = cache_guard.get(batch_idx) {
365 if let Ok(mut stats) = self.stats.lock() {
366 stats.cache_hits += 1;
367 }
368 return Some(Ok((batch.0.clone(), batch.1.clone())));
369 }
370 }
371 }
372
373 let start = Instant::now();
375 let result = self.load_batch(batch_idx);
376 let load_time = start.elapsed();
377
378 if let Ok(mut stats) = self.stats.lock() {
380 stats.batches_loaded += 1;
381 stats.samples_loaded += self.config.batch_size.min(
382 self.indices
383 .len()
384 .saturating_sub(batch_idx * self.config.batch_size),
385 );
386 stats.total_load_time += load_time;
387 stats.avg_batch_time = stats.total_load_time / stats.batches_loaded as u32;
388 stats.cache_misses += 1;
389 }
390
391 if let Some(ref cache) = self.cache {
393 if let Ok(ref batch) = result {
394 if let Ok(mut cache_guard) = cache.lock() {
395 cache_guard.insert(batch_idx, (batch.0.clone(), batch.1.clone()));
396 }
397 }
398 }
399
400 Some(result)
401 }
402
403 pub fn prefetch_iter(self) -> PrefetchingIterator<F, D> {
405 PrefetchingIterator::new(self)
406 }
407}
408
409impl<
410 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
411 D: Dataset<F> + Send + Sync + Clone + 'static,
412 > Iterator for OptimizedDataLoader<F, D>
413{
414 type Item = BatchResult<F>;
415
416 fn next(&mut self) -> Option<Self::Item> {
417 self.next_batch()
418 }
419}
420
421pub struct PrefetchingIterator<
427 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
428 D: Dataset<F> + Send + Sync + Clone + 'static,
429> {
430 loader: Arc<OptimizedDataLoader<F, D>>,
432 queue: Arc<PrefetchQueue<F>>,
434 worker_handle: Option<thread::JoinHandle<()>>,
436 expected_idx: usize,
438 buffer: VecDeque<(usize, BatchResult<F>)>,
440}
441
442impl<
443 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
444 D: Dataset<F> + Send + Sync + Clone + 'static,
445 > PrefetchingIterator<F, D>
446{
447 fn new(loader: OptimizedDataLoader<F, D>) -> Self {
449 let prefetch_size = loader.config.prefetch_size;
450 let loader = Arc::new(loader);
451 let queue = Arc::new(PrefetchQueue::new(prefetch_size));
452
453 let worker_loader = Arc::clone(&loader);
455 let worker_queue = Arc::clone(&queue);
456
457 let worker_handle = thread::spawn(move || {
458 let mut batch_idx = 0;
459 loop {
460 if worker_queue.stop.load(Ordering::Relaxed) {
461 break;
462 }
463
464 if batch_idx >= worker_loader.num_batches {
465 break;
466 }
467
468 let result = worker_loader.load_batch(batch_idx);
469 if !worker_queue.push(batch_idx, result) {
470 break;
471 }
472 batch_idx += 1;
473 }
474 });
475
476 Self {
477 loader,
478 queue,
479 worker_handle: Some(worker_handle),
480 expected_idx: 0,
481 buffer: VecDeque::new(),
482 }
483 }
484}
485
486impl<
487 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
488 D: Dataset<F> + Send + Sync + Clone + 'static,
489 > Iterator for PrefetchingIterator<F, D>
490{
491 type Item = BatchResult<F>;
492
493 fn next(&mut self) -> Option<Self::Item> {
494 if self.expected_idx >= self.loader.num_batches {
495 return None;
496 }
497
498 if let Some(pos) = self
500 .buffer
501 .iter()
502 .position(|(idx, _)| *idx == self.expected_idx)
503 {
504 let (_, result) = self.buffer.remove(pos).expect("Position was just found");
505 self.expected_idx += 1;
506 return Some(result);
507 }
508
509 let wait_start = Instant::now();
511 loop {
512 if let Some((idx, result)) = self.queue.pop() {
513 if idx == self.expected_idx {
514 self.expected_idx += 1;
515
516 if let Ok(mut stats) = self.loader.stats.lock() {
518 stats.prefetch_wait_time += wait_start.elapsed();
519 }
520
521 return Some(result);
522 } else {
523 self.buffer.push_back((idx, result));
525 }
526 } else if self.queue.is_empty() && self.queue.stop.load(Ordering::Relaxed) {
527 return None;
529 } else {
530 thread::sleep(Duration::from_micros(10));
532 }
533 }
534 }
535}
536
537impl<
538 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
539 D: Dataset<F> + Send + Sync + Clone + 'static,
540 > Drop for PrefetchingIterator<F, D>
541{
542 fn drop(&mut self) {
543 self.queue.stop();
544 if let Some(handle) = self.worker_handle.take() {
545 let _ = handle.join();
546 }
547 }
548}
549
550#[derive(Debug, Clone)]
556pub struct BatchSizeOptimizationResult {
557 pub recommended_batch_size: usize,
559 pub throughput_results: Vec<(usize, f64)>,
561 pub memory_results: Vec<(usize, usize)>,
563 pub memory_limited: bool,
565}
566
567pub struct BatchSizeOptimizer {
569 min_batch_size: usize,
571 max_batch_size: usize,
573 warmup_batches: usize,
575 timing_batches: usize,
577 max_memory: usize,
579}
580
581impl Default for BatchSizeOptimizer {
582 fn default() -> Self {
583 Self {
584 min_batch_size: 8,
585 max_batch_size: 512,
586 warmup_batches: 2,
587 timing_batches: 5,
588 max_memory: 0,
589 }
590 }
591}
592
593impl BatchSizeOptimizer {
594 pub fn new() -> Self {
596 Self::default()
597 }
598
599 pub fn with_range(mut self, min: usize, max: usize) -> Self {
601 self.min_batch_size = min;
602 self.max_batch_size = max;
603 self
604 }
605
606 pub fn with_max_memory(mut self, max_memory: usize) -> Self {
608 self.max_memory = max_memory;
609 self
610 }
611
612 pub fn find_optimal<
614 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
615 D: Dataset<F> + Send + Sync + Clone + 'static,
616 >(
617 &self,
618 dataset: D,
619 ) -> Result<BatchSizeOptimizationResult> {
620 let mut throughput_results = Vec::new();
621 let mut memory_results = Vec::new();
622 let mut best_throughput = 0.0;
623 let mut best_batch_size = self.min_batch_size;
624 let mut memory_limited = false;
625
626 let mut batch_size = self.min_batch_size;
627
628 while batch_size <= self.max_batch_size && batch_size <= dataset.len() {
629 let config = OptimizedLoaderConfig {
630 batch_size,
631 shuffle: false,
632 drop_last: true,
633 ..Default::default()
634 };
635
636 let mut loader = OptimizedDataLoader::new(dataset.clone(), config);
637 loader.reset();
638
639 for _ in 0..self.warmup_batches {
641 if loader.next_batch().is_none() {
642 break;
643 }
644 }
645
646 let start = Instant::now();
648 let mut batches_processed = 0;
649 let mut total_memory = 0;
650
651 for _ in 0..self.timing_batches {
652 match loader.next_batch() {
653 Some(Ok((x, y))) => {
654 batches_processed += 1;
655 total_memory += estimate_array_memory(&x) + estimate_array_memory(&y);
656 }
657 Some(Err(_)) => break,
658 None => break,
659 }
660 }
661
662 if batches_processed == 0 {
663 break;
664 }
665
666 let elapsed = start.elapsed().as_secs_f64();
667 let samples_per_second = (batches_processed * batch_size) as f64 / elapsed;
668 let avg_memory = total_memory / batches_processed;
669
670 throughput_results.push((batch_size, samples_per_second));
671 memory_results.push((batch_size, avg_memory));
672
673 if self.max_memory > 0 && avg_memory > self.max_memory {
675 memory_limited = true;
676 break;
677 }
678
679 if samples_per_second > best_throughput {
680 best_throughput = samples_per_second;
681 best_batch_size = batch_size;
682 }
683
684 batch_size = (batch_size * 2).min(self.max_batch_size + 1);
686 }
687
688 Ok(BatchSizeOptimizationResult {
689 recommended_batch_size: best_batch_size,
690 throughput_results,
691 memory_results,
692 memory_limited,
693 })
694 }
695}
696
697#[derive(Debug, Clone)]
706pub struct MemoryAwareConfig {
707 pub target_memory_fraction: f64,
711 pub bytes_per_sample: Option<usize>,
715 pub min_batch_size: usize,
717 pub max_batch_size: usize,
719 pub shuffle: bool,
721 pub drop_last: bool,
723 pub prefetch_ahead: usize,
725}
726
727impl Default for MemoryAwareConfig {
728 fn default() -> Self {
729 Self {
730 target_memory_fraction: 0.25,
731 bytes_per_sample: None,
732 min_batch_size: 4,
733 max_batch_size: 4096,
734 shuffle: true,
735 drop_last: false,
736 prefetch_ahead: 2,
737 }
738 }
739}
740
741fn estimate_available_memory_bytes() -> usize {
747 #[cfg(target_os = "linux")]
749 {
750 if let Ok(contents) = std::fs::read_to_string("/proc/meminfo") {
751 for line in contents.lines() {
754 if line.starts_with("MemAvailable:") {
755 let parts: Vec<&str> = line.split_whitespace().collect();
756 if parts.len() >= 2 {
757 if let Ok(kb) = parts[1].parse::<usize>() {
758 return kb * 1024;
759 }
760 }
761 }
762 }
763 }
764 }
765 512 * 1024 * 1024
767}
768
769fn compute_adaptive_batch_size(
778 dataset_len: usize,
779 bytes_per_sample: usize,
780 config: &MemoryAwareConfig,
781) -> usize {
782 let chunk_cfg = ChunkConfig {
787 strategy: ChunkStrategy::Adaptive,
788 min_chunk_size: config.min_batch_size,
789 max_chunk_size: config.max_batch_size,
790 ..ChunkConfig::default()
791 };
792 let chunking_hint = ChunkingUtils::optimal_chunk_size(dataset_len, &chunk_cfg);
793
794 let available = estimate_available_memory_bytes();
796 let budget_bytes = ((available as f64) * config.target_memory_fraction) as usize;
798 let budget_samples = budget_bytes
800 .checked_div(bytes_per_sample)
801 .map(|v| v.max(1))
802 .unwrap_or(config.max_batch_size);
803
804 let raw = chunking_hint.min(budget_samples);
808 raw.max(config.min_batch_size).min(config.max_batch_size)
809}
810
811pub struct MemoryAwareDataLoader<
825 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
826 D: Dataset<F> + Send + Sync + Clone + 'static,
827> {
828 dataset: Arc<D>,
830 config: MemoryAwareConfig,
832 indices: Vec<usize>,
834 position: AtomicUsize,
836 batch_size: usize,
839 num_batches: usize,
841 stats: Mutex<LoadingStats>,
843 _phantom: PhantomData<F>,
844}
845
846impl<
847 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
848 D: Dataset<F> + Send + Sync + Clone + 'static,
849 > MemoryAwareDataLoader<F, D>
850{
851 pub fn new_adaptive(dataset: D, config: MemoryAwareConfig) -> Result<Self> {
863 let dataset_len = dataset.len();
864 if dataset_len == 0 {
865 return Err(NeuralError::TrainingError(
866 "Cannot create MemoryAwareDataLoader from an empty dataset".to_string(),
867 ));
868 }
869
870 let bytes_per_sample = match config.bytes_per_sample {
872 Some(b) => b,
873 None => {
874 let (x0, y0) = dataset.get(0)?;
876 (x0.len() + y0.len()) * std::mem::size_of::<F>()
877 }
878 };
879
880 let batch_size = compute_adaptive_batch_size(dataset_len, bytes_per_sample, &config);
881
882 let num_batches = if config.drop_last {
883 dataset_len / batch_size
884 } else {
885 dataset_len.div_ceil(batch_size)
886 };
887
888 let indices: Vec<usize> = (0..dataset_len).collect();
889
890 Ok(Self {
891 dataset: Arc::new(dataset),
892 config,
893 indices,
894 position: AtomicUsize::new(0),
895 batch_size,
896 num_batches,
897 stats: Mutex::new(LoadingStats::default()),
898 _phantom: PhantomData,
899 })
900 }
901
902 pub fn refresh_batch_size(&mut self) -> Result<usize> {
908 let dataset_len = self.dataset.len();
909 let bytes_per_sample = match self.config.bytes_per_sample {
910 Some(b) => b,
911 None => {
912 let (x0, y0) = self.dataset.get(0)?;
913 (x0.len() + y0.len()) * std::mem::size_of::<F>()
914 }
915 };
916
917 let new_batch_size =
918 compute_adaptive_batch_size(dataset_len, bytes_per_sample, &self.config);
919 self.batch_size = new_batch_size;
920 self.num_batches = if self.config.drop_last {
921 dataset_len / new_batch_size
922 } else {
923 dataset_len.div_ceil(new_batch_size)
924 };
925 Ok(new_batch_size)
926 }
927
928 pub fn adaptive_batch_size(&self) -> usize {
930 self.batch_size
931 }
932
933 pub fn num_batches(&self) -> usize {
935 self.num_batches
936 }
937
938 pub fn len(&self) -> usize {
940 self.dataset.len()
941 }
942
943 pub fn is_empty(&self) -> bool {
945 self.dataset.len() == 0
946 }
947
948 pub fn stats(&self) -> LoadingStats {
950 self.stats
951 .lock()
952 .map_or_else(|_| LoadingStats::default(), |s| s.clone())
953 }
954
955 pub fn reset(&mut self) {
959 if self.config.shuffle {
960 let mut rng = scirs2_core::random::rng();
961 self.indices.shuffle(&mut rng);
962 }
963 self.position.store(0, Ordering::Relaxed);
964 }
965
966 fn load_batch(&self, batch_idx: usize) -> BatchResult<F> {
969 let start = batch_idx * self.batch_size;
970 let end = (start + self.batch_size).min(self.indices.len());
971
972 if start >= self.indices.len() {
973 return Err(NeuralError::TrainingError(
974 "Batch index out of range".to_string(),
975 ));
976 }
977
978 let batch_indices: Vec<usize> = self.indices[start..end].to_vec();
979
980 if batch_indices.is_empty() {
981 return Err(NeuralError::TrainingError("Empty batch".to_string()));
982 }
983
984 let (first_x, first_y) = self.dataset.get(batch_indices[0])?;
986
987 let batch_x_shape: Vec<usize> = std::iter::once(batch_indices.len())
988 .chain(first_x.shape().iter().copied())
989 .collect();
990 let batch_y_shape: Vec<usize> = std::iter::once(batch_indices.len())
991 .chain(first_y.shape().iter().copied())
992 .collect();
993
994 let mut batch_x = Array::zeros(IxDyn(&batch_x_shape));
995 let mut batch_y = Array::zeros(IxDyn(&batch_y_shape));
996
997 for (i, &idx) in batch_indices.iter().enumerate() {
998 let (x, y) = self.dataset.get(idx)?;
999 let mut sx = batch_x.slice_mut(scirs2_core::ndarray::s![i, ..]);
1000 sx.assign(&x);
1001 let mut sy = batch_y.slice_mut(scirs2_core::ndarray::s![i, ..]);
1002 sy.assign(&y);
1003 }
1004
1005 Ok((batch_x, batch_y))
1006 }
1007
1008 pub fn next_batch(&self) -> Option<BatchResult<F>> {
1011 let batch_idx = self.position.fetch_add(1, Ordering::Relaxed);
1012 if batch_idx >= self.num_batches {
1013 return None;
1014 }
1015
1016 let start_time = Instant::now();
1017 let result = self.load_batch(batch_idx);
1018 let elapsed = start_time.elapsed();
1019
1020 if let Ok(mut stats) = self.stats.lock() {
1021 stats.batches_loaded += 1;
1022 stats.samples_loaded += self.batch_size.min(
1023 self.indices
1024 .len()
1025 .saturating_sub(batch_idx * self.batch_size),
1026 );
1027 stats.total_load_time += elapsed;
1028 stats.avg_batch_time = stats.total_load_time / stats.batches_loaded as u32;
1029 stats.cache_misses += 1;
1030 }
1031
1032 Some(result)
1033 }
1034
1035 pub fn into_prefetch_iter(self) -> MemoryAwarePrefetchIter<F, D> {
1039 MemoryAwarePrefetchIter::new(self)
1040 }
1041}
1042
1043impl<
1044 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
1045 D: Dataset<F> + Send + Sync + Clone + 'static,
1046 > Iterator for MemoryAwareDataLoader<F, D>
1047{
1048 type Item = BatchResult<F>;
1049
1050 fn next(&mut self) -> Option<Self::Item> {
1051 self.next_batch()
1052 }
1053}
1054
1055pub struct MemoryAwarePrefetchIter<
1067 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
1068 D: Dataset<F> + Send + Sync + Clone + 'static,
1069> {
1070 loader: Arc<MemoryAwareDataLoader<F, D>>,
1071 queue: Arc<PrefetchQueue<F>>,
1072 worker: Option<thread::JoinHandle<()>>,
1073 expected_idx: usize,
1074 out_of_order: VecDeque<(usize, BatchResult<F>)>,
1076}
1077
1078impl<
1079 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
1080 D: Dataset<F> + Send + Sync + Clone + 'static,
1081 > MemoryAwarePrefetchIter<F, D>
1082{
1083 fn new(loader: MemoryAwareDataLoader<F, D>) -> Self {
1084 let prefetch_ahead = loader.config.prefetch_ahead;
1085 let num_batches = loader.num_batches;
1086 let loader = Arc::new(loader);
1087 let queue = Arc::new(PrefetchQueue::new(prefetch_ahead));
1088
1089 let worker_loader = Arc::clone(&loader);
1090 let worker_queue = Arc::clone(&queue);
1091
1092 let worker = thread::spawn(move || {
1093 for batch_idx in 0..num_batches {
1094 if worker_queue.stop.load(Ordering::Relaxed) {
1095 break;
1096 }
1097 let result = worker_loader.load_batch(batch_idx);
1098 if !worker_queue.push(batch_idx, result) {
1099 break;
1100 }
1101 }
1102 });
1103
1104 Self {
1105 loader,
1106 queue,
1107 worker: Some(worker),
1108 expected_idx: 0,
1109 out_of_order: VecDeque::new(),
1110 }
1111 }
1112}
1113
1114impl<
1115 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
1116 D: Dataset<F> + Send + Sync + Clone + 'static,
1117 > Iterator for MemoryAwarePrefetchIter<F, D>
1118{
1119 type Item = BatchResult<F>;
1120
1121 fn next(&mut self) -> Option<Self::Item> {
1122 if self.expected_idx >= self.loader.num_batches {
1123 return None;
1124 }
1125
1126 if let Some(pos) = self
1128 .out_of_order
1129 .iter()
1130 .position(|(idx, _)| *idx == self.expected_idx)
1131 {
1132 let (_, result) = self
1133 .out_of_order
1134 .remove(pos)
1135 .expect("position was just found in out_of_order buffer");
1136 self.expected_idx += 1;
1137 return Some(result);
1138 }
1139
1140 let wait_start = Instant::now();
1142 loop {
1143 if let Some((idx, result)) = self.queue.pop() {
1144 if idx == self.expected_idx {
1145 if let Ok(mut stats) = self.loader.stats.lock() {
1146 stats.prefetch_wait_time += wait_start.elapsed();
1147 }
1148 self.expected_idx += 1;
1149 return Some(result);
1150 }
1151 self.out_of_order.push_back((idx, result));
1153 } else if self.queue.is_empty() && self.queue.stop.load(Ordering::Relaxed) {
1154 return None;
1155 } else {
1156 thread::sleep(Duration::from_micros(10));
1157 }
1158 }
1159 }
1160}
1161
1162impl<
1163 F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
1164 D: Dataset<F> + Send + Sync + Clone + 'static,
1165 > Drop for MemoryAwarePrefetchIter<F, D>
1166{
1167 fn drop(&mut self) {
1168 self.queue.stop();
1169 if let Some(handle) = self.worker.take() {
1170 let _ = handle.join();
1171 }
1172 }
1173}
1174
1175#[cfg(test)]
1180mod tests {
1181 use super::*;
1182 use crate::data::InMemoryDataset;
1183
1184 fn create_test_dataset() -> InMemoryDataset<f64> {
1185 let features = Array::zeros(IxDyn(&[100, 10]));
1186 let labels = Array::zeros(IxDyn(&[100, 2]));
1187 InMemoryDataset::new(features, labels).expect("Failed to create test dataset")
1188 }
1189
1190 #[test]
1191 fn test_optimized_loader_config_default() {
1192 let config = OptimizedLoaderConfig::default();
1193 assert_eq!(config.batch_size, 32);
1194 assert_eq!(config.prefetch_size, 2);
1195 assert_eq!(config.num_workers, 0);
1196 assert!(!config.drop_last);
1197 assert!(config.shuffle);
1198 }
1199
1200 #[test]
1201 fn test_optimized_dataloader_creation() {
1202 let dataset = create_test_dataset();
1203 let config = OptimizedLoaderConfig {
1204 batch_size: 10,
1205 shuffle: false,
1206 ..Default::default()
1207 };
1208
1209 let loader = OptimizedDataLoader::new(dataset, config);
1210 assert_eq!(loader.len(), 100);
1211 assert_eq!(loader.num_batches(), 10);
1212 }
1213
1214 #[test]
1215 fn test_optimized_dataloader_iteration() {
1216 let dataset = create_test_dataset();
1217 let config = OptimizedLoaderConfig {
1218 batch_size: 10,
1219 shuffle: false,
1220 drop_last: true,
1221 ..Default::default()
1222 };
1223
1224 let mut loader = OptimizedDataLoader::new(dataset, config);
1225 loader.reset();
1226
1227 let mut batch_count = 0;
1228 while let Some(result) = loader.next_batch() {
1229 let (x, y) = result.expect("Failed to load batch");
1230 assert_eq!(x.shape()[0], 10);
1231 assert_eq!(y.shape()[0], 10);
1232 batch_count += 1;
1233 }
1234
1235 assert_eq!(batch_count, 10);
1236 }
1237
1238 #[test]
1239 fn test_optimized_dataloader_stats() {
1240 let dataset = create_test_dataset();
1241 let config = OptimizedLoaderConfig {
1242 batch_size: 20,
1243 shuffle: false,
1244 ..Default::default()
1245 };
1246
1247 let mut loader = OptimizedDataLoader::new(dataset, config);
1248 loader.reset();
1249
1250 while loader.next_batch().is_some() {}
1252
1253 let stats = loader.stats();
1254 assert_eq!(stats.batches_loaded, 5);
1255 assert_eq!(stats.samples_loaded, 100);
1256 }
1257
1258 #[test]
1259 fn test_batch_cache() {
1260 let mut cache: BatchCache<f64> = BatchCache::new(10);
1261
1262 let batch1 = (Array::zeros(IxDyn(&[5, 10])), Array::zeros(IxDyn(&[5, 2])));
1263
1264 cache.insert(0, batch1.clone());
1265
1266 let cached = cache.get(0);
1267 assert!(cached.is_some());
1268 assert_eq!(cached.map(|b| b.0.shape()[0]), Some(5));
1269
1270 assert!(cache.get(1).is_none());
1271
1272 cache.clear();
1273 assert!(cache.get(0).is_none());
1274 }
1275
1276 #[test]
1277 fn test_prefetch_queue() {
1278 let queue: PrefetchQueue<f64> = PrefetchQueue::new(3);
1279
1280 let batch = Ok((Array::zeros(IxDyn(&[5, 10])), Array::zeros(IxDyn(&[5, 2]))));
1281
1282 assert!(queue.push(0, batch));
1283 assert!(!queue.is_empty());
1284
1285 let popped = queue.pop();
1286 assert!(popped.is_some());
1287 assert_eq!(popped.map(|(idx, _)| idx), Some(0));
1288
1289 assert!(queue.is_empty());
1290
1291 queue.stop();
1292 let batch2 = Ok((Array::zeros(IxDyn(&[5, 10])), Array::zeros(IxDyn(&[5, 2]))));
1294 assert!(!queue.push(1, batch2));
1295 }
1296
1297 #[test]
1298 fn test_loading_stats_default() {
1299 let stats = LoadingStats::default();
1300 assert_eq!(stats.batches_loaded, 0);
1301 assert_eq!(stats.samples_loaded, 0);
1302 assert_eq!(stats.cache_hits, 0);
1303 assert_eq!(stats.cache_misses, 0);
1304 }
1305
1306 #[test]
1307 fn test_estimate_array_memory() {
1308 let array: Array<f64, IxDyn> = Array::zeros(IxDyn(&[10, 20]));
1309 let memory = estimate_array_memory(&array);
1310 assert_eq!(memory, 10 * 20 * std::mem::size_of::<f64>());
1311 }
1312
1313 #[test]
1314 fn test_batch_size_optimizer_default() {
1315 let optimizer = BatchSizeOptimizer::default();
1316 assert_eq!(optimizer.min_batch_size, 8);
1317 assert_eq!(optimizer.max_batch_size, 512);
1318 }
1319
1320 #[test]
1321 fn test_batch_size_optimizer_with_range() {
1322 let optimizer = BatchSizeOptimizer::new()
1323 .with_range(16, 256)
1324 .with_max_memory(1024 * 1024);
1325
1326 assert_eq!(optimizer.min_batch_size, 16);
1327 assert_eq!(optimizer.max_batch_size, 256);
1328 assert_eq!(optimizer.max_memory, 1024 * 1024);
1329 }
1330
1331 #[test]
1332 fn test_find_optimal_batch_size() {
1333 let dataset = create_test_dataset();
1334 let optimizer = BatchSizeOptimizer::new().with_range(10, 50);
1335
1336 let result = optimizer.find_optimal(dataset);
1337 assert!(result.is_ok());
1338
1339 let result = result.expect("Optimization should succeed");
1340 assert!(result.recommended_batch_size >= 10);
1341 assert!(result.recommended_batch_size <= 50);
1342 assert!(!result.throughput_results.is_empty());
1343 }
1344
1345 #[test]
1346 fn test_dataloader_with_caching() {
1347 let dataset = create_test_dataset();
1348 let config = OptimizedLoaderConfig {
1349 batch_size: 10,
1350 shuffle: false,
1351 cache_batches: true,
1352 ..Default::default()
1353 };
1354
1355 let mut loader = OptimizedDataLoader::new(dataset, config);
1356 loader.reset();
1357
1358 while loader.next_batch().is_some() {}
1360
1361 let stats = loader.stats();
1362 assert_eq!(stats.cache_misses, 10);
1363 assert_eq!(stats.cache_hits, 0);
1364 }
1365
1366 #[test]
1367 fn test_iterator_trait() {
1368 let dataset = create_test_dataset();
1369 let config = OptimizedLoaderConfig {
1370 batch_size: 25,
1371 shuffle: false,
1372 drop_last: true,
1373 ..Default::default()
1374 };
1375
1376 let mut loader = OptimizedDataLoader::new(dataset, config);
1377 loader.reset();
1378
1379 let batches: Vec<_> = loader.collect();
1380 assert_eq!(batches.len(), 4); }
1382
1383 #[test]
1388 fn test_memory_aware_config_default() {
1389 let cfg = MemoryAwareConfig::default();
1390 assert!(
1391 cfg.target_memory_fraction > 0.0 && cfg.target_memory_fraction <= 1.0,
1392 "target_memory_fraction must be in (0, 1]"
1393 );
1394 assert!(cfg.min_batch_size >= 1);
1395 assert!(cfg.max_batch_size >= cfg.min_batch_size);
1396 }
1397
1398 #[test]
1399 fn test_memory_aware_loader_creation() {
1400 let dataset = create_test_dataset();
1401 let config = MemoryAwareConfig {
1402 shuffle: false,
1403 drop_last: false,
1404 ..Default::default()
1405 };
1406
1407 let loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1408 .expect("loader creation must succeed");
1409
1410 let bs = loader.adaptive_batch_size();
1412 assert!(bs >= 4, "batch_size ({bs}) must be >= min_batch_size (4)");
1413 assert!(
1414 bs <= 4096,
1415 "batch_size ({bs}) must be <= max_batch_size (4096)"
1416 );
1417 assert!(loader.num_batches() >= 1);
1419 assert_eq!(loader.len(), 100);
1420 assert!(!loader.is_empty());
1421 }
1422
1423 #[test]
1424 fn test_memory_aware_loader_iteration_all_samples() {
1425 let dataset = create_test_dataset();
1426 let config = MemoryAwareConfig {
1427 shuffle: false,
1428 drop_last: false,
1429 min_batch_size: 10,
1430 max_batch_size: 10,
1431 target_memory_fraction: 1.0, ..Default::default()
1433 };
1434
1435 let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1436 .expect("loader creation must succeed");
1437 loader.reset();
1438
1439 let mut total_samples = 0usize;
1440 let mut batch_count = 0usize;
1441 while let Some(result) = loader.next_batch() {
1442 let (x, _y) = result.expect("batch load must succeed");
1443 total_samples += x.shape()[0];
1444 batch_count += 1;
1445 }
1446
1447 assert_eq!(total_samples, 100, "all 100 samples must be yielded");
1448 assert_eq!(batch_count, 10, "100 samples / batch_size 10 = 10 batches");
1449 }
1450
1451 #[test]
1452 fn test_memory_aware_loader_drop_last() {
1453 let dataset = create_test_dataset();
1456 let config = MemoryAwareConfig {
1457 shuffle: false,
1458 drop_last: true,
1459 min_batch_size: 32,
1460 max_batch_size: 32,
1461 target_memory_fraction: 1.0,
1462 ..Default::default()
1463 };
1464
1465 let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1466 .expect("loader creation must succeed");
1467 loader.reset();
1468
1469 let batches: Vec<_> = loader.collect();
1470 assert_eq!(batches.len(), 3, "drop_last: 100/32 = 3 full batches");
1471 }
1472
1473 #[test]
1474 fn test_memory_aware_loader_refresh_batch_size() {
1475 let dataset = create_test_dataset();
1476 let config = MemoryAwareConfig {
1477 shuffle: false,
1478 ..Default::default()
1479 };
1480
1481 let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1482 .expect("loader creation must succeed");
1483
1484 let new_bs = loader.refresh_batch_size().expect("refresh must succeed");
1485 assert!(new_bs >= loader.config.min_batch_size);
1486 assert!(new_bs <= loader.config.max_batch_size);
1487 assert_eq!(new_bs, loader.adaptive_batch_size());
1488 }
1489
1490 #[test]
1491 fn test_memory_aware_loader_stats() {
1492 let dataset = create_test_dataset();
1493 let config = MemoryAwareConfig {
1494 shuffle: false,
1495 drop_last: false,
1496 min_batch_size: 10,
1497 max_batch_size: 10,
1498 target_memory_fraction: 1.0,
1499 ..Default::default()
1500 };
1501
1502 let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1503 .expect("loader creation must succeed");
1504 loader.reset();
1505
1506 while loader.next_batch().is_some() {}
1507
1508 let stats = loader.stats();
1509 assert_eq!(stats.batches_loaded, 10);
1510 assert_eq!(stats.samples_loaded, 100);
1511 }
1512
1513 #[test]
1514 fn test_memory_aware_prefetch_iter() {
1515 let dataset = create_test_dataset();
1516 let config = MemoryAwareConfig {
1517 shuffle: false,
1518 drop_last: false,
1519 min_batch_size: 10,
1520 max_batch_size: 10,
1521 target_memory_fraction: 1.0,
1522 prefetch_ahead: 2,
1523 ..Default::default()
1524 };
1525
1526 let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1527 .expect("loader creation must succeed");
1528 loader.reset();
1529
1530 let iter = loader.into_prefetch_iter();
1531 let batches: Vec<_> = iter.collect();
1532
1533 for batch_result in &batches {
1535 let (x, _y) = batch_result
1536 .as_ref()
1537 .expect("prefetch batch must not be an error");
1538 assert_eq!(x.shape()[0], 10);
1539 }
1540 assert_eq!(batches.len(), 10);
1541 }
1542
1543 #[test]
1544 fn test_estimate_available_memory_is_positive() {
1545 let mem = estimate_available_memory_bytes();
1546 assert!(mem > 0, "available memory estimate must be > 0");
1547 }
1548
1549 #[test]
1550 fn test_compute_adaptive_batch_size_bounds() {
1551 let config = MemoryAwareConfig {
1552 min_batch_size: 8,
1553 max_batch_size: 64,
1554 target_memory_fraction: 0.1,
1555 bytes_per_sample: Some(1024),
1556 ..Default::default()
1557 };
1558 let bs = compute_adaptive_batch_size(1000, 1024, &config);
1559 assert!(bs >= 8, "must respect min_batch_size");
1560 assert!(bs <= 64, "must respect max_batch_size");
1561 }
1562}