scirs2_sparse/
sparray.rs

1// Sparse Array API
2//
3// This module provides the base trait for sparse arrays, inspired by SciPy's transition
4// from matrix-based API to array-based API.
5
6use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::Debug;
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12
13/// Trait for sparse array types.
14///
15/// This trait defines the common interface for all sparse array implementations.
16/// It is designed to align with SciPy's sparse array API, providing array-like semantics
17/// rather than matrix-like semantics.
18///
19/// # Notes
20///
21/// The sparse array API differs from the sparse matrix API in the following ways:
22///
23/// - `*` operator performs element-wise multiplication, not matrix multiplication
24/// - Matrix multiplication is done with the `dot` method or `@` operator in Python
25/// - Operations like `sum` produce arrays, not matrices
26/// - Sparse arrays use array-style slicing operations
27///
28pub trait SparseArray<T>: std::any::Any
29where
30    T: Float
31        + Add<Output = T>
32        + Sub<Output = T>
33        + Mul<Output = T>
34        + Div<Output = T>
35        + Debug
36        + Copy
37        + 'static,
38{
39    /// Returns the shape of the sparse array.
40    fn shape(&self) -> (usize, usize);
41
42    /// Returns the number of stored (non-zero) elements.
43    fn nnz(&self) -> usize;
44
45    /// Returns the data type of the sparse array.
46    fn dtype(&self) -> &str;
47
48    /// Returns a view of the sparse array as a dense ndarray.
49    fn to_array(&self) -> Array2<T>;
50
51    /// Returns a dense copy of the sparse array.
52    fn toarray(&self) -> Array2<T>;
53
54    /// Returns a sparse array in COO format.
55    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
56
57    /// Returns a sparse array in CSR format.
58    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
59
60    /// Returns a sparse array in CSC format.
61    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
62
63    /// Returns a sparse array in DOK format.
64    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
65
66    /// Returns a sparse array in LIL format.
67    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
68
69    /// Returns a sparse array in DIA format.
70    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
71
72    /// Returns a sparse array in BSR format.
73    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
74
75    /// Element-wise addition.
76    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
77
78    /// Element-wise subtraction.
79    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
80
81    /// Element-wise multiplication.
82    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
83
84    /// Element-wise division.
85    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
86
87    /// Matrix multiplication.
88    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
89
90    /// Matrix-vector multiplication.
91    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>>;
92
93    /// Transpose the sparse array.
94    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
95
96    /// Return a copy of the sparse array with the specified elements.
97    fn copy(&self) -> Box<dyn SparseArray<T>>;
98
99    /// Get a value at the specified position.
100    fn get(&self, i: usize, j: usize) -> T;
101
102    /// Set a value at the specified position.
103    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()>;
104
105    /// Eliminate zeros from the sparse array.
106    fn eliminate_zeros(&mut self);
107
108    /// Sort indices of the sparse array.
109    fn sort_indices(&mut self);
110
111    /// Return a sorted copy of this sparse array.
112    fn sorted_indices(&self) -> Box<dyn SparseArray<T>>;
113
114    /// Check if indices are sorted.
115    fn has_sorted_indices(&self) -> bool;
116
117    /// Sum the sparse array elements.
118    ///
119    /// Parameters:
120    /// - `axis`: The axis along which to sum. If None, sum over both axes.
121    ///
122    /// Returns a sparse array if summing over a single axis, or a scalar if summing over both axes.
123    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>>;
124
125    /// Compute the maximum value of the sparse array elements.
126    fn max(&self) -> T;
127
128    /// Compute the minimum value of the sparse array elements.
129    fn min(&self) -> T;
130
131    /// Return the indices and values of the nonzero elements.
132    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>);
133
134    /// Return a slice of the sparse array.
135    fn slice(
136        &self,
137        row_range: (usize, usize),
138        col_range: (usize, usize),
139    ) -> SparseResult<Box<dyn SparseArray<T>>>;
140
141    /// Returns the concrete type of the array for downcasting.
142    fn as_any(&self) -> &dyn std::any::Any;
143}
144
145/// Represents the result of a sum operation on a sparse array.
146// Manually implement Debug and Clone instead of deriving them
147pub enum SparseSum<T>
148where
149    T: Float + Debug + Copy + 'static,
150{
151    /// Sum over a single axis, returning a sparse array.
152    SparseArray(Box<dyn SparseArray<T>>),
153
154    /// Sum over both axes, returning a scalar.
155    Scalar(T),
156}
157
158impl<T> Debug for SparseSum<T>
159where
160    T: Float + Debug + Copy + 'static,
161{
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        match self {
164            SparseSum::SparseArray(_) => write!(f, "SparseSum::SparseArray(...)"),
165            SparseSum::Scalar(value) => write!(f, "SparseSum::Scalar({:?})", value),
166        }
167    }
168}
169
170impl<T> Clone for SparseSum<T>
171where
172    T: Float + Debug + Copy + 'static,
173{
174    fn clone(&self) -> Self {
175        match self {
176            SparseSum::SparseArray(array) => SparseSum::SparseArray(array.copy()),
177            SparseSum::Scalar(value) => SparseSum::Scalar(*value),
178        }
179    }
180}
181
182/// Identifies sparse arrays (both matrix and array types)
183pub fn is_sparse<T>(_obj: &dyn SparseArray<T>) -> bool
184where
185    T: Float
186        + Add<Output = T>
187        + Sub<Output = T>
188        + Mul<Output = T>
189        + Div<Output = T>
190        + Debug
191        + Copy
192        + 'static,
193{
194    true // Since this is a trait method, any object that implements it is sparse
195}
196
197/// Create a base SparseArray implementation for demonstrations and testing
198pub struct SparseArrayBase<T>
199where
200    T: Float
201        + Add<Output = T>
202        + Sub<Output = T>
203        + Mul<Output = T>
204        + Div<Output = T>
205        + Debug
206        + Copy
207        + 'static,
208{
209    data: Array2<T>,
210}
211
212impl<T> SparseArrayBase<T>
213where
214    T: Float
215        + Add<Output = T>
216        + Sub<Output = T>
217        + Mul<Output = T>
218        + Div<Output = T>
219        + Debug
220        + Copy
221        + 'static,
222{
223    /// Create a new SparseArrayBase from a dense ndarray.
224    pub fn new(data: Array2<T>) -> Self {
225        Self { data }
226    }
227}
228
229impl<T> SparseArray<T> for SparseArrayBase<T>
230where
231    T: Float
232        + Add<Output = T>
233        + Sub<Output = T>
234        + Mul<Output = T>
235        + Div<Output = T>
236        + Debug
237        + Copy
238        + 'static,
239{
240    fn shape(&self) -> (usize, usize) {
241        let shape = self.data.shape();
242        (shape[0], shape[1])
243    }
244
245    fn nnz(&self) -> usize {
246        self.data.iter().filter(|&&x| !x.is_zero()).count()
247    }
248
249    fn dtype(&self) -> &str {
250        "float" // This is a placeholder; ideally, we'd return the actual type
251    }
252
253    fn to_array(&self) -> Array2<T> {
254        self.data.clone()
255    }
256
257    fn toarray(&self) -> Array2<T> {
258        self.data.clone()
259    }
260
261    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
262        // In a real implementation, this would convert to COO format
263        Ok(Box::new(self.clone()))
264    }
265
266    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
267        // In a real implementation, this would convert to CSR format
268        Ok(Box::new(self.clone()))
269    }
270
271    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
272        // In a real implementation, this would convert to CSC format
273        Ok(Box::new(self.clone()))
274    }
275
276    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
277        // In a real implementation, this would convert to DOK format
278        Ok(Box::new(self.clone()))
279    }
280
281    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
282        // In a real implementation, this would convert to LIL format
283        Ok(Box::new(self.clone()))
284    }
285
286    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
287        // In a real implementation, this would convert to DIA format
288        Ok(Box::new(self.clone()))
289    }
290
291    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
292        // In a real implementation, this would convert to BSR format
293        Ok(Box::new(self.clone()))
294    }
295
296    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
297        let other_array = other.to_array();
298        let result = &self.data + &other_array;
299        Ok(Box::new(SparseArrayBase::new(result)))
300    }
301
302    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
303        let other_array = other.to_array();
304        let result = &self.data - &other_array;
305        Ok(Box::new(SparseArrayBase::new(result)))
306    }
307
308    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
309        let other_array = other.to_array();
310        let result = &self.data * &other_array;
311        Ok(Box::new(SparseArrayBase::new(result)))
312    }
313
314    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
315        let other_array = other.to_array();
316        let result = &self.data / &other_array;
317        Ok(Box::new(SparseArrayBase::new(result)))
318    }
319
320    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
321        let other_array = other.to_array();
322        let (m, n) = self.shape();
323        let (p, q) = other.shape();
324
325        if n != p {
326            return Err(SparseError::DimensionMismatch {
327                expected: n,
328                found: p,
329            });
330        }
331
332        let mut result = Array2::zeros((m, q));
333        for i in 0..m {
334            for j in 0..q {
335                let mut sum = T::zero();
336                for k in 0..n {
337                    sum = sum + self.data[[i, k]] * other_array[[k, j]];
338                }
339                result[[i, j]] = sum;
340            }
341        }
342
343        Ok(Box::new(SparseArrayBase::new(result)))
344    }
345
346    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
347        let (m, n) = self.shape();
348        if n != other.len() {
349            return Err(SparseError::DimensionMismatch {
350                expected: n,
351                found: other.len(),
352            });
353        }
354
355        let mut result = Array1::zeros(m);
356        for i in 0..m {
357            let mut sum = T::zero();
358            for j in 0..n {
359                sum = sum + self.data[[i, j]] * other[j];
360            }
361            result[i] = sum;
362        }
363
364        Ok(result)
365    }
366
367    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
368        Ok(Box::new(SparseArrayBase::new(self.data.t().to_owned())))
369    }
370
371    fn copy(&self) -> Box<dyn SparseArray<T>> {
372        Box::new(self.clone())
373    }
374
375    fn get(&self, i: usize, j: usize) -> T {
376        self.data[[i, j]]
377    }
378
379    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
380        let (m, n) = self.shape();
381        if i >= m || j >= n {
382            return Err(SparseError::IndexOutOfBounds {
383                index: (i, j),
384                shape: (m, n),
385            });
386        }
387        self.data[[i, j]] = value;
388        Ok(())
389    }
390
391    fn eliminate_zeros(&mut self) {
392        // No-op for dense array
393    }
394
395    fn sort_indices(&mut self) {
396        // No-op for dense array
397    }
398
399    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
400        self.copy()
401    }
402
403    fn has_sorted_indices(&self) -> bool {
404        true // Dense array has implicitly sorted indices
405    }
406
407    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
408        match axis {
409            None => {
410                let mut sum = T::zero();
411                for &val in self.data.iter() {
412                    sum = sum + val;
413                }
414                Ok(SparseSum::Scalar(sum))
415            }
416            Some(0) => {
417                let (_, n) = self.shape();
418                let mut result = Array2::zeros((1, n));
419                for j in 0..n {
420                    let mut sum = T::zero();
421                    for i in 0..self.data.shape()[0] {
422                        sum = sum + self.data[[i, j]];
423                    }
424                    result[[0, j]] = sum;
425                }
426                Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
427                    result,
428                ))))
429            }
430            Some(1) => {
431                let (m, _) = self.shape();
432                let mut result = Array2::zeros((m, 1));
433                for i in 0..m {
434                    let mut sum = T::zero();
435                    for j in 0..self.data.shape()[1] {
436                        sum = sum + self.data[[i, j]];
437                    }
438                    result[[i, 0]] = sum;
439                }
440                Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
441                    result,
442                ))))
443            }
444            _ => Err(SparseError::InvalidAxis),
445        }
446    }
447
448    fn max(&self) -> T {
449        self.data
450            .iter()
451            .fold(T::neg_infinity(), |acc, &x| acc.max(x))
452    }
453
454    fn min(&self) -> T {
455        self.data.iter().fold(T::infinity(), |acc, &x| acc.min(x))
456    }
457
458    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
459        let (m, n) = self.shape();
460        let nnz = self.nnz();
461        let mut rows = Vec::with_capacity(nnz);
462        let mut cols = Vec::with_capacity(nnz);
463        let mut values = Vec::with_capacity(nnz);
464
465        for i in 0..m {
466            for j in 0..n {
467                let value = self.data[[i, j]];
468                if !value.is_zero() {
469                    rows.push(i);
470                    cols.push(j);
471                    values.push(value);
472                }
473            }
474        }
475
476        (
477            Array1::from_vec(rows),
478            Array1::from_vec(cols),
479            Array1::from_vec(values),
480        )
481    }
482
483    fn slice(
484        &self,
485        row_range: (usize, usize),
486        col_range: (usize, usize),
487    ) -> SparseResult<Box<dyn SparseArray<T>>> {
488        let (start_row, end_row) = row_range;
489        let (start_col, end_col) = col_range;
490        let (m, n) = self.shape();
491
492        if start_row >= m
493            || end_row > m
494            || start_col >= n
495            || end_col > n
496            || start_row >= end_row
497            || start_col >= end_col
498        {
499            return Err(SparseError::InvalidSliceRange);
500        }
501
502        let view = self
503            .data
504            .slice(ndarray::s![start_row..end_row, start_col..end_col]);
505        Ok(Box::new(SparseArrayBase::new(view.to_owned())))
506    }
507
508    fn as_any(&self) -> &dyn std::any::Any {
509        self
510    }
511}
512
513impl<T> Clone for SparseArrayBase<T>
514where
515    T: Float
516        + Add<Output = T>
517        + Sub<Output = T>
518        + Mul<Output = T>
519        + Div<Output = T>
520        + Debug
521        + Copy
522        + 'static,
523{
524    fn clone(&self) -> Self {
525        Self {
526            data: self.data.clone(),
527        }
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534    use ndarray::Array;
535
536    #[test]
537    fn test_sparse_array_base() {
538        let data = Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0])
539            .unwrap();
540        let sparse = SparseArrayBase::new(data);
541
542        assert_eq!(sparse.shape(), (3, 3));
543        assert_eq!(sparse.nnz(), 5);
544        assert_eq!(sparse.get(0, 0), 1.0);
545        assert_eq!(sparse.get(1, 1), 3.0);
546        assert_eq!(sparse.get(2, 2), 5.0);
547        assert_eq!(sparse.get(0, 1), 0.0);
548    }
549
550    #[test]
551    fn test_sparse_array_operations() {
552        let data1 = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
553        let data2 = Array::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
554
555        let sparse1 = SparseArrayBase::new(data1);
556        let sparse2 = SparseArrayBase::new(data2);
557
558        // Test add
559        let result = sparse1.add(&sparse2).unwrap();
560        let result_array = result.to_array();
561        assert_eq!(result_array[[0, 0]], 6.0);
562        assert_eq!(result_array[[0, 1]], 8.0);
563        assert_eq!(result_array[[1, 0]], 10.0);
564        assert_eq!(result_array[[1, 1]], 12.0);
565
566        // Test dot
567        let result = sparse1.dot(&sparse2).unwrap();
568        let result_array = result.to_array();
569        assert_eq!(result_array[[0, 0]], 19.0);
570        assert_eq!(result_array[[0, 1]], 22.0);
571        assert_eq!(result_array[[1, 0]], 43.0);
572        assert_eq!(result_array[[1, 1]], 50.0);
573    }
574}