scirs2_sparse/
csr_array.rs

1// CSR Array implementation
2//
3// This module provides the CSR (Compressed Sparse Row) array format,
4// which is efficient for row-wise operations and is one of the most common formats.
5
6use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14/// CSR Array format - Compressed Sparse Row matrix representation
15///
16/// The CSR (Compressed Sparse Row) format is one of the most popular sparse matrix formats,
17/// storing a sparse array using three arrays:
18/// - `data`: array of non-zero values in row-major order
19/// - `indices`: column indices of the non-zero values
20/// - `indptr`: row pointers; `indptr[i]` is the index into `data`/`indices` where row `i` starts
21///
22/// The CSR format is particularly efficient for:
23/// - ✅ Matrix-vector multiplications (`A * x`)
24/// - ✅ Matrix-matrix multiplications with other sparse matrices
25/// - ✅ Row-wise operations and row slicing
26/// - ✅ Iterating over non-zero elements row by row
27/// - ✅ Adding and subtracting sparse matrices
28///
29/// But less efficient for:
30/// - ❌ Column-wise operations and column slicing
31/// - ❌ Inserting or modifying individual elements after construction
32/// - ❌ Operations that require column access patterns
33///
34/// # Memory Layout
35///
36/// For a matrix with `m` rows, `n` columns, and `nnz` non-zero elements:
37/// - `data`: length `nnz` - stores the actual non-zero values
38/// - `indices`: length `nnz` - stores column indices for each non-zero value
39/// - `indptr`: length `m+1` - stores cumulative count of non-zeros per row
40///
41/// # Examples
42///
43/// ## Basic Construction and Access
44/// ```
45/// use scirs2_sparse::csr_array::CsrArray;
46///
47/// // Create a 3x3 matrix:
48/// // [1.0, 0.0, 2.0]
49/// // [0.0, 3.0, 0.0]
50/// // [4.0, 0.0, 5.0]
51/// let rows = vec![0, 0, 1, 2, 2];
52/// let cols = vec![0, 2, 1, 0, 2];
53/// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
54/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
55///
56/// // Access elements
57/// assert_eq!(matrix.get(0, 0), 1.0);
58/// assert_eq!(matrix.get(0, 1), 0.0);  // Zero element
59/// assert_eq!(matrix.get(1, 1), 3.0);
60///
61/// // Get matrix properties
62/// assert_eq!(matrix.shape(), (3, 3));
63/// assert_eq!(matrix.nnz(), 5);
64/// ```
65///
66/// ## Matrix Operations
67/// ```
68/// use scirs2_sparse::csr_array::CsrArray;
69/// use ndarray::Array1;
70///
71/// let rows = vec![0, 1, 2];
72/// let cols = vec![0, 1, 2];
73/// let data = vec![2.0, 3.0, 4.0];
74/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
75///
76/// // Matrix-vector multiplication
77/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
78/// let y = matrix.dot(&x).unwrap();
79/// assert_eq!(y[0], 2.0);  // 2.0 * 1.0
80/// assert_eq!(y[1], 6.0);  // 3.0 * 2.0
81/// assert_eq!(y[2], 12.0); // 4.0 * 3.0
82/// ```
83///
84/// ## Format Conversion
85/// ```
86/// use scirs2_sparse::csr_array::CsrArray;
87///
88/// let rows = vec![0, 1];
89/// let cols = vec![0, 1];
90/// let data = vec![1.0, 2.0];
91/// let csr = CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).unwrap();
92///
93/// // Convert to dense array
94/// let dense = csr.to_array();
95/// assert_eq!(dense[[0, 0]], 1.0);
96/// assert_eq!(dense[[1, 1]], 2.0);
97///
98/// // Convert to other sparse formats
99/// let coo = csr.to_coo();
100/// let csc = csr.to_csc();
101/// ```
102#[derive(Clone)]
103pub struct CsrArray<T>
104where
105    T: Float
106        + Add<Output = T>
107        + Sub<Output = T>
108        + Mul<Output = T>
109        + Div<Output = T>
110        + Debug
111        + Copy
112        + 'static,
113{
114    /// Non-zero values
115    data: Array1<T>,
116    /// Column indices of non-zero values
117    indices: Array1<usize>,
118    /// Row pointers (indices into data/indices for the start of each row)
119    indptr: Array1<usize>,
120    /// Shape of the sparse array
121    shape: (usize, usize),
122    /// Whether indices are sorted for each row
123    has_sorted_indices: bool,
124}
125
126impl<T> CsrArray<T>
127where
128    T: Float
129        + Add<Output = T>
130        + Sub<Output = T>
131        + Mul<Output = T>
132        + Div<Output = T>
133        + Debug
134        + Copy
135        + 'static,
136{
137    /// Creates a new CSR array from raw components
138    ///
139    /// # Arguments
140    /// * `data` - Array of non-zero values
141    /// * `indices` - Column indices of non-zero values
142    /// * `indptr` - Index pointers for the start of each row
143    /// * `shape` - Shape of the sparse array
144    ///
145    /// # Returns
146    /// A new `CsrArray`
147    ///
148    /// # Errors
149    /// Returns an error if the data is not consistent
150    pub fn new(
151        data: Array1<T>,
152        indices: Array1<usize>,
153        indptr: Array1<usize>,
154        shape: (usize, usize),
155    ) -> SparseResult<Self> {
156        // Validation
157        if data.len() != indices.len() {
158            return Err(SparseError::InconsistentData {
159                reason: "data and indices must have the same length".to_string(),
160            });
161        }
162
163        if indptr.len() != shape.0 + 1 {
164            return Err(SparseError::InconsistentData {
165                reason: format!(
166                    "indptr length ({}) must be one more than the number of rows ({})",
167                    indptr.len(),
168                    shape.0
169                ),
170            });
171        }
172
173        if let Some(&max_idx) = indices.iter().max() {
174            if max_idx >= shape.1 {
175                return Err(SparseError::IndexOutOfBounds {
176                    index: (0, max_idx),
177                    shape,
178                });
179            }
180        }
181
182        if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
183            if first != 0 {
184                return Err(SparseError::InconsistentData {
185                    reason: "first element of indptr must be 0".to_string(),
186                });
187            }
188
189            if last != data.len() {
190                return Err(SparseError::InconsistentData {
191                    reason: format!(
192                        "last element of indptr ({}) must equal data length ({})",
193                        last,
194                        data.len()
195                    ),
196                });
197            }
198        }
199
200        let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
201
202        Ok(Self {
203            data,
204            indices,
205            indptr,
206            shape,
207            has_sorted_indices,
208        })
209    }
210
211    /// Create a CSR array from triplet format (COO-like)
212    ///
213    /// This function creates a CSR (Compressed Sparse Row) array from coordinate triplets.
214    /// The triplets represent non-zero elements as (row, column, value) tuples.
215    ///
216    /// # Arguments
217    /// * `rows` - Row indices of non-zero elements
218    /// * `cols` - Column indices of non-zero elements  
219    /// * `data` - Values of non-zero elements
220    /// * `shape` - Shape of the sparse array (nrows, ncols)
221    /// * `sorted` - Whether the triplets are already sorted by (row, col). If false, sorting will be performed.
222    ///
223    /// # Returns
224    /// A new `CsrArray` containing the sparse matrix
225    ///
226    /// # Errors
227    /// Returns an error if:
228    /// - `rows`, `cols`, and `data` have different lengths
229    /// - Any index is out of bounds for the given shape
230    /// - The resulting data structure is inconsistent
231    ///
232    /// # Examples
233    ///
234    /// Create a simple 3x3 sparse matrix:
235    /// ```
236    /// use scirs2_sparse::csr_array::CsrArray;
237    ///
238    /// // Create a 3x3 matrix with the following structure:
239    /// // [1.0, 0.0, 2.0]
240    /// // [0.0, 3.0, 0.0]
241    /// // [4.0, 0.0, 5.0]
242    /// let rows = vec![0, 0, 1, 2, 2];
243    /// let cols = vec![0, 2, 1, 0, 2];
244    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
245    /// let shape = (3, 3);
246    ///
247    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
248    /// assert_eq!(matrix.get(0, 0), 1.0);
249    /// assert_eq!(matrix.get(0, 1), 0.0);
250    /// assert_eq!(matrix.get(1, 1), 3.0);
251    /// ```
252    ///
253    /// Create an empty sparse matrix:
254    /// ```
255    /// use scirs2_sparse::csr_array::CsrArray;
256    ///
257    /// let rows: Vec<usize> = vec![];
258    /// let cols: Vec<usize> = vec![];
259    /// let data: Vec<f64> = vec![];
260    /// let shape = (5, 5);
261    ///
262    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
263    /// assert_eq!(matrix.nnz(), 0);
264    /// assert_eq!(matrix.shape(), (5, 5));
265    /// ```
266    ///
267    /// Handle duplicate entries (they will be preserved):
268    /// ```
269    /// use scirs2_sparse::csr_array::CsrArray;
270    ///
271    /// // Multiple entries at the same position
272    /// let rows = vec![0, 0];
273    /// let cols = vec![0, 0];
274    /// let data = vec![1.0, 2.0];
275    /// let shape = (2, 2);
276    ///
277    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
278    /// // Note: CSR format preserves duplicates; use sum_duplicates() to combine them
279    /// assert_eq!(matrix.nnz(), 2);
280    /// ```
281    pub fn from_triplets(
282        rows: &[usize],
283        cols: &[usize],
284        data: &[T],
285        shape: (usize, usize),
286        sorted: bool,
287    ) -> SparseResult<Self> {
288        if rows.len() != cols.len() || rows.len() != data.len() {
289            return Err(SparseError::InconsistentData {
290                reason: "rows, cols, and data must have the same length".to_string(),
291            });
292        }
293
294        if rows.is_empty() {
295            // Empty matrix
296            let indptr = Array1::zeros(shape.0 + 1);
297            return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
298        }
299
300        let nnz = rows.len();
301        let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
302
303        for i in 0..nnz {
304            if rows[i] >= shape.0 || cols[i] >= shape.1 {
305                return Err(SparseError::IndexOutOfBounds {
306                    index: (rows[i], cols[i]),
307                    shape,
308                });
309            }
310            all_data.push((rows[i], cols[i], data[i]));
311        }
312
313        if !sorted {
314            all_data.sort_by_key(|&(row, col_, _)| (row, col_));
315        }
316
317        // Count elements per row
318        let mut row_counts = vec![0; shape.0];
319        for &(row_, _, _) in &all_data {
320            row_counts[row_] += 1;
321        }
322
323        // Create indptr
324        let mut indptr = Vec::with_capacity(shape.0 + 1);
325        indptr.push(0);
326        let mut cumsum = 0;
327        for &count in &row_counts {
328            cumsum += count;
329            indptr.push(cumsum);
330        }
331
332        // Create indices and data arrays
333        let mut indices = Vec::with_capacity(nnz);
334        let mut values = Vec::with_capacity(nnz);
335
336        for (_, col, val) in all_data {
337            indices.push(col);
338            values.push(val);
339        }
340
341        Self::new(
342            Array1::from_vec(values),
343            Array1::from_vec(indices),
344            Array1::from_vec(indptr),
345            shape,
346        )
347    }
348
349    /// Checks if column indices are sorted for each row
350    fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
351        for row in 0..indptr.len() - 1 {
352            let start = indptr[row];
353            let end = indptr[row + 1];
354
355            for i in start..end.saturating_sub(1) {
356                if i + 1 < indices.len() && indices[i] > indices[i + 1] {
357                    return false;
358                }
359            }
360        }
361        true
362    }
363
364    /// Get the raw data array
365    pub fn get_data(&self) -> &Array1<T> {
366        &self.data
367    }
368
369    /// Get the raw indices array
370    pub fn get_indices(&self) -> &Array1<usize> {
371        &self.indices
372    }
373
374    /// Get the raw indptr array
375    pub fn get_indptr(&self) -> &Array1<usize> {
376        &self.indptr
377    }
378
379    /// Get the number of rows
380    pub fn nrows(&self) -> usize {
381        self.shape.0
382    }
383
384    /// Get the number of columns  
385    pub fn ncols(&self) -> usize {
386        self.shape.1
387    }
388
389    /// Get the shape (rows, cols)
390    pub fn shape(&self) -> (usize, usize) {
391        self.shape
392    }
393}
394
395impl<T> SparseArray<T> for CsrArray<T>
396where
397    T: Float
398        + Add<Output = T>
399        + Sub<Output = T>
400        + Mul<Output = T>
401        + Div<Output = T>
402        + Debug
403        + Copy
404        + 'static,
405{
406    fn shape(&self) -> (usize, usize) {
407        self.shape
408    }
409
410    fn nnz(&self) -> usize {
411        self.data.len()
412    }
413
414    fn dtype(&self) -> &str {
415        "float" // Placeholder, ideally we would return the actual type
416    }
417
418    fn to_array(&self) -> Array2<T> {
419        let (rows, cols) = self.shape;
420        let mut result = Array2::zeros((rows, cols));
421
422        for row in 0..rows {
423            let start = self.indptr[row];
424            let end = self.indptr[row + 1];
425
426            for i in start..end {
427                let col = self.indices[i];
428                result[[row, col]] = self.data[i];
429            }
430        }
431
432        result
433    }
434
435    fn toarray(&self) -> Array2<T> {
436        self.to_array()
437    }
438
439    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
440        // This would convert to COO format
441        // For now we just return self
442        Ok(Box::new(self.clone()))
443    }
444
445    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
446        Ok(Box::new(self.clone()))
447    }
448
449    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
450        // This would convert to CSC format
451        // For now we just return self
452        Ok(Box::new(self.clone()))
453    }
454
455    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
456        // This would convert to DOK format
457        // For now we just return self
458        Ok(Box::new(self.clone()))
459    }
460
461    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
462        // This would convert to LIL format
463        // For now we just return self
464        Ok(Box::new(self.clone()))
465    }
466
467    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
468        // This would convert to DIA format
469        // For now we just return self
470        Ok(Box::new(self.clone()))
471    }
472
473    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
474        // This would convert to BSR format
475        // For now we just return self
476        Ok(Box::new(self.clone()))
477    }
478
479    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
480        // Placeholder implementation - a full implementation would handle direct
481        // sparse matrix addition more efficiently
482        let self_array = self.to_array();
483        let other_array = other.to_array();
484
485        if self.shape() != other.shape() {
486            return Err(SparseError::DimensionMismatch {
487                expected: self.shape().0,
488                found: other.shape().0,
489            });
490        }
491
492        let result = &self_array + &other_array;
493
494        // Convert back to CSR
495        let (rows, cols) = self.shape();
496        let mut data = Vec::new();
497        let mut indices = Vec::new();
498        let mut indptr = vec![0];
499
500        for row in 0..rows {
501            for col in 0..cols {
502                let val = result[[row, col]];
503                if !val.is_zero() {
504                    data.push(val);
505                    indices.push(col);
506                }
507            }
508            indptr.push(data.len());
509        }
510
511        CsrArray::new(
512            Array1::from_vec(data),
513            Array1::from_vec(indices),
514            Array1::from_vec(indptr),
515            self.shape(),
516        )
517        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
518    }
519
520    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
521        // Similar to add, this is a placeholder
522        let self_array = self.to_array();
523        let other_array = other.to_array();
524
525        if self.shape() != other.shape() {
526            return Err(SparseError::DimensionMismatch {
527                expected: self.shape().0,
528                found: other.shape().0,
529            });
530        }
531
532        let result = &self_array - &other_array;
533
534        // Convert back to CSR
535        let (rows, cols) = self.shape();
536        let mut data = Vec::new();
537        let mut indices = Vec::new();
538        let mut indptr = vec![0];
539
540        for row in 0..rows {
541            for col in 0..cols {
542                let val = result[[row, col]];
543                if !val.is_zero() {
544                    data.push(val);
545                    indices.push(col);
546                }
547            }
548            indptr.push(data.len());
549        }
550
551        CsrArray::new(
552            Array1::from_vec(data),
553            Array1::from_vec(indices),
554            Array1::from_vec(indptr),
555            self.shape(),
556        )
557        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
558    }
559
560    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
561        // This is element-wise multiplication (Hadamard product)
562        // In the sparse array API, * is element-wise, not matrix multiplication
563        let self_array = self.to_array();
564        let other_array = other.to_array();
565
566        if self.shape() != other.shape() {
567            return Err(SparseError::DimensionMismatch {
568                expected: self.shape().0,
569                found: other.shape().0,
570            });
571        }
572
573        let result = &self_array * &other_array;
574
575        // Convert back to CSR
576        let (rows, cols) = self.shape();
577        let mut data = Vec::new();
578        let mut indices = Vec::new();
579        let mut indptr = vec![0];
580
581        for row in 0..rows {
582            for col in 0..cols {
583                let val = result[[row, col]];
584                if !val.is_zero() {
585                    data.push(val);
586                    indices.push(col);
587                }
588            }
589            indptr.push(data.len());
590        }
591
592        CsrArray::new(
593            Array1::from_vec(data),
594            Array1::from_vec(indices),
595            Array1::from_vec(indptr),
596            self.shape(),
597        )
598        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
599    }
600
601    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
602        // Element-wise division
603        let self_array = self.to_array();
604        let other_array = other.to_array();
605
606        if self.shape() != other.shape() {
607            return Err(SparseError::DimensionMismatch {
608                expected: self.shape().0,
609                found: other.shape().0,
610            });
611        }
612
613        let result = &self_array / &other_array;
614
615        // Convert back to CSR
616        let (rows, cols) = self.shape();
617        let mut data = Vec::new();
618        let mut indices = Vec::new();
619        let mut indptr = vec![0];
620
621        for row in 0..rows {
622            for col in 0..cols {
623                let val = result[[row, col]];
624                if !val.is_zero() {
625                    data.push(val);
626                    indices.push(col);
627                }
628            }
629            indptr.push(data.len());
630        }
631
632        CsrArray::new(
633            Array1::from_vec(data),
634            Array1::from_vec(indices),
635            Array1::from_vec(indptr),
636            self.shape(),
637        )
638        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
639    }
640
641    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
642        // Matrix multiplication
643        // This is a placeholder; a full implementation would be optimized for sparse arrays
644        let (m, n) = self.shape();
645        let (p, q) = other.shape();
646
647        if n != p {
648            return Err(SparseError::DimensionMismatch {
649                expected: n,
650                found: p,
651            });
652        }
653
654        let mut result = Array2::zeros((m, q));
655        let other_array = other.to_array();
656
657        for row in 0..m {
658            let start = self.indptr[row];
659            let end = self.indptr[row + 1];
660
661            for j in 0..q {
662                let mut sum = T::zero();
663                for idx in start..end {
664                    let col = self.indices[idx];
665                    sum = sum + self.data[idx] * other_array[[col, j]];
666                }
667                if !sum.is_zero() {
668                    result[[row, j]] = sum;
669                }
670            }
671        }
672
673        // Convert result back to CSR
674        let mut data = Vec::new();
675        let mut indices = Vec::new();
676        let mut indptr = vec![0];
677
678        for row in 0..m {
679            for col in 0..q {
680                let val = result[[row, col]];
681                if !val.is_zero() {
682                    data.push(val);
683                    indices.push(col);
684                }
685            }
686            indptr.push(data.len());
687        }
688
689        CsrArray::new(
690            Array1::from_vec(data),
691            Array1::from_vec(indices),
692            Array1::from_vec(indptr),
693            (m, q),
694        )
695        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
696    }
697
698    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
699        let (m, n) = self.shape();
700        if n != other.len() {
701            return Err(SparseError::DimensionMismatch {
702                expected: n,
703                found: other.len(),
704            });
705        }
706
707        let mut result = Array1::zeros(m);
708
709        for row in 0..m {
710            let start = self.indptr[row];
711            let end = self.indptr[row + 1];
712
713            let mut sum = T::zero();
714            for idx in start..end {
715                let col = self.indices[idx];
716                sum = sum + self.data[idx] * other[col];
717            }
718            result[row] = sum;
719        }
720
721        Ok(result)
722    }
723
724    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
725        // Transpose is non-trivial for CSR format
726        // A full implementation would convert to CSC format or implement
727        // an efficient algorithm
728        let (rows, cols) = self.shape();
729        let mut row_indices = Vec::with_capacity(self.nnz());
730        let mut col_indices = Vec::with_capacity(self.nnz());
731        let mut values = Vec::with_capacity(self.nnz());
732
733        for row in 0..rows {
734            let start = self.indptr[row];
735            let end = self.indptr[row + 1];
736
737            for idx in start..end {
738                let col = self.indices[idx];
739                row_indices.push(col); // Note: rows and cols are swapped for transposition
740                col_indices.push(row);
741                values.push(self.data[idx]);
742            }
743        }
744
745        // We need to create CSR from this "COO" representation
746        CsrArray::from_triplets(
747            &row_indices,
748            &col_indices,
749            &values,
750            (cols, rows), // Swapped dimensions
751            false,
752        )
753        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
754    }
755
756    fn copy(&self) -> Box<dyn SparseArray<T>> {
757        Box::new(self.clone())
758    }
759
760    fn get(&self, i: usize, j: usize) -> T {
761        if i >= self.shape.0 || j >= self.shape.1 {
762            return T::zero();
763        }
764
765        let start = self.indptr[i];
766        let end = self.indptr[i + 1];
767
768        for idx in start..end {
769            if self.indices[idx] == j {
770                return self.data[idx];
771            }
772            // If indices are sorted, we can break early
773            if self.has_sorted_indices && self.indices[idx] > j {
774                break;
775            }
776        }
777
778        T::zero()
779    }
780
781    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
782        // Setting elements in CSR format is non-trivial
783        // This is a placeholder implementation that doesn't actually modify the structure
784        if i >= self.shape.0 || j >= self.shape.1 {
785            return Err(SparseError::IndexOutOfBounds {
786                index: (i, j),
787                shape: self.shape,
788            });
789        }
790
791        let start = self.indptr[i];
792        let end = self.indptr[i + 1];
793
794        // Try to find existing element
795        for idx in start..end {
796            if self.indices[idx] == j {
797                // Update value
798                self.data[idx] = value;
799                return Ok(());
800            }
801            // If indices are sorted, we can insert at the right position
802            if self.has_sorted_indices && self.indices[idx] > j {
803                // Insert here - this would require restructuring indices and data
804                // Not implemented in this placeholder
805                return Err(SparseError::NotImplemented(
806                    "Inserting new elements in CSR format".to_string(),
807                ));
808            }
809        }
810
811        // Element not found, would need to insert
812        // This would require restructuring indices and data
813        Err(SparseError::NotImplemented(
814            "Inserting new elements in CSR format".to_string(),
815        ))
816    }
817
818    fn eliminate_zeros(&mut self) {
819        // Find all non-zero entries
820        let mut new_data = Vec::new();
821        let mut new_indices = Vec::new();
822        let mut new_indptr = vec![0];
823
824        let (rows, _) = self.shape();
825
826        for row in 0..rows {
827            let start = self.indptr[row];
828            let end = self.indptr[row + 1];
829
830            for idx in start..end {
831                if !self.data[idx].is_zero() {
832                    new_data.push(self.data[idx]);
833                    new_indices.push(self.indices[idx]);
834                }
835            }
836            new_indptr.push(new_data.len());
837        }
838
839        // Replace data with filtered data
840        self.data = Array1::from_vec(new_data);
841        self.indices = Array1::from_vec(new_indices);
842        self.indptr = Array1::from_vec(new_indptr);
843    }
844
845    fn sort_indices(&mut self) {
846        if self.has_sorted_indices {
847            return;
848        }
849
850        let (rows, _) = self.shape();
851
852        for row in 0..rows {
853            let start = self.indptr[row];
854            let end = self.indptr[row + 1];
855
856            if start == end {
857                continue;
858            }
859
860            // Extract the non-zero elements for this row
861            let mut row_data = Vec::with_capacity(end - start);
862            for idx in start..end {
863                row_data.push((self.indices[idx], self.data[idx]));
864            }
865
866            // Sort by column index
867            row_data.sort_by_key(|&(col_, _)| col_);
868
869            // Put the sorted data back
870            for (i, (col, val)) in row_data.into_iter().enumerate() {
871                self.indices[start + i] = col;
872                self.data[start + i] = val;
873            }
874        }
875
876        self.has_sorted_indices = true;
877    }
878
879    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
880        if self.has_sorted_indices {
881            return Box::new(self.clone());
882        }
883
884        let mut sorted = self.clone();
885        sorted.sort_indices();
886        Box::new(sorted)
887    }
888
889    fn has_sorted_indices(&self) -> bool {
890        self.has_sorted_indices
891    }
892
893    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
894        match axis {
895            None => {
896                // Sum all elements
897                let mut sum = T::zero();
898                for &val in self.data.iter() {
899                    sum = sum + val;
900                }
901                Ok(SparseSum::Scalar(sum))
902            }
903            Some(0) => {
904                // Sum along rows (result is a row vector)
905                let (_, cols) = self.shape();
906                let mut result = vec![T::zero(); cols];
907
908                for row in 0..self.shape.0 {
909                    let start = self.indptr[row];
910                    let end = self.indptr[row + 1];
911
912                    for idx in start..end {
913                        let col = self.indices[idx];
914                        result[col] = result[col] + self.data[idx];
915                    }
916                }
917
918                // Convert to CSR format
919                let mut data = Vec::new();
920                let mut indices = Vec::new();
921                let mut indptr = vec![0];
922
923                for (col, &val) in result.iter().enumerate() {
924                    if !val.is_zero() {
925                        data.push(val);
926                        indices.push(col);
927                    }
928                }
929                indptr.push(data.len());
930
931                let result_array = CsrArray::new(
932                    Array1::from_vec(data),
933                    Array1::from_vec(indices),
934                    Array1::from_vec(indptr),
935                    (1, cols),
936                )?;
937
938                Ok(SparseSum::SparseArray(Box::new(result_array)))
939            }
940            Some(1) => {
941                // Sum along columns (result is a column vector)
942                let mut result = Vec::with_capacity(self.shape.0);
943
944                for row in 0..self.shape.0 {
945                    let start = self.indptr[row];
946                    let end = self.indptr[row + 1];
947
948                    let mut row_sum = T::zero();
949                    for idx in start..end {
950                        row_sum = row_sum + self.data[idx];
951                    }
952                    result.push(row_sum);
953                }
954
955                // Convert to CSR format
956                let mut data = Vec::new();
957                let mut indices = Vec::new();
958                let mut indptr = vec![0];
959
960                for &val in result.iter() {
961                    if !val.is_zero() {
962                        data.push(val);
963                        indices.push(0);
964                        indptr.push(data.len());
965                    } else {
966                        indptr.push(data.len());
967                    }
968                }
969
970                let result_array = CsrArray::new(
971                    Array1::from_vec(data),
972                    Array1::from_vec(indices),
973                    Array1::from_vec(indptr),
974                    (self.shape.0, 1),
975                )?;
976
977                Ok(SparseSum::SparseArray(Box::new(result_array)))
978            }
979            _ => Err(SparseError::InvalidAxis),
980        }
981    }
982
983    fn max(&self) -> T {
984        if self.data.is_empty() {
985            return T::neg_infinity();
986        }
987
988        let mut max_val = self.data[0];
989        for &val in self.data.iter().skip(1) {
990            if val > max_val {
991                max_val = val;
992            }
993        }
994
995        // Check if max_val is less than zero, as zeros aren't explicitly stored
996        if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
997            max_val = T::zero();
998        }
999
1000        max_val
1001    }
1002
1003    fn min(&self) -> T {
1004        if self.data.is_empty() {
1005            return T::infinity();
1006        }
1007
1008        let mut min_val = self.data[0];
1009        for &val in self.data.iter().skip(1) {
1010            if val < min_val {
1011                min_val = val;
1012            }
1013        }
1014
1015        // Check if min_val is greater than zero, as zeros aren't explicitly stored
1016        if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
1017            min_val = T::zero();
1018        }
1019
1020        min_val
1021    }
1022
1023    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
1024        let nnz = self.nnz();
1025        let mut rows = Vec::with_capacity(nnz);
1026        let mut cols = Vec::with_capacity(nnz);
1027        let mut values = Vec::with_capacity(nnz);
1028
1029        for row in 0..self.shape.0 {
1030            let start = self.indptr[row];
1031            let end = self.indptr[row + 1];
1032
1033            for idx in start..end {
1034                let col = self.indices[idx];
1035                rows.push(row);
1036                cols.push(col);
1037                values.push(self.data[idx]);
1038            }
1039        }
1040
1041        (
1042            Array1::from_vec(rows),
1043            Array1::from_vec(cols),
1044            Array1::from_vec(values),
1045        )
1046    }
1047
1048    fn slice(
1049        &self,
1050        row_range: (usize, usize),
1051        col_range: (usize, usize),
1052    ) -> SparseResult<Box<dyn SparseArray<T>>> {
1053        let (start_row, end_row) = row_range;
1054        let (start_col, end_col) = col_range;
1055
1056        if start_row >= self.shape.0
1057            || end_row > self.shape.0
1058            || start_col >= self.shape.1
1059            || end_col > self.shape.1
1060        {
1061            return Err(SparseError::InvalidSliceRange);
1062        }
1063
1064        if start_row >= end_row || start_col >= end_col {
1065            return Err(SparseError::InvalidSliceRange);
1066        }
1067
1068        let mut data = Vec::new();
1069        let mut indices = Vec::new();
1070        let mut indptr = vec![0];
1071
1072        for row in start_row..end_row {
1073            let start = self.indptr[row];
1074            let end = self.indptr[row + 1];
1075
1076            for idx in start..end {
1077                let col = self.indices[idx];
1078                if col >= start_col && col < end_col {
1079                    data.push(self.data[idx]);
1080                    indices.push(col - start_col);
1081                }
1082            }
1083            indptr.push(data.len());
1084        }
1085
1086        CsrArray::new(
1087            Array1::from_vec(data),
1088            Array1::from_vec(indices),
1089            Array1::from_vec(indptr),
1090            (end_row - start_row, end_col - start_col),
1091        )
1092        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
1093    }
1094
1095    fn as_any(&self) -> &dyn std::any::Any {
1096        self
1097    }
1098
1099    fn get_indptr(&self) -> Option<&Array1<usize>> {
1100        Some(&self.indptr)
1101    }
1102
1103    fn indptr(&self) -> Option<&Array1<usize>> {
1104        Some(&self.indptr)
1105    }
1106}
1107
1108impl<T> fmt::Debug for CsrArray<T>
1109where
1110    T: Float
1111        + Add<Output = T>
1112        + Sub<Output = T>
1113        + Mul<Output = T>
1114        + Div<Output = T>
1115        + Debug
1116        + Copy
1117        + 'static,
1118{
1119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1120        write!(
1121            f,
1122            "CsrArray<{}x{}, nnz={}>",
1123            self.shape.0,
1124            self.shape.1,
1125            self.nnz()
1126        )
1127    }
1128}
1129
1130#[cfg(test)]
1131mod tests {
1132    use super::*;
1133    use approx::assert_relative_eq;
1134
1135    #[test]
1136    fn test_csr_array_construction() {
1137        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1138        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1139        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1140        let shape = (3, 3);
1141
1142        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1143
1144        assert_eq!(csr.shape(), (3, 3));
1145        assert_eq!(csr.nnz(), 5);
1146        assert_eq!(csr.get(0, 0), 1.0);
1147        assert_eq!(csr.get(0, 2), 2.0);
1148        assert_eq!(csr.get(1, 1), 3.0);
1149        assert_eq!(csr.get(2, 0), 4.0);
1150        assert_eq!(csr.get(2, 2), 5.0);
1151        assert_eq!(csr.get(0, 1), 0.0);
1152    }
1153
1154    #[test]
1155    fn test_csr_from_triplets() {
1156        let rows = vec![0, 0, 1, 2, 2];
1157        let cols = vec![0, 2, 1, 0, 2];
1158        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1159        let shape = (3, 3);
1160
1161        let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
1162
1163        assert_eq!(csr.shape(), (3, 3));
1164        assert_eq!(csr.nnz(), 5);
1165        assert_eq!(csr.get(0, 0), 1.0);
1166        assert_eq!(csr.get(0, 2), 2.0);
1167        assert_eq!(csr.get(1, 1), 3.0);
1168        assert_eq!(csr.get(2, 0), 4.0);
1169        assert_eq!(csr.get(2, 2), 5.0);
1170        assert_eq!(csr.get(0, 1), 0.0);
1171    }
1172
1173    #[test]
1174    fn test_csr_array_to_array() {
1175        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1176        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1177        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1178        let shape = (3, 3);
1179
1180        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1181        let dense = csr.to_array();
1182
1183        assert_eq!(dense.shape(), &[3, 3]);
1184        assert_eq!(dense[[0, 0]], 1.0);
1185        assert_eq!(dense[[0, 1]], 0.0);
1186        assert_eq!(dense[[0, 2]], 2.0);
1187        assert_eq!(dense[[1, 0]], 0.0);
1188        assert_eq!(dense[[1, 1]], 3.0);
1189        assert_eq!(dense[[1, 2]], 0.0);
1190        assert_eq!(dense[[2, 0]], 4.0);
1191        assert_eq!(dense[[2, 1]], 0.0);
1192        assert_eq!(dense[[2, 2]], 5.0);
1193    }
1194
1195    #[test]
1196    fn test_csr_array_dot_vector() {
1197        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1198        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1199        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1200        let shape = (3, 3);
1201
1202        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1203        let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1204
1205        let result = csr.dot_vector(&vec.view()).unwrap();
1206
1207        // Expected: [1*1 + 0*2 + 2*3, 0*1 + 3*2 + 0*3, 4*1 + 0*2 + 5*3] = [7, 6, 19]
1208        assert_eq!(result.len(), 3);
1209        assert_relative_eq!(result[0], 7.0);
1210        assert_relative_eq!(result[1], 6.0);
1211        assert_relative_eq!(result[2], 19.0);
1212    }
1213
1214    #[test]
1215    fn test_csr_array_sum() {
1216        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1217        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1218        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1219        let shape = (3, 3);
1220
1221        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1222
1223        // Sum all elements
1224        if let SparseSum::Scalar(sum) = csr.sum(None).unwrap() {
1225            assert_relative_eq!(sum, 15.0);
1226        } else {
1227            panic!("Expected scalar sum");
1228        }
1229
1230        // Sum along rows
1231        if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).unwrap() {
1232            let row_sum_array = row_sum.to_array();
1233            assert_eq!(row_sum_array.shape(), &[1, 3]);
1234            assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1235            assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1236            assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1237        } else {
1238            panic!("Expected sparse array sum");
1239        }
1240
1241        // Sum along columns
1242        if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).unwrap() {
1243            let col_sum_array = col_sum.to_array();
1244            assert_eq!(col_sum_array.shape(), &[3, 1]);
1245            assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1246            assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1247            assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1248        } else {
1249            panic!("Expected sparse array sum");
1250        }
1251    }
1252}