Skip to main content

scirs2_sparse/
csc_array.rs

1// CSC Array implementation
2//
3// This module provides the CSC (Compressed Sparse Column) array format,
4// which is efficient for column-wise operations.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::error::{SparseError, SparseResult};
14use crate::sparray::{SparseArray, SparseSum};
15
16/// Insert a value into an `Array1` at position `idx`, shifting subsequent
17/// elements to the right.  ndarray's `Array1` does not provide an insert
18/// method, so we convert to `Vec`, insert, and convert back.
19fn array1_insert<T: Clone + Default>(arr: &Array1<T>, idx: usize, value: T) -> Array1<T> {
20    let mut v = arr.to_vec();
21    v.insert(idx, value);
22    Array1::from_vec(v)
23}
24
25/// CSC Array format
26///
27/// The CSC (Compressed Sparse Column) format stores a sparse array in three arrays:
28/// - data: array of non-zero values
29/// - indices: row indices of the non-zero values
30/// - indptr: index pointers; for each column, points to the first non-zero element
31///
32/// # Notes
33///
34/// - Efficient for column-oriented operations
35/// - Fast matrix-vector multiplications for A^T x
36/// - Fast column slicing
37/// - Slow row slicing
38/// - Slow constructing by setting individual elements
39///
40#[derive(Clone)]
41pub struct CscArray<T>
42where
43    T: SparseElement + Div<Output = T> + 'static,
44{
45    /// Non-zero values
46    data: Array1<T>,
47    /// Row indices of non-zero values
48    indices: Array1<usize>,
49    /// Column pointers (indices into data/indices for the start of each column)
50    indptr: Array1<usize>,
51    /// Shape of the sparse array
52    shape: (usize, usize),
53    /// Whether indices are sorted for each column
54    has_sorted_indices: bool,
55}
56
57impl<T> CscArray<T>
58where
59    T: SparseElement + Div<Output = T> + Zero + 'static,
60{
61    /// Creates a new CSC array from raw components
62    ///
63    /// # Arguments
64    /// * `data` - Array of non-zero values
65    /// * `indices` - Row indices of non-zero values
66    /// * `indptr` - Index pointers for the start of each column
67    /// * `shape` - Shape of the sparse array
68    ///
69    /// # Returns
70    /// A new `CscArray`
71    ///
72    /// # Errors
73    /// Returns an error if the data is not consistent
74    pub fn new(
75        data: Array1<T>,
76        indices: Array1<usize>,
77        indptr: Array1<usize>,
78        shape: (usize, usize),
79    ) -> SparseResult<Self> {
80        // Validation
81        if data.len() != indices.len() {
82            return Err(SparseError::InconsistentData {
83                reason: "data and indices must have the same length".to_string(),
84            });
85        }
86
87        if indptr.len() != shape.1 + 1 {
88            return Err(SparseError::InconsistentData {
89                reason: format!(
90                    "indptr length ({}) must be one more than the number of columns ({})",
91                    indptr.len(),
92                    shape.1
93                ),
94            });
95        }
96
97        if let Some(&max_idx) = indices.iter().max() {
98            if max_idx >= shape.0 {
99                return Err(SparseError::IndexOutOfBounds {
100                    index: (max_idx, 0),
101                    shape,
102                });
103            }
104        }
105
106        if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
107            if first != 0 {
108                return Err(SparseError::InconsistentData {
109                    reason: "first element of indptr must be 0".to_string(),
110                });
111            }
112
113            if last != data.len() {
114                return Err(SparseError::InconsistentData {
115                    reason: format!(
116                        "last element of indptr ({}) must equal data length ({})",
117                        last,
118                        data.len()
119                    ),
120                });
121            }
122        }
123
124        let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
125
126        Ok(Self {
127            data,
128            indices,
129            indptr,
130            shape,
131            has_sorted_indices,
132        })
133    }
134
135    /// Create a CSC array from triplet format (COO-like)
136    ///
137    /// # Arguments
138    /// * `rows` - Row indices
139    /// * `cols` - Column indices
140    /// * `data` - Values
141    /// * `shape` - Shape of the sparse array
142    /// * `sorted` - Whether the triplets are sorted by column
143    ///
144    /// # Returns
145    /// A new `CscArray`
146    ///
147    /// # Errors
148    /// Returns an error if the data is not consistent
149    pub fn from_triplets(
150        rows: &[usize],
151        cols: &[usize],
152        data: &[T],
153        shape: (usize, usize),
154        sorted: bool,
155    ) -> SparseResult<Self> {
156        if rows.len() != cols.len() || rows.len() != data.len() {
157            return Err(SparseError::InconsistentData {
158                reason: "rows, cols, and data must have the same length".to_string(),
159            });
160        }
161
162        if rows.is_empty() {
163            // Empty matrix
164            let indptr = Array1::zeros(shape.1 + 1);
165            return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
166        }
167
168        let nnz = rows.len();
169        let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
170
171        for i in 0..nnz {
172            if rows[i] >= shape.0 || cols[i] >= shape.1 {
173                return Err(SparseError::IndexOutOfBounds {
174                    index: (rows[i], cols[i]),
175                    shape,
176                });
177            }
178            all_data.push((rows[i], cols[i], data[i]));
179        }
180
181        if !sorted {
182            all_data.sort_by_key(|&(_, col_, _)| col_);
183        }
184
185        // Count elements per column
186        let mut col_counts = vec![0; shape.1];
187        for &(_, col_, _) in &all_data {
188            col_counts[col_] += 1;
189        }
190
191        // Create indptr
192        let mut indptr = Vec::with_capacity(shape.1 + 1);
193        indptr.push(0);
194        let mut cumsum = 0;
195        for &count in &col_counts {
196            cumsum += count;
197            indptr.push(cumsum);
198        }
199
200        // Create indices and data arrays
201        let mut indices = Vec::with_capacity(nnz);
202        let mut values = Vec::with_capacity(nnz);
203
204        for (row_, _, val) in all_data {
205            indices.push(row_);
206            values.push(val);
207        }
208
209        Self::new(
210            Array1::from_vec(values),
211            Array1::from_vec(indices),
212            Array1::from_vec(indptr),
213            shape,
214        )
215    }
216
217    /// Checks if row indices are sorted for each column
218    fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
219        for col in 0..indptr.len() - 1 {
220            let start = indptr[col];
221            let end = indptr[col + 1];
222
223            for i in start..end.saturating_sub(1) {
224                if i + 1 < indices.len() && indices[i] > indices[i + 1] {
225                    return false;
226                }
227            }
228        }
229        true
230    }
231
232    /// Get the raw data array
233    pub fn get_data(&self) -> &Array1<T> {
234        &self.data
235    }
236
237    /// Get the raw indices array
238    pub fn get_indices(&self) -> &Array1<usize> {
239        &self.indices
240    }
241
242    /// Get the raw indptr array
243    pub fn get_indptr(&self) -> &Array1<usize> {
244        &self.indptr
245    }
246}
247
248impl<T> SparseArray<T> for CscArray<T>
249where
250    T: SparseElement + Div<Output = T> + Float + 'static,
251{
252    fn shape(&self) -> (usize, usize) {
253        self.shape
254    }
255
256    fn nnz(&self) -> usize {
257        self.data.len()
258    }
259
260    fn dtype(&self) -> &str {
261        "float" // Placeholder, ideally we would return the actual type
262    }
263
264    fn to_array(&self) -> Array2<T> {
265        let (rows, cols) = self.shape;
266        let mut result = Array2::zeros((rows, cols));
267
268        for col in 0..cols {
269            let start = self.indptr[col];
270            let end = self.indptr[col + 1];
271
272            for i in start..end {
273                let row = self.indices[i];
274                result[[row, col]] = self.data[i];
275            }
276        }
277
278        result
279    }
280
281    fn toarray(&self) -> Array2<T> {
282        self.to_array()
283    }
284
285    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
286        // Convert to COO format
287        let nnz = self.nnz();
288        let mut row_indices = Vec::with_capacity(nnz);
289        let mut col_indices = Vec::with_capacity(nnz);
290        let mut values = Vec::with_capacity(nnz);
291
292        for col in 0..self.shape.1 {
293            let start = self.indptr[col];
294            let end = self.indptr[col + 1];
295
296            for idx in start..end {
297                row_indices.push(self.indices[idx]);
298                col_indices.push(col);
299                values.push(self.data[idx]);
300            }
301        }
302
303        CooArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
304            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
305    }
306
307    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
308        // For efficiency, we'll go via COO format
309        self.to_coo()?.to_csr()
310    }
311
312    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
313        Ok(Box::new(self.clone()))
314    }
315
316    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
317        // This would convert to DOK format
318        // For now, we'll go via COO format
319        self.to_coo()?.to_dok()
320    }
321
322    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
323        // This would convert to LIL format
324        // For now, we'll go via COO format
325        self.to_coo()?.to_lil()
326    }
327
328    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
329        // This would convert to DIA format
330        // For now, we'll go via COO format
331        self.to_coo()?.to_dia()
332    }
333
334    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
335        // This would convert to BSR format
336        // For now, we'll go via COO format
337        self.to_coo()?.to_bsr()
338    }
339
340    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
341        // For efficiency, convert to CSR format and then add
342        let self_csr = self.to_csr()?;
343        self_csr.add(other)
344    }
345
346    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
347        // For efficiency, convert to CSR format and then subtract
348        let self_csr = self.to_csr()?;
349        self_csr.sub(other)
350    }
351
352    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
353        // Element-wise multiplication (Hadamard product)
354        // Convert to CSR for efficiency
355        let self_csr = self.to_csr()?;
356        self_csr.mul(other)
357    }
358
359    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
360        // Element-wise division
361        // Convert to CSR for efficiency
362        let self_csr = self.to_csr()?;
363        self_csr.div(other)
364    }
365
366    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
367        // Matrix multiplication
368        // Convert to CSR for efficiency
369        let self_csr = self.to_csr()?;
370        self_csr.dot(other)
371    }
372
373    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
374        let (m, n) = self.shape();
375        if n != other.len() {
376            return Err(SparseError::DimensionMismatch {
377                expected: n,
378                found: other.len(),
379            });
380        }
381
382        let mut result = Array1::zeros(m);
383
384        for col in 0..n {
385            let start = self.indptr[col];
386            let end = self.indptr[col + 1];
387
388            let val = other[col];
389            if !SparseElement::is_zero(&val) {
390                for idx in start..end {
391                    let row = self.indices[idx];
392                    result[row] = result[row] + self.data[idx] * val;
393                }
394            }
395        }
396
397        Ok(result)
398    }
399
400    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
401        // CSC transposed is effectively CSR (swap rows/cols)
402        CsrArray::new(
403            self.data.clone(),
404            self.indices.clone(),
405            self.indptr.clone(),
406            (self.shape.1, self.shape.0), // Swap dimensions
407        )
408        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
409    }
410
411    fn copy(&self) -> Box<dyn SparseArray<T>> {
412        Box::new(self.clone())
413    }
414
415    fn get(&self, i: usize, j: usize) -> T {
416        if i >= self.shape.0 || j >= self.shape.1 {
417            return T::sparse_zero();
418        }
419
420        let start = self.indptr[j];
421        let end = self.indptr[j + 1];
422
423        for idx in start..end {
424            if self.indices[idx] == i {
425                return self.data[idx];
426            }
427
428            // If indices are sorted, we can break early
429            if self.has_sorted_indices && self.indices[idx] > i {
430                break;
431            }
432        }
433
434        T::sparse_zero()
435    }
436
437    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
438        if i >= self.shape.0 || j >= self.shape.1 {
439            return Err(SparseError::IndexOutOfBounds {
440                index: (i, j),
441                shape: self.shape,
442            });
443        }
444
445        let start = self.indptr[j];
446        let end = self.indptr[j + 1];
447
448        // Try to find existing element
449        for idx in start..end {
450            if self.indices[idx] == i {
451                self.data[idx] = value;
452                return Ok(());
453            }
454            if self.has_sorted_indices && self.indices[idx] > i {
455                // Insert at position `idx` to maintain sorted order
456                self.data = array1_insert(&self.data, idx, value);
457                self.indices = array1_insert(&self.indices, idx, i);
458                // Increment indptr for all subsequent columns
459                for col_ptr in self.indptr.iter_mut().skip(j + 1) {
460                    *col_ptr += 1;
461                }
462                return Ok(());
463            }
464        }
465
466        // Element not found - insert at end of this column's range
467        self.data = array1_insert(&self.data, end, value);
468        self.indices = array1_insert(&self.indices, end, i);
469        // Increment indptr for all subsequent columns
470        for col_ptr in self.indptr.iter_mut().skip(j + 1) {
471            *col_ptr += 1;
472        }
473        // Re-check sorted state for this column
474        if self.has_sorted_indices {
475            let new_end = self.indptr[j + 1];
476            let new_start = self.indptr[j];
477            for k in new_start..new_end.saturating_sub(1) {
478                if self.indices[k] > self.indices[k + 1] {
479                    self.has_sorted_indices = false;
480                    break;
481                }
482            }
483        }
484        Ok(())
485    }
486
487    fn eliminate_zeros(&mut self) {
488        // Find all non-zero entries
489        let mut new_data = Vec::new();
490        let mut new_indices = Vec::new();
491        let mut new_indptr = vec![0];
492
493        let (_, cols) = self.shape;
494
495        for col in 0..cols {
496            let start = self.indptr[col];
497            let end = self.indptr[col + 1];
498
499            for idx in start..end {
500                if !SparseElement::is_zero(&self.data[idx]) {
501                    new_data.push(self.data[idx]);
502                    new_indices.push(self.indices[idx]);
503                }
504            }
505            new_indptr.push(new_data.len());
506        }
507
508        // Replace data with filtered data
509        self.data = Array1::from_vec(new_data);
510        self.indices = Array1::from_vec(new_indices);
511        self.indptr = Array1::from_vec(new_indptr);
512    }
513
514    fn sort_indices(&mut self) {
515        if self.has_sorted_indices {
516            return;
517        }
518
519        let (_, cols) = self.shape;
520
521        for col in 0..cols {
522            let start = self.indptr[col];
523            let end = self.indptr[col + 1];
524
525            if start == end {
526                continue;
527            }
528
529            // Extract the non-zero elements for this column
530            let mut col_data = Vec::with_capacity(end - start);
531            for idx in start..end {
532                col_data.push((self.indices[idx], self.data[idx]));
533            }
534
535            // Sort by row index
536            col_data.sort_by_key(|&(row_, _)| row_);
537
538            // Put the sorted data back
539            for (i, (row, val)) in col_data.into_iter().enumerate() {
540                self.indices[start + i] = row;
541                self.data[start + i] = val;
542            }
543        }
544
545        self.has_sorted_indices = true;
546    }
547
548    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
549        if self.has_sorted_indices {
550            return Box::new(self.clone());
551        }
552
553        let mut sorted = self.clone();
554        sorted.sort_indices();
555        Box::new(sorted)
556    }
557
558    fn has_sorted_indices(&self) -> bool {
559        self.has_sorted_indices
560    }
561
562    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
563        match axis {
564            None => {
565                // Sum all elements
566                let mut sum = T::sparse_zero();
567                for &val in self.data.iter() {
568                    sum = sum + val;
569                }
570                Ok(SparseSum::Scalar(sum))
571            }
572            Some(0) => {
573                // Sum along rows (result is a row vector)
574                // For efficiency, convert to CSR and sum
575                let self_csr = self.to_csr()?;
576                self_csr.sum(Some(0))
577            }
578            Some(1) => {
579                // Sum along columns (result is a column vector)
580                let mut result = Vec::with_capacity(self.shape.1);
581
582                for col in 0..self.shape.1 {
583                    let start = self.indptr[col];
584                    let end = self.indptr[col + 1];
585
586                    let mut col_sum = T::sparse_zero();
587                    for idx in start..end {
588                        col_sum = col_sum + self.data[idx];
589                    }
590                    result.push(col_sum);
591                }
592
593                // Convert to COO format for the column vector
594                let mut row_indices = Vec::new();
595                let mut col_indices = Vec::new();
596                let mut values = Vec::new();
597
598                for (col, &val) in result.iter().enumerate() {
599                    if !SparseElement::is_zero(&val) {
600                        row_indices.push(0);
601                        col_indices.push(col);
602                        values.push(val);
603                    }
604                }
605
606                let coo = CooArray::from_triplets(
607                    &row_indices,
608                    &col_indices,
609                    &values,
610                    (1, self.shape.1),
611                    true,
612                )?;
613
614                Ok(SparseSum::SparseArray(Box::new(coo)))
615            }
616            _ => Err(SparseError::InvalidAxis),
617        }
618    }
619
620    fn max(&self) -> T {
621        if self.data.is_empty() {
622            // Empty sparse matrix - all elements are implicitly zero
623            return T::sparse_zero();
624        }
625
626        let mut max_val = self.data[0];
627        for &val in self.data.iter().skip(1) {
628            if val > max_val {
629                max_val = val;
630            }
631        }
632
633        // Check if max_val is less than zero, as zeros aren't explicitly stored
634        let zero = T::sparse_zero();
635        if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
636            max_val = zero;
637        }
638
639        max_val
640    }
641
642    fn min(&self) -> T {
643        if self.data.is_empty() {
644            return T::sparse_zero();
645        }
646
647        let mut min_val = self.data[0];
648        for &val in self.data.iter().skip(1) {
649            if val < min_val {
650                min_val = val;
651            }
652        }
653
654        // Check if min_val is greater than zero, as zeros aren't explicitly stored
655        if min_val > T::sparse_zero() && self.nnz() < self.shape.0 * self.shape.1 {
656            min_val = T::sparse_zero();
657        }
658
659        min_val
660    }
661
662    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
663        let nnz = self.nnz();
664        let mut rows = Vec::with_capacity(nnz);
665        let mut cols = Vec::with_capacity(nnz);
666        let mut values = Vec::with_capacity(nnz);
667
668        for col in 0..self.shape.1 {
669            let start = self.indptr[col];
670            let end = self.indptr[col + 1];
671
672            for idx in start..end {
673                let row = self.indices[idx];
674                rows.push(row);
675                cols.push(col);
676                values.push(self.data[idx]);
677            }
678        }
679
680        (
681            Array1::from_vec(rows),
682            Array1::from_vec(cols),
683            Array1::from_vec(values),
684        )
685    }
686
687    fn slice(
688        &self,
689        row_range: (usize, usize),
690        col_range: (usize, usize),
691    ) -> SparseResult<Box<dyn SparseArray<T>>> {
692        let (start_row, end_row) = row_range;
693        let (start_col, end_col) = col_range;
694
695        if start_row >= self.shape.0
696            || end_row > self.shape.0
697            || start_col >= self.shape.1
698            || end_col > self.shape.1
699        {
700            return Err(SparseError::InvalidSliceRange);
701        }
702
703        if start_row >= end_row || start_col >= end_col {
704            return Err(SparseError::InvalidSliceRange);
705        }
706
707        // CSC format is efficient for column slicing
708        let mut data = Vec::new();
709        let mut indices = Vec::new();
710        let mut indptr = vec![0];
711
712        for col in start_col..end_col {
713            let start = self.indptr[col];
714            let end = self.indptr[col + 1];
715
716            for idx in start..end {
717                let row = self.indices[idx];
718                if row >= start_row && row < end_row {
719                    data.push(self.data[idx]);
720                    indices.push(row - start_row); // Adjust indices
721                }
722            }
723            indptr.push(data.len());
724        }
725
726        CscArray::new(
727            Array1::from_vec(data),
728            Array1::from_vec(indices),
729            Array1::from_vec(indptr),
730            (end_row - start_row, end_col - start_col),
731        )
732        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
733    }
734
735    fn as_any(&self) -> &dyn std::any::Any {
736        self
737    }
738
739    fn get_indptr(&self) -> Option<&Array1<usize>> {
740        Some(&self.indptr)
741    }
742
743    fn indptr(&self) -> Option<&Array1<usize>> {
744        Some(&self.indptr)
745    }
746}
747
748impl<T> fmt::Debug for CscArray<T>
749where
750    T: SparseElement + Div<Output = T> + Float + 'static,
751{
752    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
753        write!(
754            f,
755            "CscArray<{}x{}, nnz={}>",
756            self.shape.0,
757            self.shape.1,
758            self.nnz()
759        )
760    }
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766    use approx::assert_relative_eq;
767
768    #[test]
769    fn test_csc_array_construction() {
770        let data = Array1::from_vec(vec![1.0, 4.0, 2.0, 3.0, 5.0]);
771        let indices = Array1::from_vec(vec![0, 2, 0, 1, 2]);
772        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
773        let shape = (3, 3);
774
775        let csc = CscArray::new(data, indices, indptr, shape).expect("Operation failed");
776
777        assert_eq!(csc.shape(), (3, 3));
778        assert_eq!(csc.nnz(), 5);
779        assert_eq!(csc.get(0, 0), 1.0);
780        assert_eq!(csc.get(2, 0), 4.0);
781        assert_eq!(csc.get(0, 1), 2.0);
782        assert_eq!(csc.get(1, 2), 3.0);
783        assert_eq!(csc.get(2, 2), 5.0);
784        assert_eq!(csc.get(1, 0), 0.0);
785    }
786
787    #[test]
788    fn test_csc_from_triplets() {
789        let rows = vec![0, 2, 0, 1, 2];
790        let cols = vec![0, 0, 1, 2, 2];
791        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
792        let shape = (3, 3);
793
794        let csc =
795            CscArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
796
797        assert_eq!(csc.shape(), (3, 3));
798        assert_eq!(csc.nnz(), 5);
799        assert_eq!(csc.get(0, 0), 1.0);
800        assert_eq!(csc.get(2, 0), 4.0);
801        assert_eq!(csc.get(0, 1), 2.0);
802        assert_eq!(csc.get(1, 2), 3.0);
803        assert_eq!(csc.get(2, 2), 5.0);
804        assert_eq!(csc.get(1, 0), 0.0);
805    }
806
807    #[test]
808    fn test_csc_array_to_array() {
809        let rows = vec![0, 2, 0, 1, 2];
810        let cols = vec![0, 0, 1, 2, 2];
811        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
812        let shape = (3, 3);
813
814        let csc =
815            CscArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
816        let dense = csc.to_array();
817
818        assert_eq!(dense.shape(), &[3, 3]);
819        assert_eq!(dense[[0, 0]], 1.0);
820        assert_eq!(dense[[1, 0]], 0.0);
821        assert_eq!(dense[[2, 0]], 4.0);
822        assert_eq!(dense[[0, 1]], 2.0);
823        assert_eq!(dense[[1, 1]], 0.0);
824        assert_eq!(dense[[2, 1]], 0.0);
825        assert_eq!(dense[[0, 2]], 0.0);
826        assert_eq!(dense[[1, 2]], 3.0);
827        assert_eq!(dense[[2, 2]], 5.0);
828    }
829
830    #[test]
831    fn test_csc_to_csr_conversion() {
832        let rows = vec![0, 2, 0, 1, 2];
833        let cols = vec![0, 0, 1, 2, 2];
834        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
835        let shape = (3, 3);
836
837        let csc =
838            CscArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
839        let csr = csc.to_csr().expect("Operation failed");
840
841        // Check that the conversion preserved values
842        let csc_array = csc.to_array();
843        let csr_array = csr.to_array();
844
845        for i in 0..shape.0 {
846            for j in 0..shape.1 {
847                assert_relative_eq!(csc_array[[i, j]], csr_array[[i, j]]);
848            }
849        }
850    }
851
852    #[test]
853    fn test_csc_dot_vector() {
854        let rows = vec![0, 2, 0, 1, 2];
855        let cols = vec![0, 0, 1, 2, 2];
856        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
857        let shape = (3, 3);
858
859        let csc =
860            CscArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
861        let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
862
863        let result = csc.dot_vector(&vec.view()).expect("Operation failed");
864
865        // Expected:
866        // [0,0]*1 + [0,1]*2 + [0,2]*3 = 1*1 + 2*2 + 0*3 = 5
867        // [1,0]*1 + [1,1]*2 + [1,2]*3 = 0*1 + 0*2 + 3*3 = 9
868        // [2,0]*1 + [2,1]*2 + [2,2]*3 = 4*1 + 0*2 + 5*3 = 19
869        assert_eq!(result.len(), 3);
870        assert_relative_eq!(result[0], 5.0);
871        assert_relative_eq!(result[1], 9.0);
872        assert_relative_eq!(result[2], 19.0);
873    }
874
875    #[test]
876    fn test_csc_transpose() {
877        let rows = vec![0, 2, 0, 1, 2];
878        let cols = vec![0, 0, 1, 2, 2];
879        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
880        let shape = (3, 3);
881
882        let csc =
883            CscArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
884        let transposed = csc.transpose().expect("Operation failed");
885
886        // Check dimensions are swapped
887        assert_eq!(transposed.shape(), (3, 3));
888
889        // Check values are correctly transposed
890        let dense = transposed.to_array();
891        assert_eq!(dense[[0, 0]], 1.0);
892        assert_eq!(dense[[0, 2]], 4.0);
893        assert_eq!(dense[[1, 0]], 2.0);
894        assert_eq!(dense[[2, 1]], 3.0);
895        assert_eq!(dense[[2, 2]], 5.0);
896    }
897
898    #[test]
899    fn test_csc_find() {
900        let rows = vec![0, 2, 0, 1, 2];
901        let cols = vec![0, 0, 1, 2, 2];
902        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
903        let shape = (3, 3);
904
905        let csc =
906            CscArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
907        let (result_rows, result_cols, result_data) = csc.find();
908
909        // Check that the find operation returned all non-zero elements
910        assert_eq!(result_rows.len(), 5);
911        assert_eq!(result_cols.len(), 5);
912        assert_eq!(result_data.len(), 5);
913
914        // Create vectors of tuples to compare
915        let mut original: Vec<_> = rows
916            .iter()
917            .zip(cols.iter())
918            .zip(data.iter())
919            .map(|((r, c), d)| (*r, *c, *d))
920            .collect();
921
922        let mut result: Vec<_> = result_rows
923            .iter()
924            .zip(result_cols.iter())
925            .zip(result_data.iter())
926            .map(|((r, c), d)| (*r, *c, *d))
927            .collect();
928
929        // Sort the vectors before comparing
930        original.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
931        result.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
932
933        assert_eq!(original, result);
934    }
935}