1use crate::error::{DatasetsError, Result};
8use crate::streaming::DataChunk;
9use crate::utils::Dataset;
10use crossbeam_channel::{bounded, unbounded, Receiver, Sender};
11use scirs2_core::ndarray::{Array1, Array2};
12use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use std::sync::Arc;
14use std::thread::{self, JoinHandle};
15
16pub type PreprocessFn = Arc<dyn Fn(&Array2<f64>) -> Result<Array2<f64>> + Send + Sync>;
18
19#[derive(Clone)]
21pub struct ParallelConfig {
22 pub num_workers: usize,
24 pub input_buffer_size: usize,
26 pub output_buffer_size: usize,
28 pub batch_size: usize,
30 pub enable_work_stealing: bool,
32 pub max_memory_bytes: usize,
34 pub enable_backpressure: bool,
36}
37
38impl Default for ParallelConfig {
39 fn default() -> Self {
40 Self {
41 num_workers: num_cpus::get(),
42 input_buffer_size: 10,
43 output_buffer_size: 10,
44 batch_size: 1000,
45 enable_work_stealing: true,
46 max_memory_bytes: 0,
47 enable_backpressure: true,
48 }
49 }
50}
51
52impl ParallelConfig {
53 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn with_workers(mut self, num_workers: usize) -> Self {
60 self.num_workers = if num_workers == 0 {
61 num_cpus::get()
62 } else {
63 num_workers
64 };
65 self
66 }
67
68 pub fn with_buffer_sizes(mut self, input: usize, output: usize) -> Self {
70 self.input_buffer_size = input;
71 self.output_buffer_size = output;
72 self
73 }
74
75 pub fn with_batch_size(mut self, size: usize) -> Self {
77 self.batch_size = size;
78 self
79 }
80
81 pub fn with_work_stealing(mut self, enable: bool) -> Self {
83 self.enable_work_stealing = enable;
84 self
85 }
86
87 pub fn with_memory_limit(mut self, bytes: usize) -> Self {
89 self.max_memory_bytes = bytes;
90 self
91 }
92}
93
94#[derive(Clone)]
96struct WorkItem {
97 id: usize,
98 data: Array2<f64>,
99 target: Option<Array1<f64>>,
100}
101
102struct ProcessedItem {
104 id: usize,
105 data: Array2<f64>,
106 target: Option<Array1<f64>>,
107}
108
109pub struct ParallelPipeline {
111 config: ParallelConfig,
112 preprocess_fn: PreprocessFn,
113 workers: Vec<JoinHandle<()>>,
114 input_sender: Option<Sender<WorkItem>>,
115 output_receiver: Option<Receiver<ProcessedItem>>,
116 stop_flag: Arc<AtomicBool>,
117 items_processed: Arc<AtomicUsize>,
118}
119
120impl ParallelPipeline {
121 pub fn new(config: ParallelConfig, preprocess_fn: PreprocessFn) -> Self {
130 let (input_tx, input_rx) = if config.enable_backpressure {
131 bounded(config.input_buffer_size)
132 } else {
133 unbounded()
134 };
135
136 let (output_tx, output_rx) = if config.enable_backpressure {
137 bounded(config.output_buffer_size)
138 } else {
139 unbounded()
140 };
141
142 let stop_flag = Arc::new(AtomicBool::new(false));
143 let items_processed = Arc::new(AtomicUsize::new(0));
144
145 let mut workers = Vec::new();
147 for worker_id in 0..config.num_workers {
148 let rx = input_rx.clone();
149 let tx = output_tx.clone();
150 let fn_clone = Arc::clone(&preprocess_fn);
151 let stop_flag_clone = Arc::clone(&stop_flag);
152 let items_clone = Arc::clone(&items_processed);
153
154 let worker = thread::spawn(move || {
155 Self::worker_loop(worker_id, rx, tx, fn_clone, stop_flag_clone, items_clone);
156 });
157
158 workers.push(worker);
159 }
160
161 drop(output_tx);
163
164 Self {
165 config,
166 preprocess_fn,
167 workers,
168 input_sender: Some(input_tx),
169 output_receiver: Some(output_rx),
170 stop_flag,
171 items_processed,
172 }
173 }
174
175 fn worker_loop(
177 _worker_id: usize,
178 input: Receiver<WorkItem>,
179 output: Sender<ProcessedItem>,
180 preprocess_fn: PreprocessFn,
181 stop_flag: Arc<AtomicBool>,
182 items_processed: Arc<AtomicUsize>,
183 ) {
184 while !stop_flag.load(Ordering::Relaxed) {
185 match input.recv() {
186 Ok(item) => {
187 match preprocess_fn(&item.data) {
189 Ok(processed_data) => {
190 let result = ProcessedItem {
191 id: item.id,
192 data: processed_data,
193 target: item.target,
194 };
195
196 items_processed.fetch_add(1, Ordering::Release);
199 let _ = output.send(result);
201 }
202 Err(_) => {
203 let result = ProcessedItem {
205 id: item.id,
206 data: item.data,
207 target: item.target,
208 };
209 let _ = output.send(result);
210 }
211 }
212 }
213 Err(_) => break, }
215 }
216 }
217
218 pub fn submit(&mut self, data: Array2<f64>, target: Option<Array1<f64>>) -> Result<usize> {
228 let id = self.items_processed.load(Ordering::Relaxed);
229 let item = WorkItem { id, data, target };
230
231 self.input_sender
232 .as_ref()
233 .ok_or_else(|| DatasetsError::ProcessingError("Pipeline not initialized".to_string()))?
234 .send(item)
235 .map_err(|e| DatasetsError::ProcessingError(format!("Failed to submit: {}", e)))?;
236
237 Ok(id)
238 }
239
240 pub fn submit_dataset(&mut self, dataset: &Dataset) -> Result<usize> {
242 self.submit(dataset.data.clone(), dataset.target.clone())
243 }
244
245 pub fn submit_chunk(&mut self, chunk: &DataChunk) -> Result<usize> {
247 self.submit(chunk.data.clone(), chunk.target.clone())
248 }
249
250 pub fn receive(&mut self) -> Result<Option<Dataset>> {
257 match self.output_receiver.as_ref() {
258 Some(rx) => match rx.recv() {
259 Ok(item) => Ok(Some(Dataset {
260 data: item.data,
261 target: item.target,
262 targetnames: None,
263 featurenames: None,
264 feature_descriptions: None,
265 description: None,
266 metadata: Default::default(),
267 })),
268 Err(_) => Ok(None), },
270 None => Err(DatasetsError::ProcessingError(
271 "Pipeline not initialized".to_string(),
272 )),
273 }
274 }
275
276 pub fn try_receive(&mut self) -> Result<Option<Dataset>> {
278 match self.output_receiver.as_ref() {
279 Some(rx) => match rx.try_recv() {
280 Ok(item) => Ok(Some(Dataset {
281 data: item.data,
282 target: item.target,
283 targetnames: None,
284 featurenames: None,
285 feature_descriptions: None,
286 description: None,
287 metadata: Default::default(),
288 })),
289 Err(_) => Ok(None),
290 },
291 None => Err(DatasetsError::ProcessingError(
292 "Pipeline not initialized".to_string(),
293 )),
294 }
295 }
296
297 pub fn process_batch(&mut self, datasets: &[Dataset]) -> Result<Vec<Dataset>> {
299 for ds in datasets {
301 self.submit_dataset(ds)?;
302 }
303
304 let mut results = Vec::new();
306 for _ in 0..datasets.len() {
307 if let Some(result) = self.receive()? {
308 results.push(result);
309 }
310 }
311
312 Ok(results)
313 }
314
315 pub fn items_processed(&self) -> usize {
317 self.items_processed.load(Ordering::Acquire)
318 }
319
320 pub fn stop(&mut self) {
322 self.stop_flag.store(true, Ordering::Relaxed);
323 self.input_sender = None; }
325
326 pub fn join(mut self) -> Result<()> {
328 self.input_sender = None;
330
331 let workers = std::mem::take(&mut self.workers);
333 for worker in workers {
334 worker.join().map_err(|_| {
335 DatasetsError::ProcessingError("Worker thread panicked".to_string())
336 })?;
337 }
338
339 Ok(())
340 }
341}
342
343impl Drop for ParallelPipeline {
344 fn drop(&mut self) {
345 self.stop();
346 }
347}
348
349pub fn create_pipeline<F>(preprocess_fn: F, num_workers: usize) -> ParallelPipeline
358where
359 F: Fn(&Array2<f64>) -> Result<Array2<f64>> + Send + Sync + 'static,
360{
361 let config = ParallelConfig::default().with_workers(num_workers);
362 ParallelPipeline::new(config, Arc::new(preprocess_fn))
363}
364
365pub fn create_pipeline_with_config<F>(config: ParallelConfig, preprocess_fn: F) -> ParallelPipeline
367where
368 F: Fn(&Array2<f64>) -> Result<Array2<f64>> + Send + Sync + 'static,
369{
370 ParallelPipeline::new(config, Arc::new(preprocess_fn))
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_parallel_config() {
379 let config = ParallelConfig::new()
380 .with_workers(4)
381 .with_batch_size(500)
382 .with_buffer_sizes(5, 5)
383 .with_work_stealing(true);
384
385 assert_eq!(config.num_workers, 4);
386 assert_eq!(config.batch_size, 500);
387 assert_eq!(config.input_buffer_size, 5);
388 assert_eq!(config.output_buffer_size, 5);
389 assert!(config.enable_work_stealing);
390 }
391
392 #[test]
393 fn test_simple_pipeline() -> Result<()> {
394 let preprocess = |data: &Array2<f64>| -> Result<Array2<f64>> { Ok(data * 2.0) };
396
397 let mut pipeline = create_pipeline(preprocess, 2);
398
399 let data =
401 Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
402 .map_err(|e| DatasetsError::InvalidFormat(format!("{}", e)))?;
403
404 pipeline.submit(data.clone(), None)?;
405
406 if let Some(result) = pipeline.receive()? {
408 assert_eq!(result.data[[0, 0]], 2.0);
409 assert_eq!(result.data[[2, 2]], 18.0);
410 }
411
412 pipeline.stop();
413 Ok(())
414 }
415
416 #[test]
417 fn test_batch_processing() -> Result<()> {
418 let preprocess = |data: &Array2<f64>| -> Result<Array2<f64>> { Ok(data + 1.0) };
419
420 let mut pipeline = create_pipeline(preprocess, 4);
421
422 let datasets: Vec<Dataset> = (0..5)
424 .map(|i| {
425 let data = Array2::from_elem((2, 2), i as f64);
426 Dataset {
427 data,
428 target: None,
429 targetnames: None,
430 featurenames: None,
431 feature_descriptions: None,
432 description: None,
433 metadata: Default::default(),
434 }
435 })
436 .collect();
437
438 let results = pipeline.process_batch(&datasets)?;
439 assert_eq!(results.len(), 5);
440
441 pipeline.stop();
442 Ok(())
443 }
444
445 #[test]
446 fn test_pipeline_stats() -> Result<()> {
447 let preprocess = |data: &Array2<f64>| -> Result<Array2<f64>> { Ok(data.clone()) };
448
449 let mut pipeline = create_pipeline(preprocess, 2);
450
451 let data = Array2::zeros((5, 5));
452 for _ in 0..3 {
453 pipeline.submit(data.clone(), None)?;
454 }
455
456 for _ in 0..3 {
458 let _ = pipeline.receive()?;
459 }
460
461 assert_eq!(pipeline.items_processed(), 3);
462
463 pipeline.stop();
464 Ok(())
465 }
466}