1use std::{cell::RefCell, collections::HashMap};
2
3use rand_distr::{Distribution, StandardNormal};
4use zenu_autograd::{
5 creator::{rand::normal, zeros::zeros},
6 nn::conv2d::{conv2d, Conv2dConfigs},
7 Variable,
8};
9use zenu_matrix::{device::Device, dim::DimTrait, nn::conv2d::conv2d_out_size, num::Num};
10
11use crate::{Module, Parameters};
12
13pub struct Conv2d<T: Num, D: Device> {
14 pub filter: Variable<T, D>,
15 pub bias: Option<Variable<T, D>>,
16 config: RefCell<Option<Conv2dConfigs<T>>>,
17 stride: (usize, usize),
18 padding: (usize, usize),
19}
20
21impl<T: Num, D: Device> Module<T, D> for Conv2d<T, D> {
22 type Input = Variable<T, D>;
23 type Output = Variable<T, D>;
24 fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
25 if self.config.borrow().is_none() {
26 let input_shape = input.get_data().shape();
27 let filter_shape = self.filter.get_data().shape();
28 let output_shape = conv2d_out_size(
29 input_shape.slice(),
30 filter_shape.slice(),
31 self.padding,
32 self.stride,
33 );
34 let config = Conv2dConfigs::new(
35 input_shape,
36 output_shape.into(),
37 filter_shape,
38 self.stride,
39 self.padding,
40 20,
41 );
42 *self.config.borrow_mut() = Some(config);
43 }
44 conv2d(
45 input,
46 self.filter.clone(),
47 self.stride,
48 self.padding,
49 self.bias.clone(),
50 Some(self.config.borrow().as_ref().unwrap().clone()),
51 )
52 }
53}
54
55impl<T: Num, D: Device> Parameters<T, D> for Conv2d<T, D> {
56 fn weights(&self) -> HashMap<String, Variable<T, D>> {
57 HashMap::new()
58 .into_iter()
59 .chain(std::iter::once((
60 String::from("conv2d.filter"),
61 self.filter.clone(),
62 )))
63 .collect()
64 }
65
66 fn biases(&self) -> HashMap<String, Variable<T, D>> {
67 self.bias
68 .as_ref()
69 .map(|bias| {
70 HashMap::new()
71 .into_iter()
72 .chain(std::iter::once((String::from("conv2d.bias"), bias.clone())))
73 .collect()
74 })
75 .unwrap_or_default()
76 }
77}
78
79impl<T: Num, D: Device> Conv2d<T, D> {
80 #[must_use]
81 pub fn new(
82 input_channel: usize,
83 output_channel: usize,
84 kernel_size: (usize, usize),
85 stride: (usize, usize),
86 padding: (usize, usize),
87 bias: bool,
88 ) -> Self
89 where
90 StandardNormal: Distribution<T>,
91 {
92 let filter_shape = [output_channel, input_channel, kernel_size.0, kernel_size.1];
93 let bias = if bias {
94 let bias = zeros([1, output_channel, 1, 1]);
95 bias.set_is_train(true);
96 bias.set_name("conv2d.bias");
97 Some(bias)
98 } else {
99 None
100 };
101 let filter = normal(T::zero(), T::one(), None, filter_shape);
102
103 filter.set_is_train(true);
104 filter.set_name("conv2d.filter");
105
106 Conv2d {
107 filter,
108 bias,
109 config: RefCell::new(None),
110 stride,
111 padding,
112 }
113 }
114
115 pub fn to<Dout: Device>(self) -> Conv2d<T, Dout> {
116 Conv2d {
117 filter: self.filter.to(),
118 bias: self.bias.map(|b| b.to()),
119 config: RefCell::new(None),
120 stride: self.stride,
121 padding: self.padding,
122 }
123 }
124}