scirs2_sparse/
csr_array.rs

1// CSR Array implementation
2//
3// This module provides the CSR (Compressed Sparse Row) array format,
4// which is efficient for row-wise operations and is one of the most common formats.
5
6use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14/// CSR Array format
15///
16/// The CSR (Compressed Sparse Row) format stores a sparse array in three arrays:
17/// - data: array of non-zero values
18/// - indices: column indices of the non-zero values
19/// - indptr: index pointers; for each row, points to the first non-zero element
20///
21/// # Notes
22///
23/// - Efficient for row-oriented operations
24/// - Fast matrix-vector multiplications
25/// - Fast row slicing
26/// - Slow column slicing
27/// - Slow constructing by setting individual elements
28///
29#[derive(Clone)]
30pub struct CsrArray<T>
31where
32    T: Float
33        + Add<Output = T>
34        + Sub<Output = T>
35        + Mul<Output = T>
36        + Div<Output = T>
37        + Debug
38        + Copy
39        + 'static,
40{
41    /// Non-zero values
42    data: Array1<T>,
43    /// Column indices of non-zero values
44    indices: Array1<usize>,
45    /// Row pointers (indices into data/indices for the start of each row)
46    indptr: Array1<usize>,
47    /// Shape of the sparse array
48    shape: (usize, usize),
49    /// Whether indices are sorted for each row
50    has_sorted_indices: bool,
51}
52
53impl<T> CsrArray<T>
54where
55    T: Float
56        + Add<Output = T>
57        + Sub<Output = T>
58        + Mul<Output = T>
59        + Div<Output = T>
60        + Debug
61        + Copy
62        + 'static,
63{
64    /// Creates a new CSR array from raw components
65    ///
66    /// # Arguments
67    /// * `data` - Array of non-zero values
68    /// * `indices` - Column indices of non-zero values
69    /// * `indptr` - Index pointers for the start of each row
70    /// * `shape` - Shape of the sparse array
71    ///
72    /// # Returns
73    /// A new `CsrArray`
74    ///
75    /// # Errors
76    /// Returns an error if the data is not consistent
77    pub fn new(
78        data: Array1<T>,
79        indices: Array1<usize>,
80        indptr: Array1<usize>,
81        shape: (usize, usize),
82    ) -> SparseResult<Self> {
83        // Validation
84        if data.len() != indices.len() {
85            return Err(SparseError::InconsistentData {
86                reason: "data and indices must have the same length".to_string(),
87            });
88        }
89
90        if indptr.len() != shape.0 + 1 {
91            return Err(SparseError::InconsistentData {
92                reason: format!(
93                    "indptr length ({}) must be one more than the number of rows ({})",
94                    indptr.len(),
95                    shape.0
96                ),
97            });
98        }
99
100        if let Some(&max_idx) = indices.iter().max() {
101            if max_idx >= shape.1 {
102                return Err(SparseError::IndexOutOfBounds {
103                    index: (0, max_idx),
104                    shape,
105                });
106            }
107        }
108
109        if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
110            if first != 0 {
111                return Err(SparseError::InconsistentData {
112                    reason: "first element of indptr must be 0".to_string(),
113                });
114            }
115
116            if last != data.len() {
117                return Err(SparseError::InconsistentData {
118                    reason: format!(
119                        "last element of indptr ({}) must equal data length ({})",
120                        last,
121                        data.len()
122                    ),
123                });
124            }
125        }
126
127        let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
128
129        Ok(Self {
130            data,
131            indices,
132            indptr,
133            shape,
134            has_sorted_indices,
135        })
136    }
137
138    /// Create a CSR array from triplet format (COO-like)
139    ///
140    /// # Arguments
141    /// * `rows` - Row indices
142    /// * `cols` - Column indices
143    /// * `data` - Values
144    /// * `shape` - Shape of the sparse array
145    /// * `sorted` - Whether the triplets are sorted by row
146    ///
147    /// # Returns
148    /// A new `CsrArray`
149    ///
150    /// # Errors
151    /// Returns an error if the data is not consistent
152    pub fn from_triplets(
153        rows: &[usize],
154        cols: &[usize],
155        data: &[T],
156        shape: (usize, usize),
157        sorted: bool,
158    ) -> SparseResult<Self> {
159        if rows.len() != cols.len() || rows.len() != data.len() {
160            return Err(SparseError::InconsistentData {
161                reason: "rows, cols, and data must have the same length".to_string(),
162            });
163        }
164
165        if rows.is_empty() {
166            // Empty matrix
167            let indptr = Array1::zeros(shape.0 + 1);
168            return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
169        }
170
171        let nnz = rows.len();
172        let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
173
174        for i in 0..nnz {
175            if rows[i] >= shape.0 || cols[i] >= shape.1 {
176                return Err(SparseError::IndexOutOfBounds {
177                    index: (rows[i], cols[i]),
178                    shape,
179                });
180            }
181            all_data.push((rows[i], cols[i], data[i]));
182        }
183
184        if !sorted {
185            all_data.sort_by_key(|&(row, col, _)| (row, col));
186        }
187
188        // Count elements per row
189        let mut row_counts = vec![0; shape.0];
190        for &(row, _, _) in &all_data {
191            row_counts[row] += 1;
192        }
193
194        // Create indptr
195        let mut indptr = Vec::with_capacity(shape.0 + 1);
196        indptr.push(0);
197        let mut cumsum = 0;
198        for &count in &row_counts {
199            cumsum += count;
200            indptr.push(cumsum);
201        }
202
203        // Create indices and data arrays
204        let mut indices = Vec::with_capacity(nnz);
205        let mut values = Vec::with_capacity(nnz);
206
207        for (_, col, val) in all_data {
208            indices.push(col);
209            values.push(val);
210        }
211
212        Self::new(
213            Array1::from_vec(values),
214            Array1::from_vec(indices),
215            Array1::from_vec(indptr),
216            shape,
217        )
218    }
219
220    /// Checks if column indices are sorted for each row
221    fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
222        for row in 0..indptr.len() - 1 {
223            let start = indptr[row];
224            let end = indptr[row + 1];
225
226            for i in start..end.saturating_sub(1) {
227                if i + 1 < indices.len() && indices[i] > indices[i + 1] {
228                    return false;
229                }
230            }
231        }
232        true
233    }
234
235    /// Get the raw data array
236    pub fn get_data(&self) -> &Array1<T> {
237        &self.data
238    }
239
240    /// Get the raw indices array
241    pub fn get_indices(&self) -> &Array1<usize> {
242        &self.indices
243    }
244
245    /// Get the raw indptr array
246    pub fn get_indptr(&self) -> &Array1<usize> {
247        &self.indptr
248    }
249
250    /// Get the number of rows
251    pub fn nrows(&self) -> usize {
252        self.shape.0
253    }
254
255    /// Get the number of columns  
256    pub fn ncols(&self) -> usize {
257        self.shape.1
258    }
259
260    /// Get the shape (rows, cols)
261    pub fn shape(&self) -> (usize, usize) {
262        self.shape
263    }
264}
265
266impl<T> SparseArray<T> for CsrArray<T>
267where
268    T: Float
269        + Add<Output = T>
270        + Sub<Output = T>
271        + Mul<Output = T>
272        + Div<Output = T>
273        + Debug
274        + Copy
275        + 'static,
276{
277    fn shape(&self) -> (usize, usize) {
278        self.shape
279    }
280
281    fn nnz(&self) -> usize {
282        self.data.len()
283    }
284
285    fn dtype(&self) -> &str {
286        "float" // Placeholder, ideally we would return the actual type
287    }
288
289    fn to_array(&self) -> Array2<T> {
290        let (rows, cols) = self.shape;
291        let mut result = Array2::zeros((rows, cols));
292
293        for row in 0..rows {
294            let start = self.indptr[row];
295            let end = self.indptr[row + 1];
296
297            for i in start..end {
298                let col = self.indices[i];
299                result[[row, col]] = self.data[i];
300            }
301        }
302
303        result
304    }
305
306    fn toarray(&self) -> Array2<T> {
307        self.to_array()
308    }
309
310    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
311        // This would convert to COO format
312        // For now we just return self
313        Ok(Box::new(self.clone()))
314    }
315
316    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
317        Ok(Box::new(self.clone()))
318    }
319
320    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
321        // This would convert to CSC format
322        // For now we just return self
323        Ok(Box::new(self.clone()))
324    }
325
326    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
327        // This would convert to DOK format
328        // For now we just return self
329        Ok(Box::new(self.clone()))
330    }
331
332    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
333        // This would convert to LIL format
334        // For now we just return self
335        Ok(Box::new(self.clone()))
336    }
337
338    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
339        // This would convert to DIA format
340        // For now we just return self
341        Ok(Box::new(self.clone()))
342    }
343
344    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
345        // This would convert to BSR format
346        // For now we just return self
347        Ok(Box::new(self.clone()))
348    }
349
350    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
351        // Placeholder implementation - a full implementation would handle direct
352        // sparse matrix addition more efficiently
353        let self_array = self.to_array();
354        let other_array = other.to_array();
355
356        if self.shape() != other.shape() {
357            return Err(SparseError::DimensionMismatch {
358                expected: self.shape().0,
359                found: other.shape().0,
360            });
361        }
362
363        let result = &self_array + &other_array;
364
365        // Convert back to CSR
366        let (rows, cols) = self.shape();
367        let mut data = Vec::new();
368        let mut indices = Vec::new();
369        let mut indptr = vec![0];
370
371        for row in 0..rows {
372            for col in 0..cols {
373                let val = result[[row, col]];
374                if !val.is_zero() {
375                    data.push(val);
376                    indices.push(col);
377                }
378            }
379            indptr.push(data.len());
380        }
381
382        CsrArray::new(
383            Array1::from_vec(data),
384            Array1::from_vec(indices),
385            Array1::from_vec(indptr),
386            self.shape(),
387        )
388        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
389    }
390
391    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
392        // Similar to add, this is a placeholder
393        let self_array = self.to_array();
394        let other_array = other.to_array();
395
396        if self.shape() != other.shape() {
397            return Err(SparseError::DimensionMismatch {
398                expected: self.shape().0,
399                found: other.shape().0,
400            });
401        }
402
403        let result = &self_array - &other_array;
404
405        // Convert back to CSR
406        let (rows, cols) = self.shape();
407        let mut data = Vec::new();
408        let mut indices = Vec::new();
409        let mut indptr = vec![0];
410
411        for row in 0..rows {
412            for col in 0..cols {
413                let val = result[[row, col]];
414                if !val.is_zero() {
415                    data.push(val);
416                    indices.push(col);
417                }
418            }
419            indptr.push(data.len());
420        }
421
422        CsrArray::new(
423            Array1::from_vec(data),
424            Array1::from_vec(indices),
425            Array1::from_vec(indptr),
426            self.shape(),
427        )
428        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
429    }
430
431    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
432        // This is element-wise multiplication (Hadamard product)
433        // In the sparse array API, * is element-wise, not matrix multiplication
434        let self_array = self.to_array();
435        let other_array = other.to_array();
436
437        if self.shape() != other.shape() {
438            return Err(SparseError::DimensionMismatch {
439                expected: self.shape().0,
440                found: other.shape().0,
441            });
442        }
443
444        let result = &self_array * &other_array;
445
446        // Convert back to CSR
447        let (rows, cols) = self.shape();
448        let mut data = Vec::new();
449        let mut indices = Vec::new();
450        let mut indptr = vec![0];
451
452        for row in 0..rows {
453            for col in 0..cols {
454                let val = result[[row, col]];
455                if !val.is_zero() {
456                    data.push(val);
457                    indices.push(col);
458                }
459            }
460            indptr.push(data.len());
461        }
462
463        CsrArray::new(
464            Array1::from_vec(data),
465            Array1::from_vec(indices),
466            Array1::from_vec(indptr),
467            self.shape(),
468        )
469        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
470    }
471
472    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
473        // Element-wise division
474        let self_array = self.to_array();
475        let other_array = other.to_array();
476
477        if self.shape() != other.shape() {
478            return Err(SparseError::DimensionMismatch {
479                expected: self.shape().0,
480                found: other.shape().0,
481            });
482        }
483
484        let result = &self_array / &other_array;
485
486        // Convert back to CSR
487        let (rows, cols) = self.shape();
488        let mut data = Vec::new();
489        let mut indices = Vec::new();
490        let mut indptr = vec![0];
491
492        for row in 0..rows {
493            for col in 0..cols {
494                let val = result[[row, col]];
495                if !val.is_zero() {
496                    data.push(val);
497                    indices.push(col);
498                }
499            }
500            indptr.push(data.len());
501        }
502
503        CsrArray::new(
504            Array1::from_vec(data),
505            Array1::from_vec(indices),
506            Array1::from_vec(indptr),
507            self.shape(),
508        )
509        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
510    }
511
512    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
513        // Matrix multiplication
514        // This is a placeholder; a full implementation would be optimized for sparse arrays
515        let (m, n) = self.shape();
516        let (p, q) = other.shape();
517
518        if n != p {
519            return Err(SparseError::DimensionMismatch {
520                expected: n,
521                found: p,
522            });
523        }
524
525        let mut result = Array2::zeros((m, q));
526        let other_array = other.to_array();
527
528        for row in 0..m {
529            let start = self.indptr[row];
530            let end = self.indptr[row + 1];
531
532            for j in 0..q {
533                let mut sum = T::zero();
534                for idx in start..end {
535                    let col = self.indices[idx];
536                    sum = sum + self.data[idx] * other_array[[col, j]];
537                }
538                if !sum.is_zero() {
539                    result[[row, j]] = sum;
540                }
541            }
542        }
543
544        // Convert result back to CSR
545        let mut data = Vec::new();
546        let mut indices = Vec::new();
547        let mut indptr = vec![0];
548
549        for row in 0..m {
550            for col in 0..q {
551                let val = result[[row, col]];
552                if !val.is_zero() {
553                    data.push(val);
554                    indices.push(col);
555                }
556            }
557            indptr.push(data.len());
558        }
559
560        CsrArray::new(
561            Array1::from_vec(data),
562            Array1::from_vec(indices),
563            Array1::from_vec(indptr),
564            (m, q),
565        )
566        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
567    }
568
569    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
570        let (m, n) = self.shape();
571        if n != other.len() {
572            return Err(SparseError::DimensionMismatch {
573                expected: n,
574                found: other.len(),
575            });
576        }
577
578        let mut result = Array1::zeros(m);
579
580        for row in 0..m {
581            let start = self.indptr[row];
582            let end = self.indptr[row + 1];
583
584            let mut sum = T::zero();
585            for idx in start..end {
586                let col = self.indices[idx];
587                sum = sum + self.data[idx] * other[col];
588            }
589            result[row] = sum;
590        }
591
592        Ok(result)
593    }
594
595    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
596        // Transpose is non-trivial for CSR format
597        // A full implementation would convert to CSC format or implement
598        // an efficient algorithm
599        let (rows, cols) = self.shape();
600        let mut row_indices = Vec::with_capacity(self.nnz());
601        let mut col_indices = Vec::with_capacity(self.nnz());
602        let mut values = Vec::with_capacity(self.nnz());
603
604        for row in 0..rows {
605            let start = self.indptr[row];
606            let end = self.indptr[row + 1];
607
608            for idx in start..end {
609                let col = self.indices[idx];
610                row_indices.push(col); // Note: rows and cols are swapped for transposition
611                col_indices.push(row);
612                values.push(self.data[idx]);
613            }
614        }
615
616        // We need to create CSR from this "COO" representation
617        CsrArray::from_triplets(
618            &row_indices,
619            &col_indices,
620            &values,
621            (cols, rows), // Swapped dimensions
622            false,
623        )
624        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
625    }
626
627    fn copy(&self) -> Box<dyn SparseArray<T>> {
628        Box::new(self.clone())
629    }
630
631    fn get(&self, i: usize, j: usize) -> T {
632        if i >= self.shape.0 || j >= self.shape.1 {
633            return T::zero();
634        }
635
636        let start = self.indptr[i];
637        let end = self.indptr[i + 1];
638
639        for idx in start..end {
640            if self.indices[idx] == j {
641                return self.data[idx];
642            }
643            // If indices are sorted, we can break early
644            if self.has_sorted_indices && self.indices[idx] > j {
645                break;
646            }
647        }
648
649        T::zero()
650    }
651
652    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
653        // Setting elements in CSR format is non-trivial
654        // This is a placeholder implementation that doesn't actually modify the structure
655        if i >= self.shape.0 || j >= self.shape.1 {
656            return Err(SparseError::IndexOutOfBounds {
657                index: (i, j),
658                shape: self.shape,
659            });
660        }
661
662        let start = self.indptr[i];
663        let end = self.indptr[i + 1];
664
665        // Try to find existing element
666        for idx in start..end {
667            if self.indices[idx] == j {
668                // Update value
669                self.data[idx] = value;
670                return Ok(());
671            }
672            // If indices are sorted, we can insert at the right position
673            if self.has_sorted_indices && self.indices[idx] > j {
674                // Insert here - this would require restructuring indices and data
675                // Not implemented in this placeholder
676                return Err(SparseError::NotImplemented(
677                    "Inserting new elements in CSR format".to_string(),
678                ));
679            }
680        }
681
682        // Element not found, would need to insert
683        // This would require restructuring indices and data
684        Err(SparseError::NotImplemented(
685            "Inserting new elements in CSR format".to_string(),
686        ))
687    }
688
689    fn eliminate_zeros(&mut self) {
690        // Find all non-zero entries
691        let mut new_data = Vec::new();
692        let mut new_indices = Vec::new();
693        let mut new_indptr = vec![0];
694
695        let (rows, _) = self.shape();
696
697        for row in 0..rows {
698            let start = self.indptr[row];
699            let end = self.indptr[row + 1];
700
701            for idx in start..end {
702                if !self.data[idx].is_zero() {
703                    new_data.push(self.data[idx]);
704                    new_indices.push(self.indices[idx]);
705                }
706            }
707            new_indptr.push(new_data.len());
708        }
709
710        // Replace data with filtered data
711        self.data = Array1::from_vec(new_data);
712        self.indices = Array1::from_vec(new_indices);
713        self.indptr = Array1::from_vec(new_indptr);
714    }
715
716    fn sort_indices(&mut self) {
717        if self.has_sorted_indices {
718            return;
719        }
720
721        let (rows, _) = self.shape();
722
723        for row in 0..rows {
724            let start = self.indptr[row];
725            let end = self.indptr[row + 1];
726
727            if start == end {
728                continue;
729            }
730
731            // Extract the non-zero elements for this row
732            let mut row_data = Vec::with_capacity(end - start);
733            for idx in start..end {
734                row_data.push((self.indices[idx], self.data[idx]));
735            }
736
737            // Sort by column index
738            row_data.sort_by_key(|&(col, _)| col);
739
740            // Put the sorted data back
741            for (i, (col, val)) in row_data.into_iter().enumerate() {
742                self.indices[start + i] = col;
743                self.data[start + i] = val;
744            }
745        }
746
747        self.has_sorted_indices = true;
748    }
749
750    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
751        if self.has_sorted_indices {
752            return Box::new(self.clone());
753        }
754
755        let mut sorted = self.clone();
756        sorted.sort_indices();
757        Box::new(sorted)
758    }
759
760    fn has_sorted_indices(&self) -> bool {
761        self.has_sorted_indices
762    }
763
764    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
765        match axis {
766            None => {
767                // Sum all elements
768                let mut sum = T::zero();
769                for &val in self.data.iter() {
770                    sum = sum + val;
771                }
772                Ok(SparseSum::Scalar(sum))
773            }
774            Some(0) => {
775                // Sum along rows (result is a row vector)
776                let (_, cols) = self.shape();
777                let mut result = vec![T::zero(); cols];
778
779                for row in 0..self.shape.0 {
780                    let start = self.indptr[row];
781                    let end = self.indptr[row + 1];
782
783                    for idx in start..end {
784                        let col = self.indices[idx];
785                        result[col] = result[col] + self.data[idx];
786                    }
787                }
788
789                // Convert to CSR format
790                let mut data = Vec::new();
791                let mut indices = Vec::new();
792                let mut indptr = vec![0];
793
794                for (col, &val) in result.iter().enumerate() {
795                    if !val.is_zero() {
796                        data.push(val);
797                        indices.push(col);
798                    }
799                }
800                indptr.push(data.len());
801
802                let result_array = CsrArray::new(
803                    Array1::from_vec(data),
804                    Array1::from_vec(indices),
805                    Array1::from_vec(indptr),
806                    (1, cols),
807                )?;
808
809                Ok(SparseSum::SparseArray(Box::new(result_array)))
810            }
811            Some(1) => {
812                // Sum along columns (result is a column vector)
813                let mut result = Vec::with_capacity(self.shape.0);
814
815                for row in 0..self.shape.0 {
816                    let start = self.indptr[row];
817                    let end = self.indptr[row + 1];
818
819                    let mut row_sum = T::zero();
820                    for idx in start..end {
821                        row_sum = row_sum + self.data[idx];
822                    }
823                    result.push(row_sum);
824                }
825
826                // Convert to CSR format
827                let mut data = Vec::new();
828                let mut indices = Vec::new();
829                let mut indptr = vec![0];
830
831                for &val in result.iter() {
832                    if !val.is_zero() {
833                        data.push(val);
834                        indices.push(0);
835                        indptr.push(data.len());
836                    } else {
837                        indptr.push(data.len());
838                    }
839                }
840
841                let result_array = CsrArray::new(
842                    Array1::from_vec(data),
843                    Array1::from_vec(indices),
844                    Array1::from_vec(indptr),
845                    (self.shape.0, 1),
846                )?;
847
848                Ok(SparseSum::SparseArray(Box::new(result_array)))
849            }
850            _ => Err(SparseError::InvalidAxis),
851        }
852    }
853
854    fn max(&self) -> T {
855        if self.data.is_empty() {
856            return T::neg_infinity();
857        }
858
859        let mut max_val = self.data[0];
860        for &val in self.data.iter().skip(1) {
861            if val > max_val {
862                max_val = val;
863            }
864        }
865
866        // Check if max_val is less than zero, as zeros aren't explicitly stored
867        if max_val < T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
868            max_val = T::zero();
869        }
870
871        max_val
872    }
873
874    fn min(&self) -> T {
875        if self.data.is_empty() {
876            return T::infinity();
877        }
878
879        let mut min_val = self.data[0];
880        for &val in self.data.iter().skip(1) {
881            if val < min_val {
882                min_val = val;
883            }
884        }
885
886        // Check if min_val is greater than zero, as zeros aren't explicitly stored
887        if min_val > T::zero() && self.nnz() < self.shape.0 * self.shape.1 {
888            min_val = T::zero();
889        }
890
891        min_val
892    }
893
894    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
895        let nnz = self.nnz();
896        let mut rows = Vec::with_capacity(nnz);
897        let mut cols = Vec::with_capacity(nnz);
898        let mut values = Vec::with_capacity(nnz);
899
900        for row in 0..self.shape.0 {
901            let start = self.indptr[row];
902            let end = self.indptr[row + 1];
903
904            for idx in start..end {
905                let col = self.indices[idx];
906                rows.push(row);
907                cols.push(col);
908                values.push(self.data[idx]);
909            }
910        }
911
912        (
913            Array1::from_vec(rows),
914            Array1::from_vec(cols),
915            Array1::from_vec(values),
916        )
917    }
918
919    fn slice(
920        &self,
921        row_range: (usize, usize),
922        col_range: (usize, usize),
923    ) -> SparseResult<Box<dyn SparseArray<T>>> {
924        let (start_row, end_row) = row_range;
925        let (start_col, end_col) = col_range;
926
927        if start_row >= self.shape.0
928            || end_row > self.shape.0
929            || start_col >= self.shape.1
930            || end_col > self.shape.1
931        {
932            return Err(SparseError::InvalidSliceRange);
933        }
934
935        if start_row >= end_row || start_col >= end_col {
936            return Err(SparseError::InvalidSliceRange);
937        }
938
939        let mut data = Vec::new();
940        let mut indices = Vec::new();
941        let mut indptr = vec![0];
942
943        for row in start_row..end_row {
944            let start = self.indptr[row];
945            let end = self.indptr[row + 1];
946
947            for idx in start..end {
948                let col = self.indices[idx];
949                if col >= start_col && col < end_col {
950                    data.push(self.data[idx]);
951                    indices.push(col - start_col);
952                }
953            }
954            indptr.push(data.len());
955        }
956
957        CsrArray::new(
958            Array1::from_vec(data),
959            Array1::from_vec(indices),
960            Array1::from_vec(indptr),
961            (end_row - start_row, end_col - start_col),
962        )
963        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
964    }
965
966    fn as_any(&self) -> &dyn std::any::Any {
967        self
968    }
969}
970
971impl<T> fmt::Debug for CsrArray<T>
972where
973    T: Float
974        + Add<Output = T>
975        + Sub<Output = T>
976        + Mul<Output = T>
977        + Div<Output = T>
978        + Debug
979        + Copy
980        + 'static,
981{
982    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
983        write!(
984            f,
985            "CsrArray<{}x{}, nnz={}>",
986            self.shape.0,
987            self.shape.1,
988            self.nnz()
989        )
990    }
991}
992
993#[cfg(test)]
994mod tests {
995    use super::*;
996    use approx::assert_relative_eq;
997
998    #[test]
999    fn test_csr_array_construction() {
1000        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1001        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1002        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1003        let shape = (3, 3);
1004
1005        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1006
1007        assert_eq!(csr.shape(), (3, 3));
1008        assert_eq!(csr.nnz(), 5);
1009        assert_eq!(csr.get(0, 0), 1.0);
1010        assert_eq!(csr.get(0, 2), 2.0);
1011        assert_eq!(csr.get(1, 1), 3.0);
1012        assert_eq!(csr.get(2, 0), 4.0);
1013        assert_eq!(csr.get(2, 2), 5.0);
1014        assert_eq!(csr.get(0, 1), 0.0);
1015    }
1016
1017    #[test]
1018    fn test_csr_from_triplets() {
1019        let rows = vec![0, 0, 1, 2, 2];
1020        let cols = vec![0, 2, 1, 0, 2];
1021        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1022        let shape = (3, 3);
1023
1024        let csr = CsrArray::from_triplets(&rows, &cols, &data, shape, false).unwrap();
1025
1026        assert_eq!(csr.shape(), (3, 3));
1027        assert_eq!(csr.nnz(), 5);
1028        assert_eq!(csr.get(0, 0), 1.0);
1029        assert_eq!(csr.get(0, 2), 2.0);
1030        assert_eq!(csr.get(1, 1), 3.0);
1031        assert_eq!(csr.get(2, 0), 4.0);
1032        assert_eq!(csr.get(2, 2), 5.0);
1033        assert_eq!(csr.get(0, 1), 0.0);
1034    }
1035
1036    #[test]
1037    fn test_csr_array_to_array() {
1038        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1039        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1040        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1041        let shape = (3, 3);
1042
1043        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1044        let dense = csr.to_array();
1045
1046        assert_eq!(dense.shape(), &[3, 3]);
1047        assert_eq!(dense[[0, 0]], 1.0);
1048        assert_eq!(dense[[0, 1]], 0.0);
1049        assert_eq!(dense[[0, 2]], 2.0);
1050        assert_eq!(dense[[1, 0]], 0.0);
1051        assert_eq!(dense[[1, 1]], 3.0);
1052        assert_eq!(dense[[1, 2]], 0.0);
1053        assert_eq!(dense[[2, 0]], 4.0);
1054        assert_eq!(dense[[2, 1]], 0.0);
1055        assert_eq!(dense[[2, 2]], 5.0);
1056    }
1057
1058    #[test]
1059    fn test_csr_array_dot_vector() {
1060        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1061        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1062        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1063        let shape = (3, 3);
1064
1065        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1066        let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1067
1068        let result = csr.dot_vector(&vec.view()).unwrap();
1069
1070        // Expected: [1*1 + 0*2 + 2*3, 0*1 + 3*2 + 0*3, 4*1 + 0*2 + 5*3] = [7, 6, 19]
1071        assert_eq!(result.len(), 3);
1072        assert_relative_eq!(result[0], 7.0);
1073        assert_relative_eq!(result[1], 6.0);
1074        assert_relative_eq!(result[2], 19.0);
1075    }
1076
1077    #[test]
1078    fn test_csr_array_sum() {
1079        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1080        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1081        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1082        let shape = (3, 3);
1083
1084        let csr = CsrArray::new(data, indices, indptr, shape).unwrap();
1085
1086        // Sum all elements
1087        if let SparseSum::Scalar(sum) = csr.sum(None).unwrap() {
1088            assert_relative_eq!(sum, 15.0);
1089        } else {
1090            panic!("Expected scalar sum");
1091        }
1092
1093        // Sum along rows
1094        if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).unwrap() {
1095            let row_sum_array = row_sum.to_array();
1096            assert_eq!(row_sum_array.shape(), &[1, 3]);
1097            assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1098            assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1099            assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1100        } else {
1101            panic!("Expected sparse array sum");
1102        }
1103
1104        // Sum along columns
1105        if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).unwrap() {
1106            let col_sum_array = col_sum.to_array();
1107            assert_eq!(col_sum_array.shape(), &[3, 1]);
1108            assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1109            assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1110            assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1111        } else {
1112            panic!("Expected sparse array sum");
1113        }
1114    }
1115}