1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3use std::sync::mpsc;
4use std::thread;
5
6use yscv_tensor::Tensor;
7
8use crate::ModelError;
9
10#[derive(Debug, Clone)]
12pub struct DataLoaderConfig {
13 pub batch_size: usize,
15 pub num_workers: usize,
17 pub prefetch_factor: usize,
19 pub drop_last: bool,
21 pub shuffle: bool,
23}
24
25impl Default for DataLoaderConfig {
26 fn default() -> Self {
27 Self {
28 batch_size: 32,
29 num_workers: 1,
30 prefetch_factor: 2,
31 drop_last: false,
32 shuffle: false,
33 }
34 }
35}
36
37#[derive(Debug, Clone, PartialEq)]
39pub struct DataLoaderBatch {
40 pub inputs: Tensor,
42 pub targets: Tensor,
44}
45
46pub struct DataLoader {
48 config: DataLoaderConfig,
49 inputs: Vec<Tensor>,
50 targets: Vec<Tensor>,
51 epoch_counter: std::cell::Cell<u64>,
52}
53
54impl DataLoader {
55 pub fn new(
60 inputs: Vec<Tensor>,
61 targets: Vec<Tensor>,
62 config: DataLoaderConfig,
63 ) -> Result<Self, ModelError> {
64 if inputs.len() != targets.len() {
65 return Err(ModelError::DatasetShapeMismatch {
66 inputs: vec![inputs.len()],
67 targets: vec![targets.len()],
68 });
69 }
70 if config.batch_size == 0 {
71 return Err(ModelError::InvalidBatchSize {
72 batch_size: config.batch_size,
73 });
74 }
75 if config.num_workers == 0 {
76 return Err(ModelError::InvalidBatchSize {
77 batch_size: config.num_workers,
78 });
79 }
80 if let Some(first) = inputs.first() {
82 let expected = first.shape();
83 for t in inputs.iter().skip(1) {
84 if t.shape() != expected {
85 return Err(ModelError::InvalidParameterShape {
86 parameter: "data_loader_input",
87 expected: expected.to_vec(),
88 got: t.shape().to_vec(),
89 });
90 }
91 }
92 }
93 if let Some(first) = targets.first() {
95 let expected = first.shape();
96 for t in targets.iter().skip(1) {
97 if t.shape() != expected {
98 return Err(ModelError::InvalidParameterShape {
99 parameter: "data_loader_target",
100 expected: expected.to_vec(),
101 got: t.shape().to_vec(),
102 });
103 }
104 }
105 }
106 Ok(Self {
107 config,
108 inputs,
109 targets,
110 epoch_counter: std::cell::Cell::new(0),
111 })
112 }
113
114 pub fn len(&self) -> usize {
116 let n = self.inputs.len();
117 if n == 0 || self.config.batch_size == 0 {
118 return 0;
119 }
120 if self.config.drop_last {
121 n / self.config.batch_size
122 } else {
123 n.div_ceil(self.config.batch_size)
124 }
125 }
126
127 pub fn is_empty(&self) -> bool {
129 self.len() == 0
130 }
131
132 pub fn config(&self) -> &DataLoaderConfig {
134 &self.config
135 }
136
137 pub fn sample_count(&self) -> usize {
139 self.inputs.len()
140 }
141
142 pub fn iter(&self) -> DataLoaderIter {
147 let epoch = self.epoch_counter.get();
148 self.epoch_counter.set(epoch.wrapping_add(1));
149
150 let num_samples = self.inputs.len();
151 let batch_size = self.config.batch_size;
152
153 let mut indices: Vec<usize> = (0..num_samples).collect();
155 if self.config.shuffle {
156 lcg_shuffle(&mut indices, epoch);
157 }
158
159 let mut batch_ranges: Vec<(usize, usize)> = Vec::new();
161 let mut start = 0;
162 while start < num_samples {
163 let end = (start + batch_size).min(num_samples);
164 let is_full = (end - start) == batch_size;
165 if is_full || !self.config.drop_last {
166 batch_ranges.push((start, end));
167 }
168 start = end;
169 }
170
171 let total_batches = batch_ranges.len();
172
173 if total_batches == 0 {
174 let (_tx, rx) = mpsc::sync_channel::<Result<DataLoaderBatch, String>>(0);
176 return DataLoaderIter {
177 receiver: rx,
178 _workers: Vec::new(),
179 remaining: 0,
180 };
181 }
182
183 let channel_capacity = self
184 .config
185 .num_workers
186 .saturating_mul(self.config.prefetch_factor)
187 .max(1);
188 let (tx, rx) = mpsc::sync_channel::<Result<DataLoaderBatch, String>>(channel_capacity);
189
190 let shared_inputs = Arc::new(self.inputs.clone());
192 let shared_targets = Arc::new(self.targets.clone());
193 let shared_indices = Arc::new(indices);
194
195 let num_workers = self.config.num_workers.min(total_batches);
196 let mut workers = Vec::with_capacity(num_workers);
197
198 for worker_id in 0..num_workers {
199 let worker_batch_indices: Vec<usize> =
201 (worker_id..total_batches).step_by(num_workers).collect();
202 let worker_ranges: Vec<(usize, usize)> = worker_batch_indices
203 .iter()
204 .map(|&bi| batch_ranges[bi])
205 .collect();
206
207 let tx = tx.clone();
208 let inputs = Arc::clone(&shared_inputs);
209 let targets = Arc::clone(&shared_targets);
210 let sample_indices = Arc::clone(&shared_indices);
211
212 let handle = thread::spawn(move || {
213 for (range_start, range_end) in worker_ranges {
214 let batch_indices: Vec<usize> = (range_start..range_end)
215 .map(|i| sample_indices[i])
216 .collect();
217
218 let result = build_batch(&inputs, &targets, &batch_indices);
219 let send_result = match result {
220 Ok(batch) => tx.send(Ok(batch)),
221 Err(e) => tx.send(Err(e.to_string())),
222 };
223 if send_result.is_err() {
224 break;
226 }
227 }
228 });
229 workers.push(handle);
230 }
231
232 drop(tx);
234
235 DataLoaderIter {
236 receiver: rx,
237 _workers: workers,
238 remaining: total_batches,
239 }
240 }
241}
242
243pub struct DataLoaderIter {
245 receiver: mpsc::Receiver<Result<DataLoaderBatch, String>>,
246 _workers: Vec<thread::JoinHandle<()>>,
247 remaining: usize,
248}
249
250impl Iterator for DataLoaderIter {
251 type Item = Result<DataLoaderBatch, ModelError>;
252
253 fn next(&mut self) -> Option<Self::Item> {
254 if self.remaining == 0 {
255 return None;
256 }
257 match self.receiver.recv() {
258 Ok(Ok(batch)) => {
259 self.remaining -= 1;
260 Some(Ok(batch))
261 }
262 Ok(Err(msg)) => {
263 self.remaining -= 1;
264 Some(Err(ModelError::DatasetLoadIo {
265 path: String::new(),
266 message: msg,
267 }))
268 }
269 Err(_) => {
270 self.remaining = 0;
272 None
273 }
274 }
275 }
276}
277
278fn stack_tensors(tensors: &[&Tensor]) -> Result<Tensor, ModelError> {
283 if tensors.is_empty() {
284 return Err(ModelError::EmptyDataset);
285 }
286 let sample_shape = tensors[0].shape();
287 let sample_len = tensors[0].len();
288
289 let batch_size = tensors.len();
290 let mut batch_shape = Vec::with_capacity(sample_shape.len() + 1);
291 batch_shape.push(batch_size);
292 batch_shape.extend_from_slice(sample_shape);
293
294 let total_len = batch_size * sample_len;
295 let mut data = Vec::with_capacity(total_len);
296 for tensor in tensors {
297 data.extend_from_slice(tensor.data());
298 }
299
300 Tensor::from_vec(batch_shape, data).map_err(ModelError::from)
301}
302
303fn build_batch(
305 inputs: &[Tensor],
306 targets: &[Tensor],
307 indices: &[usize],
308) -> Result<DataLoaderBatch, ModelError> {
309 let input_refs: Vec<&Tensor> = indices.iter().map(|&i| &inputs[i]).collect();
310 let target_refs: Vec<&Tensor> = indices.iter().map(|&i| &targets[i]).collect();
311
312 let stacked_inputs = stack_tensors(&input_refs)?;
313 let stacked_targets = stack_tensors(&target_refs)?;
314
315 Ok(DataLoaderBatch {
316 inputs: stacked_inputs,
317 targets: stacked_targets,
318 })
319}
320
321fn lcg_shuffle(indices: &mut [usize], seed: u64) {
323 let mut state = seed ^ 0x6C62_272E_07BB_0142;
324 let mut index = indices.len();
325 while index > 1 {
326 index -= 1;
327 state = state
328 .wrapping_mul(6_364_136_223_846_793_005)
329 .wrapping_add(1);
330 let swap_idx = ((state >> 33) as usize) % (index + 1);
331 indices.swap(index, swap_idx);
332 }
333}
334
335#[derive(Debug, Clone)]
341pub struct SequentialSampler {
342 len: usize,
343}
344
345impl SequentialSampler {
346 pub fn new(len: usize) -> Self {
347 Self { len }
348 }
349
350 pub fn indices(&self) -> Vec<usize> {
352 (0..self.len).collect()
353 }
354}
355
356#[derive(Debug, Clone)]
358pub struct RandomSampler {
359 len: usize,
360 seed: u64,
361}
362
363impl RandomSampler {
364 pub fn new(len: usize, seed: u64) -> Self {
365 Self { len, seed }
366 }
367
368 pub fn indices(&self) -> Vec<usize> {
370 let mut idx: Vec<usize> = (0..self.len).collect();
371 lcg_shuffle(&mut idx, self.seed);
372 idx
373 }
374}
375
376#[derive(Debug, Clone)]
380pub struct WeightedRandomSampler {
381 weights: Vec<f64>,
382 num_samples: usize,
383 seed: u64,
384}
385
386impl WeightedRandomSampler {
387 pub fn new(weights: Vec<f64>, num_samples: usize, seed: u64) -> Result<Self, ModelError> {
393 if weights.is_empty() {
394 return Err(ModelError::EmptyDataset);
395 }
396 Ok(Self {
397 weights,
398 num_samples,
399 seed,
400 })
401 }
402
403 pub fn indices(&self) -> Vec<usize> {
405 let total: f64 = self.weights.iter().sum();
406 if total <= 0.0 {
407 return (0..self.num_samples)
408 .map(|i| i % self.weights.len())
409 .collect();
410 }
411
412 let mut cdf = Vec::with_capacity(self.weights.len());
414 let mut acc = 0.0;
415 for &w in &self.weights {
416 acc += w / total;
417 cdf.push(acc);
418 }
419
420 let mut state = self.seed ^ 0x5DEE_CE66_D1A4_F87D;
421 let mut result = Vec::with_capacity(self.num_samples);
422 for _ in 0..self.num_samples {
423 state = state
424 .wrapping_mul(6_364_136_223_846_793_005)
425 .wrapping_add(1);
426 let u = (state >> 11) as f64 / (1u64 << 53) as f64; let idx = match cdf
429 .binary_search_by(|v| v.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal))
430 {
431 Ok(i) => i,
432 Err(i) => i.min(self.weights.len() - 1),
433 };
434 result.push(idx);
435 }
436 result
437 }
438
439 pub fn num_samples(&self) -> usize {
441 self.num_samples
442 }
443}
444
445pub struct StreamingDataLoader {
461 path: PathBuf,
462 batch_size: usize,
463 file_paths: Vec<PathBuf>,
464 current_index: usize,
465 prefetch_rx: Option<mpsc::Receiver<Result<(Tensor, Tensor), ModelError>>>,
466 _prefetch_handle: Option<thread::JoinHandle<()>>,
467}
468
469impl StreamingDataLoader {
470 pub fn new(path: impl Into<PathBuf>, batch_size: usize) -> Result<Self, ModelError> {
475 let path = path.into();
476 if batch_size == 0 {
477 return Err(ModelError::InvalidBatchSize { batch_size });
478 }
479 let file_paths = Self::scan_batch_files(&path);
480 let mut loader = Self {
481 path,
482 batch_size,
483 file_paths,
484 current_index: 0,
485 prefetch_rx: None,
486 _prefetch_handle: None,
487 };
488 loader.start_prefetch();
489 Ok(loader)
490 }
491
492 pub fn next_batch(&mut self) -> Option<(Tensor, Tensor)> {
495 if self.current_index >= self.file_paths.len() {
496 return None;
497 }
498
499 let result = if let Some(rx) = self.prefetch_rx.take() {
501 match rx.recv() {
502 Ok(Ok(batch)) => Some(batch),
503 Ok(Err(_)) => None,
504 Err(_) => None,
505 }
506 } else {
507 Self::load_batch_file(&self.file_paths[self.current_index]).ok()
509 };
510
511 self.current_index += 1;
512
513 self.start_prefetch();
515
516 result
517 }
518
519 pub fn reset(&mut self) {
521 self.prefetch_rx = None;
523 self._prefetch_handle = None;
524 self.current_index = 0;
525 self.start_prefetch();
526 }
527
528 pub fn len(&self) -> usize {
530 self.file_paths.len()
531 }
532
533 pub fn is_empty(&self) -> bool {
535 self.file_paths.is_empty()
536 }
537
538 pub fn batch_size(&self) -> usize {
540 self.batch_size
541 }
542
543 pub fn path(&self) -> &Path {
545 &self.path
546 }
547
548 fn start_prefetch(&mut self) {
551 if self.current_index >= self.file_paths.len() {
552 return;
553 }
554 let file_path = self.file_paths[self.current_index].clone();
555 let (tx, rx) = mpsc::sync_channel(1);
556 let handle = thread::spawn(move || {
557 let result = Self::load_batch_file(&file_path);
558 let _ = tx.send(result);
559 });
560 self.prefetch_rx = Some(rx);
561 self._prefetch_handle = Some(handle);
562 }
563
564 fn scan_batch_files(dir: &Path) -> Vec<PathBuf> {
566 let read_dir = match std::fs::read_dir(dir) {
567 Ok(rd) => rd,
568 Err(_) => return Vec::new(),
569 };
570 let mut paths: Vec<PathBuf> = read_dir
571 .filter_map(|entry| entry.ok())
572 .map(|entry| entry.path())
573 .filter(|p| {
574 if let Some(name) = p.file_name().and_then(|n| n.to_str()) {
575 name.starts_with("batch_") && name.ends_with(".bin")
576 } else {
577 false
578 }
579 })
580 .collect();
581 paths.sort();
582 paths
583 }
584
585 fn load_batch_file(path: &Path) -> Result<(Tensor, Tensor), ModelError> {
587 let data = std::fs::read(path).map_err(|e| ModelError::DatasetLoadIo {
588 path: path.display().to_string(),
589 message: e.to_string(),
590 })?;
591 let input = Self::read_tensor_from_bytes(&data, 0)?;
592 let offset = Self::tensor_byte_size(&input) + 4 + input.shape().len() * 4;
593 let target = Self::read_tensor_from_bytes(&data, offset)?;
594 Ok((input, target))
595 }
596
597 pub fn write_batch_file(
599 path: &Path,
600 inputs: &Tensor,
601 targets: &Tensor,
602 ) -> Result<(), ModelError> {
603 let mut buf = Vec::new();
604 Self::write_tensor_to_bytes(&mut buf, inputs);
605 Self::write_tensor_to_bytes(&mut buf, targets);
606 std::fs::write(path, &buf).map_err(|e| ModelError::DatasetLoadIo {
607 path: path.display().to_string(),
608 message: e.to_string(),
609 })
610 }
611
612 fn write_tensor_to_bytes(buf: &mut Vec<u8>, tensor: &Tensor) {
613 let ndims = tensor.shape().len() as u32;
614 buf.extend_from_slice(&ndims.to_le_bytes());
615 for &d in tensor.shape() {
616 buf.extend_from_slice(&(d as u32).to_le_bytes());
617 }
618 for &v in tensor.data() {
619 buf.extend_from_slice(&v.to_le_bytes());
620 }
621 }
622
623 fn tensor_byte_size(tensor: &Tensor) -> usize {
624 tensor.data().len() * 4
625 }
626
627 fn read_tensor_from_bytes(data: &[u8], offset: usize) -> Result<Tensor, ModelError> {
628 if offset + 4 > data.len() {
629 return Err(ModelError::DatasetLoadIo {
630 path: String::new(),
631 message: "unexpected end of batch file (ndims)".to_string(),
632 });
633 }
634 let ndims = u32::from_le_bytes([
635 data[offset],
636 data[offset + 1],
637 data[offset + 2],
638 data[offset + 3],
639 ]) as usize;
640 let shape_start = offset + 4;
641 let shape_end = shape_start + ndims * 4;
642 if shape_end > data.len() {
643 return Err(ModelError::DatasetLoadIo {
644 path: String::new(),
645 message: "unexpected end of batch file (shape)".to_string(),
646 });
647 }
648 let mut shape = Vec::with_capacity(ndims);
649 for i in 0..ndims {
650 let s = shape_start + i * 4;
651 shape.push(
652 u32::from_le_bytes([data[s], data[s + 1], data[s + 2], data[s + 3]]) as usize,
653 );
654 }
655 let num_elements: usize = shape.iter().product();
656 let data_start = shape_end;
657 let data_end = data_start + num_elements * 4;
658 if data_end > data.len() {
659 return Err(ModelError::DatasetLoadIo {
660 path: String::new(),
661 message: "unexpected end of batch file (data)".to_string(),
662 });
663 }
664 let mut values = Vec::with_capacity(num_elements);
665 for i in 0..num_elements {
666 let s = data_start + i * 4;
667 values.push(f32::from_le_bytes([
668 data[s],
669 data[s + 1],
670 data[s + 2],
671 data[s + 3],
672 ]));
673 }
674 Tensor::from_vec(shape, values).map_err(ModelError::from)
675 }
676}