1#[cfg(feature = "scipy")]
7use numpy::{PyArray1, PyReadonlyArray1};
8#[cfg(feature = "scipy")]
9use pyo3::prelude::*;
10#[cfg(feature = "scipy")]
11use pyo3::types::PyDict;
12#[cfg(feature = "scipy")]
13use pyo3::Bound;
14
15use crate::*;
16use std::collections::HashMap;
17use torsh_core::Result as TorshResult;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ScipyFormat {
22 Csr,
24 Csc,
26 Coo,
28 Bsr,
30 Dia,
32}
33
34impl From<SparseFormat> for ScipyFormat {
35 fn from(format: SparseFormat) -> Self {
36 match format {
37 SparseFormat::Coo => ScipyFormat::Coo,
38 SparseFormat::Csr => ScipyFormat::Csr,
39 SparseFormat::Csc => ScipyFormat::Csc,
40 SparseFormat::Bsr => ScipyFormat::Bsr,
41 SparseFormat::Dia => ScipyFormat::Dia,
42 SparseFormat::Ell => ScipyFormat::Csr, SparseFormat::Rle => ScipyFormat::Csr, SparseFormat::Symmetric => ScipyFormat::Csr, SparseFormat::Dsr => ScipyFormat::Csr, }
47 }
48}
49
50impl From<ScipyFormat> for SparseFormat {
51 fn from(format: ScipyFormat) -> Self {
52 match format {
53 ScipyFormat::Coo => SparseFormat::Coo,
54 ScipyFormat::Csr => SparseFormat::Csr,
55 ScipyFormat::Csc => SparseFormat::Csc,
56 ScipyFormat::Bsr => SparseFormat::Bsr,
57 ScipyFormat::Dia => SparseFormat::Dia,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ScipySparseData {
65 pub format: ScipyFormat,
67 pub shape: (usize, usize),
69 pub data: Vec<f64>,
71 pub indices: Vec<usize>,
73 pub indptr_or_row: Vec<usize>,
75 pub blocksize: Option<(usize, usize)>,
77 pub diagonals: Option<Vec<i32>>,
79}
80
81impl ScipySparseData {
82 pub fn new(format: ScipyFormat, shape: (usize, usize)) -> Self {
84 Self {
85 format,
86 shape,
87 data: Vec::new(),
88 indices: Vec::new(),
89 indptr_or_row: Vec::new(),
90 blocksize: None,
91 diagonals: None,
92 }
93 }
94
95 pub fn from_coo(
97 shape: (usize, usize),
98 row_indices: Vec<usize>,
99 col_indices: Vec<usize>,
100 values: Vec<f64>,
101 ) -> Self {
102 Self {
103 format: ScipyFormat::Coo,
104 shape,
105 data: values,
106 indices: col_indices,
107 indptr_or_row: row_indices,
108 blocksize: None,
109 diagonals: None,
110 }
111 }
112
113 pub fn from_csr(
115 shape: (usize, usize),
116 row_ptr: Vec<usize>,
117 col_indices: Vec<usize>,
118 values: Vec<f64>,
119 ) -> Self {
120 Self {
121 format: ScipyFormat::Csr,
122 shape,
123 data: values,
124 indices: col_indices,
125 indptr_or_row: row_ptr,
126 blocksize: None,
127 diagonals: None,
128 }
129 }
130
131 pub fn from_csc(
133 shape: (usize, usize),
134 col_ptr: Vec<usize>,
135 row_indices: Vec<usize>,
136 values: Vec<f64>,
137 ) -> Self {
138 Self {
139 format: ScipyFormat::Csc,
140 shape,
141 data: values,
142 indices: row_indices,
143 indptr_or_row: col_ptr,
144 blocksize: None,
145 diagonals: None,
146 }
147 }
148}
149
150pub struct ScipySparseIntegration;
152
153impl ScipySparseIntegration {
154 pub fn to_scipy_data(sparse: &dyn SparseTensor) -> TorshResult<ScipySparseData> {
156 let shape = sparse.shape();
157 let (rows, cols) = (shape.dims()[0], shape.dims()[1]);
158
159 match sparse.format() {
160 SparseFormat::Coo => {
161 let coo = sparse.to_coo()?;
162 let triplets = coo.triplets();
163
164 let mut row_indices = Vec::new();
165 let mut col_indices = Vec::new();
166 let mut values = Vec::new();
167
168 for (row, col, val) in triplets {
169 row_indices.push(row);
170 col_indices.push(col);
171 values.push(val as f64);
172 }
173
174 Ok(ScipySparseData::from_coo(
175 (rows, cols),
176 row_indices,
177 col_indices,
178 values,
179 ))
180 }
181 SparseFormat::Csr => {
182 let csr = sparse.to_csr()?;
183 let row_ptr = csr.row_ptr().to_vec();
184 let col_indices = csr.col_indices().to_vec();
185 let values = csr.values().iter().map(|&v| v as f64).collect();
186
187 Ok(ScipySparseData::from_csr(
188 (rows, cols),
189 row_ptr,
190 col_indices,
191 values,
192 ))
193 }
194 SparseFormat::Csc => {
195 let csc = sparse.to_csc()?;
196 let col_ptr = csc.col_ptr().to_vec();
197 let row_indices = csc.row_indices().to_vec();
198 let values = csc.values().iter().map(|&v| v as f64).collect();
199
200 Ok(ScipySparseData::from_csc(
201 (rows, cols),
202 col_ptr,
203 row_indices,
204 values,
205 ))
206 }
207 _ => {
208 let coo = sparse.to_coo()?;
210 Self::to_scipy_data(&coo)
211 }
212 }
213 }
214
215 pub fn from_scipy_data(
217 data: &ScipySparseData,
218 ) -> TorshResult<Box<dyn SparseTensor + Send + Sync>> {
219 let shape = Shape::new(vec![data.shape.0, data.shape.1]);
220
221 match data.format {
222 ScipyFormat::Coo => {
223 let mut rows = Vec::new();
224 let mut cols = Vec::new();
225 let mut values = Vec::new();
226
227 for i in 0..data.data.len() {
228 rows.push(data.indptr_or_row[i]);
229 cols.push(data.indices[i]);
230 values.push(data.data[i] as f32);
231 }
232
233 let coo = CooTensor::new(rows, cols, values, shape)?;
234 Ok(Box::new(coo))
235 }
236 ScipyFormat::Csr => {
237 let row_ptr = &data.indptr_or_row;
238 let col_indices = &data.indices;
239 let values: Vec<f32> = data.data.iter().map(|&v| v as f32).collect();
240
241 let csr =
242 CsrTensor::from_raw_parts(row_ptr.clone(), col_indices.clone(), values, shape)?;
243
244 Ok(Box::new(csr))
245 }
246 ScipyFormat::Csc => {
247 let col_ptr = &data.indptr_or_row;
248 let row_indices = &data.indices;
249 let values: Vec<f32> = data.data.iter().map(|&v| v as f32).collect();
250
251 let csc =
252 CscTensor::from_raw_parts(col_ptr.clone(), row_indices.clone(), values, shape)?;
253
254 Ok(Box::new(csc))
255 }
256 _ => {
257 let coo_data = ScipySparseData {
259 format: ScipyFormat::Coo,
260 ..data.clone()
261 };
262 let coo = Self::from_scipy_data(&coo_data)?;
263 convert_sparse_format(coo.as_ref(), data.format.into())
264 }
265 }
266 }
267
268 pub fn to_dict(sparse: &dyn SparseTensor) -> TorshResult<HashMap<String, Vec<f64>>> {
270 let scipy_data = Self::to_scipy_data(sparse)?;
271
272 let mut dict = HashMap::new();
273 dict.insert("data".to_string(), scipy_data.data);
274 dict.insert(
275 "indices".to_string(),
276 scipy_data.indices.iter().map(|&x| x as f64).collect(),
277 );
278 dict.insert(
279 "indptr".to_string(),
280 scipy_data.indptr_or_row.iter().map(|&x| x as f64).collect(),
281 );
282 dict.insert(
283 "shape".to_string(),
284 vec![scipy_data.shape.0 as f64, scipy_data.shape.1 as f64],
285 );
286
287 Ok(dict)
288 }
289
290 pub fn to_python_code(sparse: &dyn SparseTensor, var_name: &str) -> TorshResult<String> {
292 let scipy_data = Self::to_scipy_data(sparse)?;
293 let format_name = match scipy_data.format {
294 ScipyFormat::Coo => "coo_matrix",
295 ScipyFormat::Csr => "csr_matrix",
296 ScipyFormat::Csc => "csc_matrix",
297 ScipyFormat::Bsr => "bsr_matrix",
298 ScipyFormat::Dia => "dia_matrix",
299 };
300
301 let mut code = String::new();
302 code.push_str("import numpy as np\n");
303 code.push_str("from scipy.sparse import ");
304 code.push_str(format_name);
305 code.push_str("\n\n");
306
307 match scipy_data.format {
308 ScipyFormat::Coo => {
309 code.push_str("# COO format data\n");
310 code.push_str(&format!("row = np.array({:?})\n", scipy_data.indptr_or_row));
311 code.push_str(&format!("col = np.array({:?})\n", scipy_data.indices));
312 code.push_str(&format!("data = np.array({:?})\n", scipy_data.data));
313 code.push_str(&format!("shape = {:?}\n", scipy_data.shape));
314 code.push_str(&format!(
315 "{var_name} = {format_name}((data, (row, col)), shape=shape)\n"
316 ));
317 }
318 ScipyFormat::Csr | ScipyFormat::Csc => {
319 let ptr_name = "indptr";
320 code.push_str(&format!("# {} format data\n", format_name.to_uppercase()));
321 code.push_str(&format!("data = np.array({:?})\n", scipy_data.data));
322 code.push_str(&format!("indices = np.array({:?})\n", scipy_data.indices));
323 code.push_str(&format!(
324 "{} = np.array({:?})\n",
325 ptr_name, scipy_data.indptr_or_row
326 ));
327 code.push_str(&format!("shape = {:?}\n", scipy_data.shape));
328 code.push_str(&format!(
329 "{var_name} = {format_name}((data, indices, {ptr_name}), shape=shape)\n"
330 ));
331 }
332 _ => {
333 code.push_str(&format!(
335 "# Note: {format_name} format converted to COO for compatibility\n"
336 ));
337 code.push_str(&format!("row = np.array({:?})\n", scipy_data.indptr_or_row));
338 code.push_str(&format!("col = np.array({:?})\n", scipy_data.indices));
339 code.push_str(&format!("data = np.array({:?})\n", scipy_data.data));
340 code.push_str(&format!("shape = {:?}\n", scipy_data.shape));
341 code.push_str(&format!(
342 "{var_name} = coo_matrix((data, (row, col)), shape=shape)\n"
343 ));
344 }
345 }
346
347 Ok(code)
348 }
349}
350
351#[cfg(feature = "scipy")]
353pub mod python_bindings {
354 use super::*;
355
356 #[pyfunction]
358 pub fn torsh_to_scipy(
359 py: Python,
360 format: &str,
361 shape: (usize, usize),
362 data: Vec<f64>,
363 indices: Vec<usize>,
364 indptr: Vec<usize>,
365 ) -> PyResult<Py<PyAny>> {
366 let scipy = py.import("scipy.sparse")?;
367
368 let data_array = PyArray1::from_vec(py, data);
369 let indices_array = PyArray1::from_vec(py, indices);
370 let indptr_array = PyArray1::from_vec(py, indptr);
371
372 let args = (
373 data_array
374 .into_pyobject(py)
375 .expect("PyArray conversion should succeed"),
376 indices_array
377 .into_pyobject(py)
378 .expect("PyArray conversion should succeed"),
379 indptr_array
380 .into_pyobject(py)
381 .expect("PyArray conversion should succeed"),
382 );
383
384 let kwargs = PyDict::new(py);
385 kwargs.set_item("shape", shape)?;
386
387 let matrix_class = scipy.getattr(format)?;
388 let result = matrix_class.call(args, Some(&kwargs))?;
389
390 Ok(result.unbind())
391 }
392
393 #[pyfunction]
395 pub fn scipy_to_torsh(
396 _py: Python,
397 scipy_matrix: &Bound<PyAny>,
398 ) -> PyResult<(String, (usize, usize), Vec<f64>, Vec<usize>, Vec<usize>)> {
399 let format_attr = scipy_matrix.getattr("format")?;
401 let format: String = format_attr.extract()?;
402
403 let shape_attr = scipy_matrix.getattr("shape")?;
405 let shape: (usize, usize) = shape_attr.extract()?;
406
407 let coo_matrix = scipy_matrix.call_method0("tocoo")?;
409
410 let data_attr = coo_matrix.getattr("data")?;
412 let row_attr = coo_matrix.getattr("row")?;
413 let col_attr = coo_matrix.getattr("col")?;
414
415 let data: PyReadonlyArray1<f64> = data_attr.extract()?;
416 let row: PyReadonlyArray1<i32> = row_attr.extract()?;
417 let col: PyReadonlyArray1<i32> = col_attr.extract()?;
418
419 let data_vec = data.as_slice()?.to_vec();
420 let row_vec: Vec<usize> = row.as_slice()?.iter().map(|&x| x as usize).collect();
421 let col_vec: Vec<usize> = col.as_slice()?.iter().map(|&x| x as usize).collect();
422
423 Ok((format, shape, data_vec, col_vec, row_vec))
424 }
425}
426
427#[macro_export]
429macro_rules! to_scipy {
430 ($sparse:expr) => {
431 ScipySparseIntegration::to_scipy_data($sparse)
432 };
433 ($sparse:expr, $format:expr) => {{
434 let scipy_data = ScipySparseIntegration::to_scipy_data($sparse)?;
435 let converted = convert_sparse_format($sparse, $format)?;
436 ScipySparseIntegration::to_scipy_data(converted.as_ref())
437 }};
438}
439
440#[macro_export]
442macro_rules! from_scipy {
443 ($data:expr) => {
444 ScipySparseIntegration::from_scipy_data($data)
445 };
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use crate::coo::CooTensor;
452 use torsh_core::{DType, Shape};
453
454 #[test]
455 fn test_scipy_data_conversion() {
456 let shape = Shape::new(vec![3, 3]);
457 let mut coo = CooTensor::empty(shape.clone(), DType::F32).unwrap();
458
459 coo.insert(0, 0, 1.0).unwrap();
461 coo.insert(1, 1, 2.0).unwrap();
462 coo.insert(2, 2, 3.0).unwrap();
463 coo.insert(0, 2, 4.0).unwrap();
464
465 let scipy_data = ScipySparseIntegration::to_scipy_data(&coo).unwrap();
467
468 assert_eq!(scipy_data.format, ScipyFormat::Coo);
469 assert_eq!(scipy_data.shape, (3, 3));
470 assert_eq!(scipy_data.data.len(), 4);
471
472 let restored = ScipySparseIntegration::from_scipy_data(&scipy_data).unwrap();
474 assert_eq!(restored.nnz(), 4);
475 assert_eq!(restored.shape(), &shape);
476 }
477
478 #[test]
479 fn test_python_code_generation() {
480 let shape = Shape::new(vec![2, 2]);
481 let mut coo = CooTensor::empty(shape, DType::F32).unwrap();
482
483 coo.insert(0, 0, 1.0).unwrap();
484 coo.insert(1, 1, 2.0).unwrap();
485
486 let code = ScipySparseIntegration::to_python_code(&coo, "matrix").unwrap();
487
488 assert!(code.contains("import numpy as np"));
489 assert!(code.contains("from scipy.sparse import"));
490 assert!(code.contains("matrix ="));
491 }
492
493 #[test]
494 fn test_dict_conversion() {
495 let shape = Shape::new(vec![2, 2]);
496 let mut coo = CooTensor::empty(shape, DType::F32).unwrap();
497
498 coo.insert(0, 0, 1.0).unwrap();
499 coo.insert(1, 1, 2.0).unwrap();
500
501 let dict = ScipySparseIntegration::to_dict(&coo).unwrap();
502
503 assert!(dict.contains_key("data"));
504 assert!(dict.contains_key("indices"));
505 assert!(dict.contains_key("indptr"));
506 assert!(dict.contains_key("shape"));
507
508 assert_eq!(dict["shape"], vec![2.0, 2.0]);
509 assert_eq!(dict["data"].len(), 2);
510 }
511
512 #[test]
513 fn test_format_conversion() {
514 assert_eq!(ScipyFormat::from(SparseFormat::Coo), ScipyFormat::Coo);
515 assert_eq!(ScipyFormat::from(SparseFormat::Csr), ScipyFormat::Csr);
516 assert_eq!(ScipyFormat::from(SparseFormat::Csc), ScipyFormat::Csc);
517 assert_eq!(ScipyFormat::from(SparseFormat::Ell), ScipyFormat::Csr);
518
519 assert_eq!(SparseFormat::from(ScipyFormat::Coo), SparseFormat::Coo);
520 assert_eq!(SparseFormat::from(ScipyFormat::Csr), SparseFormat::Csr);
521 assert_eq!(SparseFormat::from(ScipyFormat::Csc), SparseFormat::Csc);
522 }
523
524 #[test]
525 fn test_macro_usage() {
526 let shape = Shape::new(vec![2, 2]);
527 let mut coo = CooTensor::empty(shape, DType::F32).unwrap();
528
529 coo.insert(0, 0, 1.0).unwrap();
530 coo.insert(1, 1, 2.0).unwrap();
531
532 let scipy_data = to_scipy!(&coo).unwrap();
533 assert_eq!(scipy_data.data.len(), 2);
534
535 let restored = from_scipy!(&scipy_data).unwrap();
536 assert_eq!(restored.nnz(), 2);
537 }
538}