1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::csr_array::CsrArray;
12use crate::error::{SparseError, SparseResult};
13use crate::sparray::{SparseArray, SparseSum};
14
15#[derive(Clone)]
31pub struct CooArray<T>
32where
33 T: SparseElement + Div<Output = T> + PartialOrd + 'static,
34{
35 row: Array1<usize>,
37 col: Array1<usize>,
39 data: Array1<T>,
41 shape: (usize, usize),
43 has_canonical_format: bool,
45}
46
47impl<T> CooArray<T>
48where
49 T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
50{
51 pub fn new(
66 data: Array1<T>,
67 row: Array1<usize>,
68 col: Array1<usize>,
69 shape: (usize, usize),
70 has_canonical_format: bool,
71 ) -> SparseResult<Self> {
72 if data.len() != row.len() || data.len() != col.len() {
74 return Err(SparseError::InconsistentData {
75 reason: "data, row, and col must have the same length".to_string(),
76 });
77 }
78
79 if let Some(&max_row) = row.iter().max() {
80 if max_row >= shape.0 {
81 return Err(SparseError::IndexOutOfBounds {
82 index: (max_row, 0),
83 shape,
84 });
85 }
86 }
87
88 if let Some(&max_col) = col.iter().max() {
89 if max_col >= shape.1 {
90 return Err(SparseError::IndexOutOfBounds {
91 index: (0, max_col),
92 shape,
93 });
94 }
95 }
96
97 Ok(Self {
98 data,
99 row,
100 col,
101 shape,
102 has_canonical_format,
103 })
104 }
105
106 pub fn from_triplets(
121 row: &[usize],
122 col: &[usize],
123 data: &[T],
124 shape: (usize, usize),
125 sorted: bool,
126 ) -> SparseResult<Self> {
127 let row_array = Array1::from_vec(row.to_vec());
128 let col_array = Array1::from_vec(col.to_vec());
129 let data_array = Array1::from_vec(data.to_vec());
130
131 Self::new(data_array, row_array, col_array, shape, sorted)
132 }
133
134 pub fn get_rows(&self) -> &Array1<usize> {
136 &self.row
137 }
138
139 pub fn get_cols(&self) -> &Array1<usize> {
141 &self.col
142 }
143
144 pub fn get_data(&self) -> &Array1<T> {
146 &self.data
147 }
148
149 pub fn canonical_format(&mut self) {
151 if self.has_canonical_format {
152 return;
153 }
154
155 let n = self.data.len();
156 let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(n);
157
158 for i in 0..n {
159 triplets.push((self.row[i], self.col[i], self.data[i]));
160 }
161
162 triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
163
164 for (i, &(r, c, v)) in triplets.iter().enumerate() {
165 self.row[i] = r;
166 self.col[i] = c;
167 self.data[i] = v;
168 }
169
170 self.has_canonical_format = true;
171 }
172
173 pub fn sum_duplicates(&mut self) {
175 self.canonical_format();
176
177 let n = self.data.len();
178 if n == 0 {
179 return;
180 }
181
182 let mut new_data = Vec::new();
183 let mut new_row = Vec::new();
184 let mut new_col = Vec::new();
185
186 let mut curr_row = self.row[0];
187 let mut curr_col = self.col[0];
188 let mut curr_sum = self.data[0];
189
190 for i in 1..n {
191 if self.row[i] == curr_row && self.col[i] == curr_col {
192 curr_sum = curr_sum + self.data[i];
193 } else {
194 if curr_sum != T::sparse_zero() {
195 new_data.push(curr_sum);
196 new_row.push(curr_row);
197 new_col.push(curr_col);
198 }
199 curr_row = self.row[i];
200 curr_col = self.col[i];
201 curr_sum = self.data[i];
202 }
203 }
204
205 if curr_sum != T::sparse_zero() {
207 new_data.push(curr_sum);
208 new_row.push(curr_row);
209 new_col.push(curr_col);
210 }
211
212 self.data = Array1::from_vec(new_data);
213 self.row = Array1::from_vec(new_row);
214 self.col = Array1::from_vec(new_col);
215 }
216}
217
218impl<T> SparseArray<T> for CooArray<T>
219where
220 T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
221{
222 fn shape(&self) -> (usize, usize) {
223 self.shape
224 }
225
226 fn nnz(&self) -> usize {
227 self.data.len()
228 }
229
230 fn dtype(&self) -> &str {
231 "float" }
233
234 fn to_array(&self) -> Array2<T> {
235 let (rows, cols) = self.shape;
236 let mut result = Array2::zeros((rows, cols));
237
238 for i in 0..self.data.len() {
239 let r = self.row[i];
240 let c = self.col[i];
241 result[[r, c]] = result[[r, c]] + self.data[i]; }
243
244 result
245 }
246
247 fn toarray(&self) -> Array2<T> {
248 self.to_array()
249 }
250
251 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
252 let mut new_coo = self.clone();
254 new_coo.sum_duplicates();
255 Ok(Box::new(new_coo))
256 }
257
258 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
259 let mut data_vec = self.data.to_vec();
261 let mut row_vec = self.row.to_vec();
262 let mut col_vec = self.col.to_vec();
263
264 let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(data_vec.len());
266 for i in 0..data_vec.len() {
267 triplets.push((row_vec[i], col_vec[i], data_vec[i]));
268 }
269 triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
270
271 for (i, &(r, c, v)) in triplets.iter().enumerate() {
272 row_vec[i] = r;
273 col_vec[i] = c;
274 data_vec[i] = v;
275 }
276
277 CsrArray::from_triplets(&row_vec, &col_vec, &data_vec, self.shape, true)
279 .map(|csr| Box::new(csr) as Box<dyn SparseArray<T>>)
280 }
281
282 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
283 let csr = self.to_csr()?;
286 csr.transpose()
287 }
288
289 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
290 Ok(Box::new(self.clone()))
292 }
293
294 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
295 Ok(Box::new(self.clone()))
297 }
298
299 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
300 Ok(Box::new(self.clone()))
302 }
303
304 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
305 Ok(Box::new(self.clone()))
307 }
308
309 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
310 let self_csr = self.to_csr()?;
312 self_csr.add(other)
313 }
314
315 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
316 let self_csr = self.to_csr()?;
318 self_csr.sub(other)
319 }
320
321 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
322 let self_csr = self.to_csr()?;
324 self_csr.mul(other)
325 }
326
327 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
328 let self_csr = self.to_csr()?;
330 self_csr.div(other)
331 }
332
333 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
334 let self_csr = self.to_csr()?;
336 self_csr.dot(other)
337 }
338
339 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
340 let (m, n) = self.shape();
341 if n != other.len() {
342 return Err(SparseError::DimensionMismatch {
343 expected: n,
344 found: other.len(),
345 });
346 }
347
348 let mut result = Array1::zeros(m);
349
350 for i in 0..self.data.len() {
351 let row = self.row[i];
352 let col = self.col[i];
353 result[row] = result[row] + self.data[i] * other[col];
354 }
355
356 Ok(result)
357 }
358
359 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
360 CooArray::new(
362 self.data.clone(),
363 self.col.clone(), self.row.clone(), (self.shape.1, self.shape.0), self.has_canonical_format,
367 )
368 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
369 }
370
371 fn copy(&self) -> Box<dyn SparseArray<T>> {
372 Box::new(self.clone())
373 }
374
375 fn get(&self, i: usize, j: usize) -> T {
376 if i >= self.shape.0 || j >= self.shape.1 {
377 return T::sparse_zero();
378 }
379
380 let mut sum = T::sparse_zero();
381 for idx in 0..self.data.len() {
382 if self.row[idx] == i && self.col[idx] == j {
383 sum = sum + self.data[idx];
384 }
385 }
386
387 sum
388 }
389
390 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
391 if i >= self.shape.0 || j >= self.shape.1 {
392 return Err(SparseError::IndexOutOfBounds {
393 index: (i, j),
394 shape: self.shape,
395 });
396 }
397
398 if value == T::sparse_zero() {
399 let mut new_data = Vec::new();
401 let mut new_row = Vec::new();
402 let mut new_col = Vec::new();
403
404 for idx in 0..self.data.len() {
405 if !(self.row[idx] == i && self.col[idx] == j) {
406 new_data.push(self.data[idx]);
407 new_row.push(self.row[idx]);
408 new_col.push(self.col[idx]);
409 }
410 }
411
412 self.data = Array1::from_vec(new_data);
413 self.row = Array1::from_vec(new_row);
414 self.col = Array1::from_vec(new_col);
415 } else {
416 self.set(i, j, T::sparse_zero())?;
418
419 let mut new_data = self.data.to_vec();
421 let mut new_row = self.row.to_vec();
422 let mut new_col = self.col.to_vec();
423
424 new_data.push(value);
425 new_row.push(i);
426 new_col.push(j);
427
428 self.data = Array1::from_vec(new_data);
429 self.row = Array1::from_vec(new_row);
430 self.col = Array1::from_vec(new_col);
431
432 self.has_canonical_format = false;
434 }
435
436 Ok(())
437 }
438
439 fn eliminate_zeros(&mut self) {
440 let mut new_data = Vec::new();
441 let mut new_row = Vec::new();
442 let mut new_col = Vec::new();
443
444 for i in 0..self.data.len() {
445 if !SparseElement::is_zero(&self.data[i]) {
446 new_data.push(self.data[i]);
447 new_row.push(self.row[i]);
448 new_col.push(self.col[i]);
449 }
450 }
451
452 self.data = Array1::from_vec(new_data);
453 self.row = Array1::from_vec(new_row);
454 self.col = Array1::from_vec(new_col);
455 }
456
457 fn sort_indices(&mut self) {
458 self.canonical_format();
459 }
460
461 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
462 if self.has_canonical_format {
463 return Box::new(self.clone());
464 }
465
466 let mut sorted = self.clone();
467 sorted.canonical_format();
468 Box::new(sorted)
469 }
470
471 fn has_sorted_indices(&self) -> bool {
472 self.has_canonical_format
473 }
474
475 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
476 let self_csr = self.to_csr()?;
478 self_csr.sum(axis)
479 }
480
481 fn max(&self) -> T {
482 if self.data.is_empty() {
483 return T::sparse_zero();
485 }
486
487 let mut max_val = self.data[0];
488 for &val in self.data.iter().skip(1) {
489 if val > max_val {
490 max_val = val;
491 }
492 }
493
494 let zero = T::sparse_zero();
496 if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
497 max_val = zero;
498 }
499
500 max_val
501 }
502
503 fn min(&self) -> T {
504 if self.data.is_empty() {
505 return T::sparse_zero();
507 }
508
509 let mut min_val = self.data[0];
510 for &val in self.data.iter().skip(1) {
511 if val < min_val {
512 min_val = val;
513 }
514 }
515
516 if min_val > T::sparse_zero() && self.nnz() < self.shape.0 * self.shape.1 {
518 min_val = T::sparse_zero();
519 }
520
521 min_val
522 }
523
524 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
525 let data_vec = self.data.to_vec();
527 let row_vec = self.row.to_vec();
528 let col_vec = self.col.to_vec();
529
530 if self.has_canonical_format {
532 (self.row.clone(), self.col.clone(), self.data.clone())
534 } else {
535 let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(data_vec.len());
537 for i in 0..data_vec.len() {
538 triplets.push((row_vec[i], col_vec[i], data_vec[i]));
539 }
540 triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
541
542 let mut result_row = Vec::new();
543 let mut result_col = Vec::new();
544 let mut result_data = Vec::new();
545
546 if !triplets.is_empty() {
547 let mut curr_row = triplets[0].0;
548 let mut curr_col = triplets[0].1;
549 let mut curr_sum = triplets[0].2;
550
551 for &(r, c, v) in triplets.iter().skip(1) {
552 if r == curr_row && c == curr_col {
553 curr_sum = curr_sum + v;
554 } else {
555 if curr_sum != T::sparse_zero() {
556 result_row.push(curr_row);
557 result_col.push(curr_col);
558 result_data.push(curr_sum);
559 }
560 curr_row = r;
561 curr_col = c;
562 curr_sum = v;
563 }
564 }
565
566 if curr_sum != T::sparse_zero() {
568 result_row.push(curr_row);
569 result_col.push(curr_col);
570 result_data.push(curr_sum);
571 }
572 }
573
574 (
575 Array1::from_vec(result_row),
576 Array1::from_vec(result_col),
577 Array1::from_vec(result_data),
578 )
579 }
580 }
581
582 fn slice(
583 &self,
584 row_range: (usize, usize),
585 col_range: (usize, usize),
586 ) -> SparseResult<Box<dyn SparseArray<T>>> {
587 let (start_row, end_row) = row_range;
588 let (start_col, end_col) = col_range;
589
590 if start_row >= self.shape.0
591 || end_row > self.shape.0
592 || start_col >= self.shape.1
593 || end_col > self.shape.1
594 {
595 return Err(SparseError::InvalidSliceRange);
596 }
597
598 if start_row >= end_row || start_col >= end_col {
599 return Err(SparseError::InvalidSliceRange);
600 }
601
602 let mut new_data = Vec::new();
603 let mut new_row = Vec::new();
604 let mut new_col = Vec::new();
605
606 for i in 0..self.data.len() {
607 let r = self.row[i];
608 let c = self.col[i];
609
610 if r >= start_row && r < end_row && c >= start_col && c < end_col {
611 new_data.push(self.data[i]);
612 new_row.push(r - start_row); new_col.push(c - start_col); }
615 }
616
617 CooArray::new(
618 Array1::from_vec(new_data),
619 Array1::from_vec(new_row),
620 Array1::from_vec(new_col),
621 (end_row - start_row, end_col - start_col),
622 false,
623 )
624 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
625 }
626
627 fn as_any(&self) -> &dyn std::any::Any {
628 self
629 }
630}
631
632impl<T> fmt::Debug for CooArray<T>
633where
634 T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
635{
636 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
637 write!(
638 f,
639 "CooArray<{}x{}, nnz={}>",
640 self.shape.0,
641 self.shape.1,
642 self.nnz()
643 )
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650
651 #[test]
652 fn test_coo_array_construction() {
653 let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
654 let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
655 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
656 let shape = (3, 3);
657
658 let coo = CooArray::new(data, row, col, shape, false).unwrap();
659
660 assert_eq!(coo.shape(), (3, 3));
661 assert_eq!(coo.nnz(), 5);
662 assert_eq!(coo.get(0, 0), 1.0);
663 assert_eq!(coo.get(0, 2), 2.0);
664 assert_eq!(coo.get(1, 1), 3.0);
665 assert_eq!(coo.get(2, 0), 4.0);
666 assert_eq!(coo.get(2, 2), 5.0);
667 assert_eq!(coo.get(0, 1), 0.0);
668 }
669
670 #[test]
671 fn test_coo_from_triplets() {
672 let rows = vec![0, 0, 1, 2, 2];
673 let cols = vec![0, 2, 1, 0, 2];
674 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
675 let shape = (3, 3);
676
677 let coo = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
678
679 assert_eq!(coo.shape(), (3, 3));
680 assert_eq!(coo.nnz(), 5);
681 assert_eq!(coo.get(0, 0), 1.0);
682 assert_eq!(coo.get(0, 2), 2.0);
683 assert_eq!(coo.get(1, 1), 3.0);
684 assert_eq!(coo.get(2, 0), 4.0);
685 assert_eq!(coo.get(2, 2), 5.0);
686 assert_eq!(coo.get(0, 1), 0.0);
687 }
688
689 #[test]
690 fn test_coo_array_to_array() {
691 let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
692 let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
693 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
694 let shape = (3, 3);
695
696 let coo = CooArray::new(data, row, col, shape, false).unwrap();
697 let dense = coo.to_array();
698
699 assert_eq!(dense.shape(), &[3, 3]);
700 assert_eq!(dense[[0, 0]], 1.0);
701 assert_eq!(dense[[0, 1]], 0.0);
702 assert_eq!(dense[[0, 2]], 2.0);
703 assert_eq!(dense[[1, 0]], 0.0);
704 assert_eq!(dense[[1, 1]], 3.0);
705 assert_eq!(dense[[1, 2]], 0.0);
706 assert_eq!(dense[[2, 0]], 4.0);
707 assert_eq!(dense[[2, 1]], 0.0);
708 assert_eq!(dense[[2, 2]], 5.0);
709 }
710
711 #[test]
712 fn test_coo_array_duplicate_entries() {
713 let row = Array1::from_vec(vec![0, 0, 0, 1, 1]);
714 let col = Array1::from_vec(vec![0, 0, 1, 0, 0]);
715 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
716 let shape = (2, 2);
717
718 let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
719
720 coo.sum_duplicates();
722
723 assert_eq!(coo.nnz(), 3);
725 assert_eq!(coo.get(0, 0), 3.0); assert_eq!(coo.get(0, 1), 3.0);
727 assert_eq!(coo.get(1, 0), 9.0); }
729
730 #[test]
731 fn test_coo_set_get() {
732 let row = Array1::from_vec(vec![0, 1]);
733 let col = Array1::from_vec(vec![0, 1]);
734 let data = Array1::from_vec(vec![1.0, 2.0]);
735 let shape = (2, 2);
736
737 let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
738
739 coo.set(0, 1, 3.0).unwrap();
741 assert_eq!(coo.get(0, 1), 3.0);
742
743 coo.set(0, 0, 4.0).unwrap();
745 assert_eq!(coo.get(0, 0), 4.0);
746
747 coo.set(0, 0, 0.0).unwrap();
749 assert_eq!(coo.get(0, 0), 0.0);
750
751 assert_eq!(coo.nnz(), 2);
753 }
754
755 #[test]
756 fn test_coo_canonical_format() {
757 let row = Array1::from_vec(vec![1, 0, 2, 0]);
758 let col = Array1::from_vec(vec![1, 0, 2, 2]);
759 let data = Array1::from_vec(vec![3.0, 1.0, 5.0, 2.0]);
760 let shape = (3, 3);
761
762 let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
763
764 assert!(!coo.has_canonical_format);
766
767 coo.canonical_format();
769
770 assert!(coo.has_canonical_format);
772
773 assert_eq!(coo.row[0], 0);
775 assert_eq!(coo.col[0], 0);
776 assert_eq!(coo.data[0], 1.0);
777
778 assert_eq!(coo.row[1], 0);
779 assert_eq!(coo.col[1], 2);
780 assert_eq!(coo.data[1], 2.0);
781
782 assert_eq!(coo.row[2], 1);
783 assert_eq!(coo.col[2], 1);
784 assert_eq!(coo.data[2], 3.0);
785
786 assert_eq!(coo.row[3], 2);
787 assert_eq!(coo.col[3], 2);
788 assert_eq!(coo.data[3], 5.0);
789 }
790
791 #[test]
792 fn test_coo_to_csr() {
793 let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
794 let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
795 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
796 let shape = (3, 3);
797
798 let coo = CooArray::new(data, row, col, shape, false).unwrap();
799
800 let csr = coo.to_csr().unwrap();
802
803 let dense = csr.to_array();
805 assert_eq!(dense[[0, 0]], 1.0);
806 assert_eq!(dense[[0, 2]], 2.0);
807 assert_eq!(dense[[1, 1]], 3.0);
808 assert_eq!(dense[[2, 0]], 4.0);
809 assert_eq!(dense[[2, 2]], 5.0);
810 }
811
812 #[test]
813 fn test_coo_transpose() {
814 let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
815 let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
816 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
817 let shape = (3, 3);
818
819 let coo = CooArray::new(data, row, col, shape, false).unwrap();
820
821 let transposed = coo.transpose().unwrap();
823
824 assert_eq!(transposed.shape(), (3, 3));
826
827 let dense = transposed.to_array();
829 assert_eq!(dense[[0, 0]], 1.0);
830 assert_eq!(dense[[2, 0]], 2.0);
831 assert_eq!(dense[[1, 1]], 3.0);
832 assert_eq!(dense[[0, 2]], 4.0);
833 assert_eq!(dense[[2, 2]], 5.0);
834 }
835}