1use crate::error::LinalgResult;
8use crate::parallel::{algorithms, WorkerConfig};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
10use scirs2_core::numeric::{Float, NumAssign, One, Zero};
11use std::iter::Sum;
12
13pub struct ParallelDecomposition;
15
16impl ParallelDecomposition {
17 pub fn cholesky<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
19 where
20 F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
21 {
22 if let Some(num_workers) = workers {
23 let config = WorkerConfig::new().with_workers(num_workers);
24 let (m, n) = a.dim();
25
26 if m * n > config.parallel_threshold {
28 return algorithms::parallel_cholesky(a, &config);
29 }
30 }
31
32 crate::decomposition::cholesky(a, workers)
34 }
35
36 pub fn lu<F>(
38 a: &ArrayView2<F>,
39 workers: Option<usize>,
40 ) -> LinalgResult<(Array2<F>, Array2<F>, Array2<F>)>
41 where
42 F: Float
43 + NumAssign
44 + One
45 + Sum
46 + Send
47 + Sync
48 + scirs2_core::ndarray::ScalarOperand
49 + 'static,
50 {
51 if let Some(num_workers) = workers {
52 let config = WorkerConfig::new().with_workers(num_workers);
53 let (m, n) = a.dim();
54
55 if m * n > config.parallel_threshold {
57 return algorithms::parallel_lu(a, &config);
58 }
59 }
60
61 crate::decomposition::lu(a, workers)
63 }
64
65 pub fn qr<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<(Array2<F>, Array2<F>)>
67 where
68 F: Float
69 + NumAssign
70 + One
71 + Sum
72 + Send
73 + Sync
74 + scirs2_core::ndarray::ScalarOperand
75 + 'static,
76 {
77 if let Some(num_workers) = workers {
78 let config = WorkerConfig::new().with_workers(num_workers);
79 let (m, n) = a.dim();
80
81 if m * n > config.parallel_threshold {
83 return algorithms::parallel_qr(a, &config);
84 }
85 }
86
87 crate::decomposition::qr(a, workers)
89 }
90
91 pub fn svd<F>(
93 a: &ArrayView2<F>,
94 full_matrices: bool,
95 workers: Option<usize>,
96 ) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
97 where
98 F: Float
99 + NumAssign
100 + One
101 + Sum
102 + Send
103 + Sync
104 + scirs2_core::ndarray::ScalarOperand
105 + 'static,
106 {
107 if let Some(num_workers) = workers {
108 let config = WorkerConfig::new().with_workers(num_workers);
109 let (m, n) = a.dim();
110
111 if m * n > config.parallel_threshold && !full_matrices {
113 return algorithms::parallel_svd(a, &config);
114 }
115 }
116
117 crate::decomposition::svd(a, full_matrices, workers)
119 }
120}
121
122pub struct ParallelSolver;
124
125impl ParallelSolver {
126 pub fn conjugate_gradient<F>(
128 a: &ArrayView2<F>,
129 b: &ArrayView1<F>,
130 max_iter: usize,
131 tolerance: F,
132 workers: Option<usize>,
133 ) -> LinalgResult<Array1<F>>
134 where
135 F: Float
136 + NumAssign
137 + One
138 + Sum
139 + Send
140 + Sync
141 + scirs2_core::ndarray::ScalarOperand
142 + 'static,
143 {
144 if let Some(num_workers) = workers {
145 let config = WorkerConfig::new().with_workers(num_workers);
146 let (m, n) = a.dim();
147
148 if m * n > config.parallel_threshold {
150 return algorithms::parallel_conjugate_gradient(a, b, max_iter, tolerance, &config);
151 }
152 }
153
154 crate::iterative_solvers::conjugate_gradient(a, b, max_iter, tolerance, None)
156 }
157
158 pub fn gmres<F>(
160 a: &ArrayView2<F>,
161 b: &ArrayView1<F>,
162 max_iter: usize,
163 tolerance: F,
164 restart: usize,
165 workers: Option<usize>,
166 ) -> LinalgResult<Array1<F>>
167 where
168 F: Float
169 + NumAssign
170 + One
171 + Sum
172 + Send
173 + Sync
174 + scirs2_core::ndarray::ScalarOperand
175 + std::fmt::Debug
176 + std::fmt::Display
177 + 'static,
178 {
179 if let Some(num_workers) = workers {
180 let config = WorkerConfig::new().with_workers(num_workers);
181 let (m, n) = a.dim();
182
183 if m * n > config.parallel_threshold {
185 return algorithms::parallel_gmres(a, b, max_iter, tolerance, restart, &config);
186 }
187 }
188
189 let options = crate::solvers::iterative::IterativeSolverOptions {
191 max_iterations: max_iter,
192 tolerance,
193 verbose: false,
194 restart: Some(restart),
195 };
196 crate::solvers::iterative::gmres(a, b, None, &options).map(|result| result.solution)
197 }
198
199 pub fn bicgstab<F>(
201 a: &ArrayView2<F>,
202 b: &ArrayView1<F>,
203 max_iter: usize,
204 tolerance: F,
205 workers: Option<usize>,
206 ) -> LinalgResult<Array1<F>>
207 where
208 F: Float
209 + NumAssign
210 + One
211 + Sum
212 + Send
213 + Sync
214 + scirs2_core::ndarray::ScalarOperand
215 + 'static,
216 {
217 if let Some(num_workers) = workers {
218 let config = WorkerConfig::new().with_workers(num_workers);
219 let (m, n) = a.dim();
220
221 if m * n > config.parallel_threshold {
223 return algorithms::parallel_bicgstab(a, b, max_iter, tolerance, &config);
224 }
225 }
226
227 crate::iterative_solvers::bicgstab(a, b, max_iter, tolerance, None)
229 }
230
231 pub fn jacobi<F>(
233 a: &ArrayView2<F>,
234 b: &ArrayView1<F>,
235 max_iter: usize,
236 tolerance: F,
237 workers: Option<usize>,
238 ) -> LinalgResult<Array1<F>>
239 where
240 F: Float
241 + NumAssign
242 + One
243 + Sum
244 + Send
245 + Sync
246 + scirs2_core::ndarray::ScalarOperand
247 + 'static,
248 {
249 if let Some(num_workers) = workers {
250 let config = WorkerConfig::new().with_workers(num_workers);
251 let (m, n) = a.dim();
252
253 if m * n > config.parallel_threshold {
255 return algorithms::parallel_jacobi(a, b, max_iter, tolerance, &config);
256 }
257 }
258
259 crate::iterative_solvers::jacobi_method(a, b, max_iter, tolerance, None)
261 }
262
263 pub fn sor<F>(
265 a: &ArrayView2<F>,
266 b: &ArrayView1<F>,
267 omega: F,
268 max_iter: usize,
269 tolerance: F,
270 workers: Option<usize>,
271 ) -> LinalgResult<Array1<F>>
272 where
273 F: Float
274 + NumAssign
275 + One
276 + Sum
277 + Send
278 + Sync
279 + scirs2_core::ndarray::ScalarOperand
280 + 'static,
281 {
282 if let Some(num_workers) = workers {
283 let config = WorkerConfig::new().with_workers(num_workers);
284 let (m, n) = a.dim();
285
286 if m * n > config.parallel_threshold {
288 return algorithms::parallel_sor(a, b, omega, max_iter, tolerance, &config);
289 }
290 }
291
292 crate::iterative_solvers::successive_over_relaxation(a, b, omega, max_iter, tolerance, None)
294 }
295}
296
297pub struct ParallelOperations;
299
300impl ParallelOperations {
301 pub fn matmul<F>(
303 a: &ArrayView2<F>,
304 b: &ArrayView2<F>,
305 workers: Option<usize>,
306 ) -> LinalgResult<Array2<F>>
307 where
308 F: Float + NumAssign + Zero + Sum + Send + Sync + 'static,
309 {
310 if let Some(num_workers) = workers {
311 let config = WorkerConfig::new().with_workers(num_workers);
312 let (m, k) = a.dim();
313 let (_, n) = b.dim();
314
315 if m * k * n > config.parallel_threshold {
317 return algorithms::parallel_gemm(a, b, &config);
318 }
319 }
320
321 Ok(a.dot(b))
323 }
324
325 pub fn matvec<F>(
327 a: &ArrayView2<F>,
328 x: &ArrayView1<F>,
329 workers: Option<usize>,
330 ) -> LinalgResult<Array1<F>>
331 where
332 F: Float + Zero + Sum + Send + Sync + 'static,
333 {
334 if let Some(num_workers) = workers {
335 let config = WorkerConfig::new().with_workers(num_workers);
336 let (m, n) = a.dim();
337
338 if m * n > config.parallel_threshold {
340 return algorithms::parallel_matvec(a, x, &config);
341 }
342 }
343
344 Ok(a.dot(x))
346 }
347
348 pub fn power_iteration<F>(
350 a: &ArrayView2<F>,
351 max_iter: usize,
352 tolerance: F,
353 workers: Option<usize>,
354 ) -> LinalgResult<(F, Array1<F>)>
355 where
356 F: Float
357 + NumAssign
358 + One
359 + Zero
360 + Sum
361 + Send
362 + Sync
363 + scirs2_core::ndarray::ScalarOperand
364 + 'static,
365 {
366 if let Some(num_workers) = workers {
367 let config = WorkerConfig::new().with_workers(num_workers);
368 let (m, n) = a.dim();
369
370 if m * n > config.parallel_threshold {
372 return algorithms::parallel_power_iteration(a, max_iter, tolerance, &config);
373 }
374 }
375
376 crate::eigen::power_iteration(a, max_iter, tolerance)
378 }
379}
380
381pub struct ParallelConfig {
383 workers: Option<usize>,
384 threshold_multiplier: f64,
385}
386
387impl ParallelConfig {
388 pub fn new() -> Self {
390 Self {
391 workers: None,
392 threshold_multiplier: 1.0,
393 }
394 }
395
396 pub fn with_workers(mut self, workers: usize) -> Self {
398 self.workers = Some(workers);
399 self
400 }
401
402 pub fn with_threshold_multiplier(mut self, multiplier: f64) -> Self {
407 self.threshold_multiplier = multiplier;
408 self
409 }
410
411 pub fn build(&self) -> WorkerConfig {
413 let mut config = WorkerConfig::new();
414
415 if let Some(workers) = self.workers {
416 config = config.with_workers(workers);
417 }
418
419 let base_threshold = config.parallel_threshold;
420 config =
421 config.with_threshold((base_threshold as f64 * self.threshold_multiplier) as usize);
422
423 config
424 }
425}
426
427impl Default for ParallelConfig {
428 fn default() -> Self {
429 Self::new()
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use scirs2_core::ndarray::array;
437
438 #[test]
439 fn test_parallel_dispatch_smallmatrix() {
440 let a = array![[1.0, 2.0], [2.0, 5.0]];
442 let result = ParallelDecomposition::cholesky(&a.view(), Some(4));
443 assert!(result.is_ok());
444 }
445
446 #[test]
447 fn test_parallel_config_builder() {
448 let config = ParallelConfig::new()
449 .with_workers(8)
450 .with_threshold_multiplier(2.0)
451 .build();
452
453 assert_eq!(config.workers, Some(8));
454 assert_eq!(config.parallel_threshold, 2000); }
456}