tensor_rs/tensor_impl/lapack_tensor/
elemwise.rs1use crate::tensor_impl::gen_tensor::GenTensor;
2#[cfg(feature = "use-blas-lapack")]
3use super::blas_api::BlasAPI;
4
5
6#[cfg(feature = "use-blas-lapack")]
7macro_rules! blas_add {
8 ($a:ty, $b: ident) => {
9 pub fn $b(
10 x: &GenTensor<$a>,
11 y: &GenTensor<$a>,
12 ) -> GenTensor<$a> {
13 let real_x;
14 let mut real_y = y.get_data().clone();
15 let mut real_size = x.numel();
16 let real_x_vec;
17 if x.numel() == 1 && y.numel() > 1 {
18 real_x_vec = vec![x.get_data()[0]; y.numel()];
19 real_x = &real_x_vec;
20 real_size = y.numel();
21 } else if x.numel() > 1 && y.numel() == 1 {
22 real_x = x.get_data();
23 real_y = vec![real_y[0]; x.numel()];
24 real_size = x.numel();
25 } else if x.numel() == y.numel() {
26 real_x = x.get_data();
27 } else {
28 if x.numel() < y.numel() {
29 panic!("right-hand broadcast only.");
30 }
31 if x.size().len() <= y.size().len() {
32 panic!("unmatched dimension. {}, {}", x.size().len(), y.size().len());
33 }
34 for i in 0..y.size().len() {
35 if y.size()[y.size().len()-i-1] != x.size()[x.size().len()-i-1] {
36 panic!("unmatched size.");
37 }
38 }
39 real_x = x.get_data();
40 real_y = real_y.repeat(x.numel()/y.numel());
41 }
42
43 BlasAPI::<$a>::axpy(real_size,
44 1.0 as $a,
45 real_x, 1,
46 &mut real_y, 1);
47 GenTensor::<$a>::new_move(real_y, x.size().clone())
48 }
49 }
50}
51
52#[cfg(feature = "use-blas-lapack")]
53blas_add!(f32, add_f32);
54
55#[cfg(feature = "use-blas-lapack")]
56blas_add!(f64, add_f64);
57
58
59#[cfg(feature = "use-blas-lapack")]
60macro_rules! blas_sub {
61 ($a:ty, $b: ident) => {
62 pub fn $b(
63 x: &GenTensor<$a>,
64 y: &GenTensor<$a>,
65 ) -> GenTensor<$a> {
66 if x.numel() == 1 && y.numel() > 1 {
67 let mut real_x_vec = vec![x.get_data()[0]; y.numel()];
68 let real_size = y.numel();
69 BlasAPI::<$a>::axpy(real_size,
70 -1.0 as $a,
71 y.get_data(), 1,
72 &mut real_x_vec, 1);
73 return GenTensor::<$a>::new_move(real_x_vec, y.size().clone());
74 } else if x.numel() > 1 && y.numel() == 1 {
75 let mut real_x_vec = x.get_data().clone();
76 let real_size = x.numel();
77 BlasAPI::<$a>::axpy(real_size,
78 -1.0 as $a,
79 y.get_data(), 1,
80 &mut real_x_vec, 1);
81 return GenTensor::<$a>::new_move(real_x_vec, y.size().clone());
82 } else if x.size() == y.size() {
83 let mut real_x_vec = x.get_data().clone();
84 let real_size = x.numel();
85 BlasAPI::<$a>::axpy(real_size,
86 -1.0 as $a,
87 y.get_data(), 1,
88 &mut real_x_vec, 1);
89 return GenTensor::<$a>::new_move(real_x_vec, y.size().clone());
90 } else {
91 if x.numel() < y.numel() {
92 panic!("right-hand broadcast only.");
93 }
94 if x.size().len() <= y.size().len() {
95 panic!("unmatched dimension and right-hand broadcast only. {}, {}",
96 x.size().len(), y.size().len());
97 }
98 for i in 0..y.size().len() {
99 if y.size()[y.size().len()-i-1] != x.size()[x.size().len()-i-1] {
100 panic!("unmatched size.");
101 }
102 }
103 let mut real_x_vec = x.get_data().clone();
104 let real_y_vec = y.get_data().repeat(x.numel()/y.numel());
105 let real_size = x.numel();
106 BlasAPI::<$a>::axpy(real_size,
107 -1.0 as $a,
108 &real_y_vec, 1,
109 &mut real_x_vec, 1);
110 return GenTensor::<$a>::new_move(real_x_vec, x.size().clone());
111 }
112 }
113 }
114}
115
116#[cfg(feature = "use-blas-lapack")]
117blas_sub!(f32, sub_f32);
118
119#[cfg(feature = "use-blas-lapack")]
120blas_sub!(f64, sub_f64);
121
122#[cfg(test)]
123mod tests {
124 use crate::tensor_impl::gen_tensor::GenTensor;
125 use super::*;
126
127 #[test]
128 #[cfg(feature = "use-blas-lapack")]
129 fn test_add() {
130 let a = GenTensor::<f32>::ones(&[1, 2, 3]);
131 let b = GenTensor::<f32>::ones(&[1, 2, 3]);
132 let c = add_f32(&a, &b);
133 let em = GenTensor::<f32>::new_raw(&[2.0, 2.0, 2.0, 2.0, 2.0, 2.0], &[1, 2, 3]);
134 assert_eq!(c, em);
135
136 let a = GenTensor::<f64>::ones(&[1, 2, 3]);
137 let b = GenTensor::<f64>::ones(&[1, 2, 3]);
138 let c = add_f64(&a, &b);
139 let em = GenTensor::<f64>::new_raw(&[2.0, 2.0, 2.0, 2.0, 2.0, 2.0], &[1, 2, 3]);
140 assert_eq!(c, em);
141
142 let a = GenTensor::<f64>::ones(&[1, 2, 3]);
143 let b = GenTensor::<f64>::ones(&[3]);
144 let c = add_f64(&a, &b);
145 let em = GenTensor::<f64>::new_raw(&[2.0, 2.0, 2.0, 2.0, 2.0, 2.0], &[1, 2, 3]);
146 assert_eq!(c, em);
147 }
148
149 #[test]
150 #[cfg(feature = "use-blas-lapack")]
151 fn test_sub() {
152 let a = GenTensor::<f32>::ones(&[1, 2, 3]);
153 let b = GenTensor::<f32>::ones(&[1, 2, 3]);
154 let c = sub_f32(&a, &b);
155 let em = GenTensor::<f32>::new_raw(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[1, 2, 3]);
156 assert_eq!(c, em);
157
158 let a = GenTensor::<f64>::ones(&[1, 2, 3]);
159 let b = GenTensor::<f64>::ones(&[1, 2, 3]);
160 let c = sub_f64(&a, &b);
161 let em = GenTensor::<f64>::new_raw(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[1, 2, 3]);
162 assert_eq!(c, em);
163
164 let a = GenTensor::<f64>::ones(&[1, 2, 3]);
165 let b = GenTensor::<f64>::ones(&[3]);
166 let c = sub_f64(&a, &b);
167 let em = GenTensor::<f64>::new_raw(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[1, 2, 3]);
168 assert_eq!(c, em);
169 }
170}