1use crate::error::Result;
7use crate::traits::*;
8use crate::types::FloatBounds;
9use rayon::prelude::*;
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
12
13#[derive(Debug, Clone)]
15pub struct ParallelConfig {
16 pub num_threads: Option<usize>,
18 pub min_parallel_batch_size: usize,
20 pub enabled: bool,
22}
23
24impl Default for ParallelConfig {
25 fn default() -> Self {
26 Self {
27 num_threads: None,
28 min_parallel_batch_size: 1000,
29 enabled: true,
30 }
31 }
32}
33
34pub trait ParallelPredict<X, Output> {
36 fn predict_parallel(&self, x: &X) -> Result<Output>;
38
39 fn predict_parallel_with_config(&self, x: &X, config: &ParallelConfig) -> Result<Output>;
41}
42
43pub trait ParallelTransform<X, Output = X> {
45 fn transform_parallel(&self, x: &X) -> Result<Output>;
47
48 fn transform_parallel_with_config(&self, x: &X, config: &ParallelConfig) -> Result<Output>;
50}
51
52pub trait ParallelFit<X, Y> {
54 type Fitted;
55
56 fn fit_parallel(self, x: &X, y: &Y) -> Result<Self::Fitted>;
58
59 fn fit_parallel_with_config(
61 self,
62 x: &X,
63 y: &Y,
64 config: &ParallelConfig,
65 ) -> Result<Self::Fitted>;
66}
67
68pub trait ParallelCrossValidation<X, Y> {
70 type Score: FloatBounds;
71
72 fn cross_validate_parallel(
74 &self,
75 model: impl Fit<X, Y> + Clone + Send + Sync,
76 x: &X,
77 y: &Y,
78 cv_folds: usize,
79 ) -> Result<Vec<Self::Score>>
80 where
81 X: Clone + Send + Sync,
82 Y: Clone + Send + Sync,
83 <Self as ParallelCrossValidation<X, Y>>::Score: Send;
84}
85
86pub trait ParallelEnsemble<X, Y, Output> {
88 fn fit_ensemble_parallel(
90 models: Vec<impl Fit<X, Y> + Clone + Send + Sync>,
91 x: &X,
92 y: &Y,
93 ) -> Result<Vec<Box<dyn Predict<X, Output>>>>
94 where
95 X: Clone + Send + Sync,
96 Y: Clone + Send + Sync;
97
98 fn predict_ensemble_parallel(
100 models: &[impl Predict<X, Output> + Sync],
101 x: &X,
102 ) -> Result<Vec<Output>>
103 where
104 X: Sync,
105 Output: Send;
106}
107
108pub fn predict_parallel_ndarray<T, M>(
110 model: &M,
111 x: &Array2<T>,
112 config: &ParallelConfig,
113) -> Result<Array1<T>>
114where
115 T: FloatBounds + Send + Sync,
116 M: Predict<Array2<T>, Array1<T>> + Sync,
117{
118 if !config.enabled || x.nrows() < config.min_parallel_batch_size {
119 return model.predict(x);
120 }
121
122 let chunk_size = (x.nrows() / rayon::current_num_threads()).max(1);
124 let chunks: Vec<_> = x.axis_chunks_iter(Axis(0), chunk_size).collect();
125
126 let results: Result<Vec<_>> = chunks
127 .into_par_iter()
128 .map(|chunk| {
129 let chunk_array = chunk.to_owned();
130 model.predict(&chunk_array)
131 })
132 .collect();
133
134 let predictions = results?;
135
136 let total_len: usize = predictions.iter().map(|p| p.len()).sum();
138 let mut result = Array1::zeros(total_len);
139 let mut offset = 0;
140
141 for pred in predictions {
142 let end = offset + pred.len();
143 result
144 .slice_mut(scirs2_core::ndarray::s![offset..end])
145 .assign(&pred);
146 offset = end;
147 }
148
149 Ok(result)
150}
151
152pub struct ParallelMatrixOps;
154
155impl ParallelMatrixOps {
156 pub fn matrix_multiply_parallel<T: FloatBounds + Send + Sync>(
158 a: &Array2<T>,
159 b: &Array2<T>,
160 config: &ParallelConfig,
161 ) -> Array2<T> {
162 let (m, k) = a.dim();
163 let (k2, n) = b.dim();
164 assert_eq!(k, k2, "Matrix dimensions must match");
165
166 let mut result = Array2::zeros((m, n));
167
168 if !config.enabled || m < config.min_parallel_batch_size {
169 result.assign(&a.dot(b));
171 return result;
172 }
173
174 result
176 .axis_iter_mut(Axis(0))
177 .into_par_iter()
178 .enumerate()
179 .for_each(|(i, mut row)| {
180 for j in 0..n {
181 let mut sum = T::zero();
182 for ki in 0..k {
183 sum += a[[i, ki]] * b[[ki, j]];
184 }
185 row[j] = sum;
186 }
187 });
188
189 result
190 }
191
192 pub fn elementwise_op_parallel<T, F>(
194 a: &Array2<T>,
195 b: &Array2<T>,
196 op: F,
197 config: &ParallelConfig,
198 ) -> Array2<T>
199 where
200 T: FloatBounds + Send + Sync,
201 F: Fn(T, T) -> T + Send + Sync,
202 {
203 assert_eq!(a.shape(), b.shape());
204
205 let mut result = Array2::zeros(a.dim());
206
207 if !config.enabled || a.len() < config.min_parallel_batch_size {
208 result
210 .iter_mut()
211 .zip(a.iter())
212 .zip(b.iter())
213 .for_each(|((r, &ai), &bi)| *r = op(ai, bi));
214 } else {
215 if let (Some(result_slice), Some(a_slice), Some(b_slice)) =
217 (result.as_slice_mut(), a.as_slice(), b.as_slice())
218 {
219 result_slice
220 .par_iter_mut()
221 .zip(a_slice.par_iter())
222 .zip(b_slice.par_iter())
223 .for_each(|((r, &ai), &bi)| *r = op(ai, bi));
224 } else {
225 result
227 .iter_mut()
228 .zip(a.iter())
229 .zip(b.iter())
230 .for_each(|((r, &ai), &bi)| *r = op(ai, bi));
231 }
232 }
233
234 result
235 }
236
237 pub fn apply_row_parallel<T, F>(matrix: &Array2<T>, op: F, config: &ParallelConfig) -> Array1<T>
239 where
240 T: FloatBounds + Send + Sync,
241 F: Fn(ArrayView1<T>) -> T + Send + Sync,
242 {
243 let mut result = Array1::zeros(matrix.nrows());
244
245 if !config.enabled || matrix.nrows() < config.min_parallel_batch_size {
246 result
248 .iter_mut()
249 .zip(matrix.axis_iter(Axis(0)))
250 .for_each(|(r, row)| *r = op(row));
251 } else {
252 if let Some(result_slice) = result.as_slice_mut() {
254 result_slice.par_iter_mut().enumerate().for_each(|(i, r)| {
255 let row = matrix.row(i);
256 *r = op(row);
257 });
258 } else {
259 result
261 .iter_mut()
262 .zip(matrix.axis_iter(Axis(0)))
263 .for_each(|(r, row)| *r = op(row));
264 }
265 }
266
267 result
268 }
269
270 pub fn apply_column_parallel<T, F>(
272 matrix: &Array2<T>,
273 op: F,
274 config: &ParallelConfig,
275 ) -> Array1<T>
276 where
277 T: FloatBounds + Send + Sync,
278 F: Fn(ArrayView1<T>) -> T + Send + Sync,
279 {
280 let mut result = Array1::zeros(matrix.ncols());
281
282 if !config.enabled || matrix.ncols() < config.min_parallel_batch_size {
283 result
285 .iter_mut()
286 .zip(matrix.axis_iter(Axis(1)))
287 .for_each(|(r, col)| *r = op(col));
288 } else {
289 if let Some(result_slice) = result.as_slice_mut() {
291 result_slice.par_iter_mut().enumerate().for_each(|(j, r)| {
292 let col = matrix.column(j);
293 *r = op(col);
294 });
295 } else {
296 result
298 .iter_mut()
299 .zip(matrix.axis_iter(Axis(1)))
300 .for_each(|(r, col)| *r = op(col));
301 }
302 }
303
304 result
305 }
306}
307
308pub struct ParallelCrossValidator<T: FloatBounds> {
310 config: ParallelConfig,
311 _phantom: std::marker::PhantomData<T>,
312}
313
314impl<T: FloatBounds> ParallelCrossValidator<T> {
315 pub fn new(config: ParallelConfig) -> Self {
317 Self {
318 config,
319 _phantom: std::marker::PhantomData,
320 }
321 }
322
323 pub fn k_fold_parallel<X, Y, M, Output>(
325 &self,
326 model: M,
327 x: &X,
328 y: &Y,
329 k: usize,
330 ) -> Result<Vec<T>>
331 where
332 M: Fit<X, Y> + Clone + Send + Sync,
333 M::Fitted: Score<X, Y, Float = T>,
334 X: Clone + Send + Sync,
335 Y: Clone + Send + Sync,
336 T: Send,
337 {
338 if !self.config.enabled || k < 2 {
339 return self.k_fold_sequential(model, x, y, k);
341 }
342
343 let fold_indices: Vec<_> = (0..k).collect();
345
346 let scores: Result<Vec<_>> = fold_indices
348 .into_par_iter()
349 .map(|_fold_idx| {
350 let model_clone = model.clone();
353 let fitted = model_clone.fit(x, y)?;
354 fitted.score(x, y)
355 })
356 .collect();
357
358 scores
359 }
360
361 fn k_fold_sequential<X, Y, M>(&self, model: M, x: &X, y: &Y, k: usize) -> Result<Vec<T>>
363 where
364 M: Fit<X, Y> + Clone,
365 M::Fitted: Score<X, Y, Float = T>,
366 {
367 let mut scores = Vec::with_capacity(k);
368
369 for _fold in 0..k {
370 let model_clone = model.clone();
371 let fitted = model_clone.fit(x, y)?;
372 let score = fitted.score(x, y)?;
373 scores.push(score);
374 }
375
376 Ok(scores)
377 }
378}
379
380pub struct ParallelEnsembleOps;
382
383impl ParallelEnsembleOps {
384 pub fn train_models_parallel<X, Y, M>(
386 models: Vec<M>,
387 x: &X,
388 y: &Y,
389 config: &ParallelConfig,
390 ) -> Result<Vec<M::Fitted>>
391 where
392 M: Fit<X, Y> + Send,
393 M::Fitted: Send,
394 X: Sync,
395 Y: Sync,
396 {
397 if !config.enabled || models.len() < 2 {
398 return models.into_iter().map(|model| model.fit(x, y)).collect();
400 }
401
402 models
404 .into_par_iter()
405 .map(|model| model.fit(x, y))
406 .collect()
407 }
408
409 pub fn predict_parallel<X, Output, M>(
411 models: &[M],
412 x: &X,
413 config: &ParallelConfig,
414 ) -> Result<Vec<Output>>
415 where
416 M: Predict<X, Output> + Sync,
417 Output: Send,
418 X: Sync,
419 {
420 if !config.enabled || models.len() < 2 {
421 return models.iter().map(|model| model.predict(x)).collect();
423 }
424
425 models.par_iter().map(|model| model.predict(x)).collect()
427 }
428}
429
430pub mod utils {
432 use super::*;
433
434 pub fn optimal_chunk_size(total_size: usize, min_chunk_size: usize) -> usize {
436 let num_threads = rayon::current_num_threads();
437 (total_size / num_threads).max(min_chunk_size)
438 }
439
440 pub fn should_use_parallel(data_size: usize, config: &ParallelConfig) -> bool {
442 config.enabled && data_size >= config.min_parallel_batch_size
443 }
444
445 pub fn initialize_thread_pool(num_threads: Option<usize>) -> Result<()> {
447 if let Some(threads) = num_threads {
448 rayon::ThreadPoolBuilder::new()
449 .num_threads(threads)
450 .build_global()
451 .map_err(|e| {
452 crate::error::SklearsError::NumericalError(format!(
453 "Failed to initialize thread pool: {e}"
454 ))
455 })?;
456 }
457 Ok(())
458 }
459}
460
461#[allow(non_snake_case)]
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use approx::assert_relative_eq;
466 use scirs2_core::ndarray::Array2;
467
468 #[test]
469 fn test_parallel_matrix_multiply() {
470 let a = Array2::from_shape_vec((100, 50), (0..5000).map(|x| x as f64).collect()).unwrap();
471 let b =
472 Array2::from_shape_vec((50, 30), (0..1500).map(|x| x as f64 + 1.0).collect()).unwrap();
473
474 let config = ParallelConfig {
475 enabled: true,
476 min_parallel_batch_size: 10,
477 num_threads: None,
478 };
479
480 let result_parallel = ParallelMatrixOps::matrix_multiply_parallel(&a, &b, &config);
481 let result_sequential = a.dot(&b);
482
483 for i in 0..result_parallel.nrows() {
485 for j in 0..result_parallel.ncols() {
486 assert_relative_eq!(
487 result_parallel[[i, j]],
488 result_sequential[[i, j]],
489 epsilon = 1e-10
490 );
491 }
492 }
493 }
494
495 #[test]
496 fn test_parallel_elementwise_ops() {
497 let a = Array2::from_shape_vec((100, 100), (0..10000).map(|x| x as f64).collect()).unwrap();
498 let b = Array2::from_shape_vec((100, 100), (0..10000).map(|x| x as f64 + 1.0).collect())
499 .unwrap();
500
501 let config = ParallelConfig {
502 enabled: true,
503 min_parallel_batch_size: 100,
504 num_threads: None,
505 };
506
507 let result_parallel =
508 ParallelMatrixOps::elementwise_op_parallel(&a, &b, |x, y| x + y, &config);
509 let result_sequential = &a + &b;
510
511 for i in 0..result_parallel.nrows() {
512 for j in 0..result_parallel.ncols() {
513 assert_relative_eq!(
514 result_parallel[[i, j]],
515 result_sequential[[i, j]],
516 epsilon = 1e-10
517 );
518 }
519 }
520 }
521
522 #[test]
523 fn test_optimal_chunk_size() {
524 let num_threads = rayon::current_num_threads();
525 let expected = (1000 / num_threads).max(10);
526 assert_eq!(utils::optimal_chunk_size(1000, 10), expected);
527 assert_eq!(utils::optimal_chunk_size(100, 50), 50); }
529
530 #[test]
531 fn test_should_use_parallel() {
532 let config = ParallelConfig::default();
533 assert!(!utils::should_use_parallel(100, &config)); assert!(utils::should_use_parallel(2000, &config)); let disabled_config = ParallelConfig {
537 enabled: false,
538 ..Default::default()
539 };
540 assert!(!utils::should_use_parallel(2000, &disabled_config)); }
542}