scirs2_sparse/
csc_array.rs

1// CSC Array implementation
2//
3// This module provides the CSC (Compressed Sparse Column) array format,
4// which is efficient for column-wise operations.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::error::{SparseError, SparseResult};
14use crate::sparray::{SparseArray, SparseSum};
15
16/// CSC Array format
17///
18/// The CSC (Compressed Sparse Column) format stores a sparse array in three arrays:
19/// - data: array of non-zero values
20/// - indices: row indices of the non-zero values
21/// - indptr: index pointers; for each column, points to the first non-zero element
22///
23/// # Notes
24///
25/// - Efficient for column-oriented operations
26/// - Fast matrix-vector multiplications for A^T x
27/// - Fast column slicing
28/// - Slow row slicing
29/// - Slow constructing by setting individual elements
30///
31#[derive(Clone)]
32pub struct CscArray<T>
33where
34    T: SparseElement + Div<Output = T> + 'static,
35{
36    /// Non-zero values
37    data: Array1<T>,
38    /// Row indices of non-zero values
39    indices: Array1<usize>,
40    /// Column pointers (indices into data/indices for the start of each column)
41    indptr: Array1<usize>,
42    /// Shape of the sparse array
43    shape: (usize, usize),
44    /// Whether indices are sorted for each column
45    has_sorted_indices: bool,
46}
47
48impl<T> CscArray<T>
49where
50    T: SparseElement + Div<Output = T> + Zero + 'static,
51{
52    /// Creates a new CSC array from raw components
53    ///
54    /// # Arguments
55    /// * `data` - Array of non-zero values
56    /// * `indices` - Row indices of non-zero values
57    /// * `indptr` - Index pointers for the start of each column
58    /// * `shape` - Shape of the sparse array
59    ///
60    /// # Returns
61    /// A new `CscArray`
62    ///
63    /// # Errors
64    /// Returns an error if the data is not consistent
65    pub fn new(
66        data: Array1<T>,
67        indices: Array1<usize>,
68        indptr: Array1<usize>,
69        shape: (usize, usize),
70    ) -> SparseResult<Self> {
71        // Validation
72        if data.len() != indices.len() {
73            return Err(SparseError::InconsistentData {
74                reason: "data and indices must have the same length".to_string(),
75            });
76        }
77
78        if indptr.len() != shape.1 + 1 {
79            return Err(SparseError::InconsistentData {
80                reason: format!(
81                    "indptr length ({}) must be one more than the number of columns ({})",
82                    indptr.len(),
83                    shape.1
84                ),
85            });
86        }
87
88        if let Some(&max_idx) = indices.iter().max() {
89            if max_idx >= shape.0 {
90                return Err(SparseError::IndexOutOfBounds {
91                    index: (max_idx, 0),
92                    shape,
93                });
94            }
95        }
96
97        if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
98            if first != 0 {
99                return Err(SparseError::InconsistentData {
100                    reason: "first element of indptr must be 0".to_string(),
101                });
102            }
103
104            if last != data.len() {
105                return Err(SparseError::InconsistentData {
106                    reason: format!(
107                        "last element of indptr ({}) must equal data length ({})",
108                        last,
109                        data.len()
110                    ),
111                });
112            }
113        }
114
115        let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
116
117        Ok(Self {
118            data,
119            indices,
120            indptr,
121            shape,
122            has_sorted_indices,
123        })
124    }
125
126    /// Create a CSC array from triplet format (COO-like)
127    ///
128    /// # Arguments
129    /// * `rows` - Row indices
130    /// * `cols` - Column indices
131    /// * `data` - Values
132    /// * `shape` - Shape of the sparse array
133    /// * `sorted` - Whether the triplets are sorted by column
134    ///
135    /// # Returns
136    /// A new `CscArray`
137    ///
138    /// # Errors
139    /// Returns an error if the data is not consistent
140    pub fn from_triplets(
141        rows: &[usize],
142        cols: &[usize],
143        data: &[T],
144        shape: (usize, usize),
145        sorted: bool,
146    ) -> SparseResult<Self> {
147        if rows.len() != cols.len() || rows.len() != data.len() {
148            return Err(SparseError::InconsistentData {
149                reason: "rows, cols, and data must have the same length".to_string(),
150            });
151        }
152
153        if rows.is_empty() {
154            // Empty matrix
155            let indptr = Array1::zeros(shape.1 + 1);
156            return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
157        }
158
159        let nnz = rows.len();
160        let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
161
162        for i in 0..nnz {
163            if rows[i] >= shape.0 || cols[i] >= shape.1 {
164                return Err(SparseError::IndexOutOfBounds {
165                    index: (rows[i], cols[i]),
166                    shape,
167                });
168            }
169            all_data.push((rows[i], cols[i], data[i]));
170        }
171
172        if !sorted {
173            all_data.sort_by_key(|&(_, col_, _)| col_);
174        }
175
176        // Count elements per column
177        let mut col_counts = vec![0; shape.1];
178        for &(_, col_, _) in &all_data {
179            col_counts[col_] += 1;
180        }
181
182        // Create indptr
183        let mut indptr = Vec::with_capacity(shape.1 + 1);
184        indptr.push(0);
185        let mut cumsum = 0;
186        for &count in &col_counts {
187            cumsum += count;
188            indptr.push(cumsum);
189        }
190
191        // Create indices and data arrays
192        let mut indices = Vec::with_capacity(nnz);
193        let mut values = Vec::with_capacity(nnz);
194
195        for (row_, _, val) in all_data {
196            indices.push(row_);
197            values.push(val);
198        }
199
200        Self::new(
201            Array1::from_vec(values),
202            Array1::from_vec(indices),
203            Array1::from_vec(indptr),
204            shape,
205        )
206    }
207
208    /// Checks if row indices are sorted for each column
209    fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
210        for col in 0..indptr.len() - 1 {
211            let start = indptr[col];
212            let end = indptr[col + 1];
213
214            for i in start..end.saturating_sub(1) {
215                if i + 1 < indices.len() && indices[i] > indices[i + 1] {
216                    return false;
217                }
218            }
219        }
220        true
221    }
222
223    /// Get the raw data array
224    pub fn get_data(&self) -> &Array1<T> {
225        &self.data
226    }
227
228    /// Get the raw indices array
229    pub fn get_indices(&self) -> &Array1<usize> {
230        &self.indices
231    }
232
233    /// Get the raw indptr array
234    pub fn get_indptr(&self) -> &Array1<usize> {
235        &self.indptr
236    }
237}
238
239impl<T> SparseArray<T> for CscArray<T>
240where
241    T: SparseElement + Div<Output = T> + Float + 'static,
242{
243    fn shape(&self) -> (usize, usize) {
244        self.shape
245    }
246
247    fn nnz(&self) -> usize {
248        self.data.len()
249    }
250
251    fn dtype(&self) -> &str {
252        "float" // Placeholder, ideally we would return the actual type
253    }
254
255    fn to_array(&self) -> Array2<T> {
256        let (rows, cols) = self.shape;
257        let mut result = Array2::zeros((rows, cols));
258
259        for col in 0..cols {
260            let start = self.indptr[col];
261            let end = self.indptr[col + 1];
262
263            for i in start..end {
264                let row = self.indices[i];
265                result[[row, col]] = self.data[i];
266            }
267        }
268
269        result
270    }
271
272    fn toarray(&self) -> Array2<T> {
273        self.to_array()
274    }
275
276    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
277        // Convert to COO format
278        let nnz = self.nnz();
279        let mut row_indices = Vec::with_capacity(nnz);
280        let mut col_indices = Vec::with_capacity(nnz);
281        let mut values = Vec::with_capacity(nnz);
282
283        for col in 0..self.shape.1 {
284            let start = self.indptr[col];
285            let end = self.indptr[col + 1];
286
287            for idx in start..end {
288                row_indices.push(self.indices[idx]);
289                col_indices.push(col);
290                values.push(self.data[idx]);
291            }
292        }
293
294        CooArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
295            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
296    }
297
298    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
299        // For efficiency, we'll go via COO format
300        self.to_coo()?.to_csr()
301    }
302
303    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
304        Ok(Box::new(self.clone()))
305    }
306
307    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
308        // This would convert to DOK format
309        // For now, we'll go via COO format
310        self.to_coo()?.to_dok()
311    }
312
313    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
314        // This would convert to LIL format
315        // For now, we'll go via COO format
316        self.to_coo()?.to_lil()
317    }
318
319    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
320        // This would convert to DIA format
321        // For now, we'll go via COO format
322        self.to_coo()?.to_dia()
323    }
324
325    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
326        // This would convert to BSR format
327        // For now, we'll go via COO format
328        self.to_coo()?.to_bsr()
329    }
330
331    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
332        // For efficiency, convert to CSR format and then add
333        let self_csr = self.to_csr()?;
334        self_csr.add(other)
335    }
336
337    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
338        // For efficiency, convert to CSR format and then subtract
339        let self_csr = self.to_csr()?;
340        self_csr.sub(other)
341    }
342
343    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
344        // Element-wise multiplication (Hadamard product)
345        // Convert to CSR for efficiency
346        let self_csr = self.to_csr()?;
347        self_csr.mul(other)
348    }
349
350    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
351        // Element-wise division
352        // Convert to CSR for efficiency
353        let self_csr = self.to_csr()?;
354        self_csr.div(other)
355    }
356
357    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
358        // Matrix multiplication
359        // Convert to CSR for efficiency
360        let self_csr = self.to_csr()?;
361        self_csr.dot(other)
362    }
363
364    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
365        let (m, n) = self.shape();
366        if n != other.len() {
367            return Err(SparseError::DimensionMismatch {
368                expected: n,
369                found: other.len(),
370            });
371        }
372
373        let mut result = Array1::zeros(m);
374
375        for col in 0..n {
376            let start = self.indptr[col];
377            let end = self.indptr[col + 1];
378
379            let val = other[col];
380            if !SparseElement::is_zero(&val) {
381                for idx in start..end {
382                    let row = self.indices[idx];
383                    result[row] = result[row] + self.data[idx] * val;
384                }
385            }
386        }
387
388        Ok(result)
389    }
390
391    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
392        // CSC transposed is effectively CSR (swap rows/cols)
393        CsrArray::new(
394            self.data.clone(),
395            self.indices.clone(),
396            self.indptr.clone(),
397            (self.shape.1, self.shape.0), // Swap dimensions
398        )
399        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
400    }
401
402    fn copy(&self) -> Box<dyn SparseArray<T>> {
403        Box::new(self.clone())
404    }
405
406    fn get(&self, i: usize, j: usize) -> T {
407        if i >= self.shape.0 || j >= self.shape.1 {
408            return T::sparse_zero();
409        }
410
411        let start = self.indptr[j];
412        let end = self.indptr[j + 1];
413
414        for idx in start..end {
415            if self.indices[idx] == i {
416                return self.data[idx];
417            }
418
419            // If indices are sorted, we can break early
420            if self.has_sorted_indices && self.indices[idx] > i {
421                break;
422            }
423        }
424
425        T::sparse_zero()
426    }
427
428    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
429        // Setting elements in CSC format is non-trivial
430        // This is a placeholder implementation that doesn't actually modify the structure
431        if i >= self.shape.0 || j >= self.shape.1 {
432            return Err(SparseError::IndexOutOfBounds {
433                index: (i, j),
434                shape: self.shape,
435            });
436        }
437
438        let start = self.indptr[j];
439        let end = self.indptr[j + 1];
440
441        // Try to find existing element
442        for idx in start..end {
443            if self.indices[idx] == i {
444                // Update value
445                self.data[idx] = value;
446                return Ok(());
447            }
448
449            // If indices are sorted, we can insert at the right position
450            if self.has_sorted_indices && self.indices[idx] > i {
451                // Insert here - this would require restructuring indices and data
452                // Not implemented in this placeholder
453                return Err(SparseError::NotImplemented(
454                    "Inserting new elements in CSC format".to_string(),
455                ));
456            }
457        }
458
459        // Element not found, would need to insert
460        // This would require restructuring indices and data
461        Err(SparseError::NotImplemented(
462            "Inserting new elements in CSC format".to_string(),
463        ))
464    }
465
466    fn eliminate_zeros(&mut self) {
467        // Find all non-zero entries
468        let mut new_data = Vec::new();
469        let mut new_indices = Vec::new();
470        let mut new_indptr = vec![0];
471
472        let (_, cols) = self.shape;
473
474        for col in 0..cols {
475            let start = self.indptr[col];
476            let end = self.indptr[col + 1];
477
478            for idx in start..end {
479                if !SparseElement::is_zero(&self.data[idx]) {
480                    new_data.push(self.data[idx]);
481                    new_indices.push(self.indices[idx]);
482                }
483            }
484            new_indptr.push(new_data.len());
485        }
486
487        // Replace data with filtered data
488        self.data = Array1::from_vec(new_data);
489        self.indices = Array1::from_vec(new_indices);
490        self.indptr = Array1::from_vec(new_indptr);
491    }
492
493    fn sort_indices(&mut self) {
494        if self.has_sorted_indices {
495            return;
496        }
497
498        let (_, cols) = self.shape;
499
500        for col in 0..cols {
501            let start = self.indptr[col];
502            let end = self.indptr[col + 1];
503
504            if start == end {
505                continue;
506            }
507
508            // Extract the non-zero elements for this column
509            let mut col_data = Vec::with_capacity(end - start);
510            for idx in start..end {
511                col_data.push((self.indices[idx], self.data[idx]));
512            }
513
514            // Sort by row index
515            col_data.sort_by_key(|&(row_, _)| row_);
516
517            // Put the sorted data back
518            for (i, (row, val)) in col_data.into_iter().enumerate() {
519                self.indices[start + i] = row;
520                self.data[start + i] = val;
521            }
522        }
523
524        self.has_sorted_indices = true;
525    }
526
527    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
528        if self.has_sorted_indices {
529            return Box::new(self.clone());
530        }
531
532        let mut sorted = self.clone();
533        sorted.sort_indices();
534        Box::new(sorted)
535    }
536
537    fn has_sorted_indices(&self) -> bool {
538        self.has_sorted_indices
539    }
540
541    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
542        match axis {
543            None => {
544                // Sum all elements
545                let mut sum = T::sparse_zero();
546                for &val in self.data.iter() {
547                    sum = sum + val;
548                }
549                Ok(SparseSum::Scalar(sum))
550            }
551            Some(0) => {
552                // Sum along rows (result is a row vector)
553                // For efficiency, convert to CSR and sum
554                let self_csr = self.to_csr()?;
555                self_csr.sum(Some(0))
556            }
557            Some(1) => {
558                // Sum along columns (result is a column vector)
559                let mut result = Vec::with_capacity(self.shape.1);
560
561                for col in 0..self.shape.1 {
562                    let start = self.indptr[col];
563                    let end = self.indptr[col + 1];
564
565                    let mut col_sum = T::sparse_zero();
566                    for idx in start..end {
567                        col_sum = col_sum + self.data[idx];
568                    }
569                    result.push(col_sum);
570                }
571
572                // Convert to COO format for the column vector
573                let mut row_indices = Vec::new();
574                let mut col_indices = Vec::new();
575                let mut values = Vec::new();
576
577                for (col, &val) in result.iter().enumerate() {
578                    if !SparseElement::is_zero(&val) {
579                        row_indices.push(0);
580                        col_indices.push(col);
581                        values.push(val);
582                    }
583                }
584
585                let coo = CooArray::from_triplets(
586                    &row_indices,
587                    &col_indices,
588                    &values,
589                    (1, self.shape.1),
590                    true,
591                )?;
592
593                Ok(SparseSum::SparseArray(Box::new(coo)))
594            }
595            _ => Err(SparseError::InvalidAxis),
596        }
597    }
598
599    fn max(&self) -> T {
600        if self.data.is_empty() {
601            // Empty sparse matrix - all elements are implicitly zero
602            return T::sparse_zero();
603        }
604
605        let mut max_val = self.data[0];
606        for &val in self.data.iter().skip(1) {
607            if val > max_val {
608                max_val = val;
609            }
610        }
611
612        // Check if max_val is less than zero, as zeros aren't explicitly stored
613        let zero = T::sparse_zero();
614        if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
615            max_val = zero;
616        }
617
618        max_val
619    }
620
621    fn min(&self) -> T {
622        if self.data.is_empty() {
623            return T::sparse_zero();
624        }
625
626        let mut min_val = self.data[0];
627        for &val in self.data.iter().skip(1) {
628            if val < min_val {
629                min_val = val;
630            }
631        }
632
633        // Check if min_val is greater than zero, as zeros aren't explicitly stored
634        if min_val > T::sparse_zero() && self.nnz() < self.shape.0 * self.shape.1 {
635            min_val = T::sparse_zero();
636        }
637
638        min_val
639    }
640
641    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
642        let nnz = self.nnz();
643        let mut rows = Vec::with_capacity(nnz);
644        let mut cols = Vec::with_capacity(nnz);
645        let mut values = Vec::with_capacity(nnz);
646
647        for col in 0..self.shape.1 {
648            let start = self.indptr[col];
649            let end = self.indptr[col + 1];
650
651            for idx in start..end {
652                let row = self.indices[idx];
653                rows.push(row);
654                cols.push(col);
655                values.push(self.data[idx]);
656            }
657        }
658
659        (
660            Array1::from_vec(rows),
661            Array1::from_vec(cols),
662            Array1::from_vec(values),
663        )
664    }
665
666    fn slice(
667        &self,
668        row_range: (usize, usize),
669        col_range: (usize, usize),
670    ) -> SparseResult<Box<dyn SparseArray<T>>> {
671        let (start_row, end_row) = row_range;
672        let (start_col, end_col) = col_range;
673
674        if start_row >= self.shape.0
675            || end_row > self.shape.0
676            || start_col >= self.shape.1
677            || end_col > self.shape.1
678        {
679            return Err(SparseError::InvalidSliceRange);
680        }
681
682        if start_row >= end_row || start_col >= end_col {
683            return Err(SparseError::InvalidSliceRange);
684        }
685
686        // CSC format is efficient for column slicing
687        let mut data = Vec::new();
688        let mut indices = Vec::new();
689        let mut indptr = vec![0];
690
691        for col in start_col..end_col {
692            let start = self.indptr[col];
693            let end = self.indptr[col + 1];
694
695            for idx in start..end {
696                let row = self.indices[idx];
697                if row >= start_row && row < end_row {
698                    data.push(self.data[idx]);
699                    indices.push(row - start_row); // Adjust indices
700                }
701            }
702            indptr.push(data.len());
703        }
704
705        CscArray::new(
706            Array1::from_vec(data),
707            Array1::from_vec(indices),
708            Array1::from_vec(indptr),
709            (end_row - start_row, end_col - start_col),
710        )
711        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
712    }
713
714    fn as_any(&self) -> &dyn std::any::Any {
715        self
716    }
717
718    fn get_indptr(&self) -> Option<&Array1<usize>> {
719        Some(&self.indptr)
720    }
721
722    fn indptr(&self) -> Option<&Array1<usize>> {
723        Some(&self.indptr)
724    }
725}
726
727impl<T> fmt::Debug for CscArray<T>
728where
729    T: SparseElement + Div<Output = T> + Float + 'static,
730{
731    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
732        write!(
733            f,
734            "CscArray<{}x{}, nnz={}>",
735            self.shape.0,
736            self.shape.1,
737            self.nnz()
738        )
739    }
740}
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745    use approx::assert_relative_eq;
746
747    #[test]
748    fn test_csc_array_construction() {
749        let data = Array1::from_vec(vec![1.0, 4.0, 2.0, 3.0, 5.0]);
750        let indices = Array1::from_vec(vec![0, 2, 0, 1, 2]);
751        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
752        let shape = (3, 3);
753
754        let csc = CscArray::new(data, indices, indptr, shape).unwrap();
755
756        assert_eq!(csc.shape(), (3, 3));
757        assert_eq!(csc.nnz(), 5);
758        assert_eq!(csc.get(0, 0), 1.0);
759        assert_eq!(csc.get(2, 0), 4.0);
760        assert_eq!(csc.get(0, 1), 2.0);
761        assert_eq!(csc.get(1, 2), 3.0);
762        assert_eq!(csc.get(2, 2), 5.0);
763        assert_eq!(csc.get(1, 0), 0.0);
764    }
765
766    #[test]
767    fn test_csc_from_triplets() {
768        let rows = vec![0, 2, 0, 1, 2];
769        let cols = vec![0, 0, 1, 2, 2];
770        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
771        let shape = (3, 3);
772
773        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
774
775        assert_eq!(csc.shape(), (3, 3));
776        assert_eq!(csc.nnz(), 5);
777        assert_eq!(csc.get(0, 0), 1.0);
778        assert_eq!(csc.get(2, 0), 4.0);
779        assert_eq!(csc.get(0, 1), 2.0);
780        assert_eq!(csc.get(1, 2), 3.0);
781        assert_eq!(csc.get(2, 2), 5.0);
782        assert_eq!(csc.get(1, 0), 0.0);
783    }
784
785    #[test]
786    fn test_csc_array_to_array() {
787        let rows = vec![0, 2, 0, 1, 2];
788        let cols = vec![0, 0, 1, 2, 2];
789        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
790        let shape = (3, 3);
791
792        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
793        let dense = csc.to_array();
794
795        assert_eq!(dense.shape(), &[3, 3]);
796        assert_eq!(dense[[0, 0]], 1.0);
797        assert_eq!(dense[[1, 0]], 0.0);
798        assert_eq!(dense[[2, 0]], 4.0);
799        assert_eq!(dense[[0, 1]], 2.0);
800        assert_eq!(dense[[1, 1]], 0.0);
801        assert_eq!(dense[[2, 1]], 0.0);
802        assert_eq!(dense[[0, 2]], 0.0);
803        assert_eq!(dense[[1, 2]], 3.0);
804        assert_eq!(dense[[2, 2]], 5.0);
805    }
806
807    #[test]
808    fn test_csc_to_csr_conversion() {
809        let rows = vec![0, 2, 0, 1, 2];
810        let cols = vec![0, 0, 1, 2, 2];
811        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
812        let shape = (3, 3);
813
814        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
815        let csr = csc.to_csr().unwrap();
816
817        // Check that the conversion preserved values
818        let csc_array = csc.to_array();
819        let csr_array = csr.to_array();
820
821        for i in 0..shape.0 {
822            for j in 0..shape.1 {
823                assert_relative_eq!(csc_array[[i, j]], csr_array[[i, j]]);
824            }
825        }
826    }
827
828    #[test]
829    fn test_csc_dot_vector() {
830        let rows = vec![0, 2, 0, 1, 2];
831        let cols = vec![0, 0, 1, 2, 2];
832        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
833        let shape = (3, 3);
834
835        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
836        let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
837
838        let result = csc.dot_vector(&vec.view()).unwrap();
839
840        // Expected:
841        // [0,0]*1 + [0,1]*2 + [0,2]*3 = 1*1 + 2*2 + 0*3 = 5
842        // [1,0]*1 + [1,1]*2 + [1,2]*3 = 0*1 + 0*2 + 3*3 = 9
843        // [2,0]*1 + [2,1]*2 + [2,2]*3 = 4*1 + 0*2 + 5*3 = 19
844        assert_eq!(result.len(), 3);
845        assert_relative_eq!(result[0], 5.0);
846        assert_relative_eq!(result[1], 9.0);
847        assert_relative_eq!(result[2], 19.0);
848    }
849
850    #[test]
851    fn test_csc_transpose() {
852        let rows = vec![0, 2, 0, 1, 2];
853        let cols = vec![0, 0, 1, 2, 2];
854        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
855        let shape = (3, 3);
856
857        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
858        let transposed = csc.transpose().unwrap();
859
860        // Check dimensions are swapped
861        assert_eq!(transposed.shape(), (3, 3));
862
863        // Check values are correctly transposed
864        let dense = transposed.to_array();
865        assert_eq!(dense[[0, 0]], 1.0);
866        assert_eq!(dense[[0, 2]], 4.0);
867        assert_eq!(dense[[1, 0]], 2.0);
868        assert_eq!(dense[[2, 1]], 3.0);
869        assert_eq!(dense[[2, 2]], 5.0);
870    }
871
872    #[test]
873    fn test_csc_find() {
874        let rows = vec![0, 2, 0, 1, 2];
875        let cols = vec![0, 0, 1, 2, 2];
876        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
877        let shape = (3, 3);
878
879        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
880        let (result_rows, result_cols, result_data) = csc.find();
881
882        // Check that the find operation returned all non-zero elements
883        assert_eq!(result_rows.len(), 5);
884        assert_eq!(result_cols.len(), 5);
885        assert_eq!(result_data.len(), 5);
886
887        // Create vectors of tuples to compare
888        let mut original: Vec<_> = rows
889            .iter()
890            .zip(cols.iter())
891            .zip(data.iter())
892            .map(|((r, c), d)| (*r, *c, *d))
893            .collect();
894
895        let mut result: Vec<_> = result_rows
896            .iter()
897            .zip(result_cols.iter())
898            .zip(result_data.iter())
899            .map(|((r, c), d)| (*r, *c, *d))
900            .collect();
901
902        // Sort the vectors before comparing
903        original.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
904        result.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
905
906        assert_eq!(original, result);
907    }
908}