scirs2_sparse/
csr_array.rs

1// CSR Array implementation
2//
3// This module provides the CSR (Compressed Sparse Row) array format,
4// which is efficient for row-wise operations and is one of the most common formats.
5
6use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14/// CSR Array format - Compressed Sparse Row matrix representation
15///
16/// The CSR (Compressed Sparse Row) format is one of the most popular sparse matrix formats,
17/// storing a sparse array using three arrays:
18/// - `data`: array of non-zero values in row-major order
19/// - `indices`: column indices of the non-zero values
20/// - `indptr`: row pointers; `indptr[i]` is the index into `data`/`indices` where row `i` starts
21///
22/// The CSR format is particularly efficient for:
23/// - ✅ Matrix-vector multiplications (`A * x`)
24/// - ✅ Matrix-matrix multiplications with other sparse matrices
25/// - ✅ Row-wise operations and row slicing
26/// - ✅ Iterating over non-zero elements row by row
27/// - ✅ Adding and subtracting sparse matrices
28///
29/// But less efficient for:
30/// - ❌ Column-wise operations and column slicing
31/// - ❌ Inserting or modifying individual elements after construction
32/// - ❌ Operations that require column access patterns
33///
34/// # Memory Layout
35///
36/// For a matrix with `m` rows, `n` columns, and `nnz` non-zero elements:
37/// - `data`: length `nnz` - stores the actual non-zero values
38/// - `indices`: length `nnz` - stores column indices for each non-zero value
39/// - `indptr`: length `m+1` - stores cumulative count of non-zeros per row
40///
41/// # Examples
42///
43/// ## Basic Construction and Access
44/// ```
45/// use scirs2_sparse::csr_array::CsrArray;
46/// use scirs2_sparse::SparseArray;
47///
48/// // Create a 3x3 matrix:
49/// // [1.0, 0.0, 2.0]
50/// // [0.0, 3.0, 0.0]
51/// // [4.0, 0.0, 5.0]
52/// let rows = vec![0, 0, 1, 2, 2];
53/// let cols = vec![0, 2, 1, 0, 2];
54/// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
55/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
56///
57/// // Access elements
58/// assert_eq!(matrix.get(0, 0), 1.0);
59/// assert_eq!(matrix.get(0, 1), 0.0);  // Zero element
60/// assert_eq!(matrix.get(1, 1), 3.0);
61///
62/// // Get matrix properties
63/// assert_eq!(matrix.shape(), (3, 3));
64/// assert_eq!(matrix.nnz(), 5);
65/// ```
66///
67/// ## Matrix Operations
68/// ```
69/// use scirs2_sparse::csr_array::CsrArray;
70/// use scirs2_sparse::SparseArray;
71/// use ndarray::Array1;
72///
73/// let rows = vec![0, 1, 2];
74/// let cols = vec![0, 1, 2];
75/// let data = vec![2.0, 3.0, 4.0];
76/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
77///
78/// // Matrix-vector multiplication
79/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
80/// let y = matrix.dot_vector(&x.view()).unwrap();
81/// assert_eq!(y[0], 2.0);  // 2.0 * 1.0
82/// assert_eq!(y[1], 6.0);  // 3.0 * 2.0
83/// assert_eq!(y[2], 12.0); // 4.0 * 3.0
84/// ```
85///
86/// ## Format Conversion
87/// ```
88/// use scirs2_sparse::csr_array::CsrArray;
89/// use scirs2_sparse::SparseArray;
90///
91/// let rows = vec![0, 1];
92/// let cols = vec![0, 1];
93/// let data = vec![1.0, 2.0];
94/// let csr = CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).unwrap();
95///
96/// // Convert to dense array
97/// let dense = csr.to_array();
98/// assert_eq!(dense[[0, 0]], 1.0);
99/// assert_eq!(dense[[1, 1]], 2.0);
100///
101/// // Convert to other sparse formats
102/// let coo = csr.to_coo();
103/// let csc = csr.to_csc();
104/// ```
105#[derive(Clone)]
106pub struct CsrArray<T>
107where
108    T: Float
109        + Add<Output = T>
110        + Sub<Output = T>
111        + Mul<Output = T>
112        + Div<Output = T>
113        + Debug
114        + Copy
115        + 'static,
116{
117    /// Non-zero values
118    data: Array1<T>,
119    /// Column indices of non-zero values
120    indices: Array1<usize>,
121    /// Row pointers (indices into data/indices for the start of each row)
122    indptr: Array1<usize>,
123    /// Shape of the sparse array
124    shape: (usize, usize),
125    /// Whether indices are sorted for each row
126    has_sorted_indices: bool,
127}
128
129impl<T> CsrArray<T>
130where
131    T: Float
132        + Add<Output = T>
133        + Sub<Output = T>
134        + Mul<Output = T>
135        + Div<Output = T>
136        + Debug
137        + Copy
138        + 'static,
139{
140    /// Creates a new CSR array from raw components
141    ///
142    /// # Arguments
143    /// * `data` - Array of non-zero values
144    /// * `indices` - Column indices of non-zero values
145    /// * `indptr` - Index pointers for the start of each row
146    /// * `shape` - Shape of the sparse array
147    ///
148    /// # Returns
149    /// A new `CsrArray`
150    ///
151    /// # Errors
152    /// Returns an error if the data is not consistent
153    pub fn new(
154        data: Array1<T>,
155        indices: Array1<usize>,
156        indptr: Array1<usize>,
157        shape: (usize, usize),
158    ) -> SparseResult<Self> {
159        // Validation
160        if data.len() != indices.len() {
161            return Err(SparseError::InconsistentData {
162                reason: "data and indices must have the same length".to_string(),
163            });
164        }
165
166        if indptr.len() != shape.0 + 1 {
167            return Err(SparseError::InconsistentData {
168                reason: format!(
169                    "indptr length ({}) must be one more than the number of rows ({})",
170                    indptr.len(),
171                    shape.0
172                ),
173            });
174        }
175
176        if let Some(&max_idx) = indices.iter().max() {
177            if max_idx >= shape.1 {
178                return Err(SparseError::IndexOutOfBounds {
179                    index: (0, max_idx),
180                    shape,
181                });
182            }
183        }
184
185        if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
186            if first != 0 {
187                return Err(SparseError::InconsistentData {
188                    reason: "first element of indptr must be 0".to_string(),
189                });
190            }
191
192            if last != data.len() {
193                return Err(SparseError::InconsistentData {
194                    reason: format!(
195                        "last element of indptr ({}) must equal data length ({})",
196                        last,
197                        data.len()
198                    ),
199                });
200            }
201        }
202
203        let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
204
205        Ok(Self {
206            data,
207            indices,
208            indptr,
209            shape,
210            has_sorted_indices,
211        })
212    }
213
214    /// Create a CSR array from triplet format (COO-like)
215    ///
216    /// This function creates a CSR (Compressed Sparse Row) array from coordinate triplets.
217    /// The triplets represent non-zero elements as (row, column, value) tuples.
218    ///
219    /// # Arguments
220    /// * `rows` - Row indices of non-zero elements
221    /// * `cols` - Column indices of non-zero elements  
222    /// * `data` - Values of non-zero elements
223    /// * `shape` - Shape of the sparse array (nrows, ncols)
224    /// * `sorted` - Whether the triplets are already sorted by (row, col). If false, sorting will be performed.
225    ///
226    /// # Returns
227    /// A new `CsrArray` containing the sparse matrix
228    ///
229    /// # Errors
230    /// Returns an error if:
231    /// - `rows`, `cols`, and `data` have different lengths
232    /// - Any index is out of bounds for the given shape
233    /// - The resulting data structure is inconsistent
234    ///
235    /// # Examples
236    ///
237    /// Create a simple 3x3 sparse matrix:
238    /// ```
239    /// use scirs2_sparse::csr_array::CsrArray;
240    /// use scirs2_sparse::SparseArray;
241    ///
242    /// // Create a 3x3 matrix with the following structure:
243    /// // [1.0, 0.0, 2.0]
244    /// // [0.0, 3.0, 0.0]
245    /// // [4.0, 0.0, 5.0]
246    /// let rows = vec![0, 0, 1, 2, 2];
247    /// let cols = vec![0, 2, 1, 0, 2];
248    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
249    /// let shape = (3, 3);
250    ///
251    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
252    /// assert_eq!(matrix.get(0, 0), 1.0);
253    /// assert_eq!(matrix.get(0, 1), 0.0);
254    /// assert_eq!(matrix.get(1, 1), 3.0);
255    /// ```
256    ///
257    /// Create an empty sparse matrix:
258    /// ```
259    /// use scirs2_sparse::csr_array::CsrArray;
260    /// use scirs2_sparse::SparseArray;
261    ///
262    /// let rows: Vec<usize> = vec![];
263    /// let cols: Vec<usize> = vec![];
264    /// let data: Vec<f64> = vec![];
265    /// let shape = (5, 5);
266    ///
267    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
268    /// assert_eq!(matrix.nnz(), 0);
269    /// assert_eq!(matrix.shape(), (5, 5));
270    /// ```
271    ///
272    /// Handle duplicate entries (they will be preserved):
273    /// ```
274    /// use scirs2_sparse::csr_array::CsrArray;
275    /// use scirs2_sparse::SparseArray;
276    ///
277    /// // Multiple entries at the same position
278    /// let rows = vec![0, 0];
279    /// let cols = vec![0, 0];
280    /// let data = vec![1.0, 2.0];
281    /// let shape = (2, 2);
282    ///
283    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
284    /// // Note: CSR format preserves duplicates; use sum_duplicates() to combine them
285    /// assert_eq!(matrix.nnz(), 2);
286    /// ```
287    pub fn from_triplets(
288        rows: &[usize],
289        cols: &[usize],
290        data: &[T],
291        shape: (usize, usize),
292        sorted: bool,
293    ) -> SparseResult<Self> {
294        if rows.len() != cols.len() || rows.len() != data.len() {
295            return Err(SparseError::InconsistentData {
296                reason: "rows, cols, and data must have the same length".to_string(),
297            });
298        }
299
300        if rows.is_empty() {
301            // Empty matrix
302            let indptr = Array1::zeros(shape.0 + 1);
303            return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
304        }
305
306        let nnz = rows.len();
307        let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
308
309        for i in 0..nnz {
310            if rows[i] >= shape.0 || cols[i] >= shape.1 {
311                return Err(SparseError::IndexOutOfBounds {
312                    index: (rows[i], cols[i]),
313                    shape,
314                });
315            }
316            all_data.push((rows[i], cols[i], data[i]));
317        }
318
319        if !sorted {
320            all_data.sort_by_key(|&(row, col_, _)| (row, col_));
321        }
322
323        // Count elements per row
324        let mut row_counts = vec![0; shape.0];
325        for &(row_, _, _) in &all_data {
326            row_counts[row_] += 1;
327        }
328
329        // Create indptr
330        let mut indptr = Vec::with_capacity(shape.0 + 1);
331        indptr.push(0);
332        let mut cumsum = 0;
333        for &count in &row_counts {
334            cumsum += count;
335            indptr.push(cumsum);
336        }
337
338        // Create indices and data arrays
339        let mut indices = Vec::with_capacity(nnz);
340        let mut values = Vec::with_capacity(nnz);
341
342        for (_, col, val) in all_data {
343            indices.push(col);
344            values.push(val);
345        }
346
347        Self::new(
348            Array1::from_vec(values),
349            Array1::from_vec(indices),
350            Array1::from_vec(indptr),
351            shape,
352        )
353    }
354
355    /// Checks if column indices are sorted for each row
356    fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
357        for row in 0..indptr.len() - 1 {
358            let start = indptr[row];
359            let end = indptr[row + 1];
360
361            for i in start..end.saturating_sub(1) {
362                if i + 1 < indices.len() && indices[i] > indices[i + 1] {
363                    return false;
364                }
365            }
366        }
367        true
368    }
369
370    /// Get the raw data array
371    pub fn get_data(&self) -> &Array1<T> {
372        &self.data
373    }
374
375    /// Get the raw indices array
376    pub fn get_indices(&self) -> &Array1<usize> {
377        &self.indices
378    }
379
380    /// Get the raw indptr array
381    pub fn get_indptr(&self) -> &Array1<usize> {
382        &self.indptr
383    }
384
385    /// Get the number of rows
386    pub fn nrows(&self) -> usize {
387        self.shape.0
388    }
389
390    /// Get the number of columns  
391    pub fn ncols(&self) -> usize {
392        self.shape.1
393    }
394
395    /// Get the shape (rows, cols)
396    pub fn shape(&self) -> (usize, usize) {
397        self.shape
398    }
399}
400
401impl<T> SparseArray<T> for CsrArray<T>
402where
403    T: Float
404        + Add<Output = T>
405        + Sub<Output = T>
406        + Mul<Output = T>
407        + Div<Output = T>
408        + Debug
409        + Copy
410        + 'static,
411{
412    fn shape(&self) -> (usize, usize) {
413        self.shape
414    }
415
416    fn nnz(&self) -> usize {
417        self.data.len()
418    }
419
420    fn dtype(&self) -> &str {
421        "float" // Placeholder, ideally we would return the actual type
422    }
423
424    fn to_array(&self) -> Array2<T> {
425        let (rows, cols) = self.shape;
426        let mut result = Array2::zeros((rows, cols));
427
428        for row in 0..rows {
429            let start = self.indptr[row];
430            let end = self.indptr[row + 1];
431
432            for i in start..end {
433                let col = self.indices[i];
434                result[[row, col]] = self.data[i];
435            }
436        }
437
438        result
439    }
440
441    fn toarray(&self) -> Array2<T> {
442        self.to_array()
443    }
444
445    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
446        // This would convert to COO format
447        // For now we just return self
448        Ok(Box::new(self.clone()))
449    }
450
451    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
452        Ok(Box::new(self.clone()))
453    }
454
455    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
456        // This would convert to CSC format
457        // For now we just return self
458        Ok(Box::new(self.clone()))
459    }
460
461    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
462        // This would convert to DOK format
463        // For now we just return self
464        Ok(Box::new(self.clone()))
465    }
466
467    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
468        // This would convert to LIL format
469        // For now we just return self
470        Ok(Box::new(self.clone()))
471    }
472
473    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
474        // This would convert to DIA format
475        // For now we just return self
476        Ok(Box::new(self.clone()))
477    }
478
479    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
480        // This would convert to BSR format
481        // For now we just return self
482        Ok(Box::new(self.clone()))
483    }
484
485    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
486        // Placeholder implementation - a full implementation would handle direct
487        // sparse matrix addition more efficiently
488        let self_array = self.to_array();
489        let other_array = other.to_array();
490
491        if self.shape() != other.shape() {
492            return Err(SparseError::DimensionMismatch {
493                expected: self.shape().0,
494                found: other.shape().0,
495            });
496        }
497
498        let result = &self_array + &other_array;
499
500        // Convert back to CSR
501        let (rows, cols) = self.shape();
502        let mut data = Vec::new();
503        let mut indices = Vec::new();
504        let mut indptr = vec![0];
505
506        for row in 0..rows {
507            for col in 0..cols {
508                let val = result[[row, col]];
509                if !val.is_zero() {
510                    data.push(val);
511                    indices.push(col);
512                }
513            }
514            indptr.push(data.len());
515        }
516
517        CsrArray::new(
518            Array1::from_vec(data),
519            Array1::from_vec(indices),
520            Array1::from_vec(indptr),
521            self.shape(),
522        )
523        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
524    }
525
526    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
527        // Similar to add, this is a placeholder
528        let self_array = self.to_array();
529        let other_array = other.to_array();
530
531        if self.shape() != other.shape() {
532            return Err(SparseError::DimensionMismatch {
533                expected: self.shape().0,
534                found: other.shape().0,
535            });
536        }
537
538        let result = &self_array - &other_array;
539
540        // Convert back to CSR
541        let (rows, cols) = self.shape();
542        let mut data = Vec::new();
543        let mut indices = Vec::new();
544        let mut indptr = vec![0];
545
546        for row in 0..rows {
547            for col in 0..cols {
548                let val = result[[row, col]];
549                if !val.is_zero() {
550                    data.push(val);
551                    indices.push(col);
552                }
553            }
554            indptr.push(data.len());
555        }
556
557        CsrArray::new(
558            Array1::from_vec(data),
559            Array1::from_vec(indices),
560            Array1::from_vec(indptr),
561            self.shape(),
562        )
563        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
564    }
565
566    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
567        // This is element-wise multiplication (Hadamard product)
568        // In the sparse array API, * is element-wise, not matrix multiplication
569        let self_array = self.to_array();
570        let other_array = other.to_array();
571
572        if self.shape() != other.shape() {
573            return Err(SparseError::DimensionMismatch {
574                expected: self.shape().0,
575                found: other.shape().0,
576            });
577        }
578
579        let result = &self_array * &other_array;
580
581        // Convert back to CSR
582        let (rows, cols) = self.shape();
583        let mut data = Vec::new();
584        let mut indices = Vec::new();
585        let mut indptr = vec![0];
586
587        for row in 0..rows {
588            for col in 0..cols {
589                let val = result[[row, col]];
590                if !val.is_zero() {
591                    data.push(val);
592                    indices.push(col);
593                }
594            }
595            indptr.push(data.len());
596        }
597
598        CsrArray::new(
599            Array1::from_vec(data),
600            Array1::from_vec(indices),
601            Array1::from_vec(indptr),
602            self.shape(),
603        )
604        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
605    }
606
607    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
608        // Element-wise division
609        let self_array = self.to_array();
610        let other_array = other.to_array();
611
612        if self.shape() != other.shape() {
613            return Err(SparseError::DimensionMismatch {
614                expected: self.shape().0,
615                found: other.shape().0,
616            });
617        }
618
619        let result = &self_array / &other_array;
620
621        // Convert back to CSR
622        let (rows, cols) = self.shape();
623        let mut data = Vec::new();
624        let mut indices = Vec::new();
625        let mut indptr = vec![0];
626
627        for row in 0..rows {
628            for col in 0..cols {
629                let val = result[[row, col]];
630                if !val.is_zero() {
631                    data.push(val);
632                    indices.push(col);
633                }
634            }
635            indptr.push(data.len());
636        }
637
638        CsrArray::new(
639            Array1::from_vec(data),
640            Array1::from_vec(indices),
641            Array1::from_vec(indptr),
642            self.shape(),
643        )
644        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
645    }
646
647    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
648        // Matrix multiplication
649        // This is a placeholder; a full implementation would be optimized for sparse arrays
650        let (m, n) = self.shape();
651        let (p, q) = other.shape();
652
653        if n != p {
654            return Err(SparseError::DimensionMismatch {
655                expected: n,
656                found: p,
657            });
658        }
659
660        let mut result = Array2::zeros((m, q));
661        let other_array = other.to_array();
662
663        for row in 0..m {
664            let start = self.indptr[row];
665            let end = self.indptr[row + 1];
666
667            for j in 0..q {
668                let mut sum = T::zero();
669                for idx in start..end {
670                    let col = self.indices[idx];
671                    sum = sum + self.data[idx] * other_array[[col, j]];
672                }
673                if !sum.is_zero() {
674                    result[[row, j]] = sum;
675                }
676            }
677        }
678
679        // Convert result back to CSR
680        let mut data = Vec::new();
681        let mut indices = Vec::new();
682        let mut indptr = vec![0];
683
684        for row in 0..m {
685            for col in 0..q {
686                let val = result[[row, col]];
687                if !val.is_zero() {
688                    data.push(val);
689                    indices.push(col);
690                }
691            }
692            indptr.push(data.len());
693        }
694
695        CsrArray::new(
696            Array1::from_vec(data),
697            Array1::from_vec(indices),
698            Array1::from_vec(indptr),
699            (m, q),
700        )
701        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
702    }
703
704    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
705        let (m, n) = self.shape();
706        if n != other.len() {
707            return Err(SparseError::DimensionMismatch {
708                expected: n,
709                found: other.len(),
710            });
711        }
712
713        let mut result = Array1::zeros(m);
714
715        for row in 0..m {
716            let start = self.indptr[row];
717            let end = self.indptr[row + 1];
718
719            let mut sum = T::zero();
720            for idx in start..end {
721                let col = self.indices[idx];
722                sum = sum + self.data[idx] * other[col];
723            }
724            result[row] = sum;
725        }
726
727        Ok(result)
728    }
729
730    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
731        // Transpose is non-trivial for CSR format
732        // A full implementation would convert to CSC format or implement
733        // an efficient algorithm
734        let (rows, cols) = self.shape();
735        let mut row_indices = Vec::with_capacity(self.nnz());
736        let mut col_indices = Vec::with_capacity(self.nnz());
737        let mut values = Vec::with_capacity(self.nnz());
738
739        for row in 0..rows {
740            let start = self.indptr[row];
741            let end = self.indptr[row + 1];
742
743            for idx in start..end {
744                let col = self.indices[idx];
745                row_indices.push(col); // Note: rows and cols are swapped for transposition
746                col_indices.push(row);
747                values.push(self.data[idx]);
748            }
749        }
750
751        // We need to create CSR from this "COO" representation
752        CsrArray::from_triplets(
753            &row_indices,
754            &col_indices,
755            &values,
756            (cols, rows), // Swapped dimensions
757            false,
758        )
759        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
760    }
761
762    fn copy(&self) -> Box<dyn SparseArray<T>> {
763        Box::new(self.clone())
764    }
765
766    fn get(&self, i: usize, j: usize) -> T {
767        if i >= self.shape.0 || j >= self.shape.1 {
768            return T::zero();
769        }
770
771        let start = self.indptr[i];
772        let end = self.indptr[i + 1];
773
774        for idx in start..end {
775            if self.indices[idx] == j {
776                return self.data[idx];
777            }
778            // If indices are sorted, we can break early
779            if self.has_sorted_indices && self.indices[idx] > j {
780                break;
781            }
782        }
783
784        T::zero()
785    }
786
787    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
788        // Setting elements in CSR format is non-trivial
789        // This is a placeholder implementation that doesn't actually modify the structure
790        if i >= self.shape.0 || j >= self.shape.1 {
791            return Err(SparseError::IndexOutOfBounds {
792                index: (i, j),
793                shape: self.shape,
794            });
795        }
796
797        let start = self.indptr[i];
798        let end = self.indptr[i + 1];
799
800        // Try to find existing element
801        for idx in start..end {
802            if self.indices[idx] == j {
803                // Update value
804                self.data[idx] = value;
805                return Ok(());
806            }
807            // If indices are sorted, we can insert at the right position
808            if self.has_sorted_indices && self.indices[idx] > j {
809                // Insert here - this would require restructuring indices and data
810                // Not implemented in this placeholder
811                return Err(SparseError::NotImplemented(
812                    "Inserting new elements in CSR format".to_string(),
813                ));
814            }
815        }
816
817        // Element not found, would need to insert
818        // This would require restructuring indices and data
819        Err(SparseError::NotImplemented(
820            "Inserting new elements in CSR format".to_string(),
821        ))
822    }
823
824    fn eliminate_zeros(&mut self) {
825        // Find all non-zero entries
826        let mut new_data = Vec::new();
827        let mut new_indices = Vec::new();
828        let mut new_indptr = vec![0];
829
830        let (rows, _) = self.shape();
831
832        for row in 0..rows {
833            let start = self.indptr[row];
834            let end = self.indptr[row + 1];
835
836            for idx in start..end {
837                if !self.data[idx].is_zero() {
838                    new_data.push(self.data[idx]);
839                    new_indices.push(self.indices[idx]);
840                }
841            }
842            new_indptr.push(new_data.len());
843        }
844
845        // Replace data with filtered data
846        self.data = Array1::from_vec(new_data);
847        self.indices = Array1::from_vec(new_indices);
848        self.indptr = Array1::from_vec(new_indptr);
849    }
850
851    fn sort_indices(&mut self) {
852        if self.has_sorted_indices {
853            return;
854        }
855
856        let (rows, _) = self.shape();
857
858        for row in 0..rows {
859            let start = self.indptr[row];
860            let end = self.indptr[row + 1];
861
862            if start == end {
863                continue;
864            }
865
866            // Extract the non-zero elements for this row
867            let mut row_data = Vec::with_capacity(end - start);
868            for idx in start..end {
869                row_data.push((self.indices[idx], self.data[idx]));
870            }
871
872            // Sort by column index
873            row_data.sort_by_key(|&(col_, _)| col_);
874
875            // Put the sorted data back
876            for (i, (col, val)) in row_data.into_iter().enumerate() {
877                self.indices[start + i] = col;
878                self.data[start + i] = val;
879            }
880        }
881
882        self.has_sorted_indices = true;
883    }
884
885    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
886        if self.has_sorted_indices {
887            return Box::new(self.clone());
888        }
889
890        let mut sorted = self.clone();
891        sorted.sort_indices();
892        Box::new(sorted)
893    }
894
895    fn has_sorted_indices(&self) -> bool {
896        self.has_sorted_indices
897    }
898
899    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
900        match axis {
901            None => {
902                // Sum all elements
903                let mut sum = T::zero();
904                for &val in self.data.iter() {
905                    sum = sum + val;
906                }
907                Ok(SparseSum::Scalar(sum))
908            }
909            Some(0) => {
910                // Sum along rows (result is a row vector)
911                let (_, cols) = self.shape();
912                let mut result = vec![T::zero(); cols];
913
914                for row in 0..self.shape.0 {
915                    let start = self.indptr[row];
916                    let end = self.indptr[row + 1];
917
918                    for idx in start..end {
919                        let col = self.indices[idx];
920                        result[col] = result[col] + self.data[idx];
921                    }
922                }
923
924                // Convert to CSR format
925                let mut data = Vec::new();
926                let mut indices = Vec::new();
927                let mut indptr = vec![0];
928
929                for (col, &val) in result.iter().enumerate() {
930                    if !val.is_zero() {
931                        data.push(val);
932                        indices.push(col);
933                    }
934                }
935                indptr.push(data.len());
936
937                let result_array = CsrArray::new(
938                    Array1::from_vec(data),
939                    Array1::from_vec(indices),
940                    Array1::from_vec(indptr),
941                    (1, cols),
942                )?;
943
944                Ok(SparseSum::SparseArray(Box::new(result_array)))
945            }
946            Some(1) => {
947                // Sum along columns (result is a column vector)
948                let mut result = Vec::with_capacity(self.shape.0);
949
950                for row in 0..self.shape.0 {
951                    let start = self.indptr[row];
952                    let end = self.indptr[row + 1];
953
954                    let mut row_sum = T::zero();
955                    for idx in start..end {
956                        row_sum = row_sum + self.data[idx];
957                    }
958                    result.push(row_sum);
959                }
960
961                // Convert to CSR format
962                let mut data = Vec::new();
963                let mut indices = Vec::new();
964                let mut indptr = vec![0];
965
966                for &val in result.iter() {
967                    if !val.is_zero() {
968                        data.push(val);
969                        indices.push(0);
970                        indptr.push(data.len());
971                    } else {
972                        indptr.push(data.len());
973                    }
974                }
975
976                let result_array = CsrArray::new(
977                    Array1::from_vec(data),
978                    Array1::from_vec(indices),
979                    Array1::from_vec(indptr),
980                    (self.shape.0, 1),
981                )?;
982
983                Ok(SparseSum::SparseArray(Box::new(result_array)))
984            }
985            _ => Err(SparseError::InvalidAxis),
986        }
987    }
988
989    fn max(&self) -> T {
990        if self.data.is_empty() {
991            return T::neg_infinity();
992        }
993
994        let mut max_val = self.data[0];
995        for &val in self.data.iter().skip(1) {
996            if val > max_val {
997                max_val = val;
998            }
999        }
1000
1001        // Check if max_val is less than zero, as zeros aren't explicitly stored
1002        if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
1003            max_val = T::zero();
1004        }
1005
1006        max_val
1007    }
1008
1009    fn min(&self) -> T {
1010        if self.data.is_empty() {
1011            return T::infinity();
1012        }
1013
1014        let mut min_val = self.data[0];
1015        for &val in self.data.iter().skip(1) {
1016            if val < min_val {
1017                min_val = val;
1018            }
1019        }
1020
1021        // Check if min_val is greater than zero, as zeros aren't explicitly stored
1022        if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
1023            min_val = T::zero();
1024        }
1025
1026        min_val
1027    }
1028
1029    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
1030        let nnz = self.nnz();
1031        let mut rows = Vec::with_capacity(nnz);
1032        let mut cols = Vec::with_capacity(nnz);
1033        let mut values = Vec::with_capacity(nnz);
1034
1035        for row in 0..self.shape.0 {
1036            let start = self.indptr[row];
1037            let end = self.indptr[row + 1];
1038
1039            for idx in start..end {
1040                let col = self.indices[idx];
1041                rows.push(row);
1042                cols.push(col);
1043                values.push(self.data[idx]);
1044            }
1045        }
1046
1047        (
1048            Array1::from_vec(rows),
1049            Array1::from_vec(cols),
1050            Array1::from_vec(values),
1051        )
1052    }
1053
1054    fn slice(
1055        &self,
1056        row_range: (usize, usize),
1057        col_range: (usize, usize),
1058    ) -> SparseResult<Box<dyn SparseArray<T>>> {
1059        let (start_row, end_row) = row_range;
1060        let (start_col, end_col) = col_range;
1061
1062        if start_row >= self.shape.0
1063            || end_row > self.shape.0
1064            || start_col >= self.shape.1
1065            || end_col > self.shape.1
1066        {
1067            return Err(SparseError::InvalidSliceRange);
1068        }
1069
1070        if start_row >= end_row || start_col >= end_col {
1071            return Err(SparseError::InvalidSliceRange);
1072        }
1073
1074        let mut data = Vec::new();
1075        let mut indices = Vec::new();
1076        let mut indptr = vec![0];
1077
1078        for row in start_row..end_row {
1079            let start = self.indptr[row];
1080            let end = self.indptr[row + 1];
1081
1082            for idx in start..end {
1083                let col = self.indices[idx];
1084                if col >= start_col && col < end_col {
1085                    data.push(self.data[idx]);
1086                    indices.push(col - start_col);
1087                }
1088            }
1089            indptr.push(data.len());
1090        }
1091
1092        CsrArray::new(
1093            Array1::from_vec(data),
1094            Array1::from_vec(indices),
1095            Array1::from_vec(indptr),
1096            (end_row - start_row, end_col - start_col),
1097        )
1098        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
1099    }
1100
1101    fn as_any(&self) -> &dyn std::any::Any {
1102        self
1103    }
1104
1105    fn get_indptr(&self) -> Option<&Array1<usize>> {
1106        Some(&self.indptr)
1107    }
1108
1109    fn indptr(&self) -> Option<&Array1<usize>> {
1110        Some(&self.indptr)
1111    }
1112}
1113
1114impl<T> fmt::Debug for CsrArray<T>
1115where
1116    T: Float
1117        + Add<Output = T>
1118        + Sub<Output = T>
1119        + Mul<Output = T>
1120        + Div<Output = T>
1121        + Debug
1122        + Copy
1123        + 'static,
1124{
1125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1126        write!(
1127            f,
1128            "CsrArray<{}x{}, nnz={}>",
1129            self.shape.0,
1130            self.shape.1,
1131            self.nnz()
1132        )
1133    }
1134}
1135
1136#[cfg(test)]
1137mod tests {
1138    use super::*;
1139    use approx::assert_relative_eq;
1140
1141    #[test]
1142    fn test_csr_array_construction() {
1143        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1144        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1145        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1146        let shape = (3, 3);
1147
1148        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1149
1150        assert_eq!(csr.shape(), (3, 3));
1151        assert_eq!(csr.nnz(), 5);
1152        assert_eq!(csr.get(0, 0), 1.0);
1153        assert_eq!(csr.get(0, 2), 2.0);
1154        assert_eq!(csr.get(1, 1), 3.0);
1155        assert_eq!(csr.get(2, 0), 4.0);
1156        assert_eq!(csr.get(2, 2), 5.0);
1157        assert_eq!(csr.get(0, 1), 0.0);
1158    }
1159
1160    #[test]
1161    fn test_csr_from_triplets() {
1162        let rows = vec![0, 0, 1, 2, 2];
1163        let cols = vec![0, 2, 1, 0, 2];
1164        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1165        let shape = (3, 3);
1166
1167        let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
1168
1169        assert_eq!(csr.shape(), (3, 3));
1170        assert_eq!(csr.nnz(), 5);
1171        assert_eq!(csr.get(0, 0), 1.0);
1172        assert_eq!(csr.get(0, 2), 2.0);
1173        assert_eq!(csr.get(1, 1), 3.0);
1174        assert_eq!(csr.get(2, 0), 4.0);
1175        assert_eq!(csr.get(2, 2), 5.0);
1176        assert_eq!(csr.get(0, 1), 0.0);
1177    }
1178
1179    #[test]
1180    fn test_csr_array_to_array() {
1181        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1182        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1183        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1184        let shape = (3, 3);
1185
1186        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1187        let dense = csr.to_array();
1188
1189        assert_eq!(dense.shape(), &[3, 3]);
1190        assert_eq!(dense[[0, 0]], 1.0);
1191        assert_eq!(dense[[0, 1]], 0.0);
1192        assert_eq!(dense[[0, 2]], 2.0);
1193        assert_eq!(dense[[1, 0]], 0.0);
1194        assert_eq!(dense[[1, 1]], 3.0);
1195        assert_eq!(dense[[1, 2]], 0.0);
1196        assert_eq!(dense[[2, 0]], 4.0);
1197        assert_eq!(dense[[2, 1]], 0.0);
1198        assert_eq!(dense[[2, 2]], 5.0);
1199    }
1200
1201    #[test]
1202    fn test_csr_array_dot_vector() {
1203        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1204        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1205        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1206        let shape = (3, 3);
1207
1208        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1209        let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1210
1211        let result = csr.dot_vector(&vec.view()).unwrap();
1212
1213        // Expected: [1*1 + 0*2 + 2*3, 0*1 + 3*2 + 0*3, 4*1 + 0*2 + 5*3] = [7, 6, 19]
1214        assert_eq!(result.len(), 3);
1215        assert_relative_eq!(result[0], 7.0);
1216        assert_relative_eq!(result[1], 6.0);
1217        assert_relative_eq!(result[2], 19.0);
1218    }
1219
1220    #[test]
1221    fn test_csr_array_sum() {
1222        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1223        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1224        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1225        let shape = (3, 3);
1226
1227        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1228
1229        // Sum all elements
1230        if let SparseSum::Scalar(sum) = csr.sum(None).unwrap() {
1231            assert_relative_eq!(sum, 15.0);
1232        } else {
1233            panic!("Expected scalar sum");
1234        }
1235
1236        // Sum along rows
1237        if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).unwrap() {
1238            let row_sum_array = row_sum.to_array();
1239            assert_eq!(row_sum_array.shape(), &[1, 3]);
1240            assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1241            assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1242            assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1243        } else {
1244            panic!("Expected sparse array sum");
1245        }
1246
1247        // Sum along columns
1248        if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).unwrap() {
1249            let col_sum_array = col_sum.to_array();
1250            assert_eq!(col_sum_array.shape(), &[3, 1]);
1251            assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1252            assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1253            assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1254        } else {
1255            panic!("Expected sparse array sum");
1256        }
1257    }
1258}