1use crate::error::{Result, SolverError};
8use crate::types::{ConditioningInfo, DimensionType, IndexType, Precision, SparsityInfo};
9use alloc::vec::Vec;
10use core::fmt;
11
12pub mod optimized;
13pub mod sparse;
14
15use sparse::*;
16
17pub use optimized::{BufferPool, OptimizedCSRStorage, StreamingMatrix};
19
20pub trait Matrix: Send + Sync {
26 fn rows(&self) -> DimensionType;
28
29 fn cols(&self) -> DimensionType;
31
32 fn get(&self, row: usize, col: usize) -> Option<Precision>;
34
35 fn row_iter(&self, row: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_>;
38
39 fn col_iter(&self, col: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_>;
42
43 fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) -> Result<()>;
45
46 fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) -> Result<()>;
48
49 fn is_diagonally_dominant(&self) -> bool;
52
53 fn diagonal_dominance_factor(&self) -> Option<Precision>;
55
56 fn nnz(&self) -> usize;
58
59 fn sparsity_info(&self) -> SparsityInfo;
61
62 fn conditioning_info(&self) -> ConditioningInfo;
64
65 fn format_name(&self) -> &'static str;
67
68 fn is_square(&self) -> bool {
70 self.rows() == self.cols()
71 }
72
73 fn frobenius_norm(&self) -> Precision {
75 let mut norm_sq = 0.0;
76 for row in 0..self.rows() {
77 for (_, value) in self.row_iter(row) {
78 norm_sq += value * value;
79 }
80 }
81 norm_sq.sqrt()
82 }
83
84 fn spectral_radius_estimate(&self) -> Precision {
87 let mut max_radius: Precision = 0.0;
88 for row in 0..self.rows() {
89 let mut diagonal = 0.0;
90 let mut off_diagonal_sum = 0.0;
91
92 for (col, value) in self.row_iter(row) {
93 if col as usize == row {
94 diagonal = value.abs();
95 } else {
96 off_diagonal_sum += value.abs();
97 }
98 }
99
100 max_radius = max_radius.max(diagonal + off_diagonal_sum);
101 }
102 max_radius
103 }
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
109pub enum SparseFormat {
110 CSR,
112 CSC,
114 COO,
116 GraphAdjacency,
118}
119
120#[derive(Debug, Clone)]
122#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
123pub struct SparseMatrix {
124 format: SparseFormat,
126 rows: DimensionType,
128 cols: DimensionType,
129 storage: SparseStorage,
131}
132
133#[derive(Debug, Clone)]
135#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
136enum SparseStorage {
137 CSR(CSRStorage),
138 CSC(CSCStorage),
139 COO(COOStorage),
140 Graph(GraphStorage),
141}
142
143impl SparseMatrix {
144 pub fn from_triplets(
161 triplets: Vec<(usize, usize, Precision)>,
162 rows: DimensionType,
163 cols: DimensionType,
164 ) -> Result<Self> {
165 for &(r, c, v) in &triplets {
167 if r >= rows {
168 return Err(SolverError::IndexOutOfBounds {
169 index: r,
170 max_index: rows - 1,
171 context: "row index in triplet".to_string(),
172 });
173 }
174 if c >= cols {
175 return Err(SolverError::IndexOutOfBounds {
176 index: c,
177 max_index: cols - 1,
178 context: "column index in triplet".to_string(),
179 });
180 }
181 if !v.is_finite() {
182 return Err(SolverError::InvalidInput {
183 message: format!("Non-finite value {} at ({}, {})", v, r, c),
184 parameter: Some("matrix_element".to_string()),
185 });
186 }
187 }
188
189 let coo_storage = COOStorage::from_triplets(triplets)?;
191 let csr_storage = CSRStorage::from_coo(&coo_storage, rows, cols)?;
192
193 Ok(Self {
194 format: SparseFormat::CSR,
195 rows,
196 cols,
197 storage: SparseStorage::CSR(csr_storage),
198 })
199 }
200
201 pub fn from_dense(
205 data: &[Precision],
206 rows: DimensionType,
207 cols: DimensionType,
208 ) -> Result<Self> {
209 if data.len() != rows * cols {
210 return Err(SolverError::DimensionMismatch {
211 expected: rows * cols,
212 actual: data.len(),
213 operation: "dense_to_sparse_conversion".to_string(),
214 });
215 }
216
217 let mut triplets = Vec::new();
218 for (i, &value) in data.iter().enumerate() {
219 if value != 0.0 {
220 let row = i / cols;
221 let col = i % cols;
222 triplets.push((row, col, value));
223 }
224 }
225
226 Self::from_triplets(triplets, rows, cols)
227 }
228
229 pub fn identity(size: DimensionType) -> Result<Self> {
231 let triplets: Vec<_> = (0..size).map(|i| (i, i, 1.0)).collect();
232 Self::from_triplets(triplets, size, size)
233 }
234
235 pub fn diagonal(diag: &[Precision]) -> Result<Self> {
237 let size = diag.len();
238 let triplets: Vec<_> = diag
239 .iter()
240 .enumerate()
241 .filter(|(_, &v)| v != 0.0)
242 .map(|(i, &v)| (i, i, v))
243 .collect();
244 Self::from_triplets(triplets, size, size)
245 }
246
247 pub fn convert_to_format(&mut self, new_format: SparseFormat) -> Result<()> {
251 if self.format == new_format {
252 return Ok(());
253 }
254
255 match (self.format, new_format) {
256 (SparseFormat::CSR, SparseFormat::CSC) => {
257 if let SparseStorage::CSR(ref csr) = self.storage {
258 let csc = CSCStorage::from_csr(csr, self.rows, self.cols)?;
259 self.storage = SparseStorage::CSC(csc);
260 self.format = SparseFormat::CSC;
261 }
262 }
263 (SparseFormat::CSC, SparseFormat::CSR) => {
264 if let SparseStorage::CSC(ref csc) = self.storage {
265 let csr = CSRStorage::from_csc(csc, self.rows, self.cols)?;
266 self.storage = SparseStorage::CSR(csr);
267 self.format = SparseFormat::CSR;
268 }
269 }
270 (_, SparseFormat::GraphAdjacency) => {
271 let triplets = self.to_triplets()?;
273 let graph = GraphStorage::from_triplets(triplets, self.rows)?;
274 self.storage = SparseStorage::Graph(graph);
275 self.format = SparseFormat::GraphAdjacency;
276 }
277 _ => {
278 let triplets = self.to_triplets()?;
280 let coo = COOStorage::from_triplets(triplets)?;
281
282 match new_format {
283 SparseFormat::CSR => {
284 let csr = CSRStorage::from_coo(&coo, self.rows, self.cols)?;
285 self.storage = SparseStorage::CSR(csr);
286 }
287 SparseFormat::CSC => {
288 let csc = CSCStorage::from_coo(&coo, self.rows, self.cols)?;
289 self.storage = SparseStorage::CSC(csc);
290 }
291 SparseFormat::COO => {
292 self.storage = SparseStorage::COO(coo);
293 }
294 _ => unreachable!(),
295 }
296 self.format = new_format;
297 }
298 }
299
300 Ok(())
301 }
302
303 pub fn to_triplets(&self) -> Result<Vec<(usize, usize, Precision)>> {
305 match &self.storage {
306 SparseStorage::CSR(csr) => csr.to_triplets(),
307 SparseStorage::CSC(csc) => csc.to_triplets(),
308 SparseStorage::COO(coo) => Ok(coo.to_triplets()),
309 SparseStorage::Graph(graph) => graph.to_triplets(),
310 }
311 }
312
313 pub fn format(&self) -> SparseFormat {
315 self.format
316 }
317
318 pub fn as_csr(&mut self) -> Result<&CSRStorage> {
322 self.convert_to_format(SparseFormat::CSR)?;
323 match &self.storage {
324 SparseStorage::CSR(csr) => Ok(csr),
325 _ => unreachable!(),
326 }
327 }
328
329 pub fn as_csc(&mut self) -> Result<&CSCStorage> {
333 self.convert_to_format(SparseFormat::CSC)?;
334 match &self.storage {
335 SparseStorage::CSC(csc) => Ok(csc),
336 _ => unreachable!(),
337 }
338 }
339
340 pub fn as_graph(&mut self) -> Result<&GraphStorage> {
344 self.convert_to_format(SparseFormat::GraphAdjacency)?;
345 match &self.storage {
346 SparseStorage::Graph(graph) => Ok(graph),
347 _ => unreachable!(),
348 }
349 }
350
351 pub fn scale(&mut self, factor: Precision) {
353 match &mut self.storage {
354 SparseStorage::CSR(csr) => csr.scale(factor),
355 SparseStorage::CSC(csc) => csc.scale(factor),
356 SparseStorage::COO(coo) => coo.scale(factor),
357 SparseStorage::Graph(graph) => graph.scale(factor),
358 }
359 }
360
361 pub fn add_diagonal(&mut self, alpha: Precision) -> Result<()> {
363 if !self.is_square() {
364 return Err(SolverError::InvalidInput {
365 message: "Cannot add diagonal to non-square matrix".to_string(),
366 parameter: Some("matrix_dimensions".to_string()),
367 });
368 }
369
370 match &mut self.storage {
371 SparseStorage::CSR(csr) => csr.add_diagonal(alpha),
372 SparseStorage::CSC(csc) => csc.add_diagonal(alpha),
373 SparseStorage::COO(coo) => coo.add_diagonal(alpha, self.rows),
374 SparseStorage::Graph(graph) => graph.add_diagonal(alpha),
375 }
376
377 Ok(())
378 }
379}
380
381impl Matrix for SparseMatrix {
382 fn rows(&self) -> DimensionType {
383 self.rows
384 }
385
386 fn cols(&self) -> DimensionType {
387 self.cols
388 }
389
390 fn get(&self, row: usize, col: usize) -> Option<Precision> {
391 if row >= self.rows || col >= self.cols {
392 return None;
393 }
394
395 match &self.storage {
396 SparseStorage::CSR(csr) => csr.get(row, col),
397 SparseStorage::CSC(csc) => csc.get(row, col),
398 SparseStorage::COO(coo) => coo.get(row, col),
399 SparseStorage::Graph(graph) => graph.get(row, col),
400 }
401 }
402
403 fn row_iter(&self, row: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_> {
404 match &self.storage {
405 SparseStorage::CSR(csr) => Box::new(csr.row_iter(row)),
406 SparseStorage::CSC(csc) => Box::new(csc.row_iter(row)),
407 SparseStorage::COO(coo) => Box::new(coo.row_iter(row)),
408 SparseStorage::Graph(graph) => Box::new(graph.row_iter(row)),
409 }
410 }
411
412 fn col_iter(&self, col: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_> {
413 match &self.storage {
414 SparseStorage::CSR(csr) => Box::new(csr.col_iter(col)),
415 SparseStorage::CSC(csc) => Box::new(csc.col_iter(col)),
416 SparseStorage::COO(coo) => Box::new(coo.col_iter(col)),
417 SparseStorage::Graph(graph) => Box::new(graph.col_iter(col)),
418 }
419 }
420
421 fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) -> Result<()> {
422 if x.len() != self.cols {
423 return Err(SolverError::DimensionMismatch {
424 expected: self.cols,
425 actual: x.len(),
426 operation: "matrix_vector_multiply".to_string(),
427 });
428 }
429 if result.len() != self.rows {
430 return Err(SolverError::DimensionMismatch {
431 expected: self.rows,
432 actual: result.len(),
433 operation: "matrix_vector_multiply".to_string(),
434 });
435 }
436
437 match &self.storage {
438 SparseStorage::CSR(csr) => csr.multiply_vector(x, result),
439 SparseStorage::CSC(csc) => csc.multiply_vector(x, result),
440 SparseStorage::COO(coo) => coo.multiply_vector(x, result),
441 SparseStorage::Graph(graph) => graph.multiply_vector(x, result),
442 }
443
444 Ok(())
445 }
446
447 fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) -> Result<()> {
448 if x.len() != self.cols {
449 return Err(SolverError::DimensionMismatch {
450 expected: self.cols,
451 actual: x.len(),
452 operation: "matrix_vector_multiply_add".to_string(),
453 });
454 }
455 if result.len() != self.rows {
456 return Err(SolverError::DimensionMismatch {
457 expected: self.rows,
458 actual: result.len(),
459 operation: "matrix_vector_multiply_add".to_string(),
460 });
461 }
462
463 match &self.storage {
464 SparseStorage::CSR(csr) => csr.multiply_vector_add(x, result),
465 SparseStorage::CSC(csc) => csc.multiply_vector_add(x, result),
466 SparseStorage::COO(coo) => coo.multiply_vector_add(x, result),
467 SparseStorage::Graph(graph) => graph.multiply_vector_add(x, result),
468 }
469
470 Ok(())
471 }
472
473 fn is_diagonally_dominant(&self) -> bool {
474 for row in 0..self.rows {
475 let mut diagonal = 0.0;
476 let mut off_diagonal_sum = 0.0;
477
478 for (col, value) in self.row_iter(row) {
479 if col as usize == row {
480 diagonal = value.abs();
481 } else {
482 off_diagonal_sum += value.abs();
483 }
484 }
485
486 if diagonal < off_diagonal_sum {
487 return false;
488 }
489 }
490 true
491 }
492
493 fn diagonal_dominance_factor(&self) -> Option<Precision> {
494 let mut min_factor = Precision::INFINITY;
495
496 for row in 0..self.rows {
497 let mut diagonal = 0.0;
498 let mut off_diagonal_sum = 0.0;
499
500 for (col, value) in self.row_iter(row) {
501 if col as usize == row {
502 diagonal = value.abs();
503 } else {
504 off_diagonal_sum += value.abs();
505 }
506 }
507
508 if off_diagonal_sum > 0.0 {
509 let factor = diagonal / off_diagonal_sum;
510 min_factor = min_factor.min(factor);
511 }
512 }
513
514 if min_factor.is_finite() {
515 Some(min_factor)
516 } else {
517 None
518 }
519 }
520
521 fn nnz(&self) -> usize {
522 match &self.storage {
523 SparseStorage::CSR(csr) => csr.nnz(),
524 SparseStorage::CSC(csc) => csc.nnz(),
525 SparseStorage::COO(coo) => coo.nnz(),
526 SparseStorage::Graph(graph) => graph.nnz(),
527 }
528 }
529
530 fn sparsity_info(&self) -> SparsityInfo {
531 let mut info = SparsityInfo::new(self.nnz(), self.rows, self.cols);
532
533 let mut max_nnz_per_row = 0;
535 for row in 0..self.rows {
536 let row_nnz = self.row_iter(row).count();
537 max_nnz_per_row = max_nnz_per_row.max(row_nnz);
538 }
539 info.max_nnz_per_row = max_nnz_per_row;
540
541 let mut max_bandwidth = 0;
543 for (r, c, _) in self.to_triplets().unwrap_or_default() {
544 let bandwidth = if r > c { r - c } else { c - r };
545 max_bandwidth = max_bandwidth.max(bandwidth);
546 }
547 info.bandwidth = Some(max_bandwidth);
548 info.is_banded = max_bandwidth < self.rows / 4; info
551 }
552
553 fn conditioning_info(&self) -> ConditioningInfo {
554 ConditioningInfo {
555 condition_number: None, is_diagonally_dominant: self.is_diagonally_dominant(),
557 diagonal_dominance_factor: self.diagonal_dominance_factor(),
558 spectral_radius: Some(self.spectral_radius_estimate()),
559 is_positive_definite: None, }
561 }
562
563 fn format_name(&self) -> &'static str {
564 match self.format {
565 SparseFormat::CSR => "CSR",
566 SparseFormat::CSC => "CSC",
567 SparseFormat::COO => "COO",
568 SparseFormat::GraphAdjacency => "GraphAdjacency",
569 }
570 }
571}
572
573impl fmt::Display for SparseMatrix {
574 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
575 write!(
576 f,
577 "{}x{} sparse matrix ({} format, {} nnz)",
578 self.rows,
579 self.cols,
580 self.format_name(),
581 self.nnz()
582 )
583 }
584}
585
586#[cfg(all(test, feature = "std"))]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_matrix_creation() {
592 let triplets = vec![(0, 0, 4.0), (0, 1, 1.0), (1, 0, 2.0), (1, 1, 5.0)];
593 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
594
595 assert_eq!(matrix.rows(), 2);
596 assert_eq!(matrix.cols(), 2);
597 assert_eq!(matrix.nnz(), 4);
598 assert!(matrix.is_diagonally_dominant());
599 }
600
601 #[test]
602 fn test_matrix_vector_multiply() {
603 let triplets = vec![(0, 0, 2.0), (0, 1, 1.0), (1, 0, 1.0), (1, 1, 3.0)];
604 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
605
606 let x = vec![1.0, 2.0];
607 let mut result = vec![0.0; 2];
608
609 matrix.multiply_vector(&x, &mut result).unwrap();
610
611 assert_eq!(result, vec![4.0, 7.0]); }
613
614 #[test]
615 fn test_diagonal_dominance() {
616 let triplets = vec![(0, 0, 5.0), (0, 1, 1.0), (1, 0, 2.0), (1, 1, 7.0)];
618 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
619 assert!(matrix.is_diagonally_dominant());
620
621 let triplets = vec![(0, 0, 1.0), (0, 1, 3.0), (1, 0, 2.0), (1, 1, 2.0)];
623 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
624 assert!(!matrix.is_diagonally_dominant());
625 }
626
627 #[test]
628 fn test_format_conversion() {
629 let triplets = vec![(0, 0, 1.0), (0, 2, 2.0), (1, 1, 3.0)];
630 let mut matrix = SparseMatrix::from_triplets(triplets, 2, 3).unwrap();
631
632 assert_eq!(matrix.format(), SparseFormat::CSR);
633
634 matrix.convert_to_format(SparseFormat::CSC).unwrap();
635 assert_eq!(matrix.format(), SparseFormat::CSC);
636
637 matrix
638 .convert_to_format(SparseFormat::GraphAdjacency)
639 .unwrap();
640 assert_eq!(matrix.format(), SparseFormat::GraphAdjacency);
641 }
642}