1use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::ops::{Add, Div, Mul, Sub};
12
13use crate::coo_array::CooArray;
14use crate::error::{SparseError, SparseResult};
15use crate::sparray::{SparseArray, SparseSum};
16
17#[derive(Clone)]
31pub struct DokArray<T>
32where
33 T: Float
34 + Add<Output = T>
35 + Sub<Output = T>
36 + Mul<Output = T>
37 + Div<Output = T>
38 + Debug
39 + Copy
40 + 'static,
41{
42 data: HashMap<(usize, usize), T>,
44 shape: (usize, usize),
46}
47
48impl<T> DokArray<T>
49where
50 T: Float
51 + Add<Output = T>
52 + Sub<Output = T>
53 + Mul<Output = T>
54 + Div<Output = T>
55 + Debug
56 + Copy
57 + 'static,
58{
59 pub fn new(shape: (usize, usize)) -> Self {
67 Self {
68 data: HashMap::new(),
69 shape,
70 }
71 }
72
73 pub fn from_triplets(
87 rows: &[usize],
88 cols: &[usize],
89 data: &[T],
90 shape: (usize, usize),
91 ) -> SparseResult<Self> {
92 if rows.len() != cols.len() || rows.len() != data.len() {
93 return Err(SparseError::InconsistentData {
94 reason: "rows, cols, and data must have the same length".to_string(),
95 });
96 }
97
98 let mut dok = Self::new(shape);
99 for i in 0..rows.len() {
100 if rows[i] >= shape.0 || cols[i] >= shape.1 {
101 return Err(SparseError::IndexOutOfBounds {
102 index: (rows[i], cols[i]),
103 shape,
104 });
105 }
106 if !data[i].is_zero() {
108 dok.data.insert((rows[i], cols[i]), data[i]);
109 }
110 }
111
112 Ok(dok)
113 }
114
115 pub fn get_data(&self) -> &HashMap<(usize, usize), T> {
117 &self.data
118 }
119
120 pub fn to_triplets(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
122 let nnz = self.nnz();
123 let mut row_indices = Vec::with_capacity(nnz);
124 let mut col_indices = Vec::with_capacity(nnz);
125 let mut values = Vec::with_capacity(nnz);
126
127 let mut entries: Vec<_> = self.data.iter().collect();
129 entries.sort_by_key(|&(&(row, col), _)| (row, col));
130
131 for (&(row, col), &value) in entries {
132 row_indices.push(row);
133 col_indices.push(col);
134 values.push(value);
135 }
136
137 (
138 Array1::from_vec(row_indices),
139 Array1::from_vec(col_indices),
140 Array1::from_vec(values),
141 )
142 }
143
144 pub fn from_array(array: &Array2<T>) -> Self {
152 let shape = (array.shape()[0], array.shape()[1]);
153 let mut dok = Self::new(shape);
154
155 for ((i, j), &value) in array.indexed_iter() {
156 if !value.is_zero() {
157 dok.data.insert((i, j), value);
158 }
159 }
160
161 dok
162 }
163}
164
165impl<T> SparseArray<T> for DokArray<T>
166where
167 T: Float
168 + Add<Output = T>
169 + Sub<Output = T>
170 + Mul<Output = T>
171 + Div<Output = T>
172 + Debug
173 + Copy
174 + 'static,
175{
176 fn shape(&self) -> (usize, usize) {
177 self.shape
178 }
179
180 fn nnz(&self) -> usize {
181 self.data.len()
182 }
183
184 fn dtype(&self) -> &str {
185 "float" }
187
188 fn to_array(&self) -> Array2<T> {
189 let (rows, cols) = self.shape;
190 let mut result = Array2::zeros((rows, cols));
191
192 for (&(row, col), &value) in &self.data {
193 result[[row, col]] = value;
194 }
195
196 result
197 }
198
199 fn toarray(&self) -> Array2<T> {
200 self.to_array()
201 }
202
203 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
204 let (row_indices, col_indices, data) = self.to_triplets();
205 CooArray::new(data, row_indices, col_indices, self.shape, true)
206 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
207 }
208
209 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
210 match self.to_coo() {
212 Ok(coo) => coo.to_csr(),
213 Err(e) => Err(e),
214 }
215 }
216
217 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
218 match self.to_coo() {
220 Ok(coo) => coo.to_csc(),
221 Err(e) => Err(e),
222 }
223 }
224
225 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
226 Ok(Box::new(self.clone()))
228 }
229
230 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
231 Err(SparseError::NotImplemented(
232 "Conversion to LIL array".to_string(),
233 ))
234 }
235
236 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
237 Err(SparseError::NotImplemented(
238 "Conversion to DIA array".to_string(),
239 ))
240 }
241
242 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
243 Err(SparseError::NotImplemented(
244 "Conversion to BSR array".to_string(),
245 ))
246 }
247
248 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
249 if self.shape() != other.shape() {
250 return Err(SparseError::DimensionMismatch {
251 expected: self.shape().0,
252 found: other.shape().0,
253 });
254 }
255
256 let mut result = self.clone();
257 let other_array = other.to_array();
258
259 for (&(row, col), &value) in &self.data {
261 result.set(row, col, value + other_array[[row, col]])?;
262 }
263
264 for ((row, col), &value) in other_array.indexed_iter() {
266 if !self.data.contains_key(&(row, col)) && !value.is_zero() {
267 result.set(row, col, value)?;
268 }
269 }
270
271 Ok(Box::new(result))
272 }
273
274 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
275 if self.shape() != other.shape() {
276 return Err(SparseError::DimensionMismatch {
277 expected: self.shape().0,
278 found: other.shape().0,
279 });
280 }
281
282 let mut result = self.clone();
283 let other_array = other.to_array();
284
285 for (&(row, col), &value) in &self.data {
287 result.set(row, col, value - other_array[[row, col]])?;
288 }
289
290 for ((row, col), &value) in other_array.indexed_iter() {
292 if !self.data.contains_key(&(row, col)) && !value.is_zero() {
293 result.set(row, col, -value)?;
294 }
295 }
296
297 Ok(Box::new(result))
298 }
299
300 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
301 if self.shape() != other.shape() {
302 return Err(SparseError::DimensionMismatch {
303 expected: self.shape().0,
304 found: other.shape().0,
305 });
306 }
307
308 let mut result = DokArray::new(self.shape());
309 let other_array = other.to_array();
310
311 for (&(row, col), &value) in &self.data {
314 let product = value * other_array[[row, col]];
315 if !product.is_zero() {
316 result.set(row, col, product)?;
317 }
318 }
319
320 Ok(Box::new(result))
321 }
322
323 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
324 if self.shape() != other.shape() {
325 return Err(SparseError::DimensionMismatch {
326 expected: self.shape().0,
327 found: other.shape().0,
328 });
329 }
330
331 let mut result = DokArray::new(self.shape());
332 let other_array = other.to_array();
333
334 for (&(row, col), &value) in &self.data {
335 let divisor = other_array[[row, col]];
336 if divisor.is_zero() {
337 return Err(SparseError::ComputationError(
338 "Division by zero".to_string(),
339 ));
340 }
341
342 let quotient = value / divisor;
343 if !quotient.is_zero() {
344 result.set(row, col, quotient)?;
345 }
346 }
347
348 Ok(Box::new(result))
349 }
350
351 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
352 let (_m, n) = self.shape();
353 let (p, _q) = other.shape();
354
355 if n != p {
356 return Err(SparseError::DimensionMismatch {
357 expected: n,
358 found: p,
359 });
360 }
361
362 let csr_self = self.to_csr()?;
364 let csr_other = other.to_csr()?;
365
366 csr_self.dot(&*csr_other)
367 }
368
369 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
370 let (m, n) = self.shape();
371 if n != other.len() {
372 return Err(SparseError::DimensionMismatch {
373 expected: n,
374 found: other.len(),
375 });
376 }
377
378 let mut result = Array1::zeros(m);
379
380 for (&(row, col), &value) in &self.data {
381 result[row] = result[row] + value * other[col];
382 }
383
384 Ok(result)
385 }
386
387 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
388 let (rows, cols) = self.shape;
389 let mut result = DokArray::new((cols, rows));
390
391 for (&(row, col), &value) in &self.data {
392 result.set(col, row, value)?;
393 }
394
395 Ok(Box::new(result))
396 }
397
398 fn copy(&self) -> Box<dyn SparseArray<T>> {
399 Box::new(self.clone())
400 }
401
402 fn get(&self, i: usize, j: usize) -> T {
403 if i >= self.shape.0 || j >= self.shape.1 {
404 return T::zero();
405 }
406
407 *self.data.get(&(i, j)).unwrap_or(&T::zero())
408 }
409
410 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
411 if i >= self.shape.0 || j >= self.shape.1 {
412 return Err(SparseError::IndexOutOfBounds {
413 index: (i, j),
414 shape: self.shape,
415 });
416 }
417
418 if value.is_zero() {
419 self.data.remove(&(i, j));
421 } else {
422 self.data.insert((i, j), value);
424 }
425
426 Ok(())
427 }
428
429 fn eliminate_zeros(&mut self) {
430 self.data.retain(|_, &mut value| !value.is_zero());
432 }
433
434 fn sort_indices(&mut self) {
435 }
437
438 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
439 self.copy()
441 }
442
443 fn has_sorted_indices(&self) -> bool {
444 true }
446
447 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
448 match axis {
449 None => {
450 let mut sum = T::zero();
452 for &value in self.data.values() {
453 sum = sum + value;
454 }
455 Ok(SparseSum::Scalar(sum))
456 }
457 Some(0) => {
458 let (_, cols) = self.shape();
460 let mut result = DokArray::new((1, cols));
461
462 for (&(_row, col), &value) in &self.data {
463 let current = result.get(0, col);
464 result.set(0, col, current + value)?;
465 }
466
467 Ok(SparseSum::SparseArray(Box::new(result)))
468 }
469 Some(1) => {
470 let (rows, _) = self.shape();
472 let mut result = DokArray::new((rows, 1));
473
474 for (&(row, _col), &value) in &self.data {
475 let current = result.get(row, 0);
476 result.set(row, 0, current + value)?;
477 }
478
479 Ok(SparseSum::SparseArray(Box::new(result)))
480 }
481 _ => Err(SparseError::InvalidAxis),
482 }
483 }
484
485 fn max(&self) -> T {
486 if self.data.is_empty() {
487 return T::nan();
488 }
489
490 self.data
491 .values()
492 .fold(T::neg_infinity(), |acc, &x| acc.max(x))
493 }
494
495 fn min(&self) -> T {
496 if self.data.is_empty() {
497 return T::nan();
498 }
499
500 self.data.values().fold(T::infinity(), |acc, &x| acc.min(x))
501 }
502
503 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
504 self.to_triplets()
505 }
506
507 fn slice(
508 &self,
509 row_range: (usize, usize),
510 col_range: (usize, usize),
511 ) -> SparseResult<Box<dyn SparseArray<T>>> {
512 let (start_row, end_row) = row_range;
513 let (start_col, end_col) = col_range;
514 let (rows, cols) = self.shape;
515
516 if start_row >= rows
517 || end_row > rows
518 || start_col >= cols
519 || end_col > cols
520 || start_row >= end_row
521 || start_col >= end_col
522 {
523 return Err(SparseError::InvalidSliceRange);
524 }
525
526 let slice_shape = (end_row - start_row, end_col - start_col);
527 let mut result = DokArray::new(slice_shape);
528
529 for (&(row, col), &value) in &self.data {
530 if row >= start_row && row < end_row && col >= start_col && col < end_col {
531 result.set(row - start_row, col - start_col, value)?;
532 }
533 }
534
535 Ok(Box::new(result))
536 }
537
538 fn as_any(&self) -> &dyn Any {
539 self
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use ndarray::Array;
547
548 #[test]
549 fn test_dok_array_create_and_access() {
550 let mut array = DokArray::<f64>::new((3, 3));
552
553 array.set(0, 0, 1.0).unwrap();
555 array.set(0, 2, 2.0).unwrap();
556 array.set(1, 2, 3.0).unwrap();
557 array.set(2, 0, 4.0).unwrap();
558 array.set(2, 1, 5.0).unwrap();
559
560 assert_eq!(array.nnz(), 5);
561
562 assert_eq!(array.get(0, 0), 1.0);
564 assert_eq!(array.get(0, 1), 0.0); assert_eq!(array.get(0, 2), 2.0);
566 assert_eq!(array.get(1, 2), 3.0);
567 assert_eq!(array.get(2, 0), 4.0);
568 assert_eq!(array.get(2, 1), 5.0);
569
570 array.set(0, 0, 0.0).unwrap();
572 assert_eq!(array.nnz(), 4);
573 assert_eq!(array.get(0, 0), 0.0);
574
575 assert_eq!(array.get(3, 0), 0.0);
577 assert_eq!(array.get(0, 3), 0.0);
578 }
579
580 #[test]
581 fn test_dok_array_from_triplets() {
582 let rows = vec![0, 0, 1, 2, 2];
583 let cols = vec![0, 2, 2, 0, 1];
584 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
585
586 let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3)).unwrap();
587
588 assert_eq!(array.nnz(), 5);
589 assert_eq!(array.get(0, 0), 1.0);
590 assert_eq!(array.get(0, 2), 2.0);
591 assert_eq!(array.get(1, 2), 3.0);
592 assert_eq!(array.get(2, 0), 4.0);
593 assert_eq!(array.get(2, 1), 5.0);
594 }
595
596 #[test]
597 fn test_dok_array_to_array() {
598 let rows = vec![0, 0, 1, 2, 2];
599 let cols = vec![0, 2, 2, 0, 1];
600 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
601
602 let array = DokArray::from_triplets(&rows, &cols, &data, (3, 3)).unwrap();
603 let dense = array.to_array();
604
605 let expected =
606 Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
607 .unwrap();
608
609 assert_eq!(dense, expected);
610 }
611
612 #[test]
613 fn test_dok_array_from_array() {
614 let dense =
615 Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0])
616 .unwrap();
617
618 let array = DokArray::from_array(&dense);
619
620 assert_eq!(array.nnz(), 5);
621 assert_eq!(array.get(0, 0), 1.0);
622 assert_eq!(array.get(0, 2), 2.0);
623 assert_eq!(array.get(1, 2), 3.0);
624 assert_eq!(array.get(2, 0), 4.0);
625 assert_eq!(array.get(2, 1), 5.0);
626 }
627
628 #[test]
629 fn test_dok_array_add() {
630 let mut array1 = DokArray::<f64>::new((2, 2));
631 array1.set(0, 0, 1.0).unwrap();
632 array1.set(0, 1, 2.0).unwrap();
633 array1.set(1, 0, 3.0).unwrap();
634
635 let mut array2 = DokArray::<f64>::new((2, 2));
636 array2.set(0, 0, 4.0).unwrap();
637 array2.set(1, 1, 5.0).unwrap();
638
639 let result = array1.add(&array2).unwrap();
640 let dense_result = result.to_array();
641
642 assert_eq!(dense_result[[0, 0]], 5.0);
643 assert_eq!(dense_result[[0, 1]], 2.0);
644 assert_eq!(dense_result[[1, 0]], 3.0);
645 assert_eq!(dense_result[[1, 1]], 5.0);
646 }
647
648 #[test]
649 fn test_dok_array_mul() {
650 let mut array1 = DokArray::<f64>::new((2, 2));
651 array1.set(0, 0, 1.0).unwrap();
652 array1.set(0, 1, 2.0).unwrap();
653 array1.set(1, 0, 3.0).unwrap();
654 array1.set(1, 1, 4.0).unwrap();
655
656 let mut array2 = DokArray::<f64>::new((2, 2));
657 array2.set(0, 0, 5.0).unwrap();
658 array2.set(0, 1, 6.0).unwrap();
659 array2.set(1, 0, 7.0).unwrap();
660 array2.set(1, 1, 8.0).unwrap();
661
662 let result = array1.mul(&array2).unwrap();
664 let dense_result = result.to_array();
665
666 assert_eq!(dense_result[[0, 0]], 5.0);
667 assert_eq!(dense_result[[0, 1]], 12.0);
668 assert_eq!(dense_result[[1, 0]], 21.0);
669 assert_eq!(dense_result[[1, 1]], 32.0);
670 }
671
672 #[test]
673 fn test_dok_array_dot() {
674 let mut array1 = DokArray::<f64>::new((2, 2));
675 array1.set(0, 0, 1.0).unwrap();
676 array1.set(0, 1, 2.0).unwrap();
677 array1.set(1, 0, 3.0).unwrap();
678 array1.set(1, 1, 4.0).unwrap();
679
680 let mut array2 = DokArray::<f64>::new((2, 2));
681 array2.set(0, 0, 5.0).unwrap();
682 array2.set(0, 1, 6.0).unwrap();
683 array2.set(1, 0, 7.0).unwrap();
684 array2.set(1, 1, 8.0).unwrap();
685
686 let result = array1.dot(&array2).unwrap();
688 let dense_result = result.to_array();
689
690 assert_eq!(dense_result[[0, 0]], 19.0);
693 assert_eq!(dense_result[[0, 1]], 22.0);
694 assert_eq!(dense_result[[1, 0]], 43.0);
695 assert_eq!(dense_result[[1, 1]], 50.0);
696 }
697
698 #[test]
699 fn test_dok_array_transpose() {
700 let mut array = DokArray::<f64>::new((2, 3));
701 array.set(0, 0, 1.0).unwrap();
702 array.set(0, 1, 2.0).unwrap();
703 array.set(0, 2, 3.0).unwrap();
704 array.set(1, 0, 4.0).unwrap();
705 array.set(1, 1, 5.0).unwrap();
706 array.set(1, 2, 6.0).unwrap();
707
708 let transposed = array.transpose().unwrap();
709
710 assert_eq!(transposed.shape(), (3, 2));
711 assert_eq!(transposed.get(0, 0), 1.0);
712 assert_eq!(transposed.get(1, 0), 2.0);
713 assert_eq!(transposed.get(2, 0), 3.0);
714 assert_eq!(transposed.get(0, 1), 4.0);
715 assert_eq!(transposed.get(1, 1), 5.0);
716 assert_eq!(transposed.get(2, 1), 6.0);
717 }
718
719 #[test]
720 fn test_dok_array_slice() {
721 let mut array = DokArray::<f64>::new((3, 3));
722 array.set(0, 0, 1.0).unwrap();
723 array.set(0, 1, 2.0).unwrap();
724 array.set(0, 2, 3.0).unwrap();
725 array.set(1, 0, 4.0).unwrap();
726 array.set(1, 1, 5.0).unwrap();
727 array.set(1, 2, 6.0).unwrap();
728 array.set(2, 0, 7.0).unwrap();
729 array.set(2, 1, 8.0).unwrap();
730 array.set(2, 2, 9.0).unwrap();
731
732 let slice = array.slice((0, 2), (1, 3)).unwrap();
733
734 assert_eq!(slice.shape(), (2, 2));
735 assert_eq!(slice.get(0, 0), 2.0);
736 assert_eq!(slice.get(0, 1), 3.0);
737 assert_eq!(slice.get(1, 0), 5.0);
738 assert_eq!(slice.get(1, 1), 6.0);
739 }
740
741 #[test]
742 fn test_dok_array_sum() {
743 let mut array = DokArray::<f64>::new((2, 3));
744 array.set(0, 0, 1.0).unwrap();
745 array.set(0, 1, 2.0).unwrap();
746 array.set(0, 2, 3.0).unwrap();
747 array.set(1, 0, 4.0).unwrap();
748 array.set(1, 1, 5.0).unwrap();
749 array.set(1, 2, 6.0).unwrap();
750
751 match array.sum(None).unwrap() {
753 SparseSum::Scalar(sum) => assert_eq!(sum, 21.0),
754 _ => panic!("Expected scalar sum"),
755 }
756
757 match array.sum(Some(0)).unwrap() {
759 SparseSum::SparseArray(sum_array) => {
760 assert_eq!(sum_array.shape(), (1, 3));
761 assert_eq!(sum_array.get(0, 0), 5.0);
762 assert_eq!(sum_array.get(0, 1), 7.0);
763 assert_eq!(sum_array.get(0, 2), 9.0);
764 }
765 _ => panic!("Expected sparse array"),
766 }
767
768 match array.sum(Some(1)).unwrap() {
770 SparseSum::SparseArray(sum_array) => {
771 assert_eq!(sum_array.shape(), (2, 1));
772 assert_eq!(sum_array.get(0, 0), 6.0);
773 assert_eq!(sum_array.get(1, 0), 15.0);
774 }
775 _ => panic!("Expected sparse array"),
776 }
777 }
778}