scirs2/parallel.rs
1//! Parallel batch processing utilities.
2//!
3//! Provides Python-facing functions that run Rayon-based parallel computations
4//! while releasing the Python GIL so that Python threads are not blocked.
5//!
6//! # Example (Python)
7//! ```python
8//! import scirs2
9//!
10//! # Compute means of multiple float arrays in parallel.
11//! arrays = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0, 7.0, 8.0, 9.0]]
12//! means = scirs2.parallel_map_mean(arrays)
13//! # means == [2.0, 4.5, 7.5]
14//!
15//! # Parallel batch matrix-vector multiply.
16//! mats = [[1.0, 0.0, 0.0, 1.0], [2.0, 0.0, 0.0, 2.0]]
17//! vecs = [[3.0, 4.0], [3.0, 4.0]]
18//! results = scirs2.parallel_batch_matvec(mats, vecs, 2, 2)
19//! # results == [[3.0, 4.0], [6.0, 8.0]]
20//! ```
21
22use pyo3::prelude::*;
23
24// ──────────────────────────────────────────────────────────────────────────────
25// parallel_map_mean
26// ──────────────────────────────────────────────────────────────────────────────
27
28/// Compute the arithmetic mean of each sub-array in parallel.
29///
30/// The GIL is released during Rayon computation, allowing Python threads to run
31/// concurrently.
32///
33/// Empty sub-arrays return `0.0`.
34///
35/// # Arguments
36/// * `arrays` – list of float arrays, one mean value is produced per element.
37///
38/// # Returns
39/// A list of `f64` mean values with the same length as `arrays`.
40#[pyfunction]
41pub fn parallel_map_mean(py: Python<'_>, arrays: Vec<Vec<f64>>) -> PyResult<Vec<f64>> {
42 let result = py.detach(|| {
43 use rayon::prelude::*;
44 arrays
45 .par_iter()
46 .map(|arr| {
47 if arr.is_empty() {
48 0.0
49 } else {
50 arr.iter().sum::<f64>() / arr.len() as f64
51 }
52 })
53 .collect::<Vec<f64>>()
54 });
55 Ok(result)
56}
57
58// ──────────────────────────────────────────────────────────────────────────────
59// parallel_batch_matvec
60// ──────────────────────────────────────────────────────────────────────────────
61
62/// Perform a batch of matrix-vector multiplications in parallel.
63///
64/// Each pair `(matrices[i], vectors[i])` represents a multiplication
65/// `A * v` where `A` is stored in row-major order with shape
66/// `(n_rows, n_cols)` and `v` is a vector of length `n_cols`.
67///
68/// The GIL is released during Rayon computation.
69///
70/// # Arguments
71/// * `matrices` – list of flat row-major matrices, each with `n_rows * n_cols` elements.
72/// * `vectors` – list of vectors, each with `n_cols` elements.
73/// * `n_rows` – number of rows in each matrix.
74/// * `n_cols` – number of columns in each matrix (= length of each vector).
75///
76/// # Errors
77/// Returns `ValueError` if:
78/// - `matrices` and `vectors` have different lengths.
79/// - Any matrix does not have exactly `n_rows * n_cols` elements.
80/// - Any vector does not have exactly `n_cols` elements.
81#[pyfunction]
82pub fn parallel_batch_matvec(
83 py: Python<'_>,
84 matrices: Vec<Vec<f64>>,
85 vectors: Vec<Vec<f64>>,
86 n_rows: usize,
87 n_cols: usize,
88) -> PyResult<Vec<Vec<f64>>> {
89 if matrices.len() != vectors.len() {
90 return Err(pyo3::exceptions::PyValueError::new_err(format!(
91 "matrices ({}) and vectors ({}) must have equal length",
92 matrices.len(),
93 vectors.len()
94 )));
95 }
96
97 let expected_mat = n_rows * n_cols;
98 for (i, mat) in matrices.iter().enumerate() {
99 if mat.len() != expected_mat {
100 return Err(pyo3::exceptions::PyValueError::new_err(format!(
101 "matrix[{i}] has {} elements but expected {expected_mat} (n_rows={n_rows}, n_cols={n_cols})",
102 mat.len()
103 )));
104 }
105 }
106 for (i, vec) in vectors.iter().enumerate() {
107 if vec.len() != n_cols {
108 return Err(pyo3::exceptions::PyValueError::new_err(format!(
109 "vector[{i}] has {} elements but expected {n_cols} (n_cols)",
110 vec.len()
111 )));
112 }
113 }
114
115 let result = py.detach(|| {
116 use rayon::prelude::*;
117 matrices
118 .par_iter()
119 .zip(vectors.par_iter())
120 .map(|(mat, vec)| {
121 (0..n_rows)
122 .map(|r| {
123 (0..n_cols)
124 .map(|c| mat[r * n_cols + c] * vec[c])
125 .sum::<f64>()
126 })
127 .collect::<Vec<f64>>()
128 })
129 .collect::<Vec<Vec<f64>>>()
130 });
131 Ok(result)
132}
133
134// ──────────────────────────────────────────────────────────────────────────────
135// Module registration
136// ──────────────────────────────────────────────────────────────────────────────
137
138/// Register parallel batch-processing functions into a PyO3 module.
139///
140/// Exposes [`parallel_map_mean`] and [`parallel_batch_matvec`].
141pub fn register_parallel_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
142 m.add_function(wrap_pyfunction!(parallel_map_mean, m)?)?;
143 m.add_function(wrap_pyfunction!(parallel_batch_matvec, m)?)?;
144 Ok(())
145}
146
147// ──────────────────────────────────────────────────────────────────────────────
148// Tests
149// ──────────────────────────────────────────────────────────────────────────────
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn parallel_map_mean_empty_arrays() {
157 pyo3::Python::initialize();
158 Python::attach(|py| {
159 let arrays: Vec<Vec<f64>> = vec![vec![], vec![1.0, 2.0, 3.0], vec![]];
160 let means = parallel_map_mean(py, arrays).expect("parallel_map_mean failed");
161 assert_eq!(means.len(), 3);
162 assert!((means[0]).abs() < f64::EPSILON); // empty → 0.0
163 assert!((means[1] - 2.0).abs() < f64::EPSILON);
164 assert!((means[2]).abs() < f64::EPSILON); // empty → 0.0
165 });
166 }
167
168 #[test]
169 fn parallel_batch_matvec_identity_matrix() {
170 pyo3::Python::initialize();
171 Python::attach(|py| {
172 // 2×2 identity × [3.0, 4.0] should give [3.0, 4.0].
173 let matrices = vec![vec![1.0, 0.0, 0.0, 1.0], vec![2.0, 0.0, 0.0, 2.0]];
174 let vectors = vec![vec![3.0, 4.0], vec![3.0, 4.0]];
175 let results =
176 parallel_batch_matvec(py, matrices, vectors, 2, 2).expect("matvec failed");
177 assert_eq!(results.len(), 2);
178 assert!((results[0][0] - 3.0).abs() < f64::EPSILON);
179 assert!((results[0][1] - 4.0).abs() < f64::EPSILON);
180 assert!((results[1][0] - 6.0).abs() < f64::EPSILON);
181 assert!((results[1][1] - 8.0).abs() < f64::EPSILON);
182 });
183 }
184
185 #[test]
186 fn parallel_batch_matvec_mismatched_lengths_errors() {
187 pyo3::Python::initialize();
188 Python::attach(|py| {
189 let matrices = vec![vec![1.0, 0.0, 0.0, 1.0]];
190 let vectors = vec![vec![1.0], vec![2.0]]; // 2 vectors, 1 matrix → error
191 let result = parallel_batch_matvec(py, matrices, vectors, 2, 2);
192 assert!(result.is_err());
193 });
194 }
195}