Skip to main content

russell_sparse/
complex_coo_matrix.rs

1use crate::StrError;
2use crate::{ComplexCooMatrix, CooMatrix};
3use russell_lab::cpx;
4
5impl ComplexCooMatrix {
6    /// Assigns this matrix to the values of another real matrix (scaled)
7    ///
8    /// Performs:
9    ///
10    /// ```text
11    /// this = (α + βi) · other
12    /// ```
13    ///
14    /// Thus:
15    ///
16    /// ```text
17    /// this[p].real = α · other[p]
18    /// this[p].imag = β · other[p]
19    ///
20    /// other[p] ∈ Reals
21    /// p = [0, nnz(other)]
22    /// ```
23    ///
24    /// **Warning:** make sure to allocate `max_nnz ≥ nnz(other)`.
25    pub fn assign_real(&mut self, alpha: f64, beta: f64, other: &CooMatrix) -> Result<(), StrError> {
26        if other.nrow != self.nrow {
27            return Err("matrices must have the same nrow");
28        }
29        if other.ncol != self.ncol {
30            return Err("matrices must have the same ncol");
31        }
32        if other.symmetric != self.symmetric {
33            return Err("matrices must have the same symmetric flag");
34        }
35        self.reset();
36        for p in 0..other.nnz {
37            let i = other.indices_i[p] as usize;
38            let j = other.indices_j[p] as usize;
39            self.put(i, j, cpx!(alpha * other.values[p], beta * other.values[p]))?;
40        }
41        Ok(())
42    }
43
44    /// Puts the entries of another real matrix into this matrix
45    ///
46    /// Effectively, performs:
47    ///
48    /// ```text
49    /// this += (α + βi) · other
50    /// ```
51    ///
52    /// Thus:
53    ///
54    /// ```text
55    /// this[p].real += α · other[p]
56    /// this[p].imag += β · other[p]
57    ///
58    /// other[p] ∈ Reals
59    /// p = [0, nnz(other)]
60    /// ```
61    ///
62    /// # Arguments
63    ///
64    /// * `alpha` -- scaling factor
65    /// * `other` -- the other matrix to be added. It must be at most as large as `this`.
66    ///
67    /// # Requirements
68    ///
69    /// * `other.nrow ≤ this.nrow`
70    /// * `other.ncol ≤ this.ncol`
71    /// * `other.symmetric == this.symmetric`
72    ///
73    /// # Note
74    ///
75    /// * make sure to allocate `max_nnz ≥ nnz(this) + nnz(other)`.
76    pub fn add_real(&mut self, alpha: f64, beta: f64, other: &CooMatrix) -> Result<(), StrError> {
77        if other.nrow > self.nrow {
78            return Err("other.nrow must be ≤ this.nrow");
79        }
80        if other.ncol > self.ncol {
81            return Err("other.ncol must be ≤ this.ncol");
82        }
83        if other.symmetric != self.symmetric {
84            return Err("matrices must have the same symmetric flag");
85        }
86        for p in 0..other.nnz {
87            let i = other.indices_i[p] as usize;
88            let j = other.indices_j[p] as usize;
89            self.put(i, j, cpx!(alpha * other.values[p], beta * other.values[p]))?;
90        }
91        Ok(())
92    }
93}
94
95////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
96
97#[cfg(test)]
98mod tests {
99    use crate::{ComplexCooMatrix, CooMatrix, Sym};
100    use russell_lab::cpx;
101
102    #[test]
103    fn assign_real_capture_errors() {
104        let nnz_a = 1;
105        let nnz_b = 2; // wrong: must be ≤ nnz_a
106        let mut a_1x2 = ComplexCooMatrix::new(1, 2, nnz_a, Sym::No).unwrap();
107        let b_2x1 = CooMatrix::new(2, 1, nnz_b, Sym::No).unwrap();
108        let b_1x3 = CooMatrix::new(1, 3, nnz_b, Sym::No).unwrap();
109        let mut b_1x2 = CooMatrix::new(1, 2, nnz_b, Sym::No).unwrap();
110        a_1x2.put(0, 0, cpx!(123.0, 321.0)).unwrap();
111        b_1x2.put(0, 0, 456.0).unwrap();
112        b_1x2.put(0, 1, 654.0).unwrap();
113        assert_eq!(
114            a_1x2.assign_real(2.0, 3.0, &b_2x1).err(),
115            Some("matrices must have the same nrow")
116        );
117        assert_eq!(
118            a_1x2.assign_real(2.0, 3.0, &b_1x3).err(),
119            Some("matrices must have the same ncol")
120        );
121        assert_eq!(
122            a_1x2.assign_real(2.0, 3.0, &b_1x2).err(),
123            Some("COO matrix: max number of items has been reached")
124        );
125        let mut a_2x2 = ComplexCooMatrix::new(2, 2, 1, Sym::YesLower).unwrap();
126        let b_2x2 = CooMatrix::new(2, 2, 1, Sym::YesFull).unwrap();
127        assert_eq!(
128            a_2x2.assign_real(2.0, 3.0, &b_2x2).err(),
129            Some("matrices must have the same symmetric flag")
130        );
131    }
132
133    #[test]
134    fn assign_real_works() {
135        let nnz = 2;
136        let mut a = ComplexCooMatrix::new(3, 2, nnz, Sym::No).unwrap();
137        let mut b = CooMatrix::new(3, 2, nnz, Sym::No).unwrap();
138        a.put(2, 1, cpx!(1000.0, 2000.0)).unwrap();
139        b.put(0, 0, 10.0).unwrap();
140        b.put(2, 1, 20.0).unwrap();
141        assert_eq!(
142            format!("{}", a.as_dense()),
143            "┌                       ┐\n\
144             │       0+0i       0+0i │\n\
145             │       0+0i       0+0i │\n\
146             │       0+0i 1000+2000i │\n\
147             └                       ┘"
148        );
149        a.assign_real(3.0, 2.0, &b).unwrap();
150        assert_eq!(
151            format!("{}", a.as_dense()),
152            "┌               ┐\n\
153             │ 30+20i   0+0i │\n\
154             │   0+0i   0+0i │\n\
155             │   0+0i 60+40i │\n\
156             └               ┘"
157        );
158    }
159
160    #[test]
161    fn add_real_capture_errors() {
162        let nnz_a = 1;
163        let nnz_b = 1;
164        let mut a_1x2 = ComplexCooMatrix::new(1, 2, nnz_a /* + nnz_b */, Sym::No).unwrap();
165        let b_2x1 = CooMatrix::new(2, 1, nnz_b, Sym::No).unwrap();
166        let b_1x3 = CooMatrix::new(1, 3, nnz_b, Sym::No).unwrap();
167        let mut b_1x2 = CooMatrix::new(1, 2, nnz_b, Sym::No).unwrap();
168        a_1x2.put(0, 0, cpx!(123.0, 321.0)).unwrap();
169        b_1x2.put(0, 0, 456.0).unwrap();
170        assert_eq!(
171            a_1x2.add_real(2.0, 3.0, &b_2x1).err(),
172            Some("other.nrow must be ≤ this.nrow")
173        );
174        assert_eq!(
175            a_1x2.add_real(2.0, 3.0, &b_1x3).err(),
176            Some("other.ncol must be ≤ this.ncol")
177        );
178        assert_eq!(
179            a_1x2.add_real(2.0, 3.0, &b_1x2).err(),
180            Some("COO matrix: max number of items has been reached")
181        );
182        let mut a_2x2 = ComplexCooMatrix::new(2, 2, 1, Sym::YesLower).unwrap();
183        let b_2x2 = CooMatrix::new(2, 2, 1, Sym::YesFull).unwrap();
184        assert_eq!(
185            a_2x2.add_real(2.0, 3.0, &b_2x2).err(),
186            Some("matrices must have the same symmetric flag")
187        );
188    }
189
190    #[test]
191    fn add_real_works() {
192        let nnz_a = 1;
193        let nnz_b = 2;
194        let mut a = ComplexCooMatrix::new(3, 2, nnz_a + nnz_b, Sym::No).unwrap();
195        let mut b = CooMatrix::new(3, 2, nnz_b, Sym::No).unwrap();
196        a.put(2, 1, cpx!(1000.0, 2000.0)).unwrap();
197        b.put(0, 0, 10.0).unwrap();
198        b.put(2, 1, 20.0).unwrap();
199        assert_eq!(
200            format!("{}", a.as_dense()),
201            "┌                       ┐\n\
202             │       0+0i       0+0i │\n\
203             │       0+0i       0+0i │\n\
204             │       0+0i 1000+2000i │\n\
205             └                       ┘"
206        );
207        a.add_real(3.0, 2.0, &b).unwrap();
208        assert_eq!(
209            format!("{}", a.as_dense()),
210            "┌                       ┐\n\
211             │     30+20i       0+0i │\n\
212             │       0+0i       0+0i │\n\
213             │       0+0i 1060+2040i │\n\
214             └                       ┘"
215        );
216    }
217}