scirs2_sparse/
sparray.rs

1// Sparse Array API
2//
3// This module provides the base trait for sparse arrays, inspired by SciPy's transition
4// from matrix-based API to array-based API.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12
13/// Trait for sparse array types.
14///
15/// This trait defines the common interface for all sparse array implementations.
16/// It is designed to align with SciPy's sparse array API, providing array-like semantics
17/// rather than matrix-like semantics.
18///
19/// # Notes
20///
21/// The sparse array API differs from the sparse matrix API in the following ways:
22///
23/// - `*` operator performs element-wise multiplication, not matrix multiplication
24/// - Matrix multiplication is done with the `dot` method or `@` operator in Python
25/// - Operations like `sum` produce arrays, not matrices
26/// - Sparse arrays use array-style slicing operations
27///
28pub trait SparseArray<T>: std::any::Any
29where
30    T: Float
31        + Add<Output = T>
32        + Sub<Output = T>
33        + Mul<Output = T>
34        + Div<Output = T>
35        + Debug
36        + Copy
37        + 'static,
38{
39    /// Returns the shape of the sparse array.
40    fn shape(&self) -> (usize, usize);
41
42    /// Returns the number of stored (non-zero) elements.
43    fn nnz(&self) -> usize;
44
45    /// Returns the data type of the sparse array.
46    fn dtype(&self) -> &str;
47
48    /// Returns a view of the sparse array as a dense ndarray.
49    fn to_array(&self) -> Array2<T>;
50
51    /// Returns a dense copy of the sparse array.
52    fn toarray(&self) -> Array2<T>;
53
54    /// Returns a sparse array in COO format.
55    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
56
57    /// Returns a sparse array in CSR format.
58    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
59
60    /// Returns a sparse array in CSC format.
61    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
62
63    /// Returns a sparse array in DOK format.
64    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
65
66    /// Returns a sparse array in LIL format.
67    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
68
69    /// Returns a sparse array in DIA format.
70    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
71
72    /// Returns a sparse array in BSR format.
73    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
74
75    /// Element-wise addition.
76    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
77
78    /// Element-wise subtraction.
79    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
80
81    /// Element-wise multiplication.
82    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
83
84    /// Element-wise division.
85    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
86
87    /// Matrix multiplication.
88    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
89
90    /// Matrix-vector multiplication.
91    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>>;
92
93    /// Transpose the sparse array.
94    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
95
96    /// Return a copy of the sparse array with the specified elements.
97    fn copy(&self) -> Box<dyn SparseArray<T>>;
98
99    /// Get a value at the specified position.
100    fn get(&self, i: usize, j: usize) -> T;
101
102    /// Set a value at the specified position.
103    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()>;
104
105    /// Eliminate zeros from the sparse array.
106    fn eliminate_zeros(&mut self);
107
108    /// Sort indices of the sparse array.
109    fn sort_indices(&mut self);
110
111    /// Return a sorted copy of this sparse array.
112    fn sorted_indices(&self) -> Box<dyn SparseArray<T>>;
113
114    /// Check if indices are sorted.
115    fn has_sorted_indices(&self) -> bool;
116
117    /// Sum the sparse array elements.
118    ///
119    /// Parameters:
120    /// - `axis`: The axis along which to sum. If None, sum over both axes.
121    ///
122    /// Returns a sparse array if summing over a single axis, or a scalar if summing over both axes.
123    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>>;
124
125    /// Compute the maximum value of the sparse array elements.
126    fn max(&self) -> T;
127
128    /// Compute the minimum value of the sparse array elements.
129    fn min(&self) -> T;
130
131    /// Return the indices and values of the nonzero elements.
132    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>);
133
134    /// Return a slice of the sparse array.
135    fn slice(
136        &self,
137        row_range: (usize, usize),
138        col_range: (usize, usize),
139    ) -> SparseResult<Box<dyn SparseArray<T>>>;
140
141    /// Returns the concrete type of the array for downcasting.
142    fn as_any(&self) -> &dyn std::any::Any;
143
144    /// Returns the indptr array for CSR/CSC formats.
145    /// For formats that don't have indptr, returns None.
146    fn get_indptr(&self) -> Option<&Array1<usize>> {
147        None
148    }
149
150    /// Returns the indptr array for CSR/CSC formats.
151    /// For formats that don't have indptr, returns None.
152    fn indptr(&self) -> Option<&Array1<usize>> {
153        None
154    }
155}
156
157/// Represents the result of a sum operation on a sparse array.
158// Manually implement Debug and Clone instead of deriving them
159pub enum SparseSum<T>
160where
161    T: Float + Debug + Copy + 'static,
162{
163    /// Sum over a single axis, returning a sparse array.
164    SparseArray(Box<dyn SparseArray<T>>),
165
166    /// Sum over both axes, returning a scalar.
167    Scalar(T),
168}
169
170impl<T> Debug for SparseSum<T>
171where
172    T: Float + Debug + Copy + 'static,
173{
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        match self {
176            SparseSum::SparseArray(_) => write!(f, "SparseSum::SparseArray(...)"),
177            SparseSum::Scalar(value) => write!(f, "SparseSum::Scalar({value:?})"),
178        }
179    }
180}
181
182impl<T> Clone for SparseSum<T>
183where
184    T: Float + Debug + Copy + 'static,
185{
186    fn clone(&self) -> Self {
187        match self {
188            SparseSum::SparseArray(array) => SparseSum::SparseArray(array.copy()),
189            SparseSum::Scalar(value) => SparseSum::Scalar(*value),
190        }
191    }
192}
193
194/// Identifies sparse arrays (both matrix and array types)
195#[allow(dead_code)]
196pub fn is_sparse<T>(obj: &dyn SparseArray<T>) -> bool
197where
198    T: Float
199        + Add<Output = T>
200        + Sub<Output = T>
201        + Mul<Output = T>
202        + Div<Output = T>
203        + Debug
204        + Copy
205        + 'static,
206{
207    true // Since this is a trait method, any object that implements it is sparse
208}
209
210/// Create a base SparseArray implementation for demonstrations and testing
211pub struct SparseArrayBase<T>
212where
213    T: Float
214        + Add<Output = T>
215        + Sub<Output = T>
216        + Mul<Output = T>
217        + Div<Output = T>
218        + Debug
219        + Copy
220        + 'static,
221{
222    data: Array2<T>,
223}
224
225impl<T> SparseArrayBase<T>
226where
227    T: Float
228        + Add<Output = T>
229        + Sub<Output = T>
230        + Mul<Output = T>
231        + Div<Output = T>
232        + Debug
233        + Copy
234        + 'static,
235{
236    /// Create a new SparseArrayBase from a dense ndarray.
237    pub fn new(data: Array2<T>) -> Self {
238        Self { data }
239    }
240}
241
242impl<T> SparseArray<T> for SparseArrayBase<T>
243where
244    T: Float
245        + Add<Output = T>
246        + Sub<Output = T>
247        + Mul<Output = T>
248        + Div<Output = T>
249        + Debug
250        + Copy
251        + 'static,
252{
253    fn shape(&self) -> (usize, usize) {
254        let shape = self.data.shape();
255        (shape[0], shape[1])
256    }
257
258    fn nnz(&self) -> usize {
259        self.data.iter().filter(|&&x| !x.is_zero()).count()
260    }
261
262    fn dtype(&self) -> &str {
263        "float" // This is a placeholder; ideally, we'd return the actual type
264    }
265
266    fn to_array(&self) -> Array2<T> {
267        self.data.clone()
268    }
269
270    fn toarray(&self) -> Array2<T> {
271        self.data.clone()
272    }
273
274    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
275        // In a real implementation, this would convert to COO format
276        Ok(Box::new(self.clone()))
277    }
278
279    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
280        // In a real implementation, this would convert to CSR format
281        Ok(Box::new(self.clone()))
282    }
283
284    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
285        // In a real implementation, this would convert to CSC format
286        Ok(Box::new(self.clone()))
287    }
288
289    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
290        // In a real implementation, this would convert to DOK format
291        Ok(Box::new(self.clone()))
292    }
293
294    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
295        // In a real implementation, this would convert to LIL format
296        Ok(Box::new(self.clone()))
297    }
298
299    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
300        // In a real implementation, this would convert to DIA format
301        Ok(Box::new(self.clone()))
302    }
303
304    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
305        // In a real implementation, this would convert to BSR format
306        Ok(Box::new(self.clone()))
307    }
308
309    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
310        let other_array = other.to_array();
311        let result = &self.data + &other_array;
312        Ok(Box::new(SparseArrayBase::new(result)))
313    }
314
315    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
316        let other_array = other.to_array();
317        let result = &self.data - &other_array;
318        Ok(Box::new(SparseArrayBase::new(result)))
319    }
320
321    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
322        let other_array = other.to_array();
323        let result = &self.data * &other_array;
324        Ok(Box::new(SparseArrayBase::new(result)))
325    }
326
327    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
328        let other_array = other.to_array();
329        let result = &self.data / &other_array;
330        Ok(Box::new(SparseArrayBase::new(result)))
331    }
332
333    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
334        let other_array = other.to_array();
335        let (m, n) = self.shape();
336        let (p, q) = other.shape();
337
338        if n != p {
339            return Err(SparseError::DimensionMismatch {
340                expected: n,
341                found: p,
342            });
343        }
344
345        let mut result = Array2::zeros((m, q));
346        for i in 0..m {
347            for j in 0..q {
348                let mut sum = T::zero();
349                for k in 0..n {
350                    sum = sum + self.data[[i, k]] * other_array[[k, j]];
351                }
352                result[[i, j]] = sum;
353            }
354        }
355
356        Ok(Box::new(SparseArrayBase::new(result)))
357    }
358
359    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
360        let (m, n) = self.shape();
361        if n != other.len() {
362            return Err(SparseError::DimensionMismatch {
363                expected: n,
364                found: other.len(),
365            });
366        }
367
368        let mut result = Array1::zeros(m);
369        for i in 0..m {
370            let mut sum = T::zero();
371            for j in 0..n {
372                sum = sum + self.data[[i, j]] * other[j];
373            }
374            result[i] = sum;
375        }
376
377        Ok(result)
378    }
379
380    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
381        Ok(Box::new(SparseArrayBase::new(self.data.t().to_owned())))
382    }
383
384    fn copy(&self) -> Box<dyn SparseArray<T>> {
385        Box::new(self.clone())
386    }
387
388    fn get(&self, i: usize, j: usize) -> T {
389        self.data[[i, j]]
390    }
391
392    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
393        let (m, n) = self.shape();
394        if i >= m || j >= n {
395            return Err(SparseError::IndexOutOfBounds {
396                index: (i, j),
397                shape: (m, n),
398            });
399        }
400        self.data[[i, j]] = value;
401        Ok(())
402    }
403
404    fn eliminate_zeros(&mut self) {
405        // No-op for dense array
406    }
407
408    fn sort_indices(&mut self) {
409        // No-op for dense array
410    }
411
412    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
413        self.copy()
414    }
415
416    fn has_sorted_indices(&self) -> bool {
417        true // Dense array has implicitly sorted indices
418    }
419
420    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
421        match axis {
422            None => {
423                let mut sum = T::zero();
424                for &val in self.data.iter() {
425                    sum = sum + val;
426                }
427                Ok(SparseSum::Scalar(sum))
428            }
429            Some(0) => {
430                let (_, n) = self.shape();
431                let mut result = Array2::zeros((1, n));
432                for j in 0..n {
433                    let mut sum = T::zero();
434                    for i in 0..self.data.shape()[0] {
435                        sum = sum + self.data[[i, j]];
436                    }
437                    result[[0, j]] = sum;
438                }
439                Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
440                    result,
441                ))))
442            }
443            Some(1) => {
444                let (m_, _) = self.shape();
445                let mut result = Array2::zeros((m_, 1));
446                for i in 0..m_ {
447                    let mut sum = T::zero();
448                    for j in 0..self.data.shape()[1] {
449                        sum = sum + self.data[[i, j]];
450                    }
451                    result[[i, 0]] = sum;
452                }
453                Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
454                    result,
455                ))))
456            }
457            _ => Err(SparseError::InvalidAxis),
458        }
459    }
460
461    fn max(&self) -> T {
462        self.data
463            .iter()
464            .fold(T::neg_infinity(), |acc, &x| acc.max(x))
465    }
466
467    fn min(&self) -> T {
468        self.data.iter().fold(T::infinity(), |acc, &x| acc.min(x))
469    }
470
471    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
472        let (m, n) = self.shape();
473        let nnz = self.nnz();
474        let mut rows = Vec::with_capacity(nnz);
475        let mut cols = Vec::with_capacity(nnz);
476        let mut values = Vec::with_capacity(nnz);
477
478        for i in 0..m {
479            for j in 0..n {
480                let value = self.data[[i, j]];
481                if !value.is_zero() {
482                    rows.push(i);
483                    cols.push(j);
484                    values.push(value);
485                }
486            }
487        }
488
489        (
490            Array1::from_vec(rows),
491            Array1::from_vec(cols),
492            Array1::from_vec(values),
493        )
494    }
495
496    fn slice(
497        &self,
498        row_range: (usize, usize),
499        col_range: (usize, usize),
500    ) -> SparseResult<Box<dyn SparseArray<T>>> {
501        let (start_row, end_row) = row_range;
502        let (start_col, end_col) = col_range;
503        let (m, n) = self.shape();
504
505        if start_row >= m
506            || end_row > m
507            || start_col >= n
508            || end_col > n
509            || start_row >= end_row
510            || start_col >= end_col
511        {
512            return Err(SparseError::InvalidSliceRange);
513        }
514
515        let view = self.data.slice(scirs2_core::ndarray::s![
516            start_row..end_row,
517            start_col..end_col
518        ]);
519        Ok(Box::new(SparseArrayBase::new(view.to_owned())))
520    }
521
522    fn as_any(&self) -> &dyn std::any::Any {
523        self
524    }
525}
526
527impl<T> Clone for SparseArrayBase<T>
528where
529    T: Float
530        + Add<Output = T>
531        + Sub<Output = T>
532        + Mul<Output = T>
533        + Div<Output = T>
534        + Debug
535        + Copy
536        + 'static,
537{
538    fn clone(&self) -> Self {
539        Self {
540            data: self.data.clone(),
541        }
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use scirs2_core::ndarray::Array;
549
550    #[test]
551    fn test_sparse_array_base() {
552        let data = Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0])
553            .unwrap();
554        let sparse = SparseArrayBase::new(data);
555
556        assert_eq!(sparse.shape(), (3, 3));
557        assert_eq!(sparse.nnz(), 5);
558        assert_eq!(sparse.get(0, 0), 1.0);
559        assert_eq!(sparse.get(1, 1), 3.0);
560        assert_eq!(sparse.get(2, 2), 5.0);
561        assert_eq!(sparse.get(0, 1), 0.0);
562    }
563
564    #[test]
565    fn test_sparse_array_operations() {
566        let data1 = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
567        let data2 = Array::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
568
569        let sparse1 = SparseArrayBase::new(data1);
570        let sparse2 = SparseArrayBase::new(data2);
571
572        // Test add
573        let result = sparse1.add(&sparse2).unwrap();
574        let result_array = result.to_array();
575        assert_eq!(result_array[[0, 0]], 6.0);
576        assert_eq!(result_array[[0, 1]], 8.0);
577        assert_eq!(result_array[[1, 0]], 10.0);
578        assert_eq!(result_array[[1, 1]], 12.0);
579
580        // Test dot
581        let result = sparse1.dot(&sparse2).unwrap();
582        let result_array = result.to_array();
583        assert_eq!(result_array[[0, 0]], 19.0);
584        assert_eq!(result_array[[0, 1]], 22.0);
585        assert_eq!(result_array[[1, 0]], 43.0);
586        assert_eq!(result_array[[1, 1]], 50.0);
587    }
588}