scirs2_sparse/
dok_array.rs

1// Dictionary of Keys (DOK) Array implementation
2//
3// This module provides the DOK (Dictionary of Keys) array format,
4// which is efficient for incremental construction of sparse arrays.
5
6use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::error::{SparseError, SparseResult};
15use crate::sparray::{SparseArray, SparseSum};
16
17/// DOK Array format
18///
19/// The DOK (Dictionary of Keys) format stores a sparse array in a dictionary (HashMap)
20/// mapping (row, col) coordinate tuples to values.
21///
22/// # Notes
23///
24/// - Efficient for incremental construction (setting elements one by one)
25/// - Fast random access to individual elements (get/set)
26/// - Slow operations that require iterating over all elements
27/// - Slow arithmetic operations
28/// - Not suitable for large-scale computational operations
29///
30#[derive(Clone)]
31pub struct DokArray<T>
32where
33    T: Float
34        + Add<Output = T>
35        + Sub<Output = T>
36        + Mul<Output = T>
37        + Div<Output = T>
38        + Debug
39        + Copy
40        + 'static,
41{
42    /// Dictionary mapping (row, col) to value
43    data: HashMap<(usize, usize), T>,
44    /// Shape of the sparse array
45    shape: (usize, usize),
46}
47
48impl<T> DokArray<T>
49where
50    T: Float
51        + Add<Output = T>
52        + Sub<Output = T>
53        + Mul<Output = T>
54        + Div<Output = T>
55        + Debug
56        + Copy
57        + 'static,
58{
59    /// Creates a new DOK array with the given shape
60    ///
61    /// # Arguments
62    /// * `shape` - Shape of the sparse array (rows, cols)
63    ///
64    /// # Returns
65    /// A new empty `DokArray`
66    pub fn new(shape: (usize, usize)) -> Self {
67        Self {
68            data: HashMap::new(),
69            shape,
70        }
71    }
72
73    /// Creates a DOK array from triplet format (COO-like)
74    ///
75    /// # Arguments
76    /// * `rows` - Row indices
77    /// * `cols` - Column indices
78    /// * `data` - Values
79    /// * `shape` - Shape of the sparse array
80    ///
81    /// # Returns
82    /// A new `DokArray`
83    ///
84    /// # Errors
85    /// Returns an error if the data is not consistent
86    pub fn from_triplets(
87        rows: &[usize],
88        cols: &[usize],
89        data: &[T],
90        shape: (usize, usize),
91    ) -> SparseResult<Self> {
92        if rows.len() != cols.len() || rows.len() != data.len() {
93            return Err(SparseError::InconsistentData {
94                reason: "rows, cols, and data must have the same length".to_string(),
95            });
96        }
97
98        let mut dok = Self::new(shape);
99        for i in 0..rows.len() {
100            if rows[i] >= shape.0 || cols[i] >= shape.1 {
101                return Err(SparseError::IndexOutOfBounds {
102                    index: (rows[i], cols[i]),
103                    shape,
104                });
105            }
106            // Only set non-zero values
107            if !data[i].is_zero() {
108                dok.data.insert((rows[i], cols[i]), data[i]);
109            }
110        }
111
112        Ok(dok)
113    }
114
115    /// Returns a reference to the internal HashMap
116    pub fn get_data(&self) -> &HashMap<(usize, usize), T> {
117        &self.data
118    }
119
120    /// Returns the triplet representation (row indices, column indices, data)
121    pub fn to_triplets(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
122        let nnz = self.nnz();
123        let mut row_indices = Vec::with_capacity(nnz);
124        let mut col_indices = Vec::with_capacity(nnz);
125        let mut values = Vec::with_capacity(nnz);
126
127        // Sort by row, then column for deterministic output
128        let mut entries: Vec<_> = self.data.iter().collect();
129        entries.sort_by_key(|&(&(row, col), _)| (row, col));
130
131        for (&(row, col), &value) in entries {
132            row_indices.push(row);
133            col_indices.push(col);
134            values.push(value);
135        }
136
137        (
138            Array1::from_vec(row_indices),
139            Array1::from_vec(col_indices),
140            Array1::from_vec(values),
141        )
142    }
143
144    /// Creates a DOK array from a dense ndarray
145    ///
146    /// # Arguments
147    /// * `array` - Dense ndarray
148    ///
149    /// # Returns
150    /// A new `DokArray` containing non-zero elements from the input array
151    pub fn from_array(array: &Array2<T>) -> Self {
152        let shape = (array.shape()[0], array.shape()[1]);
153        let mut dok = Self::new(shape);
154
155        for ((i, j), &value) in array.indexed_iter() {
156            if !value.is_zero() {
157                dok.data.insert((i, j), value);
158            }
159        }
160
161        dok
162    }
163}
164
165impl<T> SparseArray<T> for DokArray<T>
166where
167    T: Float
168        + Add<Output = T>
169        + Sub<Output = T>
170        + Mul<Output = T>
171        + Div<Output = T>
172        + Debug
173        + Copy
174        + 'static,
175{
176    fn shape(&self) -> (usize, usize) {
177        self.shape
178    }
179
180    fn nnz(&self) -> usize {
181        self.data.len()
182    }
183
184    fn dtype(&self) -> &str {
185        "float" // This is a placeholder; ideally, we'd return the actual type
186    }
187
188    fn to_array(&self) -> Array2<T> {
189        let (rows, cols) = self.shape;
190        let mut result = Array2::zeros((rows, cols));
191
192        for (&(row, col), &value) in &self.data {
193            result[[row, col]] = value;
194        }
195
196        result
197    }
198
199    fn toarray(&self) -> Array2<T> {
200        self.to_array()
201    }
202
203    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
204        let (row_indices, col_indices, data) = self.to_triplets();
205        CooArray::new(data, row_indices, col_indices, self.shape, true)
206            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
207    }
208
209    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
210        // First convert to COO, then to CSR
211        match self.to_coo() {
212            Ok(coo) => coo.to_csr(),
213            Err(e) => Err(e),
214        }
215    }
216
217    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
218        // First convert to COO, then to CSC
219        match self.to_coo() {
220            Ok(coo) => coo.to_csc(),
221            Err(e) => Err(e),
222        }
223    }
224
225    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
226        // We're already a DOK array
227        Ok(Box::new(self.clone()))
228    }
229
230    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
231        Err(SparseError::NotImplemented(
232            "Conversion to LIL array".to_string(),
233        ))
234    }
235
236    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
237        Err(SparseError::NotImplemented(
238            "Conversion to DIA array".to_string(),
239        ))
240    }
241
242    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
243        Err(SparseError::NotImplemented(
244            "Conversion to BSR array".to_string(),
245        ))
246    }
247
248    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
249        if self.shape() != other.shape() {
250            return Err(SparseError::DimensionMismatch {
251                expected: self.shape().0,
252                found: other.shape().0,
253            });
254        }
255
256        let mut result = self.clone();
257        let other_array = other.to_array();
258
259        // Add existing values from self
260        for (&(row, col), &value) in &self.data {
261            result.set(row, col, value + other_array[[row, col]])?;
262        }
263
264        // Add values from other that aren't in self
265        for ((row, col), &value) in other_array.indexed_iter() {
266            if !self.data.contains_key(&(row, col)) && !value.is_zero() {
267                result.set(row, col, value)?;
268            }
269        }
270
271        Ok(Box::new(result))
272    }
273
274    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
275        if self.shape() != other.shape() {
276            return Err(SparseError::DimensionMismatch {
277                expected: self.shape().0,
278                found: other.shape().0,
279            });
280        }
281
282        let mut result = self.clone();
283        let other_array = other.to_array();
284
285        // Subtract existing values from self
286        for (&(row, col), &value) in &self.data {
287            result.set(row, col, value - other_array[[row, col]])?;
288        }
289
290        // Subtract values from other that aren't in self
291        for ((row, col), &value) in other_array.indexed_iter() {
292            if !self.data.contains_key(&(row, col)) && !value.is_zero() {
293                result.set(row, col, -value)?;
294            }
295        }
296
297        Ok(Box::new(result))
298    }
299
300    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
301        if self.shape() != other.shape() {
302            return Err(SparseError::DimensionMismatch {
303                expected: self.shape().0,
304                found: other.shape().0,
305            });
306        }
307
308        let mut result = DokArray::new(self.shape());
309        let other_array = other.to_array();
310
311        // Only need to process entries in self
312        // since a*0 = 0 for any a
313        for (&(row, col), &value) in &self.data {
314            let product = value * other_array[[row, col]];
315            if !product.is_zero() {
316                result.set(row, col, product)?;
317            }
318        }
319
320        Ok(Box::new(result))
321    }
322
323    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
324        if self.shape() != other.shape() {
325            return Err(SparseError::DimensionMismatch {
326                expected: self.shape().0,
327                found: other.shape().0,
328            });
329        }
330
331        let mut result = DokArray::new(self.shape());
332        let other_array = other.to_array();
333
334        for (&(row, col), &value) in &self.data {
335            let divisor = other_array[[row, col]];
336            if divisor.is_zero() {
337                return Err(SparseError::ComputationError(
338                    "Division by zero".to_string(),
339                ));
340            }
341
342            let quotient = value / divisor;
343            if !quotient.is_zero() {
344                result.set(row, col, quotient)?;
345            }
346        }
347
348        Ok(Box::new(result))
349    }
350
351    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
352        let (_m, n) = self.shape();
353        let (p, _q) = other.shape();
354
355        if n != p {
356            return Err(SparseError::DimensionMismatch {
357                expected: n,
358                found: p,
359            });
360        }
361
362        // Convert to CSR for efficient matrix multiplication
363        let csr_self = self.to_csr()?;
364        let csr_other = other.to_csr()?;
365
366        csr_self.dot(&*csr_other)
367    }
368
369    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
370        let (m, n) = self.shape();
371        if n != other.len() {
372            return Err(SparseError::DimensionMismatch {
373                expected: n,
374                found: other.len(),
375            });
376        }
377
378        let mut result = Array1::zeros(m);
379
380        for (&(row, col), &value) in &self.data {
381            result[row] = result[row] + value * other[col];
382        }
383
384        Ok(result)
385    }
386
387    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
388        let (rows, cols) = self.shape;
389        let mut result = DokArray::new((cols, rows));
390
391        for (&(row, col), &value) in &self.data {
392            result.set(col, row, value)?;
393        }
394
395        Ok(Box::new(result))
396    }
397
398    fn copy(&self) -> Box<dyn SparseArray<T>> {
399        Box::new(self.clone())
400    }
401
402    fn get(&self, i: usize, j: usize) -> T {
403        if i >= self.shape.0 || j >= self.shape.1 {
404            return T::zero();
405        }
406
407        *self.data.get(&(i, j)).unwrap_or(&T::zero())
408    }
409
410    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
411        if i >= self.shape.0 || j >= self.shape.1 {
412            return Err(SparseError::IndexOutOfBounds {
413                index: (i, j),
414                shape: self.shape,
415            });
416        }
417
418        if value.is_zero() {
419            // Remove zero entries
420            self.data.remove(&(i, j));
421        } else {
422            // Set non-zero value
423            self.data.insert((i, j), value);
424        }
425
426        Ok(())
427    }
428
429    fn eliminate_zeros(&mut self) {
430        // DOK format already doesn't store zeros, but just in case
431        self.data.retain(|_, &mut value| !value.is_zero());
432    }
433
434    fn sort_indices(&mut self) {
435        // No-op for DOK format since it's a HashMap
436    }
437
438    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
439        // DOK doesn't have the concept of sorted indices
440        self.copy()
441    }
442
443    fn has_sorted_indices(&self) -> bool {
444        true // DOK format doesn't have the concept of sorted indices
445    }
446
447    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
448        match axis {
449            None => {
450                // Sum all elements
451                let mut sum = T::zero();
452                for &value in self.data.values() {
453                    sum = sum + value;
454                }
455                Ok(SparseSum::Scalar(sum))
456            }
457            Some(0) => {
458                // Sum along rows
459                let (_, cols) = self.shape();
460                let mut result = DokArray::new((1, cols));
461
462                for (&(_row, col), &value) in &self.data {
463                    let current = result.get(0, col);
464                    result.set(0, col, current + value)?;
465                }
466
467                Ok(SparseSum::SparseArray(Box::new(result)))
468            }
469            Some(1) => {
470                // Sum along columns
471                let (rows, _) = self.shape();
472                let mut result = DokArray::new((rows, 1));
473
474                for (&(row, _col), &value) in &self.data {
475                    let current = result.get(row, 0);
476                    result.set(row, 0, current + value)?;
477                }
478
479                Ok(SparseSum::SparseArray(Box::new(result)))
480            }
481            _ => Err(SparseError::InvalidAxis),
482        }
483    }
484
485    fn max(&self) -> T {
486        if self.data.is_empty() {
487            return T::nan();
488        }
489
490        self.data
491            .values()
492            .fold(T::neg_infinity(), |acc, &x| acc.max(x))
493    }
494
495    fn min(&self) -> T {
496        if self.data.is_empty() {
497            return T::nan();
498        }
499
500        self.data.values().fold(T::infinity(), |acc, &x| acc.min(x))
501    }
502
503    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
504        self.to_triplets()
505    }
506
507    fn slice(
508        &self,
509        row_range: (usize, usize),
510        col_range: (usize, usize),
511    ) -> SparseResult<Box<dyn SparseArray<T>>> {
512        let (start_row, end_row) = row_range;
513        let (start_col, end_col) = col_range;
514        let (rows, cols) = self.shape;
515
516        if start_row >= rows
517            || end_row > rows
518            || start_col >= cols
519            || end_col > cols
520            || start_row >= end_row
521            || start_col >= end_col
522        {
523            return Err(SparseError::InvalidSliceRange);
524        }
525
526        let slice_shape = (end_row - start_row, end_col - start_col);
527        let mut result = DokArray::new(slice_shape);
528
529        for (&(row, col), &value) in &self.data {
530            if row >= start_row && row < end_row && col >= start_col && col < end_col {
531                result.set(row - start_row, col - start_col, value)?;
532            }
533        }
534
535        Ok(Box::new(result))
536    }
537
538    fn as_any(&self) -> &dyn Any {
539        self
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use ndarray::Array;
547
548    #[test]
549    fn test_dok_array_create_and_access() {
550        // Create a 3x3 sparse array
551        let mut array = DokArray::<f64>::new((3, 3));
552
553        // Set some values
554        array.set(0, 0, 1.0).unwrap();
555        array.set(0, 2, 2.0).unwrap();
556        array.set(1, 2, 3.0).unwrap();
557        array.set(2, 0, 4.0).unwrap();
558        array.set(2, 1, 5.0).unwrap();
559
560        assert_eq!(array.nnz(), 5);
561
562        // Access values
563        assert_eq!(array.get(0, 0), 1.0);
564        assert_eq!(array.get(0, 1), 0.0); // Zero entry
565        assert_eq!(array.get(0, 2), 2.0);
566        assert_eq!(array.get(1, 2), 3.0);
567        assert_eq!(array.get(2, 0), 4.0);
568        assert_eq!(array.get(2, 1), 5.0);
569
570        // Set a value to zero should remove it
571        array.set(0, 0, 0.0).unwrap();
572        assert_eq!(array.nnz(), 4);
573        assert_eq!(array.get(0, 0), 0.0);
574
575        // Out of bounds access should return zero
576        assert_eq!(array.get(3, 0), 0.0);
577        assert_eq!(array.get(0, 3), 0.0);
578    }
579
580    #[test]
581    fn test_dok_array_from_triplets() {
582        let rows = vec![0, 0, 1, 2, 2];
583        let cols = vec![0, 2, 2, 0, 1];
584        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
585
586        let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3)).unwrap();
587
588        assert_eq!(array.nnz(), 5);
589        assert_eq!(array.get(0, 0), 1.0);
590        assert_eq!(array.get(0, 2), 2.0);
591        assert_eq!(array.get(1, 2), 3.0);
592        assert_eq!(array.get(2, 0), 4.0);
593        assert_eq!(array.get(2, 1), 5.0);
594    }
595
596    #[test]
597    fn test_dok_array_to_array() {
598        let rows = vec![0, 0, 1, 2, 2];
599        let cols = vec![0, 2, 2, 0, 1];
600        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
601
602        let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3)).unwrap();
603        let dense = array.to_array();
604
605        let expected =
606            Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
607                .unwrap();
608
609        assert_eq!(dense, expected);
610    }
611
612    #[test]
613    fn test_dok_array_from_array() {
614        let dense =
615            Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
616                .unwrap();
617
618        let array = DokArray::from_array(&dense);
619
620        assert_eq!(array.nnz(), 5);
621        assert_eq!(array.get(0, 0), 1.0);
622        assert_eq!(array.get(0, 2), 2.0);
623        assert_eq!(array.get(1, 2), 3.0);
624        assert_eq!(array.get(2, 0), 4.0);
625        assert_eq!(array.get(2, 1), 5.0);
626    }
627
628    #[test]
629    fn test_dok_array_add() {
630        let mut array1 = DokArray::<f64>::new((2, 2));
631        array1.set(0, 0, 1.0).unwrap();
632        array1.set(0, 1, 2.0).unwrap();
633        array1.set(1, 0, 3.0).unwrap();
634
635        let mut array2 = DokArray::<f64>::new((2, 2));
636        array2.set(0, 0, 4.0).unwrap();
637        array2.set(1, 1, 5.0).unwrap();
638
639        let result = array1.add(&array2).unwrap();
640        let dense_result = result.to_array();
641
642        assert_eq!(dense_result[[0, 0]], 5.0);
643        assert_eq!(dense_result[[0, 1]], 2.0);
644        assert_eq!(dense_result[[1, 0]], 3.0);
645        assert_eq!(dense_result[[1, 1]], 5.0);
646    }
647
648    #[test]
649    fn test_dok_array_mul() {
650        let mut array1 = DokArray::<f64>::new((2, 2));
651        array1.set(0, 0, 1.0).unwrap();
652        array1.set(0, 1, 2.0).unwrap();
653        array1.set(1, 0, 3.0).unwrap();
654        array1.set(1, 1, 4.0).unwrap();
655
656        let mut array2 = DokArray::<f64>::new((2, 2));
657        array2.set(0, 0, 5.0).unwrap();
658        array2.set(0, 1, 6.0).unwrap();
659        array2.set(1, 0, 7.0).unwrap();
660        array2.set(1, 1, 8.0).unwrap();
661
662        // Element-wise multiplication
663        let result = array1.mul(&array2).unwrap();
664        let dense_result = result.to_array();
665
666        assert_eq!(dense_result[[0, 0]], 5.0);
667        assert_eq!(dense_result[[0, 1]], 12.0);
668        assert_eq!(dense_result[[1, 0]], 21.0);
669        assert_eq!(dense_result[[1, 1]], 32.0);
670    }
671
672    #[test]
673    fn test_dok_array_dot() {
674        let mut array1 = DokArray::<f64>::new((2, 2));
675        array1.set(0, 0, 1.0).unwrap();
676        array1.set(0, 1, 2.0).unwrap();
677        array1.set(1, 0, 3.0).unwrap();
678        array1.set(1, 1, 4.0).unwrap();
679
680        let mut array2 = DokArray::<f64>::new((2, 2));
681        array2.set(0, 0, 5.0).unwrap();
682        array2.set(0, 1, 6.0).unwrap();
683        array2.set(1, 0, 7.0).unwrap();
684        array2.set(1, 1, 8.0).unwrap();
685
686        // Matrix multiplication
687        let result = array1.dot(&array2).unwrap();
688        let dense_result = result.to_array();
689
690        // [1 2] [5 6] = [1*5 + 2*7, 1*6 + 2*8] = [19, 22]
691        // [3 4] [7 8]   [3*5 + 4*7, 3*6 + 4*8]   [43, 50]
692        assert_eq!(dense_result[[0, 0]], 19.0);
693        assert_eq!(dense_result[[0, 1]], 22.0);
694        assert_eq!(dense_result[[1, 0]], 43.0);
695        assert_eq!(dense_result[[1, 1]], 50.0);
696    }
697
698    #[test]
699    fn test_dok_array_transpose() {
700        let mut array = DokArray::<f64>::new((2, 3));
701        array.set(0, 0, 1.0).unwrap();
702        array.set(0, 1, 2.0).unwrap();
703        array.set(0, 2, 3.0).unwrap();
704        array.set(1, 0, 4.0).unwrap();
705        array.set(1, 1, 5.0).unwrap();
706        array.set(1, 2, 6.0).unwrap();
707
708        let transposed = array.transpose().unwrap();
709
710        assert_eq!(transposed.shape(), (3, 2));
711        assert_eq!(transposed.get(0, 0), 1.0);
712        assert_eq!(transposed.get(1, 0), 2.0);
713        assert_eq!(transposed.get(2, 0), 3.0);
714        assert_eq!(transposed.get(0, 1), 4.0);
715        assert_eq!(transposed.get(1, 1), 5.0);
716        assert_eq!(transposed.get(2, 1), 6.0);
717    }
718
719    #[test]
720    fn test_dok_array_slice() {
721        let mut array = DokArray::<f64>::new((3, 3));
722        array.set(0, 0, 1.0).unwrap();
723        array.set(0, 1, 2.0).unwrap();
724        array.set(0, 2, 3.0).unwrap();
725        array.set(1, 0, 4.0).unwrap();
726        array.set(1, 1, 5.0).unwrap();
727        array.set(1, 2, 6.0).unwrap();
728        array.set(2, 0, 7.0).unwrap();
729        array.set(2, 1, 8.0).unwrap();
730        array.set(2, 2, 9.0).unwrap();
731
732        let slice = array.slice((0, 2), (1, 3)).unwrap();
733
734        assert_eq!(slice.shape(), (2, 2));
735        assert_eq!(slice.get(0, 0), 2.0);
736        assert_eq!(slice.get(0, 1), 3.0);
737        assert_eq!(slice.get(1, 0), 5.0);
738        assert_eq!(slice.get(1, 1), 6.0);
739    }
740
741    #[test]
742    fn test_dok_array_sum() {
743        let mut array = DokArray::<f64>::new((2, 3));
744        array.set(0, 0, 1.0).unwrap();
745        array.set(0, 1, 2.0).unwrap();
746        array.set(0, 2, 3.0).unwrap();
747        array.set(1, 0, 4.0).unwrap();
748        array.set(1, 1, 5.0).unwrap();
749        array.set(1, 2, 6.0).unwrap();
750
751        // Sum all elements
752        match array.sum(None).unwrap() {
753            SparseSum::Scalar(sum) => assert_eq!(sum, 21.0),
754            _ => panic!("Expected scalar sum"),
755        }
756
757        // Sum along rows (axis 0)
758        match array.sum(Some(0)).unwrap() {
759            SparseSum::SparseArray(sum_array) => {
760                assert_eq!(sum_array.shape(), (1, 3));
761                assert_eq!(sum_array.get(0, 0), 5.0);
762                assert_eq!(sum_array.get(0, 1), 7.0);
763                assert_eq!(sum_array.get(0, 2), 9.0);
764            }
765            _ => panic!("Expected sparse array"),
766        }
767
768        // Sum along columns (axis 1)
769        match array.sum(Some(1)).unwrap() {
770            SparseSum::SparseArray(sum_array) => {
771                assert_eq!(sum_array.shape(), (2, 1));
772                assert_eq!(sum_array.get(0, 0), 6.0);
773                assert_eq!(sum_array.get(1, 0), 15.0);
774            }
775            _ => panic!("Expected sparse array"),
776        }
777    }
778}