1use crate::{nn, nn::Conv2D, nn::FuncT, nn::ModuleT};
6
7fn conv2d(p: nn::Path, c_in: i64, c_out: i64, ksize: i64, padding: i64, stride: i64) -> Conv2D {
8 let conv2d_cfg = nn::ConvConfig { stride, padding, bias: false, ..Default::default() };
9 nn::conv2d(p, c_in, c_out, ksize, conv2d_cfg)
10}
11
12fn downsample(p: nn::Path, c_in: i64, c_out: i64, stride: i64) -> impl ModuleT {
13 if stride != 1 || c_in != c_out {
14 nn::seq_t().add(conv2d(&p / "0", c_in, c_out, 1, 0, stride)).add(nn::batch_norm2d(
15 &p / "1",
16 c_out,
17 Default::default(),
18 ))
19 } else {
20 nn::seq_t()
21 }
22}
23
24fn basic_block(p: nn::Path, c_in: i64, c_out: i64, stride: i64) -> impl ModuleT {
25 let conv1 = conv2d(&p / "conv1", c_in, c_out, 3, 1, stride);
26 let bn1 = nn::batch_norm2d(&p / "bn1", c_out, Default::default());
27 let conv2 = conv2d(&p / "conv2", c_out, c_out, 3, 1, 1);
28 let bn2 = nn::batch_norm2d(&p / "bn2", c_out, Default::default());
29 let downsample = downsample(&p / "downsample", c_in, c_out, stride);
30 nn::func_t(move |xs, train| {
31 let ys = xs.apply(&conv1).apply_t(&bn1, train).relu().apply(&conv2).apply_t(&bn2, train);
32 (xs.apply_t(&downsample, train) + ys).relu()
33 })
34}
35
36fn basic_layer(p: nn::Path, c_in: i64, c_out: i64, stride: i64, cnt: i64) -> impl ModuleT {
37 let mut layer = nn::seq_t().add(basic_block(&p / "0", c_in, c_out, stride));
38 for block_index in 1..cnt {
39 layer = layer.add(basic_block(&p / &block_index.to_string(), c_out, c_out, 1))
40 }
41 layer
42}
43
44fn resnet(
45 p: &nn::Path,
46 nclasses: Option<i64>,
47 c1: i64,
48 c2: i64,
49 c3: i64,
50 c4: i64,
51) -> FuncT<'static> {
52 let conv1 = conv2d(p / "conv1", 3, 64, 7, 3, 2);
53 let bn1 = nn::batch_norm2d(p / "bn1", 64, Default::default());
54 let layer1 = basic_layer(p / "layer1", 64, 64, 1, c1);
55 let layer2 = basic_layer(p / "layer2", 64, 128, 2, c2);
56 let layer3 = basic_layer(p / "layer3", 128, 256, 2, c3);
57 let layer4 = basic_layer(p / "layer4", 256, 512, 2, c4);
58 let fc = nclasses.map(|n| nn::linear(p / "fc", 512, n, Default::default()));
59 nn::func_t(move |xs, train| {
60 xs.apply(&conv1)
61 .apply_t(&bn1, train)
62 .relu()
63 .max_pool2d([3, 3], [2, 2], [1, 1], [1, 1], false)
64 .apply_t(&layer1, train)
65 .apply_t(&layer2, train)
66 .apply_t(&layer3, train)
67 .apply_t(&layer4, train)
68 .adaptive_avg_pool2d([1, 1])
69 .flat_view()
70 .apply_opt(&fc)
71 })
72}
73
74pub fn resnet18(p: &nn::Path, num_classes: i64) -> FuncT<'static> {
79 resnet(p, Some(num_classes), 2, 2, 2, 2)
80}
81
82pub fn resnet18_no_final_layer(p: &nn::Path) -> FuncT<'static> {
83 resnet(p, None, 2, 2, 2, 2)
84}
85
86pub fn resnet34(p: &nn::Path, num_classes: i64) -> FuncT<'static> {
91 resnet(p, Some(num_classes), 3, 4, 6, 3)
92}
93
94pub fn resnet34_no_final_layer(p: &nn::Path) -> FuncT<'static> {
95 resnet(p, None, 3, 4, 6, 3)
96}
97
98fn bottleneck_block(p: nn::Path, c_in: i64, c_out: i64, stride: i64, e: i64) -> impl ModuleT {
101 let e_dim = e * c_out;
102 let conv1 = conv2d(&p / "conv1", c_in, c_out, 1, 0, 1);
103 let bn1 = nn::batch_norm2d(&p / "bn1", c_out, Default::default());
104 let conv2 = conv2d(&p / "conv2", c_out, c_out, 3, 1, stride);
105 let bn2 = nn::batch_norm2d(&p / "bn2", c_out, Default::default());
106 let conv3 = conv2d(&p / "conv3", c_out, e_dim, 1, 0, 1);
107 let bn3 = nn::batch_norm2d(&p / "bn3", e_dim, Default::default());
108 let downsample = downsample(&p / "downsample", c_in, e_dim, stride);
109 nn::func_t(move |xs, train| {
110 let ys = xs
111 .apply(&conv1)
112 .apply_t(&bn1, train)
113 .relu()
114 .apply(&conv2)
115 .apply_t(&bn2, train)
116 .relu()
117 .apply(&conv3)
118 .apply_t(&bn3, train);
119 (xs.apply_t(&downsample, train) + ys).relu()
120 })
121}
122
123fn bottleneck_layer(p: nn::Path, c_in: i64, c_out: i64, stride: i64, cnt: i64) -> impl ModuleT {
124 let mut layer = nn::seq_t().add(bottleneck_block(&p / "0", c_in, c_out, stride, 4));
125 for block_index in 1..cnt {
126 layer = layer.add(bottleneck_block(&p / &block_index.to_string(), 4 * c_out, c_out, 1, 4))
127 }
128 layer
129}
130
131fn bottleneck_resnet(
132 p: &nn::Path,
133 nclasses: Option<i64>,
134 c1: i64,
135 c2: i64,
136 c3: i64,
137 c4: i64,
138) -> impl ModuleT {
139 let conv1 = conv2d(p / "conv1", 3, 64, 7, 3, 2);
140 let bn1 = nn::batch_norm2d(p / "bn1", 64, Default::default());
141 let layer1 = bottleneck_layer(p / "layer1", 64, 64, 1, c1);
142 let layer2 = bottleneck_layer(p / "layer2", 4 * 64, 128, 2, c2);
143 let layer3 = bottleneck_layer(p / "layer3", 4 * 128, 256, 2, c3);
144 let layer4 = bottleneck_layer(p / "layer4", 4 * 256, 512, 2, c4);
145 let fc = nclasses.map(|n| nn::linear(p / "fc", 4 * 512, n, Default::default()));
146 nn::func_t(move |xs, train| {
147 xs.apply(&conv1)
148 .apply_t(&bn1, train)
149 .relu()
150 .max_pool2d([3, 3], [2, 2], [1, 1], [1, 1], false)
151 .apply_t(&layer1, train)
152 .apply_t(&layer2, train)
153 .apply_t(&layer3, train)
154 .apply_t(&layer4, train)
155 .adaptive_avg_pool2d([1, 1])
156 .flat_view()
157 .apply_opt(&fc)
158 })
159}
160
161pub fn resnet50(p: &nn::Path, num_classes: i64) -> impl ModuleT {
162 bottleneck_resnet(p, Some(num_classes), 3, 4, 6, 3)
163}
164
165pub fn resnet50_no_final_layer(p: &nn::Path) -> impl ModuleT {
166 bottleneck_resnet(p, None, 3, 4, 6, 3)
167}
168
169pub fn resnet101(p: &nn::Path, num_classes: i64) -> impl ModuleT {
170 bottleneck_resnet(p, Some(num_classes), 3, 4, 23, 3)
171}
172
173pub fn resnet101_no_final_layer(p: &nn::Path) -> impl ModuleT {
174 bottleneck_resnet(p, None, 3, 4, 23, 3)
175}
176
177pub fn resnet152(p: &nn::Path, num_classes: i64) -> impl ModuleT {
178 bottleneck_resnet(p, Some(num_classes), 3, 8, 36, 3)
179}
180
181pub fn resnet150_no_final_layer(p: &nn::Path) -> impl ModuleT {
182 bottleneck_resnet(p, None, 3, 8, 36, 3)
183}