1use std::ops::Add;
2
3use crate::{scalar::Scalar, CscMatrix};
4
5impl<T: Scalar> Add for &CscMatrix<T> {
6 type Output = CscMatrix<T>;
7
8 fn add(self, rhs: Self) -> Self::Output {
9 assert_eq!(self.nrows(), rhs.nrows());
10 assert_eq!(self.ncols(), rhs.ncols());
11
12 let (lhs, rhs) = (self.transpose(), rhs.transpose());
14
15 let mut colptr = Vec::with_capacity(self.ncols() + 1);
17 let cap = lhs.nnz() + rhs.nnz();
18 let mut rowind = Vec::with_capacity(cap);
19 let mut values = Vec::with_capacity(cap);
20
21 let mut set = vec![0; lhs.nrows()];
23 let mut vec = vec![T::zero(); lhs.nrows()];
24
25 let mut nz = 0;
27 for col in 0..lhs.ncols() {
28 colptr.push(nz);
29 for ptr in lhs.colptr[col]..lhs.colptr[col + 1] {
30 let row = lhs.rowind[ptr];
31 if set[row] < col + 1 {
32 set[row] = col + 1;
33 rowind.push(row);
34 vec[row] = lhs.values[ptr];
35 nz += 1;
36 } else {
37 vec[row] += lhs.values[ptr];
38 }
39 }
40 for ptr in rhs.colptr[col]..rhs.colptr[col + 1] {
41 let row = rhs.rowind[ptr];
42 if set[row] < col + 1 {
43 set[row] = col + 1;
44 rowind.push(row);
45 vec[row] = rhs.values[ptr];
46 nz += 1;
47 } else {
48 vec[row] += rhs.values[ptr];
49 }
50 }
51 for ptr in colptr[col]..nz {
52 let value = vec[rowind[ptr]];
53 values.push(value)
54 }
55 }
56 colptr.push(nz);
57
58 let output = CscMatrix {
60 nrows: self.nrows(),
61 ncols: self.ncols(),
62 colptr,
63 rowind,
64 values,
65 };
66
67 output.transpose()
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn add() {
78 let lhs = CscMatrix::new(
79 4,
80 4,
81 vec![0, 2, 4, 6, 7],
82 vec![0, 1, 2, 3, 1, 3, 3],
83 vec![1.0, 2.0, 4.0, 5.0, 3.0, 6.0, 7.0],
84 );
85 let rhs = CscMatrix::new(
86 4,
87 4,
88 vec![0, 1, 2, 4, 5],
89 vec![0, 3, 0, 1, 2],
90 vec![2.0, 6.0, 4.0, 8.0, 10.0],
91 );
92 let mat = &lhs + &rhs;
93 assert_eq!(mat.nrows, 4);
94 assert_eq!(mat.ncols, 4);
95 assert_eq!(mat.colptr, [0, 2, 4, 7, 9]);
96 assert_eq!(mat.rowind, [0, 1, 2, 3, 0, 1, 3, 2, 3]);
97 assert_eq!(mat.values, [3.0, 2.0, 4.0, 11.0, 4.0, 11.0, 6.0, 10.0, 7.0]);
98 assert_eq!(mat.colptr.capacity(), mat.ncols() + 1);
99 assert_eq!(mat.rowind.capacity(), mat.nnz());
100 assert_eq!(mat.values.capacity(), mat.nnz());
101 }
102}