1use crate::StrError;
2use crate::{ComplexCooMatrix, CooMatrix};
3use russell_lab::cpx;
4
5impl ComplexCooMatrix {
6 pub fn assign_real(&mut self, alpha: f64, beta: f64, other: &CooMatrix) -> Result<(), StrError> {
26 if other.nrow != self.nrow {
27 return Err("matrices must have the same nrow");
28 }
29 if other.ncol != self.ncol {
30 return Err("matrices must have the same ncol");
31 }
32 if other.symmetric != self.symmetric {
33 return Err("matrices must have the same symmetric flag");
34 }
35 self.reset();
36 for p in 0..other.nnz {
37 let i = other.indices_i[p] as usize;
38 let j = other.indices_j[p] as usize;
39 self.put(i, j, cpx!(alpha * other.values[p], beta * other.values[p]))?;
40 }
41 Ok(())
42 }
43
44 pub fn add_real(&mut self, alpha: f64, beta: f64, other: &CooMatrix) -> Result<(), StrError> {
77 if other.nrow > self.nrow {
78 return Err("other.nrow must be ≤ this.nrow");
79 }
80 if other.ncol > self.ncol {
81 return Err("other.ncol must be ≤ this.ncol");
82 }
83 if other.symmetric != self.symmetric {
84 return Err("matrices must have the same symmetric flag");
85 }
86 for p in 0..other.nnz {
87 let i = other.indices_i[p] as usize;
88 let j = other.indices_j[p] as usize;
89 self.put(i, j, cpx!(alpha * other.values[p], beta * other.values[p]))?;
90 }
91 Ok(())
92 }
93}
94
95#[cfg(test)]
98mod tests {
99 use crate::{ComplexCooMatrix, CooMatrix, Sym};
100 use russell_lab::cpx;
101
102 #[test]
103 fn assign_real_capture_errors() {
104 let nnz_a = 1;
105 let nnz_b = 2; let mut a_1x2 = ComplexCooMatrix::new(1, 2, nnz_a, Sym::No).unwrap();
107 let b_2x1 = CooMatrix::new(2, 1, nnz_b, Sym::No).unwrap();
108 let b_1x3 = CooMatrix::new(1, 3, nnz_b, Sym::No).unwrap();
109 let mut b_1x2 = CooMatrix::new(1, 2, nnz_b, Sym::No).unwrap();
110 a_1x2.put(0, 0, cpx!(123.0, 321.0)).unwrap();
111 b_1x2.put(0, 0, 456.0).unwrap();
112 b_1x2.put(0, 1, 654.0).unwrap();
113 assert_eq!(
114 a_1x2.assign_real(2.0, 3.0, &b_2x1).err(),
115 Some("matrices must have the same nrow")
116 );
117 assert_eq!(
118 a_1x2.assign_real(2.0, 3.0, &b_1x3).err(),
119 Some("matrices must have the same ncol")
120 );
121 assert_eq!(
122 a_1x2.assign_real(2.0, 3.0, &b_1x2).err(),
123 Some("COO matrix: max number of items has been reached")
124 );
125 let mut a_2x2 = ComplexCooMatrix::new(2, 2, 1, Sym::YesLower).unwrap();
126 let b_2x2 = CooMatrix::new(2, 2, 1, Sym::YesFull).unwrap();
127 assert_eq!(
128 a_2x2.assign_real(2.0, 3.0, &b_2x2).err(),
129 Some("matrices must have the same symmetric flag")
130 );
131 }
132
133 #[test]
134 fn assign_real_works() {
135 let nnz = 2;
136 let mut a = ComplexCooMatrix::new(3, 2, nnz, Sym::No).unwrap();
137 let mut b = CooMatrix::new(3, 2, nnz, Sym::No).unwrap();
138 a.put(2, 1, cpx!(1000.0, 2000.0)).unwrap();
139 b.put(0, 0, 10.0).unwrap();
140 b.put(2, 1, 20.0).unwrap();
141 assert_eq!(
142 format!("{}", a.as_dense()),
143 "┌ ┐\n\
144 │ 0+0i 0+0i │\n\
145 │ 0+0i 0+0i │\n\
146 │ 0+0i 1000+2000i │\n\
147 └ ┘"
148 );
149 a.assign_real(3.0, 2.0, &b).unwrap();
150 assert_eq!(
151 format!("{}", a.as_dense()),
152 "┌ ┐\n\
153 │ 30+20i 0+0i │\n\
154 │ 0+0i 0+0i │\n\
155 │ 0+0i 60+40i │\n\
156 └ ┘"
157 );
158 }
159
160 #[test]
161 fn add_real_capture_errors() {
162 let nnz_a = 1;
163 let nnz_b = 1;
164 let mut a_1x2 = ComplexCooMatrix::new(1, 2, nnz_a , Sym::No).unwrap();
165 let b_2x1 = CooMatrix::new(2, 1, nnz_b, Sym::No).unwrap();
166 let b_1x3 = CooMatrix::new(1, 3, nnz_b, Sym::No).unwrap();
167 let mut b_1x2 = CooMatrix::new(1, 2, nnz_b, Sym::No).unwrap();
168 a_1x2.put(0, 0, cpx!(123.0, 321.0)).unwrap();
169 b_1x2.put(0, 0, 456.0).unwrap();
170 assert_eq!(
171 a_1x2.add_real(2.0, 3.0, &b_2x1).err(),
172 Some("other.nrow must be ≤ this.nrow")
173 );
174 assert_eq!(
175 a_1x2.add_real(2.0, 3.0, &b_1x3).err(),
176 Some("other.ncol must be ≤ this.ncol")
177 );
178 assert_eq!(
179 a_1x2.add_real(2.0, 3.0, &b_1x2).err(),
180 Some("COO matrix: max number of items has been reached")
181 );
182 let mut a_2x2 = ComplexCooMatrix::new(2, 2, 1, Sym::YesLower).unwrap();
183 let b_2x2 = CooMatrix::new(2, 2, 1, Sym::YesFull).unwrap();
184 assert_eq!(
185 a_2x2.add_real(2.0, 3.0, &b_2x2).err(),
186 Some("matrices must have the same symmetric flag")
187 );
188 }
189
190 #[test]
191 fn add_real_works() {
192 let nnz_a = 1;
193 let nnz_b = 2;
194 let mut a = ComplexCooMatrix::new(3, 2, nnz_a + nnz_b, Sym::No).unwrap();
195 let mut b = CooMatrix::new(3, 2, nnz_b, Sym::No).unwrap();
196 a.put(2, 1, cpx!(1000.0, 2000.0)).unwrap();
197 b.put(0, 0, 10.0).unwrap();
198 b.put(2, 1, 20.0).unwrap();
199 assert_eq!(
200 format!("{}", a.as_dense()),
201 "┌ ┐\n\
202 │ 0+0i 0+0i │\n\
203 │ 0+0i 0+0i │\n\
204 │ 0+0i 1000+2000i │\n\
205 └ ┘"
206 );
207 a.add_real(3.0, 2.0, &b).unwrap();
208 assert_eq!(
209 format!("{}", a.as_dense()),
210 "┌ ┐\n\
211 │ 30+20i 0+0i │\n\
212 │ 0+0i 0+0i │\n\
213 │ 0+0i 1060+2040i │\n\
214 └ ┘"
215 );
216 }
217}