1use crate::attribute::Transpose;
6use crate::default::Default;
7use crate::math::Mat;
8use crate::math::Trans;
9use crate::matrix::ops::*;
10use crate::matrix::Matrix;
11use crate::vector::ops::*;
12use num_complex::{Complex32, Complex64};
13use std::ops::{Add, Mul};
14
15impl<'a, T> Add for &'a dyn Matrix<T>
16where
17 T: Axpy + Copy + Default,
18{
19 type Output = Mat<T>;
20
21 fn add(self, b: &dyn Matrix<T>) -> Mat<T> {
22 if self.cols() != b.cols() || self.rows() != b.rows() {
23 panic!("Dimension mismatch")
24 }
25
26 let scale = Default::one();
27 let mut result = Mat::from(self);
28 Axpy::axpy_mat(&scale, b, &mut result);
29 result
30 }
31}
32
33impl<'a, T> Mul<T> for &'a dyn Matrix<T>
34where
35 T: Sized + Copy + Scal,
36{
37 type Output = Mat<T>;
38
39 fn mul(self, alpha: T) -> Mat<T> {
40 let mut result = Mat::from(self);
41 Scal::scal_mat(&alpha, &mut result);
42 result
43 }
44}
45
46macro_rules! left_scale(($($t: ident), +) => (
47 $(
48 impl<'a> Mul<&'a dyn Matrix<$t>> for $t
49 {
50 type Output = Mat<$t>;
51
52 fn mul(self, x: &dyn Matrix<$t>) -> Mat<$t> {
53 let mut result = Mat::from(x);
54 Scal::scal_mat(&self, &mut result);
55 result
56 }
57 }
58 )+
59));
60
61left_scale!(f32, f64, Complex32, Complex64);
62
63impl<'a, T> Mul<&'a dyn Matrix<T>> for &'a dyn Matrix<T>
64where
65 T: Default + Gemm,
66{
67 type Output = Mat<T>;
68
69 fn mul(self, b: &dyn Matrix<T>) -> Mat<T> {
70 if self.cols() != b.rows() {
71 panic!("Dimension mismatch");
72 }
73
74 let n = self.rows() as usize;
75 let m = b.cols() as usize;
76 let mut result = Mat::new(n, m);
77 let t = Transpose::NoTrans;
78
79 Gemm::gemm(
80 &Default::one(),
81 t,
82 self,
83 t,
84 b,
85 &Default::zero(),
86 &mut result,
87 );
88 result
89 }
90}
91
92impl<'a, T> Mul<&'a dyn Matrix<T>> for Trans<&'a dyn Matrix<T>>
93where
94 T: Default + Gemm,
95{
96 type Output = Mat<T>;
97
98 fn mul(self, b: &dyn Matrix<T>) -> Mat<T> {
99 let (a, at) = match self {
100 Trans::T(a) => (a, Transpose::Trans),
101 Trans::H(a) => (a, Transpose::ConjTrans),
102 };
103
104 if a.rows() != b.rows() {
105 panic!("Dimension mismatch");
106 }
107
108 let n = a.cols() as usize;
109 let m = b.cols() as usize;
110 let mut result = Mat::new(n, m);
111 let bt = Transpose::NoTrans;
112
113 Gemm::gemm(&Default::one(), at, a, bt, b, &Default::zero(), &mut result);
114 result
115 }
116}
117
118impl<'a, T> Mul<Trans<&'a dyn Matrix<T>>> for &'a dyn Matrix<T>
119where
120 T: Default + Gemm,
121{
122 type Output = Mat<T>;
123
124 fn mul(self, rhs: Trans<&dyn Matrix<T>>) -> Mat<T> {
125 let (b, bt) = match rhs {
126 Trans::T(a) => (a, Transpose::Trans),
127 Trans::H(a) => (a, Transpose::ConjTrans),
128 };
129
130 if self.cols() != b.cols() {
131 panic!("Dimension mismatch");
132 }
133
134 let n = self.rows() as usize;
135 let m = b.rows() as usize;
136 let mut result = Mat::new(n, m);
137 let at = Transpose::NoTrans;
138
139 Gemm::gemm(
140 &Default::one(),
141 at,
142 self,
143 bt,
144 b,
145 &Default::zero(),
146 &mut result,
147 );
148 result
149 }
150}
151
152impl<'a, T> Mul<Trans<&'a dyn Matrix<T>>> for Trans<&'a dyn Matrix<T>>
153where
154 T: Default + Gemm,
155{
156 type Output = Mat<T>;
157
158 fn mul(self, rhs: Trans<&dyn Matrix<T>>) -> Mat<T> {
159 let (a, at) = match self {
160 Trans::T(a) => (a, Transpose::Trans),
161 Trans::H(a) => (a, Transpose::ConjTrans),
162 };
163
164 let (b, bt) = match rhs {
165 Trans::T(a) => (a, Transpose::Trans),
166 Trans::H(a) => (a, Transpose::ConjTrans),
167 };
168
169 if self.rows() != b.cols() {
170 panic!("Dimension mismatch");
171 }
172
173 let n = self.cols() as usize;
174 let m = b.rows() as usize;
175 let mut result = Mat::new(n, m);
176
177 Gemm::gemm(&Default::one(), at, a, bt, b, &Default::zero(), &mut result);
178 result
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use crate::math::Marker::T;
185 use crate::math::Mat;
186 use crate::Matrix;
187
188 #[test]
189 fn add() {
190 let a = mat![1.0, 2.0; 3.0, 4.0];
191 let b = mat![-1.0, 3.0; 1.0, 1.0];
192
193 let c = {
194 let ar = &a as &dyn Matrix<_>;
195 let br = &b as &dyn Matrix<_>;
196 ar + br
197 };
198
199 assert_eq!(c, mat![0.0, 5.0; 4.0, 5.0]);
200 }
201
202 #[test]
203 fn scale() {
204 let x = mat![1f32, 2f32; 3f32, 4f32];
205 let xr = &x as &dyn Matrix<_>;
206
207 let y = xr * 3.0;
208 let z = 3.0 * xr;
209 assert_eq!(y, mat![3f32, 6f32; 9f32, 12f32]);
210 assert_eq!(z, y);
211 }
212
213 #[test]
214 fn mul() {
215 let a = mat![1.0, 2.0; 3.0, 4.0];
216 let b = mat![-1.0, 3.0; 1.0, 1.0];
217
218 let c = {
219 let ar = &a as &dyn Matrix<_>;
220 let br = &b as &dyn Matrix<_>;
221 ar * br
222 };
223
224 assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
225 }
226
227 #[test]
228 fn left_mul_trans() {
229 let a = mat![1.0, 3.0; 2.0, 4.0];
230 let b = mat![-1.0, 3.0; 1.0, 1.0];
231
232 let c = {
233 let ar = &a as &dyn Matrix<_>;
234 let br = &b as &dyn Matrix<_>;
235 (ar ^ T) * br
236 };
237
238 assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
239 }
240
241 #[test]
242 fn right_mul_trans() {
243 let a = mat![1.0, 2.0; 3.0, 4.0];
244 let b = mat![-1.0, 1.0; 3.0, 1.0];
245
246 let c = {
247 let ar = &a as &dyn Matrix<_>;
248 let br = &b as &dyn Matrix<_>;
249 ar * (br ^ T)
250 };
251
252 assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
253 }
254
255 #[test]
256 fn mul_trans() {
257 let a = mat![1.0, 3.0; 2.0, 4.0];
258 let b = mat![-1.0, 1.0; 3.0, 1.0];
259
260 let c = {
261 let ar = &a as &dyn Matrix<_>;
262 let br = &b as &dyn Matrix<_>;
263 (ar ^ T) * (br ^ T)
264 };
265
266 assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
267 }
268}