1use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9
10use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13use scirs2_core::ndarray::Array1;
15
16use scirs2_sparse::{diag_matrix, eye, CooArray, CscArray, CsrArray, SparseArray};
18
19#[pyfunction]
25#[pyo3(signature = (rows, cols, data, shape, sum_duplicates=false))]
26#[allow(clippy::too_many_arguments)]
27fn csr_array_from_triplets(
28 py: Python,
29 rows: Vec<usize>,
30 cols: Vec<usize>,
31 data: Vec<f64>,
32 shape: (usize, usize),
33 sum_duplicates: bool,
34) -> PyResult<Py<PyAny>> {
35 let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, sum_duplicates)
36 .map_err(|e| PyRuntimeError::new_err(format!("Failed to create CSR array: {}", e)))?;
37
38 let dict = PyDict::new(py);
40 dict.set_item("format", "csr")?;
41 dict.set_item("shape", shape)?;
42 dict.set_item("nnz", csr.nnz())?;
43
44 dict.set_item("indptr", csr.get_indptr().to_vec())?;
46 dict.set_item("indices", csr.get_indices().to_vec())?;
47 dict.set_item("data", csr.get_data().to_vec())?;
48
49 Ok(dict.into())
50}
51
52#[pyfunction]
54#[pyo3(signature = (rows, cols, data, shape, sum_duplicates=false))]
55#[allow(clippy::too_many_arguments)]
56fn coo_array_from_triplets(
57 py: Python,
58 rows: Vec<usize>,
59 cols: Vec<usize>,
60 data: Vec<f64>,
61 shape: (usize, usize),
62 sum_duplicates: bool,
63) -> PyResult<Py<PyAny>> {
64 let coo = CooArray::from_triplets(&rows, &cols, &data, shape, sum_duplicates)
65 .map_err(|e| PyRuntimeError::new_err(format!("Failed to create COO array: {}", e)))?;
66
67 let dict = PyDict::new(py);
68 dict.set_item("format", "coo")?;
69 dict.set_item("shape", shape)?;
70 dict.set_item("nnz", coo.nnz())?;
71
72 dict.set_item("row", coo.get_rows().to_vec())?;
74 dict.set_item("col", coo.get_cols().to_vec())?;
75 dict.set_item("data", coo.get_data().to_vec())?;
76
77 Ok(dict.into())
78}
79
80#[pyfunction]
82#[pyo3(signature = (rows, cols, data, shape, sum_duplicates=false))]
83#[allow(clippy::too_many_arguments)]
84fn csc_array_from_triplets(
85 py: Python,
86 rows: Vec<usize>,
87 cols: Vec<usize>,
88 data: Vec<f64>,
89 shape: (usize, usize),
90 sum_duplicates: bool,
91) -> PyResult<Py<PyAny>> {
92 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, sum_duplicates)
93 .map_err(|e| PyRuntimeError::new_err(format!("Failed to create CSC array: {}", e)))?;
94
95 let dict = PyDict::new(py);
96 dict.set_item("format", "csc")?;
97 dict.set_item("shape", shape)?;
98 dict.set_item("nnz", csc.nnz())?;
99
100 dict.set_item("indptr", csc.get_indptr().to_vec())?;
102 dict.set_item("indices", csc.get_indices().to_vec())?;
103 dict.set_item("data", csc.get_data().to_vec())?;
104
105 Ok(dict.into())
106}
107
108#[pyfunction]
110fn sparse_eye_py(py: Python, n: usize) -> PyResult<Py<PyAny>> {
111 let csr = eye::<f64>(n)
112 .map_err(|e| PyRuntimeError::new_err(format!("Failed to create identity matrix: {}", e)))?;
113
114 let dict = PyDict::new(py);
115 dict.set_item("format", "csr")?;
116 dict.set_item("shape", (n, n))?;
117 dict.set_item("nnz", csr.data.len())?;
118
119 dict.set_item("indptr", csr.indptr.clone())?;
121 dict.set_item("indices", csr.indices.clone())?;
122 dict.set_item("data", csr.data.clone())?;
123
124 Ok(dict.into())
125}
126
127#[pyfunction]
129fn sparse_diag_py(py: Python, diag: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyAny>> {
130 let binding = diag.readonly();
131 let diag_data = binding.as_array();
132
133 let diag_slice: Vec<f64> = diag_data.iter().copied().collect();
135 let csr = diag_matrix::<f64>(&diag_slice, None)
136 .map_err(|e| PyRuntimeError::new_err(format!("Failed to create diagonal matrix: {}", e)))?;
137
138 let n = diag_data.len();
139 let dict = PyDict::new(py);
140 dict.set_item("format", "csr")?;
141 dict.set_item("shape", (n, n))?;
142 dict.set_item("nnz", csr.data.len())?;
143
144 dict.set_item("indptr", csr.indptr.clone())?;
146 dict.set_item("indices", csr.indices.clone())?;
147 dict.set_item("data", csr.data.clone())?;
148
149 Ok(dict.into())
150}
151
152#[pyfunction]
158fn sparse_to_dense_py(
159 py: Python,
160 indptr: Vec<usize>,
161 indices: Vec<usize>,
162 data: Vec<f64>,
163 shape: (usize, usize),
164) -> PyResult<Py<PyArray2<f64>>> {
165 let csr = CsrArray::new(
167 Array1::from_vec(data),
168 Array1::from_vec(indices),
169 Array1::from_vec(indptr),
170 shape,
171 )
172 .map_err(|e| PyRuntimeError::new_err(format!("Invalid CSR data: {}", e)))?;
173
174 let dense = csr.to_array();
175 Ok(dense.into_pyarray(py).unbind())
176}
177
178#[pyfunction]
180fn sparse_matvec_py(
181 py: Python,
182 indptr: Vec<usize>,
183 indices: Vec<usize>,
184 data: Vec<f64>,
185 shape: (usize, usize),
186 x: &Bound<'_, PyArray1<f64>>,
187) -> PyResult<Py<PyArray1<f64>>> {
188 let csr = CsrArray::new(
189 Array1::from_vec(data),
190 Array1::from_vec(indices),
191 Array1::from_vec(indptr),
192 shape,
193 )
194 .map_err(|e| PyRuntimeError::new_err(format!("Invalid CSR data: {}", e)))?;
195
196 let x_binding = x.readonly();
197 let x_data = x_binding.as_array();
198
199 let result = csr.dot_vector(&x_data).map_err(|e| {
200 PyRuntimeError::new_err(format!("Matrix-vector multiplication failed: {}", e))
201 })?;
202
203 Ok(result.into_pyarray(py).unbind())
204}
205
206#[pyfunction]
208#[allow(clippy::too_many_arguments)]
209fn sparse_matmul_py(
210 py: Python,
211 a_indptr: Vec<usize>,
212 a_indices: Vec<usize>,
213 a_data: Vec<f64>,
214 a_shape: (usize, usize),
215 b_indptr: Vec<usize>,
216 b_indices: Vec<usize>,
217 b_data: Vec<f64>,
218 b_shape: (usize, usize),
219) -> PyResult<Py<PyAny>> {
220 let a = CsrArray::new(
221 Array1::from_vec(a_data),
222 Array1::from_vec(a_indices),
223 Array1::from_vec(a_indptr),
224 a_shape,
225 )
226 .map_err(|e| PyRuntimeError::new_err(format!("Invalid CSR data for A: {}", e)))?;
227 let b = CsrArray::new(
228 Array1::from_vec(b_data),
229 Array1::from_vec(b_indices),
230 Array1::from_vec(b_indptr),
231 b_shape,
232 )
233 .map_err(|e| PyRuntimeError::new_err(format!("Invalid CSR data for B: {}", e)))?;
234
235 let result = a
236 .dot(&b)
237 .map_err(|e| PyRuntimeError::new_err(format!("Matrix multiplication failed: {}", e)))?;
238
239 let (rows, cols, data) = result.find();
241 let shape = result.shape();
242 let nnz = result.nnz();
243
244 let result_csr = CsrArray::from_triplets(
246 rows.as_slice().unwrap_or(&[]),
247 cols.as_slice().unwrap_or(&[]),
248 data.as_slice().unwrap_or(&[]),
249 shape,
250 false,
251 )
252 .map_err(|e| PyRuntimeError::new_err(format!("Failed to create result CSR: {}", e)))?;
253
254 let dict = PyDict::new(py);
255 dict.set_item("format", "csr")?;
256 dict.set_item("shape", shape)?;
257 dict.set_item("nnz", nnz)?;
258
259 dict.set_item("indptr", result_csr.get_indptr().to_vec())?;
260 dict.set_item("indices", result_csr.get_indices().to_vec())?;
261 dict.set_item("data", result_csr.get_data().to_vec())?;
262
263 Ok(dict.into())
264}
265
266#[pyfunction]
268fn sparse_transpose_py(
269 py: Python,
270 indptr: Vec<usize>,
271 indices: Vec<usize>,
272 data: Vec<f64>,
273 shape: (usize, usize),
274) -> PyResult<Py<PyAny>> {
275 let csr = CsrArray::new(
276 Array1::from_vec(data),
277 Array1::from_vec(indices),
278 Array1::from_vec(indptr),
279 shape,
280 )
281 .map_err(|e| PyRuntimeError::new_err(format!("Invalid CSR data: {}", e)))?;
282
283 let transposed = csr
284 .transpose()
285 .map_err(|e| PyRuntimeError::new_err(format!("Transpose failed: {}", e)))?;
286
287 let (rows, cols, data) = transposed.find();
289 let t_shape = transposed.shape();
290 let nnz = transposed.nnz();
291
292 let result_csr = CsrArray::from_triplets(
294 rows.as_slice().unwrap_or(&[]),
295 cols.as_slice().unwrap_or(&[]),
296 data.as_slice().unwrap_or(&[]),
297 t_shape,
298 false,
299 )
300 .map_err(|e| PyRuntimeError::new_err(format!("Failed to create transposed CSR: {}", e)))?;
301
302 let dict = PyDict::new(py);
303 dict.set_item("format", "csr")?;
304 dict.set_item("shape", t_shape)?;
305 dict.set_item("nnz", nnz)?;
306
307 dict.set_item("indptr", result_csr.get_indptr().to_vec())?;
308 dict.set_item("indices", result_csr.get_indices().to_vec())?;
309 dict.set_item("data", result_csr.get_data().to_vec())?;
310
311 Ok(dict.into())
312}
313
314pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
316 m.add_function(wrap_pyfunction!(csr_array_from_triplets, m)?)?;
318 m.add_function(wrap_pyfunction!(coo_array_from_triplets, m)?)?;
319 m.add_function(wrap_pyfunction!(csc_array_from_triplets, m)?)?;
320 m.add_function(wrap_pyfunction!(sparse_eye_py, m)?)?;
321 m.add_function(wrap_pyfunction!(sparse_diag_py, m)?)?;
322
323 m.add_function(wrap_pyfunction!(sparse_to_dense_py, m)?)?;
325 m.add_function(wrap_pyfunction!(sparse_matvec_py, m)?)?;
326 m.add_function(wrap_pyfunction!(sparse_matmul_py, m)?)?;
327 m.add_function(wrap_pyfunction!(sparse_transpose_py, m)?)?;
328
329 Ok(())
330}