Skip to main content

scirs2_neural/training/
optimized_dataloader.rs

1//! Optimized data loading pipeline with prefetching and parallel loading
2//!
3//! This module provides an optimized data loading pipeline with:
4//! - Prefetching for overlapping data loading and computation
5//! - Parallel batch loading with configurable worker threads
6//! - Memory-efficient batch caching
7//! - Automatic batch size optimization
8
9use crate::data::Dataset;
10use crate::error::{NeuralError, Result};
11use scirs2_core::chunking::{ChunkConfig, ChunkStrategy, ChunkingUtils};
12use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
13use scirs2_core::numeric::{Float, FromPrimitive};
14use scirs2_core::random::seq::SliceRandom;
15use scirs2_core::NumAssign;
16use std::collections::VecDeque;
17use std::fmt::Debug;
18use std::marker::PhantomData;
19use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
20use std::sync::{Arc, Mutex};
21use std::thread;
22use std::time::{Duration, Instant};
23
24/// Type alias for a batch pair of input and target arrays
25type BatchPair<F> = (Array<F, IxDyn>, Array<F, IxDyn>);
26
27// =============================================================================
28// Configuration
29// =============================================================================
30
31/// Configuration for optimized data loading
32#[derive(Debug, Clone)]
33pub struct OptimizedLoaderConfig {
34    /// Batch size
35    pub batch_size: usize,
36    /// Number of batches to prefetch
37    pub prefetch_size: usize,
38    /// Number of worker threads (0 for single-threaded)
39    pub num_workers: usize,
40    /// Whether to drop the last incomplete batch
41    pub drop_last: bool,
42    /// Whether to shuffle data
43    pub shuffle: bool,
44    /// Pin memory for faster GPU transfer (placeholder for future GPU support)
45    pub pin_memory: bool,
46    /// Cache batches in memory
47    pub cache_batches: bool,
48    /// Maximum memory for cache (in bytes, 0 for unlimited)
49    pub max_cache_memory: usize,
50}
51
52impl Default for OptimizedLoaderConfig {
53    fn default() -> Self {
54        Self {
55            batch_size: 32,
56            prefetch_size: 2,
57            num_workers: 0,
58            drop_last: false,
59            shuffle: true,
60            pin_memory: false,
61            cache_batches: false,
62            max_cache_memory: 0,
63        }
64    }
65}
66
67/// Statistics for data loading performance
68#[derive(Debug, Clone, Default)]
69pub struct LoadingStats {
70    /// Total batches loaded
71    pub batches_loaded: usize,
72    /// Total samples loaded
73    pub samples_loaded: usize,
74    /// Total loading time
75    pub total_load_time: Duration,
76    /// Average batch load time
77    pub avg_batch_time: Duration,
78    /// Cache hit count
79    pub cache_hits: usize,
80    /// Cache miss count
81    pub cache_misses: usize,
82    /// Prefetch queue wait time
83    pub prefetch_wait_time: Duration,
84}
85
86// =============================================================================
87// Batch Result Type
88// =============================================================================
89
90/// Type alias for batch result
91pub type BatchResult<F> = Result<(Array<F, IxDyn>, Array<F, IxDyn>)>;
92
93// =============================================================================
94// Batch Cache
95// =============================================================================
96
97/// Cache for storing loaded batches
98struct BatchCache<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> {
99    /// Cached batches by index
100    cache: Vec<Option<BatchPair<F>>>,
101    /// Maximum number of cached batches
102    max_batches: usize,
103    /// Current memory usage estimate
104    memory_usage: usize,
105}
106
107impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> BatchCache<F> {
108    fn new(max_batches: usize) -> Self {
109        Self {
110            cache: vec![None; max_batches],
111            max_batches,
112            memory_usage: 0,
113        }
114    }
115
116    fn get(&self, index: usize) -> Option<&BatchPair<F>> {
117        if index < self.cache.len() {
118            self.cache[index].as_ref()
119        } else {
120            None
121        }
122    }
123
124    fn insert(&mut self, index: usize, batch: BatchPair<F>) {
125        if index < self.cache.len() {
126            let batch_size = estimate_array_memory(&batch.0) + estimate_array_memory(&batch.1);
127            self.memory_usage += batch_size;
128            self.cache[index] = Some(batch);
129        }
130    }
131
132    fn clear(&mut self) {
133        self.cache.iter_mut().for_each(|b| *b = None);
134        self.memory_usage = 0;
135    }
136}
137
138/// Estimate memory usage of an array
139fn estimate_array_memory<F: Float + NumAssign>(array: &Array<F, IxDyn>) -> usize {
140    array.len() * std::mem::size_of::<F>()
141}
142
143// =============================================================================
144// Prefetch Queue
145// =============================================================================
146
147/// Thread-safe queue for prefetched batches
148struct PrefetchQueue<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> {
149    /// Queue of prefetched batches
150    queue: Mutex<VecDeque<(usize, BatchResult<F>)>>,
151    /// Maximum queue size
152    max_size: usize,
153    /// Current size
154    size: AtomicUsize,
155    /// Whether to stop prefetching
156    stop: AtomicBool,
157}
158
159impl<F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync> PrefetchQueue<F> {
160    fn new(max_size: usize) -> Self {
161        Self {
162            queue: Mutex::new(VecDeque::with_capacity(max_size)),
163            max_size,
164            size: AtomicUsize::new(0),
165            stop: AtomicBool::new(false),
166        }
167    }
168
169    fn push(&self, index: usize, batch: BatchResult<F>) -> bool {
170        if self.stop.load(Ordering::Relaxed) {
171            return false;
172        }
173
174        // Wait if queue is full
175        while self.size.load(Ordering::Relaxed) >= self.max_size {
176            if self.stop.load(Ordering::Relaxed) {
177                return false;
178            }
179            thread::sleep(Duration::from_micros(100));
180        }
181
182        let mut queue = match self.queue.lock() {
183            Ok(q) => q,
184            Err(_) => return false,
185        };
186        queue.push_back((index, batch));
187        self.size.fetch_add(1, Ordering::Relaxed);
188        true
189    }
190
191    fn pop(&self) -> Option<(usize, BatchResult<F>)> {
192        let mut queue = match self.queue.lock() {
193            Ok(q) => q,
194            Err(_) => return None,
195        };
196        let result = queue.pop_front();
197        if result.is_some() {
198            self.size.fetch_sub(1, Ordering::Relaxed);
199        }
200        result
201    }
202
203    fn stop(&self) {
204        self.stop.store(true, Ordering::Relaxed);
205    }
206
207    fn is_empty(&self) -> bool {
208        self.size.load(Ordering::Relaxed) == 0
209    }
210}
211
212// =============================================================================
213// Optimized Data Loader
214// =============================================================================
215
216/// Optimized data loader with prefetching and parallel loading
217pub struct OptimizedDataLoader<
218    F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync,
219    D: Dataset<F> + Send + Sync + Clone + 'static,
220> {
221    /// The underlying dataset
222    dataset: Arc<D>,
223    /// Configuration
224    config: OptimizedLoaderConfig,
225    /// Current indices for iteration
226    indices: Vec<usize>,
227    /// Current position in iteration
228    position: AtomicUsize,
229    /// Total number of batches
230    num_batches: usize,
231    /// Batch cache
232    cache: Option<Mutex<BatchCache<F>>>,
233    /// Loading statistics
234    stats: Mutex<LoadingStats>,
235    /// Phantom data for float type
236    _phantom: PhantomData<F>,
237}
238
239impl<
240        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
241        D: Dataset<F> + Send + Sync + Clone + 'static,
242    > OptimizedDataLoader<F, D>
243{
244    /// Create a new optimized data loader
245    pub fn new(dataset: D, config: OptimizedLoaderConfig) -> Self {
246        let dataset_len = dataset.len();
247        let batch_size = config.batch_size;
248        let drop_last = config.drop_last;
249
250        let num_batches = if drop_last {
251            dataset_len / batch_size
252        } else {
253            dataset_len.div_ceil(batch_size)
254        };
255
256        let indices: Vec<usize> = (0..dataset_len).collect();
257
258        let cache = if config.cache_batches {
259            Some(Mutex::new(BatchCache::new(num_batches)))
260        } else {
261            None
262        };
263
264        Self {
265            dataset: Arc::new(dataset),
266            config,
267            indices,
268            position: AtomicUsize::new(0),
269            num_batches,
270            cache,
271            stats: Mutex::new(LoadingStats::default()),
272            _phantom: PhantomData,
273        }
274    }
275
276    /// Reset the loader for a new epoch
277    pub fn reset(&mut self) {
278        if self.config.shuffle {
279            let mut rng = scirs2_core::random::rng();
280            self.indices.shuffle(&mut rng);
281        }
282        self.position.store(0, Ordering::Relaxed);
283    }
284
285    /// Get the number of batches
286    pub fn num_batches(&self) -> usize {
287        self.num_batches
288    }
289
290    /// Get the dataset length
291    pub fn len(&self) -> usize {
292        self.dataset.len()
293    }
294
295    /// Check if the loader is empty
296    pub fn is_empty(&self) -> bool {
297        self.len() == 0
298    }
299
300    /// Get loading statistics
301    pub fn stats(&self) -> LoadingStats {
302        self.stats
303            .lock()
304            .map_or_else(|_| LoadingStats::default(), |s| s.clone())
305    }
306
307    /// Load a single batch
308    fn load_batch(&self, batch_idx: usize) -> BatchResult<F> {
309        let start = batch_idx * self.config.batch_size;
310        let end = (start + self.config.batch_size).min(self.indices.len());
311
312        if start >= self.indices.len() {
313            return Err(NeuralError::TrainingError(
314                "Batch index out of range".to_string(),
315            ));
316        }
317
318        let batch_indices: Vec<usize> = self.indices[start..end].to_vec();
319
320        if batch_indices.is_empty() {
321            return Err(NeuralError::TrainingError("Empty batch".to_string()));
322        }
323
324        // Load first sample to determine shapes
325        let (first_x, first_y) = self.dataset.get(batch_indices[0])?;
326
327        // Create batch arrays
328        let batch_x_shape: Vec<usize> = std::iter::once(batch_indices.len())
329            .chain(first_x.shape().iter().copied())
330            .collect();
331        let batch_y_shape: Vec<usize> = std::iter::once(batch_indices.len())
332            .chain(first_y.shape().iter().copied())
333            .collect();
334
335        let mut batch_x = Array::zeros(IxDyn(&batch_x_shape));
336        let mut batch_y = Array::zeros(IxDyn(&batch_y_shape));
337
338        // Fill batch arrays
339        for (i, &idx) in batch_indices.iter().enumerate() {
340            let (x, y) = self.dataset.get(idx)?;
341
342            // Copy data into batch arrays
343            let mut batch_x_slice = batch_x.slice_mut(scirs2_core::ndarray::s![i, ..]);
344            batch_x_slice.assign(&x);
345
346            let mut batch_y_slice = batch_y.slice_mut(scirs2_core::ndarray::s![i, ..]);
347            batch_y_slice.assign(&y);
348        }
349
350        Ok((batch_x, batch_y))
351    }
352
353    /// Get the next batch
354    pub fn next_batch(&self) -> Option<BatchResult<F>> {
355        let batch_idx = self.position.fetch_add(1, Ordering::Relaxed);
356
357        if batch_idx >= self.num_batches {
358            return None;
359        }
360
361        // Check cache first
362        if let Some(ref cache) = self.cache {
363            if let Ok(cache_guard) = cache.lock() {
364                if let Some(batch) = cache_guard.get(batch_idx) {
365                    if let Ok(mut stats) = self.stats.lock() {
366                        stats.cache_hits += 1;
367                    }
368                    return Some(Ok((batch.0.clone(), batch.1.clone())));
369                }
370            }
371        }
372
373        // Load batch
374        let start = Instant::now();
375        let result = self.load_batch(batch_idx);
376        let load_time = start.elapsed();
377
378        // Update statistics
379        if let Ok(mut stats) = self.stats.lock() {
380            stats.batches_loaded += 1;
381            stats.samples_loaded += self.config.batch_size.min(
382                self.indices
383                    .len()
384                    .saturating_sub(batch_idx * self.config.batch_size),
385            );
386            stats.total_load_time += load_time;
387            stats.avg_batch_time = stats.total_load_time / stats.batches_loaded as u32;
388            stats.cache_misses += 1;
389        }
390
391        // Cache the result if enabled
392        if let Some(ref cache) = self.cache {
393            if let Ok(ref batch) = result {
394                if let Ok(mut cache_guard) = cache.lock() {
395                    cache_guard.insert(batch_idx, (batch.0.clone(), batch.1.clone()));
396                }
397            }
398        }
399
400        Some(result)
401    }
402
403    /// Create a prefetching iterator
404    pub fn prefetch_iter(self) -> PrefetchingIterator<F, D> {
405        PrefetchingIterator::new(self)
406    }
407}
408
409impl<
410        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
411        D: Dataset<F> + Send + Sync + Clone + 'static,
412    > Iterator for OptimizedDataLoader<F, D>
413{
414    type Item = BatchResult<F>;
415
416    fn next(&mut self) -> Option<Self::Item> {
417        self.next_batch()
418    }
419}
420
421// =============================================================================
422// Prefetching Iterator
423// =============================================================================
424
425/// Iterator that prefetches batches in the background
426pub struct PrefetchingIterator<
427    F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
428    D: Dataset<F> + Send + Sync + Clone + 'static,
429> {
430    /// The underlying loader
431    loader: Arc<OptimizedDataLoader<F, D>>,
432    /// Prefetch queue
433    queue: Arc<PrefetchQueue<F>>,
434    /// Worker thread handle
435    worker_handle: Option<thread::JoinHandle<()>>,
436    /// Expected next batch index
437    expected_idx: usize,
438    /// Buffered batches (for out-of-order delivery)
439    buffer: VecDeque<(usize, BatchResult<F>)>,
440}
441
442impl<
443        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
444        D: Dataset<F> + Send + Sync + Clone + 'static,
445    > PrefetchingIterator<F, D>
446{
447    /// Create a new prefetching iterator
448    fn new(loader: OptimizedDataLoader<F, D>) -> Self {
449        let prefetch_size = loader.config.prefetch_size;
450        let loader = Arc::new(loader);
451        let queue = Arc::new(PrefetchQueue::new(prefetch_size));
452
453        // Start prefetch worker
454        let worker_loader = Arc::clone(&loader);
455        let worker_queue = Arc::clone(&queue);
456
457        let worker_handle = thread::spawn(move || {
458            let mut batch_idx = 0;
459            loop {
460                if worker_queue.stop.load(Ordering::Relaxed) {
461                    break;
462                }
463
464                if batch_idx >= worker_loader.num_batches {
465                    break;
466                }
467
468                let result = worker_loader.load_batch(batch_idx);
469                if !worker_queue.push(batch_idx, result) {
470                    break;
471                }
472                batch_idx += 1;
473            }
474        });
475
476        Self {
477            loader,
478            queue,
479            worker_handle: Some(worker_handle),
480            expected_idx: 0,
481            buffer: VecDeque::new(),
482        }
483    }
484}
485
486impl<
487        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
488        D: Dataset<F> + Send + Sync + Clone + 'static,
489    > Iterator for PrefetchingIterator<F, D>
490{
491    type Item = BatchResult<F>;
492
493    fn next(&mut self) -> Option<Self::Item> {
494        if self.expected_idx >= self.loader.num_batches {
495            return None;
496        }
497
498        // Check buffer first
499        if let Some(pos) = self
500            .buffer
501            .iter()
502            .position(|(idx, _)| *idx == self.expected_idx)
503        {
504            let (_, result) = self.buffer.remove(pos).expect("Position was just found");
505            self.expected_idx += 1;
506            return Some(result);
507        }
508
509        // Wait for the expected batch from prefetch queue
510        let wait_start = Instant::now();
511        loop {
512            if let Some((idx, result)) = self.queue.pop() {
513                if idx == self.expected_idx {
514                    self.expected_idx += 1;
515
516                    // Update wait time statistics
517                    if let Ok(mut stats) = self.loader.stats.lock() {
518                        stats.prefetch_wait_time += wait_start.elapsed();
519                    }
520
521                    return Some(result);
522                } else {
523                    // Buffer out-of-order batches
524                    self.buffer.push_back((idx, result));
525                }
526            } else if self.queue.is_empty() && self.queue.stop.load(Ordering::Relaxed) {
527                // No more batches coming
528                return None;
529            } else {
530                // Wait a bit for prefetch
531                thread::sleep(Duration::from_micros(10));
532            }
533        }
534    }
535}
536
537impl<
538        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
539        D: Dataset<F> + Send + Sync + Clone + 'static,
540    > Drop for PrefetchingIterator<F, D>
541{
542    fn drop(&mut self) {
543        self.queue.stop();
544        if let Some(handle) = self.worker_handle.take() {
545            let _ = handle.join();
546        }
547    }
548}
549
550// =============================================================================
551// Automatic Batch Size Optimization
552// =============================================================================
553
554/// Result of batch size optimization
555#[derive(Debug, Clone)]
556pub struct BatchSizeOptimizationResult {
557    /// Recommended batch size
558    pub recommended_batch_size: usize,
559    /// Throughput at each tested batch size
560    pub throughput_results: Vec<(usize, f64)>,
561    /// Memory usage at each tested batch size
562    pub memory_results: Vec<(usize, usize)>,
563    /// Whether memory limit was reached
564    pub memory_limited: bool,
565}
566
567/// Optimizer for finding the best batch size
568pub struct BatchSizeOptimizer {
569    /// Minimum batch size to test
570    min_batch_size: usize,
571    /// Maximum batch size to test
572    max_batch_size: usize,
573    /// Number of warmup batches before timing
574    warmup_batches: usize,
575    /// Number of batches to time
576    timing_batches: usize,
577    /// Maximum memory to use (bytes, 0 for no limit)
578    max_memory: usize,
579}
580
581impl Default for BatchSizeOptimizer {
582    fn default() -> Self {
583        Self {
584            min_batch_size: 8,
585            max_batch_size: 512,
586            warmup_batches: 2,
587            timing_batches: 5,
588            max_memory: 0,
589        }
590    }
591}
592
593impl BatchSizeOptimizer {
594    /// Create a new batch size optimizer
595    pub fn new() -> Self {
596        Self::default()
597    }
598
599    /// Set the batch size range to test
600    pub fn with_range(mut self, min: usize, max: usize) -> Self {
601        self.min_batch_size = min;
602        self.max_batch_size = max;
603        self
604    }
605
606    /// Set the maximum memory limit
607    pub fn with_max_memory(mut self, max_memory: usize) -> Self {
608        self.max_memory = max_memory;
609        self
610    }
611
612    /// Find the optimal batch size for a dataset
613    pub fn find_optimal<
614        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
615        D: Dataset<F> + Send + Sync + Clone + 'static,
616    >(
617        &self,
618        dataset: D,
619    ) -> Result<BatchSizeOptimizationResult> {
620        let mut throughput_results = Vec::new();
621        let mut memory_results = Vec::new();
622        let mut best_throughput = 0.0;
623        let mut best_batch_size = self.min_batch_size;
624        let mut memory_limited = false;
625
626        let mut batch_size = self.min_batch_size;
627
628        while batch_size <= self.max_batch_size && batch_size <= dataset.len() {
629            let config = OptimizedLoaderConfig {
630                batch_size,
631                shuffle: false,
632                drop_last: true,
633                ..Default::default()
634            };
635
636            let mut loader = OptimizedDataLoader::new(dataset.clone(), config);
637            loader.reset();
638
639            // Warmup
640            for _ in 0..self.warmup_batches {
641                if loader.next_batch().is_none() {
642                    break;
643                }
644            }
645
646            // Timing
647            let start = Instant::now();
648            let mut batches_processed = 0;
649            let mut total_memory = 0;
650
651            for _ in 0..self.timing_batches {
652                match loader.next_batch() {
653                    Some(Ok((x, y))) => {
654                        batches_processed += 1;
655                        total_memory += estimate_array_memory(&x) + estimate_array_memory(&y);
656                    }
657                    Some(Err(_)) => break,
658                    None => break,
659                }
660            }
661
662            if batches_processed == 0 {
663                break;
664            }
665
666            let elapsed = start.elapsed().as_secs_f64();
667            let samples_per_second = (batches_processed * batch_size) as f64 / elapsed;
668            let avg_memory = total_memory / batches_processed;
669
670            throughput_results.push((batch_size, samples_per_second));
671            memory_results.push((batch_size, avg_memory));
672
673            // Check memory limit
674            if self.max_memory > 0 && avg_memory > self.max_memory {
675                memory_limited = true;
676                break;
677            }
678
679            if samples_per_second > best_throughput {
680                best_throughput = samples_per_second;
681                best_batch_size = batch_size;
682            }
683
684            // Increase batch size
685            batch_size = (batch_size * 2).min(self.max_batch_size + 1);
686        }
687
688        Ok(BatchSizeOptimizationResult {
689            recommended_batch_size: best_batch_size,
690            throughput_results,
691            memory_results,
692            memory_limited,
693        })
694    }
695}
696
697// =============================================================================
698// Memory-Aware Data Loader (Phase 7.1 — scirs2-core::chunking integration)
699// =============================================================================
700
701/// Configuration for memory-aware data loading.
702///
703/// Controls how the loader queries available memory and selects chunk/batch
704/// sizes to avoid pressure on the allocator during training.
705#[derive(Debug, Clone)]
706pub struct MemoryAwareConfig {
707    /// Target fraction of estimated available system memory to use per batch.
708    /// Must be in (0.0, 1.0].  Defaults to 0.25 (use ≤ 25 % of available RAM
709    /// per batch so that forward + backward passes have head-room).
710    pub target_memory_fraction: f64,
711    /// Per-sample byte count used for sizing calculations.  Set this to the
712    /// actual element count × `size_of::<F>()` for your dataset.  If `None`,
713    /// the loader will query the first sample at construction time.
714    pub bytes_per_sample: Option<usize>,
715    /// Hard lower bound on batch size.
716    pub min_batch_size: usize,
717    /// Hard upper bound on batch size.
718    pub max_batch_size: usize,
719    /// Whether to shuffle indices each epoch.
720    pub shuffle: bool,
721    /// Whether to drop the final incomplete batch.
722    pub drop_last: bool,
723    /// Number of batches to keep prefetched in the background queue.
724    pub prefetch_ahead: usize,
725}
726
727impl Default for MemoryAwareConfig {
728    fn default() -> Self {
729        Self {
730            target_memory_fraction: 0.25,
731            bytes_per_sample: None,
732            min_batch_size: 4,
733            max_batch_size: 4096,
734            shuffle: true,
735            drop_last: false,
736            prefetch_ahead: 2,
737        }
738    }
739}
740
741/// Estimate available system memory in bytes using a conservative heuristic.
742///
743/// We do not depend on any OS-specific crate here; instead we read
744/// `/proc/meminfo` on Linux and fall back to a safe 512 MiB constant on other
745/// platforms.  This keeps the crate 100 % pure-Rust and cross-platform.
746fn estimate_available_memory_bytes() -> usize {
747    // Attempt Linux /proc/meminfo first (most accurate without extra deps).
748    #[cfg(target_os = "linux")]
749    {
750        if let Ok(contents) = std::fs::read_to_string("/proc/meminfo") {
751            // Look for "MemAvailable" which already accounts for reclaimable
752            // pages and is a better predictor than "MemFree".
753            for line in contents.lines() {
754                if line.starts_with("MemAvailable:") {
755                    let parts: Vec<&str> = line.split_whitespace().collect();
756                    if parts.len() >= 2 {
757                        if let Ok(kb) = parts[1].parse::<usize>() {
758                            return kb * 1024;
759                        }
760                    }
761                }
762            }
763        }
764    }
765    // Conservative fallback: assume 512 MiB is available.
766    512 * 1024 * 1024
767}
768
769/// Compute a batch size that respects the `target_memory_fraction` and the
770/// bounds in `config`, using `ChunkingUtils::optimal_chunk_size` from
771/// `scirs2-core::chunking` as a starting point for the element-count hint.
772///
773/// # Arguments
774/// * `dataset_len`       – number of samples in the dataset.
775/// * `bytes_per_sample`  – estimated byte cost of one (input + label) sample.
776/// * `config`            – `MemoryAwareConfig` with fraction and bounds.
777fn compute_adaptive_batch_size(
778    dataset_len: usize,
779    bytes_per_sample: usize,
780    config: &MemoryAwareConfig,
781) -> usize {
782    // ── Step 1: ask ChunkingUtils for a purely data-size-driven hint ─────────
783    // We use Adaptive strategy with bounds derived from the memory config so
784    // that the core chunking logic (CPU count, work-stealing, etc.) still
785    // participates in the decision.
786    let chunk_cfg = ChunkConfig {
787        strategy: ChunkStrategy::Adaptive,
788        min_chunk_size: config.min_batch_size,
789        max_chunk_size: config.max_batch_size,
790        ..ChunkConfig::default()
791    };
792    let chunking_hint = ChunkingUtils::optimal_chunk_size(dataset_len, &chunk_cfg);
793
794    // ── Step 2: derive a memory-budget-constrained upper bound ───────────────
795    let available = estimate_available_memory_bytes();
796    // How many bytes may we use for a single batch?
797    let budget_bytes = ((available as f64) * config.target_memory_fraction) as usize;
798    // Convert to sample count, guarding against division by zero.
799    let budget_samples = budget_bytes
800        .checked_div(bytes_per_sample)
801        .map(|v| v.max(1))
802        .unwrap_or(config.max_batch_size);
803
804    // ── Step 3: reconcile the two hints ──────────────────────────────────────
805    // Take the more conservative (smaller) of the two estimates, then clamp to
806    // the configured bounds.
807    let raw = chunking_hint.min(budget_samples);
808    raw.max(config.min_batch_size).min(config.max_batch_size)
809}
810
811/// A data loader that automatically sizes its batches based on available system
812/// memory and the `scirs2-core::chunking` adaptive strategy.
813///
814/// `MemoryAwareDataLoader` wraps an existing `Dataset` and selects a batch size
815/// at construction time (and can recompute it at epoch boundaries) so that the
816/// training process stays within a configurable fraction of available RAM.
817///
818/// The loader prefetches the next batch in a background thread while the caller
819/// consumes the current one, overlapping I/O and computation.
820///
821/// # Type Parameters
822/// * `F` – floating-point element type (e.g. `f32` or `f64`).
823/// * `D` – the underlying dataset type implementing [`crate::data::Dataset`].
824pub struct MemoryAwareDataLoader<
825    F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
826    D: Dataset<F> + Send + Sync + Clone + 'static,
827> {
828    /// The underlying dataset (shared with the prefetch thread).
829    dataset: Arc<D>,
830    /// The configuration supplied by the caller.
831    config: MemoryAwareConfig,
832    /// Shuffled or sequential sample indices for the current epoch.
833    indices: Vec<usize>,
834    /// Current read position (batch-level index, not sample-level).
835    position: AtomicUsize,
836    /// Batch size selected at construction (may be refreshed with
837    /// `refresh_batch_size`).
838    batch_size: usize,
839    /// Derived total number of batches for the current epoch.
840    num_batches: usize,
841    /// Loading performance statistics.
842    stats: Mutex<LoadingStats>,
843    _phantom: PhantomData<F>,
844}
845
846impl<
847        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
848        D: Dataset<F> + Send + Sync + Clone + 'static,
849    > MemoryAwareDataLoader<F, D>
850{
851    /// Create a new `MemoryAwareDataLoader` with an automatically selected
852    /// batch size.
853    ///
854    /// The batch size is computed once at construction from the dataset length,
855    /// the per-sample byte cost, and available system memory.  Call
856    /// [`Self::refresh_batch_size`] at epoch boundaries if you want it to
857    /// re-examine memory pressure.
858    ///
859    /// # Arguments
860    /// * `dataset` – the dataset to load from.
861    /// * `config`  – memory-aware loading configuration.
862    pub fn new_adaptive(dataset: D, config: MemoryAwareConfig) -> Result<Self> {
863        let dataset_len = dataset.len();
864        if dataset_len == 0 {
865            return Err(NeuralError::TrainingError(
866                "Cannot create MemoryAwareDataLoader from an empty dataset".to_string(),
867            ));
868        }
869
870        // Resolve bytes_per_sample: use the caller's hint or probe the dataset.
871        let bytes_per_sample = match config.bytes_per_sample {
872            Some(b) => b,
873            None => {
874                // Peek at the first sample to determine array shapes.
875                let (x0, y0) = dataset.get(0)?;
876                (x0.len() + y0.len()) * std::mem::size_of::<F>()
877            }
878        };
879
880        let batch_size = compute_adaptive_batch_size(dataset_len, bytes_per_sample, &config);
881
882        let num_batches = if config.drop_last {
883            dataset_len / batch_size
884        } else {
885            dataset_len.div_ceil(batch_size)
886        };
887
888        let indices: Vec<usize> = (0..dataset_len).collect();
889
890        Ok(Self {
891            dataset: Arc::new(dataset),
892            config,
893            indices,
894            position: AtomicUsize::new(0),
895            batch_size,
896            num_batches,
897            stats: Mutex::new(LoadingStats::default()),
898            _phantom: PhantomData,
899        })
900    }
901
902    /// Recompute and update the batch size based on the current system memory
903    /// state.  Call this between epochs to react to changing memory pressure
904    /// (e.g. other processes using RAM).
905    ///
906    /// Returns the new batch size.
907    pub fn refresh_batch_size(&mut self) -> Result<usize> {
908        let dataset_len = self.dataset.len();
909        let bytes_per_sample = match self.config.bytes_per_sample {
910            Some(b) => b,
911            None => {
912                let (x0, y0) = self.dataset.get(0)?;
913                (x0.len() + y0.len()) * std::mem::size_of::<F>()
914            }
915        };
916
917        let new_batch_size =
918            compute_adaptive_batch_size(dataset_len, bytes_per_sample, &self.config);
919        self.batch_size = new_batch_size;
920        self.num_batches = if self.config.drop_last {
921            dataset_len / new_batch_size
922        } else {
923            dataset_len.div_ceil(new_batch_size)
924        };
925        Ok(new_batch_size)
926    }
927
928    /// Returns the batch size currently in use.
929    pub fn adaptive_batch_size(&self) -> usize {
930        self.batch_size
931    }
932
933    /// Returns the number of batches in the current epoch configuration.
934    pub fn num_batches(&self) -> usize {
935        self.num_batches
936    }
937
938    /// Returns the number of samples in the dataset.
939    pub fn len(&self) -> usize {
940        self.dataset.len()
941    }
942
943    /// Returns `true` if the dataset is empty.
944    pub fn is_empty(&self) -> bool {
945        self.dataset.len() == 0
946    }
947
948    /// Returns a snapshot of loading performance statistics.
949    pub fn stats(&self) -> LoadingStats {
950        self.stats
951            .lock()
952            .map_or_else(|_| LoadingStats::default(), |s| s.clone())
953    }
954
955    /// Reset state for the beginning of a new epoch (optionally shuffling
956    /// indices).  Does *not* refresh the batch size; call
957    /// [`Self::refresh_batch_size`] explicitly if desired.
958    pub fn reset(&mut self) {
959        if self.config.shuffle {
960            let mut rng = scirs2_core::random::rng();
961            self.indices.shuffle(&mut rng);
962        }
963        self.position.store(0, Ordering::Relaxed);
964    }
965
966    /// Load a single batch by batch index.  This is the core loading routine
967    /// shared between `next_batch` and the prefetch worker.
968    fn load_batch(&self, batch_idx: usize) -> BatchResult<F> {
969        let start = batch_idx * self.batch_size;
970        let end = (start + self.batch_size).min(self.indices.len());
971
972        if start >= self.indices.len() {
973            return Err(NeuralError::TrainingError(
974                "Batch index out of range".to_string(),
975            ));
976        }
977
978        let batch_indices: Vec<usize> = self.indices[start..end].to_vec();
979
980        if batch_indices.is_empty() {
981            return Err(NeuralError::TrainingError("Empty batch".to_string()));
982        }
983
984        // Probe the first sample to determine array shapes.
985        let (first_x, first_y) = self.dataset.get(batch_indices[0])?;
986
987        let batch_x_shape: Vec<usize> = std::iter::once(batch_indices.len())
988            .chain(first_x.shape().iter().copied())
989            .collect();
990        let batch_y_shape: Vec<usize> = std::iter::once(batch_indices.len())
991            .chain(first_y.shape().iter().copied())
992            .collect();
993
994        let mut batch_x = Array::zeros(IxDyn(&batch_x_shape));
995        let mut batch_y = Array::zeros(IxDyn(&batch_y_shape));
996
997        for (i, &idx) in batch_indices.iter().enumerate() {
998            let (x, y) = self.dataset.get(idx)?;
999            let mut sx = batch_x.slice_mut(scirs2_core::ndarray::s![i, ..]);
1000            sx.assign(&x);
1001            let mut sy = batch_y.slice_mut(scirs2_core::ndarray::s![i, ..]);
1002            sy.assign(&y);
1003        }
1004
1005        Ok((batch_x, batch_y))
1006    }
1007
1008    /// Fetch the next batch sequentially (no background prefetch).  Returns
1009    /// `None` once all batches for the current epoch have been consumed.
1010    pub fn next_batch(&self) -> Option<BatchResult<F>> {
1011        let batch_idx = self.position.fetch_add(1, Ordering::Relaxed);
1012        if batch_idx >= self.num_batches {
1013            return None;
1014        }
1015
1016        let start_time = Instant::now();
1017        let result = self.load_batch(batch_idx);
1018        let elapsed = start_time.elapsed();
1019
1020        if let Ok(mut stats) = self.stats.lock() {
1021            stats.batches_loaded += 1;
1022            stats.samples_loaded += self.batch_size.min(
1023                self.indices
1024                    .len()
1025                    .saturating_sub(batch_idx * self.batch_size),
1026            );
1027            stats.total_load_time += elapsed;
1028            stats.avg_batch_time = stats.total_load_time / stats.batches_loaded as u32;
1029            stats.cache_misses += 1;
1030        }
1031
1032        Some(result)
1033    }
1034
1035    /// Consume `self` and return a [`MemoryAwarePrefetchIter`] that loads the
1036    /// next batch in a background thread while the caller processes the current
1037    /// one.
1038    pub fn into_prefetch_iter(self) -> MemoryAwarePrefetchIter<F, D> {
1039        MemoryAwarePrefetchIter::new(self)
1040    }
1041}
1042
1043impl<
1044        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
1045        D: Dataset<F> + Send + Sync + Clone + 'static,
1046    > Iterator for MemoryAwareDataLoader<F, D>
1047{
1048    type Item = BatchResult<F>;
1049
1050    fn next(&mut self) -> Option<Self::Item> {
1051        self.next_batch()
1052    }
1053}
1054
1055// =============================================================================
1056// Prefetching iterator for MemoryAwareDataLoader
1057// =============================================================================
1058
1059/// Background-prefetching iterator produced by
1060/// [`MemoryAwareDataLoader::into_prefetch_iter`].
1061///
1062/// Internally it spawns a single worker thread that fills a bounded queue with
1063/// pre-loaded batches.  The consumer calls `next()` and receives batches in
1064/// order; if the worker is faster the consumer never waits, if the consumer is
1065/// faster it blocks briefly until the worker catches up.
1066pub struct MemoryAwarePrefetchIter<
1067    F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
1068    D: Dataset<F> + Send + Sync + Clone + 'static,
1069> {
1070    loader: Arc<MemoryAwareDataLoader<F, D>>,
1071    queue: Arc<PrefetchQueue<F>>,
1072    worker: Option<thread::JoinHandle<()>>,
1073    expected_idx: usize,
1074    /// Buffer for batches that arrived out of the expected order.
1075    out_of_order: VecDeque<(usize, BatchResult<F>)>,
1076}
1077
1078impl<
1079        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
1080        D: Dataset<F> + Send + Sync + Clone + 'static,
1081    > MemoryAwarePrefetchIter<F, D>
1082{
1083    fn new(loader: MemoryAwareDataLoader<F, D>) -> Self {
1084        let prefetch_ahead = loader.config.prefetch_ahead;
1085        let num_batches = loader.num_batches;
1086        let loader = Arc::new(loader);
1087        let queue = Arc::new(PrefetchQueue::new(prefetch_ahead));
1088
1089        let worker_loader = Arc::clone(&loader);
1090        let worker_queue = Arc::clone(&queue);
1091
1092        let worker = thread::spawn(move || {
1093            for batch_idx in 0..num_batches {
1094                if worker_queue.stop.load(Ordering::Relaxed) {
1095                    break;
1096                }
1097                let result = worker_loader.load_batch(batch_idx);
1098                if !worker_queue.push(batch_idx, result) {
1099                    break;
1100                }
1101            }
1102        });
1103
1104        Self {
1105            loader,
1106            queue,
1107            worker: Some(worker),
1108            expected_idx: 0,
1109            out_of_order: VecDeque::new(),
1110        }
1111    }
1112}
1113
1114impl<
1115        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
1116        D: Dataset<F> + Send + Sync + Clone + 'static,
1117    > Iterator for MemoryAwarePrefetchIter<F, D>
1118{
1119    type Item = BatchResult<F>;
1120
1121    fn next(&mut self) -> Option<Self::Item> {
1122        if self.expected_idx >= self.loader.num_batches {
1123            return None;
1124        }
1125
1126        // Drain the out-of-order buffer first.
1127        if let Some(pos) = self
1128            .out_of_order
1129            .iter()
1130            .position(|(idx, _)| *idx == self.expected_idx)
1131        {
1132            let (_, result) = self
1133                .out_of_order
1134                .remove(pos)
1135                .expect("position was just found in out_of_order buffer");
1136            self.expected_idx += 1;
1137            return Some(result);
1138        }
1139
1140        // Block until the worker delivers the expected batch.
1141        let wait_start = Instant::now();
1142        loop {
1143            if let Some((idx, result)) = self.queue.pop() {
1144                if idx == self.expected_idx {
1145                    if let Ok(mut stats) = self.loader.stats.lock() {
1146                        stats.prefetch_wait_time += wait_start.elapsed();
1147                    }
1148                    self.expected_idx += 1;
1149                    return Some(result);
1150                }
1151                // Not the one we need — stash for later.
1152                self.out_of_order.push_back((idx, result));
1153            } else if self.queue.is_empty() && self.queue.stop.load(Ordering::Relaxed) {
1154                return None;
1155            } else {
1156                thread::sleep(Duration::from_micros(10));
1157            }
1158        }
1159    }
1160}
1161
1162impl<
1163        F: Float + Debug + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
1164        D: Dataset<F> + Send + Sync + Clone + 'static,
1165    > Drop for MemoryAwarePrefetchIter<F, D>
1166{
1167    fn drop(&mut self) {
1168        self.queue.stop();
1169        if let Some(handle) = self.worker.take() {
1170            let _ = handle.join();
1171        }
1172    }
1173}
1174
1175// =============================================================================
1176// Tests
1177// =============================================================================
1178
1179#[cfg(test)]
1180mod tests {
1181    use super::*;
1182    use crate::data::InMemoryDataset;
1183
1184    fn create_test_dataset() -> InMemoryDataset<f64> {
1185        let features = Array::zeros(IxDyn(&[100, 10]));
1186        let labels = Array::zeros(IxDyn(&[100, 2]));
1187        InMemoryDataset::new(features, labels).expect("Failed to create test dataset")
1188    }
1189
1190    #[test]
1191    fn test_optimized_loader_config_default() {
1192        let config = OptimizedLoaderConfig::default();
1193        assert_eq!(config.batch_size, 32);
1194        assert_eq!(config.prefetch_size, 2);
1195        assert_eq!(config.num_workers, 0);
1196        assert!(!config.drop_last);
1197        assert!(config.shuffle);
1198    }
1199
1200    #[test]
1201    fn test_optimized_dataloader_creation() {
1202        let dataset = create_test_dataset();
1203        let config = OptimizedLoaderConfig {
1204            batch_size: 10,
1205            shuffle: false,
1206            ..Default::default()
1207        };
1208
1209        let loader = OptimizedDataLoader::new(dataset, config);
1210        assert_eq!(loader.len(), 100);
1211        assert_eq!(loader.num_batches(), 10);
1212    }
1213
1214    #[test]
1215    fn test_optimized_dataloader_iteration() {
1216        let dataset = create_test_dataset();
1217        let config = OptimizedLoaderConfig {
1218            batch_size: 10,
1219            shuffle: false,
1220            drop_last: true,
1221            ..Default::default()
1222        };
1223
1224        let mut loader = OptimizedDataLoader::new(dataset, config);
1225        loader.reset();
1226
1227        let mut batch_count = 0;
1228        while let Some(result) = loader.next_batch() {
1229            let (x, y) = result.expect("Failed to load batch");
1230            assert_eq!(x.shape()[0], 10);
1231            assert_eq!(y.shape()[0], 10);
1232            batch_count += 1;
1233        }
1234
1235        assert_eq!(batch_count, 10);
1236    }
1237
1238    #[test]
1239    fn test_optimized_dataloader_stats() {
1240        let dataset = create_test_dataset();
1241        let config = OptimizedLoaderConfig {
1242            batch_size: 20,
1243            shuffle: false,
1244            ..Default::default()
1245        };
1246
1247        let mut loader = OptimizedDataLoader::new(dataset, config);
1248        loader.reset();
1249
1250        // Load all batches
1251        while loader.next_batch().is_some() {}
1252
1253        let stats = loader.stats();
1254        assert_eq!(stats.batches_loaded, 5);
1255        assert_eq!(stats.samples_loaded, 100);
1256    }
1257
1258    #[test]
1259    fn test_batch_cache() {
1260        let mut cache: BatchCache<f64> = BatchCache::new(10);
1261
1262        let batch1 = (Array::zeros(IxDyn(&[5, 10])), Array::zeros(IxDyn(&[5, 2])));
1263
1264        cache.insert(0, batch1.clone());
1265
1266        let cached = cache.get(0);
1267        assert!(cached.is_some());
1268        assert_eq!(cached.map(|b| b.0.shape()[0]), Some(5));
1269
1270        assert!(cache.get(1).is_none());
1271
1272        cache.clear();
1273        assert!(cache.get(0).is_none());
1274    }
1275
1276    #[test]
1277    fn test_prefetch_queue() {
1278        let queue: PrefetchQueue<f64> = PrefetchQueue::new(3);
1279
1280        let batch = Ok((Array::zeros(IxDyn(&[5, 10])), Array::zeros(IxDyn(&[5, 2]))));
1281
1282        assert!(queue.push(0, batch));
1283        assert!(!queue.is_empty());
1284
1285        let popped = queue.pop();
1286        assert!(popped.is_some());
1287        assert_eq!(popped.map(|(idx, _)| idx), Some(0));
1288
1289        assert!(queue.is_empty());
1290
1291        queue.stop();
1292        // After stop, push should return false
1293        let batch2 = Ok((Array::zeros(IxDyn(&[5, 10])), Array::zeros(IxDyn(&[5, 2]))));
1294        assert!(!queue.push(1, batch2));
1295    }
1296
1297    #[test]
1298    fn test_loading_stats_default() {
1299        let stats = LoadingStats::default();
1300        assert_eq!(stats.batches_loaded, 0);
1301        assert_eq!(stats.samples_loaded, 0);
1302        assert_eq!(stats.cache_hits, 0);
1303        assert_eq!(stats.cache_misses, 0);
1304    }
1305
1306    #[test]
1307    fn test_estimate_array_memory() {
1308        let array: Array<f64, IxDyn> = Array::zeros(IxDyn(&[10, 20]));
1309        let memory = estimate_array_memory(&array);
1310        assert_eq!(memory, 10 * 20 * std::mem::size_of::<f64>());
1311    }
1312
1313    #[test]
1314    fn test_batch_size_optimizer_default() {
1315        let optimizer = BatchSizeOptimizer::default();
1316        assert_eq!(optimizer.min_batch_size, 8);
1317        assert_eq!(optimizer.max_batch_size, 512);
1318    }
1319
1320    #[test]
1321    fn test_batch_size_optimizer_with_range() {
1322        let optimizer = BatchSizeOptimizer::new()
1323            .with_range(16, 256)
1324            .with_max_memory(1024 * 1024);
1325
1326        assert_eq!(optimizer.min_batch_size, 16);
1327        assert_eq!(optimizer.max_batch_size, 256);
1328        assert_eq!(optimizer.max_memory, 1024 * 1024);
1329    }
1330
1331    #[test]
1332    fn test_find_optimal_batch_size() {
1333        let dataset = create_test_dataset();
1334        let optimizer = BatchSizeOptimizer::new().with_range(10, 50);
1335
1336        let result = optimizer.find_optimal(dataset);
1337        assert!(result.is_ok());
1338
1339        let result = result.expect("Optimization should succeed");
1340        assert!(result.recommended_batch_size >= 10);
1341        assert!(result.recommended_batch_size <= 50);
1342        assert!(!result.throughput_results.is_empty());
1343    }
1344
1345    #[test]
1346    fn test_dataloader_with_caching() {
1347        let dataset = create_test_dataset();
1348        let config = OptimizedLoaderConfig {
1349            batch_size: 10,
1350            shuffle: false,
1351            cache_batches: true,
1352            ..Default::default()
1353        };
1354
1355        let mut loader = OptimizedDataLoader::new(dataset, config);
1356        loader.reset();
1357
1358        // First pass - all cache misses
1359        while loader.next_batch().is_some() {}
1360
1361        let stats = loader.stats();
1362        assert_eq!(stats.cache_misses, 10);
1363        assert_eq!(stats.cache_hits, 0);
1364    }
1365
1366    #[test]
1367    fn test_iterator_trait() {
1368        let dataset = create_test_dataset();
1369        let config = OptimizedLoaderConfig {
1370            batch_size: 25,
1371            shuffle: false,
1372            drop_last: true,
1373            ..Default::default()
1374        };
1375
1376        let mut loader = OptimizedDataLoader::new(dataset, config);
1377        loader.reset();
1378
1379        let batches: Vec<_> = loader.collect();
1380        assert_eq!(batches.len(), 4); // 100 / 25 = 4 batches
1381    }
1382
1383    // -------------------------------------------------------------------------
1384    // MemoryAwareDataLoader tests
1385    // -------------------------------------------------------------------------
1386
1387    #[test]
1388    fn test_memory_aware_config_default() {
1389        let cfg = MemoryAwareConfig::default();
1390        assert!(
1391            cfg.target_memory_fraction > 0.0 && cfg.target_memory_fraction <= 1.0,
1392            "target_memory_fraction must be in (0, 1]"
1393        );
1394        assert!(cfg.min_batch_size >= 1);
1395        assert!(cfg.max_batch_size >= cfg.min_batch_size);
1396    }
1397
1398    #[test]
1399    fn test_memory_aware_loader_creation() {
1400        let dataset = create_test_dataset();
1401        let config = MemoryAwareConfig {
1402            shuffle: false,
1403            drop_last: false,
1404            ..Default::default()
1405        };
1406
1407        let loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1408            .expect("loader creation must succeed");
1409
1410        // Batch size must be within the configured bounds.
1411        let bs = loader.adaptive_batch_size();
1412        assert!(bs >= 4, "batch_size ({bs}) must be >= min_batch_size (4)");
1413        assert!(
1414            bs <= 4096,
1415            "batch_size ({bs}) must be <= max_batch_size (4096)"
1416        );
1417        // Dataset has 100 samples, so there must be at least 1 batch.
1418        assert!(loader.num_batches() >= 1);
1419        assert_eq!(loader.len(), 100);
1420        assert!(!loader.is_empty());
1421    }
1422
1423    #[test]
1424    fn test_memory_aware_loader_iteration_all_samples() {
1425        let dataset = create_test_dataset();
1426        let config = MemoryAwareConfig {
1427            shuffle: false,
1428            drop_last: false,
1429            min_batch_size: 10,
1430            max_batch_size: 10,
1431            target_memory_fraction: 1.0, // doesn't matter when bounds force batch_size
1432            ..Default::default()
1433        };
1434
1435        let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1436            .expect("loader creation must succeed");
1437        loader.reset();
1438
1439        let mut total_samples = 0usize;
1440        let mut batch_count = 0usize;
1441        while let Some(result) = loader.next_batch() {
1442            let (x, _y) = result.expect("batch load must succeed");
1443            total_samples += x.shape()[0];
1444            batch_count += 1;
1445        }
1446
1447        assert_eq!(total_samples, 100, "all 100 samples must be yielded");
1448        assert_eq!(batch_count, 10, "100 samples / batch_size 10 = 10 batches");
1449    }
1450
1451    #[test]
1452    fn test_memory_aware_loader_drop_last() {
1453        // 100 samples, batch_size clamped to 32 by bounds.
1454        // drop_last=true → 3 full batches of 32, 4 samples discarded.
1455        let dataset = create_test_dataset();
1456        let config = MemoryAwareConfig {
1457            shuffle: false,
1458            drop_last: true,
1459            min_batch_size: 32,
1460            max_batch_size: 32,
1461            target_memory_fraction: 1.0,
1462            ..Default::default()
1463        };
1464
1465        let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1466            .expect("loader creation must succeed");
1467        loader.reset();
1468
1469        let batches: Vec<_> = loader.collect();
1470        assert_eq!(batches.len(), 3, "drop_last: 100/32 = 3 full batches");
1471    }
1472
1473    #[test]
1474    fn test_memory_aware_loader_refresh_batch_size() {
1475        let dataset = create_test_dataset();
1476        let config = MemoryAwareConfig {
1477            shuffle: false,
1478            ..Default::default()
1479        };
1480
1481        let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1482            .expect("loader creation must succeed");
1483
1484        let new_bs = loader.refresh_batch_size().expect("refresh must succeed");
1485        assert!(new_bs >= loader.config.min_batch_size);
1486        assert!(new_bs <= loader.config.max_batch_size);
1487        assert_eq!(new_bs, loader.adaptive_batch_size());
1488    }
1489
1490    #[test]
1491    fn test_memory_aware_loader_stats() {
1492        let dataset = create_test_dataset();
1493        let config = MemoryAwareConfig {
1494            shuffle: false,
1495            drop_last: false,
1496            min_batch_size: 10,
1497            max_batch_size: 10,
1498            target_memory_fraction: 1.0,
1499            ..Default::default()
1500        };
1501
1502        let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1503            .expect("loader creation must succeed");
1504        loader.reset();
1505
1506        while loader.next_batch().is_some() {}
1507
1508        let stats = loader.stats();
1509        assert_eq!(stats.batches_loaded, 10);
1510        assert_eq!(stats.samples_loaded, 100);
1511    }
1512
1513    #[test]
1514    fn test_memory_aware_prefetch_iter() {
1515        let dataset = create_test_dataset();
1516        let config = MemoryAwareConfig {
1517            shuffle: false,
1518            drop_last: false,
1519            min_batch_size: 10,
1520            max_batch_size: 10,
1521            target_memory_fraction: 1.0,
1522            prefetch_ahead: 2,
1523            ..Default::default()
1524        };
1525
1526        let mut loader = MemoryAwareDataLoader::<f64, _>::new_adaptive(dataset, config)
1527            .expect("loader creation must succeed");
1528        loader.reset();
1529
1530        let iter = loader.into_prefetch_iter();
1531        let batches: Vec<_> = iter.collect();
1532
1533        // Each batch must be a successful result with the right shape.
1534        for batch_result in &batches {
1535            let (x, _y) = batch_result
1536                .as_ref()
1537                .expect("prefetch batch must not be an error");
1538            assert_eq!(x.shape()[0], 10);
1539        }
1540        assert_eq!(batches.len(), 10);
1541    }
1542
1543    #[test]
1544    fn test_estimate_available_memory_is_positive() {
1545        let mem = estimate_available_memory_bytes();
1546        assert!(mem > 0, "available memory estimate must be > 0");
1547    }
1548
1549    #[test]
1550    fn test_compute_adaptive_batch_size_bounds() {
1551        let config = MemoryAwareConfig {
1552            min_batch_size: 8,
1553            max_batch_size: 64,
1554            target_memory_fraction: 0.1,
1555            bytes_per_sample: Some(1024),
1556            ..Default::default()
1557        };
1558        let bs = compute_adaptive_batch_size(1000, 1024, &config);
1559        assert!(bs >= 8, "must respect min_batch_size");
1560        assert!(bs <= 64, "must respect max_batch_size");
1561    }
1562}