1use crate::types::{Precision, DimensionType, IndexType, SparsityInfo, ConditioningInfo};
8use crate::error::{SolverError, Result};
9use alloc::{vec::Vec, string::String};
10use core::fmt;
11
12pub mod sparse;
13pub mod optimized;
14
15use sparse::*;
16
17pub use optimized::{OptimizedCSRStorage, BufferPool, StreamingMatrix};
19
20pub trait Matrix: Send + Sync {
26 fn rows(&self) -> DimensionType;
28
29 fn cols(&self) -> DimensionType;
31
32 fn get(&self, row: usize, col: usize) -> Option<Precision>;
34
35 fn row_iter(&self, row: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_>;
38
39 fn col_iter(&self, col: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_>;
42
43 fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) -> Result<()>;
45
46 fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) -> Result<()>;
48
49 fn is_diagonally_dominant(&self) -> bool;
52
53 fn diagonal_dominance_factor(&self) -> Option<Precision>;
55
56 fn nnz(&self) -> usize;
58
59 fn sparsity_info(&self) -> SparsityInfo;
61
62 fn conditioning_info(&self) -> ConditioningInfo;
64
65 fn format_name(&self) -> &'static str;
67
68 fn is_square(&self) -> bool {
70 self.rows() == self.cols()
71 }
72
73 fn frobenius_norm(&self) -> Precision {
75 let mut norm_sq = 0.0;
76 for row in 0..self.rows() {
77 for (_, value) in self.row_iter(row) {
78 norm_sq += value * value;
79 }
80 }
81 norm_sq.sqrt()
82 }
83
84 fn spectral_radius_estimate(&self) -> Precision {
87 let mut max_radius: Precision = 0.0;
88 for row in 0..self.rows() {
89 let mut diagonal = 0.0;
90 let mut off_diagonal_sum = 0.0;
91
92 for (col, value) in self.row_iter(row) {
93 if col as usize == row {
94 diagonal = value.abs();
95 } else {
96 off_diagonal_sum += value.abs();
97 }
98 }
99
100 max_radius = max_radius.max(diagonal + off_diagonal_sum);
101 }
102 max_radius
103 }
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
109pub enum SparseFormat {
110 CSR,
112 CSC,
114 COO,
116 GraphAdjacency,
118}
119
120#[derive(Debug, Clone)]
122#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
123pub struct SparseMatrix {
124 format: SparseFormat,
126 rows: DimensionType,
128 cols: DimensionType,
129 storage: SparseStorage,
131}
132
133#[derive(Debug, Clone)]
135#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
136enum SparseStorage {
137 CSR(CSRStorage),
138 CSC(CSCStorage),
139 COO(COOStorage),
140 Graph(GraphStorage),
141}
142
143impl SparseMatrix {
144 pub fn from_triplets(
161 triplets: Vec<(usize, usize, Precision)>,
162 rows: DimensionType,
163 cols: DimensionType,
164 ) -> Result<Self> {
165 for &(r, c, v) in &triplets {
167 if r >= rows {
168 return Err(SolverError::IndexOutOfBounds {
169 index: r,
170 max_index: rows - 1,
171 context: "row index in triplet".to_string(),
172 });
173 }
174 if c >= cols {
175 return Err(SolverError::IndexOutOfBounds {
176 index: c,
177 max_index: cols - 1,
178 context: "column index in triplet".to_string(),
179 });
180 }
181 if !v.is_finite() {
182 return Err(SolverError::InvalidInput {
183 message: format!("Non-finite value {} at ({}, {})", v, r, c),
184 parameter: Some("matrix_element".to_string()),
185 });
186 }
187 }
188
189 let coo_storage = COOStorage::from_triplets(triplets)?;
191 let csr_storage = CSRStorage::from_coo(&coo_storage, rows, cols)?;
192
193 Ok(Self {
194 format: SparseFormat::CSR,
195 rows,
196 cols,
197 storage: SparseStorage::CSR(csr_storage),
198 })
199 }
200
201 pub fn from_dense(data: &[Precision], rows: DimensionType, cols: DimensionType) -> Result<Self> {
205 if data.len() != rows * cols {
206 return Err(SolverError::DimensionMismatch {
207 expected: rows * cols,
208 actual: data.len(),
209 operation: "dense_to_sparse_conversion".to_string(),
210 });
211 }
212
213 let mut triplets = Vec::new();
214 for (i, &value) in data.iter().enumerate() {
215 if value != 0.0 {
216 let row = i / cols;
217 let col = i % cols;
218 triplets.push((row, col, value));
219 }
220 }
221
222 Self::from_triplets(triplets, rows, cols)
223 }
224
225 pub fn identity(size: DimensionType) -> Result<Self> {
227 let triplets: Vec<_> = (0..size).map(|i| (i, i, 1.0)).collect();
228 Self::from_triplets(triplets, size, size)
229 }
230
231 pub fn diagonal(diag: &[Precision]) -> Result<Self> {
233 let size = diag.len();
234 let triplets: Vec<_> = diag.iter().enumerate()
235 .filter(|(_, &v)| v != 0.0)
236 .map(|(i, &v)| (i, i, v))
237 .collect();
238 Self::from_triplets(triplets, size, size)
239 }
240
241 pub fn convert_to_format(&mut self, new_format: SparseFormat) -> Result<()> {
245 if self.format == new_format {
246 return Ok(());
247 }
248
249 match (self.format, new_format) {
250 (SparseFormat::CSR, SparseFormat::CSC) => {
251 if let SparseStorage::CSR(ref csr) = self.storage {
252 let csc = CSCStorage::from_csr(csr, self.rows, self.cols)?;
253 self.storage = SparseStorage::CSC(csc);
254 self.format = SparseFormat::CSC;
255 }
256 },
257 (SparseFormat::CSC, SparseFormat::CSR) => {
258 if let SparseStorage::CSC(ref csc) = self.storage {
259 let csr = CSRStorage::from_csc(csc, self.rows, self.cols)?;
260 self.storage = SparseStorage::CSR(csr);
261 self.format = SparseFormat::CSR;
262 }
263 },
264 (_, SparseFormat::GraphAdjacency) => {
265 let triplets = self.to_triplets()?;
267 let graph = GraphStorage::from_triplets(triplets, self.rows)?;
268 self.storage = SparseStorage::Graph(graph);
269 self.format = SparseFormat::GraphAdjacency;
270 },
271 _ => {
272 let triplets = self.to_triplets()?;
274 let coo = COOStorage::from_triplets(triplets)?;
275
276 match new_format {
277 SparseFormat::CSR => {
278 let csr = CSRStorage::from_coo(&coo, self.rows, self.cols)?;
279 self.storage = SparseStorage::CSR(csr);
280 },
281 SparseFormat::CSC => {
282 let csc = CSCStorage::from_coo(&coo, self.rows, self.cols)?;
283 self.storage = SparseStorage::CSC(csc);
284 },
285 SparseFormat::COO => {
286 self.storage = SparseStorage::COO(coo);
287 },
288 _ => unreachable!(),
289 }
290 self.format = new_format;
291 }
292 }
293
294 Ok(())
295 }
296
297 pub fn to_triplets(&self) -> Result<Vec<(usize, usize, Precision)>> {
299 match &self.storage {
300 SparseStorage::CSR(csr) => csr.to_triplets(),
301 SparseStorage::CSC(csc) => csc.to_triplets(),
302 SparseStorage::COO(coo) => Ok(coo.to_triplets()),
303 SparseStorage::Graph(graph) => graph.to_triplets(),
304 }
305 }
306
307 pub fn format(&self) -> SparseFormat {
309 self.format
310 }
311
312 pub fn as_csr(&mut self) -> Result<&CSRStorage> {
316 self.convert_to_format(SparseFormat::CSR)?;
317 match &self.storage {
318 SparseStorage::CSR(csr) => Ok(csr),
319 _ => unreachable!(),
320 }
321 }
322
323 pub fn as_csc(&mut self) -> Result<&CSCStorage> {
327 self.convert_to_format(SparseFormat::CSC)?;
328 match &self.storage {
329 SparseStorage::CSC(csc) => Ok(csc),
330 _ => unreachable!(),
331 }
332 }
333
334 pub fn as_graph(&mut self) -> Result<&GraphStorage> {
338 self.convert_to_format(SparseFormat::GraphAdjacency)?;
339 match &self.storage {
340 SparseStorage::Graph(graph) => Ok(graph),
341 _ => unreachable!(),
342 }
343 }
344
345 pub fn scale(&mut self, factor: Precision) {
347 match &mut self.storage {
348 SparseStorage::CSR(csr) => csr.scale(factor),
349 SparseStorage::CSC(csc) => csc.scale(factor),
350 SparseStorage::COO(coo) => coo.scale(factor),
351 SparseStorage::Graph(graph) => graph.scale(factor),
352 }
353 }
354
355 pub fn add_diagonal(&mut self, alpha: Precision) -> Result<()> {
357 if !self.is_square() {
358 return Err(SolverError::InvalidInput {
359 message: "Cannot add diagonal to non-square matrix".to_string(),
360 parameter: Some("matrix_dimensions".to_string()),
361 });
362 }
363
364 match &mut self.storage {
365 SparseStorage::CSR(csr) => csr.add_diagonal(alpha),
366 SparseStorage::CSC(csc) => csc.add_diagonal(alpha),
367 SparseStorage::COO(coo) => coo.add_diagonal(alpha, self.rows),
368 SparseStorage::Graph(graph) => graph.add_diagonal(alpha),
369 }
370
371 Ok(())
372 }
373}
374
375impl Matrix for SparseMatrix {
376 fn rows(&self) -> DimensionType {
377 self.rows
378 }
379
380 fn cols(&self) -> DimensionType {
381 self.cols
382 }
383
384 fn get(&self, row: usize, col: usize) -> Option<Precision> {
385 if row >= self.rows || col >= self.cols {
386 return None;
387 }
388
389 match &self.storage {
390 SparseStorage::CSR(csr) => csr.get(row, col),
391 SparseStorage::CSC(csc) => csc.get(row, col),
392 SparseStorage::COO(coo) => coo.get(row, col),
393 SparseStorage::Graph(graph) => graph.get(row, col),
394 }
395 }
396
397 fn row_iter(&self, row: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_> {
398 match &self.storage {
399 SparseStorage::CSR(csr) => Box::new(csr.row_iter(row)),
400 SparseStorage::CSC(csc) => Box::new(csc.row_iter(row)),
401 SparseStorage::COO(coo) => Box::new(coo.row_iter(row)),
402 SparseStorage::Graph(graph) => Box::new(graph.row_iter(row)),
403 }
404 }
405
406 fn col_iter(&self, col: usize) -> Box<dyn Iterator<Item = (IndexType, Precision)> + '_> {
407 match &self.storage {
408 SparseStorage::CSR(csr) => Box::new(csr.col_iter(col)),
409 SparseStorage::CSC(csc) => Box::new(csc.col_iter(col)),
410 SparseStorage::COO(coo) => Box::new(coo.col_iter(col)),
411 SparseStorage::Graph(graph) => Box::new(graph.col_iter(col)),
412 }
413 }
414
415 fn multiply_vector(&self, x: &[Precision], result: &mut [Precision]) -> Result<()> {
416 if x.len() != self.cols {
417 return Err(SolverError::DimensionMismatch {
418 expected: self.cols,
419 actual: x.len(),
420 operation: "matrix_vector_multiply".to_string(),
421 });
422 }
423 if result.len() != self.rows {
424 return Err(SolverError::DimensionMismatch {
425 expected: self.rows,
426 actual: result.len(),
427 operation: "matrix_vector_multiply".to_string(),
428 });
429 }
430
431 match &self.storage {
432 SparseStorage::CSR(csr) => csr.multiply_vector(x, result),
433 SparseStorage::CSC(csc) => csc.multiply_vector(x, result),
434 SparseStorage::COO(coo) => coo.multiply_vector(x, result),
435 SparseStorage::Graph(graph) => graph.multiply_vector(x, result),
436 }
437
438 Ok(())
439 }
440
441 fn multiply_vector_add(&self, x: &[Precision], result: &mut [Precision]) -> Result<()> {
442 if x.len() != self.cols {
443 return Err(SolverError::DimensionMismatch {
444 expected: self.cols,
445 actual: x.len(),
446 operation: "matrix_vector_multiply_add".to_string(),
447 });
448 }
449 if result.len() != self.rows {
450 return Err(SolverError::DimensionMismatch {
451 expected: self.rows,
452 actual: result.len(),
453 operation: "matrix_vector_multiply_add".to_string(),
454 });
455 }
456
457 match &self.storage {
458 SparseStorage::CSR(csr) => csr.multiply_vector_add(x, result),
459 SparseStorage::CSC(csc) => csc.multiply_vector_add(x, result),
460 SparseStorage::COO(coo) => coo.multiply_vector_add(x, result),
461 SparseStorage::Graph(graph) => graph.multiply_vector_add(x, result),
462 }
463
464 Ok(())
465 }
466
467 fn is_diagonally_dominant(&self) -> bool {
468 for row in 0..self.rows {
469 let mut diagonal = 0.0;
470 let mut off_diagonal_sum = 0.0;
471
472 for (col, value) in self.row_iter(row) {
473 if col as usize == row {
474 diagonal = value.abs();
475 } else {
476 off_diagonal_sum += value.abs();
477 }
478 }
479
480 if diagonal < off_diagonal_sum {
481 return false;
482 }
483 }
484 true
485 }
486
487 fn diagonal_dominance_factor(&self) -> Option<Precision> {
488 let mut min_factor = Precision::INFINITY;
489
490 for row in 0..self.rows {
491 let mut diagonal = 0.0;
492 let mut off_diagonal_sum = 0.0;
493
494 for (col, value) in self.row_iter(row) {
495 if col as usize == row {
496 diagonal = value.abs();
497 } else {
498 off_diagonal_sum += value.abs();
499 }
500 }
501
502 if off_diagonal_sum > 0.0 {
503 let factor = diagonal / off_diagonal_sum;
504 min_factor = min_factor.min(factor);
505 }
506 }
507
508 if min_factor.is_finite() {
509 Some(min_factor)
510 } else {
511 None
512 }
513 }
514
515 fn nnz(&self) -> usize {
516 match &self.storage {
517 SparseStorage::CSR(csr) => csr.nnz(),
518 SparseStorage::CSC(csc) => csc.nnz(),
519 SparseStorage::COO(coo) => coo.nnz(),
520 SparseStorage::Graph(graph) => graph.nnz(),
521 }
522 }
523
524 fn sparsity_info(&self) -> SparsityInfo {
525 let mut info = SparsityInfo::new(self.nnz(), self.rows, self.cols);
526
527 let mut max_nnz_per_row = 0;
529 for row in 0..self.rows {
530 let row_nnz = self.row_iter(row).count();
531 max_nnz_per_row = max_nnz_per_row.max(row_nnz);
532 }
533 info.max_nnz_per_row = max_nnz_per_row;
534
535 let mut max_bandwidth = 0;
537 for (r, c, _) in self.to_triplets().unwrap_or_default() {
538 let bandwidth = if r > c { r - c } else { c - r };
539 max_bandwidth = max_bandwidth.max(bandwidth);
540 }
541 info.bandwidth = Some(max_bandwidth);
542 info.is_banded = max_bandwidth < self.rows / 4; info
545 }
546
547 fn conditioning_info(&self) -> ConditioningInfo {
548 ConditioningInfo {
549 condition_number: None, is_diagonally_dominant: self.is_diagonally_dominant(),
551 diagonal_dominance_factor: self.diagonal_dominance_factor(),
552 spectral_radius: Some(self.spectral_radius_estimate()),
553 is_positive_definite: None, }
555 }
556
557 fn format_name(&self) -> &'static str {
558 match self.format {
559 SparseFormat::CSR => "CSR",
560 SparseFormat::CSC => "CSC",
561 SparseFormat::COO => "COO",
562 SparseFormat::GraphAdjacency => "GraphAdjacency",
563 }
564 }
565}
566
567impl fmt::Display for SparseMatrix {
568 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
569 write!(f, "{}x{} sparse matrix ({} format, {} nnz)",
570 self.rows, self.cols, self.format_name(), self.nnz())
571 }
572}
573
574#[cfg(all(test, feature = "std"))]
575mod tests {
576 use super::*;
577
578 #[test]
579 fn test_matrix_creation() {
580 let triplets = vec![(0, 0, 4.0), (0, 1, 1.0), (1, 0, 2.0), (1, 1, 5.0)];
581 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
582
583 assert_eq!(matrix.rows(), 2);
584 assert_eq!(matrix.cols(), 2);
585 assert_eq!(matrix.nnz(), 4);
586 assert!(matrix.is_diagonally_dominant());
587 }
588
589 #[test]
590 fn test_matrix_vector_multiply() {
591 let triplets = vec![(0, 0, 2.0), (0, 1, 1.0), (1, 0, 1.0), (1, 1, 3.0)];
592 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
593
594 let x = vec![1.0, 2.0];
595 let mut result = vec![0.0; 2];
596
597 matrix.multiply_vector(&x, &mut result).unwrap();
598
599 assert_eq!(result, vec![4.0, 7.0]); }
601
602 #[test]
603 fn test_diagonal_dominance() {
604 let triplets = vec![(0, 0, 5.0), (0, 1, 1.0), (1, 0, 2.0), (1, 1, 7.0)];
606 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
607 assert!(matrix.is_diagonally_dominant());
608
609 let triplets = vec![(0, 0, 1.0), (0, 1, 3.0), (1, 0, 2.0), (1, 1, 2.0)];
611 let matrix = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
612 assert!(!matrix.is_diagonally_dominant());
613 }
614
615 #[test]
616 fn test_format_conversion() {
617 let triplets = vec![(0, 0, 1.0), (0, 2, 2.0), (1, 1, 3.0)];
618 let mut matrix = SparseMatrix::from_triplets(triplets, 2, 3).unwrap();
619
620 assert_eq!(matrix.format(), SparseFormat::CSR);
621
622 matrix.convert_to_format(SparseFormat::CSC).unwrap();
623 assert_eq!(matrix.format(), SparseFormat::CSC);
624
625 matrix.convert_to_format(SparseFormat::GraphAdjacency).unwrap();
626 assert_eq!(matrix.format(), SparseFormat::GraphAdjacency);
627 }
628}