1use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9use rayon::prelude::*;
10
11use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
13
14use scirs2_core::{Array1, Array2};
16
17use scirs2_linalg::compat::pinv;
19use scirs2_linalg::{
20 basic_trace, cond,
22 eig,
23 lstsq,
24 matrix_norm,
25 matrix_rank,
26 vector_norm,
27};
28
29type SvdResult = (Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>);
31
32#[pyfunction]
39fn det_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
40 let binding = a.readonly();
41 let data = binding.as_array();
42
43 scirs2_linalg::det_f64_lapack(&data)
45 .map_err(|e| PyRuntimeError::new_err(format!("Determinant failed: {}", e)))
46}
47
48#[pyfunction]
50fn inv_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
51 let binding = a.readonly();
52 let data = binding.as_array();
53
54 let result = scirs2_linalg::inv_f64_lapack(&data)
56 .map_err(|e| PyRuntimeError::new_err(format!("Inverse failed: {}", e)))?;
57
58 Ok(result.into_pyarray(py).unbind())
59}
60
61#[pyfunction]
63fn trace_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
64 let binding = a.readonly();
65 let data = binding.as_array();
66
67 basic_trace(&data).map_err(|e| PyRuntimeError::new_err(format!("Trace failed: {}", e)))
68}
69
70#[pyfunction]
77fn lu_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
78 let binding = a.readonly();
79 let data = binding.as_array();
80
81 let (p, l, u) = scirs2_linalg::lu_f64_lapack(&data)
83 .map_err(|e| PyRuntimeError::new_err(format!("LU decomposition failed: {}", e)))?;
84
85 let dict = PyDict::new(py);
86 dict.set_item("p", p.into_pyarray(py).unbind())?;
87 dict.set_item("l", l.into_pyarray(py).unbind())?;
88 dict.set_item("u", u.into_pyarray(py).unbind())?;
89
90 Ok(dict.into())
91}
92
93#[pyfunction]
96fn qr_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
97 let binding = a.readonly();
98 let data = binding.as_array();
99
100 let (q, r) = scirs2_linalg::qr_f64_lapack(&data)
102 .map_err(|e| PyRuntimeError::new_err(format!("QR decomposition failed: {}", e)))?;
103
104 let dict = PyDict::new(py);
105 dict.set_item("q", q.into_pyarray(py).unbind())?;
106 dict.set_item("r", r.into_pyarray(py).unbind())?;
107
108 Ok(dict.into())
109}
110
111#[pyfunction]
114#[pyo3(signature = (a, full_matrices=false))]
115fn svd_py(py: Python, a: &Bound<'_, PyArray2<f64>>, full_matrices: bool) -> PyResult<Py<PyAny>> {
116 let binding = a.readonly();
117 let data = binding.as_array();
118
119 let (u, s, vt) = scirs2_linalg::svd_f64_lapack(&data, full_matrices)
121 .map_err(|e| PyRuntimeError::new_err(format!("SVD decomposition failed: {}", e)))?;
122
123 let dict = PyDict::new(py);
124 dict.set_item("u", u.into_pyarray(py).unbind())?;
125 dict.set_item("s", s.into_pyarray(py).unbind())?;
126 dict.set_item("vt", vt.into_pyarray(py).unbind())?;
127
128 Ok(dict.into())
129}
130
131#[pyfunction]
133fn cholesky_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
134 let binding = a.readonly();
135 let data = binding.as_array();
136
137 let result = scirs2_linalg::cholesky_f64_lapack(&data)
139 .map_err(|e| PyRuntimeError::new_err(format!("Cholesky decomposition failed: {}", e)))?;
140
141 Ok(result.into_pyarray(py).unbind())
142}
143
144#[pyfunction]
147fn eig_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
148 let binding = a.readonly();
149 let data = binding.as_array();
150
151 let (eigenvalues, eigenvectors) = scirs2_linalg::eig_f64_lapack(&data)
153 .map_err(|e| PyRuntimeError::new_err(format!("Eigenvalue decomposition failed: {}", e)))?;
154
155 let eigenvalues_real: Vec<f64> = eigenvalues.iter().map(|c| c.re).collect();
157 let eigenvalues_imag: Vec<f64> = eigenvalues.iter().map(|c| c.im).collect();
158
159 let (nrows, ncols) = eigenvectors.dim();
160 let mut eigenvectors_real = Array2::zeros((nrows, ncols));
161 let mut eigenvectors_imag = Array2::zeros((nrows, ncols));
162
163 for ((i, j), val) in eigenvectors.indexed_iter() {
164 eigenvectors_real[[i, j]] = val.re;
165 eigenvectors_imag[[i, j]] = val.im;
166 }
167
168 let dict = PyDict::new(py);
169 dict.set_item(
170 "eigenvalues_real",
171 Array1::from_vec(eigenvalues_real).into_pyarray(py).unbind(),
172 )?;
173 dict.set_item(
174 "eigenvalues_imag",
175 Array1::from_vec(eigenvalues_imag).into_pyarray(py).unbind(),
176 )?;
177 dict.set_item(
178 "eigenvectors_real",
179 eigenvectors_real.into_pyarray(py).unbind(),
180 )?;
181 dict.set_item(
182 "eigenvectors_imag",
183 eigenvectors_imag.into_pyarray(py).unbind(),
184 )?;
185
186 Ok(dict.into())
187}
188
189#[pyfunction]
192fn eigh_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
193 let binding = a.readonly();
194 let data = binding.as_array();
195
196 let (eigenvalues, eigenvectors) = scirs2_linalg::eigh_f64_lapack(&data).map_err(|e| {
198 PyRuntimeError::new_err(format!("Symmetric eigenvalue decomposition failed: {}", e))
199 })?;
200
201 let dict = PyDict::new(py);
202 dict.set_item("eigenvalues", eigenvalues.into_pyarray(py).unbind())?;
203 dict.set_item("eigenvectors", eigenvectors.into_pyarray(py).unbind())?;
204
205 Ok(dict.into())
206}
207
208#[pyfunction]
211fn eigvals_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
212 let binding = a.readonly();
213 let data = binding.as_array();
214
215 let (eigenvalues, _eigenvectors) = eig(&data, None)
216 .map_err(|e| PyRuntimeError::new_err(format!("Eigenvalue computation failed: {}", e)))?;
217
218 let real: Vec<f64> = eigenvalues.iter().map(|c| c.re).collect();
220 let imag: Vec<f64> = eigenvalues.iter().map(|c| c.im).collect();
221
222 let dict = PyDict::new(py);
223 dict.set_item("real", Array1::from_vec(real).into_pyarray(py).unbind())?;
224 dict.set_item("imag", Array1::from_vec(imag).into_pyarray(py).unbind())?;
225
226 Ok(dict.into())
227}
228
229#[pyfunction]
235fn solve_py(
236 py: Python,
237 a: &Bound<'_, PyArray2<f64>>,
238 b: &Bound<'_, PyArray1<f64>>,
239) -> PyResult<Py<PyArray1<f64>>> {
240 let a_binding = a.readonly();
241 let a_data = a_binding.as_array();
242 let b_binding = b.readonly();
243 let b_data = b_binding.as_array();
244
245 let result = scirs2_linalg::solve_f64_lapack(&a_data, &b_data)
247 .map_err(|e| PyRuntimeError::new_err(format!("Linear solve failed: {}", e)))?;
248
249 Ok(result.into_pyarray(py).unbind())
250}
251
252#[pyfunction]
255fn lstsq_py(
256 py: Python,
257 a: &Bound<'_, PyArray2<f64>>,
258 b: &Bound<'_, PyArray1<f64>>,
259) -> PyResult<Py<PyAny>> {
260 let a_binding = a.readonly();
261 let a_data = a_binding.as_array();
262 let b_binding = b.readonly();
263 let b_data = b_binding.as_array();
264
265 let result = lstsq(&a_data, &b_data, None)
266 .map_err(|e| PyRuntimeError::new_err(format!("Least squares failed: {}", e)))?;
267
268 let dict = PyDict::new(py);
269 dict.set_item("solution", result.x.into_pyarray(py).unbind())?;
270 dict.set_item("residuals", result.residuals)?;
271 dict.set_item("rank", result.rank)?;
272 dict.set_item("singular_values", result.s.into_pyarray(py).unbind())?;
273
274 Ok(dict.into())
275}
276
277#[pyfunction]
284#[pyo3(signature = (a, ord="fro"))]
285fn matrix_norm_py(a: &Bound<'_, PyArray2<f64>>, ord: &str) -> PyResult<f64> {
286 let binding = a.readonly();
287 let data = binding.as_array();
288
289 matrix_norm(&data, ord, None)
290 .map_err(|e| PyRuntimeError::new_err(format!("Matrix norm failed: {}", e)))
291}
292
293#[pyfunction]
296#[pyo3(signature = (x, ord=2))]
297fn vector_norm_py(x: &Bound<'_, PyArray1<f64>>, ord: usize) -> PyResult<f64> {
298 let binding = x.readonly();
299 let data = binding.as_array();
300
301 vector_norm(&data, ord)
302 .map_err(|e| PyRuntimeError::new_err(format!("Vector norm failed: {}", e)))
303}
304
305#[pyfunction]
307fn cond_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
308 let binding = a.readonly();
309 let data = binding.as_array();
310
311 cond(&data, None, None)
312 .map_err(|e| PyRuntimeError::new_err(format!("Condition number failed: {}", e)))
313}
314
315#[pyfunction]
317#[pyo3(signature = (a, tol=None))]
318fn matrix_rank_py(a: &Bound<'_, PyArray2<f64>>, tol: Option<f64>) -> PyResult<usize> {
319 let binding = a.readonly();
320 let data = binding.as_array();
321
322 matrix_rank(&data, tol, None)
323 .map_err(|e| PyRuntimeError::new_err(format!("Matrix rank failed: {}", e)))
324}
325
326#[pyfunction]
328#[pyo3(signature = (a, rcond=None))]
329fn pinv_py(
330 py: Python,
331 a: &Bound<'_, PyArray2<f64>>,
332 rcond: Option<f64>,
333) -> PyResult<Py<PyArray2<f64>>> {
334 let binding = a.readonly();
335 let data = binding.as_array();
336
337 let _ = rcond; let result =
340 pinv(&data).map_err(|e| PyRuntimeError::new_err(format!("Pseudoinverse failed: {}", e)))?;
341
342 Ok(result.into_pyarray(py).unbind())
343}
344
345fn vec2d_to_array2(rows: &[Vec<f64>]) -> Result<Array2<f64>, String> {
351 if rows.is_empty() {
352 return Err("Matrix has no rows".to_string());
353 }
354 let nrows = rows.len();
355 let ncols = rows[0].len();
356 for (i, row) in rows.iter().enumerate() {
357 if row.len() != ncols {
358 return Err(format!(
359 "Row {} has {} columns, expected {}",
360 i,
361 row.len(),
362 ncols
363 ));
364 }
365 }
366 let flat: Vec<f64> = rows.iter().flat_map(|r| r.iter().cloned()).collect();
367 Array2::from_shape_vec((nrows, ncols), flat).map_err(|e| format!("Shape error: {}", e))
368}
369
370fn array2_to_vec2d(a: &Array2<f64>) -> Vec<Vec<f64>> {
372 let (nrows, ncols) = a.dim();
373 (0..nrows)
374 .map(|i| (0..ncols).map(|j| a[[i, j]]).collect())
375 .collect()
376}
377
378#[pyfunction]
387fn batch_matmul_py(
388 a_list: Vec<Vec<Vec<f64>>>,
389 b_list: Vec<Vec<Vec<f64>>>,
390) -> PyResult<Vec<Vec<Vec<f64>>>> {
391 if a_list.len() != b_list.len() {
392 return Err(PyRuntimeError::new_err(format!(
393 "a_list length {} does not match b_list length {}",
394 a_list.len(),
395 b_list.len()
396 )));
397 }
398 if a_list.is_empty() {
399 return Ok(vec![]);
400 }
401
402 let results: Vec<Result<Vec<Vec<f64>>, String>> = a_list
403 .par_iter()
404 .zip(b_list.par_iter())
405 .map(|(a_rows, b_rows)| {
406 let a = vec2d_to_array2(a_rows)?;
407 let b = vec2d_to_array2(b_rows)?;
408 let (_, a_cols) = a.dim();
409 let (b_rows_n, _) = b.dim();
410 if a_cols != b_rows_n {
411 return Err(format!(
412 "Incompatible shapes for matmul: inner dims {} != {}",
413 a_cols, b_rows_n
414 ));
415 }
416 let c = a.dot(&b);
418 Ok(array2_to_vec2d(&c))
419 })
420 .collect();
421
422 results
423 .into_iter()
424 .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_matmul failed: {}", e))))
425 .collect()
426}
427
428#[pyfunction]
436fn batch_svd_py(matrices: Vec<Vec<Vec<f64>>>) -> PyResult<Vec<SvdResult>> {
437 if matrices.is_empty() {
438 return Ok(vec![]);
439 }
440
441 let results: Vec<Result<SvdResult, String>> = matrices
442 .par_iter()
443 .map(|mat_rows| {
444 let mat = vec2d_to_array2(mat_rows)?;
445 let (u, s, vt) = scirs2_linalg::svd_f64_lapack(&mat.view(), false)
446 .map_err(|e| format!("SVD failed: {}", e))?;
447 Ok((array2_to_vec2d(&u), s.to_vec(), array2_to_vec2d(&vt)))
448 })
449 .collect();
450
451 results
452 .into_iter()
453 .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_svd failed: {}", e))))
454 .collect()
455}
456
457#[pyfunction]
466fn batch_solve_py(a_list: Vec<Vec<Vec<f64>>>, b_list: Vec<Vec<f64>>) -> PyResult<Vec<Vec<f64>>> {
467 if a_list.len() != b_list.len() {
468 return Err(PyRuntimeError::new_err(format!(
469 "a_list length {} does not match b_list length {}",
470 a_list.len(),
471 b_list.len()
472 )));
473 }
474 if a_list.is_empty() {
475 return Ok(vec![]);
476 }
477
478 let results: Vec<Result<Vec<f64>, String>> = a_list
479 .par_iter()
480 .zip(b_list.par_iter())
481 .map(|(a_rows, b_vec)| {
482 let a = vec2d_to_array2(a_rows)?;
483 let b = Array1::from_vec(b_vec.clone());
484 let x = scirs2_linalg::solve_f64_lapack(&a.view(), &b.view())
485 .map_err(|e| format!("Solve failed: {}", e))?;
486 Ok(x.to_vec())
487 })
488 .collect();
489
490 results
491 .into_iter()
492 .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_solve failed: {}", e))))
493 .collect()
494}
495
496#[pyfunction]
505#[pyo3(signature = (matrices, ord=None))]
506fn batch_matrix_norm_py(matrices: Vec<Vec<Vec<f64>>>, ord: Option<&str>) -> PyResult<Vec<f64>> {
507 if matrices.is_empty() {
508 return Ok(vec![]);
509 }
510 let norm_type = ord.unwrap_or("fro");
511 match norm_type {
513 "fro" | "nuc" | "1" | "inf" => {}
514 other => {
515 return Err(PyRuntimeError::new_err(format!(
516 "Unknown norm type '{}'. Supported: fro, nuc, 1, inf",
517 other
518 )));
519 }
520 }
521
522 let results: Vec<Result<f64, String>> = matrices
523 .par_iter()
524 .map(|mat_rows| {
525 let mat = vec2d_to_array2(mat_rows)?;
526 scirs2_linalg::matrix_norm(&mat.view(), norm_type, None)
527 .map_err(|e| format!("Matrix norm failed: {}", e))
528 })
529 .collect();
530
531 results
532 .into_iter()
533 .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_matrix_norm failed: {}", e))))
534 .collect()
535}
536
537pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
539 m.add_function(wrap_pyfunction!(det_py, m)?)?;
541 m.add_function(wrap_pyfunction!(inv_py, m)?)?;
542 m.add_function(wrap_pyfunction!(trace_py, m)?)?;
543
544 m.add_function(wrap_pyfunction!(lu_py, m)?)?;
546 m.add_function(wrap_pyfunction!(qr_py, m)?)?;
547 m.add_function(wrap_pyfunction!(svd_py, m)?)?;
548 m.add_function(wrap_pyfunction!(cholesky_py, m)?)?;
549 m.add_function(wrap_pyfunction!(eig_py, m)?)?;
550 m.add_function(wrap_pyfunction!(eigh_py, m)?)?;
551 m.add_function(wrap_pyfunction!(eigvals_py, m)?)?;
552
553 m.add_function(wrap_pyfunction!(solve_py, m)?)?;
555 m.add_function(wrap_pyfunction!(lstsq_py, m)?)?;
556
557 m.add_function(wrap_pyfunction!(matrix_norm_py, m)?)?;
559 m.add_function(wrap_pyfunction!(vector_norm_py, m)?)?;
560 m.add_function(wrap_pyfunction!(cond_py, m)?)?;
561 m.add_function(wrap_pyfunction!(matrix_rank_py, m)?)?;
562 m.add_function(wrap_pyfunction!(pinv_py, m)?)?;
563
564 m.add_function(wrap_pyfunction!(batch_matmul_py, m)?)?;
566 m.add_function(wrap_pyfunction!(batch_svd_py, m)?)?;
567 m.add_function(wrap_pyfunction!(batch_solve_py, m)?)?;
568 m.add_function(wrap_pyfunction!(batch_matrix_norm_py, m)?)?;
569
570 Ok(())
571}