Skip to main content

scirs2_datasets/
parallel_preprocessing.rs

1//! Parallel data preprocessing pipeline
2//!
3//! This module provides a multi-threaded preprocessing pipeline with work-stealing
4//! scheduler, memory-efficient batch processing, and backpressure handling for
5//! optimal throughput and resource utilization.
6
7use crate::error::{DatasetsError, Result};
8use crate::streaming::DataChunk;
9use crate::utils::Dataset;
10use crossbeam_channel::{bounded, unbounded, Receiver, Sender};
11use scirs2_core::ndarray::{Array1, Array2};
12use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use std::sync::Arc;
14use std::thread::{self, JoinHandle};
15
16/// Preprocessing function type
17pub type PreprocessFn = Arc<dyn Fn(&Array2<f64>) -> Result<Array2<f64>> + Send + Sync>;
18
19/// Configuration for parallel preprocessing
20#[derive(Clone)]
21pub struct ParallelConfig {
22    /// Number of worker threads (0 = auto-detect)
23    pub num_workers: usize,
24    /// Size of the input buffer
25    pub input_buffer_size: usize,
26    /// Size of the output buffer
27    pub output_buffer_size: usize,
28    /// Batch size for processing
29    pub batch_size: usize,
30    /// Whether to use work stealing
31    pub enable_work_stealing: bool,
32    /// Maximum memory usage in bytes (0 = unlimited)
33    pub max_memory_bytes: usize,
34    /// Whether to enable backpressure
35    pub enable_backpressure: bool,
36}
37
38impl Default for ParallelConfig {
39    fn default() -> Self {
40        Self {
41            num_workers: num_cpus::get(),
42            input_buffer_size: 10,
43            output_buffer_size: 10,
44            batch_size: 1000,
45            enable_work_stealing: true,
46            max_memory_bytes: 0,
47            enable_backpressure: true,
48        }
49    }
50}
51
52impl ParallelConfig {
53    /// Create a new configuration
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    /// Set number of workers
59    pub fn with_workers(mut self, num_workers: usize) -> Self {
60        self.num_workers = if num_workers == 0 {
61            num_cpus::get()
62        } else {
63            num_workers
64        };
65        self
66    }
67
68    /// Set buffer sizes
69    pub fn with_buffer_sizes(mut self, input: usize, output: usize) -> Self {
70        self.input_buffer_size = input;
71        self.output_buffer_size = output;
72        self
73    }
74
75    /// Set batch size
76    pub fn with_batch_size(mut self, size: usize) -> Self {
77        self.batch_size = size;
78        self
79    }
80
81    /// Enable or disable work stealing
82    pub fn with_work_stealing(mut self, enable: bool) -> Self {
83        self.enable_work_stealing = enable;
84        self
85    }
86
87    /// Set memory limit
88    pub fn with_memory_limit(mut self, bytes: usize) -> Self {
89        self.max_memory_bytes = bytes;
90        self
91    }
92}
93
94/// Work item for preprocessing
95#[derive(Clone)]
96struct WorkItem {
97    id: usize,
98    data: Array2<f64>,
99    target: Option<Array1<f64>>,
100}
101
102/// Processed result
103struct ProcessedItem {
104    id: usize,
105    data: Array2<f64>,
106    target: Option<Array1<f64>>,
107}
108
109/// Parallel preprocessing pipeline
110pub struct ParallelPipeline {
111    config: ParallelConfig,
112    preprocess_fn: PreprocessFn,
113    workers: Vec<JoinHandle<()>>,
114    input_sender: Option<Sender<WorkItem>>,
115    output_receiver: Option<Receiver<ProcessedItem>>,
116    stop_flag: Arc<AtomicBool>,
117    items_processed: Arc<AtomicUsize>,
118}
119
120impl ParallelPipeline {
121    /// Create a new parallel preprocessing pipeline
122    ///
123    /// # Arguments
124    /// * `config` - Pipeline configuration
125    /// * `preprocess_fn` - Function to apply to each data chunk
126    ///
127    /// # Returns
128    /// * `ParallelPipeline` - The pipeline instance
129    pub fn new(config: ParallelConfig, preprocess_fn: PreprocessFn) -> Self {
130        let (input_tx, input_rx) = if config.enable_backpressure {
131            bounded(config.input_buffer_size)
132        } else {
133            unbounded()
134        };
135
136        let (output_tx, output_rx) = if config.enable_backpressure {
137            bounded(config.output_buffer_size)
138        } else {
139            unbounded()
140        };
141
142        let stop_flag = Arc::new(AtomicBool::new(false));
143        let items_processed = Arc::new(AtomicUsize::new(0));
144
145        // Spawn worker threads
146        let mut workers = Vec::new();
147        for worker_id in 0..config.num_workers {
148            let rx = input_rx.clone();
149            let tx = output_tx.clone();
150            let fn_clone = Arc::clone(&preprocess_fn);
151            let stop_flag_clone = Arc::clone(&stop_flag);
152            let items_clone = Arc::clone(&items_processed);
153
154            let worker = thread::spawn(move || {
155                Self::worker_loop(worker_id, rx, tx, fn_clone, stop_flag_clone, items_clone);
156            });
157
158            workers.push(worker);
159        }
160
161        // Drop the original senders/receivers so workers can detect completion
162        drop(output_tx);
163
164        Self {
165            config,
166            preprocess_fn,
167            workers,
168            input_sender: Some(input_tx),
169            output_receiver: Some(output_rx),
170            stop_flag,
171            items_processed,
172        }
173    }
174
175    /// Worker thread main loop
176    fn worker_loop(
177        _worker_id: usize,
178        input: Receiver<WorkItem>,
179        output: Sender<ProcessedItem>,
180        preprocess_fn: PreprocessFn,
181        stop_flag: Arc<AtomicBool>,
182        items_processed: Arc<AtomicUsize>,
183    ) {
184        while !stop_flag.load(Ordering::Relaxed) {
185            match input.recv() {
186                Ok(item) => {
187                    // Process the item
188                    match preprocess_fn(&item.data) {
189                        Ok(processed_data) => {
190                            let result = ProcessedItem {
191                                id: item.id,
192                                data: processed_data,
193                                target: item.target,
194                            };
195
196                            // Increment before sending so the counter is visible
197                            // to the receiver once it drains the result
198                            items_processed.fetch_add(1, Ordering::Release);
199                            // Send result (ignore errors as receiver might be dropped)
200                            let _ = output.send(result);
201                        }
202                        Err(_) => {
203                            // On error, pass through original data
204                            let result = ProcessedItem {
205                                id: item.id,
206                                data: item.data,
207                                target: item.target,
208                            };
209                            let _ = output.send(result);
210                        }
211                    }
212                }
213                Err(_) => break, // Channel closed
214            }
215        }
216    }
217
218    /// Submit data for processing
219    ///
220    /// # Arguments
221    /// * `data` - Input data array
222    /// * `target` - Optional target values
223    ///
224    /// # Returns
225    /// * `Ok(usize)` - ID of the submitted item
226    /// * `Err(DatasetsError)` - If submission fails
227    pub fn submit(&mut self, data: Array2<f64>, target: Option<Array1<f64>>) -> Result<usize> {
228        let id = self.items_processed.load(Ordering::Relaxed);
229        let item = WorkItem { id, data, target };
230
231        self.input_sender
232            .as_ref()
233            .ok_or_else(|| DatasetsError::ProcessingError("Pipeline not initialized".to_string()))?
234            .send(item)
235            .map_err(|e| DatasetsError::ProcessingError(format!("Failed to submit: {}", e)))?;
236
237        Ok(id)
238    }
239
240    /// Submit a dataset for processing
241    pub fn submit_dataset(&mut self, dataset: &Dataset) -> Result<usize> {
242        self.submit(dataset.data.clone(), dataset.target.clone())
243    }
244
245    /// Submit a data chunk for processing
246    pub fn submit_chunk(&mut self, chunk: &DataChunk) -> Result<usize> {
247        self.submit(chunk.data.clone(), chunk.target.clone())
248    }
249
250    /// Receive a processed result
251    ///
252    /// # Returns
253    /// * `Ok(Some(Dataset))` - Processed dataset
254    /// * `Ok(None)` - No more results (all workers finished)
255    /// * `Err(DatasetsError)` - If receive fails
256    pub fn receive(&mut self) -> Result<Option<Dataset>> {
257        match self.output_receiver.as_ref() {
258            Some(rx) => match rx.recv() {
259                Ok(item) => Ok(Some(Dataset {
260                    data: item.data,
261                    target: item.target,
262                    targetnames: None,
263                    featurenames: None,
264                    feature_descriptions: None,
265                    description: None,
266                    metadata: Default::default(),
267                })),
268                Err(_) => Ok(None), // Channel closed
269            },
270            None => Err(DatasetsError::ProcessingError(
271                "Pipeline not initialized".to_string(),
272            )),
273        }
274    }
275
276    /// Try to receive a result without blocking
277    pub fn try_receive(&mut self) -> Result<Option<Dataset>> {
278        match self.output_receiver.as_ref() {
279            Some(rx) => match rx.try_recv() {
280                Ok(item) => Ok(Some(Dataset {
281                    data: item.data,
282                    target: item.target,
283                    targetnames: None,
284                    featurenames: None,
285                    feature_descriptions: None,
286                    description: None,
287                    metadata: Default::default(),
288                })),
289                Err(_) => Ok(None),
290            },
291            None => Err(DatasetsError::ProcessingError(
292                "Pipeline not initialized".to_string(),
293            )),
294        }
295    }
296
297    /// Process a batch of datasets
298    pub fn process_batch(&mut self, datasets: &[Dataset]) -> Result<Vec<Dataset>> {
299        // Submit all
300        for ds in datasets {
301            self.submit_dataset(ds)?;
302        }
303
304        // Collect results
305        let mut results = Vec::new();
306        for _ in 0..datasets.len() {
307            if let Some(result) = self.receive()? {
308                results.push(result);
309            }
310        }
311
312        Ok(results)
313    }
314
315    /// Get number of items processed
316    pub fn items_processed(&self) -> usize {
317        self.items_processed.load(Ordering::Acquire)
318    }
319
320    /// Stop the pipeline gracefully
321    pub fn stop(&mut self) {
322        self.stop_flag.store(true, Ordering::Relaxed);
323        self.input_sender = None; // Drop sender to wake up workers
324    }
325
326    /// Wait for all workers to finish
327    pub fn join(mut self) -> Result<()> {
328        // Drop senders to signal completion
329        self.input_sender = None;
330
331        // Wait for all workers
332        let workers = std::mem::take(&mut self.workers);
333        for worker in workers {
334            worker.join().map_err(|_| {
335                DatasetsError::ProcessingError("Worker thread panicked".to_string())
336            })?;
337        }
338
339        Ok(())
340    }
341}
342
343impl Drop for ParallelPipeline {
344    fn drop(&mut self) {
345        self.stop();
346    }
347}
348
349/// Create a simple preprocessing pipeline
350///
351/// # Arguments
352/// * `preprocess_fn` - Function to apply to each chunk
353/// * `num_workers` - Number of worker threads (0 = auto)
354///
355/// # Returns
356/// * `ParallelPipeline` - The pipeline instance
357pub fn create_pipeline<F>(preprocess_fn: F, num_workers: usize) -> ParallelPipeline
358where
359    F: Fn(&Array2<f64>) -> Result<Array2<f64>> + Send + Sync + 'static,
360{
361    let config = ParallelConfig::default().with_workers(num_workers);
362    ParallelPipeline::new(config, Arc::new(preprocess_fn))
363}
364
365/// Create a pipeline with custom configuration
366pub fn create_pipeline_with_config<F>(config: ParallelConfig, preprocess_fn: F) -> ParallelPipeline
367where
368    F: Fn(&Array2<f64>) -> Result<Array2<f64>> + Send + Sync + 'static,
369{
370    ParallelPipeline::new(config, Arc::new(preprocess_fn))
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_parallel_config() {
379        let config = ParallelConfig::new()
380            .with_workers(4)
381            .with_batch_size(500)
382            .with_buffer_sizes(5, 5)
383            .with_work_stealing(true);
384
385        assert_eq!(config.num_workers, 4);
386        assert_eq!(config.batch_size, 500);
387        assert_eq!(config.input_buffer_size, 5);
388        assert_eq!(config.output_buffer_size, 5);
389        assert!(config.enable_work_stealing);
390    }
391
392    #[test]
393    fn test_simple_pipeline() -> Result<()> {
394        // Create a simple preprocessing function (multiply by 2)
395        let preprocess = |data: &Array2<f64>| -> Result<Array2<f64>> { Ok(data * 2.0) };
396
397        let mut pipeline = create_pipeline(preprocess, 2);
398
399        // Submit some data
400        let data =
401            Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
402                .map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
403
404        pipeline.submit(data.clone(), None)?;
405
406        // Receive result
407        if let Some(result) = pipeline.receive()? {
408            assert_eq!(result.data[[0, 0]], 2.0);
409            assert_eq!(result.data[[2, 2]], 18.0);
410        }
411
412        pipeline.stop();
413        Ok(())
414    }
415
416    #[test]
417    fn test_batch_processing() -> Result<()> {
418        let preprocess = |data: &Array2<f64>| -> Result<Array2<f64>> { Ok(data + 1.0) };
419
420        let mut pipeline = create_pipeline(preprocess, 4);
421
422        // Create batch of datasets
423        let datasets: Vec<Dataset> = (0..5)
424            .map(|i| {
425                let data = Array2::from_elem((2, 2), i as f64);
426                Dataset {
427                    data,
428                    target: None,
429                    targetnames: None,
430                    featurenames: None,
431                    feature_descriptions: None,
432                    description: None,
433                    metadata: Default::default(),
434                }
435            })
436            .collect();
437
438        let results = pipeline.process_batch(&datasets)?;
439        assert_eq!(results.len(), 5);
440
441        pipeline.stop();
442        Ok(())
443    }
444
445    #[test]
446    fn test_pipeline_stats() -> Result<()> {
447        let preprocess = |data: &Array2<f64>| -> Result<Array2<f64>> { Ok(data.clone()) };
448
449        let mut pipeline = create_pipeline(preprocess, 2);
450
451        let data = Array2::zeros((5, 5));
452        for _ in 0..3 {
453            pipeline.submit(data.clone(), None)?;
454        }
455
456        // Drain results
457        for _ in 0..3 {
458            let _ = pipeline.receive()?;
459        }
460
461        assert_eq!(pipeline.items_processed(), 3);
462
463        pipeline.stop();
464        Ok(())
465    }
466}