rust_blas/math/
matrix.rs

1// Copyright 2015 Michael Yang. All rights reserved.
2// Use of this source code is governed by a MIT-style
3// license that can be found in the LICENSE file.
4
5use crate::attribute::Transpose;
6use crate::default::Default;
7use crate::math::Mat;
8use crate::math::Trans;
9use crate::matrix::ops::*;
10use crate::matrix::Matrix;
11use crate::vector::ops::*;
12use num_complex::{Complex32, Complex64};
13use std::ops::{Add, Mul};
14
15impl<'a, T> Add for &'a dyn Matrix<T>
16where
17    T: Axpy + Copy + Default,
18{
19    type Output = Mat<T>;
20
21    fn add(self, b: &dyn Matrix<T>) -> Mat<T> {
22        if self.cols() != b.cols() || self.rows() != b.rows() {
23            panic!("Dimension mismatch")
24        }
25
26        let scale = Default::one();
27        let mut result = Mat::from(self);
28        Axpy::axpy_mat(&scale, b, &mut result);
29        result
30    }
31}
32
33impl<'a, T> Mul<T> for &'a dyn Matrix<T>
34where
35    T: Sized + Copy + Scal,
36{
37    type Output = Mat<T>;
38
39    fn mul(self, alpha: T) -> Mat<T> {
40        let mut result = Mat::from(self);
41        Scal::scal_mat(&alpha, &mut result);
42        result
43    }
44}
45
46macro_rules! left_scale(($($t: ident), +) => (
47    $(
48        impl<'a> Mul<&'a dyn Matrix<$t>> for $t
49        {
50            type Output = Mat<$t>;
51
52            fn mul(self, x: &dyn Matrix<$t>) -> Mat<$t> {
53                let mut result = Mat::from(x);
54                Scal::scal_mat(&self, &mut result);
55                result
56            }
57        }
58    )+
59));
60
61left_scale!(f32, f64, Complex32, Complex64);
62
63impl<'a, T> Mul<&'a dyn Matrix<T>> for &'a dyn Matrix<T>
64where
65    T: Default + Gemm,
66{
67    type Output = Mat<T>;
68
69    fn mul(self, b: &dyn Matrix<T>) -> Mat<T> {
70        if self.cols() != b.rows() {
71            panic!("Dimension mismatch");
72        }
73
74        let n = self.rows() as usize;
75        let m = b.cols() as usize;
76        let mut result = Mat::new(n, m);
77        let t = Transpose::NoTrans;
78
79        Gemm::gemm(
80            &Default::one(),
81            t,
82            self,
83            t,
84            b,
85            &Default::zero(),
86            &mut result,
87        );
88        result
89    }
90}
91
92impl<'a, T> Mul<&'a dyn Matrix<T>> for Trans<&'a dyn Matrix<T>>
93where
94    T: Default + Gemm,
95{
96    type Output = Mat<T>;
97
98    fn mul(self, b: &dyn Matrix<T>) -> Mat<T> {
99        let (a, at) = match self {
100            Trans::T(a) => (a, Transpose::Trans),
101            Trans::H(a) => (a, Transpose::ConjTrans),
102        };
103
104        if a.rows() != b.rows() {
105            panic!("Dimension mismatch");
106        }
107
108        let n = a.cols() as usize;
109        let m = b.cols() as usize;
110        let mut result = Mat::new(n, m);
111        let bt = Transpose::NoTrans;
112
113        Gemm::gemm(&Default::one(), at, a, bt, b, &Default::zero(), &mut result);
114        result
115    }
116}
117
118impl<'a, T> Mul<Trans<&'a dyn Matrix<T>>> for &'a dyn Matrix<T>
119where
120    T: Default + Gemm,
121{
122    type Output = Mat<T>;
123
124    fn mul(self, rhs: Trans<&dyn Matrix<T>>) -> Mat<T> {
125        let (b, bt) = match rhs {
126            Trans::T(a) => (a, Transpose::Trans),
127            Trans::H(a) => (a, Transpose::ConjTrans),
128        };
129
130        if self.cols() != b.cols() {
131            panic!("Dimension mismatch");
132        }
133
134        let n = self.rows() as usize;
135        let m = b.rows() as usize;
136        let mut result = Mat::new(n, m);
137        let at = Transpose::NoTrans;
138
139        Gemm::gemm(
140            &Default::one(),
141            at,
142            self,
143            bt,
144            b,
145            &Default::zero(),
146            &mut result,
147        );
148        result
149    }
150}
151
152impl<'a, T> Mul<Trans<&'a dyn Matrix<T>>> for Trans<&'a dyn Matrix<T>>
153where
154    T: Default + Gemm,
155{
156    type Output = Mat<T>;
157
158    fn mul(self, rhs: Trans<&dyn Matrix<T>>) -> Mat<T> {
159        let (a, at) = match self {
160            Trans::T(a) => (a, Transpose::Trans),
161            Trans::H(a) => (a, Transpose::ConjTrans),
162        };
163
164        let (b, bt) = match rhs {
165            Trans::T(a) => (a, Transpose::Trans),
166            Trans::H(a) => (a, Transpose::ConjTrans),
167        };
168
169        if self.rows() != b.cols() {
170            panic!("Dimension mismatch");
171        }
172
173        let n = self.cols() as usize;
174        let m = b.rows() as usize;
175        let mut result = Mat::new(n, m);
176
177        Gemm::gemm(&Default::one(), at, a, bt, b, &Default::zero(), &mut result);
178        result
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use crate::math::Marker::T;
185    use crate::math::Mat;
186    use crate::Matrix;
187
188    #[test]
189    fn add() {
190        let a = mat![1.0, 2.0; 3.0, 4.0];
191        let b = mat![-1.0, 3.0; 1.0, 1.0];
192
193        let c = {
194            let ar = &a as &dyn Matrix<_>;
195            let br = &b as &dyn Matrix<_>;
196            ar + br
197        };
198
199        assert_eq!(c, mat![0.0, 5.0; 4.0, 5.0]);
200    }
201
202    #[test]
203    fn scale() {
204        let x = mat![1f32, 2f32; 3f32, 4f32];
205        let xr = &x as &dyn Matrix<_>;
206
207        let y = xr * 3.0;
208        let z = 3.0 * xr;
209        assert_eq!(y, mat![3f32, 6f32; 9f32, 12f32]);
210        assert_eq!(z, y);
211    }
212
213    #[test]
214    fn mul() {
215        let a = mat![1.0, 2.0; 3.0, 4.0];
216        let b = mat![-1.0, 3.0; 1.0, 1.0];
217
218        let c = {
219            let ar = &a as &dyn Matrix<_>;
220            let br = &b as &dyn Matrix<_>;
221            ar * br
222        };
223
224        assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
225    }
226
227    #[test]
228    fn left_mul_trans() {
229        let a = mat![1.0, 3.0; 2.0, 4.0];
230        let b = mat![-1.0, 3.0; 1.0, 1.0];
231
232        let c = {
233            let ar = &a as &dyn Matrix<_>;
234            let br = &b as &dyn Matrix<_>;
235            (ar ^ T) * br
236        };
237
238        assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
239    }
240
241    #[test]
242    fn right_mul_trans() {
243        let a = mat![1.0, 2.0; 3.0, 4.0];
244        let b = mat![-1.0, 1.0; 3.0, 1.0];
245
246        let c = {
247            let ar = &a as &dyn Matrix<_>;
248            let br = &b as &dyn Matrix<_>;
249            ar * (br ^ T)
250        };
251
252        assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
253    }
254
255    #[test]
256    fn mul_trans() {
257        let a = mat![1.0, 3.0; 2.0, 4.0];
258        let b = mat![-1.0, 1.0; 3.0, 1.0];
259
260        let c = {
261            let ar = &a as &dyn Matrix<_>;
262            let br = &b as &dyn Matrix<_>;
263            (ar ^ T) * (br ^ T)
264        };
265
266        assert_eq!(c, mat![1.0, 5.0; 1.0, 13.0]);
267    }
268}