svod_model/blocks/
basic_block.rs1use snafu::ResultExt;
2use svod_tensor::Tensor;
3
4use crate::state::{self, HasStateDict, StateDict, prefixed};
5
6use super::batchnorm::BatchNormWeights;
7use super::conv::Conv2dWeights;
8use super::error::{Result, TensorSnafu};
9
10#[derive(Copy, Clone, Debug, Eq, PartialEq)]
12pub enum BlockKind {
13 Basic,
15 Bottleneck,
17}
18
19impl BlockKind {
20 pub fn expansion(self) -> usize {
21 match self {
22 BlockKind::Basic => 1,
23 BlockKind::Bottleneck => 4,
24 }
25 }
26}
27
28#[derive(Clone)]
29pub struct BasicBlock {
30 pub conv1: Conv2dWeights,
31 pub bn1: BatchNormWeights,
32 pub conv2: Conv2dWeights,
33 pub bn2: BatchNormWeights,
34 pub downsample: Option<(Conv2dWeights, BatchNormWeights)>,
35}
36
37impl BasicBlock {
38 pub fn empty(in_planes: usize, planes: usize, stride: usize) -> Self {
39 let expansion = BlockKind::Basic.expansion();
40 let downsample = if stride != 1 || in_planes != planes * expansion {
41 Some((
42 Conv2dWeights::empty(planes * expansion, in_planes, 1, stride, 0),
43 BatchNormWeights::empty(planes * expansion),
44 ))
45 } else {
46 None
47 };
48 Self {
49 conv1: Conv2dWeights::empty(planes, in_planes, 3, stride, 1),
50 bn1: BatchNormWeights::empty(planes),
51 conv2: Conv2dWeights::empty(planes, planes, 3, 1, 1),
52 bn2: BatchNormWeights::empty(planes),
53 downsample,
54 }
55 }
56
57 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
58 let out = self.bn1.forward(&self.conv1.forward(x)?)?;
59 let out = out.relu().context(TensorSnafu)?;
60 let out = self.bn2.forward(&self.conv2.forward(&out)?)?;
61 let shortcut = match &self.downsample {
62 Some((c, b)) => b.forward(&c.forward(x)?)?,
63 None => x.clone(),
64 };
65 out.try_add(&shortcut).context(TensorSnafu)?.relu().context(TensorSnafu)
66 }
67}
68
69impl HasStateDict for BasicBlock {
70 fn state_dict(&self, prefix: &str) -> StateDict {
71 let mut sd = self.conv1.state_dict(&prefixed(prefix, "conv1"));
72 sd.extend(self.bn1.state_dict(&prefixed(prefix, "bn1")));
73 sd.extend(self.conv2.state_dict(&prefixed(prefix, "conv2")));
74 sd.extend(self.bn2.state_dict(&prefixed(prefix, "bn2")));
75 if let Some((c, b)) = &self.downsample {
76 sd.extend(c.state_dict(&prefixed(prefix, "downsample.0")));
77 sd.extend(b.state_dict(&prefixed(prefix, "downsample.1")));
78 }
79 sd
80 }
81
82 fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
83 self.conv1.load_state_dict(sd, &prefixed(prefix, "conv1"))?;
84 self.bn1.load_state_dict(sd, &prefixed(prefix, "bn1"))?;
85 self.conv2.load_state_dict(sd, &prefixed(prefix, "conv2"))?;
86 self.bn2.load_state_dict(sd, &prefixed(prefix, "bn2"))?;
87 if let Some((c, b)) = &mut self.downsample {
88 c.load_state_dict(sd, &prefixed(prefix, "downsample.0"))?;
89 b.load_state_dict(sd, &prefixed(prefix, "downsample.1"))?;
90 }
91 Ok(())
92 }
93}