1use crate::{nn, nn::ModuleT, Tensor};
3
4fn conv_bn(p: nn::Path, c_in: i64, c_out: i64, ksize: i64, pad: i64, stride: i64) -> impl ModuleT {
5    let conv2d_cfg = nn::ConvConfig { stride, padding: pad, bias: false, ..Default::default() };
6    let bn_cfg = nn::BatchNormConfig { eps: 0.001, ..Default::default() };
7    nn::seq_t()
8        .add(nn::conv2d(&p / "conv", c_in, c_out, ksize, conv2d_cfg))
9        .add(nn::batch_norm2d(&p / "bn", c_out, bn_cfg))
10        .add_fn(|xs| xs.relu())
11}
12
13fn conv_bn2(p: nn::Path, c_in: i64, c_out: i64, ksize: [i64; 2], pad: [i64; 2]) -> impl ModuleT {
14    let conv2d_cfg =
15        nn::ConvConfigND::<[i64; 2]> { padding: pad, bias: false, ..Default::default() };
16    let bn_cfg = nn::BatchNormConfig { eps: 0.001, ..Default::default() };
17    nn::seq_t()
18        .add(nn::conv(&p / "conv", c_in, c_out, ksize, conv2d_cfg))
19        .add(nn::batch_norm2d(&p / "bn", c_out, bn_cfg))
20        .add_fn(|xs| xs.relu())
21}
22
23fn max_pool2d(xs: &Tensor, ksize: i64, stride: i64) -> Tensor {
24    xs.max_pool2d([ksize, ksize], [stride, stride], [0, 0], [1, 1], false)
25}
26
27fn inception_a(p: nn::Path, c_in: i64, c_pool: i64) -> impl ModuleT {
28    let b1 = conv_bn(&p / "branch1x1", c_in, 64, 1, 0, 1);
29    let b2_1 = conv_bn(&p / "branch5x5_1", c_in, 48, 1, 0, 1);
30    let b2_2 = conv_bn(&p / "branch5x5_2", 48, 64, 5, 2, 1);
31    let b3_1 = conv_bn(&p / "branch3x3dbl_1", c_in, 64, 1, 0, 1);
32    let b3_2 = conv_bn(&p / "branch3x3dbl_2", 64, 96, 3, 1, 1);
33    let b3_3 = conv_bn(&p / "branch3x3dbl_3", 96, 96, 3, 1, 1);
34    let bpool = conv_bn(&p / "branch_pool", c_in, c_pool, 1, 0, 1);
35    nn::func_t(move |xs, tr| {
36        let b1 = xs.apply_t(&b1, tr);
37        let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr);
38        let b3 = xs.apply_t(&b3_1, tr).apply_t(&b3_2, tr).apply_t(&b3_3, tr);
39        let bpool = xs.avg_pool2d([3, 3], [1, 1], [1, 1], false, true, 9).apply_t(&bpool, tr);
40        Tensor::cat(&[b1, b2, b3, bpool], 1)
41    })
42}
43
44fn inception_b(p: nn::Path, c_in: i64) -> impl ModuleT {
45    let b1 = conv_bn(&p / "branch3x3", c_in, 384, 3, 0, 2);
46    let b2_1 = conv_bn(&p / "branch3x3dbl_1", c_in, 64, 1, 0, 1);
47    let b2_2 = conv_bn(&p / "branch3x3dbl_2", 64, 96, 3, 1, 1);
48    let b2_3 = conv_bn(&p / "branch3x3dbl_3", 96, 96, 3, 0, 2);
49    nn::func_t(move |xs, tr| {
50        let b1 = xs.apply_t(&b1, tr);
51        let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr).apply_t(&b2_3, tr);
52        let bpool = max_pool2d(xs, 3, 2);
53        Tensor::cat(&[b1, b2, bpool], 1)
54    })
55}
56
57fn inception_c(p: nn::Path, c_in: i64, c7: i64) -> impl ModuleT {
58    let b1 = conv_bn(&p / "branch1x1", c_in, 192, 1, 0, 1);
59
60    let b2_1 = conv_bn(&p / "branch7x7_1", c_in, c7, 1, 0, 1);
61    let b2_2 = conv_bn2(&p / "branch7x7_2", c7, c7, [1, 7], [0, 3]);
62    let b2_3 = conv_bn2(&p / "branch7x7_3", c7, 192, [7, 1], [3, 0]);
63
64    let b3_1 = conv_bn(&p / "branch7x7dbl_1", c_in, c7, 1, 0, 1);
65    let b3_2 = conv_bn2(&p / "branch7x7dbl_2", c7, c7, [7, 1], [3, 0]);
66    let b3_3 = conv_bn2(&p / "branch7x7dbl_3", c7, c7, [1, 7], [0, 3]);
67    let b3_4 = conv_bn2(&p / "branch7x7dbl_4", c7, c7, [7, 1], [3, 0]);
68    let b3_5 = conv_bn2(&p / "branch7x7dbl_5", c7, 192, [1, 7], [0, 3]);
69
70    let bpool = conv_bn(&p / "branch_pool", c_in, 192, 1, 0, 1);
71
72    nn::func_t(move |xs, tr| {
73        let b1 = xs.apply_t(&b1, tr);
74        let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr).apply_t(&b2_3, tr);
75        let b3 = xs
76            .apply_t(&b3_1, tr)
77            .apply_t(&b3_2, tr)
78            .apply_t(&b3_3, tr)
79            .apply_t(&b3_4, tr)
80            .apply_t(&b3_5, tr);
81        let bpool = xs.avg_pool2d([3, 3], [1, 1], [1, 1], false, true, 9).apply_t(&bpool, tr);
82        Tensor::cat(&[b1, b2, b3, bpool], 1)
83    })
84}
85
86fn inception_d(p: nn::Path, c_in: i64) -> impl ModuleT {
87    let b1_1 = conv_bn(&p / "branch3x3_1", c_in, 192, 1, 0, 1);
88    let b1_2 = conv_bn(&p / "branch3x3_2", 192, 320, 3, 0, 2);
89
90    let b2_1 = conv_bn(&p / "branch7x7x3_1", c_in, 192, 1, 0, 1);
91    let b2_2 = conv_bn2(&p / "branch7x7x3_2", 192, 192, [1, 7], [0, 3]);
92    let b2_3 = conv_bn2(&p / "branch7x7x3_3", 192, 192, [7, 1], [3, 0]);
93    let b2_4 = conv_bn(&p / "branch7x7x3_4", 192, 192, 3, 0, 2);
94
95    nn::func_t(move |xs, tr| {
96        let b1 = xs.apply_t(&b1_1, tr).apply_t(&b1_2, tr);
97        let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr).apply_t(&b2_3, tr).apply_t(&b2_4, tr);
98        let bpool = max_pool2d(xs, 3, 2);
99        Tensor::cat(&[b1, b2, bpool], 1)
100    })
101}
102
103fn inception_e(p: nn::Path, c_in: i64) -> impl ModuleT {
104    let b1 = conv_bn(&p / "branch1x1", c_in, 320, 1, 0, 1);
105
106    let b2_1 = conv_bn(&p / "branch3x3_1", c_in, 384, 1, 0, 1);
107    let b2_2a = conv_bn2(&p / "branch3x3_2a", 384, 384, [1, 3], [0, 1]);
108    let b2_2b = conv_bn2(&p / "branch3x3_2b", 384, 384, [3, 1], [1, 0]);
109
110    let b3_1 = conv_bn(&p / "branch3x3dbl_1", c_in, 448, 1, 0, 1);
111    let b3_2 = conv_bn(&p / "branch3x3dbl_2", 448, 384, 3, 1, 1);
112    let b3_3a = conv_bn2(&p / "branch3x3dbl_3a", 384, 384, [1, 3], [0, 1]);
113    let b3_3b = conv_bn2(&p / "branch3x3dbl_3b", 384, 384, [3, 1], [1, 0]);
114
115    let bpool = conv_bn(&p / "branch_pool", c_in, 192, 1, 0, 1);
116
117    nn::func_t(move |xs, tr| {
118        let b1 = xs.apply_t(&b1, tr);
119
120        let b2 = xs.apply_t(&b2_1, tr);
121        let b2 = Tensor::cat(&[b2.apply_t(&b2_2a, tr), b2.apply_t(&b2_2b, tr)], 1);
122
123        let b3 = xs.apply_t(&b3_1, tr).apply_t(&b3_2, tr);
124        let b3 = Tensor::cat(&[b3.apply_t(&b3_3a, tr), b3.apply_t(&b3_3b, tr)], 1);
125
126        let bpool = xs.avg_pool2d([3, 3], [1, 1], [1, 1], false, true, 9).apply_t(&bpool, tr);
127
128        Tensor::cat(&[b1, b2, b3, bpool], 1)
129    })
130}
131
132pub fn v3(p: &nn::Path, nclasses: i64) -> impl ModuleT {
133    nn::seq_t()
134        .add(conv_bn(p / "Conv2d_1a_3x3", 3, 32, 3, 0, 2))
135        .add(conv_bn(p / "Conv2d_2a_3x3", 32, 32, 3, 0, 1))
136        .add(conv_bn(p / "Conv2d_2b_3x3", 32, 64, 3, 1, 1))
137        .add_fn(|xs| max_pool2d(&xs.relu(), 3, 2))
138        .add(conv_bn(p / "Conv2d_3b_1x1", 64, 80, 1, 0, 1))
139        .add(conv_bn(p / "Conv2d_4a_3x3", 80, 192, 3, 0, 1))
140        .add_fn(|xs| max_pool2d(&xs.relu(), 3, 2))
141        .add(inception_a(p / "Mixed_5b", 192, 32))
142        .add(inception_a(p / "Mixed_5c", 256, 64))
143        .add(inception_a(p / "Mixed_5d", 288, 64))
144        .add(inception_b(p / "Mixed_6a", 288))
145        .add(inception_c(p / "Mixed_6b", 768, 128))
146        .add(inception_c(p / "Mixed_6c", 768, 160))
147        .add(inception_c(p / "Mixed_6d", 768, 160))
148        .add(inception_c(p / "Mixed_6e", 768, 192))
149        .add(inception_d(p / "Mixed_7a", 768))
150        .add(inception_e(p / "Mixed_7b", 1280))
151        .add(inception_e(p / "Mixed_7c", 2048))
152        .add_fn_t(|xs, train| xs.adaptive_avg_pool2d([1, 1]).dropout(0.5, train).flat_view())
153        .add(nn::linear(p / "fc", 2048, nclasses, Default::default()))
154}