zenu_matrix/nn/conv2d/
mod.rs

1use crate::{
2    device::{cpu::Cpu, Device, DeviceBase},
3    dim::{DimDyn, DimTrait},
4    matrix::{Matrix, Owned, Ref},
5    num::Num,
6};
7
8mod conv2d_bckwd_filter_cpu;
9mod conv2d_cpu_impl;
10mod deconv2d_cpu_impl;
11
12use self::{
13    conv2d_bckwd_filter_cpu::conv2d_bckwd_fileter, conv2d_cpu_impl::conv2d_inner,
14    deconv2d_cpu_impl::deconv2d_inner,
15};
16
17#[cfg(feature = "nvidia")]
18use zenu_cuda::{
19    cudnn::{
20        conv::{backward_bias, ConvolutionBuilder, ConvolutionConfig},
21        TensorFormat,
22    },
23    kernel::conv_bias_add,
24};
25
26#[cfg(feature = "nvidia")]
27use crate::device::nvidia::Nvidia;
28
29#[must_use]
30pub fn conv2d_out_size(
31    img_shape: &[usize],
32    kernel_shape: &[usize],
33    padding: (usize, usize),
34    stride: (usize, usize),
35) -> [usize; 4] {
36    let (b, h, w) = (img_shape[0], img_shape[2], img_shape[3]);
37    let (oc, kh, kw) = (kernel_shape[0], kernel_shape[2], kernel_shape[3]);
38    let (ph, pw) = padding;
39    let (sh, sw) = stride;
40    let (h, w) = ((h + 2 * ph - kh) / sh + 1, (w + 2 * pw - kw) / sw + 1);
41    [b, oc, h, w]
42}
43
44pub(super) fn get_deconv_outsize_(size: usize, k: usize, s: usize, p: usize) -> usize {
45    s * (size - 1) + k - 2 * p
46}
47
48#[must_use]
49pub fn deconv2d_out_size(
50    img_shape: &[usize],
51    kernel_shape: &[usize],
52    padding: (usize, usize),
53    stride: (usize, usize),
54) -> [usize; 4] {
55    let (b, h, w) = (img_shape[0], img_shape[2], img_shape[3]);
56    let (ic, kh, kw) = (kernel_shape[1], kernel_shape[2], kernel_shape[3]);
57    let (ph, pw) = padding;
58    let (sh, sw) = stride;
59    let (h, w) = (
60        get_deconv_outsize_(h, kh, sh, ph),
61        get_deconv_outsize_(w, kw, sw, pw),
62    );
63    [b, ic, h, w]
64}
65
66pub struct Conv2dConfig<T: Num> {
67    #[cfg(feature = "nvidia")]
68    pub conv: ConvolutionConfig<T>,
69    _phantom: std::marker::PhantomData<T>,
70}
71
72#[expect(clippy::too_many_arguments)]
73#[allow(unused_variables)]
74#[must_use]
75pub fn create_conv_descriptor<T: Num>(
76    input_shape: &[usize],
77    output_shape: &[usize],
78    filter_shape: &[usize],
79    pad_h: usize,
80    pad_w: usize,
81    stride_h: usize,
82    stride_w: usize,
83    dilation_h: usize,
84    dilation_w: usize,
85    groups: usize,
86) -> Conv2dConfig<T> {
87    #[cfg(feature = "nvidia")]
88    let conv = {
89        let input_shape = input_shape
90            .iter()
91            .map(|x| i32::try_from(*x).unwrap())
92            .collect::<Vec<_>>();
93        let output_shape = output_shape
94            .iter()
95            .map(|x| i32::try_from(*x).unwrap())
96            .collect::<Vec<_>>();
97        let filter_shape = filter_shape
98            .iter()
99            .map(|x| i32::try_from(*x).unwrap())
100            .collect::<Vec<_>>();
101
102        let input_shape_0: i32 = input_shape[0];
103        let input_shape_1: i32 = input_shape[1];
104        let input_shape_2: i32 = input_shape[2];
105        let input_shape_3: i32 = input_shape[3];
106
107        let output_shape_0: i32 = output_shape[0];
108        let output_shape_1: i32 = output_shape[1];
109        let output_shape_2: i32 = output_shape[2];
110        let output_shape_3: i32 = output_shape[3];
111
112        let filter_shape_0: i32 = filter_shape[0];
113        let filter_shape_1: i32 = filter_shape[1];
114        let filter_shape_2: i32 = filter_shape[2];
115        let filter_shape_3: i32 = filter_shape[3];
116
117        let pad_h: i32 = pad_h.try_into().unwrap();
118        let pad_w: i32 = pad_w.try_into().unwrap();
119
120        let stride_h: i32 = stride_h.try_into().unwrap();
121        let stride_w: i32 = stride_w.try_into().unwrap();
122
123        let dilation_h: i32 = dilation_h.try_into().unwrap();
124        let dilation_w: i32 = dilation_w.try_into().unwrap();
125
126        let conv = ConvolutionBuilder::<T>::default()
127            .input(
128                input_shape_0,
129                input_shape_1,
130                input_shape_2,
131                input_shape_3,
132                TensorFormat::NCHW,
133            )
134            .unwrap()
135            .filter(
136                filter_shape_0,
137                filter_shape_1,
138                filter_shape_2,
139                filter_shape_3,
140                TensorFormat::NCHW,
141            )
142            .unwrap()
143            .output(
144                output_shape_0,
145                output_shape_1,
146                output_shape_2,
147                output_shape_3,
148                TensorFormat::NCHW,
149            )
150            .unwrap()
151            .conv(pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w)
152            .unwrap()
153            .algorithms(groups);
154
155        conv.build()
156    };
157
158    Conv2dConfig {
159        #[cfg(feature = "nvidia")]
160        conv,
161        _phantom: std::marker::PhantomData,
162    }
163}
164
165pub trait Conv2d: DeviceBase {
166    #[expect(clippy::too_many_arguments)]
167    fn conv2d<T: Num>(
168        input: Matrix<Ref<&T>, DimDyn, Self>,
169        y: Matrix<Ref<&mut T>, DimDyn, Self>,
170        filter: Matrix<Ref<&T>, DimDyn, Self>,
171        pad_h: usize,
172        pad_w: usize,
173        stride_h: usize,
174        stride_w: usize,
175        dilation_h: usize,
176        dilation_w: usize,
177        config: Option<&Conv2dConfig<T>>,
178    );
179
180    #[expect(clippy::too_many_arguments)]
181    fn conv2d_bckwd_data<T: Num>(
182        dy: Matrix<Ref<&T>, DimDyn, Self>,
183        dx: Matrix<Ref<&mut T>, DimDyn, Self>,
184        filter: Matrix<Ref<&T>, DimDyn, Self>,
185        pad_h: usize,
186        pad_w: usize,
187        stride_h: usize,
188        stride_w: usize,
189        dilation_h: usize,
190        dilation_w: usize,
191        config: Option<&Conv2dConfig<T>>,
192    );
193
194    #[expect(clippy::too_many_arguments)]
195    fn conv2d_bckwd_filter<T: Num>(
196        input: Matrix<Ref<&T>, DimDyn, Self>,
197        dy: Matrix<Ref<&T>, DimDyn, Self>,
198        df: Matrix<Ref<&mut T>, DimDyn, Self>,
199        pad_h: usize,
200        pad_w: usize,
201        stride_h: usize,
202        stride_w: usize,
203        dilation_h: usize,
204        dilation_w: usize,
205        config: Option<&Conv2dConfig<T>>,
206    );
207
208    fn conv2d_forward_bias<T: Num>(
209        input: Matrix<Ref<&T>, DimDyn, Self>,
210        y: Matrix<Ref<&mut T>, DimDyn, Self>,
211        bias: Matrix<Ref<&T>, DimDyn, Self>,
212    );
213
214    fn conv2d_bckwd_bias<T: Num>(
215        dy: Matrix<Ref<&T>, DimDyn, Self>,
216        dx: Matrix<Ref<&mut T>, DimDyn, Self>,
217    );
218}
219
220impl Conv2d for Cpu {
221    fn conv2d<T: Num>(
222        input: Matrix<Ref<&T>, DimDyn, Self>,
223        y: Matrix<Ref<&mut T>, DimDyn, Self>,
224        filter: Matrix<Ref<&T>, DimDyn, Self>,
225        pad_h: usize,
226        pad_w: usize,
227        stride_h: usize,
228        stride_w: usize,
229        dilation_h: usize,
230        dilation_w: usize,
231        _config: Option<&Conv2dConfig<T>>,
232    ) {
233        if dilation_h != 1 || dilation_w != 1 {
234            todo!();
235        }
236        y.copy_from(&conv2d_inner(
237            input,
238            filter,
239            None,
240            (pad_h, pad_w),
241            (stride_h, stride_w),
242        ));
243    }
244
245    fn conv2d_bckwd_data<T: Num>(
246        dy: Matrix<Ref<&T>, DimDyn, Self>,
247        dx: Matrix<Ref<&mut T>, DimDyn, Self>,
248        filter: Matrix<Ref<&T>, DimDyn, Self>,
249        pad_h: usize,
250        pad_w: usize,
251        stride_h: usize,
252        stride_w: usize,
253        dilation_h: usize,
254        dilation_w: usize,
255        _config: Option<&Conv2dConfig<T>>,
256    ) {
257        if dilation_h != 1 || dilation_w != 1 {
258            todo!();
259        }
260        dx.copy_from(&deconv2d_inner(
261            dy,
262            filter,
263            None,
264            (pad_h, pad_w),
265            (stride_h, stride_w),
266        ));
267    }
268
269    fn conv2d_bckwd_filter<T: Num>(
270        input: Matrix<Ref<&T>, DimDyn, Self>,
271        dy: Matrix<Ref<&T>, DimDyn, Self>,
272        df: Matrix<Ref<&mut T>, DimDyn, Self>,
273        pad_h: usize,
274        pad_w: usize,
275        stride_h: usize,
276        stride_w: usize,
277        dilation_h: usize,
278        dilation_w: usize,
279        _config: Option<&Conv2dConfig<T>>,
280    ) {
281        if dilation_h != 1 || dilation_w != 1 {
282            todo!();
283        }
284        df.copy_from(&conv2d_bckwd_fileter(
285            input,
286            df.to_ref().shape(),
287            dy,
288            (pad_h, pad_w),
289            (stride_h, stride_w),
290        ));
291    }
292
293    fn conv2d_forward_bias<T: Num>(
294        input: Matrix<Ref<&T>, DimDyn, Self>,
295        mut y: Matrix<Ref<&mut T>, DimDyn, Self>,
296        bias: Matrix<Ref<&T>, DimDyn, Self>,
297    ) {
298        y.add_array(&input, &bias);
299    }
300
301    fn conv2d_bckwd_bias<T: Num>(
302        dy: Matrix<Ref<&T>, DimDyn, Self>,
303        dx: Matrix<Ref<&mut T>, DimDyn, Self>,
304    ) {
305        let dy_0 = dy.sum(0, true);
306        let dy_0_2 = dy_0.to_ref().sum(2, true);
307        dx.copy_from(&dy_0_2.to_ref().sum(3, true));
308    }
309}
310
311#[cfg(feature = "nvidia")]
312impl Conv2d for Nvidia {
313    fn conv2d<T: Num>(
314        input: Matrix<Ref<&T>, DimDyn, Self>,
315        y: Matrix<Ref<&mut T>, DimDyn, Self>,
316        filter: Matrix<Ref<&T>, DimDyn, Self>,
317        pad_h: usize,
318        pad_w: usize,
319        stride_h: usize,
320        stride_w: usize,
321        dilation_h: usize,
322        dilation_w: usize,
323        config: Option<&Conv2dConfig<T>>,
324    ) {
325        let config = match config {
326            Some(config) => &config.conv,
327            None => {
328                &create_conv_descriptor::<T>(
329                    input.shape().slice(),
330                    y.shape().slice(),
331                    filter.shape().slice(),
332                    pad_h,
333                    pad_w,
334                    stride_h,
335                    stride_w,
336                    dilation_h,
337                    dilation_w,
338                    1,
339                )
340                .conv
341            }
342        };
343
344        config.forward(
345            T::one(),
346            input.as_ptr(),
347            filter.as_ptr(),
348            T::zero(),
349            y.as_mut_ptr(),
350        );
351    }
352
353    fn conv2d_bckwd_data<T: Num>(
354        dy: Matrix<Ref<&T>, DimDyn, Self>,
355        dx: Matrix<Ref<&mut T>, DimDyn, Self>,
356        filter: Matrix<Ref<&T>, DimDyn, Self>,
357        pad_h: usize,
358        pad_w: usize,
359        stride_h: usize,
360        stride_w: usize,
361        dilation_h: usize,
362        dilation_w: usize,
363        config: Option<&Conv2dConfig<T>>,
364    ) {
365        let config = match config {
366            Some(config) => &config.conv,
367            None => {
368                &create_conv_descriptor::<T>(
369                    dx.shape().slice(),
370                    dy.shape().slice(),
371                    filter.shape().slice(),
372                    pad_h,
373                    pad_w,
374                    stride_h,
375                    stride_w,
376                    dilation_h,
377                    dilation_w,
378                    1,
379                )
380                .conv
381            }
382        };
383
384        config.backward_data(
385            T::one(),
386            filter.as_ptr(),
387            dy.as_ptr(),
388            T::zero(),
389            dx.as_mut_ptr(),
390        );
391    }
392
393    fn conv2d_bckwd_filter<T: Num>(
394        input: Matrix<Ref<&T>, DimDyn, Self>,
395        dy: Matrix<Ref<&T>, DimDyn, Self>,
396        df: Matrix<Ref<&mut T>, DimDyn, Self>,
397        pad_h: usize,
398        pad_w: usize,
399        stride_h: usize,
400        stride_w: usize,
401        dilation_h: usize,
402        dilation_w: usize,
403        config: Option<&Conv2dConfig<T>>,
404    ) {
405        let config = match config {
406            Some(config) => &config.conv,
407            None => {
408                &create_conv_descriptor::<T>(
409                    input.shape().slice(),
410                    dy.shape().slice(),
411                    df.shape().slice(),
412                    pad_h,
413                    pad_w,
414                    stride_h,
415                    stride_w,
416                    dilation_h,
417                    dilation_w,
418                    1,
419                )
420                .conv
421            }
422        };
423
424        config.backward_filter(
425            T::one(),
426            input.as_ptr(),
427            dy.as_ptr(),
428            T::zero(),
429            df.as_mut_ptr(),
430        );
431    }
432
433    fn conv2d_forward_bias<T: Num>(
434        input: Matrix<Ref<&T>, DimDyn, Self>,
435        y: Matrix<Ref<&mut T>, DimDyn, Self>,
436        bias: Matrix<Ref<&T>, DimDyn, Self>,
437    ) {
438        let input_channel_stride = input.stride()[1];
439        let input_num_elm = input.shape().num_elm();
440        let bias_num_elm = bias.shape().num_elm();
441
442        conv_bias_add(
443            input.as_ptr(),
444            bias.as_ptr(),
445            input_num_elm,
446            input_channel_stride,
447            bias_num_elm,
448            y.as_mut_ptr(),
449        );
450    }
451
452    fn conv2d_bckwd_bias<T: Num>(
453        dy: Matrix<Ref<&T>, DimDyn, Self>,
454        dx: Matrix<Ref<&mut T>, DimDyn, Self>,
455    ) {
456        let dy_shape = dy.shape();
457        let dy_shape = dy_shape.slice();
458        backward_bias(T::one(), dy.as_ptr(), T::zero(), dx.as_mut_ptr(), dy_shape);
459    }
460}
461
462#[expect(clippy::too_many_arguments)]
463#[must_use]
464pub fn conv2d_forward<T: Num, D: Device>(
465    input: Matrix<Ref<&T>, DimDyn, D>,
466    filter: Matrix<Ref<&T>, DimDyn, D>,
467    pad_h: usize,
468    pad_w: usize,
469    stride_h: usize,
470    stride_w: usize,
471    dilation_h: usize,
472    dilation_w: usize,
473    config: Option<&Conv2dConfig<T>>,
474) -> Matrix<Owned<T>, DimDyn, D> {
475    let out_size = conv2d_out_size(
476        input.shape().slice(),
477        filter.shape().slice(),
478        (pad_h, pad_w),
479        (stride_h, stride_w),
480    );
481    let mut y = Matrix::<Owned<T>, DimDyn, D>::alloc(out_size);
482    D::conv2d(
483        input,
484        y.to_ref_mut(),
485        filter,
486        pad_h,
487        pad_w,
488        stride_h,
489        stride_w,
490        dilation_h,
491        dilation_w,
492        config,
493    );
494    y
495}
496
497#[expect(clippy::too_many_arguments)]
498#[must_use]
499pub fn conv2d_bckwd_data<T: Num, D: Device>(
500    dy: Matrix<Ref<&T>, DimDyn, D>,
501    filter: Matrix<Ref<&T>, DimDyn, D>,
502    pad_h: usize,
503    pad_w: usize,
504    stride_h: usize,
505    stride_w: usize,
506    dilation_h: usize,
507    dilation_w: usize,
508    config: Option<&Conv2dConfig<T>>,
509) -> Matrix<Owned<T>, DimDyn, D> {
510    let input_shape = deconv2d_out_size(
511        dy.shape().slice(),
512        filter.shape().slice(),
513        (pad_h, pad_w),
514        (stride_h, stride_w),
515    );
516    let mut dx = Matrix::<Owned<T>, DimDyn, D>::alloc(input_shape);
517    D::conv2d_bckwd_data(
518        dy,
519        dx.to_ref_mut(),
520        filter,
521        pad_h,
522        pad_w,
523        stride_h,
524        stride_w,
525        dilation_h,
526        dilation_w,
527        config,
528    );
529    dx
530}
531
532#[expect(clippy::too_many_arguments)]
533#[must_use]
534pub fn conv2d_bckwd_filter<T: Num, D: Device>(
535    input: Matrix<Ref<&T>, DimDyn, D>,
536    dy: Matrix<Ref<&T>, DimDyn, D>,
537    pad_h: usize,
538    pad_w: usize,
539    stride_h: usize,
540    stride_w: usize,
541    dilation_h: usize,
542    dilation_w: usize,
543    filter_shape: DimDyn,
544    config: Option<&Conv2dConfig<T>>,
545) -> Matrix<Owned<T>, DimDyn, D> {
546    let mut df = Matrix::<Owned<T>, DimDyn, D>::alloc(filter_shape);
547    D::conv2d_bckwd_filter(
548        input,
549        dy,
550        df.to_ref_mut(),
551        pad_h,
552        pad_w,
553        stride_h,
554        stride_w,
555        dilation_h,
556        dilation_w,
557        config,
558    );
559    df
560}
561
562pub fn conv2d_bias_add<T: Num, D: Device>(
563    input: Matrix<Ref<&T>, DimDyn, D>,
564    bias: Matrix<Ref<&T>, DimDyn, D>,
565    output: Matrix<Ref<&mut T>, DimDyn, D>,
566) {
567    D::conv2d_forward_bias(input, output, bias);
568}
569
570pub fn conv2d_bckwd_data_bias<T: Num, D: Device>(
571    dy: Matrix<Ref<&T>, DimDyn, D>,
572    dx: Matrix<Ref<&mut T>, DimDyn, D>,
573) {
574    D::conv2d_bckwd_bias(dy, dx);
575}
576
577#[expect(clippy::unreadable_literal, clippy::too_many_lines)]
578#[cfg(test)]
579mod conv2d {
580    use zenu_test::{assert_mat_eq_epsilon, run_mat_test};
581
582    use crate::{
583        device::Device,
584        dim::DimDyn,
585        matrix::{Matrix, Owned},
586    };
587
588    use super::{conv2d_bckwd_data, conv2d_bckwd_filter, conv2d_forward};
589
590    struct Conv2dTestCase<D: Device> {
591        input: Matrix<Owned<f32>, DimDyn, D>,
592        filter: Matrix<Owned<f32>, DimDyn, D>,
593        pad_h: usize,
594        pad_w: usize,
595        stride_h: usize,
596        stride_w: usize,
597        dilation_h: usize,
598        dilation_w: usize,
599        expected: Matrix<Owned<f32>, DimDyn, D>,
600        input_grad: Matrix<Owned<f32>, DimDyn, D>,
601        filter_grad: Matrix<Owned<f32>, DimDyn, D>,
602        output_grad: Matrix<Owned<f32>, DimDyn, D>,
603    }
604
605    fn conv2d_test<D: Device>() {
606        let test_case = small::<D>();
607
608        let forward_pred = conv2d_forward(
609            test_case.input.to_ref(),
610            test_case.filter.to_ref(),
611            test_case.pad_h,
612            test_case.pad_w,
613            test_case.stride_h,
614            test_case.stride_w,
615            test_case.dilation_h,
616            test_case.dilation_w,
617            None,
618        );
619        assert_mat_eq_epsilon!(forward_pred, test_case.expected, 1e-4);
620        let input_grad = conv2d_bckwd_data(
621            test_case.output_grad.to_ref(),
622            test_case.filter.to_ref(),
623            test_case.pad_h,
624            test_case.pad_w,
625            test_case.stride_h,
626            test_case.stride_w,
627            test_case.dilation_h,
628            test_case.dilation_w,
629            None,
630        );
631        assert_mat_eq_epsilon!(input_grad, test_case.input_grad, 1e-4);
632
633        let filter_grad = conv2d_bckwd_filter(
634            test_case.input.to_ref(),
635            test_case.output_grad.to_ref(),
636            test_case.pad_h,
637            test_case.pad_w,
638            test_case.stride_h,
639            test_case.stride_w,
640            test_case.dilation_h,
641            test_case.dilation_w,
642            test_case.filter.shape(),
643            None,
644        );
645        assert_mat_eq_epsilon!(filter_grad, test_case.filter_grad, 1e-4);
646    }
647    run_mat_test!(conv2d_test, conv2d_test_cpu, conv2d_test_nvidia);
648
649    fn small<D: Device>() -> Conv2dTestCase<D> {
650        let input = vec![
651            0.5432947,
652            -0.39515755,
653            0.20552567,
654            -0.45032975,
655            -0.5730771,
656            -0.5553584,
657            0.59432304,
658            1.5419426,
659            1.8197253,
660            -0.5515287,
661            -1.325326,
662            0.18855357,
663            -0.069072686,
664            -0.49492535,
665            -1.4959149,
666            -0.19383712,
667            -0.4731198,
668            0.33555076,
669            1.5091219,
670            2.0819554,
671            1.7067116,
672            2.3803675,
673            -1.1256016,
674            -0.3169981,
675            -0.14067143,
676            0.8057536,
677            0.3276143,
678            -0.7607072,
679            -1.599082,
680            0.018486667,
681            -0.7504268,
682            0.18540798,
683        ];
684        let output = vec![
685            0.3671525,
686            -0.17387724,
687            -0.53952014,
688            -0.41356063,
689            0.13519445,
690            -0.6369239,
691            -0.5777169,
692            -0.07820636,
693            -0.6019154,
694            -0.85000455,
695            -0.227178,
696            0.38553098,
697            0.53258127,
698            0.4952766,
699            0.16334829,
700            0.5179188,
701            -1.1829954,
702            -0.15092221,
703            0.15374796,
704            0.5376092,
705            -0.35269666,
706            -0.10102463,
707            -0.628401,
708            -0.40036133,
709            -0.5694187,
710            -0.1765114,
711            -0.05552435,
712            -0.3107502,
713            -0.6736164,
714            -0.44401115,
715            -0.1804393,
716            0.056986123,
717            0.5652461,
718            0.8913239,
719            0.30458608,
720            -0.7666081,
721            0.15480474,
722            0.14275207,
723            0.42336845,
724            0.12534592,
725            0.5706087,
726            0.40240055,
727            -0.16282544,
728            -0.032061294,
729            0.47645676,
730            -0.09869753,
731            -0.34638345,
732            -0.02880986,
733        ];
734        let input_grad = vec![
735            -0.06312838,
736            0.05240719,
737            0.05240719,
738            0.21505278,
739            -0.07415994,
740            0.063570745,
741            0.063570745,
742            0.22900042,
743            -0.07415994,
744            0.063570745,
745            0.063570745,
746            0.22900042,
747            -0.0014246926,
748            0.13951382,
749            0.13951382,
750            0.005797662,
751            -0.73124456,
752            -0.7982433,
753            -0.7982433,
754            -0.098860174,
755            -0.57463914,
756            -0.689119,
757            -0.689119,
758            -0.12428501,
759            -0.57463914,
760            -0.689119,
761            -0.689119,
762            -0.12428501,
763            -0.22594097,
764            -0.37261552,
765            -0.37261552,
766            -0.085577406,
767        ];
768        let filter = vec![
769            -0.0017646605,
770            0.12644097,
771            -0.1939936,
772            -0.1734625,
773            -0.090781756,
774            0.063205294,
775            -0.0046700113,
776            0.18688585,
777            -0.020917172,
778            0.06236978,
779            -0.071232304,
780            -0.046330906,
781            -0.2251778,
782            -0.15610139,
783            -0.09716192,
784            0.008731253,
785            0.0931814,
786            0.14142673,
787            -0.15979224,
788            -0.10263957,
789            0.0856111,
790            0.19572432,
791            -0.048507567,
792            0.17637877,
793            -0.03799128,
794            0.024940623,
795            0.21342279,
796            -0.218654,
797            -0.14838351,
798            -0.05967162,
799            -0.09187673,
800            0.20364694,
801            -0.1527774,
802            -0.1085015,
803            -0.16467114,
804            -0.22074954,
805            -0.13758895,
806            0.2026092,
807            0.105174676,
808            0.11423842,
809            0.01239595,
810            -0.12084066,
811            0.039877214,
812            -0.22007395,
813            -0.1703105,
814            -0.121511586,
815            0.1487135,
816            0.13819724,
817            -0.104532786,
818            -0.0085047,
819            0.1507459,
820            0.23431942,
821            0.093546025,
822            0.03184169,
823        ];
824        let filter_grad = vec![
825            -0.23757887,
826            1.0425875,
827            -0.7473556,
828            -2.297492,
829            -1.2111626,
830            -2.932033,
831            -2.651155,
832            -1.1144958,
833            -2.292071,
834            5.325727,
835            6.329977,
836            5.2370563,
837            2.994705,
838            4.184363,
839            4.690524,
840            1.6231518,
841            0.7308545,
842            0.7638962,
843            -0.23757887,
844            1.0425875,
845            -0.7473556,
846            -2.297492,
847            -1.2111626,
848            -2.932033,
849            -2.651155,
850            -1.1144958,
851            -2.292071,
852            5.325727,
853            6.329977,
854            5.2370563,
855            2.994705,
856            4.184363,
857            4.690524,
858            1.6231518,
859            0.7308545,
860            0.7638962,
861            -0.23757887,
862            1.0425875,
863            -0.7473556,
864            -2.297492,
865            -1.2111626,
866            -2.932033,
867            -2.651155,
868            -1.1144958,
869            -2.292071,
870            5.325727,
871            6.329977,
872            5.2370563,
873            2.994705,
874            4.184363,
875            4.690524,
876            1.6231518,
877            0.7308545,
878            0.7638962,
879        ];
880        let input = Matrix::<Owned<f32>, DimDyn, D>::from_vec(input, [1, 2, 4, 4]);
881        let input_grad = Matrix::<Owned<f32>, DimDyn, D>::from_vec(input_grad, [1, 2, 4, 4]);
882        let filter = Matrix::<Owned<f32>, DimDyn, D>::from_vec(filter, [3, 2, 3, 3]);
883        let filter_grad = Matrix::<Owned<f32>, DimDyn, D>::from_vec(filter_grad, [3, 2, 3, 3]);
884        let expected = Matrix::<Owned<f32>, DimDyn, D>::from_vec(output, [1, 3, 4, 4]);
885        let output_grad = Matrix::<Owned<f32>, DimDyn, D>::ones([1, 3, 4, 4]);
886
887        Conv2dTestCase {
888            input,
889            filter,
890            pad_h: 1,
891            pad_w: 1,
892            stride_h: 1,
893            stride_w: 1,
894            dilation_h: 1,
895            dilation_w: 1,
896            expected,
897            input_grad,
898            filter_grad,
899            output_grad,
900        }
901    }
902}