Skip to main content

svod_model/blocks/
basic_block.rs

1use snafu::ResultExt;
2use svod_tensor::Tensor;
3
4use crate::state::{self, HasStateDict, StateDict, prefixed};
5
6use super::batchnorm::BatchNormWeights;
7use super::conv::Conv2dWeights;
8use super::error::{Result, TensorSnafu};
9
10/// Which residual block class a stage uses.
11#[derive(Copy, Clone, Debug, Eq, PartialEq)]
12pub enum BlockKind {
13    /// Two 3×3 convs per block, no channel expansion.
14    Basic,
15    /// 1×1 → 3×3 → 1×1 bottleneck, 4× channel expansion.
16    Bottleneck,
17}
18
19impl BlockKind {
20    pub fn expansion(self) -> usize {
21        match self {
22            BlockKind::Basic => 1,
23            BlockKind::Bottleneck => 4,
24        }
25    }
26}
27
28#[derive(Clone)]
29pub struct BasicBlock {
30    pub conv1: Conv2dWeights,
31    pub bn1: BatchNormWeights,
32    pub conv2: Conv2dWeights,
33    pub bn2: BatchNormWeights,
34    pub downsample: Option<(Conv2dWeights, BatchNormWeights)>,
35}
36
37impl BasicBlock {
38    pub fn empty(in_planes: usize, planes: usize, stride: usize) -> Self {
39        let expansion = BlockKind::Basic.expansion();
40        let downsample = if stride != 1 || in_planes != planes * expansion {
41            Some((
42                Conv2dWeights::empty(planes * expansion, in_planes, 1, stride, 0),
43                BatchNormWeights::empty(planes * expansion),
44            ))
45        } else {
46            None
47        };
48        Self {
49            conv1: Conv2dWeights::empty(planes, in_planes, 3, stride, 1),
50            bn1: BatchNormWeights::empty(planes),
51            conv2: Conv2dWeights::empty(planes, planes, 3, 1, 1),
52            bn2: BatchNormWeights::empty(planes),
53            downsample,
54        }
55    }
56
57    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
58        let out = self.bn1.forward(&self.conv1.forward(x)?)?;
59        let out = out.relu().context(TensorSnafu)?;
60        let out = self.bn2.forward(&self.conv2.forward(&out)?)?;
61        let shortcut = match &self.downsample {
62            Some((c, b)) => b.forward(&c.forward(x)?)?,
63            None => x.clone(),
64        };
65        out.try_add(&shortcut).context(TensorSnafu)?.relu().context(TensorSnafu)
66    }
67}
68
69impl HasStateDict for BasicBlock {
70    fn state_dict(&self, prefix: &str) -> StateDict {
71        let mut sd = self.conv1.state_dict(&prefixed(prefix, "conv1"));
72        sd.extend(self.bn1.state_dict(&prefixed(prefix, "bn1")));
73        sd.extend(self.conv2.state_dict(&prefixed(prefix, "conv2")));
74        sd.extend(self.bn2.state_dict(&prefixed(prefix, "bn2")));
75        if let Some((c, b)) = &self.downsample {
76            sd.extend(c.state_dict(&prefixed(prefix, "downsample.0")));
77            sd.extend(b.state_dict(&prefixed(prefix, "downsample.1")));
78        }
79        sd
80    }
81
82    fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
83        self.conv1.load_state_dict(sd, &prefixed(prefix, "conv1"))?;
84        self.bn1.load_state_dict(sd, &prefixed(prefix, "bn1"))?;
85        self.conv2.load_state_dict(sd, &prefixed(prefix, "conv2"))?;
86        self.bn2.load_state_dict(sd, &prefixed(prefix, "bn2"))?;
87        if let Some((c, b)) = &mut self.downsample {
88            c.load_state_dict(sd, &prefixed(prefix, "downsample.0"))?;
89            b.load_state_dict(sd, &prefixed(prefix, "downsample.1"))?;
90        }
91        Ok(())
92    }
93}