Skip to main content

scirs2_linalg/
parallel_dispatch.rs

1//! Parallel algorithm dispatch for linear algebra operations
2//!
3//! This module provides utilities for automatically selecting and dispatching
4//! to parallel implementations when appropriate, based on matrix size and
5//! worker configuration.
6
7use crate::error::LinalgResult;
8use crate::parallel::{algorithms, WorkerConfig};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
10use scirs2_core::numeric::{Float, NumAssign, One, Zero};
11use std::iter::Sum;
12
13/// Parallel-aware matrix decomposition dispatcher
14pub struct ParallelDecomposition;
15
16impl ParallelDecomposition {
17    /// Choose and execute the appropriate Cholesky decomposition implementation
18    pub fn cholesky<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
19    where
20        F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
21    {
22        if let Some(num_workers) = workers {
23            let config = WorkerConfig::new().with_workers(num_workers);
24            let (m, n) = a.dim();
25
26            // Use parallel implementation for large matrices
27            if m * n > config.parallel_threshold {
28                return algorithms::parallel_cholesky(a, &config);
29            }
30        }
31
32        // Fall back to standard implementation
33        crate::decomposition::cholesky(a, workers)
34    }
35
36    /// Choose and execute the appropriate LU decomposition implementation
37    pub fn lu<F>(
38        a: &ArrayView2<F>,
39        workers: Option<usize>,
40    ) -> LinalgResult<(Array2<F>, Array2<F>, Array2<F>)>
41    where
42        F: Float
43            + NumAssign
44            + One
45            + Sum
46            + Send
47            + Sync
48            + scirs2_core::ndarray::ScalarOperand
49            + 'static,
50    {
51        if let Some(num_workers) = workers {
52            let config = WorkerConfig::new().with_workers(num_workers);
53            let (m, n) = a.dim();
54
55            // Use parallel implementation for large matrices
56            if m * n > config.parallel_threshold {
57                return algorithms::parallel_lu(a, &config);
58            }
59        }
60
61        // Fall back to standard implementation
62        crate::decomposition::lu(a, workers)
63    }
64
65    /// Choose and execute the appropriate QR decomposition implementation
66    pub fn qr<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<(Array2<F>, Array2<F>)>
67    where
68        F: Float
69            + NumAssign
70            + One
71            + Sum
72            + Send
73            + Sync
74            + scirs2_core::ndarray::ScalarOperand
75            + 'static,
76    {
77        if let Some(num_workers) = workers {
78            let config = WorkerConfig::new().with_workers(num_workers);
79            let (m, n) = a.dim();
80
81            // Use parallel implementation for large matrices
82            if m * n > config.parallel_threshold {
83                return algorithms::parallel_qr(a, &config);
84            }
85        }
86
87        // Fall back to standard implementation
88        crate::decomposition::qr(a, workers)
89    }
90
91    /// Choose and execute the appropriate SVD implementation
92    pub fn svd<F>(
93        a: &ArrayView2<F>,
94        full_matrices: bool,
95        workers: Option<usize>,
96    ) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
97    where
98        F: Float
99            + NumAssign
100            + One
101            + Sum
102            + Send
103            + Sync
104            + scirs2_core::ndarray::ScalarOperand
105            + 'static,
106    {
107        if let Some(num_workers) = workers {
108            let config = WorkerConfig::new().with_workers(num_workers);
109            let (m, n) = a.dim();
110
111            // Use parallel implementation for large _matrices (only supports reduced form)
112            if m * n > config.parallel_threshold && !full_matrices {
113                return algorithms::parallel_svd(a, &config);
114            }
115        }
116
117        // Fall back to standard implementation
118        crate::decomposition::svd(a, full_matrices, workers)
119    }
120}
121
122/// Parallel-aware solver dispatcher
123pub struct ParallelSolver;
124
125impl ParallelSolver {
126    /// Choose and execute the appropriate conjugate gradient implementation
127    pub fn conjugate_gradient<F>(
128        a: &ArrayView2<F>,
129        b: &ArrayView1<F>,
130        max_iter: usize,
131        tolerance: F,
132        workers: Option<usize>,
133    ) -> LinalgResult<Array1<F>>
134    where
135        F: Float
136            + NumAssign
137            + One
138            + Sum
139            + Send
140            + Sync
141            + scirs2_core::ndarray::ScalarOperand
142            + 'static,
143    {
144        if let Some(num_workers) = workers {
145            let config = WorkerConfig::new().with_workers(num_workers);
146            let (m, n) = a.dim();
147
148            // Use parallel implementation for large matrices
149            if m * n > config.parallel_threshold {
150                return algorithms::parallel_conjugate_gradient(a, b, max_iter, tolerance, &config);
151            }
152        }
153
154        // Fall back to standard implementation
155        crate::iterative_solvers::conjugate_gradient(a, b, max_iter, tolerance, None)
156    }
157
158    /// Choose and execute the appropriate GMRES implementation
159    pub fn gmres<F>(
160        a: &ArrayView2<F>,
161        b: &ArrayView1<F>,
162        max_iter: usize,
163        tolerance: F,
164        restart: usize,
165        workers: Option<usize>,
166    ) -> LinalgResult<Array1<F>>
167    where
168        F: Float
169            + NumAssign
170            + One
171            + Sum
172            + Send
173            + Sync
174            + scirs2_core::ndarray::ScalarOperand
175            + std::fmt::Debug
176            + std::fmt::Display
177            + 'static,
178    {
179        if let Some(num_workers) = workers {
180            let config = WorkerConfig::new().with_workers(num_workers);
181            let (m, n) = a.dim();
182
183            // Use parallel implementation for large matrices
184            if m * n > config.parallel_threshold {
185                return algorithms::parallel_gmres(a, b, max_iter, tolerance, restart, &config);
186            }
187        }
188
189        // Fall back to standard implementation
190        let options = crate::solvers::iterative::IterativeSolverOptions {
191            max_iterations: max_iter,
192            tolerance,
193            verbose: false,
194            restart: Some(restart),
195        };
196        crate::solvers::iterative::gmres(a, b, None, &options).map(|result| result.solution)
197    }
198
199    /// Choose and execute the appropriate BiCGSTAB implementation
200    pub fn bicgstab<F>(
201        a: &ArrayView2<F>,
202        b: &ArrayView1<F>,
203        max_iter: usize,
204        tolerance: F,
205        workers: Option<usize>,
206    ) -> LinalgResult<Array1<F>>
207    where
208        F: Float
209            + NumAssign
210            + One
211            + Sum
212            + Send
213            + Sync
214            + scirs2_core::ndarray::ScalarOperand
215            + 'static,
216    {
217        if let Some(num_workers) = workers {
218            let config = WorkerConfig::new().with_workers(num_workers);
219            let (m, n) = a.dim();
220
221            // Use parallel implementation for large matrices
222            if m * n > config.parallel_threshold {
223                return algorithms::parallel_bicgstab(a, b, max_iter, tolerance, &config);
224            }
225        }
226
227        // Fall back to standard implementation
228        crate::iterative_solvers::bicgstab(a, b, max_iter, tolerance, None)
229    }
230
231    /// Choose and execute the appropriate Jacobi method implementation
232    pub fn jacobi<F>(
233        a: &ArrayView2<F>,
234        b: &ArrayView1<F>,
235        max_iter: usize,
236        tolerance: F,
237        workers: Option<usize>,
238    ) -> LinalgResult<Array1<F>>
239    where
240        F: Float
241            + NumAssign
242            + One
243            + Sum
244            + Send
245            + Sync
246            + scirs2_core::ndarray::ScalarOperand
247            + 'static,
248    {
249        if let Some(num_workers) = workers {
250            let config = WorkerConfig::new().with_workers(num_workers);
251            let (m, n) = a.dim();
252
253            // Use parallel implementation for large matrices
254            if m * n > config.parallel_threshold {
255                return algorithms::parallel_jacobi(a, b, max_iter, tolerance, &config);
256            }
257        }
258
259        // Fall back to standard implementation
260        crate::iterative_solvers::jacobi_method(a, b, max_iter, tolerance, None)
261    }
262
263    /// Choose and execute the appropriate SOR method implementation
264    pub fn sor<F>(
265        a: &ArrayView2<F>,
266        b: &ArrayView1<F>,
267        omega: F,
268        max_iter: usize,
269        tolerance: F,
270        workers: Option<usize>,
271    ) -> LinalgResult<Array1<F>>
272    where
273        F: Float
274            + NumAssign
275            + One
276            + Sum
277            + Send
278            + Sync
279            + scirs2_core::ndarray::ScalarOperand
280            + 'static,
281    {
282        if let Some(num_workers) = workers {
283            let config = WorkerConfig::new().with_workers(num_workers);
284            let (m, n) = a.dim();
285
286            // Use parallel implementation for large matrices
287            if m * n > config.parallel_threshold {
288                return algorithms::parallel_sor(a, b, omega, max_iter, tolerance, &config);
289            }
290        }
291
292        // Fall back to standard implementation
293        crate::iterative_solvers::successive_over_relaxation(a, b, omega, max_iter, tolerance, None)
294    }
295}
296
297/// Parallel-aware matrix operations dispatcher
298pub struct ParallelOperations;
299
300impl ParallelOperations {
301    /// Choose and execute the appropriate matrix multiplication implementation
302    pub fn matmul<F>(
303        a: &ArrayView2<F>,
304        b: &ArrayView2<F>,
305        workers: Option<usize>,
306    ) -> LinalgResult<Array2<F>>
307    where
308        F: Float + NumAssign + Zero + Sum + Send + Sync + 'static,
309    {
310        if let Some(num_workers) = workers {
311            let config = WorkerConfig::new().with_workers(num_workers);
312            let (m, k) = a.dim();
313            let (_, n) = b.dim();
314
315            // Use parallel implementation for large matrices
316            if m * k * n > config.parallel_threshold {
317                return algorithms::parallel_gemm(a, b, &config);
318            }
319        }
320
321        // Fall back to standard implementation
322        Ok(a.dot(b))
323    }
324
325    /// Choose and execute the appropriate matrix-vector multiplication implementation
326    pub fn matvec<F>(
327        a: &ArrayView2<F>,
328        x: &ArrayView1<F>,
329        workers: Option<usize>,
330    ) -> LinalgResult<Array1<F>>
331    where
332        F: Float + Zero + Sum + Send + Sync + 'static,
333    {
334        if let Some(num_workers) = workers {
335            let config = WorkerConfig::new().with_workers(num_workers);
336            let (m, n) = a.dim();
337
338            // Use parallel implementation for large matrices
339            if m * n > config.parallel_threshold {
340                return algorithms::parallel_matvec(a, x, &config);
341            }
342        }
343
344        // Fall back to standard implementation
345        Ok(a.dot(x))
346    }
347
348    /// Choose and execute the appropriate power iteration implementation
349    pub fn power_iteration<F>(
350        a: &ArrayView2<F>,
351        max_iter: usize,
352        tolerance: F,
353        workers: Option<usize>,
354    ) -> LinalgResult<(F, Array1<F>)>
355    where
356        F: Float
357            + NumAssign
358            + One
359            + Zero
360            + Sum
361            + Send
362            + Sync
363            + scirs2_core::ndarray::ScalarOperand
364            + 'static,
365    {
366        if let Some(num_workers) = workers {
367            let config = WorkerConfig::new().with_workers(num_workers);
368            let (m, n) = a.dim();
369
370            // Use parallel implementation for large matrices
371            if m * n > config.parallel_threshold {
372                return algorithms::parallel_power_iteration(a, max_iter, tolerance, &config);
373            }
374        }
375
376        // Fall back to standard implementation
377        crate::eigen::power_iteration(a, max_iter, tolerance)
378    }
379}
380
381/// Configuration builder for parallel dispatch
382pub struct ParallelConfig {
383    workers: Option<usize>,
384    threshold_multiplier: f64,
385}
386
387impl ParallelConfig {
388    /// Create a new parallel configuration
389    pub fn new() -> Self {
390        Self {
391            workers: None,
392            threshold_multiplier: 1.0,
393        }
394    }
395
396    /// Set the number of worker threads
397    pub fn with_workers(mut self, workers: usize) -> Self {
398        self.workers = Some(workers);
399        self
400    }
401
402    /// Set the threshold multiplier for parallel execution
403    ///
404    /// A value of 2.0 means matrices need to be 2x larger than default
405    /// threshold to use parallel implementation
406    pub fn with_threshold_multiplier(mut self, multiplier: f64) -> Self {
407        self.threshold_multiplier = multiplier;
408        self
409    }
410
411    /// Build a WorkerConfig from this configuration
412    pub fn build(&self) -> WorkerConfig {
413        let mut config = WorkerConfig::new();
414
415        if let Some(workers) = self.workers {
416            config = config.with_workers(workers);
417        }
418
419        let base_threshold = config.parallel_threshold;
420        config =
421            config.with_threshold((base_threshold as f64 * self.threshold_multiplier) as usize);
422
423        config
424    }
425}
426
427impl Default for ParallelConfig {
428    fn default() -> Self {
429        Self::new()
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use scirs2_core::ndarray::array;
437
438    #[test]
439    fn test_parallel_dispatch_smallmatrix() {
440        // Small matrix should use serial implementation
441        let a = array![[1.0, 2.0], [2.0, 5.0]];
442        let result = ParallelDecomposition::cholesky(&a.view(), Some(4));
443        assert!(result.is_ok());
444    }
445
446    #[test]
447    fn test_parallel_config_builder() {
448        let config = ParallelConfig::new()
449            .with_workers(8)
450            .with_threshold_multiplier(2.0)
451            .build();
452
453        assert_eq!(config.workers, Some(8));
454        assert_eq!(config.parallel_threshold, 2000); // Default 1000 * 2.0
455    }
456}