1use std::collections::BTreeMap;
2use crate::tensor::PaddingMode;
3use super::GenTensor;
4use crate::tensor_trait::convolution::Convolution;
5
6impl<T> Convolution for GenTensor<T> where T: num_traits::Float {
7
8 fn conv2d(&self, filter: &GenTensor<T>,
10 stride: (usize, usize),
11 padding: (usize, usize),
12 dilation: (usize, usize),
13 padding_mode: PaddingMode
14 ) -> Self {
15 self.conv_gen(filter,
16 &[stride.0, stride.1],
17 &[padding.0, padding.1],
18 &[dilation.0, dilation.1],
19 padding_mode)
20 }
21 fn conv2d_grad(&self, filter: &GenTensor<T>,
22 stride: (usize, usize),
23 padding: (usize, usize),
24 dilation: (usize, usize),
25 padding_mode: PaddingMode,
26 output_grad: &GenTensor<T>
27 ) -> (Self, Self){
28 self.conv_grad_gen(filter,
29 &[stride.0, stride.1],
30 &[padding.0, padding.1],
31 &[dilation.0, dilation.1],
32 padding_mode,
33 output_grad)
34 }
35
36 fn conv_gen(&self, filter: &GenTensor<T>,
38 stride: &[usize],
39 padding: &[usize],
40 dilation: &[usize],
41 padding_mode: PaddingMode
42 ) -> GenTensor<T> {
43 let self_dim = self.size();
44 let filter_dim = filter.size();
45 if self_dim.len() != filter_dim.len() {
46 panic!("covn2d expects input and filter has the same dims, get {:?}, {:?}", self_dim, filter_dim);
47 }
48 if stride.len() != padding.len() || stride.len() != dilation.len() || stride.len() != (self_dim.len() - 2) {
49 panic!("stride, padding, stride should have the same # of dims, {:?}, {:?}, {:?}", stride, padding, dilation);
50 }
51 if stride.iter().any(|x| *x < 1) {
52 panic!("stride should be at least 1, get {:?}", stride);
53 }
54 if dilation.iter().any(|x| *x < 1) {
55 panic!("dilation should be at least 1, get {:?}", dilation);
56 }
57
58 let out_channels = filter_dim[0];
59 let in_channels = filter_dim[1];
60 let sample_size = self_dim[0];
61 let data_channels = self_dim[1];
62 if in_channels != data_channels {
63 panic!("covn2d expects input data channel size matches depth in filter {:?}, {:?}", self_dim, filter_dim);
64 }
65
66 let mut padded_dim = Vec::new();
68 for i in 2..self_dim.len() {
69 padded_dim.push(self_dim[i] + padding[i-2]*2);
70 }
71 let mut start_point = Vec::new();
78 for i in 0..stride.len() {
79 let half = filter_dim[2+i]/2;
80 let dilated = half*dilation[i];
81 start_point.push(dilated);
82 }
83 let mut output_size = Vec::new();
86 for i in 0..stride.len() {
88 let output_dim = (padded_dim[i] - dilation[i]*(filter_dim[2+i]-1)-1)/stride[i] + 1;
89 output_size.push(output_dim);
90 }
91 let mut output_tensor_size = vec![sample_size, out_channels];
92 output_tensor_size.append(&mut output_size.clone()); let output_inner_size = output_size.iter().product::<usize>();
94 let mut ret = GenTensor::<T>::zeros(&output_tensor_size);
99
100 let conv_size = filter_dim.iter().product::<usize>()/out_channels; let mut data_block = vec![T::zero(); conv_size];
102 let mut filter_block = vec![T::zero(); conv_size];
103
104 let inner_steps = output_inner_size*out_channels;
105 let filter_step = conv_size;
106
107 for i in 0..sample_size {
108 for j in 0..out_channels {
109 filter_block.copy_from_slice(&filter.get_data()[(j)*filter_step..(j+1)*filter_step]);
110
111 let mut left_upper = vec![0; stride.len()];
112 for k in 0..output_inner_size {
113 let mut current_data_elem = left_upper.to_vec();
117 for in_channel_index in 0..in_channels {
118 for inner_index in 0..conv_size/in_channels {
119
120 let mut push_value = T::zero();
122 let mut in_margin = false;
123 for i in 0..current_data_elem.len() {
124 if current_data_elem[i] < padding[i] || current_data_elem[i] >= (padding[i] + self_dim[i+2]){
125 match padding_mode {
126 PaddingMode::Zeros => {
127 push_value = T::zero();
128 in_margin = true;
129 break;
130 },
131 _ => {unimplemented!();}
132 }
133 }
134 }
135 if ! in_margin {
136 let real_data_elem = current_data_elem.iter().zip(padding.iter()).map(|(x, y)| x - y).collect::<Vec::<usize>>();
137 let mut real_data_elem2 = vec![i, in_channel_index];
138 real_data_elem2.append(&mut real_data_elem.clone());
139 push_value = self.get(&real_data_elem2);
140 }
141
142 data_block[in_channel_index*(conv_size/in_channels) + inner_index] = push_value;
143
144
145 let mut current_pos = current_data_elem.len()-1;
147 loop {
148 current_data_elem[current_pos] += dilation[current_pos];
149 if current_data_elem[current_pos] >= dilation[current_pos]*filter_dim[current_pos+2] + left_upper[current_pos] {
150 current_data_elem[current_pos] = left_upper[current_pos];
151 if current_pos > 0 {
152 current_pos -= 1;
153 } else {
154 break;
155 }
156 } else {
157 break;
158 }
159 };
160 }
161 };
162
163 let mut value = T::zero();
167 for (x, y) in data_block.iter().zip(&filter_block) {
168 value = value + (*x)*(*y);
169 }
170 ret.set_1d(i*inner_steps + j*output_inner_size + k, value);
174
175 let mut current_pos = left_upper.len()-1;
177 loop {
178 left_upper[current_pos] += stride[current_pos];
179 let mut compare_pos = padded_dim[current_pos] - start_point[current_pos]*2;
180 if filter_dim[current_pos+2] % 2 == 0 {
181 compare_pos += 1;
182 }
183 if left_upper[current_pos] >= compare_pos {
184 left_upper[current_pos] = 0;
185 if current_pos > 0 {
186 current_pos -= 1;
187 } else {
188 break;
189 }
190 } else {
191 break;
192 }
193 };
194
195 }
196 }
197 }
198
199 ret
200 }
201
202 fn conv_grad_gen(&self, filter: &GenTensor<T>,
205 stride: &[usize],
206 padding: &[usize],
207 dilation: &[usize],
208 padding_mode: PaddingMode,
209 output_grad: &GenTensor<T>,
210 ) -> (GenTensor<T>, GenTensor<T>) {
211 let self_dim = self.size();
212 let filter_dim = filter.size();
213 let output_grad_dim = output_grad.size();
214 if self_dim.len() <= 2 {
215 panic!("input data for conv has not enough dim {:?}.", self_dim);
216 }
217 if filter_dim.len() <= 2 {
218 panic!("filter for conv has not enough dim {:?}.", filter_dim);
219 }
220 if output_grad_dim.len() <= 2 {
221 panic!("output gradient for conv has not enough dim {:?}.", filter_dim);
222 }
223 if self_dim.len() != filter_dim.len() || self_dim.len() != output_grad_dim.len() {
224 panic!("covn2d expects input, output gradient and filter has the same dims, get {:?}, {:?}, {:?}", self_dim, filter_dim, output_grad_dim);
225 }
226 if filter_dim[1] != self_dim[1] {
227 panic!("covn2d expects input data channel size matches depth in filter {:?}, {:?}", self_dim, filter_dim);
228 }
229 if self_dim[0] != output_grad_dim[0] {
230 panic!("conv2d expects input and output has the same N: {:?}, {:?}", self_dim, output_grad_dim);
231 }
232 if filter_dim[0] != output_grad_dim[1] {
233 panic!("conv2d expects filter and output has the same Cout: {:?}, {:?}", filter_dim, output_grad_dim);
234 }
235 if stride.len() != padding.len() || stride.len() != dilation.len() {
236 panic!("stride, padding, stride should have the same # of dims, {:?}, {:?}, {:?}", stride, padding, dilation);
237 }
238 if stride.len()+2 != filter_dim.len() {
239 panic!("expect the same inner size, {:?}, {:?}", stride, filter_dim);
240 }
241
242 let filter_size = filter.size();
243 let n_c_out = filter_size[0];
244 let n_c_in = filter_size[1];
245 let n_n = self_dim[0];
246 let n_f_dd = filter_dim.iter().product::<usize>()/n_c_out/n_c_in;
248 let d_inner = self_dim.len() - 2;
249
250 let output_dd = output_grad_dim.iter().product::<usize>()/n_n/n_c_out;
251
252 let mut w_grad: BTreeMap<usize, Vec<T>> = BTreeMap::new();
254 let mut x_grad: BTreeMap<usize, Vec<T>> = BTreeMap::new();
255
256 for i in 0..n_n {
257 for j in 0..n_c_out {
258 let mut left_upper = vec![0; d_inner];
260
261 let mut output_index = 0;
262
263 loop {
264 let output_real_index = j*output_dd + i*n_c_out*output_dd + output_index;
268 let output_dimpos = output_grad.index2dimpos(output_real_index);
270 let output_gradient_value = output_grad.get(&output_dimpos);
272 for cin_index in 0..n_c_in {
278 for dd_index in 0..n_f_dd {
279
280 let mut filter_elem = Vec::new();
282 let mut reminder = dd_index;
283 for dim_pos in 0..d_inner {
284 let left_product = filter_size[dim_pos+3..filter_size.len()]
285 .iter()
286 .product::<usize>();
287 filter_elem.push(reminder / left_product);
288 reminder %= left_product;
289 }
290 let mut data_elem = left_upper.to_vec();
295 for dim_pos in 0..d_inner {
296 data_elem[dim_pos] += filter_elem[dim_pos]*dilation[dim_pos];
297 }
298 let mut full_filter_elem = vec![j, cin_index];
303 full_filter_elem.append(&mut filter_elem.clone());
304 let mut zero_padded_flag = false;
309 let mut unpadded_elem = data_elem.clone();
310 for dim_pos in 0..d_inner {
312 if data_elem[dim_pos] < padding[dim_pos] {
313 match padding_mode {
314 PaddingMode::Zeros => {
315 zero_padded_flag = true;
316 },
317 PaddingMode::Reflect => {
318 unpadded_elem[dim_pos] = padding[dim_pos] - data_elem[dim_pos] - 1;
319 },
320 PaddingMode::Replicate => {
321 unpadded_elem[dim_pos] = 0;
322 },
323 PaddingMode::Circular => {
324 unpadded_elem[dim_pos] = self_dim[dim_pos+2] - (padding[dim_pos] - data_elem[dim_pos]);
325 },
326 }
327 } else if data_elem[dim_pos] >= self_dim[dim_pos + 2] + padding[dim_pos] {
328 match padding_mode {
329 PaddingMode::Zeros => {
330 zero_padded_flag = true;
331 },
332 PaddingMode::Reflect => {
333 unpadded_elem[dim_pos] = self_dim[dim_pos+2] - (data_elem[dim_pos] - (self_dim[dim_pos + 2] + padding[dim_pos]) + 1);
334 },
335 PaddingMode::Replicate => {
336 unpadded_elem[dim_pos] = self_dim[dim_pos + 2]-1;
337 },
338 PaddingMode::Circular => {
339 unpadded_elem[dim_pos] = data_elem[dim_pos] - (self_dim[dim_pos + 2] + padding[dim_pos]);
340 },
341 }
342 } else {
343 unpadded_elem[dim_pos] -= padding[dim_pos];
344 }
345 }
346
347 if zero_padded_flag {
348 continue;
349 } else {
350 let mut full_data_elem = vec![i, cin_index];
352 full_data_elem.append(&mut unpadded_elem.clone());
353 let filter_value = filter.get(&full_filter_elem);
356 let data_value = self.get(&full_data_elem);
357
358 let w_grad_value = output_gradient_value*data_value;
360 let x_grad_value = output_gradient_value*filter_value;
361
362 let total_w_index = filter.dimpos2index(&full_filter_elem);
363 let total_x_index = self.dimpos2index(&full_data_elem);
364
365 if let std::collections::btree_map::Entry::Vacant(e) = w_grad.entry(total_w_index) {
378 e.insert(vec![w_grad_value]);
379 } else {
380 w_grad.get_mut(&total_w_index).expect("").push(w_grad_value);
381 }
382
383 if let std::collections::btree_map::Entry::Vacant(e) = x_grad.entry(total_x_index) {
384 e.insert(vec![x_grad_value]);
385 } else {
386 x_grad.get_mut(&total_x_index).expect("").push(x_grad_value);
387 }
388 }
389
390 }
391 }
392
393 for current_pos in 0..d_inner {
395 let real_pos = d_inner - current_pos - 1;
396 left_upper[real_pos] += stride[real_pos];
397
398 let compare_pos = self_dim[real_pos+2]
399 + padding[real_pos]*2
400 - ((filter_dim[real_pos + 2]-1)*dilation[real_pos] + 1);
401
402 if left_upper[real_pos] > compare_pos {
403 left_upper[real_pos] = 0;
404 } else {
405 break;
406 }
407 }
408 if left_upper.iter().sum::<usize>() == 0 {
409 break;
410 }
411 output_index += 1;
412 };
413 }
414 }
415
416 let mut ret_w_grad = GenTensor::zeros(filter.size());
417 let mut ret_x_grad = GenTensor::zeros(self.size());
418
419 for i in w_grad.keys() {
420 let mut sum = T::zero();
422 for w_value in w_grad.get(i).expect("") {
423 sum = sum + *w_value;
424 }
426 ret_w_grad.set_1d(*i, sum);
429 }
430 for i in x_grad.keys() {
431 let mut sum = T::zero();
433 for x_value in x_grad.get(i).expect("") {
434 sum = sum + *x_value;
435 }
437 ret_x_grad.set_1d(*i, sum);
440 }
441
442 (ret_w_grad, ret_x_grad)
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use crate::tensor_impl::gen_tensor::GenTensor;
449 use crate::tensor_trait::index_slicing::IndexSlicing;
450 use super::*;
451
452 #[test]
453 fn conv_gen() {
454
455 {
456 let data = GenTensor::<f32>::arange(30).reshape(&vec![2, 3, 5]);
457 let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 3, 3]);
458 let stride = vec![1];
459 let padding = vec![0];
460 let dilation = vec![1];
461 let padding_mode = PaddingMode::Zeros;
462 let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
463 println!("output size: {:?}", result.size());
464 println!("output size: {:?}", result.get_data());
465 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![312.0, 348.0, 384.0, 798.0, 915.0, 1032.0, 852.0, 888.0, 924.0, 2553.0, 2670.0, 2787.0], &vec![2, 2, 3]));
466 }
467
468 {
469 let mut raw_data = Vec::new();
470 for i in 0..75 {
471 raw_data.push(i as f32);
472 }
473 let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 5]);
474 let mut raw_data = Vec::new();
475 for i in 0..54 {
476 raw_data.push(i as f32);
477 }
478 let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 3]);
479
480 let stride = vec![1, 1];
481 let padding = vec![0, 0];
482 let dilation = vec![1, 1];
483 let padding_mode = PaddingMode::Zeros;
484
485 let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
486
487 println!("output size: {:?}", result.size());
488 println!("output size: {:?}", result.get_data());
489 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![15219.0, 15570.0, 15921.0, 16974.0, 17325.0, 17676.0, 18729.0, 19080.0, 19431.0, 37818.0, 38898.0, 39978.0, 43218.0, 44298.0, 45378.0, 48618.0, 49698.0, 50778.0], &vec![1, 2, 3, 3]));
490 }
491
492 {
493 let mut raw_data = Vec::new();
494 for i in 0..60 {
495 raw_data.push(i as f32);
496 }
497 let data = GenTensor::<f32>::new_raw(&raw_data, &vec![1, 3, 5, 4]);
498 let mut raw_data = Vec::new();
499 for i in 0..36 {
500 raw_data.push(i as f32);
501 }
502 let filter = GenTensor::<f32>::new_raw(&raw_data, &vec![2, 3, 3, 2]);
503
504 let stride = vec![1, 1];
505 let padding = vec![0, 0];
506 let dilation = vec![1, 1];
507 let padding_mode = PaddingMode::Zeros;
508
509 let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
510
511 println!("output size: {:?}", result.size());
512 println!("output size: {:?}", result.get_data());
513 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![5289.0, 5442.0, 5595.0, 5901.0, 6054.0, 6207.0, 6513.0, 6666.0, 6819.0, 13227.0, 13704.0, 14181.0, 15135.0, 15612.0, 16089.0, 17043.0, 17520.0, 17997.0], &vec![1, 2, 3, 3]));
514 }
515
516 {
517 let data = GenTensor::<f32>::arange(375).reshape(&vec![1, 3, 5, 5, 5]);
518 let filter = GenTensor::<f32>::arange(162).reshape(&vec![2, 3, 3, 3, 3]);
519 let stride = vec![1, 1, 1];
520 let padding = vec![0, 0, 0];
521 let dilation = vec![1, 1, 1];
522 let padding_mode = PaddingMode::Zeros;
523 let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
524 println!("output size: {:?}", result.size());
525 println!("output size: {:?}", result.get_data());
526 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![700704.0, 703944.0, 707184.0, 716904.0, 720144.0, 723384.0, 733104.0, 736344.0, 739584.0, 781704.0, 784944.0, 788184.0, 797904.0, 801144.0, 804384.0, 814104.0, 817344.0, 820584.0, 862704.0, 865944.0, 869184.0, 878904.0, 882144.0, 885384.0, 895104.0, 898344.0, 901584.0, 1724220.0, 1734021.0, 1743822.0, 1773225.0, 1783026.0, 1792827.0, 1822230.0, 1832031.0, 1841832.0, 1969245.0, 1979046.0, 1988847.0, 2018250.0, 2028051.0, 2037852.0, 2067255.0, 2077056.0, 2086857.0, 2214270.0, 2224071.0, 2233872.0, 2263275.0, 2273076.0, 2282877.0, 2312280.0, 2322081.0, 2331882.0], &vec![1, 2, 3, 3, 3]));
527 }
528
529 {
530 let data = GenTensor::<f32>::arange(16).reshape(&vec![1, 1, 4, 4]);
531 let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
532 let stride = vec![1, 1];
533 let padding = vec![1, 1];
534 let dilation = vec![1, 1];
535 let padding_mode = PaddingMode::Zeros;
536 let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
537 println!("final output size: {:?}", result.size());
538 println!("final output: {:?}", result.get_data());
539 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![73.0, 121.0, 154.0, 103.0, 171.0, 258.0, 294.0, 186.0, 279.0, 402.0, 438.0, 270.0, 139.0, 187.0, 202.0, 113.0, 163.0, 283.0, 370.0, 265.0, 414.0, 663.0, 780.0, 537.0, 738.0, 1131.0, 1248.0, 837.0, 517.0, 781.0, 850.0, 563.0], &vec![1, 2, 4, 4]));
540 }
541
542 {
543 let data = GenTensor::<f32>::arange(49).reshape(&vec![1, 1, 7, 7]);
544 let filter = GenTensor::<f32>::arange(18).reshape(&vec![2, 1, 3, 3]);
545 let stride = vec![2, 2];
546 let padding = vec![0, 0];
547 let dilation = vec![1, 1];
548 let padding_mode = PaddingMode::Zeros;
549 let result = data.conv_gen(&filter, &stride, &padding, &dilation, padding_mode);
550 println!("final output size: {:?}", result.size());
551 println!("final output: {:?}", result.get_data());
552 assert_eq!(result, GenTensor::<f32>::new_raw(&vec![420.0, 492.0, 564.0, 924.0, 996.0, 1068.0, 1428.0, 1500.0, 1572.0, 1068.0, 1302.0, 1536.0, 2706.0, 2940.0, 3174.0, 4344.0, 4578.0, 4812.0], &vec![1, 2, 3, 3]));
553 }
554 }
555
556 #[test]
557 fn conv_grad_gen() {
558
559 {
560 let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
561 let filter = GenTensor::<f32>::arange(54).reshape(&vec![2, 3, 3, 3]);
562 let output_grad = GenTensor::<f32>::arange(18).reshape(&vec![1, 2, 3, 3]);
563
564 let stride = vec![1, 1];
565 let padding = vec![0, 0];
566 let dilation = vec![1, 1];
567 let padding_mode = PaddingMode::Zeros;
568
569 let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
570 println!("w_grad: {:?}", w_grad);
571 println!("x_grad: {:?}", x_grad);
572
573 assert_eq!(w_grad, GenTensor::new_raw(&vec![312.0, 348.0, 384.0, 492.0, 528.0, 564.0, 672.0, 708.0, 744.0, 1212.0, 1248.0, 1284.0, 1392.0, 1428.0, 1464.0, 1572.0, 1608.0, 1644.0, 2112.0, 2148.0, 2184.0, 2292.0, 2328.0, 2364.0, 2472.0, 2508.0, 2544.0, 798.0, 915.0, 1032.0, 1383.0, 1500.0, 1617.0, 1968.0, 2085.0, 2202.0, 3723.0, 3840.0, 3957.0, 4308.0, 4425.0, 4542.0, 4893.0, 5010.0, 5127.0, 6648.0, 6765.0, 6882.0, 7233.0, 7350.0, 7467.0, 7818.0, 7935.0, 8052.0], &vec![2, 3, 3, 3]));
574 }
575
576 {
577
578 let data = GenTensor::<f32>::arange(60).reshape(&vec![1, 3, 5, 4]);
579 let filter = GenTensor::<f32>::arange(36).reshape(&vec![2, 3, 3, 2]);
580 let output_grad = GenTensor::<f32>::arange(18).reshape(&vec![1, 2, 3, 3]);
581 let stride = vec![1, 1];
584 let padding = vec![0, 0];
585 let dilation = vec![1, 1];
586 let padding_mode = PaddingMode::Zeros;
587
588 let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
589 println!("{:?}, {:?}, {:?}", w_grad, x_grad, output_grad);
590 assert_eq!(w_grad, GenTensor::new_raw(&vec![258.0, 294.0, 402.0, 438.0, 546.0, 582.0, 978.0, 1014.0, 1122.0, 1158.0, 1266.0, 1302.0, 1698.0, 1734.0, 1842.0, 1878.0, 1986.0, 2022.0, 663.0, 780.0, 1131.0, 1248.0, 1599.0, 1716.0, 3003.0, 3120.0, 3471.0, 3588.0, 3939.0, 4056.0, 5343.0, 5460.0, 5811.0, 5928.0, 6279.0, 6396.0], &vec![2, 3, 3, 2]));
592
593 }
594
595
596 {
597 let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
598 let filter = GenTensor::<f32>::arange(54).reshape(&vec![2, 3, 3, 3]);
599 let output_grad = GenTensor::<f32>::arange(50).reshape(&vec![1, 2, 5, 5]);
600
601 let stride = vec![1, 1];
602 let padding = vec![1, 1]; let dilation = vec![1, 1];
604 let padding_mode = PaddingMode::Zeros;
605
606 let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
607 println!("w_grad: {:?}", w_grad);
608 println!("x_grad: {:?}", x_grad);
609
610 assert_eq!(w_grad, GenTensor::new_raw(&vec![2680.0, 3420.0, 2760.0, 3900.0, 4900.0, 3900.0, 2760.0, 3420.0, 2680.0, 8680.0, 10670.0, 8360.0, 10150.0, 12400.0, 9650.0, 6760.0, 8170.0, 6280.0, 14680.0, 17920.0, 13960.0, 16400.0, 19900.0, 15400.0, 10760.0, 12920.0, 9880.0, 6280.0, 8170.0, 6760.0, 9650.0, 12400.0, 10150.0, 8360.0, 10670.0, 8680.0, 22280.0, 27920.0, 22360.0, 28400.0, 35525.0, 28400.0, 22360.0, 27920.0, 22280.0, 38280.0, 47670.0, 37960.0, 47150.0, 58650.0, 46650.0, 36360.0, 45170.0, 35880.0], &vec![2, 3, 3, 3]));
611 }
612
613 {
614 let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
615 let filter = GenTensor::<f32>::arange(150).reshape(&vec![2, 3, 5, 5]);
616 let output_grad = GenTensor::<f32>::arange(50).reshape(&vec![1, 2, 5, 5]);
617
618 let stride = vec![1, 1];
619 let padding = vec![2, 2]; let dilation = vec![1, 1];
621 let padding_mode = PaddingMode::Zeros;
622
623 let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
624 println!("w_grad: {:?}", w_grad);
625 println!("x_grad: {:?}", x_grad);
626
627 assert_eq!(w_grad, GenTensor::new_raw(&vec![1128.0, 1580.0, 2065.0, 1700.0, 1308.0, 1964.0, 2680.0, 3420.0, 2760.0, 2084.0, 2905.0, 3900.0, 4900.0, 3900.0, 2905.0, 2084.0, 2760.0, 3420.0, 2680.0, 1964.0, 1308.0, 1700.0, 2065.0, 1580.0, 1128.0, 5178.0, 6830.0, 8440.0, 6650.0, 4908.0, 6614.0, 8680.0, 10670.0, 8360.0, 6134.0, 7780.0, 10150.0, 12400.0, 9650.0, 7030.0, 5234.0, 6760.0, 8170.0, 6280.0, 4514.0, 3108.0, 3950.0, 4690.0, 3530.0, 2478.0, 9228.0, 12080.0, 14815.0, 11600.0, 8508.0, 11264.0, 14680.0, 17920.0, 13960.0, 10184.0, 12655.0, 16400.0, 19900.0, 15400.0, 11155.0, 8384.0, 10760.0, 12920.0, 9880.0, 7064.0, 4908.0, 6200.0, 7315.0, 5480.0, 3828.0, 2478.0, 3530.0, 4690.0, 3950.0, 3108.0, 4514.0, 6280.0, 8170.0, 6760.0, 5234.0, 7030.0, 9650.0, 12400.0, 10150.0, 7780.0, 6134.0, 8360.0, 10670.0, 8680.0, 6614.0, 4908.0, 6650.0, 8440.0, 6830.0, 5178.0, 12153.0, 16280.0, 20440.0, 16400.0, 12333.0, 16664.0, 22280.0, 27920.0, 22360.0, 16784.0, 21280.0, 28400.0, 35525.0, 28400.0, 21280.0, 16784.0, 22360.0, 27920.0, 22280.0, 16664.0, 12333.0, 16400.0, 20440.0, 16280.0, 12153.0, 21828.0, 29030.0, 36190.0, 28850.0, 21558.0, 28814.0, 38280.0, 47670.0, 37960.0, 28334.0, 35530.0, 47150.0, 58650.0, 46650.0, 34780.0, 27434.0, 36360.0, 45170.0, 35880.0, 26714.0, 19758.0, 26150.0, 32440.0, 25730.0, 19128.0], &vec![2, 3, 5, 5]));
628 }
629
630 {
631 let data = GenTensor::<f32>::arange(75).reshape(&vec![1, 3, 5, 5]);
632 let filter = GenTensor::<f32>::arange(150).reshape(&vec![2, 3, 5, 5]);
633 let output_grad = GenTensor::<f32>::arange(18).reshape(&vec![1, 2, 3, 3]);
634
635 let stride = vec![2, 2]; let padding = vec![2, 2];
637 let dilation = vec![1, 1];
638 let padding_mode = PaddingMode::Zeros;
639
640 let (w_grad, x_grad) = data.conv_grad_gen(&filter, &stride, &padding, &dilation, padding_mode, &output_grad);
641 println!("w_grad: {:?}", w_grad);
642 println!("x_grad: {:?}", x_grad);
643
644 assert_eq!(w_grad, GenTensor::new_raw(&vec![176.0, 200.0, 284.0, 172.0, 192.0, 296.0, 320.0, 449.0, 272.0, 292.0, 420.0, 447.0, 624.0, 375.0, 396.0, 164.0, 176.0, 233.0, 128.0, 136.0, 224.0, 236.0, 308.0, 168.0, 176.0, 776.0, 800.0, 1109.0, 672.0, 692.0, 896.0, 920.0, 1274.0, 772.0, 792.0, 1095.0, 1122.0, 1524.0, 900.0, 921.0, 464.0, 476.0, 608.0, 328.0, 336.0, 524.0, 536.0, 683.0, 368.0, 376.0, 1376.0, 1400.0, 1934.0, 1172.0, 1192.0, 1496.0, 1520.0, 2099.0, 1272.0, 1292.0, 1770.0, 1797.0, 2424.0, 1425.0, 1446.0, 764.0, 776.0, 983.0, 528.0, 536.0, 824.0, 836.0, 1058.0, 568.0, 576.0, 392.0, 452.0, 662.0, 424.0, 480.0, 692.0, 752.0, 1097.0, 704.0, 760.0, 1014.0, 1095.0, 1596.0, 1023.0, 1098.0, 560.0, 608.0, 881.0, 560.0, 604.0, 800.0, 848.0, 1226.0, 780.0, 824.0, 1892.0, 1952.0, 2837.0, 1824.0, 1880.0, 2192.0, 2252.0, 3272.0, 2104.0, 2160.0, 3039.0, 3120.0, 4521.0, 2898.0, 2973.0, 1760.0, 1808.0, 2606.0, 1660.0, 1704.0, 2000.0, 2048.0, 2951.0, 1880.0, 1924.0, 3392.0, 3452.0, 5012.0, 3224.0, 3280.0, 3692.0, 3752.0, 5447.0, 3504.0, 3560.0, 5064.0, 5145.0, 7446.0, 4773.0, 4848.0, 2960.0, 3008.0, 4331.0, 2760.0, 2804.0, 3200.0, 3248.0, 4676.0, 2980.0, 3024.0], &vec![2, 3, 5, 5]));
645 }
646 }
647
648}