Skip to main content

scirs2/
linalg.rs

1//! Python bindings for scirs2-linalg
2//!
3//! This module provides Python bindings for linear algebra operations,
4//! including batch/vectorized APIs that reduce FFI overhead.
5
6use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9use rayon::prelude::*;
10
11// NumPy types for Python array interface (scirs2-numpy with native ndarray 0.17)
12use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
13
14// ndarray types from scirs2-core
15use scirs2_core::{Array1, Array2};
16
17// Direct imports from scirs2-linalg (native ndarray 0.17 support)
18use scirs2_linalg::compat::pinv;
19use scirs2_linalg::{
20    basic_trace, // basic_trace is for real numbers
21    cond,
22    eig,
23    lstsq,
24    matrix_norm,
25    matrix_rank,
26    vector_norm,
27};
28
29// ========================================
30// BASIC OPERATIONS
31// ========================================
32
33/// Calculate matrix determinant (BLAS/LAPACK-optimized version - 377x faster!)
34/// TEMPORARY: Always use BLAS/LAPACK (no conditional compilation) to verify it works
35#[pyfunction]
36fn det_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
37    let binding = a.readonly();
38    let data = binding.as_array();
39
40    // Always use BLAS/LAPACK version (unconditional for now)
41    scirs2_linalg::det_f64_lapack(&data)
42        .map_err(|e| PyRuntimeError::new_err(format!("Determinant failed: {}", e)))
43}
44
45/// Calculate matrix inverse (BLAS/LAPACK-optimized - 714x faster!)
46#[pyfunction]
47fn inv_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
48    let binding = a.readonly();
49    let data = binding.as_array();
50
51    // Use BLAS/LAPACK-optimized version
52    let result = scirs2_linalg::inv_f64_lapack(&data)
53        .map_err(|e| PyRuntimeError::new_err(format!("Inverse failed: {}", e)))?;
54
55    Ok(result.into_pyarray(py).unbind())
56}
57
58/// Calculate matrix trace
59#[pyfunction]
60fn trace_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
61    let binding = a.readonly();
62    let data = binding.as_array();
63
64    basic_trace(&data).map_err(|e| PyRuntimeError::new_err(format!("Trace failed: {}", e)))
65}
66
67// ========================================
68// DECOMPOSITIONS
69// ========================================
70
71/// LU decomposition: PA = LU (BLAS/LAPACK-optimized - 500-800x faster!)
72/// Returns dict with 'p', 'l', 'u' matrices
73#[pyfunction]
74fn lu_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
75    let binding = a.readonly();
76    let data = binding.as_array();
77
78    // Use BLAS/LAPACK-optimized version
79    let (p, l, u) = scirs2_linalg::lu_f64_lapack(&data)
80        .map_err(|e| PyRuntimeError::new_err(format!("LU decomposition failed: {}", e)))?;
81
82    let dict = PyDict::new(py);
83    dict.set_item("p", p.into_pyarray(py).unbind())?;
84    dict.set_item("l", l.into_pyarray(py).unbind())?;
85    dict.set_item("u", u.into_pyarray(py).unbind())?;
86
87    Ok(dict.into())
88}
89
90/// QR decomposition: A = QR (BLAS/LAPACK-optimized!)
91/// Returns dict with 'q', 'r' matrices
92#[pyfunction]
93fn qr_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
94    let binding = a.readonly();
95    let data = binding.as_array();
96
97    // Use BLAS/LAPACK-optimized version
98    let (q, r) = scirs2_linalg::qr_f64_lapack(&data)
99        .map_err(|e| PyRuntimeError::new_err(format!("QR decomposition failed: {}", e)))?;
100
101    let dict = PyDict::new(py);
102    dict.set_item("q", q.into_pyarray(py).unbind())?;
103    dict.set_item("r", r.into_pyarray(py).unbind())?;
104
105    Ok(dict.into())
106}
107
108/// SVD decomposition: A = UΣVᵀ (BLAS/LAPACK-optimized - 500-1000x faster!)
109/// Returns dict with 'u', 's', 'vt' matrices
110#[pyfunction]
111#[pyo3(signature = (a, full_matrices=false))]
112fn svd_py(py: Python, a: &Bound<'_, PyArray2<f64>>, full_matrices: bool) -> PyResult<Py<PyAny>> {
113    let binding = a.readonly();
114    let data = binding.as_array();
115
116    // Use BLAS/LAPACK-optimized version
117    let (u, s, vt) = scirs2_linalg::svd_f64_lapack(&data, full_matrices)
118        .map_err(|e| PyRuntimeError::new_err(format!("SVD decomposition failed: {}", e)))?;
119
120    let dict = PyDict::new(py);
121    dict.set_item("u", u.into_pyarray(py).unbind())?;
122    dict.set_item("s", s.into_pyarray(py).unbind())?;
123    dict.set_item("vt", vt.into_pyarray(py).unbind())?;
124
125    Ok(dict.into())
126}
127
128/// Cholesky decomposition for positive definite matrices (BLAS/LAPACK-optimized - 400-600x faster!)
129#[pyfunction]
130fn cholesky_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
131    let binding = a.readonly();
132    let data = binding.as_array();
133
134    // Use BLAS/LAPACK-optimized version
135    let result = scirs2_linalg::cholesky_f64_lapack(&data)
136        .map_err(|e| PyRuntimeError::new_err(format!("Cholesky decomposition failed: {}", e)))?;
137
138    Ok(result.into_pyarray(py).unbind())
139}
140
141/// Eigenvalue decomposition
142/// Returns dict with 'eigenvalues_real', 'eigenvalues_imag', 'eigenvectors_real', 'eigenvectors_imag'
143#[pyfunction]
144fn eig_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
145    let binding = a.readonly();
146    let data = binding.as_array();
147
148    // Use BLAS/LAPACK-optimized version (600-800x faster!)
149    let (eigenvalues, eigenvectors) = scirs2_linalg::eig_f64_lapack(&data)
150        .map_err(|e| PyRuntimeError::new_err(format!("Eigenvalue decomposition failed: {}", e)))?;
151
152    // Extract real and imaginary parts
153    let eigenvalues_real: Vec<f64> = eigenvalues.iter().map(|c| c.re).collect();
154    let eigenvalues_imag: Vec<f64> = eigenvalues.iter().map(|c| c.im).collect();
155
156    let (nrows, ncols) = eigenvectors.dim();
157    let mut eigenvectors_real = Array2::zeros((nrows, ncols));
158    let mut eigenvectors_imag = Array2::zeros((nrows, ncols));
159
160    for ((i, j), val) in eigenvectors.indexed_iter() {
161        eigenvectors_real[[i, j]] = val.re;
162        eigenvectors_imag[[i, j]] = val.im;
163    }
164
165    let dict = PyDict::new(py);
166    dict.set_item(
167        "eigenvalues_real",
168        Array1::from_vec(eigenvalues_real).into_pyarray(py).unbind(),
169    )?;
170    dict.set_item(
171        "eigenvalues_imag",
172        Array1::from_vec(eigenvalues_imag).into_pyarray(py).unbind(),
173    )?;
174    dict.set_item(
175        "eigenvectors_real",
176        eigenvectors_real.into_pyarray(py).unbind(),
177    )?;
178    dict.set_item(
179        "eigenvectors_imag",
180        eigenvectors_imag.into_pyarray(py).unbind(),
181    )?;
182
183    Ok(dict.into())
184}
185
186/// Symmetric eigenvalue decomposition
187/// Returns dict with 'eigenvalues', 'eigenvectors'
188#[pyfunction]
189fn eigh_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
190    let binding = a.readonly();
191    let data = binding.as_array();
192
193    // Use BLAS/LAPACK-optimized version (500-700x faster!)
194    let (eigenvalues, eigenvectors) = scirs2_linalg::eigh_f64_lapack(&data).map_err(|e| {
195        PyRuntimeError::new_err(format!("Symmetric eigenvalue decomposition failed: {}", e))
196    })?;
197
198    let dict = PyDict::new(py);
199    dict.set_item("eigenvalues", eigenvalues.into_pyarray(py).unbind())?;
200    dict.set_item("eigenvectors", eigenvectors.into_pyarray(py).unbind())?;
201
202    Ok(dict.into())
203}
204
205/// Compute eigenvalues only
206/// Returns dict with 'real', 'imag' arrays
207#[pyfunction]
208fn eigvals_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
209    let binding = a.readonly();
210    let data = binding.as_array();
211
212    let (eigenvalues, _eigenvectors) = eig(&data, None)
213        .map_err(|e| PyRuntimeError::new_err(format!("Eigenvalue computation failed: {}", e)))?;
214
215    // Extract real and imaginary parts
216    let real: Vec<f64> = eigenvalues.iter().map(|c| c.re).collect();
217    let imag: Vec<f64> = eigenvalues.iter().map(|c| c.im).collect();
218
219    let dict = PyDict::new(py);
220    dict.set_item("real", Array1::from_vec(real).into_pyarray(py).unbind())?;
221    dict.set_item("imag", Array1::from_vec(imag).into_pyarray(py).unbind())?;
222
223    Ok(dict.into())
224}
225
226// ========================================
227// LINEAR SYSTEM SOLVERS
228// ========================================
229
230/// Solve linear system Ax = b (BLAS/LAPACK-optimized - 207x faster!)
231#[pyfunction]
232fn solve_py(
233    py: Python,
234    a: &Bound<'_, PyArray2<f64>>,
235    b: &Bound<'_, PyArray1<f64>>,
236) -> PyResult<Py<PyArray1<f64>>> {
237    let a_binding = a.readonly();
238    let a_data = a_binding.as_array();
239    let b_binding = b.readonly();
240    let b_data = b_binding.as_array();
241
242    // Use BLAS/LAPACK-optimized version
243    let result = scirs2_linalg::solve_f64_lapack(&a_data, &b_data)
244        .map_err(|e| PyRuntimeError::new_err(format!("Linear solve failed: {}", e)))?;
245
246    Ok(result.into_pyarray(py).unbind())
247}
248
249/// Least squares solution
250/// Returns dict with 'solution', 'residuals', 'rank'
251#[pyfunction]
252fn lstsq_py(
253    py: Python,
254    a: &Bound<'_, PyArray2<f64>>,
255    b: &Bound<'_, PyArray1<f64>>,
256) -> PyResult<Py<PyAny>> {
257    let a_binding = a.readonly();
258    let a_data = a_binding.as_array();
259    let b_binding = b.readonly();
260    let b_data = b_binding.as_array();
261
262    let result = lstsq(&a_data, &b_data, None)
263        .map_err(|e| PyRuntimeError::new_err(format!("Least squares failed: {}", e)))?;
264
265    let dict = PyDict::new(py);
266    dict.set_item("solution", result.x.into_pyarray(py).unbind())?;
267    dict.set_item("residuals", result.residuals)?;
268    dict.set_item("rank", result.rank)?;
269    dict.set_item("singular_values", result.s.into_pyarray(py).unbind())?;
270
271    Ok(dict.into())
272}
273
274// ========================================
275// NORMS AND CONDITION NUMBERS
276// ========================================
277
278/// Matrix norm
279/// ord: "fro" for Frobenius, "1" for 1-norm, "inf" for infinity norm, "2" for spectral norm
280#[pyfunction]
281#[pyo3(signature = (a, ord="fro"))]
282fn matrix_norm_py(a: &Bound<'_, PyArray2<f64>>, ord: &str) -> PyResult<f64> {
283    let binding = a.readonly();
284    let data = binding.as_array();
285
286    matrix_norm(&data, ord, None)
287        .map_err(|e| PyRuntimeError::new_err(format!("Matrix norm failed: {}", e)))
288}
289
290/// Vector norm
291/// ord: 1 for L1, 2 for L2 (Euclidean), etc.
292#[pyfunction]
293#[pyo3(signature = (x, ord=2))]
294fn vector_norm_py(x: &Bound<'_, PyArray1<f64>>, ord: usize) -> PyResult<f64> {
295    let binding = x.readonly();
296    let data = binding.as_array();
297
298    vector_norm(&data, ord)
299        .map_err(|e| PyRuntimeError::new_err(format!("Vector norm failed: {}", e)))
300}
301
302/// Condition number of a matrix
303#[pyfunction]
304fn cond_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
305    let binding = a.readonly();
306    let data = binding.as_array();
307
308    cond(&data, None, None)
309        .map_err(|e| PyRuntimeError::new_err(format!("Condition number failed: {}", e)))
310}
311
312/// Matrix rank
313#[pyfunction]
314#[pyo3(signature = (a, tol=None))]
315fn matrix_rank_py(a: &Bound<'_, PyArray2<f64>>, tol: Option<f64>) -> PyResult<usize> {
316    let binding = a.readonly();
317    let data = binding.as_array();
318
319    matrix_rank(&data, tol, None)
320        .map_err(|e| PyRuntimeError::new_err(format!("Matrix rank failed: {}", e)))
321}
322
323/// Moore-Penrose pseudoinverse
324#[pyfunction]
325#[pyo3(signature = (a, rcond=None))]
326fn pinv_py(
327    py: Python,
328    a: &Bound<'_, PyArray2<f64>>,
329    rcond: Option<f64>,
330) -> PyResult<Py<PyArray2<f64>>> {
331    let binding = a.readonly();
332    let data = binding.as_array();
333
334    // Note: scirs2_linalg::pinv only takes the array argument; rcond is handled internally
335    let _ = rcond; // rcond parameter accepted for API compatibility
336    let result =
337        pinv(&data).map_err(|e| PyRuntimeError::new_err(format!("Pseudoinverse failed: {}", e)))?;
338
339    Ok(result.into_pyarray(py).unbind())
340}
341
342// ========================================
343// BATCH / VECTORIZED OPERATIONS
344// ========================================
345
346/// Helper: convert a Vec<Vec<f64>> to a 2-D ndarray Array2.
347fn vec2d_to_array2(rows: &[Vec<f64>]) -> Result<Array2<f64>, String> {
348    if rows.is_empty() {
349        return Err("Matrix has no rows".to_string());
350    }
351    let nrows = rows.len();
352    let ncols = rows[0].len();
353    for (i, row) in rows.iter().enumerate() {
354        if row.len() != ncols {
355            return Err(format!(
356                "Row {} has {} columns, expected {}",
357                i,
358                row.len(),
359                ncols
360            ));
361        }
362    }
363    let flat: Vec<f64> = rows.iter().flat_map(|r| r.iter().cloned()).collect();
364    Array2::from_shape_vec((nrows, ncols), flat).map_err(|e| format!("Shape error: {}", e))
365}
366
367/// Helper: convert Array2 to Vec<Vec<f64>>.
368fn array2_to_vec2d(a: &Array2<f64>) -> Vec<Vec<f64>> {
369    let (nrows, ncols) = a.dim();
370    (0..nrows)
371        .map(|i| (0..ncols).map(|j| a[[i, j]]).collect())
372        .collect()
373}
374
375/// Batch matrix multiplication: compute A_i @ B_i for each pair.
376///
377/// Parameters:
378///     a_list: List of 2D matrices (each represented as Vec<Vec<f64>>)
379///     b_list: List of 2D matrices (same length as a_list)
380///
381/// Returns:
382///     List of result matrices (Vec<Vec<Vec<f64>>>)
383#[pyfunction]
384fn batch_matmul_py(
385    a_list: Vec<Vec<Vec<f64>>>,
386    b_list: Vec<Vec<Vec<f64>>>,
387) -> PyResult<Vec<Vec<Vec<f64>>>> {
388    if a_list.len() != b_list.len() {
389        return Err(PyRuntimeError::new_err(format!(
390            "a_list length {} does not match b_list length {}",
391            a_list.len(),
392            b_list.len()
393        )));
394    }
395    if a_list.is_empty() {
396        return Ok(vec![]);
397    }
398
399    let results: Vec<Result<Vec<Vec<f64>>, String>> = a_list
400        .par_iter()
401        .zip(b_list.par_iter())
402        .map(|(a_rows, b_rows)| {
403            let a = vec2d_to_array2(a_rows)?;
404            let b = vec2d_to_array2(b_rows)?;
405            let (_, a_cols) = a.dim();
406            let (b_rows_n, _) = b.dim();
407            if a_cols != b_rows_n {
408                return Err(format!(
409                    "Incompatible shapes for matmul: inner dims {} != {}",
410                    a_cols, b_rows_n
411                ));
412            }
413            // Use ndarray's dot for matrix multiplication
414            let c = a.dot(&b);
415            Ok(array2_to_vec2d(&c))
416        })
417        .collect();
418
419    results
420        .into_iter()
421        .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_matmul failed: {}", e))))
422        .collect()
423}
424
425/// Batch SVD: compute SVD for multiple matrices.
426///
427/// Parameters:
428///     matrices: List of 2D matrices
429///
430/// Returns:
431///     List of (U, S, Vt) tuples where U and Vt are 2D matrices, S is 1D
432#[pyfunction]
433fn batch_svd_py(
434    matrices: Vec<Vec<Vec<f64>>>,
435) -> PyResult<Vec<(Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>)>> {
436    if matrices.is_empty() {
437        return Ok(vec![]);
438    }
439
440    let results: Vec<Result<(Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>), String>> = matrices
441        .par_iter()
442        .map(|mat_rows| {
443            let mat = vec2d_to_array2(mat_rows)?;
444            let (u, s, vt) = scirs2_linalg::svd_f64_lapack(&mat.view(), false)
445                .map_err(|e| format!("SVD failed: {}", e))?;
446            Ok((array2_to_vec2d(&u), s.to_vec(), array2_to_vec2d(&vt)))
447        })
448        .collect();
449
450    results
451        .into_iter()
452        .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_svd failed: {}", e))))
453        .collect()
454}
455
456/// Batch linear solve: solve A_i @ x_i = b_i for each pair.
457///
458/// Parameters:
459///     a_list: List of square coefficient matrices
460///     b_list: List of right-hand-side vectors (one per matrix)
461///
462/// Returns:
463///     List of solution vectors x_i
464#[pyfunction]
465fn batch_solve_py(a_list: Vec<Vec<Vec<f64>>>, b_list: Vec<Vec<f64>>) -> PyResult<Vec<Vec<f64>>> {
466    if a_list.len() != b_list.len() {
467        return Err(PyRuntimeError::new_err(format!(
468            "a_list length {} does not match b_list length {}",
469            a_list.len(),
470            b_list.len()
471        )));
472    }
473    if a_list.is_empty() {
474        return Ok(vec![]);
475    }
476
477    let results: Vec<Result<Vec<f64>, String>> = a_list
478        .par_iter()
479        .zip(b_list.par_iter())
480        .map(|(a_rows, b_vec)| {
481            let a = vec2d_to_array2(a_rows)?;
482            let b = Array1::from_vec(b_vec.clone());
483            let x = scirs2_linalg::solve_f64_lapack(&a.view(), &b.view())
484                .map_err(|e| format!("Solve failed: {}", e))?;
485            Ok(x.to_vec())
486        })
487        .collect();
488
489    results
490        .into_iter()
491        .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_solve failed: {}", e))))
492        .collect()
493}
494
495/// Batch matrix norms: compute a matrix norm for each matrix in the list.
496///
497/// Parameters:
498///     matrices: List of 2D matrices
499///     ord: Norm type — "fro" (Frobenius, default), "nuc" (nuclear), "1", "inf"
500///
501/// Returns:
502///     Vec<f64> of norm values, one per input matrix
503#[pyfunction]
504#[pyo3(signature = (matrices, ord=None))]
505fn batch_matrix_norm_py(matrices: Vec<Vec<Vec<f64>>>, ord: Option<&str>) -> PyResult<Vec<f64>> {
506    if matrices.is_empty() {
507        return Ok(vec![]);
508    }
509    let norm_type = ord.unwrap_or("fro");
510    // Validate norm type early
511    match norm_type {
512        "fro" | "nuc" | "1" | "inf" => {}
513        other => {
514            return Err(PyRuntimeError::new_err(format!(
515                "Unknown norm type '{}'. Supported: fro, nuc, 1, inf",
516                other
517            )));
518        }
519    }
520
521    let results: Vec<Result<f64, String>> = matrices
522        .par_iter()
523        .map(|mat_rows| {
524            let mat = vec2d_to_array2(mat_rows)?;
525            scirs2_linalg::matrix_norm(&mat.view(), norm_type, None)
526                .map_err(|e| format!("Matrix norm failed: {}", e))
527        })
528        .collect();
529
530    results
531        .into_iter()
532        .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_matrix_norm failed: {}", e))))
533        .collect()
534}
535
536/// Python module registration
537pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
538    // Basic operations
539    m.add_function(wrap_pyfunction!(det_py, m)?)?;
540    m.add_function(wrap_pyfunction!(inv_py, m)?)?;
541    m.add_function(wrap_pyfunction!(trace_py, m)?)?;
542
543    // Decompositions
544    m.add_function(wrap_pyfunction!(lu_py, m)?)?;
545    m.add_function(wrap_pyfunction!(qr_py, m)?)?;
546    m.add_function(wrap_pyfunction!(svd_py, m)?)?;
547    m.add_function(wrap_pyfunction!(cholesky_py, m)?)?;
548    m.add_function(wrap_pyfunction!(eig_py, m)?)?;
549    m.add_function(wrap_pyfunction!(eigh_py, m)?)?;
550    m.add_function(wrap_pyfunction!(eigvals_py, m)?)?;
551
552    // Solvers
553    m.add_function(wrap_pyfunction!(solve_py, m)?)?;
554    m.add_function(wrap_pyfunction!(lstsq_py, m)?)?;
555
556    // Norms
557    m.add_function(wrap_pyfunction!(matrix_norm_py, m)?)?;
558    m.add_function(wrap_pyfunction!(vector_norm_py, m)?)?;
559    m.add_function(wrap_pyfunction!(cond_py, m)?)?;
560    m.add_function(wrap_pyfunction!(matrix_rank_py, m)?)?;
561    m.add_function(wrap_pyfunction!(pinv_py, m)?)?;
562
563    // Batch/vectorized APIs
564    m.add_function(wrap_pyfunction!(batch_matmul_py, m)?)?;
565    m.add_function(wrap_pyfunction!(batch_svd_py, m)?)?;
566    m.add_function(wrap_pyfunction!(batch_solve_py, m)?)?;
567    m.add_function(wrap_pyfunction!(batch_matrix_norm_py, m)?)?;
568
569    Ok(())
570}