scirs2_sparse/
dok_array.rs

1// Dictionary of Keys (DOK) Array implementation
2//
3// This module provides the DOK (Dictionary of Keys) array format,
4// which is efficient for incremental construction of sparse arrays.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::error::{SparseError, SparseResult};
15use crate::sparray::{SparseArray, SparseSum};
16
17/// DOK Array format
18///
19/// The DOK (Dictionary of Keys) format stores a sparse array in a dictionary (HashMap)
20/// mapping (row, col) coordinate tuples to values.
21///
22/// # Notes
23///
24/// - Efficient for incremental construction (setting elements one by one)
25/// - Fast random access to individual elements (get/set)
26/// - Slow operations that require iterating over all elements
27/// - Slow arithmetic operations
28/// - Not suitable for large-scale computational operations
29///
30#[derive(Clone)]
31pub struct DokArray<T>
32where
33    T: SparseElement + Div<Output = T> + 'static,
34{
35    /// Dictionary mapping (row, col) to value
36    data: HashMap<(usize, usize), T>,
37    /// Shape of the sparse array
38    shape: (usize, usize),
39}
40
41impl<T> DokArray<T>
42where
43    T: SparseElement + Div<Output = T> + 'static,
44{
45    /// Creates a new DOK array with the given shape
46    ///
47    /// # Arguments
48    /// * `shape` - Shape of the sparse array (rows, cols)
49    ///
50    /// # Returns
51    /// A new empty `DokArray`
52    pub fn new(shape: (usize, usize)) -> Self {
53        Self {
54            data: HashMap::new(),
55            shape,
56        }
57    }
58
59    /// Creates a DOK array from triplet format (COO-like)
60    ///
61    /// # Arguments
62    /// * `rows` - Row indices
63    /// * `cols` - Column indices
64    /// * `data` - Values
65    /// * `shape` - Shape of the sparse array
66    ///
67    /// # Returns
68    /// A new `DokArray`
69    ///
70    /// # Errors
71    /// Returns an error if the data is not consistent
72    pub fn from_triplets(
73        rows: &[usize],
74        cols: &[usize],
75        data: &[T],
76        shape: (usize, usize),
77    ) -> SparseResult<Self> {
78        if rows.len() != cols.len() || rows.len() != data.len() {
79            return Err(SparseError::InconsistentData {
80                reason: "rows, cols, and data must have the same length".to_string(),
81            });
82        }
83
84        let mut dok = Self::new(shape);
85        for i in 0..rows.len() {
86            if rows[i] >= shape.0 || cols[i] >= shape.1 {
87                return Err(SparseError::IndexOutOfBounds {
88                    index: (rows[i], cols[i]),
89                    shape,
90                });
91            }
92            // Only set non-zero values
93            if !SparseElement::is_zero(&data[i]) {
94                dok.data.insert((rows[i], cols[i]), data[i]);
95            }
96        }
97
98        Ok(dok)
99    }
100
101    /// Returns a reference to the internal HashMap
102    pub fn get_data(&self) -> &HashMap<(usize, usize), T> {
103        &self.data
104    }
105
106    /// Returns the triplet representation (row indices, column indices, data)
107    pub fn to_triplets(&self) -> (Array1<usize>, Array1<usize>, Array1<T>)
108    where
109        T: Float + PartialOrd,
110    {
111        let nnz = self.nnz();
112        let mut row_indices = Vec::with_capacity(nnz);
113        let mut col_indices = Vec::with_capacity(nnz);
114        let mut values = Vec::with_capacity(nnz);
115
116        // Sort by row, then column for deterministic output
117        let mut entries: Vec<_> = self.data.iter().collect();
118        entries.sort_by_key(|(&(row, col), _)| (row, col));
119
120        for (&(row, col), &value) in entries {
121            row_indices.push(row);
122            col_indices.push(col);
123            values.push(value);
124        }
125
126        (
127            Array1::from_vec(row_indices),
128            Array1::from_vec(col_indices),
129            Array1::from_vec(values),
130        )
131    }
132
133    /// Creates a DOK array from a dense ndarray
134    ///
135    /// # Arguments
136    /// * `array` - Dense ndarray
137    ///
138    /// # Returns
139    /// A new `DokArray` containing non-zero elements from the input array
140    pub fn from_array(array: &Array2<T>) -> Self {
141        let shape = (array.shape()[0], array.shape()[1]);
142        let mut dok = Self::new(shape);
143
144        for ((i, j), &value) in array.indexed_iter() {
145            if !SparseElement::is_zero(&value) {
146                dok.data.insert((i, j), value);
147            }
148        }
149
150        dok
151    }
152}
153
154impl<T> SparseArray<T> for DokArray<T>
155where
156    T: SparseElement + Div<Output = T> + Float + PartialOrd + 'static,
157{
158    fn shape(&self) -> (usize, usize) {
159        self.shape
160    }
161
162    fn nnz(&self) -> usize {
163        self.data.len()
164    }
165
166    fn dtype(&self) -> &str {
167        "float" // This is a placeholder; ideally, we'd return the actual type
168    }
169
170    fn to_array(&self) -> Array2<T> {
171        let (rows, cols) = self.shape;
172        let mut result = Array2::zeros((rows, cols));
173
174        for (&(row, col), &value) in &self.data {
175            result[[row, col]] = value;
176        }
177
178        result
179    }
180
181    fn toarray(&self) -> Array2<T> {
182        self.to_array()
183    }
184
185    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
186        let (row_indices, col_indices, data) = self.to_triplets();
187        CooArray::new(data, row_indices, col_indices, self.shape, true)
188            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
189    }
190
191    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
192        // First convert to COO, then to CSR
193        match self.to_coo() {
194            Ok(coo) => coo.to_csr(),
195            Err(e) => Err(e),
196        }
197    }
198
199    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
200        // First convert to COO, then to CSC
201        match self.to_coo() {
202            Ok(coo) => coo.to_csc(),
203            Err(e) => Err(e),
204        }
205    }
206
207    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
208        // We're already a DOK array
209        Ok(Box::new(self.clone()))
210    }
211
212    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
213        Err(SparseError::NotImplemented(
214            "Conversion to LIL array".to_string(),
215        ))
216    }
217
218    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
219        Err(SparseError::NotImplemented(
220            "Conversion to DIA array".to_string(),
221        ))
222    }
223
224    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
225        Err(SparseError::NotImplemented(
226            "Conversion to BSR array".to_string(),
227        ))
228    }
229
230    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
231        if self.shape() != other.shape() {
232            return Err(SparseError::DimensionMismatch {
233                expected: self.shape().0,
234                found: other.shape().0,
235            });
236        }
237
238        let mut result = self.clone();
239        let other_array = other.to_array();
240
241        // Add existing values from self
242        for (&(row, col), &value) in &self.data {
243            result.set(row, col, value + other_array[[row, col]])?;
244        }
245
246        // Add values from other that aren't in self
247        for ((row, col), &value) in other_array.indexed_iter() {
248            if !self.data.contains_key(&(row, col)) && !SparseElement::is_zero(&value) {
249                result.set(row, col, value)?;
250            }
251        }
252
253        Ok(Box::new(result))
254    }
255
256    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
257        if self.shape() != other.shape() {
258            return Err(SparseError::DimensionMismatch {
259                expected: self.shape().0,
260                found: other.shape().0,
261            });
262        }
263
264        let mut result = self.clone();
265        let other_array = other.to_array();
266
267        // Subtract existing values from self
268        for (&(row, col), &value) in &self.data {
269            result.set(row, col, value - other_array[[row, col]])?;
270        }
271
272        // Subtract values from other that aren't in self
273        for ((row, col), &value) in other_array.indexed_iter() {
274            if !self.data.contains_key(&(row, col)) && !SparseElement::is_zero(&value) {
275                result.set(row, col, -value)?;
276            }
277        }
278
279        Ok(Box::new(result))
280    }
281
282    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
283        if self.shape() != other.shape() {
284            return Err(SparseError::DimensionMismatch {
285                expected: self.shape().0,
286                found: other.shape().0,
287            });
288        }
289
290        let mut result = DokArray::new(self.shape());
291        let other_array = other.to_array();
292
293        // Only need to process entries in self
294        // since a*0 = 0 for any a
295        for (&(row, col), &value) in &self.data {
296            let product = value * other_array[[row, col]];
297            if !SparseElement::is_zero(&product) {
298                result.set(row, col, product)?;
299            }
300        }
301
302        Ok(Box::new(result))
303    }
304
305    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
306        if self.shape() != other.shape() {
307            return Err(SparseError::DimensionMismatch {
308                expected: self.shape().0,
309                found: other.shape().0,
310            });
311        }
312
313        let mut result = DokArray::new(self.shape());
314        let other_array = other.to_array();
315
316        for (&(row, col), &value) in &self.data {
317            let divisor = other_array[[row, col]];
318            if SparseElement::is_zero(&divisor) {
319                return Err(SparseError::ComputationError(
320                    "Division by zero".to_string(),
321                ));
322            }
323
324            let quotient = value / divisor;
325            if !SparseElement::is_zero(&quotient) {
326                result.set(row, col, quotient)?;
327            }
328        }
329
330        Ok(Box::new(result))
331    }
332
333    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
334        let (_m, n) = self.shape();
335        let (p, q) = other.shape();
336
337        if n != p {
338            return Err(SparseError::DimensionMismatch {
339                expected: n,
340                found: p,
341            });
342        }
343
344        // Convert to CSR for efficient matrix multiplication
345        let csr_self = self.to_csr()?;
346        let csr_other = other.to_csr()?;
347
348        csr_self.dot(&*csr_other)
349    }
350
351    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
352        let (m, n) = self.shape();
353        if n != other.len() {
354            return Err(SparseError::DimensionMismatch {
355                expected: n,
356                found: other.len(),
357            });
358        }
359
360        let mut result = Array1::zeros(m);
361
362        for (&(row, col), &value) in &self.data {
363            result[row] = result[row] + value * other[col];
364        }
365
366        Ok(result)
367    }
368
369    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
370        let (rows, cols) = self.shape;
371        let mut result = DokArray::new((cols, rows));
372
373        for (&(row, col), &value) in &self.data {
374            result.set(col, row, value)?;
375        }
376
377        Ok(Box::new(result))
378    }
379
380    fn copy(&self) -> Box<dyn SparseArray<T>> {
381        Box::new(self.clone())
382    }
383
384    fn get(&self, i: usize, j: usize) -> T {
385        if i >= self.shape.0 || j >= self.shape.1 {
386            return T::sparse_zero();
387        }
388
389        *self.data.get(&(i, j)).unwrap_or(&T::sparse_zero())
390    }
391
392    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
393        if i >= self.shape.0 || j >= self.shape.1 {
394            return Err(SparseError::IndexOutOfBounds {
395                index: (i, j),
396                shape: self.shape,
397            });
398        }
399
400        if SparseElement::is_zero(&value) {
401            // Remove zero entries
402            self.data.remove(&(i, j));
403        } else {
404            // Set non-zero value
405            self.data.insert((i, j), value);
406        }
407
408        Ok(())
409    }
410
411    fn eliminate_zeros(&mut self) {
412        // DOK format already doesn't store zeros, but just in case
413        self.data
414            .retain(|_, &mut value| !SparseElement::is_zero(&value));
415    }
416
417    fn sort_indices(&mut self) {
418        // No-op for DOK format since it's a HashMap
419    }
420
421    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
422        // DOK doesn't have the concept of sorted indices
423        self.copy()
424    }
425
426    fn has_sorted_indices(&self) -> bool {
427        true // DOK format doesn't have the concept of sorted indices
428    }
429
430    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
431        match axis {
432            None => {
433                // Sum all elements
434                let mut sum = T::sparse_zero();
435                for &value in self.data.values() {
436                    sum = sum + value;
437                }
438                Ok(SparseSum::Scalar(sum))
439            }
440            Some(0) => {
441                // Sum along rows
442                let (_, cols) = self.shape();
443                let mut result = DokArray::new((1, cols));
444
445                for (&(_row, col), &value) in &self.data {
446                    let current = result.get(0, col);
447                    result.set(0, col, current + value)?;
448                }
449
450                Ok(SparseSum::SparseArray(Box::new(result)))
451            }
452            Some(1) => {
453                // Sum along columns
454                let (rows, _) = self.shape();
455                let mut result = DokArray::new((rows, 1));
456
457                for (&(row, col), &value) in &self.data {
458                    let current = result.get(row, 0);
459                    result.set(row, 0, current + value)?;
460                }
461
462                Ok(SparseSum::SparseArray(Box::new(result)))
463            }
464            _ => Err(SparseError::InvalidAxis),
465        }
466    }
467
468    fn max(&self) -> T {
469        if self.data.is_empty() {
470            return T::nan();
471        }
472
473        self.data
474            .values()
475            .fold(T::neg_infinity(), |acc, &x| acc.max(x))
476    }
477
478    fn min(&self) -> T {
479        if self.data.is_empty() {
480            return T::nan();
481        }
482
483        self.data
484            .values()
485            .fold(T::sparse_zero(), |acc, &x| acc.min(x))
486    }
487
488    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
489        self.to_triplets()
490    }
491
492    fn slice(
493        &self,
494        row_range: (usize, usize),
495        col_range: (usize, usize),
496    ) -> SparseResult<Box<dyn SparseArray<T>>> {
497        let (start_row, end_row) = row_range;
498        let (start_col, end_col) = col_range;
499        let (rows, cols) = self.shape;
500
501        if start_row >= rows
502            || end_row > rows
503            || start_col >= cols
504            || end_col > cols
505            || start_row >= end_row
506            || start_col >= end_col
507        {
508            return Err(SparseError::InvalidSliceRange);
509        }
510
511        let sliceshape = (end_row - start_row, end_col - start_col);
512        let mut result = DokArray::new(sliceshape);
513
514        for (&(row, col), &value) in &self.data {
515            if row >= start_row && row < end_row && col >= start_col && col < end_col {
516                result.set(row - start_row, col - start_col, value)?;
517            }
518        }
519
520        Ok(Box::new(result))
521    }
522
523    fn as_any(&self) -> &dyn Any {
524        self
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use scirs2_core::ndarray::Array;
532
533    #[test]
534    fn test_dok_array_create_and_access() {
535        // Create a 3x3 sparse array
536        let mut array = DokArray::<f64>::new((3, 3));
537
538        // Set some values
539        array
540            .set(0, 0, 1.0)
541            .expect("Test: failed to set array element");
542        array
543            .set(0, 2, 2.0)
544            .expect("Test: failed to set array element");
545        array
546            .set(1, 2, 3.0)
547            .expect("Test: failed to set array element");
548        array
549            .set(2, 0, 4.0)
550            .expect("Test: failed to set array element");
551        array
552            .set(2, 1, 5.0)
553            .expect("Test: failed to set array element");
554
555        assert_eq!(array.nnz(), 5);
556
557        // Access values
558        assert_eq!(array.get(0, 0), 1.0);
559        assert_eq!(array.get(0, 1), 0.0); // Zero entry
560        assert_eq!(array.get(0, 2), 2.0);
561        assert_eq!(array.get(1, 2), 3.0);
562        assert_eq!(array.get(2, 0), 4.0);
563        assert_eq!(array.get(2, 1), 5.0);
564
565        // Set a value to zero should remove it
566        array
567            .set(0, 0, 0.0)
568            .expect("Test: failed to set array element");
569        assert_eq!(array.nnz(), 4);
570        assert_eq!(array.get(0, 0), 0.0);
571
572        // Out of bounds access should return zero
573        assert_eq!(array.get(3, 0), 0.0);
574        assert_eq!(array.get(0, 3), 0.0);
575    }
576
577    #[test]
578    fn test_dok_array_from_triplets() {
579        let rows = vec![0, 0, 1, 2, 2];
580        let cols = vec![0, 2, 2, 0, 1];
581        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
582
583        let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3))
584            .expect("Test: failed to create DokArray from triplets");
585
586        assert_eq!(array.nnz(), 5);
587        assert_eq!(array.get(0, 0), 1.0);
588        assert_eq!(array.get(0, 2), 2.0);
589        assert_eq!(array.get(1, 2), 3.0);
590        assert_eq!(array.get(2, 0), 4.0);
591        assert_eq!(array.get(2, 1), 5.0);
592    }
593
594    #[test]
595    fn test_dok_array_to_array() {
596        let rows = vec![0, 0, 1, 2, 2];
597        let cols = vec![0, 2, 2, 0, 1];
598        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
599
600        let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3))
601            .expect("Test: failed to create DokArray from triplets");
602        let dense = array.to_array();
603
604        let expected =
605            Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
606                .expect("Test: failed to create array from shape vec");
607
608        assert_eq!(dense, expected);
609    }
610
611    #[test]
612    fn test_dok_array_from_array() {
613        let dense =
614            Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
615                .expect("Test: failed to create array from shape vec");
616
617        let array = DokArray::from_array(&dense);
618
619        assert_eq!(array.nnz(), 5);
620        assert_eq!(array.get(0, 0), 1.0);
621        assert_eq!(array.get(0, 2), 2.0);
622        assert_eq!(array.get(1, 2), 3.0);
623        assert_eq!(array.get(2, 0), 4.0);
624        assert_eq!(array.get(2, 1), 5.0);
625    }
626
627    #[test]
628    fn test_dok_array_add() {
629        let mut array1 = DokArray::<f64>::new((2, 2));
630        array1
631            .set(0, 0, 1.0)
632            .expect("Test: failed to set array element");
633        array1
634            .set(0, 1, 2.0)
635            .expect("Test: failed to set array element");
636        array1
637            .set(1, 0, 3.0)
638            .expect("Test: failed to set array element");
639
640        let mut array2 = DokArray::<f64>::new((2, 2));
641        array2
642            .set(0, 0, 4.0)
643            .expect("Test: failed to set array element");
644        array2
645            .set(1, 1, 5.0)
646            .expect("Test: failed to set array element");
647
648        let result = array1.add(&array2).expect("Test: array addition failed");
649        let dense_result = result.to_array();
650
651        assert_eq!(dense_result[[0, 0]], 5.0);
652        assert_eq!(dense_result[[0, 1]], 2.0);
653        assert_eq!(dense_result[[1, 0]], 3.0);
654        assert_eq!(dense_result[[1, 1]], 5.0);
655    }
656
657    #[test]
658    fn test_dok_array_mul() {
659        let mut array1 = DokArray::<f64>::new((2, 2));
660        array1
661            .set(0, 0, 1.0)
662            .expect("Test: failed to set array element");
663        array1
664            .set(0, 1, 2.0)
665            .expect("Test: failed to set array element");
666        array1
667            .set(1, 0, 3.0)
668            .expect("Test: failed to set array element");
669        array1
670            .set(1, 1, 4.0)
671            .expect("Test: failed to set array element");
672
673        let mut array2 = DokArray::<f64>::new((2, 2));
674        array2
675            .set(0, 0, 5.0)
676            .expect("Test: failed to set array element");
677        array2
678            .set(0, 1, 6.0)
679            .expect("Test: failed to set array element");
680        array2
681            .set(1, 0, 7.0)
682            .expect("Test: failed to set array element");
683        array2
684            .set(1, 1, 8.0)
685            .expect("Test: failed to set array element");
686
687        // Element-wise multiplication
688        let result = array1
689            .mul(&array2)
690            .expect("Test: array multiplication failed");
691        let dense_result = result.to_array();
692
693        assert_eq!(dense_result[[0, 0]], 5.0);
694        assert_eq!(dense_result[[0, 1]], 12.0);
695        assert_eq!(dense_result[[1, 0]], 21.0);
696        assert_eq!(dense_result[[1, 1]], 32.0);
697    }
698
699    #[test]
700    fn test_dok_array_dot() {
701        let mut array1 = DokArray::<f64>::new((2, 2));
702        array1
703            .set(0, 0, 1.0)
704            .expect("Test: failed to set array element");
705        array1
706            .set(0, 1, 2.0)
707            .expect("Test: failed to set array element");
708        array1
709            .set(1, 0, 3.0)
710            .expect("Test: failed to set array element");
711        array1
712            .set(1, 1, 4.0)
713            .expect("Test: failed to set array element");
714
715        let mut array2 = DokArray::<f64>::new((2, 2));
716        array2
717            .set(0, 0, 5.0)
718            .expect("Test: failed to set array element");
719        array2
720            .set(0, 1, 6.0)
721            .expect("Test: failed to set array element");
722        array2
723            .set(1, 0, 7.0)
724            .expect("Test: failed to set array element");
725        array2
726            .set(1, 1, 8.0)
727            .expect("Test: failed to set array element");
728
729        // Matrix multiplication
730        let result = array1.dot(&array2).expect("Test: array dot product failed");
731        let dense_result = result.to_array();
732
733        // [1 2] [5 6] = [1*5 + 2*7, 1*6 + 2*8] = [19, 22]
734        // [3 4] [7 8]   [3*5 + 4*7, 3*6 + 4*8]   [43, 50]
735        assert_eq!(dense_result[[0, 0]], 19.0);
736        assert_eq!(dense_result[[0, 1]], 22.0);
737        assert_eq!(dense_result[[1, 0]], 43.0);
738        assert_eq!(dense_result[[1, 1]], 50.0);
739    }
740
741    #[test]
742    fn test_dok_array_transpose() {
743        let mut array = DokArray::<f64>::new((2, 3));
744        array
745            .set(0, 0, 1.0)
746            .expect("Test: failed to set array element");
747        array
748            .set(0, 1, 2.0)
749            .expect("Test: failed to set array element");
750        array
751            .set(0, 2, 3.0)
752            .expect("Test: failed to set array element");
753        array
754            .set(1, 0, 4.0)
755            .expect("Test: failed to set array element");
756        array
757            .set(1, 1, 5.0)
758            .expect("Test: failed to set array element");
759        array
760            .set(1, 2, 6.0)
761            .expect("Test: failed to set array element");
762
763        let transposed = array.transpose().expect("Test: array transpose failed");
764
765        assert_eq!(transposed.shape(), (3, 2));
766        assert_eq!(transposed.get(0, 0), 1.0);
767        assert_eq!(transposed.get(1, 0), 2.0);
768        assert_eq!(transposed.get(2, 0), 3.0);
769        assert_eq!(transposed.get(0, 1), 4.0);
770        assert_eq!(transposed.get(1, 1), 5.0);
771        assert_eq!(transposed.get(2, 1), 6.0);
772    }
773
774    #[test]
775    fn test_dok_array_slice() {
776        let mut array = DokArray::<f64>::new((3, 3));
777        array
778            .set(0, 0, 1.0)
779            .expect("Test: failed to set array element");
780        array
781            .set(0, 1, 2.0)
782            .expect("Test: failed to set array element");
783        array
784            .set(0, 2, 3.0)
785            .expect("Test: failed to set array element");
786        array
787            .set(1, 0, 4.0)
788            .expect("Test: failed to set array element");
789        array
790            .set(1, 1, 5.0)
791            .expect("Test: failed to set array element");
792        array
793            .set(1, 2, 6.0)
794            .expect("Test: failed to set array element");
795        array
796            .set(2, 0, 7.0)
797            .expect("Test: failed to set array element");
798        array
799            .set(2, 1, 8.0)
800            .expect("Test: failed to set array element");
801        array
802            .set(2, 2, 9.0)
803            .expect("Test: failed to set array element");
804
805        let slice = array
806            .slice((0, 2), (1, 3))
807            .expect("Test: array slice failed");
808
809        assert_eq!(slice.shape(), (2, 2));
810        assert_eq!(slice.get(0, 0), 2.0);
811        assert_eq!(slice.get(0, 1), 3.0);
812        assert_eq!(slice.get(1, 0), 5.0);
813        assert_eq!(slice.get(1, 1), 6.0);
814    }
815
816    #[test]
817    fn test_dok_array_sum() {
818        let mut array = DokArray::<f64>::new((2, 3));
819        array
820            .set(0, 0, 1.0)
821            .expect("Test: failed to set array element");
822        array
823            .set(0, 1, 2.0)
824            .expect("Test: failed to set array element");
825        array
826            .set(0, 2, 3.0)
827            .expect("Test: failed to set array element");
828        array
829            .set(1, 0, 4.0)
830            .expect("Test: failed to set array element");
831        array
832            .set(1, 1, 5.0)
833            .expect("Test: failed to set array element");
834        array
835            .set(1, 2, 6.0)
836            .expect("Test: failed to set array element");
837
838        // Sum all elements
839        match array.sum(None).expect("Test: array sum failed") {
840            SparseSum::Scalar(sum) => assert_eq!(sum, 21.0),
841            _ => panic!("Expected scalar sum"),
842        }
843
844        // Sum along rows (axis 0)
845        match array.sum(Some(0)).expect("Test: array sum failed") {
846            SparseSum::SparseArray(sum_array) => {
847                assert_eq!(sum_array.shape(), (1, 3));
848                assert_eq!(sum_array.get(0, 0), 5.0);
849                assert_eq!(sum_array.get(0, 1), 7.0);
850                assert_eq!(sum_array.get(0, 2), 9.0);
851            }
852            _ => panic!("Expected sparse array"),
853        }
854
855        // Sum along columns (axis 1)
856        match array.sum(Some(1)).expect("Test: array sum failed") {
857            SparseSum::SparseArray(sum_array) => {
858                assert_eq!(sum_array.shape(), (2, 1));
859                assert_eq!(sum_array.get(0, 0), 6.0);
860                assert_eq!(sum_array.get(1, 0), 15.0);
861            }
862            _ => panic!("Expected sparse array"),
863        }
864    }
865}