1use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9use rayon::prelude::*;
10
11use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
13
14use scirs2_core::{Array1, Array2};
16
17use scirs2_linalg::compat::pinv;
19use scirs2_linalg::{
20 basic_trace, cond,
22 eig,
23 lstsq,
24 matrix_norm,
25 matrix_rank,
26 vector_norm,
27};
28
29#[pyfunction]
36fn det_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
37 let binding = a.readonly();
38 let data = binding.as_array();
39
40 scirs2_linalg::det_f64_lapack(&data)
42 .map_err(|e| PyRuntimeError::new_err(format!("Determinant failed: {}", e)))
43}
44
45#[pyfunction]
47fn inv_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
48 let binding = a.readonly();
49 let data = binding.as_array();
50
51 let result = scirs2_linalg::inv_f64_lapack(&data)
53 .map_err(|e| PyRuntimeError::new_err(format!("Inverse failed: {}", e)))?;
54
55 Ok(result.into_pyarray(py).unbind())
56}
57
58#[pyfunction]
60fn trace_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
61 let binding = a.readonly();
62 let data = binding.as_array();
63
64 basic_trace(&data).map_err(|e| PyRuntimeError::new_err(format!("Trace failed: {}", e)))
65}
66
67#[pyfunction]
74fn lu_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
75 let binding = a.readonly();
76 let data = binding.as_array();
77
78 let (p, l, u) = scirs2_linalg::lu_f64_lapack(&data)
80 .map_err(|e| PyRuntimeError::new_err(format!("LU decomposition failed: {}", e)))?;
81
82 let dict = PyDict::new(py);
83 dict.set_item("p", p.into_pyarray(py).unbind())?;
84 dict.set_item("l", l.into_pyarray(py).unbind())?;
85 dict.set_item("u", u.into_pyarray(py).unbind())?;
86
87 Ok(dict.into())
88}
89
90#[pyfunction]
93fn qr_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
94 let binding = a.readonly();
95 let data = binding.as_array();
96
97 let (q, r) = scirs2_linalg::qr_f64_lapack(&data)
99 .map_err(|e| PyRuntimeError::new_err(format!("QR decomposition failed: {}", e)))?;
100
101 let dict = PyDict::new(py);
102 dict.set_item("q", q.into_pyarray(py).unbind())?;
103 dict.set_item("r", r.into_pyarray(py).unbind())?;
104
105 Ok(dict.into())
106}
107
108#[pyfunction]
111#[pyo3(signature = (a, full_matrices=false))]
112fn svd_py(py: Python, a: &Bound<'_, PyArray2<f64>>, full_matrices: bool) -> PyResult<Py<PyAny>> {
113 let binding = a.readonly();
114 let data = binding.as_array();
115
116 let (u, s, vt) = scirs2_linalg::svd_f64_lapack(&data, full_matrices)
118 .map_err(|e| PyRuntimeError::new_err(format!("SVD decomposition failed: {}", e)))?;
119
120 let dict = PyDict::new(py);
121 dict.set_item("u", u.into_pyarray(py).unbind())?;
122 dict.set_item("s", s.into_pyarray(py).unbind())?;
123 dict.set_item("vt", vt.into_pyarray(py).unbind())?;
124
125 Ok(dict.into())
126}
127
128#[pyfunction]
130fn cholesky_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
131 let binding = a.readonly();
132 let data = binding.as_array();
133
134 let result = scirs2_linalg::cholesky_f64_lapack(&data)
136 .map_err(|e| PyRuntimeError::new_err(format!("Cholesky decomposition failed: {}", e)))?;
137
138 Ok(result.into_pyarray(py).unbind())
139}
140
141#[pyfunction]
144fn eig_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
145 let binding = a.readonly();
146 let data = binding.as_array();
147
148 let (eigenvalues, eigenvectors) = scirs2_linalg::eig_f64_lapack(&data)
150 .map_err(|e| PyRuntimeError::new_err(format!("Eigenvalue decomposition failed: {}", e)))?;
151
152 let eigenvalues_real: Vec<f64> = eigenvalues.iter().map(|c| c.re).collect();
154 let eigenvalues_imag: Vec<f64> = eigenvalues.iter().map(|c| c.im).collect();
155
156 let (nrows, ncols) = eigenvectors.dim();
157 let mut eigenvectors_real = Array2::zeros((nrows, ncols));
158 let mut eigenvectors_imag = Array2::zeros((nrows, ncols));
159
160 for ((i, j), val) in eigenvectors.indexed_iter() {
161 eigenvectors_real[[i, j]] = val.re;
162 eigenvectors_imag[[i, j]] = val.im;
163 }
164
165 let dict = PyDict::new(py);
166 dict.set_item(
167 "eigenvalues_real",
168 Array1::from_vec(eigenvalues_real).into_pyarray(py).unbind(),
169 )?;
170 dict.set_item(
171 "eigenvalues_imag",
172 Array1::from_vec(eigenvalues_imag).into_pyarray(py).unbind(),
173 )?;
174 dict.set_item(
175 "eigenvectors_real",
176 eigenvectors_real.into_pyarray(py).unbind(),
177 )?;
178 dict.set_item(
179 "eigenvectors_imag",
180 eigenvectors_imag.into_pyarray(py).unbind(),
181 )?;
182
183 Ok(dict.into())
184}
185
186#[pyfunction]
189fn eigh_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
190 let binding = a.readonly();
191 let data = binding.as_array();
192
193 let (eigenvalues, eigenvectors) = scirs2_linalg::eigh_f64_lapack(&data).map_err(|e| {
195 PyRuntimeError::new_err(format!("Symmetric eigenvalue decomposition failed: {}", e))
196 })?;
197
198 let dict = PyDict::new(py);
199 dict.set_item("eigenvalues", eigenvalues.into_pyarray(py).unbind())?;
200 dict.set_item("eigenvectors", eigenvectors.into_pyarray(py).unbind())?;
201
202 Ok(dict.into())
203}
204
205#[pyfunction]
208fn eigvals_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
209 let binding = a.readonly();
210 let data = binding.as_array();
211
212 let (eigenvalues, _eigenvectors) = eig(&data, None)
213 .map_err(|e| PyRuntimeError::new_err(format!("Eigenvalue computation failed: {}", e)))?;
214
215 let real: Vec<f64> = eigenvalues.iter().map(|c| c.re).collect();
217 let imag: Vec<f64> = eigenvalues.iter().map(|c| c.im).collect();
218
219 let dict = PyDict::new(py);
220 dict.set_item("real", Array1::from_vec(real).into_pyarray(py).unbind())?;
221 dict.set_item("imag", Array1::from_vec(imag).into_pyarray(py).unbind())?;
222
223 Ok(dict.into())
224}
225
226#[pyfunction]
232fn solve_py(
233 py: Python,
234 a: &Bound<'_, PyArray2<f64>>,
235 b: &Bound<'_, PyArray1<f64>>,
236) -> PyResult<Py<PyArray1<f64>>> {
237 let a_binding = a.readonly();
238 let a_data = a_binding.as_array();
239 let b_binding = b.readonly();
240 let b_data = b_binding.as_array();
241
242 let result = scirs2_linalg::solve_f64_lapack(&a_data, &b_data)
244 .map_err(|e| PyRuntimeError::new_err(format!("Linear solve failed: {}", e)))?;
245
246 Ok(result.into_pyarray(py).unbind())
247}
248
249#[pyfunction]
252fn lstsq_py(
253 py: Python,
254 a: &Bound<'_, PyArray2<f64>>,
255 b: &Bound<'_, PyArray1<f64>>,
256) -> PyResult<Py<PyAny>> {
257 let a_binding = a.readonly();
258 let a_data = a_binding.as_array();
259 let b_binding = b.readonly();
260 let b_data = b_binding.as_array();
261
262 let result = lstsq(&a_data, &b_data, None)
263 .map_err(|e| PyRuntimeError::new_err(format!("Least squares failed: {}", e)))?;
264
265 let dict = PyDict::new(py);
266 dict.set_item("solution", result.x.into_pyarray(py).unbind())?;
267 dict.set_item("residuals", result.residuals)?;
268 dict.set_item("rank", result.rank)?;
269 dict.set_item("singular_values", result.s.into_pyarray(py).unbind())?;
270
271 Ok(dict.into())
272}
273
274#[pyfunction]
281#[pyo3(signature = (a, ord="fro"))]
282fn matrix_norm_py(a: &Bound<'_, PyArray2<f64>>, ord: &str) -> PyResult<f64> {
283 let binding = a.readonly();
284 let data = binding.as_array();
285
286 matrix_norm(&data, ord, None)
287 .map_err(|e| PyRuntimeError::new_err(format!("Matrix norm failed: {}", e)))
288}
289
290#[pyfunction]
293#[pyo3(signature = (x, ord=2))]
294fn vector_norm_py(x: &Bound<'_, PyArray1<f64>>, ord: usize) -> PyResult<f64> {
295 let binding = x.readonly();
296 let data = binding.as_array();
297
298 vector_norm(&data, ord)
299 .map_err(|e| PyRuntimeError::new_err(format!("Vector norm failed: {}", e)))
300}
301
302#[pyfunction]
304fn cond_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
305 let binding = a.readonly();
306 let data = binding.as_array();
307
308 cond(&data, None, None)
309 .map_err(|e| PyRuntimeError::new_err(format!("Condition number failed: {}", e)))
310}
311
312#[pyfunction]
314#[pyo3(signature = (a, tol=None))]
315fn matrix_rank_py(a: &Bound<'_, PyArray2<f64>>, tol: Option<f64>) -> PyResult<usize> {
316 let binding = a.readonly();
317 let data = binding.as_array();
318
319 matrix_rank(&data, tol, None)
320 .map_err(|e| PyRuntimeError::new_err(format!("Matrix rank failed: {}", e)))
321}
322
323#[pyfunction]
325#[pyo3(signature = (a, rcond=None))]
326fn pinv_py(
327 py: Python,
328 a: &Bound<'_, PyArray2<f64>>,
329 rcond: Option<f64>,
330) -> PyResult<Py<PyArray2<f64>>> {
331 let binding = a.readonly();
332 let data = binding.as_array();
333
334 let _ = rcond; let result =
337 pinv(&data).map_err(|e| PyRuntimeError::new_err(format!("Pseudoinverse failed: {}", e)))?;
338
339 Ok(result.into_pyarray(py).unbind())
340}
341
342fn vec2d_to_array2(rows: &[Vec<f64>]) -> Result<Array2<f64>, String> {
348 if rows.is_empty() {
349 return Err("Matrix has no rows".to_string());
350 }
351 let nrows = rows.len();
352 let ncols = rows[0].len();
353 for (i, row) in rows.iter().enumerate() {
354 if row.len() != ncols {
355 return Err(format!(
356 "Row {} has {} columns, expected {}",
357 i,
358 row.len(),
359 ncols
360 ));
361 }
362 }
363 let flat: Vec<f64> = rows.iter().flat_map(|r| r.iter().cloned()).collect();
364 Array2::from_shape_vec((nrows, ncols), flat).map_err(|e| format!("Shape error: {}", e))
365}
366
367fn array2_to_vec2d(a: &Array2<f64>) -> Vec<Vec<f64>> {
369 let (nrows, ncols) = a.dim();
370 (0..nrows)
371 .map(|i| (0..ncols).map(|j| a[[i, j]]).collect())
372 .collect()
373}
374
375#[pyfunction]
384fn batch_matmul_py(
385 a_list: Vec<Vec<Vec<f64>>>,
386 b_list: Vec<Vec<Vec<f64>>>,
387) -> PyResult<Vec<Vec<Vec<f64>>>> {
388 if a_list.len() != b_list.len() {
389 return Err(PyRuntimeError::new_err(format!(
390 "a_list length {} does not match b_list length {}",
391 a_list.len(),
392 b_list.len()
393 )));
394 }
395 if a_list.is_empty() {
396 return Ok(vec![]);
397 }
398
399 let results: Vec<Result<Vec<Vec<f64>>, String>> = a_list
400 .par_iter()
401 .zip(b_list.par_iter())
402 .map(|(a_rows, b_rows)| {
403 let a = vec2d_to_array2(a_rows)?;
404 let b = vec2d_to_array2(b_rows)?;
405 let (_, a_cols) = a.dim();
406 let (b_rows_n, _) = b.dim();
407 if a_cols != b_rows_n {
408 return Err(format!(
409 "Incompatible shapes for matmul: inner dims {} != {}",
410 a_cols, b_rows_n
411 ));
412 }
413 let c = a.dot(&b);
415 Ok(array2_to_vec2d(&c))
416 })
417 .collect();
418
419 results
420 .into_iter()
421 .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_matmul failed: {}", e))))
422 .collect()
423}
424
425#[pyfunction]
433fn batch_svd_py(
434 matrices: Vec<Vec<Vec<f64>>>,
435) -> PyResult<Vec<(Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>)>> {
436 if matrices.is_empty() {
437 return Ok(vec![]);
438 }
439
440 let results: Vec<Result<(Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>), String>> = matrices
441 .par_iter()
442 .map(|mat_rows| {
443 let mat = vec2d_to_array2(mat_rows)?;
444 let (u, s, vt) = scirs2_linalg::svd_f64_lapack(&mat.view(), false)
445 .map_err(|e| format!("SVD failed: {}", e))?;
446 Ok((array2_to_vec2d(&u), s.to_vec(), array2_to_vec2d(&vt)))
447 })
448 .collect();
449
450 results
451 .into_iter()
452 .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_svd failed: {}", e))))
453 .collect()
454}
455
456#[pyfunction]
465fn batch_solve_py(a_list: Vec<Vec<Vec<f64>>>, b_list: Vec<Vec<f64>>) -> PyResult<Vec<Vec<f64>>> {
466 if a_list.len() != b_list.len() {
467 return Err(PyRuntimeError::new_err(format!(
468 "a_list length {} does not match b_list length {}",
469 a_list.len(),
470 b_list.len()
471 )));
472 }
473 if a_list.is_empty() {
474 return Ok(vec![]);
475 }
476
477 let results: Vec<Result<Vec<f64>, String>> = a_list
478 .par_iter()
479 .zip(b_list.par_iter())
480 .map(|(a_rows, b_vec)| {
481 let a = vec2d_to_array2(a_rows)?;
482 let b = Array1::from_vec(b_vec.clone());
483 let x = scirs2_linalg::solve_f64_lapack(&a.view(), &b.view())
484 .map_err(|e| format!("Solve failed: {}", e))?;
485 Ok(x.to_vec())
486 })
487 .collect();
488
489 results
490 .into_iter()
491 .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_solve failed: {}", e))))
492 .collect()
493}
494
495#[pyfunction]
504#[pyo3(signature = (matrices, ord=None))]
505fn batch_matrix_norm_py(matrices: Vec<Vec<Vec<f64>>>, ord: Option<&str>) -> PyResult<Vec<f64>> {
506 if matrices.is_empty() {
507 return Ok(vec![]);
508 }
509 let norm_type = ord.unwrap_or("fro");
510 match norm_type {
512 "fro" | "nuc" | "1" | "inf" => {}
513 other => {
514 return Err(PyRuntimeError::new_err(format!(
515 "Unknown norm type '{}'. Supported: fro, nuc, 1, inf",
516 other
517 )));
518 }
519 }
520
521 let results: Vec<Result<f64, String>> = matrices
522 .par_iter()
523 .map(|mat_rows| {
524 let mat = vec2d_to_array2(mat_rows)?;
525 scirs2_linalg::matrix_norm(&mat.view(), norm_type, None)
526 .map_err(|e| format!("Matrix norm failed: {}", e))
527 })
528 .collect();
529
530 results
531 .into_iter()
532 .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_matrix_norm failed: {}", e))))
533 .collect()
534}
535
536pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
538 m.add_function(wrap_pyfunction!(det_py, m)?)?;
540 m.add_function(wrap_pyfunction!(inv_py, m)?)?;
541 m.add_function(wrap_pyfunction!(trace_py, m)?)?;
542
543 m.add_function(wrap_pyfunction!(lu_py, m)?)?;
545 m.add_function(wrap_pyfunction!(qr_py, m)?)?;
546 m.add_function(wrap_pyfunction!(svd_py, m)?)?;
547 m.add_function(wrap_pyfunction!(cholesky_py, m)?)?;
548 m.add_function(wrap_pyfunction!(eig_py, m)?)?;
549 m.add_function(wrap_pyfunction!(eigh_py, m)?)?;
550 m.add_function(wrap_pyfunction!(eigvals_py, m)?)?;
551
552 m.add_function(wrap_pyfunction!(solve_py, m)?)?;
554 m.add_function(wrap_pyfunction!(lstsq_py, m)?)?;
555
556 m.add_function(wrap_pyfunction!(matrix_norm_py, m)?)?;
558 m.add_function(wrap_pyfunction!(vector_norm_py, m)?)?;
559 m.add_function(wrap_pyfunction!(cond_py, m)?)?;
560 m.add_function(wrap_pyfunction!(matrix_rank_py, m)?)?;
561 m.add_function(wrap_pyfunction!(pinv_py, m)?)?;
562
563 m.add_function(wrap_pyfunction!(batch_matmul_py, m)?)?;
565 m.add_function(wrap_pyfunction!(batch_svd_py, m)?)?;
566 m.add_function(wrap_pyfunction!(batch_solve_py, m)?)?;
567 m.add_function(wrap_pyfunction!(batch_matrix_norm_py, m)?)?;
568
569 Ok(())
570}