spalinalg/csc/ops/
add.rs

1use std::ops::Add;
2
3use crate::{scalar::Scalar, CscMatrix};
4
5impl<T: Scalar> Add for &CscMatrix<T> {
6    type Output = CscMatrix<T>;
7
8    fn add(self, rhs: Self) -> Self::Output {
9        assert_eq!(self.nrows(), rhs.nrows());
10        assert_eq!(self.ncols(), rhs.ncols());
11
12        // Transpose inputs
13        let (lhs, rhs) = (self.transpose(), rhs.transpose());
14
15        // Allocate output
16        let mut colptr = Vec::with_capacity(self.ncols() + 1);
17        let cap = lhs.nnz() + rhs.nnz();
18        let mut rowind = Vec::with_capacity(cap);
19        let mut values = Vec::with_capacity(cap);
20
21        // Allocate workspace
22        let mut set = vec![0; lhs.nrows()];
23        let mut vec = vec![T::zero(); lhs.nrows()];
24
25        // Addition
26        let mut nz = 0;
27        for col in 0..lhs.ncols() {
28            colptr.push(nz);
29            for ptr in lhs.colptr[col]..lhs.colptr[col + 1] {
30                let row = lhs.rowind[ptr];
31                if set[row] < col + 1 {
32                    set[row] = col + 1;
33                    rowind.push(row);
34                    vec[row] = lhs.values[ptr];
35                    nz += 1;
36                } else {
37                    vec[row] += lhs.values[ptr];
38                }
39            }
40            for ptr in rhs.colptr[col]..rhs.colptr[col + 1] {
41                let row = rhs.rowind[ptr];
42                if set[row] < col + 1 {
43                    set[row] = col + 1;
44                    rowind.push(row);
45                    vec[row] = rhs.values[ptr];
46                    nz += 1;
47                } else {
48                    vec[row] += rhs.values[ptr];
49                }
50            }
51            for ptr in colptr[col]..nz {
52                let value = vec[rowind[ptr]];
53                values.push(value)
54            }
55        }
56        colptr.push(nz);
57
58        // Construct matrix
59        let output = CscMatrix {
60            nrows: self.nrows(),
61            ncols: self.ncols(),
62            colptr,
63            rowind,
64            values,
65        };
66
67        // Transpose output
68        output.transpose()
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn add() {
78        let lhs = CscMatrix::new(
79            4,
80            4,
81            vec![0, 2, 4, 6, 7],
82            vec![0, 1, 2, 3, 1, 3, 3],
83            vec![1.0, 2.0, 4.0, 5.0, 3.0, 6.0, 7.0],
84        );
85        let rhs = CscMatrix::new(
86            4,
87            4,
88            vec![0, 1, 2, 4, 5],
89            vec![0, 3, 0, 1, 2],
90            vec![2.0, 6.0, 4.0, 8.0, 10.0],
91        );
92        let mat = &lhs + &rhs;
93        assert_eq!(mat.nrows, 4);
94        assert_eq!(mat.ncols, 4);
95        assert_eq!(mat.colptr, [0, 2, 4, 7, 9]);
96        assert_eq!(mat.rowind, [0, 1, 2, 3, 0, 1, 3, 2, 3]);
97        assert_eq!(mat.values, [3.0, 2.0, 4.0, 11.0, 4.0, 11.0, 6.0, 10.0, 7.0]);
98        assert_eq!(mat.colptr.capacity(), mat.ncols() + 1);
99        assert_eq!(mat.rowind.capacity(), mat.nnz());
100        assert_eq!(mat.values.capacity(), mat.nnz());
101    }
102}