scirs2_sparse/
sparray.rs

1// Sparse Array API
2//
3// This module provides the base trait for sparse arrays, inspired by SciPy's transition
4// from matrix-based API to array-based API.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::Debug;
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12
13/// Trait for sparse array types.
14///
15/// This trait defines the common interface for all sparse array implementations.
16/// It is designed to align with SciPy's sparse array API, providing array-like semantics
17/// rather than matrix-like semantics.
18///
19/// # Notes
20///
21/// The sparse array API differs from the sparse matrix API in the following ways:
22///
23/// - `*` operator performs element-wise multiplication, not matrix multiplication
24/// - Matrix multiplication is done with the `dot` method or `@` operator in Python
25/// - Operations like `sum` produce arrays, not matrices
26/// - Sparse arrays use array-style slicing operations
27///
28/// # Compatibility
29///
30/// This trait now supports both integer and floating-point element types through
31/// the `SparseElement` trait, enabling use cases like graph adjacency matrices (integers)
32/// and numerical computation (floats).
33///
34pub trait SparseArray<T>: std::any::Any
35where
36    T: SparseElement + Div<Output = T> + 'static,
37{
38    /// Returns the shape of the sparse array.
39    fn shape(&self) -> (usize, usize);
40
41    /// Returns the number of stored (non-zero) elements.
42    fn nnz(&self) -> usize;
43
44    /// Returns the data type of the sparse array.
45    fn dtype(&self) -> &str;
46
47    /// Returns a view of the sparse array as a dense ndarray.
48    fn to_array(&self) -> Array2<T>;
49
50    /// Returns a dense copy of the sparse array.
51    fn toarray(&self) -> Array2<T>;
52
53    /// Returns a sparse array in COO format.
54    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
55
56    /// Returns a sparse array in CSR format.
57    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
58
59    /// Returns a sparse array in CSC format.
60    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
61
62    /// Returns a sparse array in DOK format.
63    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
64
65    /// Returns a sparse array in LIL format.
66    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
67
68    /// Returns a sparse array in DIA format.
69    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
70
71    /// Returns a sparse array in BSR format.
72    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
73
74    /// Element-wise addition.
75    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
76
77    /// Element-wise subtraction.
78    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
79
80    /// Element-wise multiplication.
81    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
82
83    /// Element-wise division.
84    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
85
86    /// Matrix multiplication.
87    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
88
89    /// Matrix-vector multiplication.
90    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>>;
91
92    /// Transpose the sparse array.
93    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
94
95    /// Return a copy of the sparse array with the specified elements.
96    fn copy(&self) -> Box<dyn SparseArray<T>>;
97
98    /// Get a value at the specified position.
99    fn get(&self, i: usize, j: usize) -> T;
100
101    /// Set a value at the specified position.
102    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()>;
103
104    /// Eliminate zeros from the sparse array.
105    fn eliminate_zeros(&mut self);
106
107    /// Sort indices of the sparse array.
108    fn sort_indices(&mut self);
109
110    /// Return a sorted copy of this sparse array.
111    fn sorted_indices(&self) -> Box<dyn SparseArray<T>>;
112
113    /// Check if indices are sorted.
114    fn has_sorted_indices(&self) -> bool;
115
116    /// Sum the sparse array elements.
117    ///
118    /// Parameters:
119    /// - `axis`: The axis along which to sum. If None, sum over both axes.
120    ///
121    /// Returns a sparse array if summing over a single axis, or a scalar if summing over both axes.
122    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>>;
123
124    /// Compute the maximum value of the sparse array elements.
125    fn max(&self) -> T;
126
127    /// Compute the minimum value of the sparse array elements.
128    fn min(&self) -> T;
129
130    /// Return the indices and values of the nonzero elements.
131    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>);
132
133    /// Return a slice of the sparse array.
134    fn slice(
135        &self,
136        row_range: (usize, usize),
137        col_range: (usize, usize),
138    ) -> SparseResult<Box<dyn SparseArray<T>>>;
139
140    /// Returns the concrete type of the array for downcasting.
141    fn as_any(&self) -> &dyn std::any::Any;
142
143    /// Returns the indptr array for CSR/CSC formats.
144    /// For formats that don't have indptr, returns None.
145    fn get_indptr(&self) -> Option<&Array1<usize>> {
146        None
147    }
148
149    /// Returns the indptr array for CSR/CSC formats.
150    /// For formats that don't have indptr, returns None.
151    fn indptr(&self) -> Option<&Array1<usize>> {
152        None
153    }
154}
155
156/// Represents the result of a sum operation on a sparse array.
157// Manually implement Debug and Clone instead of deriving them
158pub enum SparseSum<T>
159where
160    T: SparseElement + Div<Output = T> + 'static,
161{
162    /// Sum over a single axis, returning a sparse array.
163    SparseArray(Box<dyn SparseArray<T>>),
164
165    /// Sum over both axes, returning a scalar.
166    Scalar(T),
167}
168
169impl<T> Debug for SparseSum<T>
170where
171    T: SparseElement + Div<Output = T> + 'static,
172{
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        match self {
175            SparseSum::SparseArray(_) => write!(f, "SparseSum::SparseArray(...)"),
176            SparseSum::Scalar(value) => write!(f, "SparseSum::Scalar({value:?})"),
177        }
178    }
179}
180
181impl<T> Clone for SparseSum<T>
182where
183    T: SparseElement + Div<Output = T> + 'static,
184{
185    fn clone(&self) -> Self {
186        match self {
187            SparseSum::SparseArray(array) => SparseSum::SparseArray(array.copy()),
188            SparseSum::Scalar(value) => SparseSum::Scalar(*value),
189        }
190    }
191}
192
193/// Identifies sparse arrays (both matrix and array types)
194#[allow(dead_code)]
195pub fn is_sparse<T>(obj: &dyn SparseArray<T>) -> bool
196where
197    T: SparseElement + Div<Output = T> + 'static,
198{
199    true // Since this is a trait method, any object that implements it is sparse
200}
201
202/// Create a base SparseArray implementation for demonstrations and testing
203pub struct SparseArrayBase<T>
204where
205    T: SparseElement + Div<Output = T> + 'static,
206{
207    data: Array2<T>,
208}
209
210impl<T> SparseArrayBase<T>
211where
212    T: SparseElement + Div<Output = T> + Zero + 'static,
213{
214    /// Create a new SparseArrayBase from a dense ndarray.
215    pub fn new(data: Array2<T>) -> Self {
216        Self { data }
217    }
218}
219
220impl<T> SparseArray<T> for SparseArrayBase<T>
221where
222    T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
223{
224    fn shape(&self) -> (usize, usize) {
225        let shape = self.data.shape();
226        (shape[0], shape[1])
227    }
228
229    fn nnz(&self) -> usize {
230        self.data.iter().filter(|&&x| x != T::sparse_zero()).count()
231    }
232
233    fn dtype(&self) -> &str {
234        "float" // This is a placeholder; ideally, we'd return the actual type
235    }
236
237    fn to_array(&self) -> Array2<T> {
238        self.data.clone()
239    }
240
241    fn toarray(&self) -> Array2<T> {
242        self.data.clone()
243    }
244
245    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
246        // In a real implementation, this would convert to COO format
247        Ok(Box::new(self.clone()))
248    }
249
250    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
251        // In a real implementation, this would convert to CSR format
252        Ok(Box::new(self.clone()))
253    }
254
255    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
256        // In a real implementation, this would convert to CSC format
257        Ok(Box::new(self.clone()))
258    }
259
260    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
261        // In a real implementation, this would convert to DOK format
262        Ok(Box::new(self.clone()))
263    }
264
265    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
266        // In a real implementation, this would convert to LIL format
267        Ok(Box::new(self.clone()))
268    }
269
270    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
271        // In a real implementation, this would convert to DIA format
272        Ok(Box::new(self.clone()))
273    }
274
275    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
276        // In a real implementation, this would convert to BSR format
277        Ok(Box::new(self.clone()))
278    }
279
280    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
281        let other_array = other.to_array();
282        let result = &self.data + &other_array;
283        Ok(Box::new(SparseArrayBase::new(result)))
284    }
285
286    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
287        let other_array = other.to_array();
288        let result = &self.data - &other_array;
289        Ok(Box::new(SparseArrayBase::new(result)))
290    }
291
292    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
293        let other_array = other.to_array();
294        let result = &self.data * &other_array;
295        Ok(Box::new(SparseArrayBase::new(result)))
296    }
297
298    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
299        let other_array = other.to_array();
300        let result = &self.data / &other_array;
301        Ok(Box::new(SparseArrayBase::new(result)))
302    }
303
304    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
305        let other_array = other.to_array();
306        let (m, n) = self.shape();
307        let (p, q) = other.shape();
308
309        if n != p {
310            return Err(SparseError::DimensionMismatch {
311                expected: n,
312                found: p,
313            });
314        }
315
316        let mut result = Array2::zeros((m, q));
317        for i in 0..m {
318            for j in 0..q {
319                let mut sum = T::sparse_zero();
320                for k in 0..n {
321                    sum = sum + self.data[[i, k]] * other_array[[k, j]];
322                }
323                result[[i, j]] = sum;
324            }
325        }
326
327        Ok(Box::new(SparseArrayBase::new(result)))
328    }
329
330    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
331        let (m, n) = self.shape();
332        if n != other.len() {
333            return Err(SparseError::DimensionMismatch {
334                expected: n,
335                found: other.len(),
336            });
337        }
338
339        let mut result = Array1::zeros(m);
340        for i in 0..m {
341            let mut sum = T::sparse_zero();
342            for j in 0..n {
343                sum = sum + self.data[[i, j]] * other[j];
344            }
345            result[i] = sum;
346        }
347
348        Ok(result)
349    }
350
351    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
352        Ok(Box::new(SparseArrayBase::new(self.data.t().to_owned())))
353    }
354
355    fn copy(&self) -> Box<dyn SparseArray<T>> {
356        Box::new(self.clone())
357    }
358
359    fn get(&self, i: usize, j: usize) -> T {
360        self.data[[i, j]]
361    }
362
363    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
364        let (m, n) = self.shape();
365        if i >= m || j >= n {
366            return Err(SparseError::IndexOutOfBounds {
367                index: (i, j),
368                shape: (m, n),
369            });
370        }
371        self.data[[i, j]] = value;
372        Ok(())
373    }
374
375    fn eliminate_zeros(&mut self) {
376        // No-op for dense array
377    }
378
379    fn sort_indices(&mut self) {
380        // No-op for dense array
381    }
382
383    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
384        self.copy()
385    }
386
387    fn has_sorted_indices(&self) -> bool {
388        true // Dense array has implicitly sorted indices
389    }
390
391    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
392        match axis {
393            None => {
394                let mut sum = T::sparse_zero();
395                for &val in self.data.iter() {
396                    sum = sum + val;
397                }
398                Ok(SparseSum::Scalar(sum))
399            }
400            Some(0) => {
401                let (_, n) = self.shape();
402                let mut result = Array2::zeros((1, n));
403                for j in 0..n {
404                    let mut sum = T::sparse_zero();
405                    for i in 0..self.data.shape()[0] {
406                        sum = sum + self.data[[i, j]];
407                    }
408                    result[[0, j]] = sum;
409                }
410                Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
411                    result,
412                ))))
413            }
414            Some(1) => {
415                let (m_, _) = self.shape();
416                let mut result = Array2::zeros((m_, 1));
417                for i in 0..m_ {
418                    let mut sum = T::sparse_zero();
419                    for j in 0..self.data.shape()[1] {
420                        sum = sum + self.data[[i, j]];
421                    }
422                    result[[i, 0]] = sum;
423                }
424                Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
425                    result,
426                ))))
427            }
428            _ => Err(SparseError::InvalidAxis),
429        }
430    }
431
432    fn max(&self) -> T {
433        if self.data.is_empty() {
434            return T::sparse_zero();
435        }
436        let mut max_val = self.data[[0, 0]];
437        for &val in self.data.iter() {
438            if val > max_val {
439                max_val = val;
440            }
441        }
442        max_val
443    }
444
445    fn min(&self) -> T {
446        if self.data.is_empty() {
447            return T::sparse_zero();
448        }
449        let mut min_val = self.data[[0, 0]];
450        for &val in self.data.iter() {
451            if val < min_val {
452                min_val = val;
453            }
454        }
455        min_val
456    }
457
458    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
459        let (m, n) = self.shape();
460        let nnz = self.nnz();
461        let mut rows = Vec::with_capacity(nnz);
462        let mut cols = Vec::with_capacity(nnz);
463        let mut values = Vec::with_capacity(nnz);
464
465        for i in 0..m {
466            for j in 0..n {
467                let value = self.data[[i, j]];
468                if value != T::sparse_zero() {
469                    rows.push(i);
470                    cols.push(j);
471                    values.push(value);
472                }
473            }
474        }
475
476        (
477            Array1::from_vec(rows),
478            Array1::from_vec(cols),
479            Array1::from_vec(values),
480        )
481    }
482
483    fn slice(
484        &self,
485        row_range: (usize, usize),
486        col_range: (usize, usize),
487    ) -> SparseResult<Box<dyn SparseArray<T>>> {
488        let (start_row, end_row) = row_range;
489        let (start_col, end_col) = col_range;
490        let (m, n) = self.shape();
491
492        if start_row >= m
493            || end_row > m
494            || start_col >= n
495            || end_col > n
496            || start_row >= end_row
497            || start_col >= end_col
498        {
499            return Err(SparseError::InvalidSliceRange);
500        }
501
502        let view = self.data.slice(scirs2_core::ndarray::s![
503            start_row..end_row,
504            start_col..end_col
505        ]);
506        Ok(Box::new(SparseArrayBase::new(view.to_owned())))
507    }
508
509    fn as_any(&self) -> &dyn std::any::Any {
510        self
511    }
512}
513
514impl<T> Clone for SparseArrayBase<T>
515where
516    T: SparseElement + Div<Output = T> + 'static,
517{
518    fn clone(&self) -> Self {
519        Self {
520            data: self.data.clone(),
521        }
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use scirs2_core::ndarray::Array;
529
530    #[test]
531    fn test_sparse_array_base() {
532        let data = Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0])
533            .unwrap();
534        let sparse = SparseArrayBase::new(data);
535
536        assert_eq!(sparse.shape(), (3, 3));
537        assert_eq!(sparse.nnz(), 5);
538        assert_eq!(sparse.get(0, 0), 1.0);
539        assert_eq!(sparse.get(1, 1), 3.0);
540        assert_eq!(sparse.get(2, 2), 5.0);
541        assert_eq!(sparse.get(0, 1), 0.0);
542    }
543
544    #[test]
545    fn test_sparse_array_operations() {
546        let data1 = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
547        let data2 = Array::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
548
549        let sparse1 = SparseArrayBase::new(data1);
550        let sparse2 = SparseArrayBase::new(data2);
551
552        // Test add
553        let result = sparse1.add(&sparse2).unwrap();
554        let result_array = result.to_array();
555        assert_eq!(result_array[[0, 0]], 6.0);
556        assert_eq!(result_array[[0, 1]], 8.0);
557        assert_eq!(result_array[[1, 0]], 10.0);
558        assert_eq!(result_array[[1, 1]], 12.0);
559
560        // Test dot
561        let result = sparse1.dot(&sparse2).unwrap();
562        let result_array = result.to_array();
563        assert_eq!(result_array[[0, 0]], 19.0);
564        assert_eq!(result_array[[0, 1]], 22.0);
565        assert_eq!(result_array[[1, 0]], 43.0);
566        assert_eq!(result_array[[1, 1]], 50.0);
567    }
568}