Skip to main content

svod_model/resnet/
config.rs

1use crate::blocks::BlockKind;
2
3/// Canonical ResNet depths. The depth selects both the block type and the
4/// per-stage block count schedule used by the original paper.
5#[derive(Copy, Clone, Debug, Eq, PartialEq)]
6pub enum ResNetDepth {
7    R18,
8    R34,
9    R50,
10    R101,
11    R152,
12}
13
14impl ResNetDepth {
15    /// Per-stage block count `[stage1, stage2, stage3, stage4]`.
16    pub fn layers(self) -> [usize; 4] {
17        match self {
18            ResNetDepth::R18 => [2, 2, 2, 2],
19            ResNetDepth::R34 => [3, 4, 6, 3],
20            ResNetDepth::R50 => [3, 4, 6, 3],
21            ResNetDepth::R101 => [3, 4, 23, 3],
22            ResNetDepth::R152 => [3, 8, 36, 3],
23        }
24    }
25
26    pub fn block(self) -> BlockKind {
27        match self {
28            ResNetDepth::R18 | ResNetDepth::R34 => BlockKind::Basic,
29            ResNetDepth::R50 | ResNetDepth::R101 | ResNetDepth::R152 => BlockKind::Bottleneck,
30        }
31    }
32
33    pub fn expansion(self) -> usize {
34        self.block().expansion()
35    }
36}
37
38/// Which forward tail the model executes. Switch with [`ResNet::with_output`]
39/// or set at construction time.
40#[derive(Copy, Clone, Debug)]
41pub enum OutputMode {
42    /// Add the FC head; forward returns logits `[B, num_classes]`. The FC
43    /// `weight` / `bias` tensors are loaded from `fc.weight` / `fc.bias`.
44    Classification { num_classes: usize },
45    /// Stop after stage 4; forward returns the final feature map
46    /// `[B, 512 * expansion, H/32, W/32]`. The FC weights are not loaded.
47    Features,
48}
49
50#[derive(Clone, Debug)]
51pub struct ResNetConfig {
52    pub depth: ResNetDepth,
53    pub output: OutputMode,
54    /// Upper bound on the symbolic `b` variable exposed by the JIT wrapper.
55    /// The prepared plan's image buffer is allocated to `max_batch_size`; the
56    /// per-call `execute_with_vars(&[("b", actual)])` shrinks the batch dim to
57    /// the live size.
58    pub max_batch_size: usize,
59}
60
61impl ResNetConfig {
62    pub fn new(depth: ResNetDepth, output: OutputMode) -> Self {
63        Self { depth, output, max_batch_size: 1 }
64    }
65
66    pub fn with_max_batch_size(mut self, max_batch_size: usize) -> Self {
67        self.max_batch_size = max_batch_size;
68        self
69    }
70}