1use crate::coo::CooMatrix;
8use crate::coo_array::CooArray;
9use crate::error::{SparseError, SparseResult};
10use crate::sparray::SparseArray;
11use num_traits::Float;
12use std::fmt::Debug;
13use std::ops::{Add, Div, Mul, Sub};
14
15#[derive(Debug, Clone)]
25pub struct SymCooMatrix<T>
26where
27 T: Float + Debug + Copy,
28{
29 pub data: Vec<T>,
31
32 pub rows: Vec<usize>,
34
35 pub cols: Vec<usize>,
37
38 pub shape: (usize, usize),
40}
41
42impl<T> SymCooMatrix<T>
43where
44 T: Float + Debug + Copy,
45{
46 pub fn new(
67 data: Vec<T>,
68 rows: Vec<usize>,
69 cols: Vec<usize>,
70 shape: (usize, usize),
71 ) -> SparseResult<Self> {
72 let (nrows, ncols) = shape;
73
74 if nrows != ncols {
76 return Err(SparseError::ValueError(
77 "Symmetric matrix must be square".to_string(),
78 ));
79 }
80
81 let nnz = data.len();
83 if rows.len() != nnz || cols.len() != nnz {
84 return Err(SparseError::ValueError(format!(
85 "Data ({}), row ({}) and column ({}) arrays must have same length",
86 nnz,
87 rows.len(),
88 cols.len()
89 )));
90 }
91
92 for i in 0..nnz {
94 let row = rows[i];
95 let col = cols[i];
96
97 if row >= nrows {
98 return Err(SparseError::IndexOutOfBounds {
99 index: (row, 0),
100 shape: (nrows, ncols),
101 });
102 }
103
104 if col >= ncols {
105 return Err(SparseError::IndexOutOfBounds {
106 index: (row, col),
107 shape: (nrows, ncols),
108 });
109 }
110
111 if col > row {
113 return Err(SparseError::ValueError(
114 "Symmetric COO should only store the lower triangular part".to_string(),
115 ));
116 }
117 }
118
119 Ok(Self {
120 data,
121 rows,
122 cols,
123 shape,
124 })
125 }
126
127 pub fn from_coo(matrix: &CooMatrix<T>) -> SparseResult<Self> {
140 let (rows, cols) = matrix.shape();
141
142 if rows != cols {
144 return Err(SparseError::ValueError(
145 "Symmetric matrix must be square".to_string(),
146 ));
147 }
148
149 if !Self::is_symmetric(matrix) {
151 return Err(SparseError::ValueError(
152 "Matrix must be symmetric to convert to SymCOO format".to_string(),
153 ));
154 }
155
156 let mut data = Vec::new();
158 let mut row_indices = Vec::new();
159 let mut col_indices = Vec::new();
160
161 let rows_vec = matrix.row_indices();
162 let cols_vec = matrix.col_indices();
163 let data_vec = matrix.data();
164
165 for i in 0..data_vec.len() {
166 let row = rows_vec[i];
167 let col = cols_vec[i];
168
169 if col <= row {
171 data.push(data_vec[i]);
172 row_indices.push(row);
173 col_indices.push(col);
174 }
175 }
176
177 Ok(Self {
178 data,
179 rows: row_indices,
180 cols: col_indices,
181 shape: (rows, cols),
182 })
183 }
184
185 pub fn is_symmetric(matrix: &CooMatrix<T>) -> bool {
195 let (rows, cols) = matrix.shape();
196
197 if rows != cols {
199 return false;
200 }
201
202 let dense = matrix.to_dense();
204
205 for i in 0..rows {
206 for j in 0..i {
207 let diff = (dense[i][j] - dense[j][i]).abs();
210 let epsilon = T::epsilon() * T::from(100.0).unwrap();
211 if diff > epsilon {
212 return false;
213 }
214 }
215 }
216
217 true
218 }
219
220 pub fn shape(&self) -> (usize, usize) {
226 self.shape
227 }
228
229 pub fn nnz_stored(&self) -> usize {
235 self.data.len()
236 }
237
238 pub fn nnz(&self) -> usize {
244 let mut count = 0;
245
246 for i in 0..self.data.len() {
247 let row = self.rows[i];
248 let col = self.cols[i];
249
250 if row == col {
251 count += 1;
253 } else {
254 count += 2;
256 }
257 }
258
259 count
260 }
261
262 pub fn get(&self, row: usize, col: usize) -> T {
273 if row >= self.shape.0 || col >= self.shape.1 {
275 return T::zero();
276 }
277
278 let (actual_row, actual_col) = if row < col { (col, row) } else { (row, col) };
281
282 for i in 0..self.data.len() {
284 if self.rows[i] == actual_row && self.cols[i] == actual_col {
285 return self.data[i];
286 }
287 }
288
289 T::zero()
290 }
291
292 pub fn to_coo(&self) -> SparseResult<CooMatrix<T>> {
298 let mut data = Vec::new();
299 let mut rows = Vec::new();
300 let mut cols = Vec::new();
301
302 data.extend_from_slice(&self.data);
304 rows.extend_from_slice(&self.rows);
305 cols.extend_from_slice(&self.cols);
306
307 for i in 0..self.data.len() {
309 let row = self.rows[i];
310 let col = self.cols[i];
311
312 if row != col {
314 data.push(self.data[i]);
316 rows.push(col);
317 cols.push(row);
318 }
319 }
320
321 CooMatrix::new(data, rows, cols, self.shape)
322 }
323
324 pub fn to_dense(&self) -> Vec<Vec<T>> {
330 let n = self.shape.0;
331 let mut dense = vec![vec![T::zero(); n]; n];
332
333 for i in 0..self.data.len() {
335 let row = self.rows[i];
336 let col = self.cols[i];
337 dense[row][col] = self.data[i];
338
339 if row != col {
341 dense[col][row] = self.data[i];
342 }
343 }
344
345 dense
346 }
347}
348
349#[derive(Debug, Clone)]
351pub struct SymCooArray<T>
352where
353 T: Float + Debug + Copy,
354{
355 inner: SymCooMatrix<T>,
357}
358
359impl<T> SymCooArray<T>
360where
361 T: Float
362 + Debug
363 + Copy
364 + 'static
365 + Add<Output = T>
366 + Sub<Output = T>
367 + Mul<Output = T>
368 + Div<Output = T>,
369{
370 pub fn new(matrix: SymCooMatrix<T>) -> Self {
380 Self { inner: matrix }
381 }
382
383 pub fn from_triplets(
397 rows: &[usize],
398 cols: &[usize],
399 data: &[T],
400 shape: (usize, usize),
401 enforce_symmetric: bool,
402 ) -> SparseResult<Self> {
403 if shape.0 != shape.1 {
404 return Err(SparseError::ValueError(
405 "Symmetric matrix must be square".to_string(),
406 ));
407 }
408
409 if !enforce_symmetric {
410 let n = shape.0;
412 let mut dense = vec![vec![T::zero(); n]; n];
413 let nnz = data.len().min(rows.len().min(cols.len()));
414
415 for i in 0..nnz {
417 let row = rows[i];
418 let col = cols[i];
419
420 if row >= n || col >= n {
421 return Err(SparseError::IndexOutOfBounds {
422 index: (row, col),
423 shape,
424 });
425 }
426
427 dense[row][col] = data[i];
428 }
429
430 for i in 0..n {
432 for j in 0..i {
433 if (dense[i][j] - dense[j][i]).abs() > T::epsilon() {
434 return Err(SparseError::ValueError(
435 "Input is not symmetric. Use enforce_symmetric=true to force symmetry"
436 .to_string(),
437 ));
438 }
439 }
440 }
441
442 let mut sym_data = Vec::new();
444 let mut sym_rows = Vec::new();
445 let mut sym_cols = Vec::new();
446
447 for (i, row) in dense.iter().enumerate().take(n) {
448 for (j, &val) in row.iter().enumerate().take(i + 1) {
449 if val != T::zero() {
450 sym_data.push(val);
451 sym_rows.push(i);
452 sym_cols.push(j);
453 }
454 }
455 }
456
457 let sym_coo = SymCooMatrix::new(sym_data, sym_rows, sym_cols, shape)?;
459 return Ok(Self { inner: sym_coo });
460 }
461
462 let n = shape.0;
464
465 let mut dense = vec![vec![T::zero(); n]; n];
467 let nnz = data.len();
468
469 for i in 0..nnz {
471 if i >= rows.len() || i >= cols.len() {
472 return Err(SparseError::ValueError(
473 "Inconsistent input arrays".to_string(),
474 ));
475 }
476
477 let row = rows[i];
478 let col = cols[i];
479
480 if row >= n || col >= n {
481 return Err(SparseError::IndexOutOfBounds {
482 index: (row, col),
483 shape: (n, n),
484 });
485 }
486
487 dense[row][col] = data[i];
488 }
489
490 for i in 0..n {
492 for j in 0..i {
493 let avg = (dense[i][j] + dense[j][i]) / (T::one() + T::one());
494 dense[i][j] = avg;
495 dense[j][i] = avg;
496 }
497 }
498
499 let mut sym_data = Vec::new();
501 let mut sym_rows = Vec::new();
502 let mut sym_cols = Vec::new();
503
504 for (i, row) in dense.iter().enumerate().take(n) {
505 for (j, &val) in row.iter().enumerate().take(i + 1) {
506 if val != T::zero() {
507 sym_data.push(val);
508 sym_rows.push(i);
509 sym_cols.push(j);
510 }
511 }
512 }
513
514 let sym_coo = SymCooMatrix::new(sym_data, sym_rows, sym_cols, shape)?;
515 Ok(Self { inner: sym_coo })
516 }
517
518 pub fn from_coo_array(array: &CooArray<T>) -> SparseResult<Self> {
528 let shape = array.shape();
529 let (rows, cols) = shape;
530
531 if rows != cols {
533 return Err(SparseError::ValueError(
534 "Symmetric matrix must be square".to_string(),
535 ));
536 }
537
538 let coo_matrix = CooMatrix::new(
540 array.get_data().to_vec(),
541 array.get_rows().to_vec(),
542 array.get_cols().to_vec(),
543 shape,
544 )?;
545
546 let sym_coo = SymCooMatrix::from_coo(&coo_matrix)?;
548
549 Ok(Self { inner: sym_coo })
550 }
551
552 pub fn inner(&self) -> &SymCooMatrix<T> {
558 &self.inner
559 }
560
561 pub fn data(&self) -> &[T] {
567 &self.inner.data
568 }
569
570 pub fn rows(&self) -> &[usize] {
576 &self.inner.rows
577 }
578
579 pub fn cols(&self) -> &[usize] {
585 &self.inner.cols
586 }
587
588 pub fn shape(&self) -> (usize, usize) {
594 self.inner.shape
595 }
596
597 pub fn to_coo_array(&self) -> SparseResult<CooArray<T>> {
603 let coo = self.inner.to_coo()?;
604
605 let rows = coo.row_indices();
607 let cols = coo.col_indices();
608 let data = coo.data();
609
610 CooArray::from_triplets(rows, cols, data, coo.shape(), false)
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use crate::sparray::SparseArray;
619
620 #[test]
621 fn test_sym_coo_creation() {
622 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
628 let rows = vec![0, 1, 1, 2, 2];
629 let cols = vec![0, 0, 1, 1, 2];
630
631 let sym = SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap();
632
633 assert_eq!(sym.shape(), (3, 3));
634 assert_eq!(sym.nnz_stored(), 5);
635
636 assert_eq!(sym.nnz(), 7);
638
639 assert_eq!(sym.get(0, 0), 2.0);
641 assert_eq!(sym.get(0, 1), 1.0);
642 assert_eq!(sym.get(1, 0), 1.0); assert_eq!(sym.get(1, 1), 2.0);
644 assert_eq!(sym.get(1, 2), 3.0);
645 assert_eq!(sym.get(2, 1), 3.0); assert_eq!(sym.get(2, 2), 1.0);
647 assert_eq!(sym.get(0, 2), 0.0);
648 assert_eq!(sym.get(2, 0), 0.0);
649 }
650
651 #[test]
652 fn test_sym_coo_from_standard() {
653 let data = vec![2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 1.0];
659 let rows = vec![0, 0, 1, 1, 1, 2, 2];
660 let cols = vec![0, 1, 0, 1, 2, 1, 2];
661
662 let coo = CooMatrix::new(data, rows, cols, (3, 3)).unwrap();
663 let sym = SymCooMatrix::from_coo(&coo).unwrap();
664
665 assert_eq!(sym.shape(), (3, 3));
666
667 let coo2 = sym.to_coo().unwrap();
669 let dense = coo2.to_dense();
670
671 assert_eq!(dense[0][0], 2.0);
673 assert_eq!(dense[0][1], 1.0);
674 assert_eq!(dense[0][2], 0.0);
675 assert_eq!(dense[1][0], 1.0);
676 assert_eq!(dense[1][1], 2.0);
677 assert_eq!(dense[1][2], 3.0);
678 assert_eq!(dense[2][0], 0.0);
679 assert_eq!(dense[2][1], 3.0);
680 assert_eq!(dense[2][2], 1.0);
681 }
682
683 #[test]
684 fn test_sym_coo_array() {
685 let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
687 let rows = vec![0, 1, 1, 2, 2];
688 let cols = vec![0, 0, 1, 1, 2];
689
690 let sym_matrix = SymCooMatrix::new(data, rows, cols, (3, 3)).unwrap();
691 let sym_array = SymCooArray::new(sym_matrix);
692
693 assert_eq!(sym_array.inner().shape(), (3, 3));
694
695 let coo_array = sym_array.to_coo_array().unwrap();
697
698 assert_eq!(coo_array.shape(), (3, 3));
700 assert_eq!(coo_array.get(0, 0), 2.0);
701 assert_eq!(coo_array.get(0, 1), 1.0);
702 assert_eq!(coo_array.get(1, 0), 1.0);
703 assert_eq!(coo_array.get(1, 1), 2.0);
704 assert_eq!(coo_array.get(1, 2), 3.0);
705 assert_eq!(coo_array.get(2, 1), 3.0);
706 assert_eq!(coo_array.get(2, 2), 1.0);
707 assert_eq!(coo_array.get(0, 2), 0.0);
708 assert_eq!(coo_array.get(2, 0), 0.0);
709 }
710
711 #[test]
712 fn test_sym_coo_array_from_triplets() {
713 let rows = vec![0, 1, 1, 2, 1, 0, 2];
716 let cols = vec![0, 1, 2, 2, 0, 1, 1];
717 let data = vec![2.0, 2.0, 3.0, 1.0, 1.0, 1.0, 3.0];
718
719 let sym_array = SymCooArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
720
721 assert_eq!(sym_array.shape(), (3, 3));
722
723 let rows2 = vec![0, 0, 1, 1, 2, 1];
725 let cols2 = vec![0, 1, 1, 2, 2, 0];
726 let data2 = vec![2.0, 1.0, 2.0, 3.0, 1.0, 2.0]; let sym_array2 = SymCooArray::from_triplets(&rows2, &cols2, &data2, (3, 3), true).unwrap();
729
730 assert_eq!(sym_array2.inner().get(1, 0), 1.5);
732 assert_eq!(sym_array2.inner().get(0, 1), 1.5);
733 }
734}