Skip to main content

yscv_model/
data_loader.rs

1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3use std::sync::mpsc;
4use std::thread;
5
6use yscv_tensor::Tensor;
7
8use crate::ModelError;
9
10/// Configuration for the parallel data loader.
11#[derive(Debug, Clone)]
12pub struct DataLoaderConfig {
13    /// Number of samples per batch.
14    pub batch_size: usize,
15    /// Number of worker threads for prefetching.
16    pub num_workers: usize,
17    /// Number of batches each worker prefetches before blocking.
18    pub prefetch_factor: usize,
19    /// Whether to drop the last incomplete batch.
20    pub drop_last: bool,
21    /// Whether to shuffle samples each epoch.
22    pub shuffle: bool,
23}
24
25impl Default for DataLoaderConfig {
26    fn default() -> Self {
27        Self {
28            batch_size: 32,
29            num_workers: 1,
30            prefetch_factor: 2,
31            drop_last: false,
32            shuffle: false,
33        }
34    }
35}
36
37/// A batch of samples produced by the data loader.
38#[derive(Debug, Clone, PartialEq)]
39pub struct DataLoaderBatch {
40    /// Stacked input tensors with shape `[batch_size, ...]`.
41    pub inputs: Tensor,
42    /// Stacked target tensors with shape `[batch_size, ...]`.
43    pub targets: Tensor,
44}
45
46/// Parallel data loader that prefetches batches using worker threads.
47pub struct DataLoader {
48    config: DataLoaderConfig,
49    inputs: Vec<Tensor>,
50    targets: Vec<Tensor>,
51    epoch_counter: std::cell::Cell<u64>,
52}
53
54impl DataLoader {
55    /// Creates a new data loader from individual sample tensors and configuration.
56    ///
57    /// All input tensors must have the same shape, and all target tensors must
58    /// have the same shape. The number of inputs must equal the number of targets.
59    pub fn new(
60        inputs: Vec<Tensor>,
61        targets: Vec<Tensor>,
62        config: DataLoaderConfig,
63    ) -> Result<Self, ModelError> {
64        if inputs.len() != targets.len() {
65            return Err(ModelError::DatasetShapeMismatch {
66                inputs: vec![inputs.len()],
67                targets: vec![targets.len()],
68            });
69        }
70        if config.batch_size == 0 {
71            return Err(ModelError::InvalidBatchSize {
72                batch_size: config.batch_size,
73            });
74        }
75        if config.num_workers == 0 {
76            return Err(ModelError::InvalidBatchSize {
77                batch_size: config.num_workers,
78            });
79        }
80        // Validate that all inputs share the same shape.
81        if let Some(first) = inputs.first() {
82            let expected = first.shape();
83            for t in inputs.iter().skip(1) {
84                if t.shape() != expected {
85                    return Err(ModelError::InvalidParameterShape {
86                        parameter: "data_loader_input",
87                        expected: expected.to_vec(),
88                        got: t.shape().to_vec(),
89                    });
90                }
91            }
92        }
93        // Validate that all targets share the same shape.
94        if let Some(first) = targets.first() {
95            let expected = first.shape();
96            for t in targets.iter().skip(1) {
97                if t.shape() != expected {
98                    return Err(ModelError::InvalidParameterShape {
99                        parameter: "data_loader_target",
100                        expected: expected.to_vec(),
101                        got: t.shape().to_vec(),
102                    });
103                }
104            }
105        }
106        Ok(Self {
107            config,
108            inputs,
109            targets,
110            epoch_counter: std::cell::Cell::new(0),
111        })
112    }
113
114    /// Returns the number of batches per epoch.
115    pub fn len(&self) -> usize {
116        let n = self.inputs.len();
117        if n == 0 || self.config.batch_size == 0 {
118            return 0;
119        }
120        if self.config.drop_last {
121            n / self.config.batch_size
122        } else {
123            n.div_ceil(self.config.batch_size)
124        }
125    }
126
127    /// Returns `true` if the loader has no batches.
128    pub fn is_empty(&self) -> bool {
129        self.len() == 0
130    }
131
132    /// Returns the underlying configuration.
133    pub fn config(&self) -> &DataLoaderConfig {
134        &self.config
135    }
136
137    /// Returns the total number of samples.
138    pub fn sample_count(&self) -> usize {
139        self.inputs.len()
140    }
141
142    /// Creates an iterator that spawns worker threads and prefetches batches.
143    ///
144    /// Each call increments the internal epoch counter, producing a different
145    /// shuffle order when `config.shuffle` is `true`.
146    pub fn iter(&self) -> DataLoaderIter {
147        let epoch = self.epoch_counter.get();
148        self.epoch_counter.set(epoch.wrapping_add(1));
149
150        let num_samples = self.inputs.len();
151        let batch_size = self.config.batch_size;
152
153        // Build sample ordering.
154        let mut indices: Vec<usize> = (0..num_samples).collect();
155        if self.config.shuffle {
156            lcg_shuffle(&mut indices, epoch);
157        }
158
159        // Build batch index ranges.
160        let mut batch_ranges: Vec<(usize, usize)> = Vec::new();
161        let mut start = 0;
162        while start < num_samples {
163            let end = (start + batch_size).min(num_samples);
164            let is_full = (end - start) == batch_size;
165            if is_full || !self.config.drop_last {
166                batch_ranges.push((start, end));
167            }
168            start = end;
169        }
170
171        let total_batches = batch_ranges.len();
172
173        if total_batches == 0 {
174            // No work to do; return an iterator that immediately finishes.
175            let (_tx, rx) = mpsc::sync_channel::<Result<DataLoaderBatch, String>>(0);
176            return DataLoaderIter {
177                receiver: rx,
178                _workers: Vec::new(),
179                remaining: 0,
180            };
181        }
182
183        let channel_capacity = self
184            .config
185            .num_workers
186            .saturating_mul(self.config.prefetch_factor)
187            .max(1);
188        let (tx, rx) = mpsc::sync_channel::<Result<DataLoaderBatch, String>>(channel_capacity);
189
190        // Share data with workers via Arc.
191        let shared_inputs = Arc::new(self.inputs.clone());
192        let shared_targets = Arc::new(self.targets.clone());
193        let shared_indices = Arc::new(indices);
194
195        let num_workers = self.config.num_workers.min(total_batches);
196        let mut workers = Vec::with_capacity(num_workers);
197
198        for worker_id in 0..num_workers {
199            // Each worker handles batches: worker_id, worker_id + N, worker_id + 2N, ...
200            let worker_batch_indices: Vec<usize> =
201                (worker_id..total_batches).step_by(num_workers).collect();
202            let worker_ranges: Vec<(usize, usize)> = worker_batch_indices
203                .iter()
204                .map(|&bi| batch_ranges[bi])
205                .collect();
206
207            let tx = tx.clone();
208            let inputs = Arc::clone(&shared_inputs);
209            let targets = Arc::clone(&shared_targets);
210            let sample_indices = Arc::clone(&shared_indices);
211
212            let handle = thread::spawn(move || {
213                for (range_start, range_end) in worker_ranges {
214                    let batch_indices: Vec<usize> = (range_start..range_end)
215                        .map(|i| sample_indices[i])
216                        .collect();
217
218                    let result = build_batch(&inputs, &targets, &batch_indices);
219                    let send_result = match result {
220                        Ok(batch) => tx.send(Ok(batch)),
221                        Err(e) => tx.send(Err(e.to_string())),
222                    };
223                    if send_result.is_err() {
224                        // Receiver dropped; stop producing.
225                        break;
226                    }
227                }
228            });
229            workers.push(handle);
230        }
231
232        // Drop the original sender so the channel closes when all workers finish.
233        drop(tx);
234
235        DataLoaderIter {
236            receiver: rx,
237            _workers: workers,
238            remaining: total_batches,
239        }
240    }
241}
242
243/// Iterator over batches produced by worker threads.
244pub struct DataLoaderIter {
245    receiver: mpsc::Receiver<Result<DataLoaderBatch, String>>,
246    _workers: Vec<thread::JoinHandle<()>>,
247    remaining: usize,
248}
249
250impl Iterator for DataLoaderIter {
251    type Item = Result<DataLoaderBatch, ModelError>;
252
253    fn next(&mut self) -> Option<Self::Item> {
254        if self.remaining == 0 {
255            return None;
256        }
257        match self.receiver.recv() {
258            Ok(Ok(batch)) => {
259                self.remaining -= 1;
260                Some(Ok(batch))
261            }
262            Ok(Err(msg)) => {
263                self.remaining -= 1;
264                Some(Err(ModelError::DatasetLoadIo {
265                    path: String::new(),
266                    message: msg,
267                }))
268            }
269            Err(_) => {
270                // Channel closed unexpectedly.
271                self.remaining = 0;
272                None
273            }
274        }
275    }
276}
277
278/// Stack individual sample tensors into a single batch tensor.
279///
280/// Given tensors each with shape `[d0, d1, ...]`, produces a tensor with
281/// shape `[batch_size, d0, d1, ...]`.
282fn stack_tensors(tensors: &[&Tensor]) -> Result<Tensor, ModelError> {
283    if tensors.is_empty() {
284        return Err(ModelError::EmptyDataset);
285    }
286    let sample_shape = tensors[0].shape();
287    let sample_len = tensors[0].len();
288
289    let batch_size = tensors.len();
290    let mut batch_shape = Vec::with_capacity(sample_shape.len() + 1);
291    batch_shape.push(batch_size);
292    batch_shape.extend_from_slice(sample_shape);
293
294    let total_len = batch_size * sample_len;
295    let mut data = Vec::with_capacity(total_len);
296    for tensor in tensors {
297        data.extend_from_slice(tensor.data());
298    }
299
300    Tensor::from_vec(batch_shape, data).map_err(ModelError::from)
301}
302
303/// Build a single batch from the given sample indices.
304fn build_batch(
305    inputs: &[Tensor],
306    targets: &[Tensor],
307    indices: &[usize],
308) -> Result<DataLoaderBatch, ModelError> {
309    let input_refs: Vec<&Tensor> = indices.iter().map(|&i| &inputs[i]).collect();
310    let target_refs: Vec<&Tensor> = indices.iter().map(|&i| &targets[i]).collect();
311
312    let stacked_inputs = stack_tensors(&input_refs)?;
313    let stacked_targets = stack_tensors(&target_refs)?;
314
315    Ok(DataLoaderBatch {
316        inputs: stacked_inputs,
317        targets: stacked_targets,
318    })
319}
320
321/// Simple LCG-based Fisher-Yates shuffle, deterministic for a given seed.
322fn lcg_shuffle(indices: &mut [usize], seed: u64) {
323    let mut state = seed ^ 0x6C62_272E_07BB_0142;
324    let mut index = indices.len();
325    while index > 1 {
326        index -= 1;
327        state = state
328            .wrapping_mul(6_364_136_223_846_793_005)
329            .wrapping_add(1);
330        let swap_idx = ((state >> 33) as usize) % (index + 1);
331        indices.swap(index, swap_idx);
332    }
333}
334
335// ---------------------------------------------------------------------------
336// Samplers
337// ---------------------------------------------------------------------------
338
339/// A sampler that yields indices in sequential order.
340#[derive(Debug, Clone)]
341pub struct SequentialSampler {
342    len: usize,
343}
344
345impl SequentialSampler {
346    pub fn new(len: usize) -> Self {
347        Self { len }
348    }
349
350    /// Returns indices `[0, 1, 2, ..., len-1]`.
351    pub fn indices(&self) -> Vec<usize> {
352        (0..self.len).collect()
353    }
354}
355
356/// A sampler that yields indices in a random (deterministic) order.
357#[derive(Debug, Clone)]
358pub struct RandomSampler {
359    len: usize,
360    seed: u64,
361}
362
363impl RandomSampler {
364    pub fn new(len: usize, seed: u64) -> Self {
365        Self { len, seed }
366    }
367
368    /// Returns a shuffled permutation of `[0, len)` using the given seed.
369    pub fn indices(&self) -> Vec<usize> {
370        let mut idx: Vec<usize> = (0..self.len).collect();
371        lcg_shuffle(&mut idx, self.seed);
372        idx
373    }
374}
375
376/// Weighted random sampler: draws `num_samples` indices with probability proportional to weights.
377///
378/// Useful for imbalanced datasets where minority classes should be oversampled.
379#[derive(Debug, Clone)]
380pub struct WeightedRandomSampler {
381    weights: Vec<f64>,
382    num_samples: usize,
383    seed: u64,
384}
385
386impl WeightedRandomSampler {
387    /// Creates a new weighted sampler.
388    ///
389    /// `weights`: per-sample weight (higher = more likely to be sampled).
390    /// `num_samples`: how many indices to draw per epoch.
391    /// `seed`: deterministic random seed.
392    pub fn new(weights: Vec<f64>, num_samples: usize, seed: u64) -> Result<Self, ModelError> {
393        if weights.is_empty() {
394            return Err(ModelError::EmptyDataset);
395        }
396        Ok(Self {
397            weights,
398            num_samples,
399            seed,
400        })
401    }
402
403    /// Draw `num_samples` indices with replacement, proportional to weights.
404    pub fn indices(&self) -> Vec<usize> {
405        let total: f64 = self.weights.iter().sum();
406        if total <= 0.0 {
407            return (0..self.num_samples)
408                .map(|i| i % self.weights.len())
409                .collect();
410        }
411
412        // Build CDF
413        let mut cdf = Vec::with_capacity(self.weights.len());
414        let mut acc = 0.0;
415        for &w in &self.weights {
416            acc += w / total;
417            cdf.push(acc);
418        }
419
420        let mut state = self.seed ^ 0x5DEE_CE66_D1A4_F87D;
421        let mut result = Vec::with_capacity(self.num_samples);
422        for _ in 0..self.num_samples {
423            state = state
424                .wrapping_mul(6_364_136_223_846_793_005)
425                .wrapping_add(1);
426            let u = (state >> 11) as f64 / (1u64 << 53) as f64; // uniform [0, 1)
427            // Binary search in CDF
428            let idx = match cdf
429                .binary_search_by(|v| v.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal))
430            {
431                Ok(i) => i,
432                Err(i) => i.min(self.weights.len() - 1),
433            };
434            result.push(idx);
435        }
436        result
437    }
438
439    /// Number of samples drawn per epoch.
440    pub fn num_samples(&self) -> usize {
441        self.num_samples
442    }
443}
444
445// ---------------------------------------------------------------------------
446// Streaming Data Loader
447// ---------------------------------------------------------------------------
448
449/// A data loader that lazily reads batches from disk, using a background thread
450/// to prefetch the next batch while the current batch is being processed.
451///
452/// The loader scans a directory for numbered batch files (`batch_0000.bin`,
453/// `batch_0001.bin`, etc.). Each file stores a pair of tensors (inputs + targets)
454/// in a simple binary format:
455///
456/// ```text
457/// [input_ndims: u32] [input_shape...] [input_data: f32...]
458/// [target_ndims: u32] [target_shape...] [target_data: f32...]
459/// ```
460pub struct StreamingDataLoader {
461    path: PathBuf,
462    batch_size: usize,
463    file_paths: Vec<PathBuf>,
464    current_index: usize,
465    prefetch_rx: Option<mpsc::Receiver<Result<(Tensor, Tensor), ModelError>>>,
466    _prefetch_handle: Option<thread::JoinHandle<()>>,
467}
468
469impl StreamingDataLoader {
470    /// Create a new streaming data loader that reads batch files from `path`.
471    ///
472    /// The directory is scanned for files matching `batch_NNNN.bin`. If no
473    /// batch files are found, an empty loader is returned.
474    pub fn new(path: impl Into<PathBuf>, batch_size: usize) -> Result<Self, ModelError> {
475        let path = path.into();
476        if batch_size == 0 {
477            return Err(ModelError::InvalidBatchSize { batch_size });
478        }
479        let file_paths = Self::scan_batch_files(&path);
480        let mut loader = Self {
481            path,
482            batch_size,
483            file_paths,
484            current_index: 0,
485            prefetch_rx: None,
486            _prefetch_handle: None,
487        };
488        loader.start_prefetch();
489        Ok(loader)
490    }
491
492    /// Returns the next batch of (inputs, targets), or `None` if all batches
493    /// have been consumed for this epoch.
494    pub fn next_batch(&mut self) -> Option<(Tensor, Tensor)> {
495        if self.current_index >= self.file_paths.len() {
496            return None;
497        }
498
499        // Receive the prefetched batch.
500        let result = if let Some(rx) = self.prefetch_rx.take() {
501            match rx.recv() {
502                Ok(Ok(batch)) => Some(batch),
503                Ok(Err(_)) => None,
504                Err(_) => None,
505            }
506        } else {
507            // Fallback: load synchronously.
508            Self::load_batch_file(&self.file_paths[self.current_index]).ok()
509        };
510
511        self.current_index += 1;
512
513        // Start prefetching the next batch.
514        self.start_prefetch();
515
516        result
517    }
518
519    /// Reset the loader to the beginning so it can be iterated again.
520    pub fn reset(&mut self) {
521        // Drop any in-flight prefetch.
522        self.prefetch_rx = None;
523        self._prefetch_handle = None;
524        self.current_index = 0;
525        self.start_prefetch();
526    }
527
528    /// Returns the total number of batch files available.
529    pub fn len(&self) -> usize {
530        self.file_paths.len()
531    }
532
533    /// Returns `true` if there are no batch files.
534    pub fn is_empty(&self) -> bool {
535        self.file_paths.is_empty()
536    }
537
538    /// Returns the configured batch size.
539    pub fn batch_size(&self) -> usize {
540        self.batch_size
541    }
542
543    /// Returns the directory path.
544    pub fn path(&self) -> &Path {
545        &self.path
546    }
547
548    // -- internal helpers ---------------------------------------------------
549
550    fn start_prefetch(&mut self) {
551        if self.current_index >= self.file_paths.len() {
552            return;
553        }
554        let file_path = self.file_paths[self.current_index].clone();
555        let (tx, rx) = mpsc::sync_channel(1);
556        let handle = thread::spawn(move || {
557            let result = Self::load_batch_file(&file_path);
558            let _ = tx.send(result);
559        });
560        self.prefetch_rx = Some(rx);
561        self._prefetch_handle = Some(handle);
562    }
563
564    /// Scan directory for `batch_NNNN.bin` files, sorted by name.
565    fn scan_batch_files(dir: &Path) -> Vec<PathBuf> {
566        let read_dir = match std::fs::read_dir(dir) {
567            Ok(rd) => rd,
568            Err(_) => return Vec::new(),
569        };
570        let mut paths: Vec<PathBuf> = read_dir
571            .filter_map(|entry| entry.ok())
572            .map(|entry| entry.path())
573            .filter(|p| {
574                if let Some(name) = p.file_name().and_then(|n| n.to_str()) {
575                    name.starts_with("batch_") && name.ends_with(".bin")
576                } else {
577                    false
578                }
579            })
580            .collect();
581        paths.sort();
582        paths
583    }
584
585    /// Load a single batch file in the simple binary tensor-pair format.
586    fn load_batch_file(path: &Path) -> Result<(Tensor, Tensor), ModelError> {
587        let data = std::fs::read(path).map_err(|e| ModelError::DatasetLoadIo {
588            path: path.display().to_string(),
589            message: e.to_string(),
590        })?;
591        let input = Self::read_tensor_from_bytes(&data, 0)?;
592        let offset = Self::tensor_byte_size(&input) + 4 + input.shape().len() * 4;
593        let target = Self::read_tensor_from_bytes(&data, offset)?;
594        Ok((input, target))
595    }
596
597    /// Write a tensor pair to the simple binary format.
598    pub fn write_batch_file(
599        path: &Path,
600        inputs: &Tensor,
601        targets: &Tensor,
602    ) -> Result<(), ModelError> {
603        let mut buf = Vec::new();
604        Self::write_tensor_to_bytes(&mut buf, inputs);
605        Self::write_tensor_to_bytes(&mut buf, targets);
606        std::fs::write(path, &buf).map_err(|e| ModelError::DatasetLoadIo {
607            path: path.display().to_string(),
608            message: e.to_string(),
609        })
610    }
611
612    fn write_tensor_to_bytes(buf: &mut Vec<u8>, tensor: &Tensor) {
613        let ndims = tensor.shape().len() as u32;
614        buf.extend_from_slice(&ndims.to_le_bytes());
615        for &d in tensor.shape() {
616            buf.extend_from_slice(&(d as u32).to_le_bytes());
617        }
618        for &v in tensor.data() {
619            buf.extend_from_slice(&v.to_le_bytes());
620        }
621    }
622
623    fn tensor_byte_size(tensor: &Tensor) -> usize {
624        tensor.data().len() * 4
625    }
626
627    fn read_tensor_from_bytes(data: &[u8], offset: usize) -> Result<Tensor, ModelError> {
628        if offset + 4 > data.len() {
629            return Err(ModelError::DatasetLoadIo {
630                path: String::new(),
631                message: "unexpected end of batch file (ndims)".to_string(),
632            });
633        }
634        let ndims = u32::from_le_bytes([
635            data[offset],
636            data[offset + 1],
637            data[offset + 2],
638            data[offset + 3],
639        ]) as usize;
640        let shape_start = offset + 4;
641        let shape_end = shape_start + ndims * 4;
642        if shape_end > data.len() {
643            return Err(ModelError::DatasetLoadIo {
644                path: String::new(),
645                message: "unexpected end of batch file (shape)".to_string(),
646            });
647        }
648        let mut shape = Vec::with_capacity(ndims);
649        for i in 0..ndims {
650            let s = shape_start + i * 4;
651            shape.push(
652                u32::from_le_bytes([data[s], data[s + 1], data[s + 2], data[s + 3]]) as usize,
653            );
654        }
655        let num_elements: usize = shape.iter().product();
656        let data_start = shape_end;
657        let data_end = data_start + num_elements * 4;
658        if data_end > data.len() {
659            return Err(ModelError::DatasetLoadIo {
660                path: String::new(),
661                message: "unexpected end of batch file (data)".to_string(),
662            });
663        }
664        let mut values = Vec::with_capacity(num_elements);
665        for i in 0..num_elements {
666            let s = data_start + i * 4;
667            values.push(f32::from_le_bytes([
668                data[s],
669                data[s + 1],
670                data[s + 2],
671                data[s + 3],
672            ]));
673        }
674        Tensor::from_vec(shape, values).map_err(ModelError::from)
675    }
676}