svod_model/blocks/
stage.rs1use svod_tensor::Tensor;
2
3use crate::state::{self, HasStateDict, StateDict, prefixed};
4
5use super::basic_block::{BasicBlock, BlockKind};
6use super::bottleneck::Bottleneck;
7use super::error::Result;
8
9#[derive(Clone)]
10pub enum Block {
11 Basic(BasicBlock),
12 Bottleneck(Bottleneck),
13}
14
15impl Block {
16 fn forward(&self, x: &Tensor) -> Result<Tensor> {
17 match self {
18 Block::Basic(b) => b.forward(x),
19 Block::Bottleneck(b) => b.forward(x),
20 }
21 }
22}
23
24impl HasStateDict for Block {
25 fn state_dict(&self, prefix: &str) -> StateDict {
26 match self {
27 Block::Basic(b) => b.state_dict(prefix),
28 Block::Bottleneck(b) => b.state_dict(prefix),
29 }
30 }
31
32 fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
33 match self {
34 Block::Basic(b) => b.load_state_dict(sd, prefix),
35 Block::Bottleneck(b) => b.load_state_dict(sd, prefix),
36 }
37 }
38}
39
40#[derive(Clone)]
41pub struct ResidualStage {
42 pub blocks: Vec<Block>,
43}
44
45impl ResidualStage {
46 pub fn empty(kind: BlockKind, in_planes: usize, planes: usize, num_blocks: usize, stride: usize) -> Self {
51 let expansion = kind.expansion();
52 let mut blocks = Vec::with_capacity(num_blocks);
53 let mut current_in = in_planes;
54 for i in 0..num_blocks {
55 let s = if i == 0 { stride } else { 1 };
56 let block = match kind {
57 BlockKind::Basic => Block::Basic(BasicBlock::empty(current_in, planes, s)),
58 BlockKind::Bottleneck => Block::Bottleneck(Bottleneck::empty(current_in, planes, s)),
59 };
60 blocks.push(block);
61 current_in = planes * expansion;
62 }
63 Self { blocks }
64 }
65
66 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
67 let mut x = x.clone();
68 for b in &self.blocks {
69 x = b.forward(&x)?;
70 }
71 Ok(x)
72 }
73}
74
75impl HasStateDict for ResidualStage {
76 fn state_dict(&self, prefix: &str) -> StateDict {
77 let mut sd = StateDict::new();
78 for (i, b) in self.blocks.iter().enumerate() {
79 sd.extend(b.state_dict(&prefixed(prefix, &i.to_string())));
80 }
81 sd
82 }
83
84 fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
85 for (i, b) in self.blocks.iter_mut().enumerate() {
86 b.load_state_dict(sd, &prefixed(prefix, &i.to_string()))?;
87 }
88 Ok(())
89 }
90}