1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::error::{SparseError, SparseResult};
15use crate::sparray::{SparseArray, SparseSum};
16
17#[derive(Clone)]
31pub struct DokArray<T>
32where
33 T: SparseElement + Div<Output = T> + 'static,
34{
35 data: HashMap<(usize, usize), T>,
37 shape: (usize, usize),
39}
40
41impl<T> DokArray<T>
42where
43 T: SparseElement + Div<Output = T> + 'static,
44{
45 pub fn new(shape: (usize, usize)) -> Self {
53 Self {
54 data: HashMap::new(),
55 shape,
56 }
57 }
58
59 pub fn from_triplets(
73 rows: &[usize],
74 cols: &[usize],
75 data: &[T],
76 shape: (usize, usize),
77 ) -> SparseResult<Self> {
78 if rows.len() != cols.len() || rows.len() != data.len() {
79 return Err(SparseError::InconsistentData {
80 reason: "rows, cols, and data must have the same length".to_string(),
81 });
82 }
83
84 let mut dok = Self::new(shape);
85 for i in 0..rows.len() {
86 if rows[i] >= shape.0 || cols[i] >= shape.1 {
87 return Err(SparseError::IndexOutOfBounds {
88 index: (rows[i], cols[i]),
89 shape,
90 });
91 }
92 if !SparseElement::is_zero(&data[i]) {
94 dok.data.insert((rows[i], cols[i]), data[i]);
95 }
96 }
97
98 Ok(dok)
99 }
100
101 pub fn get_data(&self) -> &HashMap<(usize, usize), T> {
103 &self.data
104 }
105
106 pub fn to_triplets(&self) -> (Array1<usize>, Array1<usize>, Array1<T>)
108 where
109 T: Float + PartialOrd,
110 {
111 let nnz = self.nnz();
112 let mut row_indices = Vec::with_capacity(nnz);
113 let mut col_indices = Vec::with_capacity(nnz);
114 let mut values = Vec::with_capacity(nnz);
115
116 let mut entries: Vec<_> = self.data.iter().collect();
118 entries.sort_by_key(|(&(row, col), _)| (row, col));
119
120 for (&(row, col), &value) in entries {
121 row_indices.push(row);
122 col_indices.push(col);
123 values.push(value);
124 }
125
126 (
127 Array1::from_vec(row_indices),
128 Array1::from_vec(col_indices),
129 Array1::from_vec(values),
130 )
131 }
132
133 pub fn from_array(array: &Array2<T>) -> Self {
141 let shape = (array.shape()[0], array.shape()[1]);
142 let mut dok = Self::new(shape);
143
144 for ((i, j), &value) in array.indexed_iter() {
145 if !SparseElement::is_zero(&value) {
146 dok.data.insert((i, j), value);
147 }
148 }
149
150 dok
151 }
152}
153
154impl<T> SparseArray<T> for DokArray<T>
155where
156 T: SparseElement + Div<Output = T> + Float + PartialOrd + 'static,
157{
158 fn shape(&self) -> (usize, usize) {
159 self.shape
160 }
161
162 fn nnz(&self) -> usize {
163 self.data.len()
164 }
165
166 fn dtype(&self) -> &str {
167 "float" }
169
170 fn to_array(&self) -> Array2<T> {
171 let (rows, cols) = self.shape;
172 let mut result = Array2::zeros((rows, cols));
173
174 for (&(row, col), &value) in &self.data {
175 result[[row, col]] = value;
176 }
177
178 result
179 }
180
181 fn toarray(&self) -> Array2<T> {
182 self.to_array()
183 }
184
185 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
186 let (row_indices, col_indices, data) = self.to_triplets();
187 CooArray::new(data, row_indices, col_indices, self.shape, true)
188 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
189 }
190
191 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
192 match self.to_coo() {
194 Ok(coo) => coo.to_csr(),
195 Err(e) => Err(e),
196 }
197 }
198
199 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
200 match self.to_coo() {
202 Ok(coo) => coo.to_csc(),
203 Err(e) => Err(e),
204 }
205 }
206
207 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
208 Ok(Box::new(self.clone()))
210 }
211
212 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
213 Err(SparseError::NotImplemented(
214 "Conversion to LIL array".to_string(),
215 ))
216 }
217
218 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
219 Err(SparseError::NotImplemented(
220 "Conversion to DIA array".to_string(),
221 ))
222 }
223
224 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
225 Err(SparseError::NotImplemented(
226 "Conversion to BSR array".to_string(),
227 ))
228 }
229
230 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
231 if self.shape() != other.shape() {
232 return Err(SparseError::DimensionMismatch {
233 expected: self.shape().0,
234 found: other.shape().0,
235 });
236 }
237
238 let mut result = self.clone();
239 let other_array = other.to_array();
240
241 for (&(row, col), &value) in &self.data {
243 result.set(row, col, value + other_array[[row, col]])?;
244 }
245
246 for ((row, col), &value) in other_array.indexed_iter() {
248 if !self.data.contains_key(&(row, col)) && !SparseElement::is_zero(&value) {
249 result.set(row, col, value)?;
250 }
251 }
252
253 Ok(Box::new(result))
254 }
255
256 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
257 if self.shape() != other.shape() {
258 return Err(SparseError::DimensionMismatch {
259 expected: self.shape().0,
260 found: other.shape().0,
261 });
262 }
263
264 let mut result = self.clone();
265 let other_array = other.to_array();
266
267 for (&(row, col), &value) in &self.data {
269 result.set(row, col, value - other_array[[row, col]])?;
270 }
271
272 for ((row, col), &value) in other_array.indexed_iter() {
274 if !self.data.contains_key(&(row, col)) && !SparseElement::is_zero(&value) {
275 result.set(row, col, -value)?;
276 }
277 }
278
279 Ok(Box::new(result))
280 }
281
282 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
283 if self.shape() != other.shape() {
284 return Err(SparseError::DimensionMismatch {
285 expected: self.shape().0,
286 found: other.shape().0,
287 });
288 }
289
290 let mut result = DokArray::new(self.shape());
291 let other_array = other.to_array();
292
293 for (&(row, col), &value) in &self.data {
296 let product = value * other_array[[row, col]];
297 if !SparseElement::is_zero(&product) {
298 result.set(row, col, product)?;
299 }
300 }
301
302 Ok(Box::new(result))
303 }
304
305 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
306 if self.shape() != other.shape() {
307 return Err(SparseError::DimensionMismatch {
308 expected: self.shape().0,
309 found: other.shape().0,
310 });
311 }
312
313 let mut result = DokArray::new(self.shape());
314 let other_array = other.to_array();
315
316 for (&(row, col), &value) in &self.data {
317 let divisor = other_array[[row, col]];
318 if SparseElement::is_zero(&divisor) {
319 return Err(SparseError::ComputationError(
320 "Division by zero".to_string(),
321 ));
322 }
323
324 let quotient = value / divisor;
325 if !SparseElement::is_zero("ient) {
326 result.set(row, col, quotient)?;
327 }
328 }
329
330 Ok(Box::new(result))
331 }
332
333 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
334 let (_m, n) = self.shape();
335 let (p, q) = other.shape();
336
337 if n != p {
338 return Err(SparseError::DimensionMismatch {
339 expected: n,
340 found: p,
341 });
342 }
343
344 let csr_self = self.to_csr()?;
346 let csr_other = other.to_csr()?;
347
348 csr_self.dot(&*csr_other)
349 }
350
351 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
352 let (m, n) = self.shape();
353 if n != other.len() {
354 return Err(SparseError::DimensionMismatch {
355 expected: n,
356 found: other.len(),
357 });
358 }
359
360 let mut result = Array1::zeros(m);
361
362 for (&(row, col), &value) in &self.data {
363 result[row] = result[row] + value * other[col];
364 }
365
366 Ok(result)
367 }
368
369 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
370 let (rows, cols) = self.shape;
371 let mut result = DokArray::new((cols, rows));
372
373 for (&(row, col), &value) in &self.data {
374 result.set(col, row, value)?;
375 }
376
377 Ok(Box::new(result))
378 }
379
380 fn copy(&self) -> Box<dyn SparseArray<T>> {
381 Box::new(self.clone())
382 }
383
384 fn get(&self, i: usize, j: usize) -> T {
385 if i >= self.shape.0 || j >= self.shape.1 {
386 return T::sparse_zero();
387 }
388
389 *self.data.get(&(i, j)).unwrap_or(&T::sparse_zero())
390 }
391
392 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
393 if i >= self.shape.0 || j >= self.shape.1 {
394 return Err(SparseError::IndexOutOfBounds {
395 index: (i, j),
396 shape: self.shape,
397 });
398 }
399
400 if SparseElement::is_zero(&value) {
401 self.data.remove(&(i, j));
403 } else {
404 self.data.insert((i, j), value);
406 }
407
408 Ok(())
409 }
410
411 fn eliminate_zeros(&mut self) {
412 self.data
414 .retain(|_, &mut value| !SparseElement::is_zero(&value));
415 }
416
417 fn sort_indices(&mut self) {
418 }
420
421 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
422 self.copy()
424 }
425
426 fn has_sorted_indices(&self) -> bool {
427 true }
429
430 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
431 match axis {
432 None => {
433 let mut sum = T::sparse_zero();
435 for &value in self.data.values() {
436 sum = sum + value;
437 }
438 Ok(SparseSum::Scalar(sum))
439 }
440 Some(0) => {
441 let (_, cols) = self.shape();
443 let mut result = DokArray::new((1, cols));
444
445 for (&(_row, col), &value) in &self.data {
446 let current = result.get(0, col);
447 result.set(0, col, current + value)?;
448 }
449
450 Ok(SparseSum::SparseArray(Box::new(result)))
451 }
452 Some(1) => {
453 let (rows, _) = self.shape();
455 let mut result = DokArray::new((rows, 1));
456
457 for (&(row, col), &value) in &self.data {
458 let current = result.get(row, 0);
459 result.set(row, 0, current + value)?;
460 }
461
462 Ok(SparseSum::SparseArray(Box::new(result)))
463 }
464 _ => Err(SparseError::InvalidAxis),
465 }
466 }
467
468 fn max(&self) -> T {
469 if self.data.is_empty() {
470 return T::nan();
471 }
472
473 self.data
474 .values()
475 .fold(T::neg_infinity(), |acc, &x| acc.max(x))
476 }
477
478 fn min(&self) -> T {
479 if self.data.is_empty() {
480 return T::nan();
481 }
482
483 self.data
484 .values()
485 .fold(T::sparse_zero(), |acc, &x| acc.min(x))
486 }
487
488 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
489 self.to_triplets()
490 }
491
492 fn slice(
493 &self,
494 row_range: (usize, usize),
495 col_range: (usize, usize),
496 ) -> SparseResult<Box<dyn SparseArray<T>>> {
497 let (start_row, end_row) = row_range;
498 let (start_col, end_col) = col_range;
499 let (rows, cols) = self.shape;
500
501 if start_row >= rows
502 || end_row > rows
503 || start_col >= cols
504 || end_col > cols
505 || start_row >= end_row
506 || start_col >= end_col
507 {
508 return Err(SparseError::InvalidSliceRange);
509 }
510
511 let sliceshape = (end_row - start_row, end_col - start_col);
512 let mut result = DokArray::new(sliceshape);
513
514 for (&(row, col), &value) in &self.data {
515 if row >= start_row && row < end_row && col >= start_col && col < end_col {
516 result.set(row - start_row, col - start_col, value)?;
517 }
518 }
519
520 Ok(Box::new(result))
521 }
522
523 fn as_any(&self) -> &dyn Any {
524 self
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use scirs2_core::ndarray::Array;
532
533 #[test]
534 fn test_dok_array_create_and_access() {
535 let mut array = DokArray::<f64>::new((3, 3));
537
538 array.set(0, 0, 1.0).unwrap();
540 array.set(0, 2, 2.0).unwrap();
541 array.set(1, 2, 3.0).unwrap();
542 array.set(2, 0, 4.0).unwrap();
543 array.set(2, 1, 5.0).unwrap();
544
545 assert_eq!(array.nnz(), 5);
546
547 assert_eq!(array.get(0, 0), 1.0);
549 assert_eq!(array.get(0, 1), 0.0); assert_eq!(array.get(0, 2), 2.0);
551 assert_eq!(array.get(1, 2), 3.0);
552 assert_eq!(array.get(2, 0), 4.0);
553 assert_eq!(array.get(2, 1), 5.0);
554
555 array.set(0, 0, 0.0).unwrap();
557 assert_eq!(array.nnz(), 4);
558 assert_eq!(array.get(0, 0), 0.0);
559
560 assert_eq!(array.get(3, 0), 0.0);
562 assert_eq!(array.get(0, 3), 0.0);
563 }
564
565 #[test]
566 fn test_dok_array_from_triplets() {
567 let rows = vec![0, 0, 1, 2, 2];
568 let cols = vec![0, 2, 2, 0, 1];
569 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
570
571 let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3)).unwrap();
572
573 assert_eq!(array.nnz(), 5);
574 assert_eq!(array.get(0, 0), 1.0);
575 assert_eq!(array.get(0, 2), 2.0);
576 assert_eq!(array.get(1, 2), 3.0);
577 assert_eq!(array.get(2, 0), 4.0);
578 assert_eq!(array.get(2, 1), 5.0);
579 }
580
581 #[test]
582 fn test_dok_array_to_array() {
583 let rows = vec![0, 0, 1, 2, 2];
584 let cols = vec![0, 2, 2, 0, 1];
585 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
586
587 let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3)).unwrap();
588 let dense = array.to_array();
589
590 let expected =
591 Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
592 .unwrap();
593
594 assert_eq!(dense, expected);
595 }
596
597 #[test]
598 fn test_dok_array_from_array() {
599 let dense =
600 Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
601 .unwrap();
602
603 let array = DokArray::from_array(&dense);
604
605 assert_eq!(array.nnz(), 5);
606 assert_eq!(array.get(0, 0), 1.0);
607 assert_eq!(array.get(0, 2), 2.0);
608 assert_eq!(array.get(1, 2), 3.0);
609 assert_eq!(array.get(2, 0), 4.0);
610 assert_eq!(array.get(2, 1), 5.0);
611 }
612
613 #[test]
614 fn test_dok_array_add() {
615 let mut array1 = DokArray::<f64>::new((2, 2));
616 array1.set(0, 0, 1.0).unwrap();
617 array1.set(0, 1, 2.0).unwrap();
618 array1.set(1, 0, 3.0).unwrap();
619
620 let mut array2 = DokArray::<f64>::new((2, 2));
621 array2.set(0, 0, 4.0).unwrap();
622 array2.set(1, 1, 5.0).unwrap();
623
624 let result = array1.add(&array2).unwrap();
625 let dense_result = result.to_array();
626
627 assert_eq!(dense_result[[0, 0]], 5.0);
628 assert_eq!(dense_result[[0, 1]], 2.0);
629 assert_eq!(dense_result[[1, 0]], 3.0);
630 assert_eq!(dense_result[[1, 1]], 5.0);
631 }
632
633 #[test]
634 fn test_dok_array_mul() {
635 let mut array1 = DokArray::<f64>::new((2, 2));
636 array1.set(0, 0, 1.0).unwrap();
637 array1.set(0, 1, 2.0).unwrap();
638 array1.set(1, 0, 3.0).unwrap();
639 array1.set(1, 1, 4.0).unwrap();
640
641 let mut array2 = DokArray::<f64>::new((2, 2));
642 array2.set(0, 0, 5.0).unwrap();
643 array2.set(0, 1, 6.0).unwrap();
644 array2.set(1, 0, 7.0).unwrap();
645 array2.set(1, 1, 8.0).unwrap();
646
647 let result = array1.mul(&array2).unwrap();
649 let dense_result = result.to_array();
650
651 assert_eq!(dense_result[[0, 0]], 5.0);
652 assert_eq!(dense_result[[0, 1]], 12.0);
653 assert_eq!(dense_result[[1, 0]], 21.0);
654 assert_eq!(dense_result[[1, 1]], 32.0);
655 }
656
657 #[test]
658 fn test_dok_array_dot() {
659 let mut array1 = DokArray::<f64>::new((2, 2));
660 array1.set(0, 0, 1.0).unwrap();
661 array1.set(0, 1, 2.0).unwrap();
662 array1.set(1, 0, 3.0).unwrap();
663 array1.set(1, 1, 4.0).unwrap();
664
665 let mut array2 = DokArray::<f64>::new((2, 2));
666 array2.set(0, 0, 5.0).unwrap();
667 array2.set(0, 1, 6.0).unwrap();
668 array2.set(1, 0, 7.0).unwrap();
669 array2.set(1, 1, 8.0).unwrap();
670
671 let result = array1.dot(&array2).unwrap();
673 let dense_result = result.to_array();
674
675 assert_eq!(dense_result[[0, 0]], 19.0);
678 assert_eq!(dense_result[[0, 1]], 22.0);
679 assert_eq!(dense_result[[1, 0]], 43.0);
680 assert_eq!(dense_result[[1, 1]], 50.0);
681 }
682
683 #[test]
684 fn test_dok_array_transpose() {
685 let mut array = DokArray::<f64>::new((2, 3));
686 array.set(0, 0, 1.0).unwrap();
687 array.set(0, 1, 2.0).unwrap();
688 array.set(0, 2, 3.0).unwrap();
689 array.set(1, 0, 4.0).unwrap();
690 array.set(1, 1, 5.0).unwrap();
691 array.set(1, 2, 6.0).unwrap();
692
693 let transposed = array.transpose().unwrap();
694
695 assert_eq!(transposed.shape(), (3, 2));
696 assert_eq!(transposed.get(0, 0), 1.0);
697 assert_eq!(transposed.get(1, 0), 2.0);
698 assert_eq!(transposed.get(2, 0), 3.0);
699 assert_eq!(transposed.get(0, 1), 4.0);
700 assert_eq!(transposed.get(1, 1), 5.0);
701 assert_eq!(transposed.get(2, 1), 6.0);
702 }
703
704 #[test]
705 fn test_dok_array_slice() {
706 let mut array = DokArray::<f64>::new((3, 3));
707 array.set(0, 0, 1.0).unwrap();
708 array.set(0, 1, 2.0).unwrap();
709 array.set(0, 2, 3.0).unwrap();
710 array.set(1, 0, 4.0).unwrap();
711 array.set(1, 1, 5.0).unwrap();
712 array.set(1, 2, 6.0).unwrap();
713 array.set(2, 0, 7.0).unwrap();
714 array.set(2, 1, 8.0).unwrap();
715 array.set(2, 2, 9.0).unwrap();
716
717 let slice = array.slice((0, 2), (1, 3)).unwrap();
718
719 assert_eq!(slice.shape(), (2, 2));
720 assert_eq!(slice.get(0, 0), 2.0);
721 assert_eq!(slice.get(0, 1), 3.0);
722 assert_eq!(slice.get(1, 0), 5.0);
723 assert_eq!(slice.get(1, 1), 6.0);
724 }
725
726 #[test]
727 fn test_dok_array_sum() {
728 let mut array = DokArray::<f64>::new((2, 3));
729 array.set(0, 0, 1.0).unwrap();
730 array.set(0, 1, 2.0).unwrap();
731 array.set(0, 2, 3.0).unwrap();
732 array.set(1, 0, 4.0).unwrap();
733 array.set(1, 1, 5.0).unwrap();
734 array.set(1, 2, 6.0).unwrap();
735
736 match array.sum(None).unwrap() {
738 SparseSum::Scalar(sum) => assert_eq!(sum, 21.0),
739 _ => panic!("Expected scalar sum"),
740 }
741
742 match array.sum(Some(0)).unwrap() {
744 SparseSum::SparseArray(sum_array) => {
745 assert_eq!(sum_array.shape(), (1, 3));
746 assert_eq!(sum_array.get(0, 0), 5.0);
747 assert_eq!(sum_array.get(0, 1), 7.0);
748 assert_eq!(sum_array.get(0, 2), 9.0);
749 }
750 _ => panic!("Expected sparse array"),
751 }
752
753 match array.sum(Some(1)).unwrap() {
755 SparseSum::SparseArray(sum_array) => {
756 assert_eq!(sum_array.shape(), (2, 1));
757 assert_eq!(sum_array.get(0, 0), 6.0);
758 assert_eq!(sum_array.get(1, 0), 15.0);
759 }
760 _ => panic!("Expected sparse array"),
761 }
762 }
763}