1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::error::{SparseError, SparseResult};
14use crate::sparray::{SparseArray, SparseSum};
15
16#[derive(Clone)]
32pub struct CscArray<T>
33where
34 T: SparseElement + Div<Output = T> + 'static,
35{
36 data: Array1<T>,
38 indices: Array1<usize>,
40 indptr: Array1<usize>,
42 shape: (usize, usize),
44 has_sorted_indices: bool,
46}
47
48impl<T> CscArray<T>
49where
50 T: SparseElement + Div<Output = T> + Zero + 'static,
51{
52 pub fn new(
66 data: Array1<T>,
67 indices: Array1<usize>,
68 indptr: Array1<usize>,
69 shape: (usize, usize),
70 ) -> SparseResult<Self> {
71 if data.len() != indices.len() {
73 return Err(SparseError::InconsistentData {
74 reason: "data and indices must have the same length".to_string(),
75 });
76 }
77
78 if indptr.len() != shape.1 + 1 {
79 return Err(SparseError::InconsistentData {
80 reason: format!(
81 "indptr length ({}) must be one more than the number of columns ({})",
82 indptr.len(),
83 shape.1
84 ),
85 });
86 }
87
88 if let Some(&max_idx) = indices.iter().max() {
89 if max_idx >= shape.0 {
90 return Err(SparseError::IndexOutOfBounds {
91 index: (max_idx, 0),
92 shape,
93 });
94 }
95 }
96
97 if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
98 if first != 0 {
99 return Err(SparseError::InconsistentData {
100 reason: "first element of indptr must be 0".to_string(),
101 });
102 }
103
104 if last != data.len() {
105 return Err(SparseError::InconsistentData {
106 reason: format!(
107 "last element of indptr ({}) must equal data length ({})",
108 last,
109 data.len()
110 ),
111 });
112 }
113 }
114
115 let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
116
117 Ok(Self {
118 data,
119 indices,
120 indptr,
121 shape,
122 has_sorted_indices,
123 })
124 }
125
126 pub fn from_triplets(
141 rows: &[usize],
142 cols: &[usize],
143 data: &[T],
144 shape: (usize, usize),
145 sorted: bool,
146 ) -> SparseResult<Self> {
147 if rows.len() != cols.len() || rows.len() != data.len() {
148 return Err(SparseError::InconsistentData {
149 reason: "rows, cols, and data must have the same length".to_string(),
150 });
151 }
152
153 if rows.is_empty() {
154 let indptr = Array1::zeros(shape.1 + 1);
156 return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
157 }
158
159 let nnz = rows.len();
160 let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
161
162 for i in 0..nnz {
163 if rows[i] >= shape.0 || cols[i] >= shape.1 {
164 return Err(SparseError::IndexOutOfBounds {
165 index: (rows[i], cols[i]),
166 shape,
167 });
168 }
169 all_data.push((rows[i], cols[i], data[i]));
170 }
171
172 if !sorted {
173 all_data.sort_by_key(|&(_, col_, _)| col_);
174 }
175
176 let mut col_counts = vec![0; shape.1];
178 for &(_, col_, _) in &all_data {
179 col_counts[col_] += 1;
180 }
181
182 let mut indptr = Vec::with_capacity(shape.1 + 1);
184 indptr.push(0);
185 let mut cumsum = 0;
186 for &count in &col_counts {
187 cumsum += count;
188 indptr.push(cumsum);
189 }
190
191 let mut indices = Vec::with_capacity(nnz);
193 let mut values = Vec::with_capacity(nnz);
194
195 for (row_, _, val) in all_data {
196 indices.push(row_);
197 values.push(val);
198 }
199
200 Self::new(
201 Array1::from_vec(values),
202 Array1::from_vec(indices),
203 Array1::from_vec(indptr),
204 shape,
205 )
206 }
207
208 fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
210 for col in 0..indptr.len() - 1 {
211 let start = indptr[col];
212 let end = indptr[col + 1];
213
214 for i in start..end.saturating_sub(1) {
215 if i + 1 < indices.len() && indices[i] > indices[i + 1] {
216 return false;
217 }
218 }
219 }
220 true
221 }
222
223 pub fn get_data(&self) -> &Array1<T> {
225 &self.data
226 }
227
228 pub fn get_indices(&self) -> &Array1<usize> {
230 &self.indices
231 }
232
233 pub fn get_indptr(&self) -> &Array1<usize> {
235 &self.indptr
236 }
237}
238
239impl<T> SparseArray<T> for CscArray<T>
240where
241 T: SparseElement + Div<Output = T> + Float + 'static,
242{
243 fn shape(&self) -> (usize, usize) {
244 self.shape
245 }
246
247 fn nnz(&self) -> usize {
248 self.data.len()
249 }
250
251 fn dtype(&self) -> &str {
252 "float" }
254
255 fn to_array(&self) -> Array2<T> {
256 let (rows, cols) = self.shape;
257 let mut result = Array2::zeros((rows, cols));
258
259 for col in 0..cols {
260 let start = self.indptr[col];
261 let end = self.indptr[col + 1];
262
263 for i in start..end {
264 let row = self.indices[i];
265 result[[row, col]] = self.data[i];
266 }
267 }
268
269 result
270 }
271
272 fn toarray(&self) -> Array2<T> {
273 self.to_array()
274 }
275
276 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
277 let nnz = self.nnz();
279 let mut row_indices = Vec::with_capacity(nnz);
280 let mut col_indices = Vec::with_capacity(nnz);
281 let mut values = Vec::with_capacity(nnz);
282
283 for col in 0..self.shape.1 {
284 let start = self.indptr[col];
285 let end = self.indptr[col + 1];
286
287 for idx in start..end {
288 row_indices.push(self.indices[idx]);
289 col_indices.push(col);
290 values.push(self.data[idx]);
291 }
292 }
293
294 CooArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
295 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
296 }
297
298 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
299 self.to_coo()?.to_csr()
301 }
302
303 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
304 Ok(Box::new(self.clone()))
305 }
306
307 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
308 self.to_coo()?.to_dok()
311 }
312
313 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
314 self.to_coo()?.to_lil()
317 }
318
319 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
320 self.to_coo()?.to_dia()
323 }
324
325 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
326 self.to_coo()?.to_bsr()
329 }
330
331 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
332 let self_csr = self.to_csr()?;
334 self_csr.add(other)
335 }
336
337 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
338 let self_csr = self.to_csr()?;
340 self_csr.sub(other)
341 }
342
343 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
344 let self_csr = self.to_csr()?;
347 self_csr.mul(other)
348 }
349
350 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
351 let self_csr = self.to_csr()?;
354 self_csr.div(other)
355 }
356
357 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
358 let self_csr = self.to_csr()?;
361 self_csr.dot(other)
362 }
363
364 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
365 let (m, n) = self.shape();
366 if n != other.len() {
367 return Err(SparseError::DimensionMismatch {
368 expected: n,
369 found: other.len(),
370 });
371 }
372
373 let mut result = Array1::zeros(m);
374
375 for col in 0..n {
376 let start = self.indptr[col];
377 let end = self.indptr[col + 1];
378
379 let val = other[col];
380 if !SparseElement::is_zero(&val) {
381 for idx in start..end {
382 let row = self.indices[idx];
383 result[row] = result[row] + self.data[idx] * val;
384 }
385 }
386 }
387
388 Ok(result)
389 }
390
391 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
392 CsrArray::new(
394 self.data.clone(),
395 self.indices.clone(),
396 self.indptr.clone(),
397 (self.shape.1, self.shape.0), )
399 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
400 }
401
402 fn copy(&self) -> Box<dyn SparseArray<T>> {
403 Box::new(self.clone())
404 }
405
406 fn get(&self, i: usize, j: usize) -> T {
407 if i >= self.shape.0 || j >= self.shape.1 {
408 return T::sparse_zero();
409 }
410
411 let start = self.indptr[j];
412 let end = self.indptr[j + 1];
413
414 for idx in start..end {
415 if self.indices[idx] == i {
416 return self.data[idx];
417 }
418
419 if self.has_sorted_indices && self.indices[idx] > i {
421 break;
422 }
423 }
424
425 T::sparse_zero()
426 }
427
428 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
429 if i >= self.shape.0 || j >= self.shape.1 {
432 return Err(SparseError::IndexOutOfBounds {
433 index: (i, j),
434 shape: self.shape,
435 });
436 }
437
438 let start = self.indptr[j];
439 let end = self.indptr[j + 1];
440
441 for idx in start..end {
443 if self.indices[idx] == i {
444 self.data[idx] = value;
446 return Ok(());
447 }
448
449 if self.has_sorted_indices && self.indices[idx] > i {
451 return Err(SparseError::NotImplemented(
454 "Inserting new elements in CSC format".to_string(),
455 ));
456 }
457 }
458
459 Err(SparseError::NotImplemented(
462 "Inserting new elements in CSC format".to_string(),
463 ))
464 }
465
466 fn eliminate_zeros(&mut self) {
467 let mut new_data = Vec::new();
469 let mut new_indices = Vec::new();
470 let mut new_indptr = vec![0];
471
472 let (_, cols) = self.shape;
473
474 for col in 0..cols {
475 let start = self.indptr[col];
476 let end = self.indptr[col + 1];
477
478 for idx in start..end {
479 if !SparseElement::is_zero(&self.data[idx]) {
480 new_data.push(self.data[idx]);
481 new_indices.push(self.indices[idx]);
482 }
483 }
484 new_indptr.push(new_data.len());
485 }
486
487 self.data = Array1::from_vec(new_data);
489 self.indices = Array1::from_vec(new_indices);
490 self.indptr = Array1::from_vec(new_indptr);
491 }
492
493 fn sort_indices(&mut self) {
494 if self.has_sorted_indices {
495 return;
496 }
497
498 let (_, cols) = self.shape;
499
500 for col in 0..cols {
501 let start = self.indptr[col];
502 let end = self.indptr[col + 1];
503
504 if start == end {
505 continue;
506 }
507
508 let mut col_data = Vec::with_capacity(end - start);
510 for idx in start..end {
511 col_data.push((self.indices[idx], self.data[idx]));
512 }
513
514 col_data.sort_by_key(|&(row_, _)| row_);
516
517 for (i, (row, val)) in col_data.into_iter().enumerate() {
519 self.indices[start + i] = row;
520 self.data[start + i] = val;
521 }
522 }
523
524 self.has_sorted_indices = true;
525 }
526
527 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
528 if self.has_sorted_indices {
529 return Box::new(self.clone());
530 }
531
532 let mut sorted = self.clone();
533 sorted.sort_indices();
534 Box::new(sorted)
535 }
536
537 fn has_sorted_indices(&self) -> bool {
538 self.has_sorted_indices
539 }
540
541 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
542 match axis {
543 None => {
544 let mut sum = T::sparse_zero();
546 for &val in self.data.iter() {
547 sum = sum + val;
548 }
549 Ok(SparseSum::Scalar(sum))
550 }
551 Some(0) => {
552 let self_csr = self.to_csr()?;
555 self_csr.sum(Some(0))
556 }
557 Some(1) => {
558 let mut result = Vec::with_capacity(self.shape.1);
560
561 for col in 0..self.shape.1 {
562 let start = self.indptr[col];
563 let end = self.indptr[col + 1];
564
565 let mut col_sum = T::sparse_zero();
566 for idx in start..end {
567 col_sum = col_sum + self.data[idx];
568 }
569 result.push(col_sum);
570 }
571
572 let mut row_indices = Vec::new();
574 let mut col_indices = Vec::new();
575 let mut values = Vec::new();
576
577 for (col, &val) in result.iter().enumerate() {
578 if !SparseElement::is_zero(&val) {
579 row_indices.push(0);
580 col_indices.push(col);
581 values.push(val);
582 }
583 }
584
585 let coo = CooArray::from_triplets(
586 &row_indices,
587 &col_indices,
588 &values,
589 (1, self.shape.1),
590 true,
591 )?;
592
593 Ok(SparseSum::SparseArray(Box::new(coo)))
594 }
595 _ => Err(SparseError::InvalidAxis),
596 }
597 }
598
599 fn max(&self) -> T {
600 if self.data.is_empty() {
601 return T::sparse_zero();
603 }
604
605 let mut max_val = self.data[0];
606 for &val in self.data.iter().skip(1) {
607 if val > max_val {
608 max_val = val;
609 }
610 }
611
612 let zero = T::sparse_zero();
614 if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
615 max_val = zero;
616 }
617
618 max_val
619 }
620
621 fn min(&self) -> T {
622 if self.data.is_empty() {
623 return T::sparse_zero();
624 }
625
626 let mut min_val = self.data[0];
627 for &val in self.data.iter().skip(1) {
628 if val < min_val {
629 min_val = val;
630 }
631 }
632
633 if min_val > T::sparse_zero() && self.nnz() < self.shape.0 * self.shape.1 {
635 min_val = T::sparse_zero();
636 }
637
638 min_val
639 }
640
641 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
642 let nnz = self.nnz();
643 let mut rows = Vec::with_capacity(nnz);
644 let mut cols = Vec::with_capacity(nnz);
645 let mut values = Vec::with_capacity(nnz);
646
647 for col in 0..self.shape.1 {
648 let start = self.indptr[col];
649 let end = self.indptr[col + 1];
650
651 for idx in start..end {
652 let row = self.indices[idx];
653 rows.push(row);
654 cols.push(col);
655 values.push(self.data[idx]);
656 }
657 }
658
659 (
660 Array1::from_vec(rows),
661 Array1::from_vec(cols),
662 Array1::from_vec(values),
663 )
664 }
665
666 fn slice(
667 &self,
668 row_range: (usize, usize),
669 col_range: (usize, usize),
670 ) -> SparseResult<Box<dyn SparseArray<T>>> {
671 let (start_row, end_row) = row_range;
672 let (start_col, end_col) = col_range;
673
674 if start_row >= self.shape.0
675 || end_row > self.shape.0
676 || start_col >= self.shape.1
677 || end_col > self.shape.1
678 {
679 return Err(SparseError::InvalidSliceRange);
680 }
681
682 if start_row >= end_row || start_col >= end_col {
683 return Err(SparseError::InvalidSliceRange);
684 }
685
686 let mut data = Vec::new();
688 let mut indices = Vec::new();
689 let mut indptr = vec![0];
690
691 for col in start_col..end_col {
692 let start = self.indptr[col];
693 let end = self.indptr[col + 1];
694
695 for idx in start..end {
696 let row = self.indices[idx];
697 if row >= start_row && row < end_row {
698 data.push(self.data[idx]);
699 indices.push(row - start_row); }
701 }
702 indptr.push(data.len());
703 }
704
705 CscArray::new(
706 Array1::from_vec(data),
707 Array1::from_vec(indices),
708 Array1::from_vec(indptr),
709 (end_row - start_row, end_col - start_col),
710 )
711 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
712 }
713
714 fn as_any(&self) -> &dyn std::any::Any {
715 self
716 }
717
718 fn get_indptr(&self) -> Option<&Array1<usize>> {
719 Some(&self.indptr)
720 }
721
722 fn indptr(&self) -> Option<&Array1<usize>> {
723 Some(&self.indptr)
724 }
725}
726
727impl<T> fmt::Debug for CscArray<T>
728where
729 T: SparseElement + Div<Output = T> + Float + 'static,
730{
731 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
732 write!(
733 f,
734 "CscArray<{}x{}, nnz={}>",
735 self.shape.0,
736 self.shape.1,
737 self.nnz()
738 )
739 }
740}
741
742#[cfg(test)]
743mod tests {
744 use super::*;
745 use approx::assert_relative_eq;
746
747 #[test]
748 fn test_csc_array_construction() {
749 let data = Array1::from_vec(vec![1.0, 4.0, 2.0, 3.0, 5.0]);
750 let indices = Array1::from_vec(vec![0, 2, 0, 1, 2]);
751 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
752 let shape = (3, 3);
753
754 let csc = CscArray::new(data, indices, indptr, shape).unwrap();
755
756 assert_eq!(csc.shape(), (3, 3));
757 assert_eq!(csc.nnz(), 5);
758 assert_eq!(csc.get(0, 0), 1.0);
759 assert_eq!(csc.get(2, 0), 4.0);
760 assert_eq!(csc.get(0, 1), 2.0);
761 assert_eq!(csc.get(1, 2), 3.0);
762 assert_eq!(csc.get(2, 2), 5.0);
763 assert_eq!(csc.get(1, 0), 0.0);
764 }
765
766 #[test]
767 fn test_csc_from_triplets() {
768 let rows = vec![0, 2, 0, 1, 2];
769 let cols = vec![0, 0, 1, 2, 2];
770 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
771 let shape = (3, 3);
772
773 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
774
775 assert_eq!(csc.shape(), (3, 3));
776 assert_eq!(csc.nnz(), 5);
777 assert_eq!(csc.get(0, 0), 1.0);
778 assert_eq!(csc.get(2, 0), 4.0);
779 assert_eq!(csc.get(0, 1), 2.0);
780 assert_eq!(csc.get(1, 2), 3.0);
781 assert_eq!(csc.get(2, 2), 5.0);
782 assert_eq!(csc.get(1, 0), 0.0);
783 }
784
785 #[test]
786 fn test_csc_array_to_array() {
787 let rows = vec![0, 2, 0, 1, 2];
788 let cols = vec![0, 0, 1, 2, 2];
789 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
790 let shape = (3, 3);
791
792 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
793 let dense = csc.to_array();
794
795 assert_eq!(dense.shape(), &[3, 3]);
796 assert_eq!(dense[[0, 0]], 1.0);
797 assert_eq!(dense[[1, 0]], 0.0);
798 assert_eq!(dense[[2, 0]], 4.0);
799 assert_eq!(dense[[0, 1]], 2.0);
800 assert_eq!(dense[[1, 1]], 0.0);
801 assert_eq!(dense[[2, 1]], 0.0);
802 assert_eq!(dense[[0, 2]], 0.0);
803 assert_eq!(dense[[1, 2]], 3.0);
804 assert_eq!(dense[[2, 2]], 5.0);
805 }
806
807 #[test]
808 fn test_csc_to_csr_conversion() {
809 let rows = vec![0, 2, 0, 1, 2];
810 let cols = vec![0, 0, 1, 2, 2];
811 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
812 let shape = (3, 3);
813
814 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
815 let csr = csc.to_csr().unwrap();
816
817 let csc_array = csc.to_array();
819 let csr_array = csr.to_array();
820
821 for i in 0..shape.0 {
822 for j in 0..shape.1 {
823 assert_relative_eq!(csc_array[[i, j]], csr_array[[i, j]]);
824 }
825 }
826 }
827
828 #[test]
829 fn test_csc_dot_vector() {
830 let rows = vec![0, 2, 0, 1, 2];
831 let cols = vec![0, 0, 1, 2, 2];
832 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
833 let shape = (3, 3);
834
835 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
836 let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
837
838 let result = csc.dot_vector(&vec.view()).unwrap();
839
840 assert_eq!(result.len(), 3);
845 assert_relative_eq!(result[0], 5.0);
846 assert_relative_eq!(result[1], 9.0);
847 assert_relative_eq!(result[2], 19.0);
848 }
849
850 #[test]
851 fn test_csc_transpose() {
852 let rows = vec![0, 2, 0, 1, 2];
853 let cols = vec![0, 0, 1, 2, 2];
854 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
855 let shape = (3, 3);
856
857 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
858 let transposed = csc.transpose().unwrap();
859
860 assert_eq!(transposed.shape(), (3, 3));
862
863 let dense = transposed.to_array();
865 assert_eq!(dense[[0, 0]], 1.0);
866 assert_eq!(dense[[0, 2]], 4.0);
867 assert_eq!(dense[[1, 0]], 2.0);
868 assert_eq!(dense[[2, 1]], 3.0);
869 assert_eq!(dense[[2, 2]], 5.0);
870 }
871
872 #[test]
873 fn test_csc_find() {
874 let rows = vec![0, 2, 0, 1, 2];
875 let cols = vec![0, 0, 1, 2, 2];
876 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
877 let shape = (3, 3);
878
879 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
880 let (result_rows, result_cols, result_data) = csc.find();
881
882 assert_eq!(result_rows.len(), 5);
884 assert_eq!(result_cols.len(), 5);
885 assert_eq!(result_data.len(), 5);
886
887 let mut original: Vec<_> = rows
889 .iter()
890 .zip(cols.iter())
891 .zip(data.iter())
892 .map(|((r, c), d)| (*r, *c, *d))
893 .collect();
894
895 let mut result: Vec<_> = result_rows
896 .iter()
897 .zip(result_cols.iter())
898 .zip(result_data.iter())
899 .map(|((r, c), d)| (*r, *c, *d))
900 .collect();
901
902 original.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
904 result.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
905
906 assert_eq!(original, result);
907 }
908}