Skip to main content

torsh_sparse/
scipy_sparse.rs

1//! SciPy sparse matrix interoperability
2//!
3//! This module provides conversion between ToRSh sparse tensors and SciPy sparse matrices,
4//! enabling seamless integration with Python scientific computing ecosystem.
5
6#[cfg(feature = "scipy")]
7use numpy::{PyArray1, PyReadonlyArray1};
8#[cfg(feature = "scipy")]
9use pyo3::prelude::*;
10#[cfg(feature = "scipy")]
11use pyo3::types::PyDict;
12#[cfg(feature = "scipy")]
13use pyo3::Bound;
14
15use crate::*;
16use std::collections::HashMap;
17use torsh_core::Result as TorshResult;
18
19/// Supported SciPy sparse matrix formats
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ScipyFormat {
22    /// Compressed Sparse Row (csr_matrix)
23    Csr,
24    /// Compressed Sparse Column (csc_matrix)
25    Csc,
26    /// Coordinate format (coo_matrix)
27    Coo,
28    /// Block Sparse Row (bsr_matrix)
29    Bsr,
30    /// Diagonal format (dia_matrix)
31    Dia,
32}
33
34impl From<SparseFormat> for ScipyFormat {
35    fn from(format: SparseFormat) -> Self {
36        match format {
37            SparseFormat::Coo => ScipyFormat::Coo,
38            SparseFormat::Csr => ScipyFormat::Csr,
39            SparseFormat::Csc => ScipyFormat::Csc,
40            SparseFormat::Bsr => ScipyFormat::Bsr,
41            SparseFormat::Dia => ScipyFormat::Dia,
42            SparseFormat::Ell => ScipyFormat::Csr, // ELL -> CSR fallback
43            SparseFormat::Rle => ScipyFormat::Csr, // RLE -> CSR fallback
44            SparseFormat::Symmetric => ScipyFormat::Csr, // Symmetric -> CSR fallback
45            SparseFormat::Dsr => ScipyFormat::Csr, // DSR -> CSR fallback
46        }
47    }
48}
49
50impl From<ScipyFormat> for SparseFormat {
51    fn from(format: ScipyFormat) -> Self {
52        match format {
53            ScipyFormat::Coo => SparseFormat::Coo,
54            ScipyFormat::Csr => SparseFormat::Csr,
55            ScipyFormat::Csc => SparseFormat::Csc,
56            ScipyFormat::Bsr => SparseFormat::Bsr,
57            ScipyFormat::Dia => SparseFormat::Dia,
58        }
59    }
60}
61
62/// SciPy sparse matrix representation for data exchange
63#[derive(Debug, Clone)]
64pub struct ScipySparseData {
65    /// Matrix format
66    pub format: ScipyFormat,
67    /// Matrix shape (rows, cols)
68    pub shape: (usize, usize),
69    /// Data values
70    pub data: Vec<f64>,
71    /// Row indices (for COO and CSR) or column indices (for CSC)
72    pub indices: Vec<usize>,
73    /// Row pointers (for CSR) or column pointers (for CSC) or coordinate rows (for COO)
74    pub indptr_or_row: Vec<usize>,
75    /// Block size for BSR format
76    pub blocksize: Option<(usize, usize)>,
77    /// Number of diagonals for DIA format
78    pub diagonals: Option<Vec<i32>>,
79}
80
81impl ScipySparseData {
82    /// Create new SciPy sparse data
83    pub fn new(format: ScipyFormat, shape: (usize, usize)) -> Self {
84        Self {
85            format,
86            shape,
87            data: Vec::new(),
88            indices: Vec::new(),
89            indptr_or_row: Vec::new(),
90            blocksize: None,
91            diagonals: None,
92        }
93    }
94
95    /// Create from COO data
96    pub fn from_coo(
97        shape: (usize, usize),
98        row_indices: Vec<usize>,
99        col_indices: Vec<usize>,
100        values: Vec<f64>,
101    ) -> Self {
102        Self {
103            format: ScipyFormat::Coo,
104            shape,
105            data: values,
106            indices: col_indices,
107            indptr_or_row: row_indices,
108            blocksize: None,
109            diagonals: None,
110        }
111    }
112
113    /// Create from CSR data
114    pub fn from_csr(
115        shape: (usize, usize),
116        row_ptr: Vec<usize>,
117        col_indices: Vec<usize>,
118        values: Vec<f64>,
119    ) -> Self {
120        Self {
121            format: ScipyFormat::Csr,
122            shape,
123            data: values,
124            indices: col_indices,
125            indptr_or_row: row_ptr,
126            blocksize: None,
127            diagonals: None,
128        }
129    }
130
131    /// Create from CSC data
132    pub fn from_csc(
133        shape: (usize, usize),
134        col_ptr: Vec<usize>,
135        row_indices: Vec<usize>,
136        values: Vec<f64>,
137    ) -> Self {
138        Self {
139            format: ScipyFormat::Csc,
140            shape,
141            data: values,
142            indices: row_indices,
143            indptr_or_row: col_ptr,
144            blocksize: None,
145            diagonals: None,
146        }
147    }
148}
149
150/// SciPy sparse matrix integration utilities
151pub struct ScipySparseIntegration;
152
153impl ScipySparseIntegration {
154    /// Convert ToRSh sparse tensor to SciPy sparse data
155    pub fn to_scipy_data(sparse: &dyn SparseTensor) -> TorshResult<ScipySparseData> {
156        let shape = sparse.shape();
157        let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
158
159        match sparse.format() {
160            SparseFormat::Coo => {
161                let coo = sparse.to_coo()?;
162                let triplets = coo.triplets();
163
164                let mut row_indices = Vec::new();
165                let mut col_indices = Vec::new();
166                let mut values = Vec::new();
167
168                for (row, col, val) in triplets {
169                    row_indices.push(row);
170                    col_indices.push(col);
171                    values.push(val as f64);
172                }
173
174                Ok(ScipySparseData::from_coo(
175                    (rows, cols),
176                    row_indices,
177                    col_indices,
178                    values,
179                ))
180            }
181            SparseFormat::Csr => {
182                let csr = sparse.to_csr()?;
183                let row_ptr = csr.row_ptr().to_vec();
184                let col_indices = csr.col_indices().to_vec();
185                let values = csr.values().iter().map(|&v| v as f64).collect();
186
187                Ok(ScipySparseData::from_csr(
188                    (rows, cols),
189                    row_ptr,
190                    col_indices,
191                    values,
192                ))
193            }
194            SparseFormat::Csc => {
195                let csc = sparse.to_csc()?;
196                let col_ptr = csc.col_ptr().to_vec();
197                let row_indices = csc.row_indices().to_vec();
198                let values = csc.values().iter().map(|&v| v as f64).collect();
199
200                Ok(ScipySparseData::from_csc(
201                    (rows, cols),
202                    col_ptr,
203                    row_indices,
204                    values,
205                ))
206            }
207            _ => {
208                // Convert other formats to COO first
209                let coo = sparse.to_coo()?;
210                Self::to_scipy_data(&coo)
211            }
212        }
213    }
214
215    /// Convert SciPy sparse data to ToRSh sparse tensor
216    pub fn from_scipy_data(
217        data: &ScipySparseData,
218    ) -> TorshResult<Box<dyn SparseTensor + Send + Sync>> {
219        let shape = Shape::new(vec![data.shape.0, data.shape.1]);
220
221        match data.format {
222            ScipyFormat::Coo => {
223                let mut rows = Vec::new();
224                let mut cols = Vec::new();
225                let mut values = Vec::new();
226
227                for i in 0..data.data.len() {
228                    rows.push(data.indptr_or_row[i]);
229                    cols.push(data.indices[i]);
230                    values.push(data.data[i] as f32);
231                }
232
233                let coo = CooTensor::new(rows, cols, values, shape)?;
234                Ok(Box::new(coo))
235            }
236            ScipyFormat::Csr => {
237                let row_ptr = &data.indptr_or_row;
238                let col_indices = &data.indices;
239                let values: Vec<f32> = data.data.iter().map(|&v| v as f32).collect();
240
241                let csr =
242                    CsrTensor::from_raw_parts(row_ptr.clone(), col_indices.clone(), values, shape)?;
243
244                Ok(Box::new(csr))
245            }
246            ScipyFormat::Csc => {
247                let col_ptr = &data.indptr_or_row;
248                let row_indices = &data.indices;
249                let values: Vec<f32> = data.data.iter().map(|&v| v as f32).collect();
250
251                let csc =
252                    CscTensor::from_raw_parts(col_ptr.clone(), row_indices.clone(), values, shape)?;
253
254                Ok(Box::new(csc))
255            }
256            _ => {
257                // Convert to COO first, then to target format
258                let coo_data = ScipySparseData {
259                    format: ScipyFormat::Coo,
260                    ..data.clone()
261                };
262                let coo = Self::from_scipy_data(&coo_data)?;
263                convert_sparse_format(coo.as_ref(), data.format.into())
264            }
265        }
266    }
267
268    /// Serialize sparse tensor to dictionary format compatible with SciPy
269    pub fn to_dict(sparse: &dyn SparseTensor) -> TorshResult<HashMap<String, Vec<f64>>> {
270        let scipy_data = Self::to_scipy_data(sparse)?;
271
272        let mut dict = HashMap::new();
273        dict.insert("data".to_string(), scipy_data.data);
274        dict.insert(
275            "indices".to_string(),
276            scipy_data.indices.iter().map(|&x| x as f64).collect(),
277        );
278        dict.insert(
279            "indptr".to_string(),
280            scipy_data.indptr_or_row.iter().map(|&x| x as f64).collect(),
281        );
282        dict.insert(
283            "shape".to_string(),
284            vec![scipy_data.shape.0 as f64, scipy_data.shape.1 as f64],
285        );
286
287        Ok(dict)
288    }
289
290    /// Generate Python code to create equivalent SciPy sparse matrix
291    pub fn to_python_code(sparse: &dyn SparseTensor, var_name: &str) -> TorshResult<String> {
292        let scipy_data = Self::to_scipy_data(sparse)?;
293        let format_name = match scipy_data.format {
294            ScipyFormat::Coo => "coo_matrix",
295            ScipyFormat::Csr => "csr_matrix",
296            ScipyFormat::Csc => "csc_matrix",
297            ScipyFormat::Bsr => "bsr_matrix",
298            ScipyFormat::Dia => "dia_matrix",
299        };
300
301        let mut code = String::new();
302        code.push_str("import numpy as np\n");
303        code.push_str("from scipy.sparse import ");
304        code.push_str(format_name);
305        code.push_str("\n\n");
306
307        match scipy_data.format {
308            ScipyFormat::Coo => {
309                code.push_str("# COO format data\n");
310                code.push_str(&format!("row = np.array({:?})\n", scipy_data.indptr_or_row));
311                code.push_str(&format!("col = np.array({:?})\n", scipy_data.indices));
312                code.push_str(&format!("data = np.array({:?})\n", scipy_data.data));
313                code.push_str(&format!("shape = {:?}\n", scipy_data.shape));
314                code.push_str(&format!(
315                    "{var_name} = {format_name}((data, (row, col)), shape=shape)\n"
316                ));
317            }
318            ScipyFormat::Csr | ScipyFormat::Csc => {
319                let ptr_name = "indptr";
320                code.push_str(&format!("# {} format data\n", format_name.to_uppercase()));
321                code.push_str(&format!("data = np.array({:?})\n", scipy_data.data));
322                code.push_str(&format!("indices = np.array({:?})\n", scipy_data.indices));
323                code.push_str(&format!(
324                    "{} = np.array({:?})\n",
325                    ptr_name, scipy_data.indptr_or_row
326                ));
327                code.push_str(&format!("shape = {:?}\n", scipy_data.shape));
328                code.push_str(&format!(
329                    "{var_name} = {format_name}((data, indices, {ptr_name}), shape=shape)\n"
330                ));
331            }
332            _ => {
333                // Fallback to COO for other formats
334                code.push_str(&format!(
335                    "# Note: {format_name} format converted to COO for compatibility\n"
336                ));
337                code.push_str(&format!("row = np.array({:?})\n", scipy_data.indptr_or_row));
338                code.push_str(&format!("col = np.array({:?})\n", scipy_data.indices));
339                code.push_str(&format!("data = np.array({:?})\n", scipy_data.data));
340                code.push_str(&format!("shape = {:?}\n", scipy_data.shape));
341                code.push_str(&format!(
342                    "{var_name} = coo_matrix((data, (row, col)), shape=shape)\n"
343                ));
344            }
345        }
346
347        Ok(code)
348    }
349}
350
351/// Python bindings for SciPy sparse integration (when scipy feature is enabled)
352#[cfg(feature = "scipy")]
353pub mod python_bindings {
354    use super::*;
355
356    /// Export sparse tensor to Python SciPy format
357    #[pyfunction]
358    pub fn torsh_to_scipy(
359        py: Python,
360        format: &str,
361        shape: (usize, usize),
362        data: Vec<f64>,
363        indices: Vec<usize>,
364        indptr: Vec<usize>,
365    ) -> PyResult<Py<PyAny>> {
366        let scipy = py.import("scipy.sparse")?;
367
368        let data_array = PyArray1::from_vec(py, data);
369        let indices_array = PyArray1::from_vec(py, indices);
370        let indptr_array = PyArray1::from_vec(py, indptr);
371
372        let args = (
373            data_array
374                .into_pyobject(py)
375                .expect("PyArray conversion should succeed"),
376            indices_array
377                .into_pyobject(py)
378                .expect("PyArray conversion should succeed"),
379            indptr_array
380                .into_pyobject(py)
381                .expect("PyArray conversion should succeed"),
382        );
383
384        let kwargs = PyDict::new(py);
385        kwargs.set_item("shape", shape)?;
386
387        let matrix_class = scipy.getattr(format)?;
388        let result = matrix_class.call(args, Some(&kwargs))?;
389
390        Ok(result.unbind())
391    }
392
393    /// Import sparse tensor from Python SciPy format
394    #[pyfunction]
395    pub fn scipy_to_torsh(
396        _py: Python,
397        scipy_matrix: &Bound<PyAny>,
398    ) -> PyResult<(String, (usize, usize), Vec<f64>, Vec<usize>, Vec<usize>)> {
399        // Get format
400        let format_attr = scipy_matrix.getattr("format")?;
401        let format: String = format_attr.extract()?;
402
403        // Get shape
404        let shape_attr = scipy_matrix.getattr("shape")?;
405        let shape: (usize, usize) = shape_attr.extract()?;
406
407        // Convert to COO format for universal handling
408        let coo_matrix = scipy_matrix.call_method0("tocoo")?;
409
410        // Extract data arrays
411        let data_attr = coo_matrix.getattr("data")?;
412        let row_attr = coo_matrix.getattr("row")?;
413        let col_attr = coo_matrix.getattr("col")?;
414
415        let data: PyReadonlyArray1<f64> = data_attr.extract()?;
416        let row: PyReadonlyArray1<i32> = row_attr.extract()?;
417        let col: PyReadonlyArray1<i32> = col_attr.extract()?;
418
419        let data_vec = data.as_slice()?.to_vec();
420        let row_vec: Vec<usize> = row.as_slice()?.iter().map(|&x| x as usize).collect();
421        let col_vec: Vec<usize> = col.as_slice()?.iter().map(|&x| x as usize).collect();
422
423        Ok((format, shape, data_vec, col_vec, row_vec))
424    }
425}
426
427/// Macro to simplify SciPy conversion
428#[macro_export]
429macro_rules! to_scipy {
430    ($sparse:expr) => {
431        ScipySparseIntegration::to_scipy_data($sparse)
432    };
433    ($sparse:expr, $format:expr) => {{
434        let scipy_data = ScipySparseIntegration::to_scipy_data($sparse)?;
435        let converted = convert_sparse_format($sparse, $format)?;
436        ScipySparseIntegration::to_scipy_data(converted.as_ref())
437    }};
438}
439
440/// Macro to simplify creation from SciPy data
441#[macro_export]
442macro_rules! from_scipy {
443    ($data:expr) => {
444        ScipySparseIntegration::from_scipy_data($data)
445    };
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use crate::coo::CooTensor;
452    use torsh_core::{DType, Shape};
453
454    #[test]
455    fn test_scipy_data_conversion() {
456        let shape = Shape::new(vec![3, 3]);
457        let mut coo = CooTensor::empty(shape.clone(), DType::F32).unwrap();
458
459        // Create a simple sparse matrix
460        coo.insert(0, 0, 1.0).unwrap();
461        coo.insert(1, 1, 2.0).unwrap();
462        coo.insert(2, 2, 3.0).unwrap();
463        coo.insert(0, 2, 4.0).unwrap();
464
465        // Convert to SciPy data
466        let scipy_data = ScipySparseIntegration::to_scipy_data(&coo).unwrap();
467
468        assert_eq!(scipy_data.format, ScipyFormat::Coo);
469        assert_eq!(scipy_data.shape, (3, 3));
470        assert_eq!(scipy_data.data.len(), 4);
471
472        // Convert back to ToRSh
473        let restored = ScipySparseIntegration::from_scipy_data(&scipy_data).unwrap();
474        assert_eq!(restored.nnz(), 4);
475        assert_eq!(restored.shape(), &shape);
476    }
477
478    #[test]
479    fn test_python_code_generation() {
480        let shape = Shape::new(vec![2, 2]);
481        let mut coo = CooTensor::empty(shape, DType::F32).unwrap();
482
483        coo.insert(0, 0, 1.0).unwrap();
484        coo.insert(1, 1, 2.0).unwrap();
485
486        let code = ScipySparseIntegration::to_python_code(&coo, "matrix").unwrap();
487
488        assert!(code.contains("import numpy as np"));
489        assert!(code.contains("from scipy.sparse import"));
490        assert!(code.contains("matrix ="));
491    }
492
493    #[test]
494    fn test_dict_conversion() {
495        let shape = Shape::new(vec![2, 2]);
496        let mut coo = CooTensor::empty(shape, DType::F32).unwrap();
497
498        coo.insert(0, 0, 1.0).unwrap();
499        coo.insert(1, 1, 2.0).unwrap();
500
501        let dict = ScipySparseIntegration::to_dict(&coo).unwrap();
502
503        assert!(dict.contains_key("data"));
504        assert!(dict.contains_key("indices"));
505        assert!(dict.contains_key("indptr"));
506        assert!(dict.contains_key("shape"));
507
508        assert_eq!(dict["shape"], vec![2.0, 2.0]);
509        assert_eq!(dict["data"].len(), 2);
510    }
511
512    #[test]
513    fn test_format_conversion() {
514        assert_eq!(ScipyFormat::from(SparseFormat::Coo), ScipyFormat::Coo);
515        assert_eq!(ScipyFormat::from(SparseFormat::Csr), ScipyFormat::Csr);
516        assert_eq!(ScipyFormat::from(SparseFormat::Csc), ScipyFormat::Csc);
517        assert_eq!(ScipyFormat::from(SparseFormat::Ell), ScipyFormat::Csr);
518
519        assert_eq!(SparseFormat::from(ScipyFormat::Coo), SparseFormat::Coo);
520        assert_eq!(SparseFormat::from(ScipyFormat::Csr), SparseFormat::Csr);
521        assert_eq!(SparseFormat::from(ScipyFormat::Csc), SparseFormat::Csc);
522    }
523
524    #[test]
525    fn test_macro_usage() {
526        let shape = Shape::new(vec![2, 2]);
527        let mut coo = CooTensor::empty(shape, DType::F32).unwrap();
528
529        coo.insert(0, 0, 1.0).unwrap();
530        coo.insert(1, 1, 2.0).unwrap();
531
532        let scipy_data = to_scipy!(&coo).unwrap();
533        assert_eq!(scipy_data.data.len(), 2);
534
535        let restored = from_scipy!(&scipy_data).unwrap();
536        assert_eq!(restored.nnz(), 2);
537    }
538}