tch_plus/vision/
resnet.rs

1//! ResNet implementation.
2//!
3//! See "Deep Residual Learning for Image Recognition" He et al. 2015
4//! <https://arxiv.org/abs/1512.03385>
5use crate::{nn, nn::Conv2D, nn::FuncT, nn::ModuleT};
6
7fn conv2d(p: nn::Path, c_in: i64, c_out: i64, ksize: i64, padding: i64, stride: i64) -> Conv2D {
8    let conv2d_cfg = nn::ConvConfig { stride, padding, bias: false, ..Default::default() };
9    nn::conv2d(p, c_in, c_out, ksize, conv2d_cfg)
10}
11
12fn downsample(p: nn::Path, c_in: i64, c_out: i64, stride: i64) -> impl ModuleT {
13    if stride != 1 || c_in != c_out {
14        nn::seq_t().add(conv2d(&p / "0", c_in, c_out, 1, 0, stride)).add(nn::batch_norm2d(
15            &p / "1",
16            c_out,
17            Default::default(),
18        ))
19    } else {
20        nn::seq_t()
21    }
22}
23
24fn basic_block(p: nn::Path, c_in: i64, c_out: i64, stride: i64) -> impl ModuleT {
25    let conv1 = conv2d(&p / "conv1", c_in, c_out, 3, 1, stride);
26    let bn1 = nn::batch_norm2d(&p / "bn1", c_out, Default::default());
27    let conv2 = conv2d(&p / "conv2", c_out, c_out, 3, 1, 1);
28    let bn2 = nn::batch_norm2d(&p / "bn2", c_out, Default::default());
29    let downsample = downsample(&p / "downsample", c_in, c_out, stride);
30    nn::func_t(move |xs, train| {
31        let ys = xs.apply(&conv1).apply_t(&bn1, train).relu().apply(&conv2).apply_t(&bn2, train);
32        (xs.apply_t(&downsample, train) + ys).relu()
33    })
34}
35
36fn basic_layer(p: nn::Path, c_in: i64, c_out: i64, stride: i64, cnt: i64) -> impl ModuleT {
37    let mut layer = nn::seq_t().add(basic_block(&p / "0", c_in, c_out, stride));
38    for block_index in 1..cnt {
39        layer = layer.add(basic_block(&p / &block_index.to_string(), c_out, c_out, 1))
40    }
41    layer
42}
43
44fn resnet(
45    p: &nn::Path,
46    nclasses: Option<i64>,
47    c1: i64,
48    c2: i64,
49    c3: i64,
50    c4: i64,
51) -> FuncT<'static> {
52    let conv1 = conv2d(p / "conv1", 3, 64, 7, 3, 2);
53    let bn1 = nn::batch_norm2d(p / "bn1", 64, Default::default());
54    let layer1 = basic_layer(p / "layer1", 64, 64, 1, c1);
55    let layer2 = basic_layer(p / "layer2", 64, 128, 2, c2);
56    let layer3 = basic_layer(p / "layer3", 128, 256, 2, c3);
57    let layer4 = basic_layer(p / "layer4", 256, 512, 2, c4);
58    let fc = nclasses.map(|n| nn::linear(p / "fc", 512, n, Default::default()));
59    nn::func_t(move |xs, train| {
60        xs.apply(&conv1)
61            .apply_t(&bn1, train)
62            .relu()
63            .max_pool2d([3, 3], [2, 2], [1, 1], [1, 1], false)
64            .apply_t(&layer1, train)
65            .apply_t(&layer2, train)
66            .apply_t(&layer3, train)
67            .apply_t(&layer4, train)
68            .adaptive_avg_pool2d([1, 1])
69            .flat_view()
70            .apply_opt(&fc)
71    })
72}
73
74/// Creates a ResNet-18 model.
75///
76/// Pre-trained weights can be downloaded at the following link:
77/// <https://github.com/LaurentMazare/tch-rs/releases/download/untagged-eb220e5c19f9bb250bd1/resnet18.ot>
78pub fn resnet18(p: &nn::Path, num_classes: i64) -> FuncT<'static> {
79    resnet(p, Some(num_classes), 2, 2, 2, 2)
80}
81
82pub fn resnet18_no_final_layer(p: &nn::Path) -> FuncT<'static> {
83    resnet(p, None, 2, 2, 2, 2)
84}
85
86/// Creates a ResNet-34 model.
87///
88/// Pre-trained weights can be downloaded at the following link:
89/// <https://github.com/LaurentMazare/tch-rs/releases/download/untagged-eb220e5c19f9bb250bd1/resnet34.ot>
90pub fn resnet34(p: &nn::Path, num_classes: i64) -> FuncT<'static> {
91    resnet(p, Some(num_classes), 3, 4, 6, 3)
92}
93
94pub fn resnet34_no_final_layer(p: &nn::Path) -> FuncT<'static> {
95    resnet(p, None, 3, 4, 6, 3)
96}
97
98// Bottleneck versions for ResNet 50, 101, and 152.
99
100fn bottleneck_block(p: nn::Path, c_in: i64, c_out: i64, stride: i64, e: i64) -> impl ModuleT {
101    let e_dim = e * c_out;
102    let conv1 = conv2d(&p / "conv1", c_in, c_out, 1, 0, 1);
103    let bn1 = nn::batch_norm2d(&p / "bn1", c_out, Default::default());
104    let conv2 = conv2d(&p / "conv2", c_out, c_out, 3, 1, stride);
105    let bn2 = nn::batch_norm2d(&p / "bn2", c_out, Default::default());
106    let conv3 = conv2d(&p / "conv3", c_out, e_dim, 1, 0, 1);
107    let bn3 = nn::batch_norm2d(&p / "bn3", e_dim, Default::default());
108    let downsample = downsample(&p / "downsample", c_in, e_dim, stride);
109    nn::func_t(move |xs, train| {
110        let ys = xs
111            .apply(&conv1)
112            .apply_t(&bn1, train)
113            .relu()
114            .apply(&conv2)
115            .apply_t(&bn2, train)
116            .relu()
117            .apply(&conv3)
118            .apply_t(&bn3, train);
119        (xs.apply_t(&downsample, train) + ys).relu()
120    })
121}
122
123fn bottleneck_layer(p: nn::Path, c_in: i64, c_out: i64, stride: i64, cnt: i64) -> impl ModuleT {
124    let mut layer = nn::seq_t().add(bottleneck_block(&p / "0", c_in, c_out, stride, 4));
125    for block_index in 1..cnt {
126        layer = layer.add(bottleneck_block(&p / &block_index.to_string(), 4 * c_out, c_out, 1, 4))
127    }
128    layer
129}
130
131fn bottleneck_resnet(
132    p: &nn::Path,
133    nclasses: Option<i64>,
134    c1: i64,
135    c2: i64,
136    c3: i64,
137    c4: i64,
138) -> impl ModuleT {
139    let conv1 = conv2d(p / "conv1", 3, 64, 7, 3, 2);
140    let bn1 = nn::batch_norm2d(p / "bn1", 64, Default::default());
141    let layer1 = bottleneck_layer(p / "layer1", 64, 64, 1, c1);
142    let layer2 = bottleneck_layer(p / "layer2", 4 * 64, 128, 2, c2);
143    let layer3 = bottleneck_layer(p / "layer3", 4 * 128, 256, 2, c3);
144    let layer4 = bottleneck_layer(p / "layer4", 4 * 256, 512, 2, c4);
145    let fc = nclasses.map(|n| nn::linear(p / "fc", 4 * 512, n, Default::default()));
146    nn::func_t(move |xs, train| {
147        xs.apply(&conv1)
148            .apply_t(&bn1, train)
149            .relu()
150            .max_pool2d([3, 3], [2, 2], [1, 1], [1, 1], false)
151            .apply_t(&layer1, train)
152            .apply_t(&layer2, train)
153            .apply_t(&layer3, train)
154            .apply_t(&layer4, train)
155            .adaptive_avg_pool2d([1, 1])
156            .flat_view()
157            .apply_opt(&fc)
158    })
159}
160
161pub fn resnet50(p: &nn::Path, num_classes: i64) -> impl ModuleT {
162    bottleneck_resnet(p, Some(num_classes), 3, 4, 6, 3)
163}
164
165pub fn resnet50_no_final_layer(p: &nn::Path) -> impl ModuleT {
166    bottleneck_resnet(p, None, 3, 4, 6, 3)
167}
168
169pub fn resnet101(p: &nn::Path, num_classes: i64) -> impl ModuleT {
170    bottleneck_resnet(p, Some(num_classes), 3, 4, 23, 3)
171}
172
173pub fn resnet101_no_final_layer(p: &nn::Path) -> impl ModuleT {
174    bottleneck_resnet(p, None, 3, 4, 23, 3)
175}
176
177pub fn resnet152(p: &nn::Path, num_classes: i64) -> impl ModuleT {
178    bottleneck_resnet(p, Some(num_classes), 3, 8, 36, 3)
179}
180
181pub fn resnet150_no_final_layer(p: &nn::Path) -> impl ModuleT {
182    bottleneck_resnet(p, None, 3, 8, 36, 3)
183}