tensor_rs/tensor_impl/lapack_tensor/
elemwise.rs

1use crate::tensor_impl::gen_tensor::GenTensor;
2#[cfg(feature = "use-blas-lapack")]
3use super::blas_api::BlasAPI;
4
5
6#[cfg(feature = "use-blas-lapack")]
7macro_rules! blas_add {
8    ($a:ty, $b: ident) => {
9        pub fn $b(
10            x: &GenTensor<$a>,
11            y: &GenTensor<$a>,
12        ) -> GenTensor<$a> {
13            let real_x;
14            let mut real_y = y.get_data().clone();
15            let mut real_size = x.numel();
16            let real_x_vec;
17            if x.numel() == 1 && y.numel() > 1 {
18                real_x_vec = vec![x.get_data()[0]; y.numel()];
19                real_x = &real_x_vec;
20                real_size = y.numel();
21            } else if x.numel() > 1 && y.numel() == 1 {
22                real_x = x.get_data();
23                real_y = vec![real_y[0]; x.numel()];
24                real_size = x.numel();
25            } else if x.numel() == y.numel() {
26                real_x = x.get_data();
27            } else {
28                if x.numel() < y.numel() {
29                    panic!("right-hand broadcast only.");
30                }
31                if x.size().len() <= y.size().len() {
32                    panic!("unmatched dimension. {}, {}", x.size().len(), y.size().len());
33                }
34                for i in 0..y.size().len() {
35                    if y.size()[y.size().len()-i-1] != x.size()[x.size().len()-i-1] {
36                        panic!("unmatched size.");
37                    }
38                }
39                real_x = x.get_data();
40                real_y = real_y.repeat(x.numel()/y.numel());
41            }
42            
43            BlasAPI::<$a>::axpy(real_size,
44                                1.0 as $a,
45                                real_x, 1,
46                                &mut real_y, 1);
47            GenTensor::<$a>::new_move(real_y, x.size().clone())
48        }
49    }
50}
51
52#[cfg(feature = "use-blas-lapack")]
53blas_add!(f32, add_f32);
54
55#[cfg(feature = "use-blas-lapack")]
56blas_add!(f64, add_f64);
57
58
59#[cfg(feature = "use-blas-lapack")]
60macro_rules! blas_sub {
61    ($a:ty, $b: ident) => {
62        pub fn $b(
63            x: &GenTensor<$a>,
64            y: &GenTensor<$a>,
65        ) -> GenTensor<$a> {
66            if x.numel() == 1 && y.numel() > 1 {
67                let mut real_x_vec = vec![x.get_data()[0]; y.numel()];
68                let real_size = y.numel();
69                BlasAPI::<$a>::axpy(real_size,
70                                    -1.0 as $a,
71                                    y.get_data(), 1,
72                                    &mut real_x_vec, 1);
73                return GenTensor::<$a>::new_move(real_x_vec, y.size().clone());
74            } else if x.numel() > 1 && y.numel() == 1 {
75                let mut real_x_vec = x.get_data().clone();
76                let real_size = x.numel();
77                BlasAPI::<$a>::axpy(real_size,
78                                    -1.0 as $a,
79                                    y.get_data(), 1,
80                                    &mut real_x_vec, 1);
81                return GenTensor::<$a>::new_move(real_x_vec, y.size().clone());
82            } else if x.size() == y.size() {
83                let mut real_x_vec = x.get_data().clone();
84                let real_size = x.numel();
85                BlasAPI::<$a>::axpy(real_size,
86                                    -1.0 as $a,
87                                    y.get_data(), 1,
88                                    &mut real_x_vec, 1);
89                return GenTensor::<$a>::new_move(real_x_vec, y.size().clone());
90            } else {
91                if x.numel() < y.numel() {
92                    panic!("right-hand broadcast only.");
93                }
94                if x.size().len() <= y.size().len() {
95                    panic!("unmatched dimension and right-hand broadcast only. {}, {}",
96			   x.size().len(), y.size().len());
97                }
98                for i in 0..y.size().len() {
99                    if y.size()[y.size().len()-i-1] != x.size()[x.size().len()-i-1] {
100                        panic!("unmatched size.");
101                    }
102                }
103                let mut real_x_vec = x.get_data().clone();
104                let real_y_vec = y.get_data().repeat(x.numel()/y.numel());
105                let real_size = x.numel();
106                BlasAPI::<$a>::axpy(real_size,
107                                    -1.0 as $a,
108                                    &real_y_vec, 1,
109                                    &mut real_x_vec, 1);
110                return GenTensor::<$a>::new_move(real_x_vec, x.size().clone());
111            }
112        }
113    }
114}
115
116#[cfg(feature = "use-blas-lapack")]
117blas_sub!(f32, sub_f32);
118
119#[cfg(feature = "use-blas-lapack")]
120blas_sub!(f64, sub_f64);
121
122#[cfg(test)]
123mod tests {
124    use crate::tensor_impl::gen_tensor::GenTensor;
125    use super::*;
126
127    #[test]
128    #[cfg(feature = "use-blas-lapack")]
129    fn test_add() {
130        let a = GenTensor::<f32>::ones(&[1, 2, 3]);
131        let b = GenTensor::<f32>::ones(&[1, 2, 3]);
132        let c = add_f32(&a, &b);
133        let em = GenTensor::<f32>::new_raw(&[2.0, 2.0, 2.0, 2.0, 2.0, 2.0], &[1, 2, 3]);
134        assert_eq!(c, em);
135
136	let a = GenTensor::<f64>::ones(&[1, 2, 3]);
137        let b = GenTensor::<f64>::ones(&[1, 2, 3]);
138        let c = add_f64(&a, &b);
139        let em = GenTensor::<f64>::new_raw(&[2.0, 2.0, 2.0, 2.0, 2.0, 2.0], &[1, 2, 3]);
140        assert_eq!(c, em);
141
142	let a = GenTensor::<f64>::ones(&[1, 2, 3]);
143        let b = GenTensor::<f64>::ones(&[3]);
144        let c = add_f64(&a, &b);
145        let em = GenTensor::<f64>::new_raw(&[2.0, 2.0, 2.0, 2.0, 2.0, 2.0], &[1, 2, 3]);
146        assert_eq!(c, em);
147    }
148
149    #[test]
150    #[cfg(feature = "use-blas-lapack")]
151    fn test_sub() {
152        let a = GenTensor::<f32>::ones(&[1, 2, 3]);
153        let b = GenTensor::<f32>::ones(&[1, 2, 3]);
154        let c = sub_f32(&a, &b);
155        let em = GenTensor::<f32>::new_raw(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[1, 2, 3]);
156        assert_eq!(c, em);
157
158	let a = GenTensor::<f64>::ones(&[1, 2, 3]);
159        let b = GenTensor::<f64>::ones(&[1, 2, 3]);
160        let c = sub_f64(&a, &b);
161        let em = GenTensor::<f64>::new_raw(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[1, 2, 3]);
162        assert_eq!(c, em);
163
164	let a = GenTensor::<f64>::ones(&[1, 2, 3]);
165        let b = GenTensor::<f64>::ones(&[3]);
166        let c = sub_f64(&a, &b);
167        let em = GenTensor::<f64>::new_raw(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[1, 2, 3]);
168        assert_eq!(c, em);
169    }
170}