scirs2_sparse/
csc_array.rs

1// CSC Array implementation
2//
3// This module provides the CSC (Compressed Sparse Column) array format,
4// which is efficient for column-wise operations.
5
6use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csr_array::CsrArray;
13use crate::error::{SparseError, SparseResult};
14use crate::sparray::{SparseArray, SparseSum};
15
16/// CSC Array format
17///
18/// The CSC (Compressed Sparse Column) format stores a sparse array in three arrays:
19/// - data: array of non-zero values
20/// - indices: row indices of the non-zero values
21/// - indptr: index pointers; for each column, points to the first non-zero element
22///
23/// # Notes
24///
25/// - Efficient for column-oriented operations
26/// - Fast matrix-vector multiplications for A^T x
27/// - Fast column slicing
28/// - Slow row slicing
29/// - Slow constructing by setting individual elements
30///
31#[derive(Clone)]
32pub struct CscArray<T>
33where
34    T: Float
35        + Add<Output = T>
36        + Sub<Output = T>
37        + Mul<Output = T>
38        + Div<Output = T>
39        + Debug
40        + Copy
41        + 'static,
42{
43    /// Non-zero values
44    data: Array1<T>,
45    /// Row indices of non-zero values
46    indices: Array1<usize>,
47    /// Column pointers (indices into data/indices for the start of each column)
48    indptr: Array1<usize>,
49    /// Shape of the sparse array
50    shape: (usize, usize),
51    /// Whether indices are sorted for each column
52    has_sorted_indices: bool,
53}
54
55impl<T> CscArray<T>
56where
57    T: Float
58        + Add<Output = T>
59        + Sub<Output = T>
60        + Mul<Output = T>
61        + Div<Output = T>
62        + Debug
63        + Copy
64        + 'static,
65{
66    /// Creates a new CSC array from raw components
67    ///
68    /// # Arguments
69    /// * `data` - Array of non-zero values
70    /// * `indices` - Row indices of non-zero values
71    /// * `indptr` - Index pointers for the start of each column
72    /// * `shape` - Shape of the sparse array
73    ///
74    /// # Returns
75    /// A new `CscArray`
76    ///
77    /// # Errors
78    /// Returns an error if the data is not consistent
79    pub fn new(
80        data: Array1<T>,
81        indices: Array1<usize>,
82        indptr: Array1<usize>,
83        shape: (usize, usize),
84    ) -> SparseResult<Self> {
85        // Validation
86        if data.len() != indices.len() {
87            return Err(SparseError::InconsistentData {
88                reason: "data and indices must have the same length".to_string(),
89            });
90        }
91
92        if indptr.len() != shape.1 + 1 {
93            return Err(SparseError::InconsistentData {
94                reason: format!(
95                    "indptr length ({}) must be one more than the number of columns ({})",
96                    indptr.len(),
97                    shape.1
98                ),
99            });
100        }
101
102        if let Some(&max_idx) = indices.iter().max() {
103            if max_idx >= shape.0 {
104                return Err(SparseError::IndexOutOfBounds {
105                    index: (max_idx, 0),
106                    shape,
107                });
108            }
109        }
110
111        if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
112            if first != 0 {
113                return Err(SparseError::InconsistentData {
114                    reason: "first element of indptr must be 0".to_string(),
115                });
116            }
117
118            if last != data.len() {
119                return Err(SparseError::InconsistentData {
120                    reason: format!(
121                        "last element of indptr ({}) must equal data length ({})",
122                        last,
123                        data.len()
124                    ),
125                });
126            }
127        }
128
129        let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
130
131        Ok(Self {
132            data,
133            indices,
134            indptr,
135            shape,
136            has_sorted_indices,
137        })
138    }
139
140    /// Create a CSC array from triplet format (COO-like)
141    ///
142    /// # Arguments
143    /// * `rows` - Row indices
144    /// * `cols` - Column indices
145    /// * `data` - Values
146    /// * `shape` - Shape of the sparse array
147    /// * `sorted` - Whether the triplets are sorted by column
148    ///
149    /// # Returns
150    /// A new `CscArray`
151    ///
152    /// # Errors
153    /// Returns an error if the data is not consistent
154    pub fn from_triplets(
155        rows: &[usize],
156        cols: &[usize],
157        data: &[T],
158        shape: (usize, usize),
159        sorted: bool,
160    ) -> SparseResult<Self> {
161        if rows.len() != cols.len() || rows.len() != data.len() {
162            return Err(SparseError::InconsistentData {
163                reason: "rows, cols, and data must have the same length".to_string(),
164            });
165        }
166
167        if rows.is_empty() {
168            // Empty matrix
169            let indptr = Array1::zeros(shape.1 + 1);
170            return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
171        }
172
173        let nnz = rows.len();
174        let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
175
176        for i in 0..nnz {
177            if rows[i] >= shape.0 || cols[i] >= shape.1 {
178                return Err(SparseError::IndexOutOfBounds {
179                    index: (rows[i], cols[i]),
180                    shape,
181                });
182            }
183            all_data.push((rows[i], cols[i], data[i]));
184        }
185
186        if !sorted {
187            all_data.sort_by_key(|&(_, col_, _)| col_);
188        }
189
190        // Count elements per column
191        let mut col_counts = vec![0; shape.1];
192        for &(_, col_, _) in &all_data {
193            col_counts[col_] += 1;
194        }
195
196        // Create indptr
197        let mut indptr = Vec::with_capacity(shape.1 + 1);
198        indptr.push(0);
199        let mut cumsum = 0;
200        for &count in &col_counts {
201            cumsum += count;
202            indptr.push(cumsum);
203        }
204
205        // Create indices and data arrays
206        let mut indices = Vec::with_capacity(nnz);
207        let mut values = Vec::with_capacity(nnz);
208
209        for (row_, _, val) in all_data {
210            indices.push(row_);
211            values.push(val);
212        }
213
214        Self::new(
215            Array1::from_vec(values),
216            Array1::from_vec(indices),
217            Array1::from_vec(indptr),
218            shape,
219        )
220    }
221
222    /// Checks if row indices are sorted for each column
223    fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
224        for col in 0..indptr.len() - 1 {
225            let start = indptr[col];
226            let end = indptr[col + 1];
227
228            for i in start..end.saturating_sub(1) {
229                if i + 1 < indices.len() && indices[i] > indices[i + 1] {
230                    return false;
231                }
232            }
233        }
234        true
235    }
236
237    /// Get the raw data array
238    pub fn get_data(&self) -> &Array1<T> {
239        &self.data
240    }
241
242    /// Get the raw indices array
243    pub fn get_indices(&self) -> &Array1<usize> {
244        &self.indices
245    }
246
247    /// Get the raw indptr array
248    pub fn get_indptr(&self) -> &Array1<usize> {
249        &self.indptr
250    }
251}
252
253impl<T> SparseArray<T> for CscArray<T>
254where
255    T: Float
256        + Add<Output = T>
257        + Sub<Output = T>
258        + Mul<Output = T>
259        + Div<Output = T>
260        + Debug
261        + Copy
262        + 'static,
263{
264    fn shape(&self) -> (usize, usize) {
265        self.shape
266    }
267
268    fn nnz(&self) -> usize {
269        self.data.len()
270    }
271
272    fn dtype(&self) -> &str {
273        "float" // Placeholder, ideally we would return the actual type
274    }
275
276    fn to_array(&self) -> Array2<T> {
277        let (rows, cols) = self.shape;
278        let mut result = Array2::zeros((rows, cols));
279
280        for col in 0..cols {
281            let start = self.indptr[col];
282            let end = self.indptr[col + 1];
283
284            for i in start..end {
285                let row = self.indices[i];
286                result[[row, col]] = self.data[i];
287            }
288        }
289
290        result
291    }
292
293    fn toarray(&self) -> Array2<T> {
294        self.to_array()
295    }
296
297    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
298        // Convert to COO format
299        let nnz = self.nnz();
300        let mut row_indices = Vec::with_capacity(nnz);
301        let mut col_indices = Vec::with_capacity(nnz);
302        let mut values = Vec::with_capacity(nnz);
303
304        for col in 0..self.shape.1 {
305            let start = self.indptr[col];
306            let end = self.indptr[col + 1];
307
308            for idx in start..end {
309                row_indices.push(self.indices[idx]);
310                col_indices.push(col);
311                values.push(self.data[idx]);
312            }
313        }
314
315        CooArray::from_triplets(&row_indices, &col_indices, &values, self.shape, false)
316            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
317    }
318
319    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
320        // For efficiency, we'll go via COO format
321        self.to_coo()?.to_csr()
322    }
323
324    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
325        Ok(Box::new(self.clone()))
326    }
327
328    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
329        // This would convert to DOK format
330        // For now, we'll go via COO format
331        self.to_coo()?.to_dok()
332    }
333
334    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
335        // This would convert to LIL format
336        // For now, we'll go via COO format
337        self.to_coo()?.to_lil()
338    }
339
340    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
341        // This would convert to DIA format
342        // For now, we'll go via COO format
343        self.to_coo()?.to_dia()
344    }
345
346    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
347        // This would convert to BSR format
348        // For now, we'll go via COO format
349        self.to_coo()?.to_bsr()
350    }
351
352    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
353        // For efficiency, convert to CSR format and then add
354        let self_csr = self.to_csr()?;
355        self_csr.add(other)
356    }
357
358    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
359        // For efficiency, convert to CSR format and then subtract
360        let self_csr = self.to_csr()?;
361        self_csr.sub(other)
362    }
363
364    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
365        // Element-wise multiplication (Hadamard product)
366        // Convert to CSR for efficiency
367        let self_csr = self.to_csr()?;
368        self_csr.mul(other)
369    }
370
371    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
372        // Element-wise division
373        // Convert to CSR for efficiency
374        let self_csr = self.to_csr()?;
375        self_csr.div(other)
376    }
377
378    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
379        // Matrix multiplication
380        // Convert to CSR for efficiency
381        let self_csr = self.to_csr()?;
382        self_csr.dot(other)
383    }
384
385    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
386        let (m, n) = self.shape();
387        if n != other.len() {
388            return Err(SparseError::DimensionMismatch {
389                expected: n,
390                found: other.len(),
391            });
392        }
393
394        let mut result = Array1::zeros(m);
395
396        for col in 0..n {
397            let start = self.indptr[col];
398            let end = self.indptr[col + 1];
399
400            let val = other[col];
401            if !val.is_zero() {
402                for idx in start..end {
403                    let row = self.indices[idx];
404                    result[row] = result[row] + self.data[idx] * val;
405                }
406            }
407        }
408
409        Ok(result)
410    }
411
412    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
413        // CSC transposed is effectively CSR (swap rows/cols)
414        CsrArray::new(
415            self.data.clone(),
416            self.indices.clone(),
417            self.indptr.clone(),
418            (self.shape.1, self.shape.0), // Swap dimensions
419        )
420        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
421    }
422
423    fn copy(&self) -> Box<dyn SparseArray<T>> {
424        Box::new(self.clone())
425    }
426
427    fn get(&self, i: usize, j: usize) -> T {
428        if i >= self.shape.0 || j >= self.shape.1 {
429            return T::zero();
430        }
431
432        let start = self.indptr[j];
433        let end = self.indptr[j + 1];
434
435        for idx in start..end {
436            if self.indices[idx] == i {
437                return self.data[idx];
438            }
439
440            // If indices are sorted, we can break early
441            if self.has_sorted_indices && self.indices[idx] > i {
442                break;
443            }
444        }
445
446        T::zero()
447    }
448
449    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
450        // Setting elements in CSC format is non-trivial
451        // This is a placeholder implementation that doesn't actually modify the structure
452        if i >= self.shape.0 || j >= self.shape.1 {
453            return Err(SparseError::IndexOutOfBounds {
454                index: (i, j),
455                shape: self.shape,
456            });
457        }
458
459        let start = self.indptr[j];
460        let end = self.indptr[j + 1];
461
462        // Try to find existing element
463        for idx in start..end {
464            if self.indices[idx] == i {
465                // Update value
466                self.data[idx] = value;
467                return Ok(());
468            }
469
470            // If indices are sorted, we can insert at the right position
471            if self.has_sorted_indices && self.indices[idx] > i {
472                // Insert here - this would require restructuring indices and data
473                // Not implemented in this placeholder
474                return Err(SparseError::NotImplemented(
475                    "Inserting new elements in CSC format".to_string(),
476                ));
477            }
478        }
479
480        // Element not found, would need to insert
481        // This would require restructuring indices and data
482        Err(SparseError::NotImplemented(
483            "Inserting new elements in CSC format".to_string(),
484        ))
485    }
486
487    fn eliminate_zeros(&mut self) {
488        // Find all non-zero entries
489        let mut new_data = Vec::new();
490        let mut new_indices = Vec::new();
491        let mut new_indptr = vec![0];
492
493        let (_, cols) = self.shape;
494
495        for col in 0..cols {
496            let start = self.indptr[col];
497            let end = self.indptr[col + 1];
498
499            for idx in start..end {
500                if !self.data[idx].is_zero() {
501                    new_data.push(self.data[idx]);
502                    new_indices.push(self.indices[idx]);
503                }
504            }
505            new_indptr.push(new_data.len());
506        }
507
508        // Replace data with filtered data
509        self.data = Array1::from_vec(new_data);
510        self.indices = Array1::from_vec(new_indices);
511        self.indptr = Array1::from_vec(new_indptr);
512    }
513
514    fn sort_indices(&mut self) {
515        if self.has_sorted_indices {
516            return;
517        }
518
519        let (_, cols) = self.shape;
520
521        for col in 0..cols {
522            let start = self.indptr[col];
523            let end = self.indptr[col + 1];
524
525            if start == end {
526                continue;
527            }
528
529            // Extract the non-zero elements for this column
530            let mut col_data = Vec::with_capacity(end - start);
531            for idx in start..end {
532                col_data.push((self.indices[idx], self.data[idx]));
533            }
534
535            // Sort by row index
536            col_data.sort_by_key(|&(row_, _)| row_);
537
538            // Put the sorted data back
539            for (i, (row, val)) in col_data.into_iter().enumerate() {
540                self.indices[start + i] = row;
541                self.data[start + i] = val;
542            }
543        }
544
545        self.has_sorted_indices = true;
546    }
547
548    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
549        if self.has_sorted_indices {
550            return Box::new(self.clone());
551        }
552
553        let mut sorted = self.clone();
554        sorted.sort_indices();
555        Box::new(sorted)
556    }
557
558    fn has_sorted_indices(&self) -> bool {
559        self.has_sorted_indices
560    }
561
562    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
563        match axis {
564            None => {
565                // Sum all elements
566                let mut sum = T::zero();
567                for &val in self.data.iter() {
568                    sum = sum + val;
569                }
570                Ok(SparseSum::Scalar(sum))
571            }
572            Some(0) => {
573                // Sum along rows (result is a row vector)
574                // For efficiency, convert to CSR and sum
575                let self_csr = self.to_csr()?;
576                self_csr.sum(Some(0))
577            }
578            Some(1) => {
579                // Sum along columns (result is a column vector)
580                let mut result = Vec::with_capacity(self.shape.1);
581
582                for col in 0..self.shape.1 {
583                    let start = self.indptr[col];
584                    let end = self.indptr[col + 1];
585
586                    let mut col_sum = T::zero();
587                    for idx in start..end {
588                        col_sum = col_sum + self.data[idx];
589                    }
590                    result.push(col_sum);
591                }
592
593                // Convert to COO format for the column vector
594                let mut row_indices = Vec::new();
595                let mut col_indices = Vec::new();
596                let mut values = Vec::new();
597
598                for (col, &val) in result.iter().enumerate() {
599                    if !val.is_zero() {
600                        row_indices.push(0);
601                        col_indices.push(col);
602                        values.push(val);
603                    }
604                }
605
606                let coo = CooArray::from_triplets(
607                    &row_indices,
608                    &col_indices,
609                    &values,
610                    (1, self.shape.1),
611                    true,
612                )?;
613
614                Ok(SparseSum::SparseArray(Box::new(coo)))
615            }
616            _ => Err(SparseError::InvalidAxis),
617        }
618    }
619
620    fn max(&self) -> T {
621        if self.data.is_empty() {
622            return T::neg_infinity();
623        }
624
625        let mut max_val = self.data[0];
626        for &val in self.data.iter().skip(1) {
627            if val > max_val {
628                max_val = val;
629            }
630        }
631
632        // Check if max_val is less than zero, as zeros aren't explicitly stored
633        if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
634            max_val = T::zero();
635        }
636
637        max_val
638    }
639
640    fn min(&self) -> T {
641        if self.data.is_empty() {
642            return T::infinity();
643        }
644
645        let mut min_val = self.data[0];
646        for &val in self.data.iter().skip(1) {
647            if val < min_val {
648                min_val = val;
649            }
650        }
651
652        // Check if min_val is greater than zero, as zeros aren't explicitly stored
653        if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
654            min_val = T::zero();
655        }
656
657        min_val
658    }
659
660    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
661        let nnz = self.nnz();
662        let mut rows = Vec::with_capacity(nnz);
663        let mut cols = Vec::with_capacity(nnz);
664        let mut values = Vec::with_capacity(nnz);
665
666        for col in 0..self.shape.1 {
667            let start = self.indptr[col];
668            let end = self.indptr[col + 1];
669
670            for idx in start..end {
671                let row = self.indices[idx];
672                rows.push(row);
673                cols.push(col);
674                values.push(self.data[idx]);
675            }
676        }
677
678        (
679            Array1::from_vec(rows),
680            Array1::from_vec(cols),
681            Array1::from_vec(values),
682        )
683    }
684
685    fn slice(
686        &self,
687        row_range: (usize, usize),
688        col_range: (usize, usize),
689    ) -> SparseResult<Box<dyn SparseArray<T>>> {
690        let (start_row, end_row) = row_range;
691        let (start_col, end_col) = col_range;
692
693        if start_row >= self.shape.0
694            || end_row > self.shape.0
695            || start_col >= self.shape.1
696            || end_col > self.shape.1
697        {
698            return Err(SparseError::InvalidSliceRange);
699        }
700
701        if start_row >= end_row || start_col >= end_col {
702            return Err(SparseError::InvalidSliceRange);
703        }
704
705        // CSC format is efficient for column slicing
706        let mut data = Vec::new();
707        let mut indices = Vec::new();
708        let mut indptr = vec![0];
709
710        for col in start_col..end_col {
711            let start = self.indptr[col];
712            let end = self.indptr[col + 1];
713
714            for idx in start..end {
715                let row = self.indices[idx];
716                if row >= start_row && row < end_row {
717                    data.push(self.data[idx]);
718                    indices.push(row - start_row); // Adjust indices
719                }
720            }
721            indptr.push(data.len());
722        }
723
724        CscArray::new(
725            Array1::from_vec(data),
726            Array1::from_vec(indices),
727            Array1::from_vec(indptr),
728            (end_row - start_row, end_col - start_col),
729        )
730        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
731    }
732
733    fn as_any(&self) -> &dyn std::any::Any {
734        self
735    }
736
737    fn get_indptr(&self) -> Option<&Array1<usize>> {
738        Some(&self.indptr)
739    }
740
741    fn indptr(&self) -> Option<&Array1<usize>> {
742        Some(&self.indptr)
743    }
744}
745
746impl<T> fmt::Debug for CscArray<T>
747where
748    T: Float
749        + Add<Output = T>
750        + Sub<Output = T>
751        + Mul<Output = T>
752        + Div<Output = T>
753        + Debug
754        + Copy
755        + 'static,
756{
757    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
758        write!(
759            f,
760            "CscArray<{}x{}, nnz={}>",
761            self.shape.0,
762            self.shape.1,
763            self.nnz()
764        )
765    }
766}
767
768#[cfg(test)]
769mod tests {
770    use super::*;
771    use approx::assert_relative_eq;
772
773    #[test]
774    fn test_csc_array_construction() {
775        let data = Array1::from_vec(vec![1.0, 4.0, 2.0, 3.0, 5.0]);
776        let indices = Array1::from_vec(vec![0, 2, 0, 1, 2]);
777        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
778        let shape = (3, 3);
779
780        let csc = CscArray::new(data, indices, indptr, shape).unwrap();
781
782        assert_eq!(csc.shape(), (3, 3));
783        assert_eq!(csc.nnz(), 5);
784        assert_eq!(csc.get(0, 0), 1.0);
785        assert_eq!(csc.get(2, 0), 4.0);
786        assert_eq!(csc.get(0, 1), 2.0);
787        assert_eq!(csc.get(1, 2), 3.0);
788        assert_eq!(csc.get(2, 2), 5.0);
789        assert_eq!(csc.get(1, 0), 0.0);
790    }
791
792    #[test]
793    fn test_csc_from_triplets() {
794        let rows = vec![0, 2, 0, 1, 2];
795        let cols = vec![0, 0, 1, 2, 2];
796        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
797        let shape = (3, 3);
798
799        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
800
801        assert_eq!(csc.shape(), (3, 3));
802        assert_eq!(csc.nnz(), 5);
803        assert_eq!(csc.get(0, 0), 1.0);
804        assert_eq!(csc.get(2, 0), 4.0);
805        assert_eq!(csc.get(0, 1), 2.0);
806        assert_eq!(csc.get(1, 2), 3.0);
807        assert_eq!(csc.get(2, 2), 5.0);
808        assert_eq!(csc.get(1, 0), 0.0);
809    }
810
811    #[test]
812    fn test_csc_array_to_array() {
813        let rows = vec![0, 2, 0, 1, 2];
814        let cols = vec![0, 0, 1, 2, 2];
815        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
816        let shape = (3, 3);
817
818        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
819        let dense = csc.to_array();
820
821        assert_eq!(dense.shape(), &[3, 3]);
822        assert_eq!(dense[[0, 0]], 1.0);
823        assert_eq!(dense[[1, 0]], 0.0);
824        assert_eq!(dense[[2, 0]], 4.0);
825        assert_eq!(dense[[0, 1]], 2.0);
826        assert_eq!(dense[[1, 1]], 0.0);
827        assert_eq!(dense[[2, 1]], 0.0);
828        assert_eq!(dense[[0, 2]], 0.0);
829        assert_eq!(dense[[1, 2]], 3.0);
830        assert_eq!(dense[[2, 2]], 5.0);
831    }
832
833    #[test]
834    fn test_csc_to_csr_conversion() {
835        let rows = vec![0, 2, 0, 1, 2];
836        let cols = vec![0, 0, 1, 2, 2];
837        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
838        let shape = (3, 3);
839
840        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
841        let csr = csc.to_csr().unwrap();
842
843        // Check that the conversion preserved values
844        let csc_array = csc.to_array();
845        let csr_array = csr.to_array();
846
847        for i in 0..shape.0 {
848            for j in 0..shape.1 {
849                assert_relative_eq!(csc_array[[i, j]], csr_array[[i, j]]);
850            }
851        }
852    }
853
854    #[test]
855    fn test_csc_dot_vector() {
856        let rows = vec![0, 2, 0, 1, 2];
857        let cols = vec![0, 0, 1, 2, 2];
858        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
859        let shape = (3, 3);
860
861        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
862        let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
863
864        let result = csc.dot_vector(&vec.view()).unwrap();
865
866        // Expected:
867        // [0,0]*1 + [0,1]*2 + [0,2]*3 = 1*1 + 2*2 + 0*3 = 5
868        // [1,0]*1 + [1,1]*2 + [1,2]*3 = 0*1 + 0*2 + 3*3 = 9
869        // [2,0]*1 + [2,1]*2 + [2,2]*3 = 4*1 + 0*2 + 5*3 = 19
870        assert_eq!(result.len(), 3);
871        assert_relative_eq!(result[0], 5.0);
872        assert_relative_eq!(result[1], 9.0);
873        assert_relative_eq!(result[2], 19.0);
874    }
875
876    #[test]
877    fn test_csc_transpose() {
878        let rows = vec![0, 2, 0, 1, 2];
879        let cols = vec![0, 0, 1, 2, 2];
880        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
881        let shape = (3, 3);
882
883        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
884        let transposed = csc.transpose().unwrap();
885
886        // Check dimensions are swapped
887        assert_eq!(transposed.shape(), (3, 3));
888
889        // Check values are correctly transposed
890        let dense = transposed.to_array();
891        assert_eq!(dense[[0, 0]], 1.0);
892        assert_eq!(dense[[0, 2]], 4.0);
893        assert_eq!(dense[[1, 0]], 2.0);
894        assert_eq!(dense[[2, 1]], 3.0);
895        assert_eq!(dense[[2, 2]], 5.0);
896    }
897
898    #[test]
899    fn test_csc_find() {
900        let rows = vec![0, 2, 0, 1, 2];
901        let cols = vec![0, 0, 1, 2, 2];
902        let data = vec![1.0, 4.0, 2.0, 3.0, 5.0];
903        let shape = (3, 3);
904
905        let csc = CscArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
906        let (result_rows, result_cols, result_data) = csc.find();
907
908        // Check that the find operation returned all non-zero elements
909        assert_eq!(result_rows.len(), 5);
910        assert_eq!(result_cols.len(), 5);
911        assert_eq!(result_data.len(), 5);
912
913        // Create vectors of tuples to compare
914        let mut original: Vec<_> = rows
915            .iter()
916            .zip(cols.iter())
917            .zip(data.iter())
918            .map(|((r, c), d)| (*r, *c, *d))
919            .collect();
920
921        let mut result: Vec<_> = result_rows
922            .iter()
923            .zip(result_cols.iter())
924            .zip(result_data.iter())
925            .map(|((r, c), d)| (*r, *c, *d))
926            .collect();
927
928        // Sort the vectors before comparing
929        original.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
930        result.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
931
932        assert_eq!(original, result);
933    }
934}