tensor_rs/tensor_impl/lapack_tensor/
convolution.rs

1use crate::tensor_impl::gen_tensor::GenTensor;
2use crate::tensor_trait::index_slicing::IndexSlicing;
3use crate::tensor::PaddingMode;
4#[cfg(feature = "use-blas-lapack")]
5use super::blas_api::BlasAPI;
6
7#[cfg(feature = "use-blas-lapack")]
8macro_rules! blas_conv {
9    ($a:ty, $b: ident) => {
10        pub fn $b(
11            data: &GenTensor<$a>,
12            filter: &GenTensor<$a>,
13            stride: &[usize],
14            padding: &[usize],
15            dilation: &[usize],
16            padding_mode: PaddingMode
17        ) -> GenTensor<$a> {
18            let self_dim = data.size();
19            let filter_dim = filter.size();
20        
21            let out_channels = filter_dim[0];
22            let in_channels = filter_dim[1];
23            let sample_size = self_dim[0];
24        
25            // prepare the padded input
26            let mut padded_dim = Vec::new();
27            for i in 2..self_dim.len() {
28                padded_dim.push(self_dim[i] + padding[i-2]*2);
29            }
30            //println!("padded_dim: {:?}", padded_dim);
31        
32            // find the coordinate of
33            // start center point in a filter in padded dimension
34            // in case filter_dim[i] is even, start_point will be the half.
35            // in case filter_dim[i] is odd, start_point will be the center.
36            let mut start_point = Vec::new();
37            for i in 0..stride.len() {
38                let half = filter_dim[2+i]/2;
39                let dilated = half*dilation[i];
40                start_point.push(dilated);
41            }
42            //println!("start_point: {:?}", start_point);
43        
44            let mut output_size = Vec::new();
45            //println!("{:?}, {:?}", padded_dim, stride);
46            for i in 0..stride.len() {
47                let output_dim = (padded_dim[i] - dilation[i]*(filter_dim[2+i]-1)-1)/stride[i] + 1;
48                output_size.push(output_dim);
49            }
50            let mut output_tensor_size = vec![sample_size, out_channels];
51            output_tensor_size.append(&mut output_size.clone()); // output_size moved.
52            let output_inner_size = output_size.iter().product::<usize>();
53            //println!("output_size: {:?}", output_size);
54            //println!("{:?}", output_inner_size);
55            //println!("{:?}", output_tensor_size);
56                
57            let conv_size = filter_dim.iter().product::<usize>()/out_channels; // this is Cin xd1xd2xd3...
58	    //let mut data_block = vec![0.; conv_size];
59	    //let mut filter_block = vec![0.; conv_size];
60        
61            //println!("sample_size*output_inner_size*conv_size: {:?}", sample_size*output_inner_size*conv_size);
62            let mut columned_data = Vec::<$a>::with_capacity(sample_size*output_inner_size*conv_size);
63            //let columned_filter = Vec::<f32>::with_capacity(out_channels*conv_size);
64        
65            let mut left_upper = vec![0; stride.len()];
66            let mut current_data_elem = left_upper.to_vec();
67            let mut push_value: $a;
68            let mut in_margin: bool;
69        
70            for i in 0..sample_size {
71                left_upper.iter_mut().map(|x| *x = 0).count();
72        
73                for _k in 0..output_inner_size { // every possible data bl
74                    // get_data_block
75                    //let mut current_data_elem = left_upper.to_vec();
76                    current_data_elem.clone_from_slice(&left_upper);
77                    for in_channel_index in 0..in_channels {
78                        for _inner_index in 0..conv_size/in_channels {
79                    
80                            // assign single scale to the tmp tensor.
81                            push_value = 0.;
82                            in_margin = false;
83                            for i in 0..current_data_elem.len() {
84                                 if current_data_elem[i] < padding[i]
85                                    || current_data_elem[i] >= (padding[i] + self_dim[i+2]) {
86                                    match padding_mode {
87                                        PaddingMode::Zeros => {
88                                            push_value = 0.;
89                                            in_margin = true;
90                                            break;
91                                        },
92                                        _ => {unimplemented!();}
93                                    }
94                                }
95                            }
96                            if ! in_margin {
97                                let real_data_elem = current_data_elem.iter()
98                                    .zip(padding.iter())
99                                    .map(|(x, y)| x - y)
100                                    .collect::<Vec::<usize>>();
101                                let mut real_data_elem2 = vec![i, in_channel_index];
102                                real_data_elem2.append(&mut real_data_elem.clone());
103                                push_value = data.get(&real_data_elem2);
104                            }
105                    
106                            //data_block[in_channel_index*(conv_size/in_channels) + inner_index] = push_value;
107                            columned_data.push(push_value);
108                    
109                            // update to the next position.
110                            let mut current_pos = current_data_elem.len()-1;
111                            loop {
112                                current_data_elem[current_pos] += dilation[current_pos];
113                                if current_data_elem[current_pos] >= dilation[current_pos]*filter_dim[current_pos+2] + left_upper[current_pos] {
114                                    current_data_elem[current_pos] = left_upper[current_pos];
115                                    if current_pos > 0 {
116                                        current_pos -= 1;
117                                    } else {
118                                        break;
119                                    }
120                                } else {
121                                    break;
122                                }
123                            };
124                        }
125                    };
126        
127                    // update for next prodsum position
128                    let mut current_pos = left_upper.len()-1;
129                    loop {
130                        left_upper[current_pos] += stride[current_pos];
131                        let mut compare_pos = padded_dim[current_pos] - start_point[current_pos]*2;
132                        if filter_dim[current_pos+2] % 2 == 0 {
133                            compare_pos += 1;
134                        }
135                        if left_upper[current_pos] >= compare_pos {
136                            left_upper[current_pos] = 0;
137                            if current_pos > 0 {
138                                current_pos -= 1;
139                            } else {
140                                break;
141                            }
142                        } else {
143                            break;
144                        }
145                    };
146                }
147            }
148        
149            //println!("columned_data: {:?}", columned_data);
150            //println!("filter: {:?}", filter.get_data());
151            //println!("sample_size*out_channels*output_inner_size: {:?}", sample_size*out_channels*output_inner_size);
152            //println!("{:?}, {:?}, {:?}", sample_size*output_inner_size, out_channels, conv_size);
153        
154            let mut columned_result = vec![0.; sample_size*out_channels*output_inner_size];
155            BlasAPI::<$a>::gemm(true, false, sample_size*output_inner_size, out_channels, conv_size,
156                                 1., &columned_data, conv_size,
157                                 filter.get_data(), conv_size,
158                                 1., &mut columned_result, sample_size*output_inner_size
159            );
160        
161            //println!("columned_result: {:?}", columned_result);
162        
163            let mut result_dim = output_tensor_size.to_vec();
164	    result_dim.swap(0, 1);
165            let mut result = GenTensor::<$a>::new_move(columned_result.to_vec(),
166                                                        result_dim);
167            let mut permute_dim: Vec<usize> = (0..output_tensor_size.len()).collect();
168            permute_dim[0] = 1;
169            permute_dim[1] = 0;
170            result = result.permute(&permute_dim);
171            result
172            
173        }
174    }
175}
176
177#[cfg(feature = "use-blas-lapack")]
178blas_conv!(f32, gemm_conv_f32);
179
180#[cfg(feature = "use-blas-lapack")]
181blas_conv!(f64, gemm_conv_f64);
182
183
184#[cfg(test)]
185mod tests {
186    use crate::tensor_impl::gen_tensor::GenTensor;
187    use super::*;
188
189
190    // gemm_conv
191    #[test]
192    #[cfg(feature = "use-blas-lapack")]
193    fn test_gemm_conv() {
194        {
195            let data = GenTensor::<f32>::arange(30).reshape(&vec![2, 3, 5]);
196            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 3, 3]);
197            let stride = vec![1];
198            let padding = vec![0];
199            let dilation = vec![1];
200            let padding_mode = PaddingMode::Zeros;
201            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
202            println!("output size: {:?}", result.size());
203            println!("output size: {:?}", result.get_data());
204            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![312.0, 348.0, 384.0, 798.0, 915.0, 1032.0, 852.0, 888.0, 924.0, 2553.0, 2670.0, 2787.0], &vec![2, 2, 3]));
205        }
206
207        {
208            let mut raw_data = Vec::new();
209            for i in 0..75 {
210                raw_data.push(i as f32);
211            }
212            let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 5]);
213            let mut raw_data = Vec::new();
214            for i in 0..54 {
215                raw_data.push(i as f32);
216            }
217            let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 3]);
218            
219            let stride = vec![1, 1];
220            let padding = vec![0, 0];
221            let dilation = vec![1, 1];
222            let padding_mode = PaddingMode::Zeros;
223            
224            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
225            
226            println!("output size: {:?}", result.size());
227            println!("output size: {:?}", result.get_data());
228            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![15219.0, 15570.0, 15921.0, 16974.0, 17325.0, 17676.0, 18729.0, 19080.0, 19431.0, 37818.0, 38898.0, 39978.0, 43218.0, 44298.0, 45378.0, 48618.0, 49698.0, 50778.0], &vec![1, 2, 3, 3]));    
229        }
230
231        {
232            let mut raw_data = Vec::new();
233            for i in 0..60 {
234                raw_data.push(i as f32);
235            }
236            let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 4]);
237            let mut raw_data = Vec::new();
238            for i in 0..36 {
239                raw_data.push(i as f32);
240            }
241            let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 2]);
242            
243            let stride = vec![1, 1];
244            let padding = vec![0, 0];
245            let dilation = vec![1, 1];
246            let padding_mode = PaddingMode::Zeros;
247            
248            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
249            
250            println!("output size: {:?}", result.size());
251            println!("output size: {:?}", result.get_data());
252            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![5289.0, 5442.0, 5595.0, 5901.0, 6054.0, 6207.0, 6513.0, 6666.0, 6819.0, 13227.0, 13704.0, 14181.0, 15135.0, 15612.0, 16089.0, 17043.0, 17520.0, 17997.0], &vec![1, 2, 3, 3]));    
253        }
254
255        {
256            let data = GenTensor::<f32>::arange(375).reshape(&vec![1, 3, 5, 5, 5]);
257            let filter = GenTensor::<f32>::arange(162).reshape(&vec![2, 3, 3, 3, 3]);
258            let stride = vec![1, 1, 1];
259            let padding = vec![0, 0, 0];
260            let dilation = vec![1, 1, 1];
261            let padding_mode = PaddingMode::Zeros;
262            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
263            println!("output size: {:?}", result.size());
264            println!("output size: {:?}", result.get_data());
265            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![700704.0, 703944.0, 707184.0, 716904.0, 720144.0, 723384.0, 733104.0, 736344.0, 739584.0, 781704.0, 784944.0, 788184.0, 797904.0, 801144.0, 804384.0, 814104.0, 817344.0, 820584.0, 862704.0, 865944.0, 869184.0, 878904.0, 882144.0, 885384.0, 895104.0, 898344.0, 901584.0, 1724220.0, 1734021.0, 1743822.0, 1773225.0, 1783026.0, 1792827.0, 1822230.0, 1832031.0, 1841832.0, 1969245.0, 1979046.0, 1988847.0, 2018250.0, 2028051.0, 2037852.0, 2067255.0, 2077056.0, 2086857.0, 2214270.0, 2224071.0, 2233872.0, 2263275.0, 2273076.0, 2282877.0, 2312280.0, 2322081.0, 2331882.0], &vec![1, 2, 3, 3, 3]));
266        }
267
268        {
269            let data = GenTensor::<f32>::arange(16).reshape(&vec![1, 1, 4, 4]);
270            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
271            let stride = vec![1, 1];
272            let padding = vec![1, 1];
273            let dilation = vec![1, 1];
274            let padding_mode = PaddingMode::Zeros;
275            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
276            println!("final output size: {:?}", result.size());
277            println!("final output: {:?}", result.get_data());
278            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![73.0, 121.0, 154.0, 103.0, 171.0, 258.0, 294.0, 186.0, 279.0, 402.0, 438.0, 270.0, 139.0, 187.0, 202.0, 113.0, 163.0, 283.0, 370.0, 265.0, 414.0, 663.0, 780.0, 537.0, 738.0, 1131.0, 1248.0, 837.0, 517.0, 781.0, 850.0, 563.0], &vec![1, 2, 4, 4]));
279        }
280
281        {
282            let data = GenTensor::<f32>::arange(49).reshape(&vec![1, 1, 7, 7]);
283            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
284            let stride = vec![2, 2];
285            let padding = vec![0, 0];
286            let dilation = vec![1, 1];
287            let padding_mode = PaddingMode::Zeros;
288            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
289            println!("final output size: {:?}", result.size());
290            println!("final output: {:?}", result.get_data());
291            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
292        }
293        
294        {
295            
296            let data = GenTensor::<f32>::arange(49).reshape(&vec![1, 1, 7, 7]);
297            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
298            let stride = vec![2, 2];
299            let padding = vec![0, 0];
300            let dilation = vec![1, 1];
301            let padding_mode = PaddingMode::Zeros;
302            let result = gemm_conv_f32(&data, &filter, &stride, &padding, &dilation, padding_mode);
303            //println!("final output size: {:?}", result.size());
304            //println!("final output: {:?}", result.get_data());
305            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
306        }
307
308        {
309            
310            let data = GenTensor::<f64>::arange(49).reshape(&vec![1, 1, 7, 7]);
311            let filter = GenTensor::<f64>::arange(18).reshape(&vec![2, 1, 3, 3]);
312            let stride = vec![2, 2];
313            let padding = vec![0, 0];
314            let dilation = vec![1, 1];
315            let padding_mode = PaddingMode::Zeros;
316            let result = gemm_conv_f64(&data, &filter, &stride, &padding, &dilation, padding_mode);
317            //println!("final output size: {:?}", result.size());
318            //println!("final output: {:?}", result.get_data());
319            assert_eq!(result, GenTensor::<f64>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
320        }
321    }
322}