Skip to main content

svod_model/blocks/
stage.rs

1use svod_tensor::Tensor;
2
3use crate::state::{self, HasStateDict, StateDict, prefixed};
4
5use super::basic_block::{BasicBlock, BlockKind};
6use super::bottleneck::Bottleneck;
7use super::error::Result;
8
9#[derive(Clone)]
10pub enum Block {
11    Basic(BasicBlock),
12    Bottleneck(Bottleneck),
13}
14
15impl Block {
16    fn forward(&self, x: &Tensor) -> Result<Tensor> {
17        match self {
18            Block::Basic(b) => b.forward(x),
19            Block::Bottleneck(b) => b.forward(x),
20        }
21    }
22}
23
24impl HasStateDict for Block {
25    fn state_dict(&self, prefix: &str) -> StateDict {
26        match self {
27            Block::Basic(b) => b.state_dict(prefix),
28            Block::Bottleneck(b) => b.state_dict(prefix),
29        }
30    }
31
32    fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
33        match self {
34            Block::Basic(b) => b.load_state_dict(sd, prefix),
35            Block::Bottleneck(b) => b.load_state_dict(sd, prefix),
36        }
37    }
38}
39
40#[derive(Clone)]
41pub struct ResidualStage {
42    pub blocks: Vec<Block>,
43}
44
45impl ResidualStage {
46    /// Construct a fresh stage. The first block may downsample (`stride`);
47    /// remaining blocks always have stride 1. Channel width follows the
48    /// canonical schedule: every block in the stage emits `planes * expansion`
49    /// channels, and the next block sees that as its `in_planes`.
50    pub fn empty(kind: BlockKind, in_planes: usize, planes: usize, num_blocks: usize, stride: usize) -> Self {
51        let expansion = kind.expansion();
52        let mut blocks = Vec::with_capacity(num_blocks);
53        let mut current_in = in_planes;
54        for i in 0..num_blocks {
55            let s = if i == 0 { stride } else { 1 };
56            let block = match kind {
57                BlockKind::Basic => Block::Basic(BasicBlock::empty(current_in, planes, s)),
58                BlockKind::Bottleneck => Block::Bottleneck(Bottleneck::empty(current_in, planes, s)),
59            };
60            blocks.push(block);
61            current_in = planes * expansion;
62        }
63        Self { blocks }
64    }
65
66    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
67        let mut x = x.clone();
68        for b in &self.blocks {
69            x = b.forward(&x)?;
70        }
71        Ok(x)
72    }
73}
74
75impl HasStateDict for ResidualStage {
76    fn state_dict(&self, prefix: &str) -> StateDict {
77        let mut sd = StateDict::new();
78        for (i, b) in self.blocks.iter().enumerate() {
79            sd.extend(b.state_dict(&prefixed(prefix, &i.to_string())));
80        }
81        sd
82    }
83
84    fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
85        for (i, b) in self.blocks.iter_mut().enumerate() {
86            b.load_state_dict(sd, &prefixed(prefix, &i.to_string()))?;
87        }
88        Ok(())
89    }
90}