Skip to main content

scirs2/
sparse.rs

1//! Python bindings for scirs2-sparse
2//!
3//! This module provides Python bindings for sparse matrix operations,
4//! including CSR, CSC, COO formats and basic sparse operations.
5
6use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9
10// NumPy types for Python array interface (scirs2-numpy with native ndarray 0.17)
11use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13// ndarray types from scirs2-core
14use scirs2_core::ndarray::Array1;
15
16// Direct imports from scirs2-sparse
17use scirs2_sparse::{diag_matrix, eye, CooArray, CscArray, CsrArray, SparseArray};
18
19// ========================================
20// SPARSE ARRAY CREATION
21// ========================================
22
23/// Create a CSR sparse array from triplets (row, col, data)
24#[pyfunction]
25#[pyo3(signature = (rows, cols, data, shape, sum_duplicates=false))]
26#[allow(clippy::too_many_arguments)]
27fn csr_array_from_triplets(
28    py: Python,
29    rows: Vec<usize>,
30    cols: Vec<usize>,
31    data: Vec<f64>,
32    shape: (usize, usize),
33    sum_duplicates: bool,
34) -> PyResult<Py<PyAny>> {
35    let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, sum_duplicates)
36        .map_err(|e| PyRuntimeError::new_err(format!("Failed to create CSR array: {}", e)))?;
37
38    // Return as dict with internal representation
39    let dict = PyDict::new(py);
40    dict.set_item("format", "csr")?;
41    dict.set_item("shape", shape)?;
42    dict.set_item("nnz", csr.nnz())?;
43
44    // Get CSR components
45    dict.set_item("indptr", csr.get_indptr().to_vec())?;
46    dict.set_item("indices", csr.get_indices().to_vec())?;
47    dict.set_item("data", csr.get_data().to_vec())?;
48
49    Ok(dict.into())
50}
51
52/// Create a COO sparse array from triplets
53#[pyfunction]
54#[pyo3(signature = (rows, cols, data, shape, sum_duplicates=false))]
55#[allow(clippy::too_many_arguments)]
56fn coo_array_from_triplets(
57    py: Python,
58    rows: Vec<usize>,
59    cols: Vec<usize>,
60    data: Vec<f64>,
61    shape: (usize, usize),
62    sum_duplicates: bool,
63) -> PyResult<Py<PyAny>> {
64    let coo = CooArray::from_triplets(&rows, &cols, &data, shape, sum_duplicates)
65        .map_err(|e| PyRuntimeError::new_err(format!("Failed to create COO array: {}", e)))?;
66
67    let dict = PyDict::new(py);
68    dict.set_item("format", "coo")?;
69    dict.set_item("shape", shape)?;
70    dict.set_item("nnz", coo.nnz())?;
71
72    // Get COO components
73    dict.set_item("row", coo.get_rows().to_vec())?;
74    dict.set_item("col", coo.get_cols().to_vec())?;
75    dict.set_item("data", coo.get_data().to_vec())?;
76
77    Ok(dict.into())
78}
79
80/// Create a CSC sparse array from triplets
81#[pyfunction]
82#[pyo3(signature = (rows, cols, data, shape, sum_duplicates=false))]
83#[allow(clippy::too_many_arguments)]
84fn csc_array_from_triplets(
85    py: Python,
86    rows: Vec<usize>,
87    cols: Vec<usize>,
88    data: Vec<f64>,
89    shape: (usize, usize),
90    sum_duplicates: bool,
91) -> PyResult<Py<PyAny>> {
92    let csc = CscArray::from_triplets(&rows, &cols, &data, shape, sum_duplicates)
93        .map_err(|e| PyRuntimeError::new_err(format!("Failed to create CSC array: {}", e)))?;
94
95    let dict = PyDict::new(py);
96    dict.set_item("format", "csc")?;
97    dict.set_item("shape", shape)?;
98    dict.set_item("nnz", csc.nnz())?;
99
100    // Get CSC components
101    dict.set_item("indptr", csc.get_indptr().to_vec())?;
102    dict.set_item("indices", csc.get_indices().to_vec())?;
103    dict.set_item("data", csc.get_data().to_vec())?;
104
105    Ok(dict.into())
106}
107
108/// Create sparse identity matrix
109#[pyfunction]
110fn sparse_eye_py(py: Python, n: usize) -> PyResult<Py<PyAny>> {
111    let csr = eye::<f64>(n)
112        .map_err(|e| PyRuntimeError::new_err(format!("Failed to create identity matrix: {}", e)))?;
113
114    let dict = PyDict::new(py);
115    dict.set_item("format", "csr")?;
116    dict.set_item("shape", (n, n))?;
117    dict.set_item("nnz", csr.data.len())?;
118
119    // CsrMatrix has public fields
120    dict.set_item("indptr", csr.indptr.clone())?;
121    dict.set_item("indices", csr.indices.clone())?;
122    dict.set_item("data", csr.data.clone())?;
123
124    Ok(dict.into())
125}
126
127/// Create sparse diagonal matrix from vector
128#[pyfunction]
129fn sparse_diag_py(py: Python, diag: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyAny>> {
130    let binding = diag.readonly();
131    let diag_data = binding.as_array();
132
133    // diag_matrix takes &[F] and Option<usize>, returns CsrMatrix
134    let diag_slice: Vec<f64> = diag_data.iter().copied().collect();
135    let csr = diag_matrix::<f64>(&diag_slice, None)
136        .map_err(|e| PyRuntimeError::new_err(format!("Failed to create diagonal matrix: {}", e)))?;
137
138    let n = diag_data.len();
139    let dict = PyDict::new(py);
140    dict.set_item("format", "csr")?;
141    dict.set_item("shape", (n, n))?;
142    dict.set_item("nnz", csr.data.len())?;
143
144    // CsrMatrix has public fields
145    dict.set_item("indptr", csr.indptr.clone())?;
146    dict.set_item("indices", csr.indices.clone())?;
147    dict.set_item("data", csr.data.clone())?;
148
149    Ok(dict.into())
150}
151
152// ========================================
153// SPARSE ARRAY OPERATIONS
154// ========================================
155
156/// Convert sparse array to dense
157#[pyfunction]
158fn sparse_to_dense_py(
159    py: Python,
160    indptr: Vec<usize>,
161    indices: Vec<usize>,
162    data: Vec<f64>,
163    shape: (usize, usize),
164) -> PyResult<Py<PyArray2<f64>>> {
165    // Reconstruct CSR array - CsrArray::new expects (data, indices, indptr, shape)
166    let csr = CsrArray::new(
167        Array1::from_vec(data),
168        Array1::from_vec(indices),
169        Array1::from_vec(indptr),
170        shape,
171    )
172    .map_err(|e| PyRuntimeError::new_err(format!("Invalid CSR data: {}", e)))?;
173
174    let dense = csr.to_array();
175    Ok(dense.into_pyarray(py).unbind())
176}
177
178/// Sparse matrix-vector multiplication
179#[pyfunction]
180fn sparse_matvec_py(
181    py: Python,
182    indptr: Vec<usize>,
183    indices: Vec<usize>,
184    data: Vec<f64>,
185    shape: (usize, usize),
186    x: &Bound<'_, PyArray1<f64>>,
187) -> PyResult<Py<PyArray1<f64>>> {
188    let csr = CsrArray::new(
189        Array1::from_vec(data),
190        Array1::from_vec(indices),
191        Array1::from_vec(indptr),
192        shape,
193    )
194    .map_err(|e| PyRuntimeError::new_err(format!("Invalid CSR data: {}", e)))?;
195
196    let x_binding = x.readonly();
197    let x_data = x_binding.as_array();
198
199    let result = csr.dot_vector(&x_data).map_err(|e| {
200        PyRuntimeError::new_err(format!("Matrix-vector multiplication failed: {}", e))
201    })?;
202
203    Ok(result.into_pyarray(py).unbind())
204}
205
206/// Sparse matrix-matrix multiplication (returns CSR)
207#[pyfunction]
208#[allow(clippy::too_many_arguments)]
209fn sparse_matmul_py(
210    py: Python,
211    a_indptr: Vec<usize>,
212    a_indices: Vec<usize>,
213    a_data: Vec<f64>,
214    a_shape: (usize, usize),
215    b_indptr: Vec<usize>,
216    b_indices: Vec<usize>,
217    b_data: Vec<f64>,
218    b_shape: (usize, usize),
219) -> PyResult<Py<PyAny>> {
220    let a = CsrArray::new(
221        Array1::from_vec(a_data),
222        Array1::from_vec(a_indices),
223        Array1::from_vec(a_indptr),
224        a_shape,
225    )
226    .map_err(|e| PyRuntimeError::new_err(format!("Invalid CSR data for A: {}", e)))?;
227    let b = CsrArray::new(
228        Array1::from_vec(b_data),
229        Array1::from_vec(b_indices),
230        Array1::from_vec(b_indptr),
231        b_shape,
232    )
233    .map_err(|e| PyRuntimeError::new_err(format!("Invalid CSR data for B: {}", e)))?;
234
235    let result = a
236        .dot(&b)
237        .map_err(|e| PyRuntimeError::new_err(format!("Matrix multiplication failed: {}", e)))?;
238
239    // Use find() to get row, col, data from the boxed SparseArray
240    let (rows, cols, data) = result.find();
241    let shape = result.shape();
242    let nnz = result.nnz();
243
244    // Convert result to CSR format for return
245    let result_csr = CsrArray::from_triplets(
246        rows.as_slice().unwrap_or(&[]),
247        cols.as_slice().unwrap_or(&[]),
248        data.as_slice().unwrap_or(&[]),
249        shape,
250        false,
251    )
252    .map_err(|e| PyRuntimeError::new_err(format!("Failed to create result CSR: {}", e)))?;
253
254    let dict = PyDict::new(py);
255    dict.set_item("format", "csr")?;
256    dict.set_item("shape", shape)?;
257    dict.set_item("nnz", nnz)?;
258
259    dict.set_item("indptr", result_csr.get_indptr().to_vec())?;
260    dict.set_item("indices", result_csr.get_indices().to_vec())?;
261    dict.set_item("data", result_csr.get_data().to_vec())?;
262
263    Ok(dict.into())
264}
265
266/// Transpose sparse matrix
267#[pyfunction]
268fn sparse_transpose_py(
269    py: Python,
270    indptr: Vec<usize>,
271    indices: Vec<usize>,
272    data: Vec<f64>,
273    shape: (usize, usize),
274) -> PyResult<Py<PyAny>> {
275    let csr = CsrArray::new(
276        Array1::from_vec(data),
277        Array1::from_vec(indices),
278        Array1::from_vec(indptr),
279        shape,
280    )
281    .map_err(|e| PyRuntimeError::new_err(format!("Invalid CSR data: {}", e)))?;
282
283    let transposed = csr
284        .transpose()
285        .map_err(|e| PyRuntimeError::new_err(format!("Transpose failed: {}", e)))?;
286
287    // Use find() to get row, col, data from the boxed SparseArray
288    let (rows, cols, data) = transposed.find();
289    let t_shape = transposed.shape();
290    let nnz = transposed.nnz();
291
292    // Convert result to CSR format for return
293    let result_csr = CsrArray::from_triplets(
294        rows.as_slice().unwrap_or(&[]),
295        cols.as_slice().unwrap_or(&[]),
296        data.as_slice().unwrap_or(&[]),
297        t_shape,
298        false,
299    )
300    .map_err(|e| PyRuntimeError::new_err(format!("Failed to create transposed CSR: {}", e)))?;
301
302    let dict = PyDict::new(py);
303    dict.set_item("format", "csr")?;
304    dict.set_item("shape", t_shape)?;
305    dict.set_item("nnz", nnz)?;
306
307    dict.set_item("indptr", result_csr.get_indptr().to_vec())?;
308    dict.set_item("indices", result_csr.get_indices().to_vec())?;
309    dict.set_item("data", result_csr.get_data().to_vec())?;
310
311    Ok(dict.into())
312}
313
314/// Python module registration
315pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
316    // Array creation
317    m.add_function(wrap_pyfunction!(csr_array_from_triplets, m)?)?;
318    m.add_function(wrap_pyfunction!(coo_array_from_triplets, m)?)?;
319    m.add_function(wrap_pyfunction!(csc_array_from_triplets, m)?)?;
320    m.add_function(wrap_pyfunction!(sparse_eye_py, m)?)?;
321    m.add_function(wrap_pyfunction!(sparse_diag_py, m)?)?;
322
323    // Array operations
324    m.add_function(wrap_pyfunction!(sparse_to_dense_py, m)?)?;
325    m.add_function(wrap_pyfunction!(sparse_matvec_py, m)?)?;
326    m.add_function(wrap_pyfunction!(sparse_matmul_py, m)?)?;
327    m.add_function(wrap_pyfunction!(sparse_transpose_py, m)?)?;
328
329    Ok(())
330}