1use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use scirs2_core::ndarray::{Array1, Array2};
10use std::collections::VecDeque;
11use std::path::Path;
12use std::sync::{Arc, Mutex};
13use std::thread;
14
15#[derive(Debug, Clone)]
17pub struct StreamConfig {
18 pub chunk_size: usize,
20 pub buffer_size: usize,
22 pub num_workers: usize,
24 pub memory_limit_mb: Option<usize>,
26 pub enable_compression: bool,
28 pub enable_prefetch: bool,
30 pub max_chunks: Option<usize>,
32}
33
34impl Default for StreamConfig {
35 fn default() -> Self {
36 Self {
37 chunk_size: 10_000,
38 buffer_size: 3,
39 num_workers: num_cpus::get(),
40 memory_limit_mb: None,
41 enable_compression: false,
42 enable_prefetch: true,
43 max_chunks: None,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct DataChunk {
51 pub data: Array2<f64>,
53 pub target: Option<Array1<f64>>,
55 pub chunk_index: usize,
57 pub sample_indices: Vec<usize>,
59 pub is_last: bool,
61}
62
63impl DataChunk {
64 pub fn n_samples(&self) -> usize {
66 self.data.nrows()
67 }
68
69 pub fn n_features(&self) -> usize {
71 self.data.ncols()
72 }
73
74 pub fn to_dataset(&self) -> Dataset {
76 Dataset {
77 data: self.data.clone(),
78 target: self.target.clone(),
79 targetnames: None,
80 featurenames: None,
81 feature_descriptions: None,
82 description: None,
83 metadata: Default::default(),
84 }
85 }
86}
87
88pub struct StreamingIterator {
90 config: StreamConfig,
91 chunk_buffer: Arc<Mutex<VecDeque<DataChunk>>>,
92 current_chunk: usize,
93 total_chunks: Option<usize>,
94 finished: bool,
95 producer_handle: Option<thread::JoinHandle<Result<()>>>,
96}
97
98impl StreamingIterator {
99 pub fn from_csv<P: AsRef<Path>>(path: P, config: StreamConfig) -> Result<Self> {
101 let path = path.as_ref().to_path_buf();
102 let chunk_buffer = Arc::new(Mutex::new(VecDeque::new()));
103 let buffer_clone = Arc::clone(&chunk_buffer);
104 let config_clone = config.clone();
105
106 let producer_handle =
108 thread::spawn(move || Self::csv_producer(path, config_clone, buffer_clone));
109
110 Ok(Self {
111 config,
112 chunk_buffer,
113 current_chunk: 0,
114 total_chunks: None,
115 finished: false,
116 producer_handle: Some(producer_handle),
117 })
118 }
119
120 pub fn from_binary<P: AsRef<Path>>(
122 path: P,
123 n_features: usize,
124 config: StreamConfig,
125 ) -> Result<Self> {
126 let path = path.as_ref().to_path_buf();
127 let chunk_buffer = Arc::new(Mutex::new(VecDeque::new()));
128 let buffer_clone = Arc::clone(&chunk_buffer);
129 let config_clone = config.clone();
130
131 let producer_handle = thread::spawn(move || {
132 Self::binary_producer(path, n_features, config_clone, buffer_clone)
133 });
134
135 Ok(Self {
136 config,
137 chunk_buffer,
138 current_chunk: 0,
139 total_chunks: None,
140 finished: false,
141 producer_handle: Some(producer_handle),
142 })
143 }
144
145 pub fn from_generator<F>(
147 generator: F,
148 total_samples: usize,
149 n_features: usize,
150 config: StreamConfig,
151 ) -> Result<Self>
152 where
153 F: Fn(usize, usize, usize) -> Result<(Array2<f64>, Option<Array1<f64>>)> + Send + 'static,
154 {
155 let chunk_buffer = Arc::new(Mutex::new(VecDeque::new()));
156 let buffer_clone = Arc::clone(&chunk_buffer);
157 let config_clone = config.clone();
158
159 let producer_handle = thread::spawn(move || {
160 Self::generator_producer(
161 generator,
162 total_samples,
163 n_features,
164 config_clone,
165 buffer_clone,
166 )
167 });
168
169 let total_chunks = total_samples.div_ceil(config.chunk_size);
170
171 Ok(Self {
172 config,
173 chunk_buffer,
174 current_chunk: 0,
175 total_chunks: Some(total_chunks),
176 finished: false,
177 producer_handle: Some(producer_handle),
178 })
179 }
180
181 pub fn next_chunk(&mut self) -> Result<Option<DataChunk>> {
183 if self.finished {
184 return Ok(None);
185 }
186
187 if let Some(max_chunks) = self.config.max_chunks {
189 if self.current_chunk >= max_chunks {
190 self.finished = true;
191 return Ok(None);
192 }
193 }
194
195 loop {
197 {
198 let mut buffer = self.chunk_buffer.lock().unwrap();
199 if let Some(chunk) = buffer.pop_front() {
200 self.current_chunk += 1;
201
202 if chunk.is_last {
203 self.finished = true;
204 }
205
206 return Ok(Some(chunk));
207 }
208 }
209
210 if let Some(handle) = &self.producer_handle {
212 if handle.is_finished() {
213 let handle = self.producer_handle.take().unwrap();
215 handle.join().unwrap()?;
216
217 let mut buffer = self.chunk_buffer.lock().unwrap();
219 if let Some(chunk) = buffer.pop_front() {
220 self.current_chunk += 1;
221 if chunk.is_last {
222 self.finished = true;
223 }
224 return Ok(Some(chunk));
225 } else {
226 self.finished = true;
227 return Ok(None);
228 }
229 }
230 }
231
232 thread::sleep(std::time::Duration::from_millis(10));
234 }
235 }
236
237 pub fn stats(&self) -> StreamStats {
239 let buffer = self.chunk_buffer.lock().unwrap();
240 StreamStats {
241 current_chunk: self.current_chunk,
242 total_chunks: self.total_chunks,
243 buffer_size: buffer.len(),
244 buffer_capacity: self.config.buffer_size,
245 finished: self.finished,
246 }
247 }
248
249 fn csv_producer(
251 path: std::path::PathBuf,
252 config: StreamConfig,
253 buffer: Arc<Mutex<VecDeque<DataChunk>>>,
254 ) -> Result<()> {
255 use std::fs::File;
256 use std::io::{BufRead, BufReader};
257
258 let file = File::open(&path)?;
259 let reader = BufReader::new(file);
260 let mut lines = reader.lines();
261
262 let _header = lines.next();
264
265 let mut chunk_data = Vec::new();
266 let mut chunk_index = 0;
267 let mut global_sample_index = 0;
268
269 for line in lines {
270 let line = line?;
271 let values: Vec<f64> = line
272 .split(',')
273 .map(|s| s.trim().parse().unwrap_or(0.0))
274 .collect();
275
276 if !values.is_empty() {
277 chunk_data.push((values, global_sample_index));
278 global_sample_index += 1;
279
280 if chunk_data.len() >= config.chunk_size {
281 let chunk = Self::create_chunk_from_data(&chunk_data, chunk_index, false)?;
282
283 loop {
285 let mut buffer_guard = buffer.lock().unwrap();
286 if buffer_guard.len() < config.buffer_size {
287 buffer_guard.push_back(chunk);
288 break;
289 }
290 drop(buffer_guard);
291 thread::sleep(std::time::Duration::from_millis(10));
292 }
293
294 chunk_data.clear();
295 chunk_index += 1;
296
297 if let Some(max_chunks) = config.max_chunks {
298 if chunk_index >= max_chunks {
299 break;
300 }
301 }
302 }
303 }
304 }
305
306 if !chunk_data.is_empty() {
308 let chunk = Self::create_chunk_from_data(&chunk_data, chunk_index, true)?;
309 let mut buffer_guard = buffer.lock().unwrap();
310 buffer_guard.push_back(chunk);
311 }
312
313 Ok(())
314 }
315
316 fn binary_producer(
318 path: std::path::PathBuf,
319 n_features: usize,
320 config: StreamConfig,
321 buffer: Arc<Mutex<VecDeque<DataChunk>>>,
322 ) -> Result<()> {
323 use std::fs::File;
324 use std::io::Read;
325
326 let mut file = File::open(&path)?;
327 let mut chunk_index = 0;
328 let mut global_sample_index = 0;
329
330 let values_per_chunk = config.chunk_size * n_features;
331 let bytes_per_chunk = values_per_chunk * std::mem::size_of::<f64>();
332
333 loop {
334 let mut buffer_data = vec![0u8; bytes_per_chunk];
335 let bytes_read = file.read(&mut buffer_data)?;
336
337 if bytes_read == 0 {
338 break; }
340
341 let values_read = bytes_read / std::mem::size_of::<f64>();
342 let samples_read = values_read / n_features;
343
344 if samples_read == 0 {
345 break;
346 }
347
348 let float_data: Vec<f64> = buffer_data[..bytes_read]
350 .chunks_exact(std::mem::size_of::<f64>())
351 .map(|chunk| {
352 let mut bytes = [0u8; 8];
353 bytes.copy_from_slice(chunk);
354 f64::from_le_bytes(bytes)
355 })
356 .collect();
357
358 let data = Array2::from_shape_vec((samples_read, n_features), float_data)
360 .map_err(|e| DatasetsError::Other(format!("Shape error: {e}")))?;
361 let sample_indices: Vec<usize> =
362 (global_sample_index..global_sample_index + samples_read).collect();
363
364 let chunk = DataChunk {
365 data,
366 target: None,
367 chunk_index,
368 sample_indices,
369 is_last: bytes_read < bytes_per_chunk,
370 };
371
372 loop {
374 let mut buffer_guard = buffer.lock().unwrap();
375 if buffer_guard.len() < config.buffer_size {
376 buffer_guard.push_back(chunk);
377 break;
378 }
379 drop(buffer_guard);
380 thread::sleep(std::time::Duration::from_millis(10));
381 }
382
383 global_sample_index += samples_read;
384 chunk_index += 1;
385
386 if let Some(max_chunks) = config.max_chunks {
387 if chunk_index >= max_chunks {
388 break;
389 }
390 }
391
392 if bytes_read < bytes_per_chunk {
393 break; }
395 }
396
397 Ok(())
398 }
399
400 fn generator_producer<F>(
402 generator: F,
403 total_samples: usize,
404 n_features: usize,
405 config: StreamConfig,
406 buffer: Arc<Mutex<VecDeque<DataChunk>>>,
407 ) -> Result<()>
408 where
409 F: Fn(usize, usize, usize) -> Result<(Array2<f64>, Option<Array1<f64>>)>,
410 {
411 let mut chunk_index = 0;
412 let mut processed_samples = 0;
413
414 while processed_samples < total_samples {
415 let remaining_samples = total_samples - processed_samples;
416 let chunk_samples = config.chunk_size.min(remaining_samples);
417
418 let (data, target) = generator(chunk_samples, n_features, processed_samples)?;
420
421 let sample_indices: Vec<usize> =
422 (processed_samples..processed_samples + chunk_samples).collect();
423 let is_last = processed_samples + chunk_samples >= total_samples;
424
425 let chunk = DataChunk {
426 data,
427 target,
428 chunk_index,
429 sample_indices,
430 is_last,
431 };
432
433 loop {
435 let mut buffer_guard = buffer.lock().unwrap();
436 if buffer_guard.len() < config.buffer_size {
437 buffer_guard.push_back(chunk);
438 break;
439 }
440 drop(buffer_guard);
441 thread::sleep(std::time::Duration::from_millis(10));
442 }
443
444 processed_samples += chunk_samples;
445 chunk_index += 1;
446
447 if let Some(max_chunks) = config.max_chunks {
448 if chunk_index >= max_chunks {
449 break;
450 }
451 }
452 }
453
454 Ok(())
455 }
456
457 fn create_chunk_from_data(
459 data: &[(Vec<f64>, usize)],
460 chunk_index: usize,
461 is_last: bool,
462 ) -> Result<DataChunk> {
463 if data.is_empty() {
464 return Err(DatasetsError::InvalidFormat("Empty chunk data".to_string()));
465 }
466
467 let n_samples = data.len();
468 let n_features = data[0].0.len() - 1; let mut chunk_data = Array2::zeros((n_samples, n_features));
471 let mut chunk_target = Array1::zeros(n_samples);
472 let mut sample_indices = Vec::with_capacity(n_samples);
473
474 for (i, (values, global_idx)) in data.iter().enumerate() {
475 for j in 0..n_features {
476 chunk_data[[i, j]] = values[j];
477 }
478 chunk_target[i] = values[n_features];
479 sample_indices.push(*global_idx);
480 }
481
482 Ok(DataChunk {
483 data: chunk_data,
484 target: Some(chunk_target),
485 chunk_index,
486 sample_indices,
487 is_last,
488 })
489 }
490}
491
492#[derive(Debug, Clone)]
494pub struct StreamStats {
495 pub current_chunk: usize,
497 pub total_chunks: Option<usize>,
499 pub buffer_size: usize,
501 pub buffer_capacity: usize,
503 pub finished: bool,
505}
506
507impl StreamStats {
508 pub fn progress_percent(&self) -> Option<f64> {
510 self.total_chunks
511 .map(|total| (self.current_chunk as f64 / total as f64) * 100.0)
512 }
513
514 pub fn buffer_utilization(&self) -> f64 {
516 (self.buffer_size as f64 / self.buffer_capacity as f64) * 100.0
517 }
518}
519
520pub struct StreamProcessor<T> {
522 config: StreamConfig,
523 phantom: std::marker::PhantomData<T>,
524}
525
526impl<T> StreamProcessor<T>
527where
528 T: Send + Sync + 'static,
529{
530 pub fn new(config: StreamConfig) -> Self {
532 Self {
533 config,
534 phantom: std::marker::PhantomData,
535 }
536 }
537
538 pub fn process_parallel<F, R>(
540 &self,
541 mut iterator: StreamingIterator,
542 processor: F,
543 ) -> Result<Vec<R>>
544 where
545 F: Fn(DataChunk) -> Result<R> + Send + Sync + Clone + 'static,
546 R: Send + 'static,
547 {
548 use std::sync::mpsc;
549
550 let (work_tx, work_rx) = mpsc::channel();
552 let work_rx = Arc::new(Mutex::new(work_rx));
553
554 let (result_tx, result_rx) = mpsc::channel();
555 let mut worker_handles = Vec::new();
556
557 for worker_id in 0..self.config.num_workers {
559 let work_rx_clone = Arc::clone(&work_rx);
560 let result_tx_clone = result_tx.clone();
561 let processor_clone = processor.clone();
562
563 let handle = thread::spawn(move || {
564 loop {
565 let chunk = {
567 let rx = work_rx_clone.lock().unwrap();
568 rx.recv().ok()
569 };
570
571 match chunk {
572 Some(Some((chunk_id, chunk))) => {
573 match processor_clone(chunk) {
575 Ok(result) => {
576 if result_tx_clone.send((chunk_id, Ok(result))).is_err() {
578 eprintln!("Worker {worker_id} failed to send result");
579 break;
580 }
581 }
582 Err(e) => {
583 eprintln!("Worker {worker_id} processing error: {e}");
584 if result_tx_clone.send((chunk_id, Err(e))).is_err() {
586 break;
587 }
588 }
589 }
590 }
591 Some(None) => break, None => break, }
594 }
595 });
596
597 worker_handles.push(handle);
598 }
599
600 let mut chunk_count = 0;
602 while let Some(chunk) = iterator.next_chunk()? {
603 work_tx
604 .send(Some((chunk_count, chunk)))
605 .map_err(|e| DatasetsError::Other(format!("Work send error: {e}")))?;
606 chunk_count += 1;
607 }
608
609 for _ in 0..self.config.num_workers {
611 work_tx
612 .send(None)
613 .map_err(|e| DatasetsError::Other(format!("End signal send error: {e}")))?;
614 }
615
616 drop(work_tx);
618
619 let mut results: Vec<Option<R>> = (0..chunk_count).map(|_| None).collect();
621 let mut received_count = 0;
622
623 while received_count < chunk_count {
625 match result_rx.recv() {
626 Ok((chunk_id, result)) => {
627 match result {
628 Ok(value) => {
629 if chunk_id < results.len() {
630 results[chunk_id] = Some(value);
631 received_count += 1;
632 }
633 }
634 Err(e) => {
635 return Err(e);
637 }
638 }
639 }
640 Err(_) => {
641 return Err(DatasetsError::Other(
642 "Failed to receive results from workers".to_string(),
643 ));
644 }
645 }
646 }
647
648 for handle in worker_handles {
650 if let Err(e) = handle.join() {
651 eprintln!("Worker thread panicked: {e:?}");
652 }
653 }
654
655 let final_results: Vec<R> =
657 results
658 .into_iter()
659 .collect::<Option<Vec<R>>>()
660 .ok_or_else(|| {
661 DatasetsError::Other("Missing results from parallel processing".to_string())
662 })?;
663
664 Ok(final_results)
665 }
666}
667
668pub struct StreamTransformer {
670 #[allow(clippy::type_complexity)]
671 transformations: Vec<Box<dyn Fn(&mut DataChunk) -> Result<()> + Send + Sync>>,
672}
673
674impl StreamTransformer {
675 pub fn new() -> Self {
677 Self {
678 transformations: Vec::new(),
679 }
680 }
681
682 pub fn add_transform<F>(mut self, transform: F) -> Self
684 where
685 F: Fn(&mut DataChunk) -> Result<()> + Send + Sync + 'static,
686 {
687 self.transformations.push(Box::new(transform));
688 self
689 }
690
691 pub fn transform_chunk(&self, chunk: &mut DataChunk) -> Result<()> {
693 for transform in &self.transformations {
694 transform(chunk)?;
695 }
696 Ok(())
697 }
698
699 pub fn add_standard_scaling(self) -> Self {
701 self.add_transform(|chunk| {
702 let mean = chunk.data.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
704 let std = chunk.data.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
705
706 for mut row in chunk.data.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
707 for (i, val) in row.iter_mut().enumerate() {
708 if std[i] > 0.0 {
709 *val = (*val - mean[i]) / std[i];
710 }
711 }
712 }
713 Ok(())
714 })
715 }
716
717 pub fn add_missing_value_imputation(self) -> Self {
719 self.add_transform(|chunk| {
720 let means = chunk.data.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
722
723 for mut row in chunk.data.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
724 for (i, val) in row.iter_mut().enumerate() {
725 if val.is_nan() {
726 *val = means[i];
727 }
728 }
729 }
730 Ok(())
731 })
732 }
733}
734
735impl Default for StreamTransformer {
736 fn default() -> Self {
737 Self::new()
738 }
739}
740
741#[allow(dead_code)]
745pub fn stream_csv<P: AsRef<Path>>(path: P, config: StreamConfig) -> Result<StreamingIterator> {
746 StreamingIterator::from_csv(path, config)
747}
748
749#[allow(dead_code)]
751pub fn stream_classification(
752 total_samples: usize,
753 n_features: usize,
754 n_classes: usize,
755 config: StreamConfig,
756) -> Result<StreamingIterator> {
757 use crate::generators::make_classification;
758
759 let generator = move |chunk_size: usize, _features: usize, start_idx: usize| {
760 let dataset = make_classification(
761 chunk_size,
762 _features,
763 n_classes,
764 2,
765 _features / 2,
766 Some(42 + start_idx as u64),
767 )?;
768 Ok((dataset.data, dataset.target))
769 };
770
771 StreamingIterator::from_generator(generator, total_samples, n_features, config)
772}
773
774#[allow(dead_code)]
776pub fn stream_regression(
777 total_samples: usize,
778 n_features: usize,
779 config: StreamConfig,
780) -> Result<StreamingIterator> {
781 use crate::generators::make_regression;
782
783 let generator = move |chunk_size: usize, _features: usize, start_idx: usize| {
784 let dataset = make_regression(
785 chunk_size,
786 _features,
787 _features / 2,
788 0.1,
789 Some(42 + start_idx as u64),
790 )?;
791 Ok((dataset.data, dataset.target))
792 };
793
794 StreamingIterator::from_generator(generator, total_samples, n_features, config)
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800
801 #[test]
802 fn test_stream_config() {
803 let config = StreamConfig::default();
804 assert_eq!(config.chunk_size, 10_000);
805 assert_eq!(config.buffer_size, 3);
806 assert!(config.num_workers > 0);
807 }
808
809 #[test]
810 fn test_data_chunk() {
811 let data = Array2::zeros((100, 5));
812 let target = Array1::zeros(100);
813 let chunk = DataChunk {
814 data,
815 target: Some(target),
816 chunk_index: 0,
817 sample_indices: (0..100).collect(),
818 is_last: false,
819 };
820
821 assert_eq!(chunk.n_samples(), 100);
822 assert_eq!(chunk.n_features(), 5);
823 assert!(!chunk.is_last);
824 }
825
826 #[test]
827 fn test_stream_stats() {
828 let stats = StreamStats {
829 current_chunk: 5,
830 total_chunks: Some(10),
831 buffer_size: 2,
832 buffer_capacity: 3,
833 finished: false,
834 };
835
836 assert_eq!(stats.progress_percent(), Some(50.0));
837 assert!((stats.buffer_utilization() - 66.66666666666667).abs() < 1e-10);
838 }
839
840 #[test]
841 fn test_stream_classification() {
842 let config = StreamConfig {
843 chunk_size: 100,
844 buffer_size: 2,
845 max_chunks: Some(3),
846 ..Default::default()
847 };
848
849 let stream = stream_classification(1000, 10, 3, config).unwrap();
850 assert!(stream.total_chunks.is_some());
851 }
852
853 #[test]
854 fn test_stream_transformer() {
855 let transformer = StreamTransformer::new()
856 .add_standard_scaling()
857 .add_missing_value_imputation();
858
859 assert_eq!(transformer.transformations.len(), 2);
860 }
861}