1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::csr_array::CsrArray;
12use crate::error::{SparseError, SparseResult};
13use crate::sparray::{SparseArray, SparseSum};
14
15#[derive(Clone)]
31pub struct CooArray<T>
32where
33 T: Float
34 + Add<Output = T>
35 + Sub<Output = T>
36 + Mul<Output = T>
37 + Div<Output = T>
38 + Debug
39 + Copy
40 + 'static,
41{
42 row: Array1<usize>,
44 col: Array1<usize>,
46 data: Array1<T>,
48 shape: (usize, usize),
50 has_canonical_format: bool,
52}
53
54impl<T> CooArray<T>
55where
56 T: Float
57 + Add<Output = T>
58 + Sub<Output = T>
59 + Mul<Output = T>
60 + Div<Output = T>
61 + Debug
62 + Copy
63 + 'static,
64{
65 pub fn new(
80 data: Array1<T>,
81 row: Array1<usize>,
82 col: Array1<usize>,
83 shape: (usize, usize),
84 has_canonical_format: bool,
85 ) -> SparseResult<Self> {
86 if data.len() != row.len() || data.len() != col.len() {
88 return Err(SparseError::InconsistentData {
89 reason: "data, row, and col must have the same length".to_string(),
90 });
91 }
92
93 if let Some(&max_row) = row.iter().max() {
94 if max_row >= shape.0 {
95 return Err(SparseError::IndexOutOfBounds {
96 index: (max_row, 0),
97 shape,
98 });
99 }
100 }
101
102 if let Some(&max_col) = col.iter().max() {
103 if max_col >= shape.1 {
104 return Err(SparseError::IndexOutOfBounds {
105 index: (0, max_col),
106 shape,
107 });
108 }
109 }
110
111 Ok(Self {
112 data,
113 row,
114 col,
115 shape,
116 has_canonical_format,
117 })
118 }
119
120 pub fn from_triplets(
135 row: &[usize],
136 col: &[usize],
137 data: &[T],
138 shape: (usize, usize),
139 sorted: bool,
140 ) -> SparseResult<Self> {
141 let row_array = Array1::from_vec(row.to_vec());
142 let col_array = Array1::from_vec(col.to_vec());
143 let data_array = Array1::from_vec(data.to_vec());
144
145 Self::new(data_array, row_array, col_array, shape, sorted)
146 }
147
148 pub fn get_rows(&self) -> &Array1<usize> {
150 &self.row
151 }
152
153 pub fn get_cols(&self) -> &Array1<usize> {
155 &self.col
156 }
157
158 pub fn get_data(&self) -> &Array1<T> {
160 &self.data
161 }
162
163 pub fn canonical_format(&mut self) {
165 if self.has_canonical_format {
166 return;
167 }
168
169 let n = self.data.len();
170 let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(n);
171
172 for i in 0..n {
173 triplets.push((self.row[i], self.col[i], self.data[i]));
174 }
175
176 triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
177
178 for (i, &(r, c, v)) in triplets.iter().enumerate() {
179 self.row[i] = r;
180 self.col[i] = c;
181 self.data[i] = v;
182 }
183
184 self.has_canonical_format = true;
185 }
186
187 pub fn sum_duplicates(&mut self) {
189 self.canonical_format();
190
191 let n = self.data.len();
192 if n == 0 {
193 return;
194 }
195
196 let mut new_data = Vec::new();
197 let mut new_row = Vec::new();
198 let mut new_col = Vec::new();
199
200 let mut curr_row = self.row[0];
201 let mut curr_col = self.col[0];
202 let mut curr_sum = self.data[0];
203
204 for i in 1..n {
205 if self.row[i] == curr_row && self.col[i] == curr_col {
206 curr_sum = curr_sum + self.data[i];
207 } else {
208 if !curr_sum.is_zero() {
209 new_data.push(curr_sum);
210 new_row.push(curr_row);
211 new_col.push(curr_col);
212 }
213 curr_row = self.row[i];
214 curr_col = self.col[i];
215 curr_sum = self.data[i];
216 }
217 }
218
219 if !curr_sum.is_zero() {
221 new_data.push(curr_sum);
222 new_row.push(curr_row);
223 new_col.push(curr_col);
224 }
225
226 self.data = Array1::from_vec(new_data);
227 self.row = Array1::from_vec(new_row);
228 self.col = Array1::from_vec(new_col);
229 }
230}
231
232impl<T> SparseArray<T> for CooArray<T>
233where
234 T: Float
235 + Add<Output = T>
236 + Sub<Output = T>
237 + Mul<Output = T>
238 + Div<Output = T>
239 + Debug
240 + Copy
241 + 'static,
242{
243 fn shape(&self) -> (usize, usize) {
244 self.shape
245 }
246
247 fn nnz(&self) -> usize {
248 self.data.len()
249 }
250
251 fn dtype(&self) -> &str {
252 "float" }
254
255 fn to_array(&self) -> Array2<T> {
256 let (rows, cols) = self.shape;
257 let mut result = Array2::zeros((rows, cols));
258
259 for i in 0..self.data.len() {
260 let r = self.row[i];
261 let c = self.col[i];
262 result[[r, c]] = result[[r, c]] + self.data[i]; }
264
265 result
266 }
267
268 fn toarray(&self) -> Array2<T> {
269 self.to_array()
270 }
271
272 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
273 let mut new_coo = self.clone();
275 new_coo.sum_duplicates();
276 Ok(Box::new(new_coo))
277 }
278
279 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
280 let mut data_vec = self.data.to_vec();
282 let mut row_vec = self.row.to_vec();
283 let mut col_vec = self.col.to_vec();
284
285 let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(data_vec.len());
287 for i in 0..data_vec.len() {
288 triplets.push((row_vec[i], col_vec[i], data_vec[i]));
289 }
290 triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
291
292 for (i, &(r, c, v)) in triplets.iter().enumerate() {
293 row_vec[i] = r;
294 col_vec[i] = c;
295 data_vec[i] = v;
296 }
297
298 CsrArray::from_triplets(&row_vec, &col_vec, &data_vec, self.shape, true)
300 .map(|csr| Box::new(csr) as Box<dyn SparseArray<T>>)
301 }
302
303 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
304 let csr = self.to_csr()?;
307 csr.transpose()
308 }
309
310 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
311 Ok(Box::new(self.clone()))
313 }
314
315 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
316 Ok(Box::new(self.clone()))
318 }
319
320 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
321 Ok(Box::new(self.clone()))
323 }
324
325 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
326 Ok(Box::new(self.clone()))
328 }
329
330 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
331 let self_csr = self.to_csr()?;
333 self_csr.add(other)
334 }
335
336 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
337 let self_csr = self.to_csr()?;
339 self_csr.sub(other)
340 }
341
342 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
343 let self_csr = self.to_csr()?;
345 self_csr.mul(other)
346 }
347
348 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
349 let self_csr = self.to_csr()?;
351 self_csr.div(other)
352 }
353
354 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
355 let self_csr = self.to_csr()?;
357 self_csr.dot(other)
358 }
359
360 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
361 let (m, n) = self.shape();
362 if n != other.len() {
363 return Err(SparseError::DimensionMismatch {
364 expected: n,
365 found: other.len(),
366 });
367 }
368
369 let mut result = Array1::zeros(m);
370
371 for i in 0..self.data.len() {
372 let row = self.row[i];
373 let col = self.col[i];
374 result[row] = result[row] + self.data[i] * other[col];
375 }
376
377 Ok(result)
378 }
379
380 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
381 CooArray::new(
383 self.data.clone(),
384 self.col.clone(), self.row.clone(), (self.shape.1, self.shape.0), self.has_canonical_format,
388 )
389 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
390 }
391
392 fn copy(&self) -> Box<dyn SparseArray<T>> {
393 Box::new(self.clone())
394 }
395
396 fn get(&self, i: usize, j: usize) -> T {
397 if i >= self.shape.0 || j >= self.shape.1 {
398 return T::zero();
399 }
400
401 let mut sum = T::zero();
402 for idx in 0..self.data.len() {
403 if self.row[idx] == i && self.col[idx] == j {
404 sum = sum + self.data[idx];
405 }
406 }
407
408 sum
409 }
410
411 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
412 if i >= self.shape.0 || j >= self.shape.1 {
413 return Err(SparseError::IndexOutOfBounds {
414 index: (i, j),
415 shape: self.shape,
416 });
417 }
418
419 if value.is_zero() {
420 let mut new_data = Vec::new();
422 let mut new_row = Vec::new();
423 let mut new_col = Vec::new();
424
425 for idx in 0..self.data.len() {
426 if !(self.row[idx] == i && self.col[idx] == j) {
427 new_data.push(self.data[idx]);
428 new_row.push(self.row[idx]);
429 new_col.push(self.col[idx]);
430 }
431 }
432
433 self.data = Array1::from_vec(new_data);
434 self.row = Array1::from_vec(new_row);
435 self.col = Array1::from_vec(new_col);
436 } else {
437 self.set(i, j, T::zero())?;
439
440 let mut new_data = self.data.to_vec();
442 let mut new_row = self.row.to_vec();
443 let mut new_col = self.col.to_vec();
444
445 new_data.push(value);
446 new_row.push(i);
447 new_col.push(j);
448
449 self.data = Array1::from_vec(new_data);
450 self.row = Array1::from_vec(new_row);
451 self.col = Array1::from_vec(new_col);
452
453 self.has_canonical_format = false;
455 }
456
457 Ok(())
458 }
459
460 fn eliminate_zeros(&mut self) {
461 let mut new_data = Vec::new();
462 let mut new_row = Vec::new();
463 let mut new_col = Vec::new();
464
465 for i in 0..self.data.len() {
466 if !self.data[i].is_zero() {
467 new_data.push(self.data[i]);
468 new_row.push(self.row[i]);
469 new_col.push(self.col[i]);
470 }
471 }
472
473 self.data = Array1::from_vec(new_data);
474 self.row = Array1::from_vec(new_row);
475 self.col = Array1::from_vec(new_col);
476 }
477
478 fn sort_indices(&mut self) {
479 self.canonical_format();
480 }
481
482 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
483 if self.has_canonical_format {
484 return Box::new(self.clone());
485 }
486
487 let mut sorted = self.clone();
488 sorted.canonical_format();
489 Box::new(sorted)
490 }
491
492 fn has_sorted_indices(&self) -> bool {
493 self.has_canonical_format
494 }
495
496 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
497 let self_csr = self.to_csr()?;
499 self_csr.sum(axis)
500 }
501
502 fn max(&self) -> T {
503 if self.data.is_empty() {
504 return T::neg_infinity();
505 }
506
507 let mut max_val = self.data[0];
508 for &val in self.data.iter().skip(1) {
509 if val > max_val {
510 max_val = val;
511 }
512 }
513
514 if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
516 max_val = T::zero();
517 }
518
519 max_val
520 }
521
522 fn min(&self) -> T {
523 if self.data.is_empty() {
524 return T::infinity();
525 }
526
527 let mut min_val = self.data[0];
528 for &val in self.data.iter().skip(1) {
529 if val < min_val {
530 min_val = val;
531 }
532 }
533
534 if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
536 min_val = T::zero();
537 }
538
539 min_val
540 }
541
542 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
543 let data_vec = self.data.to_vec();
545 let row_vec = self.row.to_vec();
546 let col_vec = self.col.to_vec();
547
548 if self.has_canonical_format {
550 (self.row.clone(), self.col.clone(), self.data.clone())
552 } else {
553 let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(data_vec.len());
555 for i in 0..data_vec.len() {
556 triplets.push((row_vec[i], col_vec[i], data_vec[i]));
557 }
558 triplets.sort_by(|&(r1, c1_, _), &(r2, c2_, _)| (r1, c1_).cmp(&(r2, c2_)));
559
560 let mut result_row = Vec::new();
561 let mut result_col = Vec::new();
562 let mut result_data = Vec::new();
563
564 if !triplets.is_empty() {
565 let mut curr_row = triplets[0].0;
566 let mut curr_col = triplets[0].1;
567 let mut curr_sum = triplets[0].2;
568
569 for &(r, c, v) in triplets.iter().skip(1) {
570 if r == curr_row && c == curr_col {
571 curr_sum = curr_sum + v;
572 } else {
573 if !curr_sum.is_zero() {
574 result_row.push(curr_row);
575 result_col.push(curr_col);
576 result_data.push(curr_sum);
577 }
578 curr_row = r;
579 curr_col = c;
580 curr_sum = v;
581 }
582 }
583
584 if !curr_sum.is_zero() {
586 result_row.push(curr_row);
587 result_col.push(curr_col);
588 result_data.push(curr_sum);
589 }
590 }
591
592 (
593 Array1::from_vec(result_row),
594 Array1::from_vec(result_col),
595 Array1::from_vec(result_data),
596 )
597 }
598 }
599
600 fn slice(
601 &self,
602 row_range: (usize, usize),
603 col_range: (usize, usize),
604 ) -> SparseResult<Box<dyn SparseArray<T>>> {
605 let (start_row, end_row) = row_range;
606 let (start_col, end_col) = col_range;
607
608 if start_row >= self.shape.0
609 || end_row > self.shape.0
610 || start_col >= self.shape.1
611 || end_col > self.shape.1
612 {
613 return Err(SparseError::InvalidSliceRange);
614 }
615
616 if start_row >= end_row || start_col >= end_col {
617 return Err(SparseError::InvalidSliceRange);
618 }
619
620 let mut new_data = Vec::new();
621 let mut new_row = Vec::new();
622 let mut new_col = Vec::new();
623
624 for i in 0..self.data.len() {
625 let r = self.row[i];
626 let c = self.col[i];
627
628 if r >= start_row && r < end_row && c >= start_col && c < end_col {
629 new_data.push(self.data[i]);
630 new_row.push(r - start_row); new_col.push(c - start_col); }
633 }
634
635 CooArray::new(
636 Array1::from_vec(new_data),
637 Array1::from_vec(new_row),
638 Array1::from_vec(new_col),
639 (end_row - start_row, end_col - start_col),
640 false,
641 )
642 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
643 }
644
645 fn as_any(&self) -> &dyn std::any::Any {
646 self
647 }
648}
649
650impl<T> fmt::Debug for CooArray<T>
651where
652 T: Float
653 + Add<Output = T>
654 + Sub<Output = T>
655 + Mul<Output = T>
656 + Div<Output = T>
657 + Debug
658 + Copy
659 + 'static,
660{
661 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
662 write!(
663 f,
664 "CooArray<{}x{}, nnz={}>",
665 self.shape.0,
666 self.shape.1,
667 self.nnz()
668 )
669 }
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675
676 #[test]
677 fn test_coo_array_construction() {
678 let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
679 let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
680 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
681 let shape = (3, 3);
682
683 let coo = CooArray::new(data, row, col, shape, false).unwrap();
684
685 assert_eq!(coo.shape(), (3, 3));
686 assert_eq!(coo.nnz(), 5);
687 assert_eq!(coo.get(0, 0), 1.0);
688 assert_eq!(coo.get(0, 2), 2.0);
689 assert_eq!(coo.get(1, 1), 3.0);
690 assert_eq!(coo.get(2, 0), 4.0);
691 assert_eq!(coo.get(2, 2), 5.0);
692 assert_eq!(coo.get(0, 1), 0.0);
693 }
694
695 #[test]
696 fn test_coo_from_triplets() {
697 let rows = vec![0, 0, 1, 2, 2];
698 let cols = vec![0, 2, 1, 0, 2];
699 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
700 let shape = (3, 3);
701
702 let coo = CooArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
703
704 assert_eq!(coo.shape(), (3, 3));
705 assert_eq!(coo.nnz(), 5);
706 assert_eq!(coo.get(0, 0), 1.0);
707 assert_eq!(coo.get(0, 2), 2.0);
708 assert_eq!(coo.get(1, 1), 3.0);
709 assert_eq!(coo.get(2, 0), 4.0);
710 assert_eq!(coo.get(2, 2), 5.0);
711 assert_eq!(coo.get(0, 1), 0.0);
712 }
713
714 #[test]
715 fn test_coo_array_to_array() {
716 let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
717 let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
718 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
719 let shape = (3, 3);
720
721 let coo = CooArray::new(data, row, col, shape, false).unwrap();
722 let dense = coo.to_array();
723
724 assert_eq!(dense.shape(), &[3, 3]);
725 assert_eq!(dense[[0, 0]], 1.0);
726 assert_eq!(dense[[0, 1]], 0.0);
727 assert_eq!(dense[[0, 2]], 2.0);
728 assert_eq!(dense[[1, 0]], 0.0);
729 assert_eq!(dense[[1, 1]], 3.0);
730 assert_eq!(dense[[1, 2]], 0.0);
731 assert_eq!(dense[[2, 0]], 4.0);
732 assert_eq!(dense[[2, 1]], 0.0);
733 assert_eq!(dense[[2, 2]], 5.0);
734 }
735
736 #[test]
737 fn test_coo_array_duplicate_entries() {
738 let row = Array1::from_vec(vec![0, 0, 0, 1, 1]);
739 let col = Array1::from_vec(vec![0, 0, 1, 0, 0]);
740 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
741 let shape = (2, 2);
742
743 let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
744
745 coo.sum_duplicates();
747
748 assert_eq!(coo.nnz(), 3);
750 assert_eq!(coo.get(0, 0), 3.0); assert_eq!(coo.get(0, 1), 3.0);
752 assert_eq!(coo.get(1, 0), 9.0); }
754
755 #[test]
756 fn test_coo_set_get() {
757 let row = Array1::from_vec(vec![0, 1]);
758 let col = Array1::from_vec(vec![0, 1]);
759 let data = Array1::from_vec(vec![1.0, 2.0]);
760 let shape = (2, 2);
761
762 let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
763
764 coo.set(0, 1, 3.0).unwrap();
766 assert_eq!(coo.get(0, 1), 3.0);
767
768 coo.set(0, 0, 4.0).unwrap();
770 assert_eq!(coo.get(0, 0), 4.0);
771
772 coo.set(0, 0, 0.0).unwrap();
774 assert_eq!(coo.get(0, 0), 0.0);
775
776 assert_eq!(coo.nnz(), 2);
778 }
779
780 #[test]
781 fn test_coo_canonical_format() {
782 let row = Array1::from_vec(vec![1, 0, 2, 0]);
783 let col = Array1::from_vec(vec![1, 0, 2, 2]);
784 let data = Array1::from_vec(vec![3.0, 1.0, 5.0, 2.0]);
785 let shape = (3, 3);
786
787 let mut coo = CooArray::new(data, row, col, shape, false).unwrap();
788
789 assert!(!coo.has_canonical_format);
791
792 coo.canonical_format();
794
795 assert!(coo.has_canonical_format);
797
798 assert_eq!(coo.row[0], 0);
800 assert_eq!(coo.col[0], 0);
801 assert_eq!(coo.data[0], 1.0);
802
803 assert_eq!(coo.row[1], 0);
804 assert_eq!(coo.col[1], 2);
805 assert_eq!(coo.data[1], 2.0);
806
807 assert_eq!(coo.row[2], 1);
808 assert_eq!(coo.col[2], 1);
809 assert_eq!(coo.data[2], 3.0);
810
811 assert_eq!(coo.row[3], 2);
812 assert_eq!(coo.col[3], 2);
813 assert_eq!(coo.data[3], 5.0);
814 }
815
816 #[test]
817 fn test_coo_to_csr() {
818 let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
819 let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
820 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
821 let shape = (3, 3);
822
823 let coo = CooArray::new(data, row, col, shape, false).unwrap();
824
825 let csr = coo.to_csr().unwrap();
827
828 let dense = csr.to_array();
830 assert_eq!(dense[[0, 0]], 1.0);
831 assert_eq!(dense[[0, 2]], 2.0);
832 assert_eq!(dense[[1, 1]], 3.0);
833 assert_eq!(dense[[2, 0]], 4.0);
834 assert_eq!(dense[[2, 2]], 5.0);
835 }
836
837 #[test]
838 fn test_coo_transpose() {
839 let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
840 let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
841 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
842 let shape = (3, 3);
843
844 let coo = CooArray::new(data, row, col, shape, false).unwrap();
845
846 let transposed = coo.transpose().unwrap();
848
849 assert_eq!(transposed.shape(), (3, 3));
851
852 let dense = transposed.to_array();
854 assert_eq!(dense[[0, 0]], 1.0);
855 assert_eq!(dense[[2, 0]], 2.0);
856 assert_eq!(dense[[1, 1]], 3.0);
857 assert_eq!(dense[[0, 2]], 4.0);
858 assert_eq!(dense[[2, 2]], 5.0);
859 }
860}