1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::error::{SparseError, SparseResult};
15use crate::lil_array::LilArray;
16use crate::sparray::{SparseArray, SparseSum};
17
18#[derive(Clone)]
32pub struct DokArray<T>
33where
34 T: SparseElement + Div<Output = T> + 'static,
35{
36 data: HashMap<(usize, usize), T>,
38 shape: (usize, usize),
40}
41
42impl<T> DokArray<T>
43where
44 T: SparseElement + Div<Output = T> + 'static,
45{
46 pub fn new(shape: (usize, usize)) -> Self {
54 Self {
55 data: HashMap::new(),
56 shape,
57 }
58 }
59
60 pub fn from_triplets(
74 rows: &[usize],
75 cols: &[usize],
76 data: &[T],
77 shape: (usize, usize),
78 ) -> SparseResult<Self> {
79 if rows.len() != cols.len() || rows.len() != data.len() {
80 return Err(SparseError::InconsistentData {
81 reason: "rows, cols, and data must have the same length".to_string(),
82 });
83 }
84
85 let mut dok = Self::new(shape);
86 for i in 0..rows.len() {
87 if rows[i] >= shape.0 || cols[i] >= shape.1 {
88 return Err(SparseError::IndexOutOfBounds {
89 index: (rows[i], cols[i]),
90 shape,
91 });
92 }
93 if !SparseElement::is_zero(&data[i]) {
95 dok.data.insert((rows[i], cols[i]), data[i]);
96 }
97 }
98
99 Ok(dok)
100 }
101
102 pub fn get_data(&self) -> &HashMap<(usize, usize), T> {
104 &self.data
105 }
106
107 pub fn to_triplets(&self) -> (Array1<usize>, Array1<usize>, Array1<T>)
109 where
110 T: Float + PartialOrd,
111 {
112 let nnz = self.nnz();
113 let mut row_indices = Vec::with_capacity(nnz);
114 let mut col_indices = Vec::with_capacity(nnz);
115 let mut values = Vec::with_capacity(nnz);
116
117 let mut entries: Vec<_> = self.data.iter().collect();
119 entries.sort_by_key(|(&(row, col), _)| (row, col));
120
121 for (&(row, col), &value) in entries {
122 row_indices.push(row);
123 col_indices.push(col);
124 values.push(value);
125 }
126
127 (
128 Array1::from_vec(row_indices),
129 Array1::from_vec(col_indices),
130 Array1::from_vec(values),
131 )
132 }
133
134 pub fn from_array(array: &Array2<T>) -> Self {
142 let shape = (array.shape()[0], array.shape()[1]);
143 let mut dok = Self::new(shape);
144
145 for ((i, j), &value) in array.indexed_iter() {
146 if !SparseElement::is_zero(&value) {
147 dok.data.insert((i, j), value);
148 }
149 }
150
151 dok
152 }
153}
154
155impl<T> SparseArray<T> for DokArray<T>
156where
157 T: SparseElement + Div<Output = T> + Float + PartialOrd + 'static,
158{
159 fn shape(&self) -> (usize, usize) {
160 self.shape
161 }
162
163 fn nnz(&self) -> usize {
164 self.data.len()
165 }
166
167 fn dtype(&self) -> &str {
168 "float" }
170
171 fn to_array(&self) -> Array2<T> {
172 let (rows, cols) = self.shape;
173 let mut result = Array2::zeros((rows, cols));
174
175 for (&(row, col), &value) in &self.data {
176 result[[row, col]] = value;
177 }
178
179 result
180 }
181
182 fn toarray(&self) -> Array2<T> {
183 self.to_array()
184 }
185
186 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
187 let (row_indices, col_indices, data) = self.to_triplets();
188 CooArray::new(data, row_indices, col_indices, self.shape, true)
189 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
190 }
191
192 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
193 match self.to_coo() {
195 Ok(coo) => coo.to_csr(),
196 Err(e) => Err(e),
197 }
198 }
199
200 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
201 match self.to_coo() {
203 Ok(coo) => coo.to_csc(),
204 Err(e) => Err(e),
205 }
206 }
207
208 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
209 Ok(Box::new(self.clone()))
211 }
212
213 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
214 let (rows_arr, cols_arr, vals_arr) = self.to_triplets();
215 let rows_slice = rows_arr
216 .as_slice()
217 .ok_or_else(|| SparseError::ValueError("non-contiguous row indices".to_string()))?;
218 let cols_slice = cols_arr
219 .as_slice()
220 .ok_or_else(|| SparseError::ValueError("non-contiguous col indices".to_string()))?;
221 let vals_slice = vals_arr
222 .as_slice()
223 .ok_or_else(|| SparseError::ValueError("non-contiguous values".to_string()))?;
224 let lil = LilArray::from_triplets(rows_slice, cols_slice, vals_slice, self.shape)?;
225 Ok(Box::new(lil))
226 }
227
228 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
229 self.to_csr()?.to_dia()
230 }
231
232 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
233 self.to_csr()?.to_bsr()
234 }
235
236 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
237 if self.shape() != other.shape() {
238 return Err(SparseError::DimensionMismatch {
239 expected: self.shape().0,
240 found: other.shape().0,
241 });
242 }
243
244 let mut result = self.clone();
245 let other_array = other.to_array();
246
247 for (&(row, col), &value) in &self.data {
249 result.set(row, col, value + other_array[[row, col]])?;
250 }
251
252 for ((row, col), &value) in other_array.indexed_iter() {
254 if !self.data.contains_key(&(row, col)) && !SparseElement::is_zero(&value) {
255 result.set(row, col, value)?;
256 }
257 }
258
259 Ok(Box::new(result))
260 }
261
262 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
263 if self.shape() != other.shape() {
264 return Err(SparseError::DimensionMismatch {
265 expected: self.shape().0,
266 found: other.shape().0,
267 });
268 }
269
270 let mut result = self.clone();
271 let other_array = other.to_array();
272
273 for (&(row, col), &value) in &self.data {
275 result.set(row, col, value - other_array[[row, col]])?;
276 }
277
278 for ((row, col), &value) in other_array.indexed_iter() {
280 if !self.data.contains_key(&(row, col)) && !SparseElement::is_zero(&value) {
281 result.set(row, col, -value)?;
282 }
283 }
284
285 Ok(Box::new(result))
286 }
287
288 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
289 if self.shape() != other.shape() {
290 return Err(SparseError::DimensionMismatch {
291 expected: self.shape().0,
292 found: other.shape().0,
293 });
294 }
295
296 let mut result = DokArray::new(self.shape());
297 let other_array = other.to_array();
298
299 for (&(row, col), &value) in &self.data {
302 let product = value * other_array[[row, col]];
303 if !SparseElement::is_zero(&product) {
304 result.set(row, col, product)?;
305 }
306 }
307
308 Ok(Box::new(result))
309 }
310
311 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
312 if self.shape() != other.shape() {
313 return Err(SparseError::DimensionMismatch {
314 expected: self.shape().0,
315 found: other.shape().0,
316 });
317 }
318
319 let mut result = DokArray::new(self.shape());
320 let other_array = other.to_array();
321
322 for (&(row, col), &value) in &self.data {
323 let divisor = other_array[[row, col]];
324 if SparseElement::is_zero(&divisor) {
325 return Err(SparseError::ComputationError(
326 "Division by zero".to_string(),
327 ));
328 }
329
330 let quotient = value / divisor;
331 if !SparseElement::is_zero("ient) {
332 result.set(row, col, quotient)?;
333 }
334 }
335
336 Ok(Box::new(result))
337 }
338
339 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
340 let (_m, n) = self.shape();
341 let (p, q) = other.shape();
342
343 if n != p {
344 return Err(SparseError::DimensionMismatch {
345 expected: n,
346 found: p,
347 });
348 }
349
350 let csr_self = self.to_csr()?;
352 let csr_other = other.to_csr()?;
353
354 csr_self.dot(&*csr_other)
355 }
356
357 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
358 let (m, n) = self.shape();
359 if n != other.len() {
360 return Err(SparseError::DimensionMismatch {
361 expected: n,
362 found: other.len(),
363 });
364 }
365
366 let mut result = Array1::zeros(m);
367
368 for (&(row, col), &value) in &self.data {
369 result[row] = result[row] + value * other[col];
370 }
371
372 Ok(result)
373 }
374
375 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
376 let (rows, cols) = self.shape;
377 let mut result = DokArray::new((cols, rows));
378
379 for (&(row, col), &value) in &self.data {
380 result.set(col, row, value)?;
381 }
382
383 Ok(Box::new(result))
384 }
385
386 fn copy(&self) -> Box<dyn SparseArray<T>> {
387 Box::new(self.clone())
388 }
389
390 fn get(&self, i: usize, j: usize) -> T {
391 if i >= self.shape.0 || j >= self.shape.1 {
392 return T::sparse_zero();
393 }
394
395 *self.data.get(&(i, j)).unwrap_or(&T::sparse_zero())
396 }
397
398 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
399 if i >= self.shape.0 || j >= self.shape.1 {
400 return Err(SparseError::IndexOutOfBounds {
401 index: (i, j),
402 shape: self.shape,
403 });
404 }
405
406 if SparseElement::is_zero(&value) {
407 self.data.remove(&(i, j));
409 } else {
410 self.data.insert((i, j), value);
412 }
413
414 Ok(())
415 }
416
417 fn eliminate_zeros(&mut self) {
418 self.data
420 .retain(|_, &mut value| !SparseElement::is_zero(&value));
421 }
422
423 fn sort_indices(&mut self) {
424 }
426
427 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
428 self.copy()
430 }
431
432 fn has_sorted_indices(&self) -> bool {
433 true }
435
436 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
437 match axis {
438 None => {
439 let mut sum = T::sparse_zero();
441 for &value in self.data.values() {
442 sum = sum + value;
443 }
444 Ok(SparseSum::Scalar(sum))
445 }
446 Some(0) => {
447 let (_, cols) = self.shape();
449 let mut result = DokArray::new((1, cols));
450
451 for (&(_row, col), &value) in &self.data {
452 let current = result.get(0, col);
453 result.set(0, col, current + value)?;
454 }
455
456 Ok(SparseSum::SparseArray(Box::new(result)))
457 }
458 Some(1) => {
459 let (rows, _) = self.shape();
461 let mut result = DokArray::new((rows, 1));
462
463 for (&(row, col), &value) in &self.data {
464 let current = result.get(row, 0);
465 result.set(row, 0, current + value)?;
466 }
467
468 Ok(SparseSum::SparseArray(Box::new(result)))
469 }
470 _ => Err(SparseError::InvalidAxis),
471 }
472 }
473
474 fn max(&self) -> T {
475 if self.data.is_empty() {
476 return T::nan();
477 }
478
479 self.data
480 .values()
481 .fold(T::neg_infinity(), |acc, &x| acc.max(x))
482 }
483
484 fn min(&self) -> T {
485 if self.data.is_empty() {
486 return T::nan();
487 }
488
489 self.data
490 .values()
491 .fold(T::sparse_zero(), |acc, &x| acc.min(x))
492 }
493
494 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
495 self.to_triplets()
496 }
497
498 fn slice(
499 &self,
500 row_range: (usize, usize),
501 col_range: (usize, usize),
502 ) -> SparseResult<Box<dyn SparseArray<T>>> {
503 let (start_row, end_row) = row_range;
504 let (start_col, end_col) = col_range;
505 let (rows, cols) = self.shape;
506
507 if start_row >= rows
508 || end_row > rows
509 || start_col >= cols
510 || end_col > cols
511 || start_row >= end_row
512 || start_col >= end_col
513 {
514 return Err(SparseError::InvalidSliceRange);
515 }
516
517 let sliceshape = (end_row - start_row, end_col - start_col);
518 let mut result = DokArray::new(sliceshape);
519
520 for (&(row, col), &value) in &self.data {
521 if row >= start_row && row < end_row && col >= start_col && col < end_col {
522 result.set(row - start_row, col - start_col, value)?;
523 }
524 }
525
526 Ok(Box::new(result))
527 }
528
529 fn as_any(&self) -> &dyn Any {
530 self
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537 use scirs2_core::ndarray::Array;
538
539 #[test]
540 fn test_dok_array_create_and_access() {
541 let mut array = DokArray::<f64>::new((3, 3));
543
544 array
546 .set(0, 0, 1.0)
547 .expect("Test: failed to set array element");
548 array
549 .set(0, 2, 2.0)
550 .expect("Test: failed to set array element");
551 array
552 .set(1, 2, 3.0)
553 .expect("Test: failed to set array element");
554 array
555 .set(2, 0, 4.0)
556 .expect("Test: failed to set array element");
557 array
558 .set(2, 1, 5.0)
559 .expect("Test: failed to set array element");
560
561 assert_eq!(array.nnz(), 5);
562
563 assert_eq!(array.get(0, 0), 1.0);
565 assert_eq!(array.get(0, 1), 0.0); assert_eq!(array.get(0, 2), 2.0);
567 assert_eq!(array.get(1, 2), 3.0);
568 assert_eq!(array.get(2, 0), 4.0);
569 assert_eq!(array.get(2, 1), 5.0);
570
571 array
573 .set(0, 0, 0.0)
574 .expect("Test: failed to set array element");
575 assert_eq!(array.nnz(), 4);
576 assert_eq!(array.get(0, 0), 0.0);
577
578 assert_eq!(array.get(3, 0), 0.0);
580 assert_eq!(array.get(0, 3), 0.0);
581 }
582
583 #[test]
584 fn test_dok_array_from_triplets() {
585 let rows = vec![0, 0, 1, 2, 2];
586 let cols = vec![0, 2, 2, 0, 1];
587 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
588
589 let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3))
590 .expect("Test: failed to create DokArray from triplets");
591
592 assert_eq!(array.nnz(), 5);
593 assert_eq!(array.get(0, 0), 1.0);
594 assert_eq!(array.get(0, 2), 2.0);
595 assert_eq!(array.get(1, 2), 3.0);
596 assert_eq!(array.get(2, 0), 4.0);
597 assert_eq!(array.get(2, 1), 5.0);
598 }
599
600 #[test]
601 fn test_dok_array_to_array() {
602 let rows = vec![0, 0, 1, 2, 2];
603 let cols = vec![0, 2, 2, 0, 1];
604 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
605
606 let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3))
607 .expect("Test: failed to create DokArray from triplets");
608 let dense = array.to_array();
609
610 let expected =
611 Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
612 .expect("Test: failed to create array from shape vec");
613
614 assert_eq!(dense, expected);
615 }
616
617 #[test]
618 fn test_dok_array_from_array() {
619 let dense =
620 Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
621 .expect("Test: failed to create array from shape vec");
622
623 let array = DokArray::from_array(&dense);
624
625 assert_eq!(array.nnz(), 5);
626 assert_eq!(array.get(0, 0), 1.0);
627 assert_eq!(array.get(0, 2), 2.0);
628 assert_eq!(array.get(1, 2), 3.0);
629 assert_eq!(array.get(2, 0), 4.0);
630 assert_eq!(array.get(2, 1), 5.0);
631 }
632
633 #[test]
634 fn test_dok_array_add() {
635 let mut array1 = DokArray::<f64>::new((2, 2));
636 array1
637 .set(0, 0, 1.0)
638 .expect("Test: failed to set array element");
639 array1
640 .set(0, 1, 2.0)
641 .expect("Test: failed to set array element");
642 array1
643 .set(1, 0, 3.0)
644 .expect("Test: failed to set array element");
645
646 let mut array2 = DokArray::<f64>::new((2, 2));
647 array2
648 .set(0, 0, 4.0)
649 .expect("Test: failed to set array element");
650 array2
651 .set(1, 1, 5.0)
652 .expect("Test: failed to set array element");
653
654 let result = array1.add(&array2).expect("Test: array addition failed");
655 let dense_result = result.to_array();
656
657 assert_eq!(dense_result[[0, 0]], 5.0);
658 assert_eq!(dense_result[[0, 1]], 2.0);
659 assert_eq!(dense_result[[1, 0]], 3.0);
660 assert_eq!(dense_result[[1, 1]], 5.0);
661 }
662
663 #[test]
664 fn test_dok_array_mul() {
665 let mut array1 = DokArray::<f64>::new((2, 2));
666 array1
667 .set(0, 0, 1.0)
668 .expect("Test: failed to set array element");
669 array1
670 .set(0, 1, 2.0)
671 .expect("Test: failed to set array element");
672 array1
673 .set(1, 0, 3.0)
674 .expect("Test: failed to set array element");
675 array1
676 .set(1, 1, 4.0)
677 .expect("Test: failed to set array element");
678
679 let mut array2 = DokArray::<f64>::new((2, 2));
680 array2
681 .set(0, 0, 5.0)
682 .expect("Test: failed to set array element");
683 array2
684 .set(0, 1, 6.0)
685 .expect("Test: failed to set array element");
686 array2
687 .set(1, 0, 7.0)
688 .expect("Test: failed to set array element");
689 array2
690 .set(1, 1, 8.0)
691 .expect("Test: failed to set array element");
692
693 let result = array1
695 .mul(&array2)
696 .expect("Test: array multiplication failed");
697 let dense_result = result.to_array();
698
699 assert_eq!(dense_result[[0, 0]], 5.0);
700 assert_eq!(dense_result[[0, 1]], 12.0);
701 assert_eq!(dense_result[[1, 0]], 21.0);
702 assert_eq!(dense_result[[1, 1]], 32.0);
703 }
704
705 #[test]
706 fn test_dok_array_dot() {
707 let mut array1 = DokArray::<f64>::new((2, 2));
708 array1
709 .set(0, 0, 1.0)
710 .expect("Test: failed to set array element");
711 array1
712 .set(0, 1, 2.0)
713 .expect("Test: failed to set array element");
714 array1
715 .set(1, 0, 3.0)
716 .expect("Test: failed to set array element");
717 array1
718 .set(1, 1, 4.0)
719 .expect("Test: failed to set array element");
720
721 let mut array2 = DokArray::<f64>::new((2, 2));
722 array2
723 .set(0, 0, 5.0)
724 .expect("Test: failed to set array element");
725 array2
726 .set(0, 1, 6.0)
727 .expect("Test: failed to set array element");
728 array2
729 .set(1, 0, 7.0)
730 .expect("Test: failed to set array element");
731 array2
732 .set(1, 1, 8.0)
733 .expect("Test: failed to set array element");
734
735 let result = array1.dot(&array2).expect("Test: array dot product failed");
737 let dense_result = result.to_array();
738
739 assert_eq!(dense_result[[0, 0]], 19.0);
742 assert_eq!(dense_result[[0, 1]], 22.0);
743 assert_eq!(dense_result[[1, 0]], 43.0);
744 assert_eq!(dense_result[[1, 1]], 50.0);
745 }
746
747 #[test]
748 fn test_dok_array_transpose() {
749 let mut array = DokArray::<f64>::new((2, 3));
750 array
751 .set(0, 0, 1.0)
752 .expect("Test: failed to set array element");
753 array
754 .set(0, 1, 2.0)
755 .expect("Test: failed to set array element");
756 array
757 .set(0, 2, 3.0)
758 .expect("Test: failed to set array element");
759 array
760 .set(1, 0, 4.0)
761 .expect("Test: failed to set array element");
762 array
763 .set(1, 1, 5.0)
764 .expect("Test: failed to set array element");
765 array
766 .set(1, 2, 6.0)
767 .expect("Test: failed to set array element");
768
769 let transposed = array.transpose().expect("Test: array transpose failed");
770
771 assert_eq!(transposed.shape(), (3, 2));
772 assert_eq!(transposed.get(0, 0), 1.0);
773 assert_eq!(transposed.get(1, 0), 2.0);
774 assert_eq!(transposed.get(2, 0), 3.0);
775 assert_eq!(transposed.get(0, 1), 4.0);
776 assert_eq!(transposed.get(1, 1), 5.0);
777 assert_eq!(transposed.get(2, 1), 6.0);
778 }
779
780 #[test]
781 fn test_dok_array_slice() {
782 let mut array = DokArray::<f64>::new((3, 3));
783 array
784 .set(0, 0, 1.0)
785 .expect("Test: failed to set array element");
786 array
787 .set(0, 1, 2.0)
788 .expect("Test: failed to set array element");
789 array
790 .set(0, 2, 3.0)
791 .expect("Test: failed to set array element");
792 array
793 .set(1, 0, 4.0)
794 .expect("Test: failed to set array element");
795 array
796 .set(1, 1, 5.0)
797 .expect("Test: failed to set array element");
798 array
799 .set(1, 2, 6.0)
800 .expect("Test: failed to set array element");
801 array
802 .set(2, 0, 7.0)
803 .expect("Test: failed to set array element");
804 array
805 .set(2, 1, 8.0)
806 .expect("Test: failed to set array element");
807 array
808 .set(2, 2, 9.0)
809 .expect("Test: failed to set array element");
810
811 let slice = array
812 .slice((0, 2), (1, 3))
813 .expect("Test: array slice failed");
814
815 assert_eq!(slice.shape(), (2, 2));
816 assert_eq!(slice.get(0, 0), 2.0);
817 assert_eq!(slice.get(0, 1), 3.0);
818 assert_eq!(slice.get(1, 0), 5.0);
819 assert_eq!(slice.get(1, 1), 6.0);
820 }
821
822 #[test]
823 fn test_dok_array_sum() {
824 let mut array = DokArray::<f64>::new((2, 3));
825 array
826 .set(0, 0, 1.0)
827 .expect("Test: failed to set array element");
828 array
829 .set(0, 1, 2.0)
830 .expect("Test: failed to set array element");
831 array
832 .set(0, 2, 3.0)
833 .expect("Test: failed to set array element");
834 array
835 .set(1, 0, 4.0)
836 .expect("Test: failed to set array element");
837 array
838 .set(1, 1, 5.0)
839 .expect("Test: failed to set array element");
840 array
841 .set(1, 2, 6.0)
842 .expect("Test: failed to set array element");
843
844 match array.sum(None).expect("Test: array sum failed") {
846 SparseSum::Scalar(sum) => assert_eq!(sum, 21.0),
847 _ => panic!("Expected scalar sum"),
848 }
849
850 match array.sum(Some(0)).expect("Test: array sum failed") {
852 SparseSum::SparseArray(sum_array) => {
853 assert_eq!(sum_array.shape(), (1, 3));
854 assert_eq!(sum_array.get(0, 0), 5.0);
855 assert_eq!(sum_array.get(0, 1), 7.0);
856 assert_eq!(sum_array.get(0, 2), 9.0);
857 }
858 _ => panic!("Expected sparse array"),
859 }
860
861 match array.sum(Some(1)).expect("Test: array sum failed") {
863 SparseSum::SparseArray(sum_array) => {
864 assert_eq!(sum_array.shape(), (2, 1));
865 assert_eq!(sum_array.get(0, 0), 6.0);
866 assert_eq!(sum_array.get(1, 0), 15.0);
867 }
868 _ => panic!("Expected sparse array"),
869 }
870 }
871}