svod_model/blocks/
bottleneck.rs1use snafu::ResultExt;
2use svod_tensor::Tensor;
3
4use crate::state::{self, HasStateDict, StateDict, prefixed};
5
6use super::basic_block::BlockKind;
7use super::batchnorm::BatchNormWeights;
8use super::conv::Conv2dWeights;
9use super::error::{Result, TensorSnafu};
10
11#[derive(Clone)]
12pub struct Bottleneck {
13 pub conv1: Conv2dWeights,
14 pub bn1: BatchNormWeights,
15 pub conv2: Conv2dWeights,
16 pub bn2: BatchNormWeights,
17 pub conv3: Conv2dWeights,
18 pub bn3: BatchNormWeights,
19 pub downsample: Option<(Conv2dWeights, BatchNormWeights)>,
20}
21
22impl Bottleneck {
23 pub fn empty(in_planes: usize, planes: usize, stride: usize) -> Self {
24 let expansion = BlockKind::Bottleneck.expansion();
25 let out_ch = planes * expansion;
26 let downsample = if stride != 1 || in_planes != out_ch {
27 Some((Conv2dWeights::empty(out_ch, in_planes, 1, stride, 0), BatchNormWeights::empty(out_ch)))
28 } else {
29 None
30 };
31 Self {
32 conv1: Conv2dWeights::empty(planes, in_planes, 1, 1, 0),
33 bn1: BatchNormWeights::empty(planes),
34 conv2: Conv2dWeights::empty(planes, planes, 3, stride, 1),
35 bn2: BatchNormWeights::empty(planes),
36 conv3: Conv2dWeights::empty(out_ch, planes, 1, 1, 0),
37 bn3: BatchNormWeights::empty(out_ch),
38 downsample,
39 }
40 }
41
42 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
43 let out = self.bn1.forward(&self.conv1.forward(x)?)?.relu().context(TensorSnafu)?;
44 let out = self.bn2.forward(&self.conv2.forward(&out)?)?.relu().context(TensorSnafu)?;
45 let out = self.bn3.forward(&self.conv3.forward(&out)?)?;
46 let shortcut = match &self.downsample {
47 Some((c, b)) => b.forward(&c.forward(x)?)?,
48 None => x.clone(),
49 };
50 out.try_add(&shortcut).context(TensorSnafu)?.relu().context(TensorSnafu)
51 }
52}
53
54impl HasStateDict for Bottleneck {
55 fn state_dict(&self, prefix: &str) -> StateDict {
56 let mut sd = self.conv1.state_dict(&prefixed(prefix, "conv1"));
57 sd.extend(self.bn1.state_dict(&prefixed(prefix, "bn1")));
58 sd.extend(self.conv2.state_dict(&prefixed(prefix, "conv2")));
59 sd.extend(self.bn2.state_dict(&prefixed(prefix, "bn2")));
60 sd.extend(self.conv3.state_dict(&prefixed(prefix, "conv3")));
61 sd.extend(self.bn3.state_dict(&prefixed(prefix, "bn3")));
62 if let Some((c, b)) = &self.downsample {
63 sd.extend(c.state_dict(&prefixed(prefix, "downsample.0")));
64 sd.extend(b.state_dict(&prefixed(prefix, "downsample.1")));
65 }
66 sd
67 }
68
69 fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
70 self.conv1.load_state_dict(sd, &prefixed(prefix, "conv1"))?;
71 self.bn1.load_state_dict(sd, &prefixed(prefix, "bn1"))?;
72 self.conv2.load_state_dict(sd, &prefixed(prefix, "conv2"))?;
73 self.bn2.load_state_dict(sd, &prefixed(prefix, "bn2"))?;
74 self.conv3.load_state_dict(sd, &prefixed(prefix, "conv3"))?;
75 self.bn3.load_state_dict(sd, &prefixed(prefix, "bn3"))?;
76 if let Some((c, b)) = &mut self.downsample {
77 c.load_state_dict(sd, &prefixed(prefix, "downsample.0"))?;
78 b.load_state_dict(sd, &prefixed(prefix, "downsample.1"))?;
79 }
80 Ok(())
81 }
82}