rust_blas/math/
matrix_vector.rs

1// Copyright 2015 Michael Yang. All rights reserved.
2// Use of this source code is governed by a MIT-style
3// license that can be found in the LICENSE file.
4
5use crate::attribute::Transpose;
6use crate::default::Default;
7use crate::math::Mat;
8use crate::math::Trans;
9use crate::matrix::Matrix;
10use crate::matrix_vector::ops::*;
11use crate::vector::Vector;
12use std::ops::Mul;
13
14impl<'a, T> Mul<&'a dyn Vector<T>> for &'a dyn Matrix<T>
15where
16    T: Default + Copy + Gemv,
17{
18    type Output = Vec<T>;
19
20    fn mul(self, x: &dyn Vector<T>) -> Vec<T> {
21        let n = self.rows() as usize;
22        let mut result = Vec::with_capacity(n);
23        unsafe {
24            result.set_len(n);
25        }
26        let scale = Default::one();
27        let clear = Default::zero();
28        let t = Transpose::NoTrans;
29
30        Gemv::gemv(t, &scale, self, x, &clear, &mut result);
31        result
32    }
33}
34
35impl<'a, T> Mul<Trans<&'a dyn Vector<T>>> for &'a dyn Vector<T>
36where
37    T: Default + Ger + Gerc + Clone,
38{
39    type Output = Mat<T>;
40
41    fn mul(self, x: Trans<&dyn Vector<T>>) -> Mat<T> {
42        let n = self.len() as usize;
43        let m = (*x).len() as usize;
44        let mut result = Mat::fill(Default::zero(), n, m);
45        let scale = Default::one();
46
47        match x {
48            Trans::T(v) => Ger::ger(&scale, self, v, &mut result),
49            Trans::H(v) => Gerc::gerc(&scale, self, v, &mut result),
50        }
51
52        result
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use crate::math::Marker::T;
59    use crate::math::Mat;
60    use crate::Matrix;
61    use crate::Vector;
62
63    #[test]
64    fn mul() {
65        let a = mat![2f32, -2.0; 2.0, -4.0];
66        let x = vec![2f32, 1.0];
67
68        let y = {
69            let ar = &a as &dyn Matrix<f32>;
70            let xr = &x as &dyn Vector<f32>;
71            ar * xr
72        };
73
74        assert_eq!(y, vec![2.0, 0.0]);
75    }
76
77    #[test]
78    fn outer() {
79        let x = vec![2.0, 1.0, 4.0];
80        let y = vec![3.0, 6.0, -1.0];
81
82        let a = {
83            let xr = &x as &dyn Vector<_>;
84            let yr = &y as &dyn Vector<_>;
85
86            xr * (yr ^ T)
87        };
88
89        let result = mat![  6.0, 12.0, -2.0;
90                            3.0, 6.0, -1.0;
91                            12.0, 24.0, -4.0];
92        assert_eq!(a, result);
93    }
94}