scirs2_sparse/
dok_array.rs

1// Dictionary of Keys (DOK) Array implementation
2//
3// This module provides the DOK (Dictionary of Keys) array format,
4// which is efficient for incremental construction of sparse arrays.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::error::{SparseError, SparseResult};
15use crate::sparray::{SparseArray, SparseSum};
16
17/// DOK Array format
18///
19/// The DOK (Dictionary of Keys) format stores a sparse array in a dictionary (HashMap)
20/// mapping (row, col) coordinate tuples to values.
21///
22/// # Notes
23///
24/// - Efficient for incremental construction (setting elements one by one)
25/// - Fast random access to individual elements (get/set)
26/// - Slow operations that require iterating over all elements
27/// - Slow arithmetic operations
28/// - Not suitable for large-scale computational operations
29///
30#[derive(Clone)]
31pub struct DokArray<T>
32where
33    T: SparseElement + Div<Output = T> + 'static,
34{
35    /// Dictionary mapping (row, col) to value
36    data: HashMap<(usize, usize), T>,
37    /// Shape of the sparse array
38    shape: (usize, usize),
39}
40
41impl<T> DokArray<T>
42where
43    T: SparseElement + Div<Output = T> + 'static,
44{
45    /// Creates a new DOK array with the given shape
46    ///
47    /// # Arguments
48    /// * `shape` - Shape of the sparse array (rows, cols)
49    ///
50    /// # Returns
51    /// A new empty `DokArray`
52    pub fn new(shape: (usize, usize)) -> Self {
53        Self {
54            data: HashMap::new(),
55            shape,
56        }
57    }
58
59    /// Creates a DOK array from triplet format (COO-like)
60    ///
61    /// # Arguments
62    /// * `rows` - Row indices
63    /// * `cols` - Column indices
64    /// * `data` - Values
65    /// * `shape` - Shape of the sparse array
66    ///
67    /// # Returns
68    /// A new `DokArray`
69    ///
70    /// # Errors
71    /// Returns an error if the data is not consistent
72    pub fn from_triplets(
73        rows: &[usize],
74        cols: &[usize],
75        data: &[T],
76        shape: (usize, usize),
77    ) -> SparseResult<Self> {
78        if rows.len() != cols.len() || rows.len() != data.len() {
79            return Err(SparseError::InconsistentData {
80                reason: "rows, cols, and data must have the same length".to_string(),
81            });
82        }
83
84        let mut dok = Self::new(shape);
85        for i in 0..rows.len() {
86            if rows[i] >= shape.0 || cols[i] >= shape.1 {
87                return Err(SparseError::IndexOutOfBounds {
88                    index: (rows[i], cols[i]),
89                    shape,
90                });
91            }
92            // Only set non-zero values
93            if !SparseElement::is_zero(&data[i]) {
94                dok.data.insert((rows[i], cols[i]), data[i]);
95            }
96        }
97
98        Ok(dok)
99    }
100
101    /// Returns a reference to the internal HashMap
102    pub fn get_data(&self) -> &HashMap<(usize, usize), T> {
103        &self.data
104    }
105
106    /// Returns the triplet representation (row indices, column indices, data)
107    pub fn to_triplets(&self) -> (Array1<usize>, Array1<usize>, Array1<T>)
108    where
109        T: Float + PartialOrd,
110    {
111        let nnz = self.nnz();
112        let mut row_indices = Vec::with_capacity(nnz);
113        let mut col_indices = Vec::with_capacity(nnz);
114        let mut values = Vec::with_capacity(nnz);
115
116        // Sort by row, then column for deterministic output
117        let mut entries: Vec<_> = self.data.iter().collect();
118        entries.sort_by_key(|(&(row, col), _)| (row, col));
119
120        for (&(row, col), &value) in entries {
121            row_indices.push(row);
122            col_indices.push(col);
123            values.push(value);
124        }
125
126        (
127            Array1::from_vec(row_indices),
128            Array1::from_vec(col_indices),
129            Array1::from_vec(values),
130        )
131    }
132
133    /// Creates a DOK array from a dense ndarray
134    ///
135    /// # Arguments
136    /// * `array` - Dense ndarray
137    ///
138    /// # Returns
139    /// A new `DokArray` containing non-zero elements from the input array
140    pub fn from_array(array: &Array2<T>) -> Self {
141        let shape = (array.shape()[0], array.shape()[1]);
142        let mut dok = Self::new(shape);
143
144        for ((i, j), &value) in array.indexed_iter() {
145            if !SparseElement::is_zero(&value) {
146                dok.data.insert((i, j), value);
147            }
148        }
149
150        dok
151    }
152}
153
154impl<T> SparseArray<T> for DokArray<T>
155where
156    T: SparseElement + Div<Output = T> + Float + PartialOrd + 'static,
157{
158    fn shape(&self) -> (usize, usize) {
159        self.shape
160    }
161
162    fn nnz(&self) -> usize {
163        self.data.len()
164    }
165
166    fn dtype(&self) -> &str {
167        "float" // This is a placeholder; ideally, we'd return the actual type
168    }
169
170    fn to_array(&self) -> Array2<T> {
171        let (rows, cols) = self.shape;
172        let mut result = Array2::zeros((rows, cols));
173
174        for (&(row, col), &value) in &self.data {
175            result[[row, col]] = value;
176        }
177
178        result
179    }
180
181    fn toarray(&self) -> Array2<T> {
182        self.to_array()
183    }
184
185    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
186        let (row_indices, col_indices, data) = self.to_triplets();
187        CooArray::new(data, row_indices, col_indices, self.shape, true)
188            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
189    }
190
191    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
192        // First convert to COO, then to CSR
193        match self.to_coo() {
194            Ok(coo) => coo.to_csr(),
195            Err(e) => Err(e),
196        }
197    }
198
199    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
200        // First convert to COO, then to CSC
201        match self.to_coo() {
202            Ok(coo) => coo.to_csc(),
203            Err(e) => Err(e),
204        }
205    }
206
207    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
208        // We're already a DOK array
209        Ok(Box::new(self.clone()))
210    }
211
212    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
213        Err(SparseError::NotImplemented(
214            "Conversion to LIL array".to_string(),
215        ))
216    }
217
218    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
219        Err(SparseError::NotImplemented(
220            "Conversion to DIA array".to_string(),
221        ))
222    }
223
224    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
225        Err(SparseError::NotImplemented(
226            "Conversion to BSR array".to_string(),
227        ))
228    }
229
230    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
231        if self.shape() != other.shape() {
232            return Err(SparseError::DimensionMismatch {
233                expected: self.shape().0,
234                found: other.shape().0,
235            });
236        }
237
238        let mut result = self.clone();
239        let other_array = other.to_array();
240
241        // Add existing values from self
242        for (&(row, col), &value) in &self.data {
243            result.set(row, col, value + other_array[[row, col]])?;
244        }
245
246        // Add values from other that aren't in self
247        for ((row, col), &value) in other_array.indexed_iter() {
248            if !self.data.contains_key(&(row, col)) && !SparseElement::is_zero(&value) {
249                result.set(row, col, value)?;
250            }
251        }
252
253        Ok(Box::new(result))
254    }
255
256    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
257        if self.shape() != other.shape() {
258            return Err(SparseError::DimensionMismatch {
259                expected: self.shape().0,
260                found: other.shape().0,
261            });
262        }
263
264        let mut result = self.clone();
265        let other_array = other.to_array();
266
267        // Subtract existing values from self
268        for (&(row, col), &value) in &self.data {
269            result.set(row, col, value - other_array[[row, col]])?;
270        }
271
272        // Subtract values from other that aren't in self
273        for ((row, col), &value) in other_array.indexed_iter() {
274            if !self.data.contains_key(&(row, col)) && !SparseElement::is_zero(&value) {
275                result.set(row, col, -value)?;
276            }
277        }
278
279        Ok(Box::new(result))
280    }
281
282    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
283        if self.shape() != other.shape() {
284            return Err(SparseError::DimensionMismatch {
285                expected: self.shape().0,
286                found: other.shape().0,
287            });
288        }
289
290        let mut result = DokArray::new(self.shape());
291        let other_array = other.to_array();
292
293        // Only need to process entries in self
294        // since a*0 = 0 for any a
295        for (&(row, col), &value) in &self.data {
296            let product = value * other_array[[row, col]];
297            if !SparseElement::is_zero(&product) {
298                result.set(row, col, product)?;
299            }
300        }
301
302        Ok(Box::new(result))
303    }
304
305    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
306        if self.shape() != other.shape() {
307            return Err(SparseError::DimensionMismatch {
308                expected: self.shape().0,
309                found: other.shape().0,
310            });
311        }
312
313        let mut result = DokArray::new(self.shape());
314        let other_array = other.to_array();
315
316        for (&(row, col), &value) in &self.data {
317            let divisor = other_array[[row, col]];
318            if SparseElement::is_zero(&divisor) {
319                return Err(SparseError::ComputationError(
320                    "Division by zero".to_string(),
321                ));
322            }
323
324            let quotient = value / divisor;
325            if !SparseElement::is_zero(&quotient) {
326                result.set(row, col, quotient)?;
327            }
328        }
329
330        Ok(Box::new(result))
331    }
332
333    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
334        let (_m, n) = self.shape();
335        let (p, q) = other.shape();
336
337        if n != p {
338            return Err(SparseError::DimensionMismatch {
339                expected: n,
340                found: p,
341            });
342        }
343
344        // Convert to CSR for efficient matrix multiplication
345        let csr_self = self.to_csr()?;
346        let csr_other = other.to_csr()?;
347
348        csr_self.dot(&*csr_other)
349    }
350
351    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
352        let (m, n) = self.shape();
353        if n != other.len() {
354            return Err(SparseError::DimensionMismatch {
355                expected: n,
356                found: other.len(),
357            });
358        }
359
360        let mut result = Array1::zeros(m);
361
362        for (&(row, col), &value) in &self.data {
363            result[row] = result[row] + value * other[col];
364        }
365
366        Ok(result)
367    }
368
369    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
370        let (rows, cols) = self.shape;
371        let mut result = DokArray::new((cols, rows));
372
373        for (&(row, col), &value) in &self.data {
374            result.set(col, row, value)?;
375        }
376
377        Ok(Box::new(result))
378    }
379
380    fn copy(&self) -> Box<dyn SparseArray<T>> {
381        Box::new(self.clone())
382    }
383
384    fn get(&self, i: usize, j: usize) -> T {
385        if i >= self.shape.0 || j >= self.shape.1 {
386            return T::sparse_zero();
387        }
388
389        *self.data.get(&(i, j)).unwrap_or(&T::sparse_zero())
390    }
391
392    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
393        if i >= self.shape.0 || j >= self.shape.1 {
394            return Err(SparseError::IndexOutOfBounds {
395                index: (i, j),
396                shape: self.shape,
397            });
398        }
399
400        if SparseElement::is_zero(&value) {
401            // Remove zero entries
402            self.data.remove(&(i, j));
403        } else {
404            // Set non-zero value
405            self.data.insert((i, j), value);
406        }
407
408        Ok(())
409    }
410
411    fn eliminate_zeros(&mut self) {
412        // DOK format already doesn't store zeros, but just in case
413        self.data
414            .retain(|_, &mut value| !SparseElement::is_zero(&value));
415    }
416
417    fn sort_indices(&mut self) {
418        // No-op for DOK format since it's a HashMap
419    }
420
421    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
422        // DOK doesn't have the concept of sorted indices
423        self.copy()
424    }
425
426    fn has_sorted_indices(&self) -> bool {
427        true // DOK format doesn't have the concept of sorted indices
428    }
429
430    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
431        match axis {
432            None => {
433                // Sum all elements
434                let mut sum = T::sparse_zero();
435                for &value in self.data.values() {
436                    sum = sum + value;
437                }
438                Ok(SparseSum::Scalar(sum))
439            }
440            Some(0) => {
441                // Sum along rows
442                let (_, cols) = self.shape();
443                let mut result = DokArray::new((1, cols));
444
445                for (&(_row, col), &value) in &self.data {
446                    let current = result.get(0, col);
447                    result.set(0, col, current + value)?;
448                }
449
450                Ok(SparseSum::SparseArray(Box::new(result)))
451            }
452            Some(1) => {
453                // Sum along columns
454                let (rows, _) = self.shape();
455                let mut result = DokArray::new((rows, 1));
456
457                for (&(row, col), &value) in &self.data {
458                    let current = result.get(row, 0);
459                    result.set(row, 0, current + value)?;
460                }
461
462                Ok(SparseSum::SparseArray(Box::new(result)))
463            }
464            _ => Err(SparseError::InvalidAxis),
465        }
466    }
467
468    fn max(&self) -> T {
469        if self.data.is_empty() {
470            return T::nan();
471        }
472
473        self.data
474            .values()
475            .fold(T::neg_infinity(), |acc, &x| acc.max(x))
476    }
477
478    fn min(&self) -> T {
479        if self.data.is_empty() {
480            return T::nan();
481        }
482
483        self.data
484            .values()
485            .fold(T::sparse_zero(), |acc, &x| acc.min(x))
486    }
487
488    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
489        self.to_triplets()
490    }
491
492    fn slice(
493        &self,
494        row_range: (usize, usize),
495        col_range: (usize, usize),
496    ) -> SparseResult<Box<dyn SparseArray<T>>> {
497        let (start_row, end_row) = row_range;
498        let (start_col, end_col) = col_range;
499        let (rows, cols) = self.shape;
500
501        if start_row >= rows
502            || end_row > rows
503            || start_col >= cols
504            || end_col > cols
505            || start_row >= end_row
506            || start_col >= end_col
507        {
508            return Err(SparseError::InvalidSliceRange);
509        }
510
511        let sliceshape = (end_row - start_row, end_col - start_col);
512        let mut result = DokArray::new(sliceshape);
513
514        for (&(row, col), &value) in &self.data {
515            if row >= start_row && row < end_row && col >= start_col && col < end_col {
516                result.set(row - start_row, col - start_col, value)?;
517            }
518        }
519
520        Ok(Box::new(result))
521    }
522
523    fn as_any(&self) -> &dyn Any {
524        self
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use scirs2_core::ndarray::Array;
532
533    #[test]
534    fn test_dok_array_create_and_access() {
535        // Create a 3x3 sparse array
536        let mut array = DokArray::<f64>::new((3, 3));
537
538        // Set some values
539        array.set(0, 0, 1.0).unwrap();
540        array.set(0, 2, 2.0).unwrap();
541        array.set(1, 2, 3.0).unwrap();
542        array.set(2, 0, 4.0).unwrap();
543        array.set(2, 1, 5.0).unwrap();
544
545        assert_eq!(array.nnz(), 5);
546
547        // Access values
548        assert_eq!(array.get(0, 0), 1.0);
549        assert_eq!(array.get(0, 1), 0.0); // Zero entry
550        assert_eq!(array.get(0, 2), 2.0);
551        assert_eq!(array.get(1, 2), 3.0);
552        assert_eq!(array.get(2, 0), 4.0);
553        assert_eq!(array.get(2, 1), 5.0);
554
555        // Set a value to zero should remove it
556        array.set(0, 0, 0.0).unwrap();
557        assert_eq!(array.nnz(), 4);
558        assert_eq!(array.get(0, 0), 0.0);
559
560        // Out of bounds access should return zero
561        assert_eq!(array.get(3, 0), 0.0);
562        assert_eq!(array.get(0, 3), 0.0);
563    }
564
565    #[test]
566    fn test_dok_array_from_triplets() {
567        let rows = vec![0, 0, 1, 2, 2];
568        let cols = vec![0, 2, 2, 0, 1];
569        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
570
571        let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3)).unwrap();
572
573        assert_eq!(array.nnz(), 5);
574        assert_eq!(array.get(0, 0), 1.0);
575        assert_eq!(array.get(0, 2), 2.0);
576        assert_eq!(array.get(1, 2), 3.0);
577        assert_eq!(array.get(2, 0), 4.0);
578        assert_eq!(array.get(2, 1), 5.0);
579    }
580
581    #[test]
582    fn test_dok_array_to_array() {
583        let rows = vec![0, 0, 1, 2, 2];
584        let cols = vec![0, 2, 2, 0, 1];
585        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
586
587        let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3)).unwrap();
588        let dense = array.to_array();
589
590        let expected =
591            Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
592                .unwrap();
593
594        assert_eq!(dense, expected);
595    }
596
597    #[test]
598    fn test_dok_array_from_array() {
599        let dense =
600            Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
601                .unwrap();
602
603        let array = DokArray::from_array(&dense);
604
605        assert_eq!(array.nnz(), 5);
606        assert_eq!(array.get(0, 0), 1.0);
607        assert_eq!(array.get(0, 2), 2.0);
608        assert_eq!(array.get(1, 2), 3.0);
609        assert_eq!(array.get(2, 0), 4.0);
610        assert_eq!(array.get(2, 1), 5.0);
611    }
612
613    #[test]
614    fn test_dok_array_add() {
615        let mut array1 = DokArray::<f64>::new((2, 2));
616        array1.set(0, 0, 1.0).unwrap();
617        array1.set(0, 1, 2.0).unwrap();
618        array1.set(1, 0, 3.0).unwrap();
619
620        let mut array2 = DokArray::<f64>::new((2, 2));
621        array2.set(0, 0, 4.0).unwrap();
622        array2.set(1, 1, 5.0).unwrap();
623
624        let result = array1.add(&array2).unwrap();
625        let dense_result = result.to_array();
626
627        assert_eq!(dense_result[[0, 0]], 5.0);
628        assert_eq!(dense_result[[0, 1]], 2.0);
629        assert_eq!(dense_result[[1, 0]], 3.0);
630        assert_eq!(dense_result[[1, 1]], 5.0);
631    }
632
633    #[test]
634    fn test_dok_array_mul() {
635        let mut array1 = DokArray::<f64>::new((2, 2));
636        array1.set(0, 0, 1.0).unwrap();
637        array1.set(0, 1, 2.0).unwrap();
638        array1.set(1, 0, 3.0).unwrap();
639        array1.set(1, 1, 4.0).unwrap();
640
641        let mut array2 = DokArray::<f64>::new((2, 2));
642        array2.set(0, 0, 5.0).unwrap();
643        array2.set(0, 1, 6.0).unwrap();
644        array2.set(1, 0, 7.0).unwrap();
645        array2.set(1, 1, 8.0).unwrap();
646
647        // Element-wise multiplication
648        let result = array1.mul(&array2).unwrap();
649        let dense_result = result.to_array();
650
651        assert_eq!(dense_result[[0, 0]], 5.0);
652        assert_eq!(dense_result[[0, 1]], 12.0);
653        assert_eq!(dense_result[[1, 0]], 21.0);
654        assert_eq!(dense_result[[1, 1]], 32.0);
655    }
656
657    #[test]
658    fn test_dok_array_dot() {
659        let mut array1 = DokArray::<f64>::new((2, 2));
660        array1.set(0, 0, 1.0).unwrap();
661        array1.set(0, 1, 2.0).unwrap();
662        array1.set(1, 0, 3.0).unwrap();
663        array1.set(1, 1, 4.0).unwrap();
664
665        let mut array2 = DokArray::<f64>::new((2, 2));
666        array2.set(0, 0, 5.0).unwrap();
667        array2.set(0, 1, 6.0).unwrap();
668        array2.set(1, 0, 7.0).unwrap();
669        array2.set(1, 1, 8.0).unwrap();
670
671        // Matrix multiplication
672        let result = array1.dot(&array2).unwrap();
673        let dense_result = result.to_array();
674
675        // [1 2] [5 6] = [1*5 + 2*7, 1*6 + 2*8] = [19, 22]
676        // [3 4] [7 8]   [3*5 + 4*7, 3*6 + 4*8]   [43, 50]
677        assert_eq!(dense_result[[0, 0]], 19.0);
678        assert_eq!(dense_result[[0, 1]], 22.0);
679        assert_eq!(dense_result[[1, 0]], 43.0);
680        assert_eq!(dense_result[[1, 1]], 50.0);
681    }
682
683    #[test]
684    fn test_dok_array_transpose() {
685        let mut array = DokArray::<f64>::new((2, 3));
686        array.set(0, 0, 1.0).unwrap();
687        array.set(0, 1, 2.0).unwrap();
688        array.set(0, 2, 3.0).unwrap();
689        array.set(1, 0, 4.0).unwrap();
690        array.set(1, 1, 5.0).unwrap();
691        array.set(1, 2, 6.0).unwrap();
692
693        let transposed = array.transpose().unwrap();
694
695        assert_eq!(transposed.shape(), (3, 2));
696        assert_eq!(transposed.get(0, 0), 1.0);
697        assert_eq!(transposed.get(1, 0), 2.0);
698        assert_eq!(transposed.get(2, 0), 3.0);
699        assert_eq!(transposed.get(0, 1), 4.0);
700        assert_eq!(transposed.get(1, 1), 5.0);
701        assert_eq!(transposed.get(2, 1), 6.0);
702    }
703
704    #[test]
705    fn test_dok_array_slice() {
706        let mut array = DokArray::<f64>::new((3, 3));
707        array.set(0, 0, 1.0).unwrap();
708        array.set(0, 1, 2.0).unwrap();
709        array.set(0, 2, 3.0).unwrap();
710        array.set(1, 0, 4.0).unwrap();
711        array.set(1, 1, 5.0).unwrap();
712        array.set(1, 2, 6.0).unwrap();
713        array.set(2, 0, 7.0).unwrap();
714        array.set(2, 1, 8.0).unwrap();
715        array.set(2, 2, 9.0).unwrap();
716
717        let slice = array.slice((0, 2), (1, 3)).unwrap();
718
719        assert_eq!(slice.shape(), (2, 2));
720        assert_eq!(slice.get(0, 0), 2.0);
721        assert_eq!(slice.get(0, 1), 3.0);
722        assert_eq!(slice.get(1, 0), 5.0);
723        assert_eq!(slice.get(1, 1), 6.0);
724    }
725
726    #[test]
727    fn test_dok_array_sum() {
728        let mut array = DokArray::<f64>::new((2, 3));
729        array.set(0, 0, 1.0).unwrap();
730        array.set(0, 1, 2.0).unwrap();
731        array.set(0, 2, 3.0).unwrap();
732        array.set(1, 0, 4.0).unwrap();
733        array.set(1, 1, 5.0).unwrap();
734        array.set(1, 2, 6.0).unwrap();
735
736        // Sum all elements
737        match array.sum(None).unwrap() {
738            SparseSum::Scalar(sum) => assert_eq!(sum, 21.0),
739            _ => panic!("Expected scalar sum"),
740        }
741
742        // Sum along rows (axis 0)
743        match array.sum(Some(0)).unwrap() {
744            SparseSum::SparseArray(sum_array) => {
745                assert_eq!(sum_array.shape(), (1, 3));
746                assert_eq!(sum_array.get(0, 0), 5.0);
747                assert_eq!(sum_array.get(0, 1), 7.0);
748                assert_eq!(sum_array.get(0, 2), 9.0);
749            }
750            _ => panic!("Expected sparse array"),
751        }
752
753        // Sum along columns (axis 1)
754        match array.sum(Some(1)).unwrap() {
755            SparseSum::SparseArray(sum_array) => {
756                assert_eq!(sum_array.shape(), (2, 1));
757                assert_eq!(sum_array.get(0, 0), 6.0);
758                assert_eq!(sum_array.get(1, 0), 15.0);
759            }
760            _ => panic!("Expected sparse array"),
761        }
762    }
763}