1use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use scirs2_core::ndarray::{Array1, Array2};
10use std::collections::VecDeque;
11use std::path::Path;
12use std::sync::{Arc, Mutex};
13use std::thread;
14
15#[derive(Debug, Clone)]
17pub struct StreamConfig {
18    pub chunk_size: usize,
20    pub buffer_size: usize,
22    pub num_workers: usize,
24    pub memory_limit_mb: Option<usize>,
26    pub enable_compression: bool,
28    pub enable_prefetch: bool,
30    pub max_chunks: Option<usize>,
32}
33
34impl Default for StreamConfig {
35    fn default() -> Self {
36        Self {
37            chunk_size: 10_000,
38            buffer_size: 3,
39            num_workers: num_cpus::get(),
40            memory_limit_mb: None,
41            enable_compression: false,
42            enable_prefetch: true,
43            max_chunks: None,
44        }
45    }
46}
47
48#[derive(Debug, Clone)]
50pub struct DataChunk {
51    pub data: Array2<f64>,
53    pub target: Option<Array1<f64>>,
55    pub chunk_index: usize,
57    pub sample_indices: Vec<usize>,
59    pub is_last: bool,
61}
62
63impl DataChunk {
64    pub fn n_samples(&self) -> usize {
66        self.data.nrows()
67    }
68
69    pub fn n_features(&self) -> usize {
71        self.data.ncols()
72    }
73
74    pub fn to_dataset(&self) -> Dataset {
76        Dataset {
77            data: self.data.clone(),
78            target: self.target.clone(),
79            targetnames: None,
80            featurenames: None,
81            feature_descriptions: None,
82            description: None,
83            metadata: Default::default(),
84        }
85    }
86}
87
88pub struct StreamingIterator {
90    config: StreamConfig,
91    chunk_buffer: Arc<Mutex<VecDeque<DataChunk>>>,
92    current_chunk: usize,
93    total_chunks: Option<usize>,
94    finished: bool,
95    producer_handle: Option<thread::JoinHandle<Result<()>>>,
96}
97
98impl StreamingIterator {
99    pub fn from_csv<P: AsRef<Path>>(path: P, config: StreamConfig) -> Result<Self> {
101        let path = path.as_ref().to_path_buf();
102        let chunk_buffer = Arc::new(Mutex::new(VecDeque::new()));
103        let buffer_clone = Arc::clone(&chunk_buffer);
104        let config_clone = config.clone();
105
106        let producer_handle =
108            thread::spawn(move || Self::csv_producer(path, config_clone, buffer_clone));
109
110        Ok(Self {
111            config,
112            chunk_buffer,
113            current_chunk: 0,
114            total_chunks: None,
115            finished: false,
116            producer_handle: Some(producer_handle),
117        })
118    }
119
120    pub fn from_binary<P: AsRef<Path>>(
122        path: P,
123        n_features: usize,
124        config: StreamConfig,
125    ) -> Result<Self> {
126        let path = path.as_ref().to_path_buf();
127        let chunk_buffer = Arc::new(Mutex::new(VecDeque::new()));
128        let buffer_clone = Arc::clone(&chunk_buffer);
129        let config_clone = config.clone();
130
131        let producer_handle = thread::spawn(move || {
132            Self::binary_producer(path, n_features, config_clone, buffer_clone)
133        });
134
135        Ok(Self {
136            config,
137            chunk_buffer,
138            current_chunk: 0,
139            total_chunks: None,
140            finished: false,
141            producer_handle: Some(producer_handle),
142        })
143    }
144
145    pub fn from_generator<F>(
147        generator: F,
148        total_samples: usize,
149        n_features: usize,
150        config: StreamConfig,
151    ) -> Result<Self>
152    where
153        F: Fn(usize, usize, usize) -> Result<(Array2<f64>, Option<Array1<f64>>)> + Send + 'static,
154    {
155        let chunk_buffer = Arc::new(Mutex::new(VecDeque::new()));
156        let buffer_clone = Arc::clone(&chunk_buffer);
157        let config_clone = config.clone();
158
159        let producer_handle = thread::spawn(move || {
160            Self::generator_producer(
161                generator,
162                total_samples,
163                n_features,
164                config_clone,
165                buffer_clone,
166            )
167        });
168
169        let total_chunks = total_samples.div_ceil(config.chunk_size);
170
171        Ok(Self {
172            config,
173            chunk_buffer,
174            current_chunk: 0,
175            total_chunks: Some(total_chunks),
176            finished: false,
177            producer_handle: Some(producer_handle),
178        })
179    }
180
181    pub fn next_chunk(&mut self) -> Result<Option<DataChunk>> {
183        if self.finished {
184            return Ok(None);
185        }
186
187        if let Some(max_chunks) = self.config.max_chunks {
189            if self.current_chunk >= max_chunks {
190                self.finished = true;
191                return Ok(None);
192            }
193        }
194
195        loop {
197            {
198                let mut buffer = self.chunk_buffer.lock().unwrap();
199                if let Some(chunk) = buffer.pop_front() {
200                    self.current_chunk += 1;
201
202                    if chunk.is_last {
203                        self.finished = true;
204                    }
205
206                    return Ok(Some(chunk));
207                }
208            }
209
210            if let Some(handle) = &self.producer_handle {
212                if handle.is_finished() {
213                    let handle = self.producer_handle.take().unwrap();
215                    handle.join().unwrap()?;
216
217                    let mut buffer = self.chunk_buffer.lock().unwrap();
219                    if let Some(chunk) = buffer.pop_front() {
220                        self.current_chunk += 1;
221                        if chunk.is_last {
222                            self.finished = true;
223                        }
224                        return Ok(Some(chunk));
225                    } else {
226                        self.finished = true;
227                        return Ok(None);
228                    }
229                }
230            }
231
232            thread::sleep(std::time::Duration::from_millis(10));
234        }
235    }
236
237    pub fn stats(&self) -> StreamStats {
239        let buffer = self.chunk_buffer.lock().unwrap();
240        StreamStats {
241            current_chunk: self.current_chunk,
242            total_chunks: self.total_chunks,
243            buffer_size: buffer.len(),
244            buffer_capacity: self.config.buffer_size,
245            finished: self.finished,
246        }
247    }
248
249    fn csv_producer(
251        path: std::path::PathBuf,
252        config: StreamConfig,
253        buffer: Arc<Mutex<VecDeque<DataChunk>>>,
254    ) -> Result<()> {
255        use std::fs::File;
256        use std::io::{BufRead, BufReader};
257
258        let file = File::open(&path)?;
259        let reader = BufReader::new(file);
260        let mut lines = reader.lines();
261
262        let _header = lines.next();
264
265        let mut chunk_data = Vec::new();
266        let mut chunk_index = 0;
267        let mut global_sample_index = 0;
268
269        for line in lines {
270            let line = line?;
271            let values: Vec<f64> = line
272                .split(',')
273                .map(|s| s.trim().parse().unwrap_or(0.0))
274                .collect();
275
276            if !values.is_empty() {
277                chunk_data.push((values, global_sample_index));
278                global_sample_index += 1;
279
280                if chunk_data.len() >= config.chunk_size {
281                    let chunk = Self::create_chunk_from_data(&chunk_data, chunk_index, false)?;
282
283                    loop {
285                        let mut buffer_guard = buffer.lock().unwrap();
286                        if buffer_guard.len() < config.buffer_size {
287                            buffer_guard.push_back(chunk);
288                            break;
289                        }
290                        drop(buffer_guard);
291                        thread::sleep(std::time::Duration::from_millis(10));
292                    }
293
294                    chunk_data.clear();
295                    chunk_index += 1;
296
297                    if let Some(max_chunks) = config.max_chunks {
298                        if chunk_index >= max_chunks {
299                            break;
300                        }
301                    }
302                }
303            }
304        }
305
306        if !chunk_data.is_empty() {
308            let chunk = Self::create_chunk_from_data(&chunk_data, chunk_index, true)?;
309            let mut buffer_guard = buffer.lock().unwrap();
310            buffer_guard.push_back(chunk);
311        }
312
313        Ok(())
314    }
315
316    fn binary_producer(
318        path: std::path::PathBuf,
319        n_features: usize,
320        config: StreamConfig,
321        buffer: Arc<Mutex<VecDeque<DataChunk>>>,
322    ) -> Result<()> {
323        use std::fs::File;
324        use std::io::Read;
325
326        let mut file = File::open(&path)?;
327        let mut chunk_index = 0;
328        let mut global_sample_index = 0;
329
330        let values_per_chunk = config.chunk_size * n_features;
331        let bytes_per_chunk = values_per_chunk * std::mem::size_of::<f64>();
332
333        loop {
334            let mut buffer_data = vec![0u8; bytes_per_chunk];
335            let bytes_read = file.read(&mut buffer_data)?;
336
337            if bytes_read == 0 {
338                break; }
340
341            let values_read = bytes_read / std::mem::size_of::<f64>();
342            let samples_read = values_read / n_features;
343
344            if samples_read == 0 {
345                break;
346            }
347
348            let float_data: Vec<f64> = buffer_data[..bytes_read]
350                .chunks_exact(std::mem::size_of::<f64>())
351                .map(|chunk| {
352                    let mut bytes = [0u8; 8];
353                    bytes.copy_from_slice(chunk);
354                    f64::from_le_bytes(bytes)
355                })
356                .collect();
357
358            let data = Array2::from_shape_vec((samples_read, n_features), float_data)
360                .map_err(|e| DatasetsError::Other(format!("Shape error: {e}")))?;
361            let sample_indices: Vec<usize> =
362                (global_sample_index..global_sample_index + samples_read).collect();
363
364            let chunk = DataChunk {
365                data,
366                target: None,
367                chunk_index,
368                sample_indices,
369                is_last: bytes_read < bytes_per_chunk,
370            };
371
372            loop {
374                let mut buffer_guard = buffer.lock().unwrap();
375                if buffer_guard.len() < config.buffer_size {
376                    buffer_guard.push_back(chunk);
377                    break;
378                }
379                drop(buffer_guard);
380                thread::sleep(std::time::Duration::from_millis(10));
381            }
382
383            global_sample_index += samples_read;
384            chunk_index += 1;
385
386            if let Some(max_chunks) = config.max_chunks {
387                if chunk_index >= max_chunks {
388                    break;
389                }
390            }
391
392            if bytes_read < bytes_per_chunk {
393                break; }
395        }
396
397        Ok(())
398    }
399
400    fn generator_producer<F>(
402        generator: F,
403        total_samples: usize,
404        n_features: usize,
405        config: StreamConfig,
406        buffer: Arc<Mutex<VecDeque<DataChunk>>>,
407    ) -> Result<()>
408    where
409        F: Fn(usize, usize, usize) -> Result<(Array2<f64>, Option<Array1<f64>>)>,
410    {
411        let mut chunk_index = 0;
412        let mut processed_samples = 0;
413
414        while processed_samples < total_samples {
415            let remaining_samples = total_samples - processed_samples;
416            let chunk_samples = config.chunk_size.min(remaining_samples);
417
418            let (data, target) = generator(chunk_samples, n_features, processed_samples)?;
420
421            let sample_indices: Vec<usize> =
422                (processed_samples..processed_samples + chunk_samples).collect();
423            let is_last = processed_samples + chunk_samples >= total_samples;
424
425            let chunk = DataChunk {
426                data,
427                target,
428                chunk_index,
429                sample_indices,
430                is_last,
431            };
432
433            loop {
435                let mut buffer_guard = buffer.lock().unwrap();
436                if buffer_guard.len() < config.buffer_size {
437                    buffer_guard.push_back(chunk);
438                    break;
439                }
440                drop(buffer_guard);
441                thread::sleep(std::time::Duration::from_millis(10));
442            }
443
444            processed_samples += chunk_samples;
445            chunk_index += 1;
446
447            if let Some(max_chunks) = config.max_chunks {
448                if chunk_index >= max_chunks {
449                    break;
450                }
451            }
452        }
453
454        Ok(())
455    }
456
457    fn create_chunk_from_data(
459        data: &[(Vec<f64>, usize)],
460        chunk_index: usize,
461        is_last: bool,
462    ) -> Result<DataChunk> {
463        if data.is_empty() {
464            return Err(DatasetsError::InvalidFormat("Empty chunk data".to_string()));
465        }
466
467        let n_samples = data.len();
468        let n_features = data[0].0.len() - 1; let mut chunk_data = Array2::zeros((n_samples, n_features));
471        let mut chunk_target = Array1::zeros(n_samples);
472        let mut sample_indices = Vec::with_capacity(n_samples);
473
474        for (i, (values, global_idx)) in data.iter().enumerate() {
475            for j in 0..n_features {
476                chunk_data[[i, j]] = values[j];
477            }
478            chunk_target[i] = values[n_features];
479            sample_indices.push(*global_idx);
480        }
481
482        Ok(DataChunk {
483            data: chunk_data,
484            target: Some(chunk_target),
485            chunk_index,
486            sample_indices,
487            is_last,
488        })
489    }
490}
491
492#[derive(Debug, Clone)]
494pub struct StreamStats {
495    pub current_chunk: usize,
497    pub total_chunks: Option<usize>,
499    pub buffer_size: usize,
501    pub buffer_capacity: usize,
503    pub finished: bool,
505}
506
507impl StreamStats {
508    pub fn progress_percent(&self) -> Option<f64> {
510        self.total_chunks
511            .map(|total| (self.current_chunk as f64 / total as f64) * 100.0)
512    }
513
514    pub fn buffer_utilization(&self) -> f64 {
516        (self.buffer_size as f64 / self.buffer_capacity as f64) * 100.0
517    }
518}
519
520pub struct StreamProcessor<T> {
522    config: StreamConfig,
523    phantom: std::marker::PhantomData<T>,
524}
525
526impl<T> StreamProcessor<T>
527where
528    T: Send + Sync + 'static,
529{
530    pub fn new(config: StreamConfig) -> Self {
532        Self {
533            config,
534            phantom: std::marker::PhantomData,
535        }
536    }
537
538    pub fn process_parallel<F, R>(
540        &self,
541        mut iterator: StreamingIterator,
542        processor: F,
543    ) -> Result<Vec<R>>
544    where
545        F: Fn(DataChunk) -> Result<R> + Send + Sync + Clone + 'static,
546        R: Send + 'static,
547    {
548        use std::sync::mpsc;
549
550        let (work_tx, work_rx) = mpsc::channel();
552        let work_rx = Arc::new(Mutex::new(work_rx));
553
554        let (result_tx, result_rx) = mpsc::channel();
555        let mut worker_handles = Vec::new();
556
557        for worker_id in 0..self.config.num_workers {
559            let work_rx_clone = Arc::clone(&work_rx);
560            let result_tx_clone = result_tx.clone();
561            let processor_clone = processor.clone();
562
563            let handle = thread::spawn(move || {
564                loop {
565                    let chunk = {
567                        let rx = work_rx_clone.lock().unwrap();
568                        rx.recv().ok()
569                    };
570
571                    match chunk {
572                        Some(Some((chunk_id, chunk))) => {
573                            match processor_clone(chunk) {
575                                Ok(result) => {
576                                    if result_tx_clone.send((chunk_id, Ok(result))).is_err() {
578                                        eprintln!("Worker {worker_id} failed to send result");
579                                        break;
580                                    }
581                                }
582                                Err(e) => {
583                                    eprintln!("Worker {worker_id} processing error: {e}");
584                                    if result_tx_clone.send((chunk_id, Err(e))).is_err() {
586                                        break;
587                                    }
588                                }
589                            }
590                        }
591                        Some(None) => break, None => break,       }
594                }
595            });
596
597            worker_handles.push(handle);
598        }
599
600        let mut chunk_count = 0;
602        while let Some(chunk) = iterator.next_chunk()? {
603            work_tx
604                .send(Some((chunk_count, chunk)))
605                .map_err(|e| DatasetsError::Other(format!("Work send error: {e}")))?;
606            chunk_count += 1;
607        }
608
609        for _ in 0..self.config.num_workers {
611            work_tx
612                .send(None)
613                .map_err(|e| DatasetsError::Other(format!("End signal send error: {e}")))?;
614        }
615
616        drop(work_tx);
618
619        let mut results: Vec<Option<R>> = (0..chunk_count).map(|_| None).collect();
621        let mut received_count = 0;
622
623        while received_count < chunk_count {
625            match result_rx.recv() {
626                Ok((chunk_id, result)) => {
627                    match result {
628                        Ok(value) => {
629                            if chunk_id < results.len() {
630                                results[chunk_id] = Some(value);
631                                received_count += 1;
632                            }
633                        }
634                        Err(e) => {
635                            return Err(e);
637                        }
638                    }
639                }
640                Err(_) => {
641                    return Err(DatasetsError::Other(
642                        "Failed to receive results from workers".to_string(),
643                    ));
644                }
645            }
646        }
647
648        for handle in worker_handles {
650            if let Err(e) = handle.join() {
651                eprintln!("Worker thread panicked: {e:?}");
652            }
653        }
654
655        let final_results: Vec<R> =
657            results
658                .into_iter()
659                .collect::<Option<Vec<R>>>()
660                .ok_or_else(|| {
661                    DatasetsError::Other("Missing results from parallel processing".to_string())
662                })?;
663
664        Ok(final_results)
665    }
666}
667
668pub struct StreamTransformer {
670    #[allow(clippy::type_complexity)]
671    transformations: Vec<Box<dyn Fn(&mut DataChunk) -> Result<()> + Send + Sync>>,
672}
673
674impl StreamTransformer {
675    pub fn new() -> Self {
677        Self {
678            transformations: Vec::new(),
679        }
680    }
681
682    pub fn add_transform<F>(mut self, transform: F) -> Self
684    where
685        F: Fn(&mut DataChunk) -> Result<()> + Send + Sync + 'static,
686    {
687        self.transformations.push(Box::new(transform));
688        self
689    }
690
691    pub fn transform_chunk(&self, chunk: &mut DataChunk) -> Result<()> {
693        for transform in &self.transformations {
694            transform(chunk)?;
695        }
696        Ok(())
697    }
698
699    pub fn add_standard_scaling(self) -> Self {
701        self.add_transform(|chunk| {
702            let mean = chunk.data.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
704            let std = chunk.data.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
705
706            for mut row in chunk.data.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
707                for (i, val) in row.iter_mut().enumerate() {
708                    if std[i] > 0.0 {
709                        *val = (*val - mean[i]) / std[i];
710                    }
711                }
712            }
713            Ok(())
714        })
715    }
716
717    pub fn add_missing_value_imputation(self) -> Self {
719        self.add_transform(|chunk| {
720            let means = chunk.data.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
722
723            for mut row in chunk.data.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
724                for (i, val) in row.iter_mut().enumerate() {
725                    if val.is_nan() {
726                        *val = means[i];
727                    }
728                }
729            }
730            Ok(())
731        })
732    }
733}
734
735impl Default for StreamTransformer {
736    fn default() -> Self {
737        Self::new()
738    }
739}
740
741#[allow(dead_code)]
745pub fn stream_csv<P: AsRef<Path>>(path: P, config: StreamConfig) -> Result<StreamingIterator> {
746    StreamingIterator::from_csv(path, config)
747}
748
749#[allow(dead_code)]
751pub fn stream_classification(
752    total_samples: usize,
753    n_features: usize,
754    n_classes: usize,
755    config: StreamConfig,
756) -> Result<StreamingIterator> {
757    use crate::generators::make_classification;
758
759    let generator = move |chunk_size: usize, _features: usize, start_idx: usize| {
760        let dataset = make_classification(
761            chunk_size,
762            _features,
763            n_classes,
764            2,
765            _features / 2,
766            Some(42 + start_idx as u64),
767        )?;
768        Ok((dataset.data, dataset.target))
769    };
770
771    StreamingIterator::from_generator(generator, total_samples, n_features, config)
772}
773
774#[allow(dead_code)]
776pub fn stream_regression(
777    total_samples: usize,
778    n_features: usize,
779    config: StreamConfig,
780) -> Result<StreamingIterator> {
781    use crate::generators::make_regression;
782
783    let generator = move |chunk_size: usize, _features: usize, start_idx: usize| {
784        let dataset = make_regression(
785            chunk_size,
786            _features,
787            _features / 2,
788            0.1,
789            Some(42 + start_idx as u64),
790        )?;
791        Ok((dataset.data, dataset.target))
792    };
793
794    StreamingIterator::from_generator(generator, total_samples, n_features, config)
795}
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800
801    #[test]
802    fn test_stream_config() {
803        let config = StreamConfig::default();
804        assert_eq!(config.chunk_size, 10_000);
805        assert_eq!(config.buffer_size, 3);
806        assert!(config.num_workers > 0);
807    }
808
809    #[test]
810    fn test_data_chunk() {
811        let data = Array2::zeros((100, 5));
812        let target = Array1::zeros(100);
813        let chunk = DataChunk {
814            data,
815            target: Some(target),
816            chunk_index: 0,
817            sample_indices: (0..100).collect(),
818            is_last: false,
819        };
820
821        assert_eq!(chunk.n_samples(), 100);
822        assert_eq!(chunk.n_features(), 5);
823        assert!(!chunk.is_last);
824    }
825
826    #[test]
827    fn test_stream_stats() {
828        let stats = StreamStats {
829            current_chunk: 5,
830            total_chunks: Some(10),
831            buffer_size: 2,
832            buffer_capacity: 3,
833            finished: false,
834        };
835
836        assert_eq!(stats.progress_percent(), Some(50.0));
837        assert!((stats.buffer_utilization() - 66.66666666666667).abs() < 1e-10);
838    }
839
840    #[test]
841    fn test_stream_classification() {
842        let config = StreamConfig {
843            chunk_size: 100,
844            buffer_size: 2,
845            max_chunks: Some(3),
846            ..Default::default()
847        };
848
849        let stream = stream_classification(1000, 10, 3, config).unwrap();
850        assert!(stream.total_chunks.is_some());
851    }
852
853    #[test]
854    fn test_stream_transformer() {
855        let transformer = StreamTransformer::new()
856            .add_standard_scaling()
857            .add_missing_value_imputation();
858
859        assert_eq!(transformer.transformations.len(), 2);
860    }
861}