1use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::error::{SparseError, SparseResult};
14use crate::sparray::{SparseArray, SparseSum};
15
16#[derive(Clone)]
32pub struct CscArray<T>
33where
34 T: Float
35 + Add<Output = T>
36 + Sub<Output = T>
37 + Mul<Output = T>
38 + Div<Output = T>
39 + Debug
40 + Copy
41 + 'static,
42{
43 data: Array1<T>,
45 indices: Array1<usize>,
47 indptr: Array1<usize>,
49 shape: (usize, usize),
51 has_sorted_indices: bool,
53}
54
55impl<T> CscArray<T>
56where
57 T: Float
58 + Add<Output = T>
59 + Sub<Output = T>
60 + Mul<Output = T>
61 + Div<Output = T>
62 + Debug
63 + Copy
64 + 'static,
65{
66 pub fn new(
80 data: Array1<T>,
81 indices: Array1<usize>,
82 indptr: Array1<usize>,
83 shape: (usize, usize),
84 ) -> SparseResult<Self> {
85 if data.len() != indices.len() {
87 return Err(SparseError::InconsistentData {
88 reason: "data and indices must have the same length".to_string(),
89 });
90 }
91
92 if indptr.len() != shape.1 + 1 {
93 return Err(SparseError::InconsistentData {
94 reason: format!(
95 "indptr length ({}) must be one more than the number of columns ({})",
96 indptr.len(),
97 shape.1
98 ),
99 });
100 }
101
102 if let Some(&max_idx) = indices.iter().max() {
103 if max_idx >= shape.0 {
104 return Err(SparseError::IndexOutOfBounds {
105 index: (max_idx, 0),
106 shape,
107 });
108 }
109 }
110
111 if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
112 if first != 0 {
113 return Err(SparseError::InconsistentData {
114 reason: "first element of indptr must be 0".to_string(),
115 });
116 }
117
118 if last != data.len() {
119 return Err(SparseError::InconsistentData {
120 reason: format!(
121 "last element of indptr ({}) must equal data length ({})",
122 last,
123 data.len()
124 ),
125 });
126 }
127 }
128
129 let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
130
131 Ok(Self {
132 data,
133 indices,
134 indptr,
135 shape,
136 has_sorted_indices,
137 })
138 }
139
140 pub fn from_triplets(
155 rows: &[usize],
156 cols: &[usize],
157 data: &[T],
158 shape: (usize, usize),
159 sorted: bool,
160 ) -> SparseResult<Self> {
161 if rows.len() != cols.len() || rows.len() != data.len() {
162 return Err(SparseError::InconsistentData {
163 reason: "rows, cols, and data must have the same length".to_string(),
164 });
165 }
166
167 if rows.is_empty() {
168 let indptr = Array1::zeros(shape.1 + 1);
170 return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
171 }
172
173 let nnz = rows.len();
174 let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
175
176 for i in 0..nnz {
177 if rows[i] >= shape.0 || cols[i] >= shape.1 {
178 return Err(SparseError::IndexOutOfBounds {
179 index: (rows[i], cols[i]),
180 shape,
181 });
182 }
183 all_data.push((rows[i], cols[i], data[i]));
184 }
185
186 if !sorted {
187 all_data.sort_by_key(|&(_, col, _)| col);
188 }
189
190 let mut col_counts = vec![0; shape.1];
192 for &(_, col, _) in &all_data {
193 col_counts[col] += 1;
194 }
195
196 let mut indptr = Vec::with_capacity(shape.1 + 1);
198 indptr.push(0);
199 let mut cumsum = 0;
200 for &count in &col_counts {
201 cumsum += count;
202 indptr.push(cumsum);
203 }
204
205 let mut indices = Vec::with_capacity(nnz);
207 let mut values = Vec::with_capacity(nnz);
208
209 for (row, _, val) in all_data {
210 indices.push(row);
211 values.push(val);
212 }
213
214 Self::new(
215 Array1::from_vec(values),
216 Array1::from_vec(indices),
217 Array1::from_vec(indptr),
218 shape,
219 )
220 }
221
222 fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
224 for col in 0..indptr.len() - 1 {
225 let start = indptr[col];
226 let end = indptr[col + 1];
227
228 for i in start..end.saturating_sub(1) {
229 if i + 1 < indices.len() && indices[i] > indices[i + 1] {
230 return false;
231 }
232 }
233 }
234 true
235 }
236
237 pub fn get_data(&self) -> &Array1<T> {
239 &self.data
240 }
241
242 pub fn get_indices(&self) -> &Array1<usize> {
244 &self.indices
245 }
246
247 pub fn get_indptr(&self) -> &Array1<usize> {
249 &self.indptr
250 }
251}
252
253impl<T> SparseArray<T> for CscArray<T>
254where
255 T: Float
256 + Add<Output = T>
257 + Sub<Output = T>
258 + Mul<Output = T>
259 + Div<Output = T>
260 + Debug
261 + Copy
262 + 'static,
263{
264 fn shape(&self) -> (usize, usize) {
265 self.shape
266 }
267
268 fn nnz(&self) -> usize {
269 self.data.len()
270 }
271
272 fn dtype(&self) -> &str {
273 "float" }
275
276 fn to_array(&self) -> Array2<T> {
277 let (rows, cols) = self.shape;
278 let mut result = Array2::zeros((rows, cols));
279
280 for col in 0..cols {
281 let start = self.indptr[col];
282 let end = self.indptr[col + 1];
283
284 for i in start..end {
285 let row = self.indices[i];
286 result[[row, col]] = self.data[i];
287 }
288 }
289
290 result
291 }
292
293 fn toarray(&self) -> Array2<T> {
294 self.to_array()
295 }
296
297 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
298 let nnz = self.nnz();
300 let mut row_indices = Vec::with_capacity(nnz);
301 let mut col_indices = Vec::with_capacity(nnz);
302 let mut values = Vec::with_capacity(nnz);
303
304 for col in 0..self.shape.1 {
305 let start = self.indptr[col];
306 let end = self.indptr[col + 1];
307
308 for idx in start..end {
309 row_indices.push(self.indices[idx]);
310 col_indices.push(col);
311 values.push(self.data[idx]);
312 }
313 }
314
315 CooArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
316 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
317 }
318
319 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
320 self.to_coo()?.to_csr()
322 }
323
324 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
325 Ok(Box::new(self.clone()))
326 }
327
328 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
329 self.to_coo()?.to_dok()
332 }
333
334 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
335 self.to_coo()?.to_lil()
338 }
339
340 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
341 self.to_coo()?.to_dia()
344 }
345
346 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
347 self.to_coo()?.to_bsr()
350 }
351
352 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
353 let self_csr = self.to_csr()?;
355 self_csr.add(other)
356 }
357
358 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
359 let self_csr = self.to_csr()?;
361 self_csr.sub(other)
362 }
363
364 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
365 let self_csr = self.to_csr()?;
368 self_csr.mul(other)
369 }
370
371 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
372 let self_csr = self.to_csr()?;
375 self_csr.div(other)
376 }
377
378 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
379 let self_csr = self.to_csr()?;
382 self_csr.dot(other)
383 }
384
385 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
386 let (m, n) = self.shape();
387 if n != other.len() {
388 return Err(SparseError::DimensionMismatch {
389 expected: n,
390 found: other.len(),
391 });
392 }
393
394 let mut result = Array1::zeros(m);
395
396 for col in 0..n {
397 let start = self.indptr[col];
398 let end = self.indptr[col + 1];
399
400 let val = other[col];
401 if !val.is_zero() {
402 for idx in start..end {
403 let row = self.indices[idx];
404 result[row] = result[row] + self.data[idx] * val;
405 }
406 }
407 }
408
409 Ok(result)
410 }
411
412 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
413 CsrArray::new(
415 self.data.clone(),
416 self.indices.clone(),
417 self.indptr.clone(),
418 (self.shape.1, self.shape.0), )
420 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
421 }
422
423 fn copy(&self) -> Box<dyn SparseArray<T>> {
424 Box::new(self.clone())
425 }
426
427 fn get(&self, i: usize, j: usize) -> T {
428 if i >= self.shape.0 || j >= self.shape.1 {
429 return T::zero();
430 }
431
432 let start = self.indptr[j];
433 let end = self.indptr[j + 1];
434
435 for idx in start..end {
436 if self.indices[idx] == i {
437 return self.data[idx];
438 }
439
440 if self.has_sorted_indices && self.indices[idx] > i {
442 break;
443 }
444 }
445
446 T::zero()
447 }
448
449 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
450 if i >= self.shape.0 || j >= self.shape.1 {
453 return Err(SparseError::IndexOutOfBounds {
454 index: (i, j),
455 shape: self.shape,
456 });
457 }
458
459 let start = self.indptr[j];
460 let end = self.indptr[j + 1];
461
462 for idx in start..end {
464 if self.indices[idx] == i {
465 self.data[idx] = value;
467 return Ok(());
468 }
469
470 if self.has_sorted_indices && self.indices[idx] > i {
472 return Err(SparseError::NotImplemented(
475 "Inserting new elements in CSC format".to_string(),
476 ));
477 }
478 }
479
480 Err(SparseError::NotImplemented(
483 "Inserting new elements in CSC format".to_string(),
484 ))
485 }
486
487 fn eliminate_zeros(&mut self) {
488 let mut new_data = Vec::new();
490 let mut new_indices = Vec::new();
491 let mut new_indptr = vec![0];
492
493 let (_, cols) = self.shape;
494
495 for col in 0..cols {
496 let start = self.indptr[col];
497 let end = self.indptr[col + 1];
498
499 for idx in start..end {
500 if !self.data[idx].is_zero() {
501 new_data.push(self.data[idx]);
502 new_indices.push(self.indices[idx]);
503 }
504 }
505 new_indptr.push(new_data.len());
506 }
507
508 self.data = Array1::from_vec(new_data);
510 self.indices = Array1::from_vec(new_indices);
511 self.indptr = Array1::from_vec(new_indptr);
512 }
513
514 fn sort_indices(&mut self) {
515 if self.has_sorted_indices {
516 return;
517 }
518
519 let (_, cols) = self.shape;
520
521 for col in 0..cols {
522 let start = self.indptr[col];
523 let end = self.indptr[col + 1];
524
525 if start == end {
526 continue;
527 }
528
529 let mut col_data = Vec::with_capacity(end - start);
531 for idx in start..end {
532 col_data.push((self.indices[idx], self.data[idx]));
533 }
534
535 col_data.sort_by_key(|&(row, _)| row);
537
538 for (i, (row, val)) in col_data.into_iter().enumerate() {
540 self.indices[start + i] = row;
541 self.data[start + i] = val;
542 }
543 }
544
545 self.has_sorted_indices = true;
546 }
547
548 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
549 if self.has_sorted_indices {
550 return Box::new(self.clone());
551 }
552
553 let mut sorted = self.clone();
554 sorted.sort_indices();
555 Box::new(sorted)
556 }
557
558 fn has_sorted_indices(&self) -> bool {
559 self.has_sorted_indices
560 }
561
562 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
563 match axis {
564 None => {
565 let mut sum = T::zero();
567 for &val in self.data.iter() {
568 sum = sum + val;
569 }
570 Ok(SparseSum::Scalar(sum))
571 }
572 Some(0) => {
573 let self_csr = self.to_csr()?;
576 self_csr.sum(Some(0))
577 }
578 Some(1) => {
579 let mut result = Vec::with_capacity(self.shape.1);
581
582 for col in 0..self.shape.1 {
583 let start = self.indptr[col];
584 let end = self.indptr[col + 1];
585
586 let mut col_sum = T::zero();
587 for idx in start..end {
588 col_sum = col_sum + self.data[idx];
589 }
590 result.push(col_sum);
591 }
592
593 let mut row_indices = Vec::new();
595 let mut col_indices = Vec::new();
596 let mut values = Vec::new();
597
598 for (col, &val) in result.iter().enumerate() {
599 if !val.is_zero() {
600 row_indices.push(0);
601 col_indices.push(col);
602 values.push(val);
603 }
604 }
605
606 let coo = CooArray::from_triplets(
607 &row_indices,
608 &col_indices,
609 &values,
610 (1, self.shape.1),
611 true,
612 )?;
613
614 Ok(SparseSum::SparseArray(Box::new(coo)))
615 }
616 _ => Err(SparseError::InvalidAxis),
617 }
618 }
619
620 fn max(&self) -> T {
621 if self.data.is_empty() {
622 return T::neg_infinity();
623 }
624
625 let mut max_val = self.data[0];
626 for &val in self.data.iter().skip(1) {
627 if val > max_val {
628 max_val = val;
629 }
630 }
631
632 if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
634 max_val = T::zero();
635 }
636
637 max_val
638 }
639
640 fn min(&self) -> T {
641 if self.data.is_empty() {
642 return T::infinity();
643 }
644
645 let mut min_val = self.data[0];
646 for &val in self.data.iter().skip(1) {
647 if val < min_val {
648 min_val = val;
649 }
650 }
651
652 if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
654 min_val = T::zero();
655 }
656
657 min_val
658 }
659
660 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
661 let nnz = self.nnz();
662 let mut rows = Vec::with_capacity(nnz);
663 let mut cols = Vec::with_capacity(nnz);
664 let mut values = Vec::with_capacity(nnz);
665
666 for col in 0..self.shape.1 {
667 let start = self.indptr[col];
668 let end = self.indptr[col + 1];
669
670 for idx in start..end {
671 let row = self.indices[idx];
672 rows.push(row);
673 cols.push(col);
674 values.push(self.data[idx]);
675 }
676 }
677
678 (
679 Array1::from_vec(rows),
680 Array1::from_vec(cols),
681 Array1::from_vec(values),
682 )
683 }
684
685 fn slice(
686 &self,
687 row_range: (usize, usize),
688 col_range: (usize, usize),
689 ) -> SparseResult<Box<dyn SparseArray<T>>> {
690 let (start_row, end_row) = row_range;
691 let (start_col, end_col) = col_range;
692
693 if start_row >= self.shape.0
694 || end_row > self.shape.0
695 || start_col >= self.shape.1
696 || end_col > self.shape.1
697 {
698 return Err(SparseError::InvalidSliceRange);
699 }
700
701 if start_row >= end_row || start_col >= end_col {
702 return Err(SparseError::InvalidSliceRange);
703 }
704
705 let mut data = Vec::new();
707 let mut indices = Vec::new();
708 let mut indptr = vec![0];
709
710 for col in start_col..end_col {
711 let start = self.indptr[col];
712 let end = self.indptr[col + 1];
713
714 for idx in start..end {
715 let row = self.indices[idx];
716 if row >= start_row && row < end_row {
717 data.push(self.data[idx]);
718 indices.push(row - start_row); }
720 }
721 indptr.push(data.len());
722 }
723
724 CscArray::new(
725 Array1::from_vec(data),
726 Array1::from_vec(indices),
727 Array1::from_vec(indptr),
728 (end_row - start_row, end_col - start_col),
729 )
730 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
731 }
732
733 fn as_any(&self) -> &dyn std::any::Any {
734 self
735 }
736}
737
738impl<T> fmt::Debug for CscArray<T>
739where
740 T: Float
741 + Add<Output = T>
742 + Sub<Output = T>
743 + Mul<Output = T>
744 + Div<Output = T>
745 + Debug
746 + Copy
747 + 'static,
748{
749 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
750 write!(
751 f,
752 "CscArray<{}x{}, nnz={}>",
753 self.shape.0,
754 self.shape.1,
755 self.nnz()
756 )
757 }
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763 use approx::assert_relative_eq;
764
765 #[test]
766 fn test_csc_array_construction() {
767 let data = Array1::from_vec(vec![1.0, 4.0, 2.0, 3.0, 5.0]);
768 let indices = Array1::from_vec(vec![0, 2, 0, 1, 2]);
769 let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
770 let shape = (3, 3);
771
772 let csc = CscArray::new(data, indices, indptr, shape).unwrap();
773
774 assert_eq!(csc.shape(), (3, 3));
775 assert_eq!(csc.nnz(), 5);
776 assert_eq!(csc.get(0, 0), 1.0);
777 assert_eq!(csc.get(2, 0), 4.0);
778 assert_eq!(csc.get(0, 1), 2.0);
779 assert_eq!(csc.get(1, 2), 3.0);
780 assert_eq!(csc.get(2, 2), 5.0);
781 assert_eq!(csc.get(1, 0), 0.0);
782 }
783
784 #[test]
785 fn test_csc_from_triplets() {
786 let rows = vec![0, 2, 0, 1, 2];
787 let cols = vec![0, 0, 1, 2, 2];
788 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
789 let shape = (3, 3);
790
791 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
792
793 assert_eq!(csc.shape(), (3, 3));
794 assert_eq!(csc.nnz(), 5);
795 assert_eq!(csc.get(0, 0), 1.0);
796 assert_eq!(csc.get(2, 0), 4.0);
797 assert_eq!(csc.get(0, 1), 2.0);
798 assert_eq!(csc.get(1, 2), 3.0);
799 assert_eq!(csc.get(2, 2), 5.0);
800 assert_eq!(csc.get(1, 0), 0.0);
801 }
802
803 #[test]
804 fn test_csc_array_to_array() {
805 let rows = vec![0, 2, 0, 1, 2];
806 let cols = vec![0, 0, 1, 2, 2];
807 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
808 let shape = (3, 3);
809
810 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
811 let dense = csc.to_array();
812
813 assert_eq!(dense.shape(), &[3, 3]);
814 assert_eq!(dense[[0, 0]], 1.0);
815 assert_eq!(dense[[1, 0]], 0.0);
816 assert_eq!(dense[[2, 0]], 4.0);
817 assert_eq!(dense[[0, 1]], 2.0);
818 assert_eq!(dense[[1, 1]], 0.0);
819 assert_eq!(dense[[2, 1]], 0.0);
820 assert_eq!(dense[[0, 2]], 0.0);
821 assert_eq!(dense[[1, 2]], 3.0);
822 assert_eq!(dense[[2, 2]], 5.0);
823 }
824
825 #[test]
826 fn test_csc_to_csr_conversion() {
827 let rows = vec![0, 2, 0, 1, 2];
828 let cols = vec![0, 0, 1, 2, 2];
829 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
830 let shape = (3, 3);
831
832 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
833 let csr = csc.to_csr().unwrap();
834
835 let csc_array = csc.to_array();
837 let csr_array = csr.to_array();
838
839 for i in 0..shape.0 {
840 for j in 0..shape.1 {
841 assert_relative_eq!(csc_array[[i, j]], csr_array[[i, j]]);
842 }
843 }
844 }
845
846 #[test]
847 fn test_csc_dot_vector() {
848 let rows = vec![0, 2, 0, 1, 2];
849 let cols = vec![0, 0, 1, 2, 2];
850 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
851 let shape = (3, 3);
852
853 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
854 let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
855
856 let result = csc.dot_vector(&vec.view()).unwrap();
857
858 assert_eq!(result.len(), 3);
863 assert_relative_eq!(result[0], 5.0);
864 assert_relative_eq!(result[1], 9.0);
865 assert_relative_eq!(result[2], 19.0);
866 }
867
868 #[test]
869 fn test_csc_transpose() {
870 let rows = vec![0, 2, 0, 1, 2];
871 let cols = vec![0, 0, 1, 2, 2];
872 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
873 let shape = (3, 3);
874
875 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
876 let transposed = csc.transpose().unwrap();
877
878 assert_eq!(transposed.shape(), (3, 3));
880
881 let dense = transposed.to_array();
883 assert_eq!(dense[[0, 0]], 1.0);
884 assert_eq!(dense[[0, 2]], 4.0);
885 assert_eq!(dense[[1, 0]], 2.0);
886 assert_eq!(dense[[2, 1]], 3.0);
887 assert_eq!(dense[[2, 2]], 5.0);
888 }
889
890 #[test]
891 fn test_csc_find() {
892 let rows = vec![0, 2, 0, 1, 2];
893 let cols = vec![0, 0, 1, 2, 2];
894 let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
895 let shape = (3, 3);
896
897 let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
898 let (result_rows, result_cols, result_data) = csc.find();
899
900 assert_eq!(result_rows.len(), 5);
902 assert_eq!(result_cols.len(), 5);
903 assert_eq!(result_data.len(), 5);
904
905 let mut original: Vec<_> = rows
907 .iter()
908 .zip(cols.iter())
909 .zip(data.iter())
910 .map(|((r, c), d)| (*r, *c, *d))
911 .collect();
912
913 let mut result: Vec<_> = result_rows
914 .iter()
915 .zip(result_cols.iter())
916 .zip(result_data.iter())
917 .map(|((r, c), d)| (*r, *c, *d))
918 .collect();
919
920 original.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
922 result.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
923
924 assert_eq!(original, result);
925 }
926}