Skip to main content

scirs2/
linalg.rs

1//! Python bindings for scirs2-linalg
2//!
3//! This module provides Python bindings for linear algebra operations,
4//! including batch/vectorized APIs that reduce FFI overhead.
5
6use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9use rayon::prelude::*;
10
11// NumPy types for Python array interface (scirs2-numpy with native ndarray 0.17)
12use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
13
14// ndarray types from scirs2-core
15use scirs2_core::{Array1, Array2};
16
17// Direct imports from scirs2-linalg (native ndarray 0.17 support)
18use scirs2_linalg::compat::pinv;
19use scirs2_linalg::{
20    basic_trace, // basic_trace is for real numbers
21    cond,
22    eig,
23    lstsq,
24    matrix_norm,
25    matrix_rank,
26    vector_norm,
27};
28
29/// Type alias for SVD result: (U, S, Vt) where U and Vt are 2D matrices, S is 1D.
30type SvdResult = (Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>);
31
32// ========================================
33// BASIC OPERATIONS
34// ========================================
35
36/// Calculate matrix determinant (BLAS/LAPACK-optimized version - 377x faster!)
37/// TEMPORARY: Always use BLAS/LAPACK (no conditional compilation) to verify it works
38#[pyfunction]
39fn det_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
40    let binding = a.readonly();
41    let data = binding.as_array();
42
43    // Always use BLAS/LAPACK version (unconditional for now)
44    scirs2_linalg::det_f64_lapack(&data)
45        .map_err(|e| PyRuntimeError::new_err(format!("Determinant failed: {}", e)))
46}
47
48/// Calculate matrix inverse (BLAS/LAPACK-optimized - 714x faster!)
49#[pyfunction]
50fn inv_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
51    let binding = a.readonly();
52    let data = binding.as_array();
53
54    // Use BLAS/LAPACK-optimized version
55    let result = scirs2_linalg::inv_f64_lapack(&data)
56        .map_err(|e| PyRuntimeError::new_err(format!("Inverse failed: {}", e)))?;
57
58    Ok(result.into_pyarray(py).unbind())
59}
60
61/// Calculate matrix trace
62#[pyfunction]
63fn trace_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
64    let binding = a.readonly();
65    let data = binding.as_array();
66
67    basic_trace(&data).map_err(|e| PyRuntimeError::new_err(format!("Trace failed: {}", e)))
68}
69
70// ========================================
71// DECOMPOSITIONS
72// ========================================
73
74/// LU decomposition: PA = LU (BLAS/LAPACK-optimized - 500-800x faster!)
75/// Returns dict with 'p', 'l', 'u' matrices
76#[pyfunction]
77fn lu_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
78    let binding = a.readonly();
79    let data = binding.as_array();
80
81    // Use BLAS/LAPACK-optimized version
82    let (p, l, u) = scirs2_linalg::lu_f64_lapack(&data)
83        .map_err(|e| PyRuntimeError::new_err(format!("LU decomposition failed: {}", e)))?;
84
85    let dict = PyDict::new(py);
86    dict.set_item("p", p.into_pyarray(py).unbind())?;
87    dict.set_item("l", l.into_pyarray(py).unbind())?;
88    dict.set_item("u", u.into_pyarray(py).unbind())?;
89
90    Ok(dict.into())
91}
92
93/// QR decomposition: A = QR (BLAS/LAPACK-optimized!)
94/// Returns dict with 'q', 'r' matrices
95#[pyfunction]
96fn qr_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
97    let binding = a.readonly();
98    let data = binding.as_array();
99
100    // Use BLAS/LAPACK-optimized version
101    let (q, r) = scirs2_linalg::qr_f64_lapack(&data)
102        .map_err(|e| PyRuntimeError::new_err(format!("QR decomposition failed: {}", e)))?;
103
104    let dict = PyDict::new(py);
105    dict.set_item("q", q.into_pyarray(py).unbind())?;
106    dict.set_item("r", r.into_pyarray(py).unbind())?;
107
108    Ok(dict.into())
109}
110
111/// SVD decomposition: A = UΣVᵀ (BLAS/LAPACK-optimized - 500-1000x faster!)
112/// Returns dict with 'u', 's', 'vt' matrices
113#[pyfunction]
114#[pyo3(signature = (a, full_matrices=false))]
115fn svd_py(py: Python, a: &Bound<'_, PyArray2<f64>>, full_matrices: bool) -> PyResult<Py<PyAny>> {
116    let binding = a.readonly();
117    let data = binding.as_array();
118
119    // Use BLAS/LAPACK-optimized version
120    let (u, s, vt) = scirs2_linalg::svd_f64_lapack(&data, full_matrices)
121        .map_err(|e| PyRuntimeError::new_err(format!("SVD decomposition failed: {}", e)))?;
122
123    let dict = PyDict::new(py);
124    dict.set_item("u", u.into_pyarray(py).unbind())?;
125    dict.set_item("s", s.into_pyarray(py).unbind())?;
126    dict.set_item("vt", vt.into_pyarray(py).unbind())?;
127
128    Ok(dict.into())
129}
130
131/// Cholesky decomposition for positive definite matrices (BLAS/LAPACK-optimized - 400-600x faster!)
132#[pyfunction]
133fn cholesky_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
134    let binding = a.readonly();
135    let data = binding.as_array();
136
137    // Use BLAS/LAPACK-optimized version
138    let result = scirs2_linalg::cholesky_f64_lapack(&data)
139        .map_err(|e| PyRuntimeError::new_err(format!("Cholesky decomposition failed: {}", e)))?;
140
141    Ok(result.into_pyarray(py).unbind())
142}
143
144/// Eigenvalue decomposition
145/// Returns dict with 'eigenvalues_real', 'eigenvalues_imag', 'eigenvectors_real', 'eigenvectors_imag'
146#[pyfunction]
147fn eig_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
148    let binding = a.readonly();
149    let data = binding.as_array();
150
151    // Use BLAS/LAPACK-optimized version (600-800x faster!)
152    let (eigenvalues, eigenvectors) = scirs2_linalg::eig_f64_lapack(&data)
153        .map_err(|e| PyRuntimeError::new_err(format!("Eigenvalue decomposition failed: {}", e)))?;
154
155    // Extract real and imaginary parts
156    let eigenvalues_real: Vec<f64> = eigenvalues.iter().map(|c| c.re).collect();
157    let eigenvalues_imag: Vec<f64> = eigenvalues.iter().map(|c| c.im).collect();
158
159    let (nrows, ncols) = eigenvectors.dim();
160    let mut eigenvectors_real = Array2::zeros((nrows, ncols));
161    let mut eigenvectors_imag = Array2::zeros((nrows, ncols));
162
163    for ((i, j), val) in eigenvectors.indexed_iter() {
164        eigenvectors_real[[i, j]] = val.re;
165        eigenvectors_imag[[i, j]] = val.im;
166    }
167
168    let dict = PyDict::new(py);
169    dict.set_item(
170        "eigenvalues_real",
171        Array1::from_vec(eigenvalues_real).into_pyarray(py).unbind(),
172    )?;
173    dict.set_item(
174        "eigenvalues_imag",
175        Array1::from_vec(eigenvalues_imag).into_pyarray(py).unbind(),
176    )?;
177    dict.set_item(
178        "eigenvectors_real",
179        eigenvectors_real.into_pyarray(py).unbind(),
180    )?;
181    dict.set_item(
182        "eigenvectors_imag",
183        eigenvectors_imag.into_pyarray(py).unbind(),
184    )?;
185
186    Ok(dict.into())
187}
188
189/// Symmetric eigenvalue decomposition
190/// Returns dict with 'eigenvalues', 'eigenvectors'
191#[pyfunction]
192fn eigh_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
193    let binding = a.readonly();
194    let data = binding.as_array();
195
196    // Use BLAS/LAPACK-optimized version (500-700x faster!)
197    let (eigenvalues, eigenvectors) = scirs2_linalg::eigh_f64_lapack(&data).map_err(|e| {
198        PyRuntimeError::new_err(format!("Symmetric eigenvalue decomposition failed: {}", e))
199    })?;
200
201    let dict = PyDict::new(py);
202    dict.set_item("eigenvalues", eigenvalues.into_pyarray(py).unbind())?;
203    dict.set_item("eigenvectors", eigenvectors.into_pyarray(py).unbind())?;
204
205    Ok(dict.into())
206}
207
208/// Compute eigenvalues only
209/// Returns dict with 'real', 'imag' arrays
210#[pyfunction]
211fn eigvals_py(py: Python, a: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyAny>> {
212    let binding = a.readonly();
213    let data = binding.as_array();
214
215    let (eigenvalues, _eigenvectors) = eig(&data, None)
216        .map_err(|e| PyRuntimeError::new_err(format!("Eigenvalue computation failed: {}", e)))?;
217
218    // Extract real and imaginary parts
219    let real: Vec<f64> = eigenvalues.iter().map(|c| c.re).collect();
220    let imag: Vec<f64> = eigenvalues.iter().map(|c| c.im).collect();
221
222    let dict = PyDict::new(py);
223    dict.set_item("real", Array1::from_vec(real).into_pyarray(py).unbind())?;
224    dict.set_item("imag", Array1::from_vec(imag).into_pyarray(py).unbind())?;
225
226    Ok(dict.into())
227}
228
229// ========================================
230// LINEAR SYSTEM SOLVERS
231// ========================================
232
233/// Solve linear system Ax = b (BLAS/LAPACK-optimized - 207x faster!)
234#[pyfunction]
235fn solve_py(
236    py: Python,
237    a: &Bound<'_, PyArray2<f64>>,
238    b: &Bound<'_, PyArray1<f64>>,
239) -> PyResult<Py<PyArray1<f64>>> {
240    let a_binding = a.readonly();
241    let a_data = a_binding.as_array();
242    let b_binding = b.readonly();
243    let b_data = b_binding.as_array();
244
245    // Use BLAS/LAPACK-optimized version
246    let result = scirs2_linalg::solve_f64_lapack(&a_data, &b_data)
247        .map_err(|e| PyRuntimeError::new_err(format!("Linear solve failed: {}", e)))?;
248
249    Ok(result.into_pyarray(py).unbind())
250}
251
252/// Least squares solution
253/// Returns dict with 'solution', 'residuals', 'rank'
254#[pyfunction]
255fn lstsq_py(
256    py: Python,
257    a: &Bound<'_, PyArray2<f64>>,
258    b: &Bound<'_, PyArray1<f64>>,
259) -> PyResult<Py<PyAny>> {
260    let a_binding = a.readonly();
261    let a_data = a_binding.as_array();
262    let b_binding = b.readonly();
263    let b_data = b_binding.as_array();
264
265    let result = lstsq(&a_data, &b_data, None)
266        .map_err(|e| PyRuntimeError::new_err(format!("Least squares failed: {}", e)))?;
267
268    let dict = PyDict::new(py);
269    dict.set_item("solution", result.x.into_pyarray(py).unbind())?;
270    dict.set_item("residuals", result.residuals)?;
271    dict.set_item("rank", result.rank)?;
272    dict.set_item("singular_values", result.s.into_pyarray(py).unbind())?;
273
274    Ok(dict.into())
275}
276
277// ========================================
278// NORMS AND CONDITION NUMBERS
279// ========================================
280
281/// Matrix norm
282/// ord: "fro" for Frobenius, "1" for 1-norm, "inf" for infinity norm, "2" for spectral norm
283#[pyfunction]
284#[pyo3(signature = (a, ord="fro"))]
285fn matrix_norm_py(a: &Bound<'_, PyArray2<f64>>, ord: &str) -> PyResult<f64> {
286    let binding = a.readonly();
287    let data = binding.as_array();
288
289    matrix_norm(&data, ord, None)
290        .map_err(|e| PyRuntimeError::new_err(format!("Matrix norm failed: {}", e)))
291}
292
293/// Vector norm
294/// ord: 1 for L1, 2 for L2 (Euclidean), etc.
295#[pyfunction]
296#[pyo3(signature = (x, ord=2))]
297fn vector_norm_py(x: &Bound<'_, PyArray1<f64>>, ord: usize) -> PyResult<f64> {
298    let binding = x.readonly();
299    let data = binding.as_array();
300
301    vector_norm(&data, ord)
302        .map_err(|e| PyRuntimeError::new_err(format!("Vector norm failed: {}", e)))
303}
304
305/// Condition number of a matrix
306#[pyfunction]
307fn cond_py(a: &Bound<'_, PyArray2<f64>>) -> PyResult<f64> {
308    let binding = a.readonly();
309    let data = binding.as_array();
310
311    cond(&data, None, None)
312        .map_err(|e| PyRuntimeError::new_err(format!("Condition number failed: {}", e)))
313}
314
315/// Matrix rank
316#[pyfunction]
317#[pyo3(signature = (a, tol=None))]
318fn matrix_rank_py(a: &Bound<'_, PyArray2<f64>>, tol: Option<f64>) -> PyResult<usize> {
319    let binding = a.readonly();
320    let data = binding.as_array();
321
322    matrix_rank(&data, tol, None)
323        .map_err(|e| PyRuntimeError::new_err(format!("Matrix rank failed: {}", e)))
324}
325
326/// Moore-Penrose pseudoinverse
327#[pyfunction]
328#[pyo3(signature = (a, rcond=None))]
329fn pinv_py(
330    py: Python,
331    a: &Bound<'_, PyArray2<f64>>,
332    rcond: Option<f64>,
333) -> PyResult<Py<PyArray2<f64>>> {
334    let binding = a.readonly();
335    let data = binding.as_array();
336
337    // Note: scirs2_linalg::pinv only takes the array argument; rcond is handled internally
338    let _ = rcond; // rcond parameter accepted for API compatibility
339    let result =
340        pinv(&data).map_err(|e| PyRuntimeError::new_err(format!("Pseudoinverse failed: {}", e)))?;
341
342    Ok(result.into_pyarray(py).unbind())
343}
344
345// ========================================
346// BATCH / VECTORIZED OPERATIONS
347// ========================================
348
349/// Helper: convert a Vec<Vec<f64>> to a 2-D ndarray Array2.
350fn vec2d_to_array2(rows: &[Vec<f64>]) -> Result<Array2<f64>, String> {
351    if rows.is_empty() {
352        return Err("Matrix has no rows".to_string());
353    }
354    let nrows = rows.len();
355    let ncols = rows[0].len();
356    for (i, row) in rows.iter().enumerate() {
357        if row.len() != ncols {
358            return Err(format!(
359                "Row {} has {} columns, expected {}",
360                i,
361                row.len(),
362                ncols
363            ));
364        }
365    }
366    let flat: Vec<f64> = rows.iter().flat_map(|r| r.iter().cloned()).collect();
367    Array2::from_shape_vec((nrows, ncols), flat).map_err(|e| format!("Shape error: {}", e))
368}
369
370/// Helper: convert Array2 to Vec<Vec<f64>>.
371fn array2_to_vec2d(a: &Array2<f64>) -> Vec<Vec<f64>> {
372    let (nrows, ncols) = a.dim();
373    (0..nrows)
374        .map(|i| (0..ncols).map(|j| a[[i, j]]).collect())
375        .collect()
376}
377
378/// Batch matrix multiplication: compute A_i @ B_i for each pair.
379///
380/// Parameters:
381///     a_list: List of 2D matrices (each represented as Vec<Vec<f64>>)
382///     b_list: List of 2D matrices (same length as a_list)
383///
384/// Returns:
385///     List of result matrices (Vec<Vec<Vec<f64>>>)
386#[pyfunction]
387fn batch_matmul_py(
388    a_list: Vec<Vec<Vec<f64>>>,
389    b_list: Vec<Vec<Vec<f64>>>,
390) -> PyResult<Vec<Vec<Vec<f64>>>> {
391    if a_list.len() != b_list.len() {
392        return Err(PyRuntimeError::new_err(format!(
393            "a_list length {} does not match b_list length {}",
394            a_list.len(),
395            b_list.len()
396        )));
397    }
398    if a_list.is_empty() {
399        return Ok(vec![]);
400    }
401
402    let results: Vec<Result<Vec<Vec<f64>>, String>> = a_list
403        .par_iter()
404        .zip(b_list.par_iter())
405        .map(|(a_rows, b_rows)| {
406            let a = vec2d_to_array2(a_rows)?;
407            let b = vec2d_to_array2(b_rows)?;
408            let (_, a_cols) = a.dim();
409            let (b_rows_n, _) = b.dim();
410            if a_cols != b_rows_n {
411                return Err(format!(
412                    "Incompatible shapes for matmul: inner dims {} != {}",
413                    a_cols, b_rows_n
414                ));
415            }
416            // Use ndarray's dot for matrix multiplication
417            let c = a.dot(&b);
418            Ok(array2_to_vec2d(&c))
419        })
420        .collect();
421
422    results
423        .into_iter()
424        .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_matmul failed: {}", e))))
425        .collect()
426}
427
428/// Batch SVD: compute SVD for multiple matrices.
429///
430/// Parameters:
431///     matrices: List of 2D matrices
432///
433/// Returns:
434///     List of (U, S, Vt) tuples where U and Vt are 2D matrices, S is 1D
435#[pyfunction]
436fn batch_svd_py(matrices: Vec<Vec<Vec<f64>>>) -> PyResult<Vec<SvdResult>> {
437    if matrices.is_empty() {
438        return Ok(vec![]);
439    }
440
441    let results: Vec<Result<SvdResult, String>> = matrices
442        .par_iter()
443        .map(|mat_rows| {
444            let mat = vec2d_to_array2(mat_rows)?;
445            let (u, s, vt) = scirs2_linalg::svd_f64_lapack(&mat.view(), false)
446                .map_err(|e| format!("SVD failed: {}", e))?;
447            Ok((array2_to_vec2d(&u), s.to_vec(), array2_to_vec2d(&vt)))
448        })
449        .collect();
450
451    results
452        .into_iter()
453        .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_svd failed: {}", e))))
454        .collect()
455}
456
457/// Batch linear solve: solve A_i @ x_i = b_i for each pair.
458///
459/// Parameters:
460///     a_list: List of square coefficient matrices
461///     b_list: List of right-hand-side vectors (one per matrix)
462///
463/// Returns:
464///     List of solution vectors x_i
465#[pyfunction]
466fn batch_solve_py(a_list: Vec<Vec<Vec<f64>>>, b_list: Vec<Vec<f64>>) -> PyResult<Vec<Vec<f64>>> {
467    if a_list.len() != b_list.len() {
468        return Err(PyRuntimeError::new_err(format!(
469            "a_list length {} does not match b_list length {}",
470            a_list.len(),
471            b_list.len()
472        )));
473    }
474    if a_list.is_empty() {
475        return Ok(vec![]);
476    }
477
478    let results: Vec<Result<Vec<f64>, String>> = a_list
479        .par_iter()
480        .zip(b_list.par_iter())
481        .map(|(a_rows, b_vec)| {
482            let a = vec2d_to_array2(a_rows)?;
483            let b = Array1::from_vec(b_vec.clone());
484            let x = scirs2_linalg::solve_f64_lapack(&a.view(), &b.view())
485                .map_err(|e| format!("Solve failed: {}", e))?;
486            Ok(x.to_vec())
487        })
488        .collect();
489
490    results
491        .into_iter()
492        .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_solve failed: {}", e))))
493        .collect()
494}
495
496/// Batch matrix norms: compute a matrix norm for each matrix in the list.
497///
498/// Parameters:
499///     matrices: List of 2D matrices
500///     ord: Norm type — "fro" (Frobenius, default), "nuc" (nuclear), "1", "inf"
501///
502/// Returns:
503///     Vec<f64> of norm values, one per input matrix
504#[pyfunction]
505#[pyo3(signature = (matrices, ord=None))]
506fn batch_matrix_norm_py(matrices: Vec<Vec<Vec<f64>>>, ord: Option<&str>) -> PyResult<Vec<f64>> {
507    if matrices.is_empty() {
508        return Ok(vec![]);
509    }
510    let norm_type = ord.unwrap_or("fro");
511    // Validate norm type early
512    match norm_type {
513        "fro" | "nuc" | "1" | "inf" => {}
514        other => {
515            return Err(PyRuntimeError::new_err(format!(
516                "Unknown norm type '{}'. Supported: fro, nuc, 1, inf",
517                other
518            )));
519        }
520    }
521
522    let results: Vec<Result<f64, String>> = matrices
523        .par_iter()
524        .map(|mat_rows| {
525            let mat = vec2d_to_array2(mat_rows)?;
526            scirs2_linalg::matrix_norm(&mat.view(), norm_type, None)
527                .map_err(|e| format!("Matrix norm failed: {}", e))
528        })
529        .collect();
530
531    results
532        .into_iter()
533        .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("batch_matrix_norm failed: {}", e))))
534        .collect()
535}
536
537/// Python module registration
538pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
539    // Basic operations
540    m.add_function(wrap_pyfunction!(det_py, m)?)?;
541    m.add_function(wrap_pyfunction!(inv_py, m)?)?;
542    m.add_function(wrap_pyfunction!(trace_py, m)?)?;
543
544    // Decompositions
545    m.add_function(wrap_pyfunction!(lu_py, m)?)?;
546    m.add_function(wrap_pyfunction!(qr_py, m)?)?;
547    m.add_function(wrap_pyfunction!(svd_py, m)?)?;
548    m.add_function(wrap_pyfunction!(cholesky_py, m)?)?;
549    m.add_function(wrap_pyfunction!(eig_py, m)?)?;
550    m.add_function(wrap_pyfunction!(eigh_py, m)?)?;
551    m.add_function(wrap_pyfunction!(eigvals_py, m)?)?;
552
553    // Solvers
554    m.add_function(wrap_pyfunction!(solve_py, m)?)?;
555    m.add_function(wrap_pyfunction!(lstsq_py, m)?)?;
556
557    // Norms
558    m.add_function(wrap_pyfunction!(matrix_norm_py, m)?)?;
559    m.add_function(wrap_pyfunction!(vector_norm_py, m)?)?;
560    m.add_function(wrap_pyfunction!(cond_py, m)?)?;
561    m.add_function(wrap_pyfunction!(matrix_rank_py, m)?)?;
562    m.add_function(wrap_pyfunction!(pinv_py, m)?)?;
563
564    // Batch/vectorized APIs
565    m.add_function(wrap_pyfunction!(batch_matmul_py, m)?)?;
566    m.add_function(wrap_pyfunction!(batch_svd_py, m)?)?;
567    m.add_function(wrap_pyfunction!(batch_solve_py, m)?)?;
568    m.add_function(wrap_pyfunction!(batch_matrix_norm_py, m)?)?;
569
570    Ok(())
571}