1use crate::{
2 device::{cpu::Cpu, Device, DeviceBase},
3 dim::{DimDyn, DimTrait},
4 matrix::{Matrix, Owned, Ref},
5 num::Num,
6};
7
8mod conv2d_bckwd_filter_cpu;
9mod conv2d_cpu_impl;
10mod deconv2d_cpu_impl;
11
12use self::{
13 conv2d_bckwd_filter_cpu::conv2d_bckwd_fileter, conv2d_cpu_impl::conv2d_inner,
14 deconv2d_cpu_impl::deconv2d_inner,
15};
16
17#[cfg(feature = "nvidia")]
18use zenu_cuda::{
19 cudnn::{
20 conv::{backward_bias, ConvolutionBuilder, ConvolutionConfig},
21 TensorFormat,
22 },
23 kernel::conv_bias_add,
24};
25
26#[cfg(feature = "nvidia")]
27use crate::device::nvidia::Nvidia;
28
29#[must_use]
30pub fn conv2d_out_size(
31 img_shape: &[usize],
32 kernel_shape: &[usize],
33 padding: (usize, usize),
34 stride: (usize, usize),
35) -> [usize; 4] {
36 let (b, h, w) = (img_shape[0], img_shape[2], img_shape[3]);
37 let (oc, kh, kw) = (kernel_shape[0], kernel_shape[2], kernel_shape[3]);
38 let (ph, pw) = padding;
39 let (sh, sw) = stride;
40 let (h, w) = ((h + 2 * ph - kh) / sh + 1, (w + 2 * pw - kw) / sw + 1);
41 [b, oc, h, w]
42}
43
44pub(super) fn get_deconv_outsize_(size: usize, k: usize, s: usize, p: usize) -> usize {
45 s * (size - 1) + k - 2 * p
46}
47
48#[must_use]
49pub fn deconv2d_out_size(
50 img_shape: &[usize],
51 kernel_shape: &[usize],
52 padding: (usize, usize),
53 stride: (usize, usize),
54) -> [usize; 4] {
55 let (b, h, w) = (img_shape[0], img_shape[2], img_shape[3]);
56 let (ic, kh, kw) = (kernel_shape[1], kernel_shape[2], kernel_shape[3]);
57 let (ph, pw) = padding;
58 let (sh, sw) = stride;
59 let (h, w) = (
60 get_deconv_outsize_(h, kh, sh, ph),
61 get_deconv_outsize_(w, kw, sw, pw),
62 );
63 [b, ic, h, w]
64}
65
66pub struct Conv2dConfig<T: Num> {
67 #[cfg(feature = "nvidia")]
68 pub conv: ConvolutionConfig<T>,
69 _phantom: std::marker::PhantomData<T>,
70}
71
72#[expect(clippy::too_many_arguments)]
73#[allow(unused_variables)]
74#[must_use]
75pub fn create_conv_descriptor<T: Num>(
76 input_shape: &[usize],
77 output_shape: &[usize],
78 filter_shape: &[usize],
79 pad_h: usize,
80 pad_w: usize,
81 stride_h: usize,
82 stride_w: usize,
83 dilation_h: usize,
84 dilation_w: usize,
85 groups: usize,
86) -> Conv2dConfig<T> {
87 #[cfg(feature = "nvidia")]
88 let conv = {
89 let input_shape = input_shape
90 .iter()
91 .map(|x| i32::try_from(*x).unwrap())
92 .collect::<Vec<_>>();
93 let output_shape = output_shape
94 .iter()
95 .map(|x| i32::try_from(*x).unwrap())
96 .collect::<Vec<_>>();
97 let filter_shape = filter_shape
98 .iter()
99 .map(|x| i32::try_from(*x).unwrap())
100 .collect::<Vec<_>>();
101
102 let input_shape_0: i32 = input_shape[0];
103 let input_shape_1: i32 = input_shape[1];
104 let input_shape_2: i32 = input_shape[2];
105 let input_shape_3: i32 = input_shape[3];
106
107 let output_shape_0: i32 = output_shape[0];
108 let output_shape_1: i32 = output_shape[1];
109 let output_shape_2: i32 = output_shape[2];
110 let output_shape_3: i32 = output_shape[3];
111
112 let filter_shape_0: i32 = filter_shape[0];
113 let filter_shape_1: i32 = filter_shape[1];
114 let filter_shape_2: i32 = filter_shape[2];
115 let filter_shape_3: i32 = filter_shape[3];
116
117 let pad_h: i32 = pad_h.try_into().unwrap();
118 let pad_w: i32 = pad_w.try_into().unwrap();
119
120 let stride_h: i32 = stride_h.try_into().unwrap();
121 let stride_w: i32 = stride_w.try_into().unwrap();
122
123 let dilation_h: i32 = dilation_h.try_into().unwrap();
124 let dilation_w: i32 = dilation_w.try_into().unwrap();
125
126 let conv = ConvolutionBuilder::<T>::default()
127 .input(
128 input_shape_0,
129 input_shape_1,
130 input_shape_2,
131 input_shape_3,
132 TensorFormat::NCHW,
133 )
134 .unwrap()
135 .filter(
136 filter_shape_0,
137 filter_shape_1,
138 filter_shape_2,
139 filter_shape_3,
140 TensorFormat::NCHW,
141 )
142 .unwrap()
143 .output(
144 output_shape_0,
145 output_shape_1,
146 output_shape_2,
147 output_shape_3,
148 TensorFormat::NCHW,
149 )
150 .unwrap()
151 .conv(pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w)
152 .unwrap()
153 .algorithms(groups);
154
155 conv.build()
156 };
157
158 Conv2dConfig {
159 #[cfg(feature = "nvidia")]
160 conv,
161 _phantom: std::marker::PhantomData,
162 }
163}
164
165pub trait Conv2d: DeviceBase {
166 #[expect(clippy::too_many_arguments)]
167 fn conv2d<T: Num>(
168 input: Matrix<Ref<&T>, DimDyn, Self>,
169 y: Matrix<Ref<&mut T>, DimDyn, Self>,
170 filter: Matrix<Ref<&T>, DimDyn, Self>,
171 pad_h: usize,
172 pad_w: usize,
173 stride_h: usize,
174 stride_w: usize,
175 dilation_h: usize,
176 dilation_w: usize,
177 config: Option<&Conv2dConfig<T>>,
178 );
179
180 #[expect(clippy::too_many_arguments)]
181 fn conv2d_bckwd_data<T: Num>(
182 dy: Matrix<Ref<&T>, DimDyn, Self>,
183 dx: Matrix<Ref<&mut T>, DimDyn, Self>,
184 filter: Matrix<Ref<&T>, DimDyn, Self>,
185 pad_h: usize,
186 pad_w: usize,
187 stride_h: usize,
188 stride_w: usize,
189 dilation_h: usize,
190 dilation_w: usize,
191 config: Option<&Conv2dConfig<T>>,
192 );
193
194 #[expect(clippy::too_many_arguments)]
195 fn conv2d_bckwd_filter<T: Num>(
196 input: Matrix<Ref<&T>, DimDyn, Self>,
197 dy: Matrix<Ref<&T>, DimDyn, Self>,
198 df: Matrix<Ref<&mut T>, DimDyn, Self>,
199 pad_h: usize,
200 pad_w: usize,
201 stride_h: usize,
202 stride_w: usize,
203 dilation_h: usize,
204 dilation_w: usize,
205 config: Option<&Conv2dConfig<T>>,
206 );
207
208 fn conv2d_forward_bias<T: Num>(
209 input: Matrix<Ref<&T>, DimDyn, Self>,
210 y: Matrix<Ref<&mut T>, DimDyn, Self>,
211 bias: Matrix<Ref<&T>, DimDyn, Self>,
212 );
213
214 fn conv2d_bckwd_bias<T: Num>(
215 dy: Matrix<Ref<&T>, DimDyn, Self>,
216 dx: Matrix<Ref<&mut T>, DimDyn, Self>,
217 );
218}
219
220impl Conv2d for Cpu {
221 fn conv2d<T: Num>(
222 input: Matrix<Ref<&T>, DimDyn, Self>,
223 y: Matrix<Ref<&mut T>, DimDyn, Self>,
224 filter: Matrix<Ref<&T>, DimDyn, Self>,
225 pad_h: usize,
226 pad_w: usize,
227 stride_h: usize,
228 stride_w: usize,
229 dilation_h: usize,
230 dilation_w: usize,
231 _config: Option<&Conv2dConfig<T>>,
232 ) {
233 if dilation_h != 1 || dilation_w != 1 {
234 todo!();
235 }
236 y.copy_from(&conv2d_inner(
237 input,
238 filter,
239 None,
240 (pad_h, pad_w),
241 (stride_h, stride_w),
242 ));
243 }
244
245 fn conv2d_bckwd_data<T: Num>(
246 dy: Matrix<Ref<&T>, DimDyn, Self>,
247 dx: Matrix<Ref<&mut T>, DimDyn, Self>,
248 filter: Matrix<Ref<&T>, DimDyn, Self>,
249 pad_h: usize,
250 pad_w: usize,
251 stride_h: usize,
252 stride_w: usize,
253 dilation_h: usize,
254 dilation_w: usize,
255 _config: Option<&Conv2dConfig<T>>,
256 ) {
257 if dilation_h != 1 || dilation_w != 1 {
258 todo!();
259 }
260 dx.copy_from(&deconv2d_inner(
261 dy,
262 filter,
263 None,
264 (pad_h, pad_w),
265 (stride_h, stride_w),
266 ));
267 }
268
269 fn conv2d_bckwd_filter<T: Num>(
270 input: Matrix<Ref<&T>, DimDyn, Self>,
271 dy: Matrix<Ref<&T>, DimDyn, Self>,
272 df: Matrix<Ref<&mut T>, DimDyn, Self>,
273 pad_h: usize,
274 pad_w: usize,
275 stride_h: usize,
276 stride_w: usize,
277 dilation_h: usize,
278 dilation_w: usize,
279 _config: Option<&Conv2dConfig<T>>,
280 ) {
281 if dilation_h != 1 || dilation_w != 1 {
282 todo!();
283 }
284 df.copy_from(&conv2d_bckwd_fileter(
285 input,
286 df.to_ref().shape(),
287 dy,
288 (pad_h, pad_w),
289 (stride_h, stride_w),
290 ));
291 }
292
293 fn conv2d_forward_bias<T: Num>(
294 input: Matrix<Ref<&T>, DimDyn, Self>,
295 mut y: Matrix<Ref<&mut T>, DimDyn, Self>,
296 bias: Matrix<Ref<&T>, DimDyn, Self>,
297 ) {
298 y.add_array(&input, &bias);
299 }
300
301 fn conv2d_bckwd_bias<T: Num>(
302 dy: Matrix<Ref<&T>, DimDyn, Self>,
303 dx: Matrix<Ref<&mut T>, DimDyn, Self>,
304 ) {
305 let dy_0 = dy.sum(0, true);
306 let dy_0_2 = dy_0.to_ref().sum(2, true);
307 dx.copy_from(&dy_0_2.to_ref().sum(3, true));
308 }
309}
310
311#[cfg(feature = "nvidia")]
312impl Conv2d for Nvidia {
313 fn conv2d<T: Num>(
314 input: Matrix<Ref<&T>, DimDyn, Self>,
315 y: Matrix<Ref<&mut T>, DimDyn, Self>,
316 filter: Matrix<Ref<&T>, DimDyn, Self>,
317 pad_h: usize,
318 pad_w: usize,
319 stride_h: usize,
320 stride_w: usize,
321 dilation_h: usize,
322 dilation_w: usize,
323 config: Option<&Conv2dConfig<T>>,
324 ) {
325 let config = match config {
326 Some(config) => &config.conv,
327 None => {
328 &create_conv_descriptor::<T>(
329 input.shape().slice(),
330 y.shape().slice(),
331 filter.shape().slice(),
332 pad_h,
333 pad_w,
334 stride_h,
335 stride_w,
336 dilation_h,
337 dilation_w,
338 1,
339 )
340 .conv
341 }
342 };
343
344 config.forward(
345 T::one(),
346 input.as_ptr(),
347 filter.as_ptr(),
348 T::zero(),
349 y.as_mut_ptr(),
350 );
351 }
352
353 fn conv2d_bckwd_data<T: Num>(
354 dy: Matrix<Ref<&T>, DimDyn, Self>,
355 dx: Matrix<Ref<&mut T>, DimDyn, Self>,
356 filter: Matrix<Ref<&T>, DimDyn, Self>,
357 pad_h: usize,
358 pad_w: usize,
359 stride_h: usize,
360 stride_w: usize,
361 dilation_h: usize,
362 dilation_w: usize,
363 config: Option<&Conv2dConfig<T>>,
364 ) {
365 let config = match config {
366 Some(config) => &config.conv,
367 None => {
368 &create_conv_descriptor::<T>(
369 dx.shape().slice(),
370 dy.shape().slice(),
371 filter.shape().slice(),
372 pad_h,
373 pad_w,
374 stride_h,
375 stride_w,
376 dilation_h,
377 dilation_w,
378 1,
379 )
380 .conv
381 }
382 };
383
384 config.backward_data(
385 T::one(),
386 filter.as_ptr(),
387 dy.as_ptr(),
388 T::zero(),
389 dx.as_mut_ptr(),
390 );
391 }
392
393 fn conv2d_bckwd_filter<T: Num>(
394 input: Matrix<Ref<&T>, DimDyn, Self>,
395 dy: Matrix<Ref<&T>, DimDyn, Self>,
396 df: Matrix<Ref<&mut T>, DimDyn, Self>,
397 pad_h: usize,
398 pad_w: usize,
399 stride_h: usize,
400 stride_w: usize,
401 dilation_h: usize,
402 dilation_w: usize,
403 config: Option<&Conv2dConfig<T>>,
404 ) {
405 let config = match config {
406 Some(config) => &config.conv,
407 None => {
408 &create_conv_descriptor::<T>(
409 input.shape().slice(),
410 dy.shape().slice(),
411 df.shape().slice(),
412 pad_h,
413 pad_w,
414 stride_h,
415 stride_w,
416 dilation_h,
417 dilation_w,
418 1,
419 )
420 .conv
421 }
422 };
423
424 config.backward_filter(
425 T::one(),
426 input.as_ptr(),
427 dy.as_ptr(),
428 T::zero(),
429 df.as_mut_ptr(),
430 );
431 }
432
433 fn conv2d_forward_bias<T: Num>(
434 input: Matrix<Ref<&T>, DimDyn, Self>,
435 y: Matrix<Ref<&mut T>, DimDyn, Self>,
436 bias: Matrix<Ref<&T>, DimDyn, Self>,
437 ) {
438 let input_channel_stride = input.stride()[1];
439 let input_num_elm = input.shape().num_elm();
440 let bias_num_elm = bias.shape().num_elm();
441
442 conv_bias_add(
443 input.as_ptr(),
444 bias.as_ptr(),
445 input_num_elm,
446 input_channel_stride,
447 bias_num_elm,
448 y.as_mut_ptr(),
449 );
450 }
451
452 fn conv2d_bckwd_bias<T: Num>(
453 dy: Matrix<Ref<&T>, DimDyn, Self>,
454 dx: Matrix<Ref<&mut T>, DimDyn, Self>,
455 ) {
456 let dy_shape = dy.shape();
457 let dy_shape = dy_shape.slice();
458 backward_bias(T::one(), dy.as_ptr(), T::zero(), dx.as_mut_ptr(), dy_shape);
459 }
460}
461
462#[expect(clippy::too_many_arguments)]
463#[must_use]
464pub fn conv2d_forward<T: Num, D: Device>(
465 input: Matrix<Ref<&T>, DimDyn, D>,
466 filter: Matrix<Ref<&T>, DimDyn, D>,
467 pad_h: usize,
468 pad_w: usize,
469 stride_h: usize,
470 stride_w: usize,
471 dilation_h: usize,
472 dilation_w: usize,
473 config: Option<&Conv2dConfig<T>>,
474) -> Matrix<Owned<T>, DimDyn, D> {
475 let out_size = conv2d_out_size(
476 input.shape().slice(),
477 filter.shape().slice(),
478 (pad_h, pad_w),
479 (stride_h, stride_w),
480 );
481 let mut y = Matrix::<Owned<T>, DimDyn, D>::alloc(out_size);
482 D::conv2d(
483 input,
484 y.to_ref_mut(),
485 filter,
486 pad_h,
487 pad_w,
488 stride_h,
489 stride_w,
490 dilation_h,
491 dilation_w,
492 config,
493 );
494 y
495}
496
497#[expect(clippy::too_many_arguments)]
498#[must_use]
499pub fn conv2d_bckwd_data<T: Num, D: Device>(
500 dy: Matrix<Ref<&T>, DimDyn, D>,
501 filter: Matrix<Ref<&T>, DimDyn, D>,
502 pad_h: usize,
503 pad_w: usize,
504 stride_h: usize,
505 stride_w: usize,
506 dilation_h: usize,
507 dilation_w: usize,
508 config: Option<&Conv2dConfig<T>>,
509) -> Matrix<Owned<T>, DimDyn, D> {
510 let input_shape = deconv2d_out_size(
511 dy.shape().slice(),
512 filter.shape().slice(),
513 (pad_h, pad_w),
514 (stride_h, stride_w),
515 );
516 let mut dx = Matrix::<Owned<T>, DimDyn, D>::alloc(input_shape);
517 D::conv2d_bckwd_data(
518 dy,
519 dx.to_ref_mut(),
520 filter,
521 pad_h,
522 pad_w,
523 stride_h,
524 stride_w,
525 dilation_h,
526 dilation_w,
527 config,
528 );
529 dx
530}
531
532#[expect(clippy::too_many_arguments)]
533#[must_use]
534pub fn conv2d_bckwd_filter<T: Num, D: Device>(
535 input: Matrix<Ref<&T>, DimDyn, D>,
536 dy: Matrix<Ref<&T>, DimDyn, D>,
537 pad_h: usize,
538 pad_w: usize,
539 stride_h: usize,
540 stride_w: usize,
541 dilation_h: usize,
542 dilation_w: usize,
543 filter_shape: DimDyn,
544 config: Option<&Conv2dConfig<T>>,
545) -> Matrix<Owned<T>, DimDyn, D> {
546 let mut df = Matrix::<Owned<T>, DimDyn, D>::alloc(filter_shape);
547 D::conv2d_bckwd_filter(
548 input,
549 dy,
550 df.to_ref_mut(),
551 pad_h,
552 pad_w,
553 stride_h,
554 stride_w,
555 dilation_h,
556 dilation_w,
557 config,
558 );
559 df
560}
561
562pub fn conv2d_bias_add<T: Num, D: Device>(
563 input: Matrix<Ref<&T>, DimDyn, D>,
564 bias: Matrix<Ref<&T>, DimDyn, D>,
565 output: Matrix<Ref<&mut T>, DimDyn, D>,
566) {
567 D::conv2d_forward_bias(input, output, bias);
568}
569
570pub fn conv2d_bckwd_data_bias<T: Num, D: Device>(
571 dy: Matrix<Ref<&T>, DimDyn, D>,
572 dx: Matrix<Ref<&mut T>, DimDyn, D>,
573) {
574 D::conv2d_bckwd_bias(dy, dx);
575}
576
577#[expect(clippy::unreadable_literal, clippy::too_many_lines)]
578#[cfg(test)]
579mod conv2d {
580 use zenu_test::{assert_mat_eq_epsilon, run_mat_test};
581
582 use crate::{
583 device::Device,
584 dim::DimDyn,
585 matrix::{Matrix, Owned},
586 };
587
588 use super::{conv2d_bckwd_data, conv2d_bckwd_filter, conv2d_forward};
589
590 struct Conv2dTestCase<D: Device> {
591 input: Matrix<Owned<f32>, DimDyn, D>,
592 filter: Matrix<Owned<f32>, DimDyn, D>,
593 pad_h: usize,
594 pad_w: usize,
595 stride_h: usize,
596 stride_w: usize,
597 dilation_h: usize,
598 dilation_w: usize,
599 expected: Matrix<Owned<f32>, DimDyn, D>,
600 input_grad: Matrix<Owned<f32>, DimDyn, D>,
601 filter_grad: Matrix<Owned<f32>, DimDyn, D>,
602 output_grad: Matrix<Owned<f32>, DimDyn, D>,
603 }
604
605 fn conv2d_test<D: Device>() {
606 let test_case = small::<D>();
607
608 let forward_pred = conv2d_forward(
609 test_case.input.to_ref(),
610 test_case.filter.to_ref(),
611 test_case.pad_h,
612 test_case.pad_w,
613 test_case.stride_h,
614 test_case.stride_w,
615 test_case.dilation_h,
616 test_case.dilation_w,
617 None,
618 );
619 assert_mat_eq_epsilon!(forward_pred, test_case.expected, 1e-4);
620 let input_grad = conv2d_bckwd_data(
621 test_case.output_grad.to_ref(),
622 test_case.filter.to_ref(),
623 test_case.pad_h,
624 test_case.pad_w,
625 test_case.stride_h,
626 test_case.stride_w,
627 test_case.dilation_h,
628 test_case.dilation_w,
629 None,
630 );
631 assert_mat_eq_epsilon!(input_grad, test_case.input_grad, 1e-4);
632
633 let filter_grad = conv2d_bckwd_filter(
634 test_case.input.to_ref(),
635 test_case.output_grad.to_ref(),
636 test_case.pad_h,
637 test_case.pad_w,
638 test_case.stride_h,
639 test_case.stride_w,
640 test_case.dilation_h,
641 test_case.dilation_w,
642 test_case.filter.shape(),
643 None,
644 );
645 assert_mat_eq_epsilon!(filter_grad, test_case.filter_grad, 1e-4);
646 }
647 run_mat_test!(conv2d_test, conv2d_test_cpu, conv2d_test_nvidia);
648
649 fn small<D: Device>() -> Conv2dTestCase<D> {
650 let input = vec![
651 0.5432947,
652 -0.39515755,
653 0.20552567,
654 -0.45032975,
655 -0.5730771,
656 -0.5553584,
657 0.59432304,
658 1.5419426,
659 1.8197253,
660 -0.5515287,
661 -1.325326,
662 0.18855357,
663 -0.069072686,
664 -0.49492535,
665 -1.4959149,
666 -0.19383712,
667 -0.4731198,
668 0.33555076,
669 1.5091219,
670 2.0819554,
671 1.7067116,
672 2.3803675,
673 -1.1256016,
674 -0.3169981,
675 -0.14067143,
676 0.8057536,
677 0.3276143,
678 -0.7607072,
679 -1.599082,
680 0.018486667,
681 -0.7504268,
682 0.18540798,
683 ];
684 let output = vec![
685 0.3671525,
686 -0.17387724,
687 -0.53952014,
688 -0.41356063,
689 0.13519445,
690 -0.6369239,
691 -0.5777169,
692 -0.07820636,
693 -0.6019154,
694 -0.85000455,
695 -0.227178,
696 0.38553098,
697 0.53258127,
698 0.4952766,
699 0.16334829,
700 0.5179188,
701 -1.1829954,
702 -0.15092221,
703 0.15374796,
704 0.5376092,
705 -0.35269666,
706 -0.10102463,
707 -0.628401,
708 -0.40036133,
709 -0.5694187,
710 -0.1765114,
711 -0.05552435,
712 -0.3107502,
713 -0.6736164,
714 -0.44401115,
715 -0.1804393,
716 0.056986123,
717 0.5652461,
718 0.8913239,
719 0.30458608,
720 -0.7666081,
721 0.15480474,
722 0.14275207,
723 0.42336845,
724 0.12534592,
725 0.5706087,
726 0.40240055,
727 -0.16282544,
728 -0.032061294,
729 0.47645676,
730 -0.09869753,
731 -0.34638345,
732 -0.02880986,
733 ];
734 let input_grad = vec![
735 -0.06312838,
736 0.05240719,
737 0.05240719,
738 0.21505278,
739 -0.07415994,
740 0.063570745,
741 0.063570745,
742 0.22900042,
743 -0.07415994,
744 0.063570745,
745 0.063570745,
746 0.22900042,
747 -0.0014246926,
748 0.13951382,
749 0.13951382,
750 0.005797662,
751 -0.73124456,
752 -0.7982433,
753 -0.7982433,
754 -0.098860174,
755 -0.57463914,
756 -0.689119,
757 -0.689119,
758 -0.12428501,
759 -0.57463914,
760 -0.689119,
761 -0.689119,
762 -0.12428501,
763 -0.22594097,
764 -0.37261552,
765 -0.37261552,
766 -0.085577406,
767 ];
768 let filter = vec![
769 -0.0017646605,
770 0.12644097,
771 -0.1939936,
772 -0.1734625,
773 -0.090781756,
774 0.063205294,
775 -0.0046700113,
776 0.18688585,
777 -0.020917172,
778 0.06236978,
779 -0.071232304,
780 -0.046330906,
781 -0.2251778,
782 -0.15610139,
783 -0.09716192,
784 0.008731253,
785 0.0931814,
786 0.14142673,
787 -0.15979224,
788 -0.10263957,
789 0.0856111,
790 0.19572432,
791 -0.048507567,
792 0.17637877,
793 -0.03799128,
794 0.024940623,
795 0.21342279,
796 -0.218654,
797 -0.14838351,
798 -0.05967162,
799 -0.09187673,
800 0.20364694,
801 -0.1527774,
802 -0.1085015,
803 -0.16467114,
804 -0.22074954,
805 -0.13758895,
806 0.2026092,
807 0.105174676,
808 0.11423842,
809 0.01239595,
810 -0.12084066,
811 0.039877214,
812 -0.22007395,
813 -0.1703105,
814 -0.121511586,
815 0.1487135,
816 0.13819724,
817 -0.104532786,
818 -0.0085047,
819 0.1507459,
820 0.23431942,
821 0.093546025,
822 0.03184169,
823 ];
824 let filter_grad = vec![
825 -0.23757887,
826 1.0425875,
827 -0.7473556,
828 -2.297492,
829 -1.2111626,
830 -2.932033,
831 -2.651155,
832 -1.1144958,
833 -2.292071,
834 5.325727,
835 6.329977,
836 5.2370563,
837 2.994705,
838 4.184363,
839 4.690524,
840 1.6231518,
841 0.7308545,
842 0.7638962,
843 -0.23757887,
844 1.0425875,
845 -0.7473556,
846 -2.297492,
847 -1.2111626,
848 -2.932033,
849 -2.651155,
850 -1.1144958,
851 -2.292071,
852 5.325727,
853 6.329977,
854 5.2370563,
855 2.994705,
856 4.184363,
857 4.690524,
858 1.6231518,
859 0.7308545,
860 0.7638962,
861 -0.23757887,
862 1.0425875,
863 -0.7473556,
864 -2.297492,
865 -1.2111626,
866 -2.932033,
867 -2.651155,
868 -1.1144958,
869 -2.292071,
870 5.325727,
871 6.329977,
872 5.2370563,
873 2.994705,
874 4.184363,
875 4.690524,
876 1.6231518,
877 0.7308545,
878 0.7638962,
879 ];
880 let input = Matrix::<Owned<f32>, DimDyn, D>::from_vec(input, [1, 2, 4, 4]);
881 let input_grad = Matrix::<Owned<f32>, DimDyn, D>::from_vec(input_grad, [1, 2, 4, 4]);
882 let filter = Matrix::<Owned<f32>, DimDyn, D>::from_vec(filter, [3, 2, 3, 3]);
883 let filter_grad = Matrix::<Owned<f32>, DimDyn, D>::from_vec(filter_grad, [3, 2, 3, 3]);
884 let expected = Matrix::<Owned<f32>, DimDyn, D>::from_vec(output, [1, 3, 4, 4]);
885 let output_grad = Matrix::<Owned<f32>, DimDyn, D>::ones([1, 3, 4, 4]);
886
887 Conv2dTestCase {
888 input,
889 filter,
890 pad_h: 1,
891 pad_w: 1,
892 stride_h: 1,
893 stride_w: 1,
894 dilation_h: 1,
895 dilation_w: 1,
896 expected,
897 input_grad,
898 filter_grad,
899 output_grad,
900 }
901 }
902}