Skip to main content

sublinear_solver/matrix/
sparse.rs

1//! Sparse matrix storage implementations.
2//!
3//! This module provides efficient storage formats for sparse matrices,
4//! including CSR, CSC, COO, and graph adjacency representations.
5
6use crate::error::Result;
7use crate::types::{DimensionType, IndexType, NodeId, Precision};
8use alloc::vec::Vec;
9
10/// Compressed Sparse Row (CSR) storage format.
11///
12/// Efficient for row-wise operations and matrix-vector multiplication.
13#[derive(Debug, Clone)]
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
15pub struct CSRStorage {
16    /// Non-zero values in row-major order
17    pub values: Vec<Precision>,
18    /// Column indices corresponding to values
19    pub col_indices: Vec<IndexType>,
20    /// Row pointers: row_ptr[i] is the start of row i in values/col_indices
21    pub row_ptr: Vec<IndexType>,
22}
23
24/// Compressed Sparse Column (CSC) storage format.
25///
26/// Efficient for column-wise operations.
27#[derive(Debug, Clone)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29pub struct CSCStorage {
30    /// Non-zero values in column-major order
31    pub values: Vec<Precision>,
32    /// Row indices corresponding to values
33    pub row_indices: Vec<IndexType>,
34    /// Column pointers: col_ptr[j] is the start of column j in values/row_indices
35    pub col_ptr: Vec<IndexType>,
36}
37
38/// Coordinate (COO) storage format.
39///
40/// Efficient for construction and random access patterns.
41#[derive(Debug, Clone)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43pub struct COOStorage {
44    /// Row indices
45    pub row_indices: Vec<IndexType>,
46    /// Column indices
47    pub col_indices: Vec<IndexType>,
48    /// Values
49    pub values: Vec<Precision>,
50}
51
52/// Graph adjacency list storage format.
53///
54/// Optimized for graph algorithms like push methods.
55#[derive(Debug, Clone)]
56#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
57pub struct GraphStorage {
58    /// Outgoing edges for each node
59    pub out_edges: Vec<Vec<GraphEdge>>,
60    /// Incoming edges for each node (for backward push)
61    pub in_edges: Vec<Vec<GraphEdge>>,
62    /// Node degrees for normalization
63    pub degrees: Vec<Precision>,
64}
65
66/// Graph edge representation.
67#[derive(Debug, Clone, Copy, PartialEq)]
68#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
69pub struct GraphEdge {
70    /// Target node
71    pub target: NodeId,
72    /// Edge weight
73    pub weight: Precision,
74}
75
76// CSR Implementation
77impl CSRStorage {
78    /// Create CSR storage from COO format.
79    pub fn from_coo(coo: &COOStorage, rows: DimensionType, cols: DimensionType) -> Result<Self> {
80        if coo.is_empty() {
81            return Ok(Self {
82                values: Vec::new(),
83                col_indices: Vec::new(),
84                row_ptr: vec![0; rows + 1],
85            });
86        }
87
88        // Sort by row, then by column
89        let mut sorted_entries: Vec<_> = coo
90            .row_indices
91            .iter()
92            .zip(&coo.col_indices)
93            .zip(&coo.values)
94            .map(|((&r, &c), &v)| (r as usize, c, v))
95            .collect();
96        sorted_entries.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
97
98        let mut values = Vec::new();
99        let mut col_indices = Vec::new();
100        let mut row_ptr = vec![0; rows + 1];
101
102        let mut current_row = 0;
103        let mut nnz_count = 0;
104
105        for (row, col, value) in sorted_entries {
106            // Skip zeros
107            if value == 0.0 {
108                continue;
109            }
110
111            // Update row pointers
112            while current_row < row {
113                current_row += 1;
114                row_ptr[current_row] = nnz_count as IndexType;
115            }
116
117            values.push(value);
118            col_indices.push(col);
119            nnz_count += 1;
120        }
121
122        // Finalize remaining row pointers
123        while current_row < rows {
124            current_row += 1;
125            row_ptr[current_row] = nnz_count as IndexType;
126        }
127
128        Ok(Self {
129            values,
130            col_indices,
131            row_ptr,
132        })
133    }
134
135    /// Create CSR storage from CSC format.
136    pub fn from_csc(csc: &CSCStorage, rows: DimensionType, cols: DimensionType) -> Result<Self> {
137        let triplets = csc.to_triplets()?;
138        let coo = COOStorage::from_triplets(triplets)?;
139        Self::from_coo(&coo, rows, cols)
140    }
141
142    /// Get matrix element at (row, col).
143    pub fn get(&self, row: usize, col: usize) -> Option<Precision> {
144        if row >= self.row_ptr.len() - 1 {
145            return None;
146        }
147
148        let start = self.row_ptr[row] as usize;
149        let end = self.row_ptr[row + 1] as usize;
150
151        // Binary search for the column
152        match self.col_indices[start..end].binary_search(&(col as IndexType)) {
153            Ok(pos) => Some(self.values[start + pos]),
154            Err(_) => None,
155        }
156    }
157
158    /// Iterate over non-zero elements in a row.
159    pub fn row_iter(&self, row: usize) -> CSRRowIter<'_> {
160        if row >= self.row_ptr.len() - 1 {
161            return CSRRowIter {
162                col_indices: &[],
163                values: &[],
164                pos: 0,
165            };
166        }
167
168        let start = self.row_ptr[row] as usize;
169        let end = self.row_ptr[row + 1] as usize;
170
171        CSRRowIter {
172            col_indices: &self.col_indices[start..end],
173            values: &self.values[start..end],
174            pos: 0,
175        }
176    }
177
178    /// Iterate over non-zero elements in a column (slow for CSR).
179    pub fn col_iter(&self, col: usize) -> CSRColIter<'_> {
180        CSRColIter {
181            storage: self,
182            col: col as IndexType,
183            row: 0,
184        }
185    }
186
187    /// Matrix-vector multiplication: result = A * x
188    pub fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) {
189        result.fill(0.0);
190        self.multiply_vector_add(x, result);
191    }
192
193    /// Matrix-vector multiplication with accumulation: result += A * x
194    pub fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) {
195        for (row, mut row_sum) in result.iter_mut().enumerate() {
196            let start = self.row_ptr[row] as usize;
197            let end = self.row_ptr[row + 1] as usize;
198
199            for i in start..end {
200                let col = self.col_indices[i] as usize;
201                *row_sum += self.values[i] * x[col];
202            }
203        }
204    }
205
206    /// Get number of non-zero elements.
207    pub fn nnz(&self) -> usize {
208        self.values.len()
209    }
210
211    /// Extract as coordinate triplets.
212    pub fn to_triplets(&self) -> Result<Vec<(usize, usize, Precision)>> {
213        let mut triplets = Vec::new();
214
215        for row in 0..self.row_ptr.len() - 1 {
216            let start = self.row_ptr[row] as usize;
217            let end = self.row_ptr[row + 1] as usize;
218
219            for i in start..end {
220                let col = self.col_indices[i] as usize;
221                let value = self.values[i];
222                triplets.push((row, col, value));
223            }
224        }
225
226        Ok(triplets)
227    }
228
229    /// Scale all values by a factor.
230    pub fn scale(&mut self, factor: Precision) {
231        for value in &mut self.values {
232            *value *= factor;
233        }
234    }
235
236    /// Add a value to the diagonal.
237    pub fn add_diagonal(&mut self, alpha: Precision) {
238        for row in 0..self.row_ptr.len() - 1 {
239            let start = self.row_ptr[row] as usize;
240            let end = self.row_ptr[row + 1] as usize;
241
242            // Look for diagonal element
243            if let Ok(pos) = self.col_indices[start..end].binary_search(&(row as IndexType)) {
244                self.values[start + pos] += alpha;
245            }
246            // Note: If diagonal element doesn't exist, we'd need to restructure the matrix
247        }
248    }
249}
250
251/// Iterator over non-zero elements in a CSR row.
252pub struct CSRRowIter<'a> {
253    col_indices: &'a [IndexType],
254    values: &'a [Precision],
255    pos: usize,
256}
257
258impl<'a> Iterator for CSRRowIter<'a> {
259    type Item = (IndexType, Precision);
260
261    fn next(&mut self) -> Option<Self::Item> {
262        if self.pos < self.col_indices.len() {
263            let col = self.col_indices[self.pos];
264            let val = self.values[self.pos];
265            self.pos += 1;
266            Some((col, val))
267        } else {
268            None
269        }
270    }
271}
272
273/// Iterator over non-zero elements in a CSR column (inefficient).
274pub struct CSRColIter<'a> {
275    storage: &'a CSRStorage,
276    col: IndexType,
277    row: usize,
278}
279
280impl<'a> Iterator for CSRColIter<'a> {
281    type Item = (IndexType, Precision);
282
283    fn next(&mut self) -> Option<Self::Item> {
284        while self.row < self.storage.row_ptr.len() - 1 {
285            let start = self.storage.row_ptr[self.row] as usize;
286            let end = self.storage.row_ptr[self.row + 1] as usize;
287
288            if let Ok(pos) = self.storage.col_indices[start..end].binary_search(&self.col) {
289                let value = self.storage.values[start + pos];
290                let row = self.row as IndexType;
291                self.row += 1;
292                return Some((row, value));
293            }
294
295            self.row += 1;
296        }
297        None
298    }
299}
300
301// CSC Implementation
302impl CSCStorage {
303    /// Create CSC storage from COO format.
304    pub fn from_coo(coo: &COOStorage, rows: DimensionType, cols: DimensionType) -> Result<Self> {
305        if coo.is_empty() {
306            return Ok(Self {
307                values: Vec::new(),
308                row_indices: Vec::new(),
309                col_ptr: vec![0; cols + 1],
310            });
311        }
312
313        // Sort by column, then by row
314        let mut sorted_entries: Vec<_> = coo
315            .row_indices
316            .iter()
317            .zip(&coo.col_indices)
318            .zip(&coo.values)
319            .map(|((&r, &c), &v)| (r, c as usize, v))
320            .collect();
321        sorted_entries.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
322
323        let mut values = Vec::new();
324        let mut row_indices = Vec::new();
325        let mut col_ptr = vec![0; cols + 1];
326
327        let mut current_col = 0;
328        let mut nnz_count = 0;
329
330        for (row, col, value) in sorted_entries {
331            // Skip zeros
332            if value == 0.0 {
333                continue;
334            }
335
336            // Update column pointers
337            while current_col < col {
338                current_col += 1;
339                col_ptr[current_col] = nnz_count as IndexType;
340            }
341
342            values.push(value);
343            row_indices.push(row);
344            nnz_count += 1;
345        }
346
347        // Finalize remaining column pointers
348        while current_col < cols {
349            current_col += 1;
350            col_ptr[current_col] = nnz_count as IndexType;
351        }
352
353        Ok(Self {
354            values,
355            row_indices,
356            col_ptr,
357        })
358    }
359
360    /// Create CSC storage from CSR format.
361    pub fn from_csr(csr: &CSRStorage, rows: DimensionType, cols: DimensionType) -> Result<Self> {
362        let triplets = csr.to_triplets()?;
363        let coo = COOStorage::from_triplets(triplets)?;
364        Self::from_coo(&coo, rows, cols)
365    }
366
367    /// Get matrix element at (row, col).
368    pub fn get(&self, row: usize, col: usize) -> Option<Precision> {
369        if col >= self.col_ptr.len() - 1 {
370            return None;
371        }
372
373        let start = self.col_ptr[col] as usize;
374        let end = self.col_ptr[col + 1] as usize;
375
376        // Binary search for the row
377        match self.row_indices[start..end].binary_search(&(row as IndexType)) {
378            Ok(pos) => Some(self.values[start + pos]),
379            Err(_) => None,
380        }
381    }
382
383    /// Iterate over non-zero elements in a row (slow for CSC).
384    pub fn row_iter(&self, row: usize) -> CSCRowIter<'_> {
385        CSCRowIter {
386            storage: self,
387            row: row as IndexType,
388            col: 0,
389        }
390    }
391
392    /// Iterate over non-zero elements in a column.
393    pub fn col_iter(&self, col: usize) -> CSCColIter<'_> {
394        if col >= self.col_ptr.len() - 1 {
395            return CSCColIter {
396                row_indices: &[],
397                values: &[],
398                pos: 0,
399            };
400        }
401
402        let start = self.col_ptr[col] as usize;
403        let end = self.col_ptr[col + 1] as usize;
404
405        CSCColIter {
406            row_indices: &self.row_indices[start..end],
407            values: &self.values[start..end],
408            pos: 0,
409        }
410    }
411
412    /// Matrix-vector multiplication: result = A * x
413    pub fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) {
414        result.fill(0.0);
415        self.multiply_vector_add(x, result);
416    }
417
418    /// Matrix-vector multiplication with accumulation: result += A * x
419    pub fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) {
420        for col in 0..self.col_ptr.len() - 1 {
421            let x_col = x[col];
422            if x_col == 0.0 {
423                continue;
424            }
425
426            let start = self.col_ptr[col] as usize;
427            let end = self.col_ptr[col + 1] as usize;
428
429            for i in start..end {
430                let row = self.row_indices[i] as usize;
431                result[row] += self.values[i] * x_col;
432            }
433        }
434    }
435
436    /// Get number of non-zero elements.
437    pub fn nnz(&self) -> usize {
438        self.values.len()
439    }
440
441    /// Extract as coordinate triplets.
442    pub fn to_triplets(&self) -> Result<Vec<(usize, usize, Precision)>> {
443        let mut triplets = Vec::new();
444
445        for col in 0..self.col_ptr.len() - 1 {
446            let start = self.col_ptr[col] as usize;
447            let end = self.col_ptr[col + 1] as usize;
448
449            for i in start..end {
450                let row = self.row_indices[i] as usize;
451                let value = self.values[i];
452                triplets.push((row, col, value));
453            }
454        }
455
456        Ok(triplets)
457    }
458
459    /// Scale all values by a factor.
460    pub fn scale(&mut self, factor: Precision) {
461        for value in &mut self.values {
462            *value *= factor;
463        }
464    }
465
466    /// Add a value to the diagonal.
467    pub fn add_diagonal(&mut self, alpha: Precision) {
468        for col in 0..self.col_ptr.len() - 1 {
469            let start = self.col_ptr[col] as usize;
470            let end = self.col_ptr[col + 1] as usize;
471
472            // Look for diagonal element
473            if let Ok(pos) = self.row_indices[start..end].binary_search(&(col as IndexType)) {
474                self.values[start + pos] += alpha;
475            }
476        }
477    }
478}
479
480/// Iterator over non-zero elements in a CSC row (inefficient).
481pub struct CSCRowIter<'a> {
482    storage: &'a CSCStorage,
483    row: IndexType,
484    col: usize,
485}
486
487impl<'a> Iterator for CSCRowIter<'a> {
488    type Item = (IndexType, Precision);
489
490    fn next(&mut self) -> Option<Self::Item> {
491        while self.col < self.storage.col_ptr.len() - 1 {
492            let start = self.storage.col_ptr[self.col] as usize;
493            let end = self.storage.col_ptr[self.col + 1] as usize;
494
495            if let Ok(pos) = self.storage.row_indices[start..end].binary_search(&self.row) {
496                let value = self.storage.values[start + pos];
497                let col = self.col as IndexType;
498                self.col += 1;
499                return Some((col, value));
500            }
501
502            self.col += 1;
503        }
504        None
505    }
506}
507
508/// Iterator over non-zero elements in a CSC column.
509pub struct CSCColIter<'a> {
510    row_indices: &'a [IndexType],
511    values: &'a [Precision],
512    pos: usize,
513}
514
515impl<'a> Iterator for CSCColIter<'a> {
516    type Item = (IndexType, Precision);
517
518    fn next(&mut self) -> Option<Self::Item> {
519        if self.pos < self.row_indices.len() {
520            let row = self.row_indices[self.pos];
521            let val = self.values[self.pos];
522            self.pos += 1;
523            Some((row, val))
524        } else {
525            None
526        }
527    }
528}
529
530// COO Implementation
531impl COOStorage {
532    /// Create COO storage from triplets.
533    pub fn from_triplets(triplets: Vec<(usize, usize, Precision)>) -> Result<Self> {
534        let mut row_indices = Vec::new();
535        let mut col_indices = Vec::new();
536        let mut values = Vec::new();
537
538        for (row, col, value) in triplets {
539            if value != 0.0 {
540                row_indices.push(row as IndexType);
541                col_indices.push(col as IndexType);
542                values.push(value);
543            }
544        }
545
546        Ok(Self {
547            row_indices,
548            col_indices,
549            values,
550        })
551    }
552
553    /// Check if the storage is empty.
554    pub fn is_empty(&self) -> bool {
555        self.values.is_empty()
556    }
557
558    /// Get matrix element at (row, col) - O(n) search.
559    pub fn get(&self, row: usize, col: usize) -> Option<Precision> {
560        for i in 0..self.values.len() {
561            if self.row_indices[i] as usize == row && self.col_indices[i] as usize == col {
562                return Some(self.values[i]);
563            }
564        }
565        None
566    }
567
568    /// Iterate over non-zero elements in a row.
569    pub fn row_iter(&self, row: usize) -> COORowIter<'_> {
570        COORowIter {
571            storage: self,
572            target_row: row as IndexType,
573            pos: 0,
574        }
575    }
576
577    /// Iterate over non-zero elements in a column.
578    pub fn col_iter(&self, col: usize) -> COOColIter<'_> {
579        COOColIter {
580            storage: self,
581            target_col: col as IndexType,
582            pos: 0,
583        }
584    }
585
586    /// Matrix-vector multiplication: result = A * x
587    pub fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) {
588        result.fill(0.0);
589        self.multiply_vector_add(x, result);
590    }
591
592    /// Matrix-vector multiplication with accumulation: result += A * x
593    pub fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) {
594        for i in 0..self.values.len() {
595            let row = self.row_indices[i] as usize;
596            let col = self.col_indices[i] as usize;
597            result[row] += self.values[i] * x[col];
598        }
599    }
600
601    /// Get number of non-zero elements.
602    pub fn nnz(&self) -> usize {
603        self.values.len()
604    }
605
606    /// Extract as coordinate triplets.
607    pub fn to_triplets(&self) -> Vec<(usize, usize, Precision)> {
608        self.row_indices
609            .iter()
610            .zip(&self.col_indices)
611            .zip(&self.values)
612            .map(|((&r, &c), &v)| (r as usize, c as usize, v))
613            .collect()
614    }
615
616    /// Scale all values by a factor.
617    pub fn scale(&mut self, factor: Precision) {
618        for value in &mut self.values {
619            *value *= factor;
620        }
621    }
622
623    /// Add a value to the diagonal.
624    pub fn add_diagonal(&mut self, alpha: Precision, rows: DimensionType) {
625        // For COO, we'd need to add new diagonal entries if they don't exist
626        // This is a simplified implementation that only modifies existing diagonal entries
627        for i in 0..self.values.len() {
628            if self.row_indices[i] == self.col_indices[i] {
629                self.values[i] += alpha;
630            }
631        }
632    }
633}
634
635/// Iterator over non-zero elements in a COO row.
636pub struct COORowIter<'a> {
637    storage: &'a COOStorage,
638    target_row: IndexType,
639    pos: usize,
640}
641
642impl<'a> Iterator for COORowIter<'a> {
643    type Item = (IndexType, Precision);
644
645    fn next(&mut self) -> Option<Self::Item> {
646        while self.pos < self.storage.values.len() {
647            if self.storage.row_indices[self.pos] == self.target_row {
648                let col = self.storage.col_indices[self.pos];
649                let val = self.storage.values[self.pos];
650                self.pos += 1;
651                return Some((col, val));
652            }
653            self.pos += 1;
654        }
655        None
656    }
657}
658
659/// Iterator over non-zero elements in a COO column.
660pub struct COOColIter<'a> {
661    storage: &'a COOStorage,
662    target_col: IndexType,
663    pos: usize,
664}
665
666impl<'a> Iterator for COOColIter<'a> {
667    type Item = (IndexType, Precision);
668
669    fn next(&mut self) -> Option<Self::Item> {
670        while self.pos < self.storage.values.len() {
671            if self.storage.col_indices[self.pos] == self.target_col {
672                let row = self.storage.row_indices[self.pos];
673                let val = self.storage.values[self.pos];
674                self.pos += 1;
675                return Some((row, val));
676            }
677            self.pos += 1;
678        }
679        None
680    }
681}
682
683// Graph Implementation
684impl GraphStorage {
685    /// Create graph storage from triplets.
686    pub fn from_triplets(
687        triplets: Vec<(usize, usize, Precision)>,
688        nodes: DimensionType,
689    ) -> Result<Self> {
690        let mut out_edges = vec![Vec::new(); nodes];
691        let mut in_edges = vec![Vec::new(); nodes];
692        let mut degrees = vec![0.0; nodes];
693
694        for (row, col, weight) in triplets {
695            if weight != 0.0 && row < nodes && col < nodes {
696                out_edges[row].push(GraphEdge {
697                    target: col as NodeId,
698                    weight,
699                });
700
701                if row != col {
702                    // Don't double-count self-loops for in_edges
703                    in_edges[col].push(GraphEdge {
704                        target: row as NodeId,
705                        weight,
706                    });
707                }
708
709                degrees[row] += weight.abs();
710            }
711        }
712
713        Ok(Self {
714            out_edges,
715            in_edges,
716            degrees,
717        })
718    }
719
720    /// Get matrix element at (row, col).
721    pub fn get(&self, row: usize, col: usize) -> Option<Precision> {
722        if row >= self.out_edges.len() {
723            return None;
724        }
725
726        for edge in &self.out_edges[row] {
727            if edge.target as usize == col {
728                return Some(edge.weight);
729            }
730        }
731        None
732    }
733
734    /// Iterate over non-zero elements in a row.
735    pub fn row_iter(&self, row: usize) -> GraphRowIter<'_> {
736        if row >= self.out_edges.len() {
737            GraphRowIter { edges: &[], pos: 0 }
738        } else {
739            GraphRowIter {
740                edges: &self.out_edges[row],
741                pos: 0,
742            }
743        }
744    }
745
746    /// Iterate over non-zero elements in a column.
747    pub fn col_iter(&self, col: usize) -> GraphColIter<'_> {
748        if col >= self.in_edges.len() {
749            GraphColIter { edges: &[], pos: 0 }
750        } else {
751            GraphColIter {
752                edges: &self.in_edges[col],
753                pos: 0,
754            }
755        }
756    }
757
758    /// Matrix-vector multiplication: result = A * x
759    pub fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) {
760        result.fill(0.0);
761        self.multiply_vector_add(x, result);
762    }
763
764    /// Matrix-vector multiplication with accumulation: result += A * x
765    pub fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) {
766        for (row, edges) in self.out_edges.iter().enumerate() {
767            for edge in edges {
768                let col = edge.target as usize;
769                if col < x.len() {
770                    result[row] += edge.weight * x[col];
771                }
772            }
773        }
774    }
775
776    /// Get number of non-zero elements.
777    pub fn nnz(&self) -> usize {
778        self.out_edges.iter().map(|edges| edges.len()).sum()
779    }
780
781    /// Extract as coordinate triplets.
782    pub fn to_triplets(&self) -> Result<Vec<(usize, usize, Precision)>> {
783        let mut triplets = Vec::new();
784
785        for (row, edges) in self.out_edges.iter().enumerate() {
786            for edge in edges {
787                triplets.push((row, edge.target as usize, edge.weight));
788            }
789        }
790
791        Ok(triplets)
792    }
793
794    /// Scale all edge weights by a factor.
795    pub fn scale(&mut self, factor: Precision) {
796        for edges in &mut self.out_edges {
797            for edge in edges {
798                edge.weight *= factor;
799            }
800        }
801
802        for edges in &mut self.in_edges {
803            for edge in edges {
804                edge.weight *= factor;
805            }
806        }
807
808        for degree in &mut self.degrees {
809            *degree *= factor.abs();
810        }
811    }
812
813    /// Add a value to the diagonal.
814    pub fn add_diagonal(&mut self, alpha: Precision) {
815        for (node, edges) in self.out_edges.iter_mut().enumerate() {
816            // Look for self-loop
817            let mut found = false;
818            for edge in edges.iter_mut() {
819                if edge.target as usize == node {
820                    edge.weight += alpha;
821                    found = true;
822                    break;
823                }
824            }
825
826            // Add self-loop if it doesn't exist
827            if !found && alpha != 0.0 {
828                edges.push(GraphEdge {
829                    target: node as NodeId,
830                    weight: alpha,
831                });
832            }
833
834            // Update degree
835            self.degrees[node] += alpha.abs();
836        }
837    }
838
839    /// Get outgoing edges for a node.
840    pub fn out_neighbors(&self, node: usize) -> &[GraphEdge] {
841        if node < self.out_edges.len() {
842            &self.out_edges[node]
843        } else {
844            &[]
845        }
846    }
847
848    /// Get incoming edges for a node.
849    pub fn in_neighbors(&self, node: usize) -> &[GraphEdge] {
850        if node < self.in_edges.len() {
851            &self.in_edges[node]
852        } else {
853            &[]
854        }
855    }
856
857    /// Get node degree.
858    pub fn degree(&self, node: usize) -> Precision {
859        if node < self.degrees.len() {
860            self.degrees[node]
861        } else {
862            0.0
863        }
864    }
865}
866
867/// Iterator over non-zero elements in a graph row.
868pub struct GraphRowIter<'a> {
869    edges: &'a [GraphEdge],
870    pos: usize,
871}
872
873impl<'a> Iterator for GraphRowIter<'a> {
874    type Item = (IndexType, Precision);
875
876    fn next(&mut self) -> Option<Self::Item> {
877        if self.pos < self.edges.len() {
878            let edge = self.edges[self.pos];
879            self.pos += 1;
880            Some((edge.target, edge.weight))
881        } else {
882            None
883        }
884    }
885}
886
887/// Iterator over non-zero elements in a graph column.
888pub struct GraphColIter<'a> {
889    edges: &'a [GraphEdge],
890    pos: usize,
891}
892
893impl<'a> Iterator for GraphColIter<'a> {
894    type Item = (IndexType, Precision);
895
896    fn next(&mut self) -> Option<Self::Item> {
897        if self.pos < self.edges.len() {
898            let edge = self.edges[self.pos];
899            self.pos += 1;
900            Some((edge.target, edge.weight))
901        } else {
902            None
903        }
904    }
905}
906
907#[cfg(all(test, feature = "std"))]
908mod tests {
909    use super::*;
910
911    #[test]
912    fn test_csr_creation() {
913        let triplets = vec![
914            (0, 0, 1.0),
915            (0, 2, 2.0),
916            (1, 1, 3.0),
917            (2, 0, 4.0),
918            (2, 2, 5.0),
919        ];
920        let coo = COOStorage::from_triplets(triplets).unwrap();
921        let csr = CSRStorage::from_coo(&coo, 3, 3).unwrap();
922
923        assert_eq!(csr.nnz(), 5);
924        assert_eq!(csr.get(0, 0), Some(1.0));
925        assert_eq!(csr.get(0, 2), Some(2.0));
926        assert_eq!(csr.get(1, 1), Some(3.0));
927        assert_eq!(csr.get(0, 1), None);
928    }
929
930    #[test]
931    fn test_csr_matrix_vector_multiply() {
932        let triplets = vec![(0, 0, 2.0), (0, 1, 1.0), (1, 0, 1.0), (1, 1, 3.0)];
933        let coo = COOStorage::from_triplets(triplets).unwrap();
934        let csr = CSRStorage::from_coo(&coo, 2, 2).unwrap();
935
936        let x = vec![1.0, 2.0];
937        let mut result = vec![0.0; 2];
938
939        csr.multiply_vector(&x, &mut result);
940        assert_eq!(result, vec![4.0, 7.0]); // [2*1+1*2, 1*1+3*2]
941    }
942
943    #[test]
944    fn test_graph_storage() {
945        let triplets = vec![(0, 1, 0.5), (1, 0, 0.3), (1, 2, 0.7), (2, 1, 0.2)];
946        let graph = GraphStorage::from_triplets(triplets, 3).unwrap();
947
948        assert_eq!(graph.nnz(), 4);
949        assert_eq!(graph.out_neighbors(1).len(), 2);
950        assert_eq!(graph.in_neighbors(1).len(), 2);
951        assert!(graph.degree(1) > 0.0);
952    }
953
954    #[test]
955    fn test_format_conversions() {
956        let triplets = vec![(0, 0, 1.0), (0, 2, 2.0), (1, 1, 3.0)];
957
958        // COO -> CSR -> CSC -> COO roundtrip
959        let coo1 = COOStorage::from_triplets(triplets.clone()).unwrap();
960        let csr = CSRStorage::from_coo(&coo1, 2, 3).unwrap();
961        let csc = CSCStorage::from_csr(&csr, 2, 3).unwrap();
962        let triplets2 = csc.to_triplets().unwrap();
963
964        // Sort both for comparison. `f64` does not implement `Ord` because of
965        // NaN, so sort lexicographically on (row, col) — the values are
966        // already exact small integers here, but we route them through
967        // `total_cmp` rather than relying on PartialOrd.
968        fn cmp_triplet(a: &(usize, usize, f64), b: &(usize, usize, f64)) -> std::cmp::Ordering {
969            a.0.cmp(&b.0)
970                .then_with(|| a.1.cmp(&b.1))
971                .then_with(|| a.2.total_cmp(&b.2))
972        }
973        let mut t1 = triplets.clone();
974        let mut t2 = triplets2;
975        t1.sort_by(cmp_triplet);
976        t2.sort_by(cmp_triplet);
977
978        assert_eq!(t1, t2);
979    }
980}