1use crate::tensor_impl::gen_tensor::GenTensor;
2use crate::tensor_trait::index_slicing::IndexSlicing;
3use crate::tensor::PaddingMode;
4#[cfg(feature = "use-blas-lapack")]
5use super::blas_api::BlasAPI;
6
7#[cfg(feature = "use-blas-lapack")]
8macro_rules! blas_conv {
9 ($a:ty, $b: ident) => {
10 pub fn $b(
11 data: &GenTensor<$a>,
12 filter: &GenTensor<$a>,
13 stride: &[usize],
14 padding: &[usize],
15 dilation: &[usize],
16 padding_mode: PaddingMode
17 ) -> GenTensor<$a> {
18 let self_dim = data.size();
19 let filter_dim = filter.size();
20
21 let out_channels = filter_dim[0];
22 let in_channels = filter_dim[1];
23 let sample_size = self_dim[0];
24
25 let mut padded_dim = Vec::new();
27 for i in 2..self_dim.len() {
28 padded_dim.push(self_dim[i] + padding[i-2]*2);
29 }
30 let mut start_point = Vec::new();
37 for i in 0..stride.len() {
38 let half = filter_dim[2+i]/2;
39 let dilated = half*dilation[i];
40 start_point.push(dilated);
41 }
42 let mut output_size = Vec::new();
45 for i in 0..stride.len() {
47 let output_dim = (padded_dim[i] - dilation[i]*(filter_dim[2+i]-1)-1)/stride[i] + 1;
48 output_size.push(output_dim);
49 }
50 let mut output_tensor_size = vec![sample_size, out_channels];
51 output_tensor_size.append(&mut output_size.clone()); let output_inner_size = output_size.iter().product::<usize>();
53 let conv_size = filter_dim.iter().product::<usize>()/out_channels; let mut columned_data = Vec::<$a>::with_capacity(sample_size*output_inner_size*conv_size);
63 let mut left_upper = vec![0; stride.len()];
66 let mut current_data_elem = left_upper.to_vec();
67 let mut push_value: $a;
68 let mut in_margin: bool;
69
70 for i in 0..sample_size {
71 left_upper.iter_mut().map(|x| *x = 0).count();
72
73 for _k in 0..output_inner_size { current_data_elem.clone_from_slice(&left_upper);
77 for in_channel_index in 0..in_channels {
78 for _inner_index in 0..conv_size/in_channels {
79
80 push_value = 0.;
82 in_margin = false;
83 for i in 0..current_data_elem.len() {
84 if current_data_elem[i] < padding[i]
85 || current_data_elem[i] >= (padding[i] + self_dim[i+2]) {
86 match padding_mode {
87 PaddingMode::Zeros => {
88 push_value = 0.;
89 in_margin = true;
90 break;
91 },
92 _ => {unimplemented!();}
93 }
94 }
95 }
96 if ! in_margin {
97 let real_data_elem = current_data_elem.iter()
98 .zip(padding.iter())
99 .map(|(x, y)| x - y)
100 .collect::<Vec::<usize>>();
101 let mut real_data_elem2 = vec![i, in_channel_index];
102 real_data_elem2.append(&mut real_data_elem.clone());
103 push_value = data.get(&real_data_elem2);
104 }
105
106 columned_data.push(push_value);
108
109 let mut current_pos = current_data_elem.len()-1;
111 loop {
112 current_data_elem[current_pos] += dilation[current_pos];
113 if current_data_elem[current_pos] >= dilation[current_pos]*filter_dim[current_pos+2] + left_upper[current_pos] {
114 current_data_elem[current_pos] = left_upper[current_pos];
115 if current_pos > 0 {
116 current_pos -= 1;
117 } else {
118 break;
119 }
120 } else {
121 break;
122 }
123 };
124 }
125 };
126
127 let mut current_pos = left_upper.len()-1;
129 loop {
130 left_upper[current_pos] += stride[current_pos];
131 let mut compare_pos = padded_dim[current_pos] - start_point[current_pos]*2;
132 if filter_dim[current_pos+2] % 2 == 0 {
133 compare_pos += 1;
134 }
135 if left_upper[current_pos] >= compare_pos {
136 left_upper[current_pos] = 0;
137 if current_pos > 0 {
138 current_pos -= 1;
139 } else {
140 break;
141 }
142 } else {
143 break;
144 }
145 };
146 }
147 }
148
149 let mut columned_result = vec![0.; sample_size*out_channels*output_inner_size];
155 BlasAPI::<$a>::gemm(true, false, sample_size*output_inner_size, out_channels, conv_size,
156 1., &columned_data, conv_size,
157 filter.get_data(), conv_size,
158 1., &mut columned_result, sample_size*output_inner_size
159 );
160
161 let mut result_dim = output_tensor_size.to_vec();
164 result_dim.swap(0, 1);
165 let mut result = GenTensor::<$a>::new_move(columned_result.to_vec(),
166 result_dim);
167 let mut permute_dim: Vec<usize> = (0..output_tensor_size.len()).collect();
168 permute_dim[0] = 1;
169 permute_dim[1] = 0;
170 result = result.permute(&permute_dim);
171 result
172
173 }
174 }
175}
176
177#[cfg(feature = "use-blas-lapack")]
178blas_conv!(f32, gemm_conv_f32);
179
180#[cfg(feature = "use-blas-lapack")]
181blas_conv!(f64, gemm_conv_f64);
182
183
184#[cfg(test)]
185mod tests {
186 use crate::tensor_impl::gen_tensor::GenTensor;
187 use super::*;
188
189
190 #[test]
192 #[cfg(feature = "use-blas-lapack")]
193 fn test_gemm_conv() {
194 {
195 let data = GenTensor::<f32>::arange(30).reshape(&vec![2, 3, 5]);
196 let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 3, 3]);
197 let stride = vec![1];
198 let padding = vec![0];
199 let dilation = vec![1];
200 let padding_mode = PaddingMode::Zeros;
201 let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
202 println!("output size: {:?}", result.size());
203 println!("output size: {:?}", result.get_data());
204 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![312.0, 348.0, 384.0, 798.0, 915.0, 1032.0, 852.0, 888.0, 924.0, 2553.0, 2670.0, 2787.0], &vec![2, 2, 3]));
205 }
206
207 {
208 let mut raw_data = Vec::new();
209 for i in 0..75 {
210 raw_data.push(i as f32);
211 }
212 let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 5]);
213 let mut raw_data = Vec::new();
214 for i in 0..54 {
215 raw_data.push(i as f32);
216 }
217 let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 3]);
218
219 let stride = vec![1, 1];
220 let padding = vec![0, 0];
221 let dilation = vec![1, 1];
222 let padding_mode = PaddingMode::Zeros;
223
224 let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
225
226 println!("output size: {:?}", result.size());
227 println!("output size: {:?}", result.get_data());
228 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![15219.0, 15570.0, 15921.0, 16974.0, 17325.0, 17676.0, 18729.0, 19080.0, 19431.0, 37818.0, 38898.0, 39978.0, 43218.0, 44298.0, 45378.0, 48618.0, 49698.0, 50778.0], &vec![1, 2, 3, 3]));
229 }
230
231 {
232 let mut raw_data = Vec::new();
233 for i in 0..60 {
234 raw_data.push(i as f32);
235 }
236 let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 4]);
237 let mut raw_data = Vec::new();
238 for i in 0..36 {
239 raw_data.push(i as f32);
240 }
241 let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 2]);
242
243 let stride = vec![1, 1];
244 let padding = vec![0, 0];
245 let dilation = vec![1, 1];
246 let padding_mode = PaddingMode::Zeros;
247
248 let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
249
250 println!("output size: {:?}", result.size());
251 println!("output size: {:?}", result.get_data());
252 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![5289.0, 5442.0, 5595.0, 5901.0, 6054.0, 6207.0, 6513.0, 6666.0, 6819.0, 13227.0, 13704.0, 14181.0, 15135.0, 15612.0, 16089.0, 17043.0, 17520.0, 17997.0], &vec![1, 2, 3, 3]));
253 }
254
255 {
256 let data = GenTensor::<f32>::arange(375).reshape(&vec![1, 3, 5, 5, 5]);
257 let filter = GenTensor::<f32>::arange(162).reshape(&vec![2, 3, 3, 3, 3]);
258 let stride = vec![1, 1, 1];
259 let padding = vec![0, 0, 0];
260 let dilation = vec![1, 1, 1];
261 let padding_mode = PaddingMode::Zeros;
262 let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
263 println!("output size: {:?}", result.size());
264 println!("output size: {:?}", result.get_data());
265 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![700704.0, 703944.0, 707184.0, 716904.0, 720144.0, 723384.0, 733104.0, 736344.0, 739584.0, 781704.0, 784944.0, 788184.0, 797904.0, 801144.0, 804384.0, 814104.0, 817344.0, 820584.0, 862704.0, 865944.0, 869184.0, 878904.0, 882144.0, 885384.0, 895104.0, 898344.0, 901584.0, 1724220.0, 1734021.0, 1743822.0, 1773225.0, 1783026.0, 1792827.0, 1822230.0, 1832031.0, 1841832.0, 1969245.0, 1979046.0, 1988847.0, 2018250.0, 2028051.0, 2037852.0, 2067255.0, 2077056.0, 2086857.0, 2214270.0, 2224071.0, 2233872.0, 2263275.0, 2273076.0, 2282877.0, 2312280.0, 2322081.0, 2331882.0], &vec![1, 2, 3, 3, 3]));
266 }
267
268 {
269 let data = GenTensor::<f32>::arange(16).reshape(&vec![1, 1, 4, 4]);
270 let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
271 let stride = vec![1, 1];
272 let padding = vec![1, 1];
273 let dilation = vec![1, 1];
274 let padding_mode = PaddingMode::Zeros;
275 let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
276 println!("final output size: {:?}", result.size());
277 println!("final output: {:?}", result.get_data());
278 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![73.0, 121.0, 154.0, 103.0, 171.0, 258.0, 294.0, 186.0, 279.0, 402.0, 438.0, 270.0, 139.0, 187.0, 202.0, 113.0, 163.0, 283.0, 370.0, 265.0, 414.0, 663.0, 780.0, 537.0, 738.0, 1131.0, 1248.0, 837.0, 517.0, 781.0, 850.0, 563.0], &vec![1, 2, 4, 4]));
279 }
280
281 {
282 let data = GenTensor::<f32>::arange(49).reshape(&vec![1, 1, 7, 7]);
283 let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
284 let stride = vec![2, 2];
285 let padding = vec![0, 0];
286 let dilation = vec![1, 1];
287 let padding_mode = PaddingMode::Zeros;
288 let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
289 println!("final output size: {:?}", result.size());
290 println!("final output: {:?}", result.get_data());
291 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
292 }
293
294 {
295
296 let data = GenTensor::<f32>::arange(49).reshape(&vec![1, 1, 7, 7]);
297 let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
298 let stride = vec![2, 2];
299 let padding = vec![0, 0];
300 let dilation = vec![1, 1];
301 let padding_mode = PaddingMode::Zeros;
302 let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
303 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
306 }
307
308 {
309
310 let data = GenTensor::<f64>::arange(49).reshape(&vec![1, 1, 7, 7]);
311 let filter = GenTensor::<f64>::arange(18).reshape(&vec![2, 1, 3, 3]);
312 let stride = vec![2, 2];
313 let padding = vec![0, 0];
314 let dilation = vec![1, 1];
315 let padding_mode = PaddingMode::Zeros;
316 let result = gemm_conv_f64(&data, &filter, &stride, &padding, &dilation, padding_mode);
317 assert_eq!(result, GenTensor::<f64>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
320 }
321 }
322}