scirs2_sparse/
csr_array.rs

1// CSR Array implementation
2//
3// This module provides the CSR (Compressed Sparse Row) array format,
4// which is efficient for row-wise operations and is one of the most common formats.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14/// CSR Array format - Compressed Sparse Row matrix representation
15///
16/// The CSR (Compressed Sparse Row) format is one of the most popular sparse matrix formats,
17/// storing a sparse array using three arrays:
18/// - `data`: array of non-zero values in row-major order
19/// - `indices`: column indices of the non-zero values
20/// - `indptr`: row pointers; `indptr[i]` is the index into `data`/`indices` where row `i` starts
21///
22/// The CSR format is particularly efficient for:
23/// - ✅ Matrix-vector multiplications (`A * x`)
24/// - ✅ Matrix-matrix multiplications with other sparse matrices
25/// - ✅ Row-wise operations and row slicing
26/// - ✅ Iterating over non-zero elements row by row
27/// - ✅ Adding and subtracting sparse matrices
28///
29/// But less efficient for:
30/// - ❌ Column-wise operations and column slicing
31/// - ❌ Inserting or modifying individual elements after construction
32/// - ❌ Operations that require column access patterns
33///
34/// # Memory Layout
35///
36/// For a matrix with `m` rows, `n` columns, and `nnz` non-zero elements:
37/// - `data`: length `nnz` - stores the actual non-zero values
38/// - `indices`: length `nnz` - stores column indices for each non-zero value
39/// - `indptr`: length `m+1` - stores cumulative count of non-zeros per row
40///
41/// # Examples
42///
43/// ## Basic Construction and Access
44/// ```
45/// use scirs2_sparse::csr_array::CsrArray;
46/// use scirs2_sparse::SparseArray;
47///
48/// // Create a 3x3 matrix:
49/// // [1.0, 0.0, 2.0]
50/// // [0.0, 3.0, 0.0]
51/// // [4.0, 0.0, 5.0]
52/// let rows = vec![0, 0, 1, 2, 2];
53/// let cols = vec![0, 2, 1, 0, 2];
54/// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
55/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
56///
57/// // Access elements
58/// assert_eq!(matrix.get(0, 0), 1.0);
59/// assert_eq!(matrix.get(0, 1), 0.0);  // Zero element
60/// assert_eq!(matrix.get(1, 1), 3.0);
61///
62/// // Get matrix properties
63/// assert_eq!(matrix.shape(), (3, 3));
64/// assert_eq!(matrix.nnz(), 5);
65/// ```
66///
67/// ## Matrix Operations
68/// ```
69/// use scirs2_sparse::csr_array::CsrArray;
70/// use scirs2_sparse::SparseArray;
71/// use scirs2_core::ndarray::Array1;
72///
73/// let rows = vec![0, 1, 2];
74/// let cols = vec![0, 1, 2];
75/// let data = vec![2.0, 3.0, 4.0];
76/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
77///
78/// // Matrix-vector multiplication
79/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
80/// let y = matrix.dot_vector(&x.view()).unwrap();
81/// assert_eq!(y[0], 2.0);  // 2.0 * 1.0
82/// assert_eq!(y[1], 6.0);  // 3.0 * 2.0
83/// assert_eq!(y[2], 12.0); // 4.0 * 3.0
84/// ```
85///
86/// ## Format Conversion
87/// ```
88/// use scirs2_sparse::csr_array::CsrArray;
89/// use scirs2_sparse::SparseArray;
90///
91/// let rows = vec![0, 1];
92/// let cols = vec![0, 1];
93/// let data = vec![1.0, 2.0];
94/// let csr = CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).unwrap();
95///
96/// // Convert to dense array
97/// let dense = csr.to_array();
98/// assert_eq!(dense[[0, 0]], 1.0);
99/// assert_eq!(dense[[1, 1]], 2.0);
100///
101/// // Convert to other sparse formats
102/// let coo = csr.to_coo();
103/// let csc = csr.to_csc();
104/// ```
105#[derive(Clone)]
106pub struct CsrArray<T>
107where
108    T: SparseElement + Div<Output = T> + 'static,
109{
110    /// Non-zero values
111    data: Array1<T>,
112    /// Column indices of non-zero values
113    indices: Array1<usize>,
114    /// Row pointers (indices into data/indices for the start of each row)
115    indptr: Array1<usize>,
116    /// Shape of the sparse array
117    shape: (usize, usize),
118    /// Whether indices are sorted for each row
119    has_sorted_indices: bool,
120}
121
122impl<T> CsrArray<T>
123where
124    T: SparseElement + Div<Output = T> + Zero + 'static,
125{
126    /// Creates a new CSR array from raw components
127    ///
128    /// # Arguments
129    /// * `data` - Array of non-zero values
130    /// * `indices` - Column indices of non-zero values
131    /// * `indptr` - Index pointers for the start of each row
132    /// * `shape` - Shape of the sparse array
133    ///
134    /// # Returns
135    /// A new `CsrArray`
136    ///
137    /// # Errors
138    /// Returns an error if the data is not consistent
139    pub fn new(
140        data: Array1<T>,
141        indices: Array1<usize>,
142        indptr: Array1<usize>,
143        shape: (usize, usize),
144    ) -> SparseResult<Self> {
145        // Validation
146        if data.len() != indices.len() {
147            return Err(SparseError::InconsistentData {
148                reason: "data and indices must have the same length".to_string(),
149            });
150        }
151
152        if indptr.len() != shape.0 + 1 {
153            return Err(SparseError::InconsistentData {
154                reason: format!(
155                    "indptr length ({}) must be one more than the number of rows ({})",
156                    indptr.len(),
157                    shape.0
158                ),
159            });
160        }
161
162        if let Some(&max_idx) = indices.iter().max() {
163            if max_idx >= shape.1 {
164                return Err(SparseError::IndexOutOfBounds {
165                    index: (0, max_idx),
166                    shape,
167                });
168            }
169        }
170
171        if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
172            if first != 0 {
173                return Err(SparseError::InconsistentData {
174                    reason: "first element of indptr must be 0".to_string(),
175                });
176            }
177
178            if last != data.len() {
179                return Err(SparseError::InconsistentData {
180                    reason: format!(
181                        "last element of indptr ({}) must equal data length ({})",
182                        last,
183                        data.len()
184                    ),
185                });
186            }
187        }
188
189        let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
190
191        Ok(Self {
192            data,
193            indices,
194            indptr,
195            shape,
196            has_sorted_indices,
197        })
198    }
199
200    /// Create a CSR array from triplet format (COO-like)
201    ///
202    /// This function creates a CSR (Compressed Sparse Row) array from coordinate triplets.
203    /// The triplets represent non-zero elements as (row, column, value) tuples.
204    ///
205    /// # Arguments
206    /// * `rows` - Row indices of non-zero elements
207    /// * `cols` - Column indices of non-zero elements  
208    /// * `data` - Values of non-zero elements
209    /// * `shape` - Shape of the sparse array (nrows, ncols)
210    /// * `sorted` - Whether the triplets are already sorted by (row, col). If false, sorting will be performed.
211    ///
212    /// # Returns
213    /// A new `CsrArray` containing the sparse matrix
214    ///
215    /// # Errors
216    /// Returns an error if:
217    /// - `rows`, `cols`, and `data` have different lengths
218    /// - Any index is out of bounds for the given shape
219    /// - The resulting data structure is inconsistent
220    ///
221    /// # Examples
222    ///
223    /// Create a simple 3x3 sparse matrix:
224    /// ```
225    /// use scirs2_sparse::csr_array::CsrArray;
226    /// use scirs2_sparse::SparseArray;
227    ///
228    /// // Create a 3x3 matrix with the following structure:
229    /// // [1.0, 0.0, 2.0]
230    /// // [0.0, 3.0, 0.0]
231    /// // [4.0, 0.0, 5.0]
232    /// let rows = vec![0, 0, 1, 2, 2];
233    /// let cols = vec![0, 2, 1, 0, 2];
234    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
235    /// let shape = (3, 3);
236    ///
237    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
238    /// assert_eq!(matrix.get(0, 0), 1.0);
239    /// assert_eq!(matrix.get(0, 1), 0.0);
240    /// assert_eq!(matrix.get(1, 1), 3.0);
241    /// ```
242    ///
243    /// Create an empty sparse matrix:
244    /// ```
245    /// use scirs2_sparse::csr_array::CsrArray;
246    /// use scirs2_sparse::SparseArray;
247    ///
248    /// let rows: Vec<usize> = vec![];
249    /// let cols: Vec<usize> = vec![];
250    /// let data: Vec<f64> = vec![];
251    /// let shape = (5, 5);
252    ///
253    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
254    /// assert_eq!(matrix.nnz(), 0);
255    /// assert_eq!(matrix.shape(), (5, 5));
256    /// ```
257    ///
258    /// Handle duplicate entries (they will be preserved):
259    /// ```
260    /// use scirs2_sparse::csr_array::CsrArray;
261    /// use scirs2_sparse::SparseArray;
262    ///
263    /// // Multiple entries at the same position
264    /// let rows = vec![0, 0];
265    /// let cols = vec![0, 0];
266    /// let data = vec![1.0, 2.0];
267    /// let shape = (2, 2);
268    ///
269    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
270    /// // Note: CSR format preserves duplicates; use sum_duplicates() to combine them
271    /// assert_eq!(matrix.nnz(), 2);
272    /// ```
273    pub fn from_triplets(
274        rows: &[usize],
275        cols: &[usize],
276        data: &[T],
277        shape: (usize, usize),
278        sorted: bool,
279    ) -> SparseResult<Self> {
280        if rows.len() != cols.len() || rows.len() != data.len() {
281            return Err(SparseError::InconsistentData {
282                reason: "rows, cols, and data must have the same length".to_string(),
283            });
284        }
285
286        if rows.is_empty() {
287            // Empty matrix
288            let indptr = Array1::zeros(shape.0 + 1);
289            return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
290        }
291
292        let nnz = rows.len();
293        let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
294
295        for i in 0..nnz {
296            if rows[i] >= shape.0 || cols[i] >= shape.1 {
297                return Err(SparseError::IndexOutOfBounds {
298                    index: (rows[i], cols[i]),
299                    shape,
300                });
301            }
302            all_data.push((rows[i], cols[i], data[i]));
303        }
304
305        if !sorted {
306            all_data.sort_by_key(|&(row, col_, _)| (row, col_));
307        }
308
309        // Count elements per row
310        let mut row_counts = vec![0; shape.0];
311        for &(row_, _, _) in &all_data {
312            row_counts[row_] += 1;
313        }
314
315        // Create indptr
316        let mut indptr = Vec::with_capacity(shape.0 + 1);
317        indptr.push(0);
318        let mut cumsum = 0;
319        for &count in &row_counts {
320            cumsum += count;
321            indptr.push(cumsum);
322        }
323
324        // Create indices and data arrays
325        let mut indices = Vec::with_capacity(nnz);
326        let mut values = Vec::with_capacity(nnz);
327
328        for (_, col, val) in all_data {
329            indices.push(col);
330            values.push(val);
331        }
332
333        Self::new(
334            Array1::from_vec(values),
335            Array1::from_vec(indices),
336            Array1::from_vec(indptr),
337            shape,
338        )
339    }
340
341    /// Checks if column indices are sorted for each row
342    fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
343        for row in 0..indptr.len() - 1 {
344            let start = indptr[row];
345            let end = indptr[row + 1];
346
347            for i in start..end.saturating_sub(1) {
348                if i + 1 < indices.len() && indices[i] > indices[i + 1] {
349                    return false;
350                }
351            }
352        }
353        true
354    }
355
356    /// Get the raw data array
357    pub fn get_data(&self) -> &Array1<T> {
358        &self.data
359    }
360
361    /// Get the raw indices array
362    pub fn get_indices(&self) -> &Array1<usize> {
363        &self.indices
364    }
365
366    /// Get the raw indptr array
367    pub fn get_indptr(&self) -> &Array1<usize> {
368        &self.indptr
369    }
370
371    /// Get the number of rows
372    pub fn nrows(&self) -> usize {
373        self.shape.0
374    }
375
376    /// Get the number of columns  
377    pub fn ncols(&self) -> usize {
378        self.shape.1
379    }
380
381    /// Get the shape (rows, cols)
382    pub fn shape(&self) -> (usize, usize) {
383        self.shape
384    }
385}
386
387impl<T> SparseArray<T> for CsrArray<T>
388where
389    T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
390{
391    fn shape(&self) -> (usize, usize) {
392        self.shape
393    }
394
395    fn nnz(&self) -> usize {
396        self.data.len()
397    }
398
399    fn dtype(&self) -> &str {
400        "float" // Placeholder, ideally we would return the actual type
401    }
402
403    fn to_array(&self) -> Array2<T> {
404        let (rows, cols) = self.shape;
405        let mut result = Array2::zeros((rows, cols));
406
407        for row in 0..rows {
408            let start = self.indptr[row];
409            let end = self.indptr[row + 1];
410
411            for i in start..end {
412                let col = self.indices[i];
413                result[[row, col]] = self.data[i];
414            }
415        }
416
417        result
418    }
419
420    fn toarray(&self) -> Array2<T> {
421        self.to_array()
422    }
423
424    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
425        // This would convert to COO format
426        // For now we just return self
427        Ok(Box::new(self.clone()))
428    }
429
430    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
431        Ok(Box::new(self.clone()))
432    }
433
434    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
435        // This would convert to CSC format
436        // For now we just return self
437        Ok(Box::new(self.clone()))
438    }
439
440    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
441        // This would convert to DOK format
442        // For now we just return self
443        Ok(Box::new(self.clone()))
444    }
445
446    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
447        // This would convert to LIL format
448        // For now we just return self
449        Ok(Box::new(self.clone()))
450    }
451
452    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
453        // This would convert to DIA format
454        // For now we just return self
455        Ok(Box::new(self.clone()))
456    }
457
458    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
459        // This would convert to BSR format
460        // For now we just return self
461        Ok(Box::new(self.clone()))
462    }
463
464    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
465        // Placeholder implementation - a full implementation would handle direct
466        // sparse matrix addition more efficiently
467        let self_array = self.to_array();
468        let other_array = other.to_array();
469
470        if self.shape() != other.shape() {
471            return Err(SparseError::DimensionMismatch {
472                expected: self.shape().0,
473                found: other.shape().0,
474            });
475        }
476
477        let result = &self_array + &other_array;
478
479        // Convert back to CSR
480        let (rows, cols) = self.shape();
481        let mut data = Vec::new();
482        let mut indices = Vec::new();
483        let mut indptr = vec![0];
484
485        for row in 0..rows {
486            for col in 0..cols {
487                let val = result[[row, col]];
488                if val != T::sparse_zero() {
489                    data.push(val);
490                    indices.push(col);
491                }
492            }
493            indptr.push(data.len());
494        }
495
496        CsrArray::new(
497            Array1::from_vec(data),
498            Array1::from_vec(indices),
499            Array1::from_vec(indptr),
500            self.shape(),
501        )
502        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
503    }
504
505    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
506        // Similar to add, this is a placeholder
507        let self_array = self.to_array();
508        let other_array = other.to_array();
509
510        if self.shape() != other.shape() {
511            return Err(SparseError::DimensionMismatch {
512                expected: self.shape().0,
513                found: other.shape().0,
514            });
515        }
516
517        let result = &self_array - &other_array;
518
519        // Convert back to CSR
520        let (rows, cols) = self.shape();
521        let mut data = Vec::new();
522        let mut indices = Vec::new();
523        let mut indptr = vec![0];
524
525        for row in 0..rows {
526            for col in 0..cols {
527                let val = result[[row, col]];
528                if val != T::sparse_zero() {
529                    data.push(val);
530                    indices.push(col);
531                }
532            }
533            indptr.push(data.len());
534        }
535
536        CsrArray::new(
537            Array1::from_vec(data),
538            Array1::from_vec(indices),
539            Array1::from_vec(indptr),
540            self.shape(),
541        )
542        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
543    }
544
545    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
546        // This is element-wise multiplication (Hadamard product)
547        // In the sparse array API, * is element-wise, not matrix multiplication
548        let self_array = self.to_array();
549        let other_array = other.to_array();
550
551        if self.shape() != other.shape() {
552            return Err(SparseError::DimensionMismatch {
553                expected: self.shape().0,
554                found: other.shape().0,
555            });
556        }
557
558        let result = &self_array * &other_array;
559
560        // Convert back to CSR
561        let (rows, cols) = self.shape();
562        let mut data = Vec::new();
563        let mut indices = Vec::new();
564        let mut indptr = vec![0];
565
566        for row in 0..rows {
567            for col in 0..cols {
568                let val = result[[row, col]];
569                if val != T::sparse_zero() {
570                    data.push(val);
571                    indices.push(col);
572                }
573            }
574            indptr.push(data.len());
575        }
576
577        CsrArray::new(
578            Array1::from_vec(data),
579            Array1::from_vec(indices),
580            Array1::from_vec(indptr),
581            self.shape(),
582        )
583        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
584    }
585
586    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
587        // Element-wise division
588        let self_array = self.to_array();
589        let other_array = other.to_array();
590
591        if self.shape() != other.shape() {
592            return Err(SparseError::DimensionMismatch {
593                expected: self.shape().0,
594                found: other.shape().0,
595            });
596        }
597
598        let result = &self_array / &other_array;
599
600        // Convert back to CSR
601        let (rows, cols) = self.shape();
602        let mut data = Vec::new();
603        let mut indices = Vec::new();
604        let mut indptr = vec![0];
605
606        for row in 0..rows {
607            for col in 0..cols {
608                let val = result[[row, col]];
609                if val != T::sparse_zero() {
610                    data.push(val);
611                    indices.push(col);
612                }
613            }
614            indptr.push(data.len());
615        }
616
617        CsrArray::new(
618            Array1::from_vec(data),
619            Array1::from_vec(indices),
620            Array1::from_vec(indptr),
621            self.shape(),
622        )
623        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
624    }
625
626    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
627        // Matrix multiplication
628        // This is a placeholder; a full implementation would be optimized for sparse arrays
629        let (m, n) = self.shape();
630        let (p, q) = other.shape();
631
632        if n != p {
633            return Err(SparseError::DimensionMismatch {
634                expected: n,
635                found: p,
636            });
637        }
638
639        let mut result = Array2::zeros((m, q));
640        let other_array = other.to_array();
641
642        for row in 0..m {
643            let start = self.indptr[row];
644            let end = self.indptr[row + 1];
645
646            for j in 0..q {
647                let mut sum = T::sparse_zero();
648                for idx in start..end {
649                    let col = self.indices[idx];
650                    sum = sum + self.data[idx] * other_array[[col, j]];
651                }
652                if sum != T::sparse_zero() {
653                    result[[row, j]] = sum;
654                }
655            }
656        }
657
658        // Convert result back to CSR
659        let mut data = Vec::new();
660        let mut indices = Vec::new();
661        let mut indptr = vec![0];
662
663        for row in 0..m {
664            for col in 0..q {
665                let val = result[[row, col]];
666                if val != T::sparse_zero() {
667                    data.push(val);
668                    indices.push(col);
669                }
670            }
671            indptr.push(data.len());
672        }
673
674        CsrArray::new(
675            Array1::from_vec(data),
676            Array1::from_vec(indices),
677            Array1::from_vec(indptr),
678            (m, q),
679        )
680        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
681    }
682
683    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
684        let (m, n) = self.shape();
685        if n != other.len() {
686            return Err(SparseError::DimensionMismatch {
687                expected: n,
688                found: other.len(),
689            });
690        }
691
692        let mut result = Array1::zeros(m);
693
694        for row in 0..m {
695            let start = self.indptr[row];
696            let end = self.indptr[row + 1];
697
698            let mut sum = T::sparse_zero();
699            for idx in start..end {
700                let col = self.indices[idx];
701                sum = sum + self.data[idx] * other[col];
702            }
703            result[row] = sum;
704        }
705
706        Ok(result)
707    }
708
709    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
710        // Transpose is non-trivial for CSR format
711        // A full implementation would convert to CSC format or implement
712        // an efficient algorithm
713        let (rows, cols) = self.shape();
714        let mut row_indices = Vec::with_capacity(self.nnz());
715        let mut col_indices = Vec::with_capacity(self.nnz());
716        let mut values = Vec::with_capacity(self.nnz());
717
718        for row in 0..rows {
719            let start = self.indptr[row];
720            let end = self.indptr[row + 1];
721
722            for idx in start..end {
723                let col = self.indices[idx];
724                row_indices.push(col); // Note: rows and cols are swapped for transposition
725                col_indices.push(row);
726                values.push(self.data[idx]);
727            }
728        }
729
730        // We need to create CSR from this "COO" representation
731        CsrArray::from_triplets(
732            &row_indices,
733            &col_indices,
734            &values,
735            (cols, rows), // Swapped dimensions
736            false,
737        )
738        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
739    }
740
741    fn copy(&self) -> Box<dyn SparseArray<T>> {
742        Box::new(self.clone())
743    }
744
745    fn get(&self, i: usize, j: usize) -> T {
746        if i >= self.shape.0 || j >= self.shape.1 {
747            return T::sparse_zero();
748        }
749
750        let start = self.indptr[i];
751        let end = self.indptr[i + 1];
752
753        for idx in start..end {
754            if self.indices[idx] == j {
755                return self.data[idx];
756            }
757            // If indices are sorted, we can break early
758            if self.has_sorted_indices && self.indices[idx] > j {
759                break;
760            }
761        }
762
763        T::sparse_zero()
764    }
765
766    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
767        // Setting elements in CSR format is non-trivial
768        // This is a placeholder implementation that doesn't actually modify the structure
769        if i >= self.shape.0 || j >= self.shape.1 {
770            return Err(SparseError::IndexOutOfBounds {
771                index: (i, j),
772                shape: self.shape,
773            });
774        }
775
776        let start = self.indptr[i];
777        let end = self.indptr[i + 1];
778
779        // Try to find existing element
780        for idx in start..end {
781            if self.indices[idx] == j {
782                // Update value
783                self.data[idx] = value;
784                return Ok(());
785            }
786            // If indices are sorted, we can insert at the right position
787            if self.has_sorted_indices && self.indices[idx] > j {
788                // Insert here - this would require restructuring indices and data
789                // Not implemented in this placeholder
790                return Err(SparseError::NotImplemented(
791                    "Inserting new elements in CSR format".to_string(),
792                ));
793            }
794        }
795
796        // Element not found, would need to insert
797        // This would require restructuring indices and data
798        Err(SparseError::NotImplemented(
799            "Inserting new elements in CSR format".to_string(),
800        ))
801    }
802
803    fn eliminate_zeros(&mut self) {
804        // Find all non-zero entries
805        let mut new_data = Vec::new();
806        let mut new_indices = Vec::new();
807        let mut new_indptr = vec![0];
808
809        let (rows, _) = self.shape();
810
811        for row in 0..rows {
812            let start = self.indptr[row];
813            let end = self.indptr[row + 1];
814
815            for idx in start..end {
816                if !SparseElement::is_zero(&self.data[idx]) {
817                    new_data.push(self.data[idx]);
818                    new_indices.push(self.indices[idx]);
819                }
820            }
821            new_indptr.push(new_data.len());
822        }
823
824        // Replace data with filtered data
825        self.data = Array1::from_vec(new_data);
826        self.indices = Array1::from_vec(new_indices);
827        self.indptr = Array1::from_vec(new_indptr);
828    }
829
830    fn sort_indices(&mut self) {
831        if self.has_sorted_indices {
832            return;
833        }
834
835        let (rows, _) = self.shape();
836
837        for row in 0..rows {
838            let start = self.indptr[row];
839            let end = self.indptr[row + 1];
840
841            if start == end {
842                continue;
843            }
844
845            // Extract the non-zero elements for this row
846            let mut row_data = Vec::with_capacity(end - start);
847            for idx in start..end {
848                row_data.push((self.indices[idx], self.data[idx]));
849            }
850
851            // Sort by column index
852            row_data.sort_by_key(|&(col_, _)| col_);
853
854            // Put the sorted data back
855            for (i, (col, val)) in row_data.into_iter().enumerate() {
856                self.indices[start + i] = col;
857                self.data[start + i] = val;
858            }
859        }
860
861        self.has_sorted_indices = true;
862    }
863
864    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
865        if self.has_sorted_indices {
866            return Box::new(self.clone());
867        }
868
869        let mut sorted = self.clone();
870        sorted.sort_indices();
871        Box::new(sorted)
872    }
873
874    fn has_sorted_indices(&self) -> bool {
875        self.has_sorted_indices
876    }
877
878    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
879        match axis {
880            None => {
881                // Sum all elements
882                let mut sum = T::sparse_zero();
883                for &val in self.data.iter() {
884                    sum = sum + val;
885                }
886                Ok(SparseSum::Scalar(sum))
887            }
888            Some(0) => {
889                // Sum along rows (result is a row vector)
890                let (_, cols) = self.shape();
891                let mut result = vec![T::sparse_zero(); cols];
892
893                for row in 0..self.shape.0 {
894                    let start = self.indptr[row];
895                    let end = self.indptr[row + 1];
896
897                    for idx in start..end {
898                        let col = self.indices[idx];
899                        result[col] = result[col] + self.data[idx];
900                    }
901                }
902
903                // Convert to CSR format
904                let mut data = Vec::new();
905                let mut indices = Vec::new();
906                let mut indptr = vec![0];
907
908                for (col, &val) in result.iter().enumerate() {
909                    if val != T::sparse_zero() {
910                        data.push(val);
911                        indices.push(col);
912                    }
913                }
914                indptr.push(data.len());
915
916                let result_array = CsrArray::new(
917                    Array1::from_vec(data),
918                    Array1::from_vec(indices),
919                    Array1::from_vec(indptr),
920                    (1, cols),
921                )?;
922
923                Ok(SparseSum::SparseArray(Box::new(result_array)))
924            }
925            Some(1) => {
926                // Sum along columns (result is a column vector)
927                let mut result = Vec::with_capacity(self.shape.0);
928
929                for row in 0..self.shape.0 {
930                    let start = self.indptr[row];
931                    let end = self.indptr[row + 1];
932
933                    let mut row_sum = T::sparse_zero();
934                    for idx in start..end {
935                        row_sum = row_sum + self.data[idx];
936                    }
937                    result.push(row_sum);
938                }
939
940                // Convert to CSR format
941                let mut data = Vec::new();
942                let mut indices = Vec::new();
943                let mut indptr = vec![0];
944
945                for &val in result.iter() {
946                    if val != T::sparse_zero() {
947                        data.push(val);
948                        indices.push(0);
949                        indptr.push(data.len());
950                    } else {
951                        indptr.push(data.len());
952                    }
953                }
954
955                let result_array = CsrArray::new(
956                    Array1::from_vec(data),
957                    Array1::from_vec(indices),
958                    Array1::from_vec(indptr),
959                    (self.shape.0, 1),
960                )?;
961
962                Ok(SparseSum::SparseArray(Box::new(result_array)))
963            }
964            _ => Err(SparseError::InvalidAxis),
965        }
966    }
967
968    fn max(&self) -> T {
969        if self.data.is_empty() {
970            // Empty sparse matrix - all elements are implicitly zero
971            return T::sparse_zero();
972        }
973
974        let mut max_val = self.data[0];
975        for &val in self.data.iter().skip(1) {
976            if val > max_val {
977                max_val = val;
978            }
979        }
980
981        // Check if max_val is less than zero, as zeros aren't explicitly stored
982        // If the matrix has implicit zeros and max_val < 0, then max is 0
983        let zero = T::sparse_zero();
984        if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
985            max_val = zero;
986        }
987
988        max_val
989    }
990
991    fn min(&self) -> T {
992        if self.data.is_empty() {
993            // Empty sparse matrix - all elements are implicitly zero
994            return T::sparse_zero();
995        }
996
997        let mut min_val = self.data[0];
998        for &val in self.data.iter().skip(1) {
999            if val < min_val {
1000                min_val = val;
1001            }
1002        }
1003
1004        // Check if min_val is greater than zero, as zeros aren't explicitly stored
1005        // If the matrix has implicit zeros and min_val > 0, then min is 0
1006        let zero = T::sparse_zero();
1007        if min_val > zero && self.nnz() < self.shape.0 * self.shape.1 {
1008            min_val = zero;
1009        }
1010
1011        min_val
1012    }
1013
1014    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
1015        let nnz = self.nnz();
1016        let mut rows = Vec::with_capacity(nnz);
1017        let mut cols = Vec::with_capacity(nnz);
1018        let mut values = Vec::with_capacity(nnz);
1019
1020        for row in 0..self.shape.0 {
1021            let start = self.indptr[row];
1022            let end = self.indptr[row + 1];
1023
1024            for idx in start..end {
1025                let col = self.indices[idx];
1026                rows.push(row);
1027                cols.push(col);
1028                values.push(self.data[idx]);
1029            }
1030        }
1031
1032        (
1033            Array1::from_vec(rows),
1034            Array1::from_vec(cols),
1035            Array1::from_vec(values),
1036        )
1037    }
1038
1039    fn slice(
1040        &self,
1041        row_range: (usize, usize),
1042        col_range: (usize, usize),
1043    ) -> SparseResult<Box<dyn SparseArray<T>>> {
1044        let (start_row, end_row) = row_range;
1045        let (start_col, end_col) = col_range;
1046
1047        if start_row >= self.shape.0
1048            || end_row > self.shape.0
1049            || start_col >= self.shape.1
1050            || end_col > self.shape.1
1051        {
1052            return Err(SparseError::InvalidSliceRange);
1053        }
1054
1055        if start_row >= end_row || start_col >= end_col {
1056            return Err(SparseError::InvalidSliceRange);
1057        }
1058
1059        let mut data = Vec::new();
1060        let mut indices = Vec::new();
1061        let mut indptr = vec![0];
1062
1063        for row in start_row..end_row {
1064            let start = self.indptr[row];
1065            let end = self.indptr[row + 1];
1066
1067            for idx in start..end {
1068                let col = self.indices[idx];
1069                if col >= start_col && col < end_col {
1070                    data.push(self.data[idx]);
1071                    indices.push(col - start_col);
1072                }
1073            }
1074            indptr.push(data.len());
1075        }
1076
1077        CsrArray::new(
1078            Array1::from_vec(data),
1079            Array1::from_vec(indices),
1080            Array1::from_vec(indptr),
1081            (end_row - start_row, end_col - start_col),
1082        )
1083        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
1084    }
1085
1086    fn as_any(&self) -> &dyn std::any::Any {
1087        self
1088    }
1089
1090    fn get_indptr(&self) -> Option<&Array1<usize>> {
1091        Some(&self.indptr)
1092    }
1093
1094    fn indptr(&self) -> Option<&Array1<usize>> {
1095        Some(&self.indptr)
1096    }
1097}
1098
1099impl<T> fmt::Debug for CsrArray<T>
1100where
1101    T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
1102{
1103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1104        write!(
1105            f,
1106            "CsrArray<{}x{}, nnz={}>",
1107            self.shape.0,
1108            self.shape.1,
1109            self.nnz()
1110        )
1111    }
1112}
1113
1114#[cfg(test)]
1115mod tests {
1116    use super::*;
1117    use approx::assert_relative_eq;
1118
1119    #[test]
1120    fn test_csr_array_construction() {
1121        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1122        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1123        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1124        let shape = (3, 3);
1125
1126        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1127
1128        assert_eq!(csr.shape(), (3, 3));
1129        assert_eq!(csr.nnz(), 5);
1130        assert_eq!(csr.get(0, 0), 1.0);
1131        assert_eq!(csr.get(0, 2), 2.0);
1132        assert_eq!(csr.get(1, 1), 3.0);
1133        assert_eq!(csr.get(2, 0), 4.0);
1134        assert_eq!(csr.get(2, 2), 5.0);
1135        assert_eq!(csr.get(0, 1), 0.0);
1136    }
1137
1138    #[test]
1139    fn test_csr_from_triplets() {
1140        let rows = vec![0, 0, 1, 2, 2];
1141        let cols = vec![0, 2, 1, 0, 2];
1142        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1143        let shape = (3, 3);
1144
1145        let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
1146
1147        assert_eq!(csr.shape(), (3, 3));
1148        assert_eq!(csr.nnz(), 5);
1149        assert_eq!(csr.get(0, 0), 1.0);
1150        assert_eq!(csr.get(0, 2), 2.0);
1151        assert_eq!(csr.get(1, 1), 3.0);
1152        assert_eq!(csr.get(2, 0), 4.0);
1153        assert_eq!(csr.get(2, 2), 5.0);
1154        assert_eq!(csr.get(0, 1), 0.0);
1155    }
1156
1157    #[test]
1158    fn test_csr_array_to_array() {
1159        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1160        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1161        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1162        let shape = (3, 3);
1163
1164        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1165        let dense = csr.to_array();
1166
1167        assert_eq!(dense.shape(), &[3, 3]);
1168        assert_eq!(dense[[0, 0]], 1.0);
1169        assert_eq!(dense[[0, 1]], 0.0);
1170        assert_eq!(dense[[0, 2]], 2.0);
1171        assert_eq!(dense[[1, 0]], 0.0);
1172        assert_eq!(dense[[1, 1]], 3.0);
1173        assert_eq!(dense[[1, 2]], 0.0);
1174        assert_eq!(dense[[2, 0]], 4.0);
1175        assert_eq!(dense[[2, 1]], 0.0);
1176        assert_eq!(dense[[2, 2]], 5.0);
1177    }
1178
1179    #[test]
1180    fn test_csr_array_dot_vector() {
1181        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1182        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1183        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1184        let shape = (3, 3);
1185
1186        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1187        let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1188
1189        let result = csr.dot_vector(&vec.view()).unwrap();
1190
1191        // Expected: [1*1 + 0*2 + 2*3, 0*1 + 3*2 + 0*3, 4*1 + 0*2 + 5*3] = [7, 6, 19]
1192        assert_eq!(result.len(), 3);
1193        assert_relative_eq!(result[0], 7.0);
1194        assert_relative_eq!(result[1], 6.0);
1195        assert_relative_eq!(result[2], 19.0);
1196    }
1197
1198    #[test]
1199    fn test_csr_array_sum() {
1200        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1201        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1202        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1203        let shape = (3, 3);
1204
1205        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1206
1207        // Sum all elements
1208        if let SparseSum::Scalar(sum) = csr.sum(None).unwrap() {
1209            assert_relative_eq!(sum, 15.0);
1210        } else {
1211            panic!("Expected scalar sum");
1212        }
1213
1214        // Sum along rows
1215        if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).unwrap() {
1216            let row_sum_array = row_sum.to_array();
1217            assert_eq!(row_sum_array.shape(), &[1, 3]);
1218            assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1219            assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1220            assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1221        } else {
1222            panic!("Expected sparse array sum");
1223        }
1224
1225        // Sum along columns
1226        if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).unwrap() {
1227            let col_sum_array = col_sum.to_array();
1228            assert_eq!(col_sum_array.shape(), &[3, 1]);
1229            assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1230            assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1231            assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1232        } else {
1233            panic!("Expected sparse array sum");
1234        }
1235    }
1236}