1use crate::nn::{self, ConvConfig, Module, ModuleT};
3use crate::Tensor;
4
5const BATCH_NORM_MOMENTUM: f64 = 0.99;
6const BATCH_NORM_EPSILON: f64 = 1e-3;
7
8#[derive(Debug, Clone, Copy)]
11pub struct MBConvConfig {
12 expand_ratio: f64,
13 kernel: i64,
14 stride: i64,
15 input_channels: i64,
16 out_channels: i64,
17 num_layers: usize,
18}
19
20fn make_divisible(v: f64, divisor: i64) -> i64 {
21 let min_value = divisor;
22 let new_v = i64::max(min_value, (v + divisor as f64 * 0.5) as i64 / divisor * divisor);
23 if (new_v as f64) < 0.9 * v {
24 new_v + divisor
25 } else {
26 new_v
27 }
28}
29
30fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
31 let bneck_conf = |e, k, s, i, o, n| {
32 let input_channels = make_divisible(i as f64 * width_mult, 8);
33 let out_channels = make_divisible(o as f64 * width_mult, 8);
34 let num_layers = (n as f64 * depth_mult).ceil() as usize;
35 MBConvConfig {
36 expand_ratio: e,
37 kernel: k,
38 stride: s,
39 input_channels,
40 out_channels,
41 num_layers,
42 }
43 };
44 vec![
45 bneck_conf(1., 3, 1, 32, 16, 1),
46 bneck_conf(6., 3, 2, 16, 24, 2),
47 bneck_conf(6., 5, 2, 24, 40, 2),
48 bneck_conf(6., 3, 2, 40, 80, 3),
49 bneck_conf(6., 5, 1, 80, 112, 3),
50 bneck_conf(6., 5, 2, 112, 192, 4),
51 bneck_conf(6., 3, 1, 192, 320, 1),
52 ]
53}
54
55impl MBConvConfig {
56 fn b0() -> Vec<Self> {
57 bneck_confs(1.0, 1.0)
58 }
59 fn b1() -> Vec<Self> {
60 bneck_confs(1.0, 1.1)
61 }
62 fn b2() -> Vec<Self> {
63 bneck_confs(1.1, 1.2)
64 }
65 fn b3() -> Vec<Self> {
66 bneck_confs(1.2, 1.4)
67 }
68 fn b4() -> Vec<Self> {
69 bneck_confs(1.4, 1.8)
70 }
71 fn b5() -> Vec<Self> {
72 bneck_confs(1.6, 2.2)
73 }
74 fn b6() -> Vec<Self> {
75 bneck_confs(1.8, 2.6)
76 }
77 fn b7() -> Vec<Self> {
78 bneck_confs(2.0, 3.1)
79 }
80}
81
82#[derive(Debug)]
84struct Conv2DSame {
85 conv2d: nn::Conv2D,
86 s: i64,
87 k: i64,
88}
89
90impl Conv2DSame {
91 fn new(vs: nn::Path, i: i64, o: i64, k: i64, stride: i64, groups: i64, b: bool) -> Self {
92 let conv_config = nn::ConvConfig { stride, groups, bias: b, ..Default::default() };
93 let conv2d = nn::conv2d(vs, i, o, k, conv_config);
94 Self { conv2d, s: stride, k }
95 }
96}
97
98impl Module for Conv2DSame {
99 fn forward(&self, xs: &Tensor) -> Tensor {
100 let s = self.s;
101 let k = self.k;
102 let size = xs.size();
103 let ih = size[2];
104 let iw = size[3];
105 let oh = (ih + s - 1) / s;
106 let ow = (iw + s - 1) / s;
107 let pad_h = i64::max((oh - 1) * s + k - ih, 0);
108 let pad_w = i64::max((ow - 1) * s + k - iw, 0);
109 if pad_h > 0 || pad_w > 0 {
110 xs.zero_pad2d(pad_w / 2, pad_w - pad_w / 2, pad_h / 2, pad_h - pad_h / 2)
111 .apply(&self.conv2d)
112 } else {
113 xs.apply(&self.conv2d)
114 }
115 }
116}
117
118#[derive(Debug)]
119struct ConvNormActivation {
120 conv2d: Conv2DSame,
121 bn2d: nn::BatchNorm,
122 activation: bool,
123}
124
125impl ConvNormActivation {
126 fn new(vs: nn::Path, i: i64, o: i64, k: i64, stride: i64, groups: i64) -> Self {
127 let conv2d = Conv2DSame::new(&vs / 0, i, o, k, stride, groups, false);
128 let bn_config = nn::BatchNormConfig {
129 momentum: 1.0 - BATCH_NORM_MOMENTUM,
130 eps: BATCH_NORM_EPSILON,
131 ..Default::default()
132 };
133 let bn2d = nn::batch_norm2d(&vs / 1, o, bn_config);
134 Self { conv2d, bn2d, activation: true }
135 }
136
137 fn no_activation(self) -> Self {
138 Self { activation: false, ..self }
139 }
140}
141
142impl ModuleT for ConvNormActivation {
143 fn forward_t(&self, xs: &Tensor, t: bool) -> Tensor {
144 let xs = xs.apply(&self.conv2d).apply_t(&self.bn2d, t);
145 if self.activation {
146 xs.swish()
147 } else {
148 xs
149 }
150 }
151}
152
153#[derive(Debug)]
154struct SqueezeExcitation {
155 fc1: Conv2DSame,
156 fc2: Conv2DSame,
157}
158
159impl SqueezeExcitation {
160 fn new(vs: nn::Path, input_channels: i64, squeeze_channels: i64) -> Self {
161 let fc1 = Conv2DSame::new(&vs / "fc1", input_channels, squeeze_channels, 1, 1, 1, true);
162 let fc2 = Conv2DSame::new(&vs / "fc2", squeeze_channels, input_channels, 1, 1, 1, true);
163 Self { fc1, fc2 }
164 }
165}
166
167impl ModuleT for SqueezeExcitation {
168 fn forward_t(&self, xs: &Tensor, t: bool) -> Tensor {
169 let scale = xs
170 .adaptive_avg_pool2d([1, 1])
171 .apply_t(&self.fc1, t)
172 .swish()
173 .apply_t(&self.fc2, t)
174 .sigmoid();
175 scale * xs
176 }
177}
178
179#[derive(Debug)]
180struct MBConv {
181 expand_cna: Option<ConvNormActivation>,
182 depthwise_cna: ConvNormActivation,
183 squeeze_excitation: SqueezeExcitation,
184 project_cna: ConvNormActivation,
185 config: MBConvConfig,
186}
187
188impl MBConv {
189 fn new(vs: nn::Path, c: MBConvConfig) -> Self {
190 let vs = &vs / "block";
191 let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
192 let expand_cna = if exp != c.input_channels {
193 Some(ConvNormActivation::new(&vs / 0, c.input_channels, exp, 1, 1, 1))
194 } else {
195 None
196 };
197 let start_index = if expand_cna.is_some() { 1 } else { 0 };
198 let depthwise_cna =
199 ConvNormActivation::new(&vs / start_index, exp, exp, c.kernel, c.stride, exp);
200 let squeeze_channels = i64::max(1, c.input_channels / 4);
201 let squeeze_excitation =
202 SqueezeExcitation::new(&vs / (start_index + 1), exp, squeeze_channels);
203 let project_cna =
204 ConvNormActivation::new(&vs / (start_index + 2), exp, c.out_channels, 1, 1, 1)
205 .no_activation();
206 Self { expand_cna, depthwise_cna, squeeze_excitation, project_cna, config: c }
207 }
208}
209
210impl ModuleT for MBConv {
211 fn forward_t(&self, xs: &Tensor, t: bool) -> Tensor {
212 let use_res_connect =
213 self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
214 let ys = match &self.expand_cna {
215 Some(expand_cna) => xs.apply_t(expand_cna, t),
216 None => xs.shallow_clone(),
217 };
218 let ys = ys
219 .apply_t(&self.depthwise_cna, t)
220 .apply_t(&self.squeeze_excitation, t)
221 .apply_t(&self.project_cna, t);
222 if use_res_connect {
223 ys + xs
224 } else {
225 ys
226 }
227 }
228}
229
230impl Tensor {
231 fn swish(&self) -> Tensor {
232 self * self.sigmoid()
233 }
234}
235
236#[derive(Debug)]
237struct EfficientNet {
238 init_cna: ConvNormActivation,
239 blocks: Vec<MBConv>,
240 final_cna: ConvNormActivation,
241 classifier: nn::Linear,
242}
243
244impl EfficientNet {
245 fn new(p: &nn::Path, configs: Vec<MBConvConfig>, nclasses: i64) -> Self {
246 let f_p = p / "features";
247 let first_in_c = configs[0].input_channels;
248 let last_out_c = configs.last().unwrap().out_channels;
249 let final_out_c = 4 * last_out_c;
250 let init_cna = ConvNormActivation::new(&f_p / 0, 3, first_in_c, 3, 2, 1);
251 let nconfigs = configs.len();
252 let mut blocks = vec![];
253 for (index, cnf) in configs.into_iter().enumerate() {
254 let f_p = &f_p / (index + 1);
255 for r_index in 0..cnf.num_layers {
256 let cnf = if r_index == 0 {
257 cnf
258 } else {
259 MBConvConfig { input_channels: cnf.out_channels, stride: 1, ..cnf }
260 };
261 blocks.push(MBConv::new(&f_p / r_index, cnf))
262 }
263 }
264 let final_cna =
265 ConvNormActivation::new(&f_p / (nconfigs + 1), last_out_c, final_out_c, 1, 1, 1);
266 let classifier =
267 nn::linear(p / "classifier" / 1, final_out_c, nclasses, Default::default());
268 Self { init_cna, blocks, final_cna, classifier }
269 }
270}
271
272impl ModuleT for EfficientNet {
273 fn forward_t(&self, xs: &Tensor, t: bool) -> Tensor {
274 let mut xs = xs.apply_t(&self.init_cna, t);
275 for block in self.blocks.iter() {
276 xs = xs.apply_t(block, t)
277 }
278 xs.apply_t(&self.final_cna, t)
279 .adaptive_avg_pool2d([1, 1])
280 .squeeze_dim(-1)
281 .squeeze_dim(-1)
282 .apply(&self.classifier)
283 }
284}
285
286pub fn b0(p: &nn::Path, nclasses: i64) -> impl ModuleT {
287 EfficientNet::new(p, MBConvConfig::b0(), nclasses)
288}
289pub fn b1(p: &nn::Path, nclasses: i64) -> impl ModuleT {
290 EfficientNet::new(p, MBConvConfig::b1(), nclasses)
291}
292pub fn b2(p: &nn::Path, nclasses: i64) -> impl ModuleT {
293 EfficientNet::new(p, MBConvConfig::b2(), nclasses)
294}
295pub fn b3(p: &nn::Path, nclasses: i64) -> impl ModuleT {
296 EfficientNet::new(p, MBConvConfig::b3(), nclasses)
297}
298pub fn b4(p: &nn::Path, nclasses: i64) -> impl ModuleT {
299 EfficientNet::new(p, MBConvConfig::b4(), nclasses)
300}
301pub fn b5(p: &nn::Path, nclasses: i64) -> impl ModuleT {
302 EfficientNet::new(p, MBConvConfig::b5(), nclasses)
303}
304pub fn b6(p: &nn::Path, nclasses: i64) -> impl ModuleT {
305 EfficientNet::new(p, MBConvConfig::b6(), nclasses)
306}
307pub fn b7(p: &nn::Path, nclasses: i64) -> impl ModuleT {
308 EfficientNet::new(p, MBConvConfig::b7(), nclasses)
309}
310
311#[allow(clippy::many_single_char_names)]
312pub fn conv2d_same(vs: nn::Path, i: i64, o: i64, k: i64, c: ConvConfig) -> impl Module {
313 let conv2d = nn::conv2d(vs, i, o, k, c);
314 let s = c.stride;
315 nn::func(move |xs| {
316 let size = xs.size();
317 let ih = size[2];
318 let iw = size[3];
319 let oh = (ih + s - 1) / s;
320 let ow = (iw + s - 1) / s;
321 let pad_h = i64::max((oh - 1) * s + k - ih, 0);
322 let pad_w = i64::max((ow - 1) * s + k - iw, 0);
323 if pad_h > 0 || pad_w > 0 {
324 xs.zero_pad2d(pad_w / 2, pad_w - pad_w / 2, pad_h / 2, pad_h - pad_h / 2).apply(&conv2d)
325 } else {
326 xs.apply(&conv2d)
327 }
328 })
329}