tensor_rs/tensor_impl/gen_tensor/
convolution.rs

1use std::collections::BTreeMap;
2use crate::tensor::PaddingMode;
3use super::GenTensor;
4use crate::tensor_trait::convolution::Convolution;
5
6impl<T> Convolution for GenTensor<T> where T: num_traits::Float {
7
8    // conv2d ops
9    fn conv2d(&self, filter: &GenTensor<T>,
10                  stride: (usize, usize),
11                  padding: (usize, usize),
12                  dilation: (usize, usize),
13                  padding_mode: PaddingMode
14    ) -> Self {
15        self.conv_gen(filter,
16                      &[stride.0, stride.1],
17                      &[padding.0, padding.1],
18                      &[dilation.0, dilation.1],
19                      padding_mode)
20    }
21    fn conv2d_grad(&self, filter: &GenTensor<T>,
22                       stride: (usize, usize),
23                       padding: (usize, usize),
24                       dilation: (usize, usize),
25                       padding_mode: PaddingMode,
26                       output_grad: &GenTensor<T>
27    ) -> (Self, Self){
28            self.conv_grad_gen(filter,
29                           &[stride.0, stride.1],
30                           &[padding.0, padding.1],
31                           &[dilation.0, dilation.1],
32                           padding_mode,
33                           output_grad)
34    }
35
36    // gneral convolutional operator, should work for 2d and 3d cases.
37    fn conv_gen(&self, filter: &GenTensor<T>,
38                    stride: &[usize],
39                    padding: &[usize],
40                    dilation: &[usize],
41                    padding_mode: PaddingMode
42    ) -> GenTensor<T> {
43        let self_dim = self.size();
44        let filter_dim = filter.size();
45        if self_dim.len() != filter_dim.len() {
46            panic!("covn2d expects input and filter has the same dims, get {:?}, {:?}", self_dim, filter_dim);
47        }
48        if stride.len() != padding.len() || stride.len() != dilation.len() || stride.len() != (self_dim.len() - 2) {
49            panic!("stride, padding, stride should have the same # of dims, {:?}, {:?}, {:?}", stride, padding, dilation);
50        }
51        if stride.iter().any(|x| *x < 1) {
52            panic!("stride should be at least 1, get {:?}", stride);
53        }
54        if dilation.iter().any(|x| *x < 1) {
55            panic!("dilation should be at least 1, get {:?}", dilation);
56        }
57
58        let out_channels = filter_dim[0];
59        let in_channels = filter_dim[1];
60        let sample_size = self_dim[0];
61        let data_channels = self_dim[1];
62        if in_channels != data_channels {
63            panic!("covn2d expects input data channel size matches depth in filter {:?}, {:?}", self_dim, filter_dim);
64        }
65        
66        // prepare the padded input
67        let mut padded_dim = Vec::new();
68        for i in 2..self_dim.len() {
69            padded_dim.push(self_dim[i] + padding[i-2]*2);
70        }
71        //println!("padded_dim: {:?}", padded_dim);
72
73        // find the coordinate of
74        // start center point in a filter in padded dimension
75        // in case filter_dim[i] is even, start_point will be the half.
76        // in case filter_dim[i] is odd, start_point will be the center.
77        let mut start_point = Vec::new();
78        for i in 0..stride.len() {
79            let half = filter_dim[2+i]/2;
80            let dilated = half*dilation[i];
81            start_point.push(dilated);
82        }
83        //println!("start_point: {:?}", start_point);
84
85        let mut output_size = Vec::new();
86        //println!("{:?}, {:?}", padded_dim, stride);
87        for i in 0..stride.len() {
88            let output_dim = (padded_dim[i] - dilation[i]*(filter_dim[2+i]-1)-1)/stride[i] + 1;
89            output_size.push(output_dim);
90        }
91        let mut output_tensor_size = vec![sample_size, out_channels];
92        output_tensor_size.append(&mut output_size.clone()); // output_size moved.
93        let output_inner_size = output_size.iter().product::<usize>();
94        //println!("output_size: {:?}", output_size);
95        //println!("{:?}", output_inner_size);
96        //println!("{:?}", output_tensor_size);
97        
98        let mut ret = GenTensor::<T>::zeros(&output_tensor_size);
99
100        let conv_size = filter_dim.iter().product::<usize>()/out_channels; // this is Cin xd1xd2xd3...
101        let mut data_block = vec![T::zero(); conv_size];
102        let mut filter_block = vec![T::zero(); conv_size];
103
104        let inner_steps = output_inner_size*out_channels;
105        let filter_step = conv_size;
106        
107        for i in 0..sample_size {
108            for j in 0..out_channels {
109                filter_block.copy_from_slice(&filter.get_data()[(j)*filter_step..(j+1)*filter_step]);
110
111                let mut left_upper = vec![0; stride.len()];
112                for k in 0..output_inner_size {
113                    //println!("left_upper: {:?}", left_upper);
114
115                    // get_data_block
116                    let mut current_data_elem = left_upper.to_vec();
117                    for in_channel_index in 0..in_channels {
118                        for inner_index in 0..conv_size/in_channels {
119
120                            // assign single scale to the tmp tensor.
121                            let mut push_value = T::zero();
122                            let mut in_margin = false;
123                            for i in 0..current_data_elem.len() {
124                                if current_data_elem[i] < padding[i] || current_data_elem[i] >= (padding[i] + self_dim[i+2]){
125                                    match padding_mode {
126                                        PaddingMode::Zeros => {
127                                            push_value = T::zero();
128                                            in_margin = true;
129                                            break;
130                                        },
131                                        _ => {unimplemented!();}
132                                    }
133                                }
134                            }
135                            if ! in_margin {
136                                let real_data_elem = current_data_elem.iter().zip(padding.iter()).map(|(x, y)| x - y).collect::<Vec::<usize>>();
137                                let mut real_data_elem2 = vec![i, in_channel_index];
138                                real_data_elem2.append(&mut real_data_elem.clone());
139                                push_value = self.get(&real_data_elem2);
140                            }
141
142                            data_block[in_channel_index*(conv_size/in_channels) + inner_index] = push_value;
143
144
145                            // update to the next position.
146                            let mut current_pos = current_data_elem.len()-1;
147                            loop {
148                                current_data_elem[current_pos] += dilation[current_pos];
149                                if current_data_elem[current_pos] >= dilation[current_pos]*filter_dim[current_pos+2] + left_upper[current_pos] {
150                                    current_data_elem[current_pos] = left_upper[current_pos];
151                                    if current_pos > 0 {
152                                        current_pos -= 1;
153                                    } else {
154                                        break;
155                                    }
156                                } else {
157                                    break;
158                                }
159                            };
160                        }
161                    };
162                
163                    //let value = data_block.iter().zip(&filter_block).map(|(x, y)|
164                    //                                                     (*x)*(*y)
165                    //).sum::<T>();
166                    let mut value = T::zero();
167                    for (x, y) in data_block.iter().zip(&filter_block) {
168                        value = value + (*x)*(*y);
169                    }
170                    //println!("index: {}, {}, {}", i, j, k);
171                    //println!("raw index: {}", i*inner_steps + j*output_inner_size + k);
172                    //ret.d[i*inner_steps + j*output_inner_size + k] = value;
173                    ret.set_1d(i*inner_steps + j*output_inner_size + k, value);
174
175                    // update for next prodsum position
176                    let mut current_pos = left_upper.len()-1;
177                    loop {
178                        left_upper[current_pos] += stride[current_pos];
179                        let mut compare_pos = padded_dim[current_pos] - start_point[current_pos]*2;
180                        if filter_dim[current_pos+2] % 2 == 0 {
181                            compare_pos += 1;
182                        }
183                        if left_upper[current_pos] >= compare_pos {
184                            left_upper[current_pos] = 0;
185                            if current_pos > 0 {
186                                current_pos -= 1;
187                            } else {
188                                break;
189                            }
190                        } else {
191                            break;
192                        }
193                    };
194
195                }
196            }
197        }
198        
199        ret
200    }
201
202    // the 1st return is the gradient for w
203    // the 2nd return is the gradient for the input, given the output_grad
204    fn conv_grad_gen(&self, filter: &GenTensor<T>,
205                         stride: &[usize],
206                         padding: &[usize],
207                         dilation: &[usize],
208                         padding_mode: PaddingMode,
209                         output_grad: &GenTensor<T>,
210    ) -> (GenTensor<T>, GenTensor<T>) {
211        let self_dim = self.size();
212        let filter_dim = filter.size();
213        let output_grad_dim = output_grad.size();
214        if self_dim.len() <= 2 {
215            panic!("input data for conv has not enough dim {:?}.", self_dim);
216        }
217        if filter_dim.len() <= 2 {
218            panic!("filter for conv has not enough dim {:?}.", filter_dim);
219        }
220        if output_grad_dim.len() <= 2 {
221            panic!("output gradient for conv has not enough dim {:?}.", filter_dim);
222        }
223        if self_dim.len() != filter_dim.len() || self_dim.len() != output_grad_dim.len() {
224            panic!("covn2d expects input, output gradient and filter has the same dims, get {:?}, {:?}, {:?}", self_dim, filter_dim, output_grad_dim);
225        }
226        if filter_dim[1] != self_dim[1] {
227            panic!("covn2d expects input data channel size matches depth in filter {:?}, {:?}", self_dim, filter_dim);
228        }
229        if self_dim[0] != output_grad_dim[0] {
230            panic!("conv2d expects input and output has the same N: {:?}, {:?}", self_dim, output_grad_dim);
231        }
232        if filter_dim[0] != output_grad_dim[1] {
233            panic!("conv2d expects filter and output has the same Cout: {:?}, {:?}", filter_dim, output_grad_dim);
234        }
235        if stride.len() != padding.len() || stride.len() != dilation.len() {
236            panic!("stride, padding, stride should have the same # of dims, {:?}, {:?}, {:?}", stride, padding, dilation);
237        }
238        if stride.len()+2 != filter_dim.len() {
239            panic!("expect the same inner size, {:?}, {:?}", stride, filter_dim);
240        }
241        
242        let filter_size = filter.size();
243        let n_c_out = filter_size[0];
244        let n_c_in = filter_size[1];
245        let n_n = self_dim[0];
246        //let n_d_dd = self_dim.iter().product::<usize>()/n_n/n_c_in;
247        let n_f_dd = filter_dim.iter().product::<usize>()/n_c_out/n_c_in;
248        let d_inner = self_dim.len() - 2;
249
250        let output_dd = output_grad_dim.iter().product::<usize>()/n_n/n_c_out;
251
252        // save all the record
253        let mut w_grad: BTreeMap<usize, Vec<T>> = BTreeMap::new();
254        let mut x_grad: BTreeMap<usize, Vec<T>> = BTreeMap::new();
255
256        for i in 0..n_n {
257            for j in 0..n_c_out {
258                // left_upper in padded dimension.
259                let mut left_upper = vec![0; d_inner];
260
261                let mut output_index = 0;
262                
263                loop {
264                    //println!("left_upper: {:?}", left_upper);
265
266                    // get the current output_gradient
267                    let output_real_index = j*output_dd + i*n_c_out*output_dd + output_index;
268                    //println!("output_real_index: {:?}", output_real_index);
269                    let output_dimpos = output_grad.index2dimpos(output_real_index);
270                    //println!("output_dimpos: {:?}", output_dimpos);
271                    let output_gradient_value = output_grad.get(&output_dimpos);
272                    //println!("output_gradient_value: {:?}", output_gradient_value.to_f32());
273
274                    // remember where to get data.
275                    // let mut data_loc = BTreeMap::<Vec::<usize>, >::new();
276
277                    for cin_index in 0..n_c_in {
278                        for dd_index in 0..n_f_dd {
279
280                            // get current position for filter elements.
281                            let mut filter_elem = Vec::new();
282                            let mut reminder = dd_index;
283                            for dim_pos in 0..d_inner {
284                                let left_product = filter_size[dim_pos+3..filter_size.len()]
285                                    .iter()
286                                    .product::<usize>();
287                                filter_elem.push(reminder / left_product);
288                                reminder %= left_product;
289                            }
290                            //println!("filter_elem: {:?}", filter_elem);
291
292                            
293                            // get current position for data elements in padded dimension
294                            let mut data_elem = left_upper.to_vec();
295                            for dim_pos in 0..d_inner {
296                                data_elem[dim_pos] += filter_elem[dim_pos]*dilation[dim_pos];
297                            }
298                            //println!("data_elem: {:?}", data_elem);
299
300
301                            // find real current position from filter
302                            let mut full_filter_elem = vec![j, cin_index];
303                            full_filter_elem.append(&mut filter_elem.clone());
304                            // println!("filter_value: {}", filter_value.to_f32().expect(""));
305                            // println!("full_filter_elem: {:?}", full_filter_elem);
306
307                            // find real current position from data
308                            let mut zero_padded_flag = false;
309                            let mut unpadded_elem = data_elem.clone();
310                            //println!("data_elem: {:?}", data_elem);
311                            for dim_pos in 0..d_inner {
312                                if data_elem[dim_pos] < padding[dim_pos] {
313                                    match padding_mode {
314                                        PaddingMode::Zeros => {
315                                            zero_padded_flag = true;
316                                        },
317                                        PaddingMode::Reflect => {
318                                            unpadded_elem[dim_pos] = padding[dim_pos] - data_elem[dim_pos] - 1;
319                                        },
320                                        PaddingMode::Replicate => {
321                                            unpadded_elem[dim_pos] = 0;
322                                        },
323                                        PaddingMode::Circular => {
324                                            unpadded_elem[dim_pos] = self_dim[dim_pos+2] - (padding[dim_pos] - data_elem[dim_pos]);
325                                        },
326                                    }
327                                } else if data_elem[dim_pos] >= self_dim[dim_pos + 2] + padding[dim_pos] {
328                                    match padding_mode {
329                                        PaddingMode::Zeros => {
330                                            zero_padded_flag = true;
331                                        },
332                                        PaddingMode::Reflect => {
333                                            unpadded_elem[dim_pos] = self_dim[dim_pos+2] - (data_elem[dim_pos] - (self_dim[dim_pos + 2] + padding[dim_pos]) + 1);
334                                        },
335                                        PaddingMode::Replicate => {
336                                            unpadded_elem[dim_pos] = self_dim[dim_pos + 2]-1;
337                                        },
338                                        PaddingMode::Circular => {
339                                            unpadded_elem[dim_pos] = data_elem[dim_pos] - (self_dim[dim_pos + 2] + padding[dim_pos]);
340                                        },
341                                    }
342                                } else {
343                                    unpadded_elem[dim_pos] -= padding[dim_pos];
344                                }
345                            }
346
347                            if zero_padded_flag {
348                                continue;
349                            } else {
350                                //println!("unpadded_elem: {:?}", unpadded_elem);
351                                let mut full_data_elem = vec![i, cin_index];
352                                full_data_elem.append(&mut unpadded_elem.clone());
353                                //println!("full_data_elem: {:?}", full_data_elem);
354                                
355                                let filter_value = filter.get(&full_filter_elem);
356                                let data_value = self.get(&full_data_elem);
357                                
358                                // collect all the data.
359                                let w_grad_value = output_gradient_value*data_value;
360                                let x_grad_value = output_gradient_value*filter_value;
361                                
362                                let total_w_index = filter.dimpos2index(&full_filter_elem);
363                                let total_x_index = self.dimpos2index(&full_data_elem);
364                                
365                                //println!("full_data_elem: {:?}, total_x_index: {:?}, data_value: {:?}",
366                                //         full_data_elem,
367                                //         total_x_index,
368                                //         data_value.to_f32());
369                                //println!("full_filter_elem: {:?}, total_w_index: {:?}, filter_value: {:?}, w_grad_value: {:?}, output_gradient_value: {:?}, data_vluae: {:?}",
370                                //         full_filter_elem,
371                                //         total_w_index,
372                                //         filter_value.to_f32(),
373                                //         w_grad_value.to_f32(),
374                                //         output_gradient_value.to_f32(),
375                                //         data_value.to_f32());
376                                
377                                if let std::collections::btree_map::Entry::Vacant(e) = w_grad.entry(total_w_index) {
378                                    e.insert(vec![w_grad_value]);
379                                } else {
380                                    w_grad.get_mut(&total_w_index).expect("").push(w_grad_value);
381                                }
382                                
383                                if let std::collections::btree_map::Entry::Vacant(e) = x_grad.entry(total_x_index) {
384                                     e.insert(vec![x_grad_value]);
385                                 } else {
386                                     x_grad.get_mut(&total_x_index).expect("").push(x_grad_value);
387                                 }    
388                            }
389                            
390                        }
391                    }
392
393                    // update left_upper to the next position.
394                    for current_pos in 0..d_inner {
395                        let real_pos = d_inner - current_pos - 1;
396                        left_upper[real_pos] += stride[real_pos];
397                        
398                        let compare_pos = self_dim[real_pos+2]
399                            + padding[real_pos]*2
400                            - ((filter_dim[real_pos + 2]-1)*dilation[real_pos] + 1);
401                        
402                        if left_upper[real_pos] > compare_pos {
403                            left_upper[real_pos] = 0;
404                        } else {
405                            break;
406                        }
407                    }
408                    if left_upper.iter().sum::<usize>() == 0 {
409                        break;
410                    }
411                    output_index += 1;
412                };
413            }
414        }
415
416        let mut ret_w_grad = GenTensor::zeros(filter.size());
417        let mut ret_x_grad = GenTensor::zeros(self.size());
418
419        for i in w_grad.keys() {
420            //println!("i: {:?}", i);
421            let mut sum = T::zero();
422            for w_value in w_grad.get(i).expect("") {
423                sum = sum + *w_value;
424                //println!("w_value: {}", w_value.to_f32().expect("") );
425            }
426            //ret_w_grad.d[*i] = sum/T::from(w_grad.get(i).expect("").len()).expect("");
427            //ret_w_grad.d[*i] = sum;
428            ret_w_grad.set_1d(*i, sum);
429        }
430        for i in x_grad.keys() {
431            //println!("i: {:?}", i);
432            let mut sum = T::zero();
433            for x_value in x_grad.get(i).expect("") {
434                sum = sum + *x_value;
435                //println!("x_value: {}", x_value.to_f32().expect("") );
436            }
437            //ret_x_grad.d[*i] = sum/T::from(x_grad.get(i).expect("").len()).expect("");
438            //ret_x_grad.d[*i] = sum;
439            ret_x_grad.set_1d(*i, sum);
440        }
441        
442        (ret_w_grad, ret_x_grad)
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use crate::tensor_impl::gen_tensor::GenTensor;
449    use crate::tensor_trait::index_slicing::IndexSlicing;
450    use super::*;
451
452    #[test]
453    fn conv_gen() {
454
455        {
456            let data = GenTensor::<f32>::arange(30).reshape(&vec![2, 3, 5]);
457            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 3, 3]);
458            let stride = vec![1];
459            let padding = vec![0];
460            let dilation = vec![1];
461            let padding_mode = PaddingMode::Zeros;
462            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
463            println!("output size: {:?}", result.size());
464            println!("output size: {:?}", result.get_data());
465            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![312.0, 348.0, 384.0, 798.0, 915.0, 1032.0, 852.0, 888.0, 924.0, 2553.0, 2670.0, 2787.0], &vec![2, 2, 3]));
466        }
467
468        {
469            let mut raw_data = Vec::new();
470            for i in 0..75 {
471                raw_data.push(i as f32);
472            }
473            let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 5]);
474            let mut raw_data = Vec::new();
475            for i in 0..54 {
476                raw_data.push(i as f32);
477            }
478            let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 3]);
479            
480            let stride = vec![1, 1];
481            let padding = vec![0, 0];
482            let dilation = vec![1, 1];
483            let padding_mode = PaddingMode::Zeros;
484            
485            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
486            
487            println!("output size: {:?}", result.size());
488            println!("output size: {:?}", result.get_data());
489            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![15219.0, 15570.0, 15921.0, 16974.0, 17325.0, 17676.0, 18729.0, 19080.0, 19431.0, 37818.0, 38898.0, 39978.0, 43218.0, 44298.0, 45378.0, 48618.0, 49698.0, 50778.0], &vec![1, 2, 3, 3]));    
490        }
491        
492        {
493            let mut raw_data = Vec::new();
494            for i in 0..60 {
495                raw_data.push(i as f32);
496            }
497            let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 4]);
498            let mut raw_data = Vec::new();
499            for i in 0..36 {
500                raw_data.push(i as f32);
501            }
502            let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 2]);
503            
504            let stride = vec![1, 1];
505            let padding = vec![0, 0];
506            let dilation = vec![1, 1];
507            let padding_mode = PaddingMode::Zeros;
508            
509            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
510            
511            println!("output size: {:?}", result.size());
512            println!("output size: {:?}", result.get_data());
513            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![5289.0, 5442.0, 5595.0, 5901.0, 6054.0, 6207.0, 6513.0, 6666.0, 6819.0, 13227.0, 13704.0, 14181.0, 15135.0, 15612.0, 16089.0, 17043.0, 17520.0, 17997.0], &vec![1, 2, 3, 3]));    
514        }
515
516        {
517            let data = GenTensor::<f32>::arange(375).reshape(&vec![1, 3, 5, 5, 5]);
518            let filter = GenTensor::<f32>::arange(162).reshape(&vec![2, 3, 3, 3, 3]);
519            let stride = vec![1, 1, 1];
520            let padding = vec![0, 0, 0];
521            let dilation = vec![1, 1, 1];
522            let padding_mode = PaddingMode::Zeros;
523            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
524            println!("output size: {:?}", result.size());
525            println!("output size: {:?}", result.get_data());
526            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![700704.0, 703944.0, 707184.0, 716904.0, 720144.0, 723384.0, 733104.0, 736344.0, 739584.0, 781704.0, 784944.0, 788184.0, 797904.0, 801144.0, 804384.0, 814104.0, 817344.0, 820584.0, 862704.0, 865944.0, 869184.0, 878904.0, 882144.0, 885384.0, 895104.0, 898344.0, 901584.0, 1724220.0, 1734021.0, 1743822.0, 1773225.0, 1783026.0, 1792827.0, 1822230.0, 1832031.0, 1841832.0, 1969245.0, 1979046.0, 1988847.0, 2018250.0, 2028051.0, 2037852.0, 2067255.0, 2077056.0, 2086857.0, 2214270.0, 2224071.0, 2233872.0, 2263275.0, 2273076.0, 2282877.0, 2312280.0, 2322081.0, 2331882.0], &vec![1, 2, 3, 3, 3]));
527        }
528
529        {
530            let data = GenTensor::<f32>::arange(16).reshape(&vec![1, 1, 4, 4]);
531            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
532            let stride = vec![1, 1];
533            let padding = vec![1, 1];
534            let dilation = vec![1, 1];
535            let padding_mode = PaddingMode::Zeros;
536            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
537            println!("final output size: {:?}", result.size());
538            println!("final output: {:?}", result.get_data());
539            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![73.0, 121.0, 154.0, 103.0, 171.0, 258.0, 294.0, 186.0, 279.0, 402.0, 438.0, 270.0, 139.0, 187.0, 202.0, 113.0, 163.0, 283.0, 370.0, 265.0, 414.0, 663.0, 780.0, 537.0, 738.0, 1131.0, 1248.0, 837.0, 517.0, 781.0, 850.0, 563.0], &vec![1, 2, 4, 4]));
540        }
541
542        {
543            let data = GenTensor::<f32>::arange(49).reshape(&vec![1, 1, 7, 7]);
544            let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
545            let stride = vec![2, 2];
546            let padding = vec![0, 0];
547            let dilation = vec![1, 1];
548            let padding_mode = PaddingMode::Zeros;
549            let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
550            println!("final output size: {:?}", result.size());
551            println!("final output: {:?}", result.get_data());
552            assert_eq!(result, GenTensor::<f32>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
553        }
554    }
555
556    #[test]
557    fn conv_grad_gen() {
558
559        {
560            let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
561            let filter = GenTensor::<f32>::arange(54).reshape(&vec![2, 3, 3, 3]);
562            let output_grad = GenTensor::<f32>::arange(18).reshape(&vec![1, 2, 3, 3]);
563            
564            let stride = vec![1, 1];
565            let padding = vec![0, 0];
566            let dilation = vec![1, 1];
567            let padding_mode = PaddingMode::Zeros;
568            
569            let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
570            println!("w_grad: {:?}", w_grad);
571            println!("x_grad: {:?}", x_grad);
572        
573            assert_eq!(w_grad, GenTensor::new_raw(&vec![312.0, 348.0, 384.0, 492.0, 528.0, 564.0, 672.0, 708.0, 744.0, 1212.0, 1248.0, 1284.0, 1392.0, 1428.0, 1464.0, 1572.0, 1608.0, 1644.0, 2112.0, 2148.0, 2184.0, 2292.0, 2328.0, 2364.0, 2472.0, 2508.0, 2544.0, 798.0, 915.0, 1032.0, 1383.0, 1500.0, 1617.0, 1968.0, 2085.0, 2202.0, 3723.0, 3840.0, 3957.0, 4308.0, 4425.0, 4542.0, 4893.0, 5010.0, 5127.0, 6648.0, 6765.0, 6882.0, 7233.0, 7350.0, 7467.0, 7818.0, 7935.0, 8052.0], &vec![2, 3, 3, 3]));
574        }
575
576        {
577        
578            let data = GenTensor::<f32>::arange(60).reshape(&vec![1, 3, 5, 4]);
579            let filter = GenTensor::<f32>::arange(36).reshape(&vec![2, 3, 3, 2]);
580            let output_grad = GenTensor::<f32>::arange(18).reshape(&vec![1, 2, 3, 3]);
581            //println!("output_grad: {:?}", output_grad);
582            
583            let stride = vec![1, 1];
584            let padding = vec![0, 0];
585            let dilation = vec![1, 1];
586            let padding_mode = PaddingMode::Zeros;
587            
588            let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
589            println!("{:?}, {:?}, {:?}", w_grad, x_grad, output_grad);
590            //println!("w_grad: {:?}", w_grad);
591            assert_eq!(w_grad, GenTensor::new_raw(&vec![258.0, 294.0, 402.0, 438.0, 546.0, 582.0, 978.0, 1014.0, 1122.0, 1158.0, 1266.0, 1302.0, 1698.0, 1734.0, 1842.0, 1878.0, 1986.0, 2022.0, 663.0, 780.0, 1131.0, 1248.0, 1599.0, 1716.0, 3003.0, 3120.0, 3471.0, 3588.0, 3939.0, 4056.0, 5343.0, 5460.0, 5811.0, 5928.0, 6279.0, 6396.0], &vec![2, 3, 3, 2]));
592        
593        }
594
595
596        {
597            let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
598            let filter = GenTensor::<f32>::arange(54).reshape(&vec![2, 3, 3, 3]);
599            let output_grad = GenTensor::<f32>::arange(50).reshape(&vec![1, 2, 5, 5]);
600            
601            let stride = vec![1, 1];
602            let padding = vec![1, 1]; // <- THIS IS THE CHANGE
603            let dilation = vec![1, 1];
604            let padding_mode = PaddingMode::Zeros;
605            
606            let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
607            println!("w_grad: {:?}", w_grad);
608            println!("x_grad: {:?}", x_grad);
609        
610            assert_eq!(w_grad, GenTensor::new_raw(&vec![2680.0, 3420.0, 2760.0, 3900.0, 4900.0, 3900.0, 2760.0, 3420.0, 2680.0, 8680.0, 10670.0, 8360.0, 10150.0, 12400.0, 9650.0, 6760.0, 8170.0, 6280.0, 14680.0, 17920.0, 13960.0, 16400.0, 19900.0, 15400.0, 10760.0, 12920.0, 9880.0, 6280.0, 8170.0, 6760.0, 9650.0, 12400.0, 10150.0, 8360.0, 10670.0, 8680.0, 22280.0, 27920.0, 22360.0, 28400.0, 35525.0, 28400.0, 22360.0, 27920.0, 22280.0, 38280.0, 47670.0, 37960.0, 47150.0, 58650.0, 46650.0, 36360.0, 45170.0, 35880.0], &vec![2, 3, 3, 3]));
611        }
612
613        {
614            let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
615            let filter = GenTensor::<f32>::arange(150).reshape(&vec![2, 3, 5, 5]);
616            let output_grad = GenTensor::<f32>::arange(50).reshape(&vec![1, 2, 5, 5]);
617            
618            let stride = vec![1, 1];
619            let padding = vec![2, 2]; // <- THIS IS THE CHANGE
620            let dilation = vec![1, 1];
621            let padding_mode = PaddingMode::Zeros;
622            
623            let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
624            println!("w_grad: {:?}", w_grad);
625            println!("x_grad: {:?}", x_grad);
626        
627            assert_eq!(w_grad, GenTensor::new_raw(&vec![1128.0, 1580.0, 2065.0, 1700.0, 1308.0, 1964.0, 2680.0, 3420.0, 2760.0, 2084.0, 2905.0, 3900.0, 4900.0, 3900.0, 2905.0, 2084.0, 2760.0, 3420.0, 2680.0, 1964.0, 1308.0, 1700.0, 2065.0, 1580.0, 1128.0, 5178.0, 6830.0, 8440.0, 6650.0, 4908.0, 6614.0, 8680.0, 10670.0, 8360.0, 6134.0, 7780.0, 10150.0, 12400.0, 9650.0, 7030.0, 5234.0, 6760.0, 8170.0, 6280.0, 4514.0, 3108.0, 3950.0, 4690.0, 3530.0, 2478.0, 9228.0, 12080.0, 14815.0, 11600.0, 8508.0, 11264.0, 14680.0, 17920.0, 13960.0, 10184.0, 12655.0, 16400.0, 19900.0, 15400.0, 11155.0, 8384.0, 10760.0, 12920.0, 9880.0, 7064.0, 4908.0, 6200.0, 7315.0, 5480.0, 3828.0, 2478.0, 3530.0, 4690.0, 3950.0, 3108.0, 4514.0, 6280.0, 8170.0, 6760.0, 5234.0, 7030.0, 9650.0, 12400.0, 10150.0, 7780.0, 6134.0, 8360.0, 10670.0, 8680.0, 6614.0, 4908.0, 6650.0, 8440.0, 6830.0, 5178.0, 12153.0, 16280.0, 20440.0, 16400.0, 12333.0, 16664.0, 22280.0, 27920.0, 22360.0, 16784.0, 21280.0, 28400.0, 35525.0, 28400.0, 21280.0, 16784.0, 22360.0, 27920.0, 22280.0, 16664.0, 12333.0, 16400.0, 20440.0, 16280.0, 12153.0, 21828.0, 29030.0, 36190.0, 28850.0, 21558.0, 28814.0, 38280.0, 47670.0, 37960.0, 28334.0, 35530.0, 47150.0, 58650.0, 46650.0, 34780.0, 27434.0, 36360.0, 45170.0, 35880.0, 26714.0, 19758.0, 26150.0, 32440.0, 25730.0, 19128.0], &vec![2, 3, 5, 5]));
628        }
629
630        {
631            let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
632            let filter = GenTensor::<f32>::arange(150).reshape(&vec![2, 3, 5, 5]);
633            let output_grad = GenTensor::<f32>::arange(18).reshape(&vec![1, 2, 3, 3]);
634            
635            let stride = vec![2, 2]; // <- THIS IS THE CHANGE
636            let padding = vec![2, 2]; 
637            let dilation = vec![1, 1];
638            let padding_mode = PaddingMode::Zeros;
639            
640            let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
641            println!("w_grad: {:?}", w_grad);
642            println!("x_grad: {:?}", x_grad);
643        
644            assert_eq!(w_grad, GenTensor::new_raw(&vec![176.0, 200.0, 284.0, 172.0, 192.0, 296.0, 320.0, 449.0, 272.0, 292.0, 420.0, 447.0, 624.0, 375.0, 396.0, 164.0, 176.0, 233.0, 128.0, 136.0, 224.0, 236.0, 308.0, 168.0, 176.0, 776.0, 800.0, 1109.0, 672.0, 692.0, 896.0, 920.0, 1274.0, 772.0, 792.0, 1095.0, 1122.0, 1524.0, 900.0, 921.0, 464.0, 476.0, 608.0, 328.0, 336.0, 524.0, 536.0, 683.0, 368.0, 376.0, 1376.0, 1400.0, 1934.0, 1172.0, 1192.0, 1496.0, 1520.0, 2099.0, 1272.0, 1292.0, 1770.0, 1797.0, 2424.0, 1425.0, 1446.0, 764.0, 776.0, 983.0, 528.0, 536.0, 824.0, 836.0, 1058.0, 568.0, 576.0, 392.0, 452.0, 662.0, 424.0, 480.0, 692.0, 752.0, 1097.0, 704.0, 760.0, 1014.0, 1095.0, 1596.0, 1023.0, 1098.0, 560.0, 608.0, 881.0, 560.0, 604.0, 800.0, 848.0, 1226.0, 780.0, 824.0, 1892.0, 1952.0, 2837.0, 1824.0, 1880.0, 2192.0, 2252.0, 3272.0, 2104.0, 2160.0, 3039.0, 3120.0, 4521.0, 2898.0, 2973.0, 1760.0, 1808.0, 2606.0, 1660.0, 1704.0, 2000.0, 2048.0, 2951.0, 1880.0, 1924.0, 3392.0, 3452.0, 5012.0, 3224.0, 3280.0, 3692.0, 3752.0, 5447.0, 3504.0, 3560.0, 5064.0, 5145.0, 7446.0, 4773.0, 4848.0, 2960.0, 3008.0, 4331.0, 2760.0, 2804.0, 3200.0, 3248.0, 4676.0, 2980.0, 3024.0], &vec![2, 3, 5, 5]));
645        }
646    }
647    
648}