svod_model/resnet/
config.rs1use crate::blocks::BlockKind;
2
3#[derive(Copy, Clone, Debug, Eq, PartialEq)]
6pub enum ResNetDepth {
7 R18,
8 R34,
9 R50,
10 R101,
11 R152,
12}
13
14impl ResNetDepth {
15 pub fn layers(self) -> [usize; 4] {
17 match self {
18 ResNetDepth::R18 => [2, 2, 2, 2],
19 ResNetDepth::R34 => [3, 4, 6, 3],
20 ResNetDepth::R50 => [3, 4, 6, 3],
21 ResNetDepth::R101 => [3, 4, 23, 3],
22 ResNetDepth::R152 => [3, 8, 36, 3],
23 }
24 }
25
26 pub fn block(self) -> BlockKind {
27 match self {
28 ResNetDepth::R18 | ResNetDepth::R34 => BlockKind::Basic,
29 ResNetDepth::R50 | ResNetDepth::R101 | ResNetDepth::R152 => BlockKind::Bottleneck,
30 }
31 }
32
33 pub fn expansion(self) -> usize {
34 self.block().expansion()
35 }
36}
37
38#[derive(Copy, Clone, Debug)]
41pub enum OutputMode {
42 Classification { num_classes: usize },
45 Features,
48}
49
50#[derive(Clone, Debug)]
51pub struct ResNetConfig {
52 pub depth: ResNetDepth,
53 pub output: OutputMode,
54 pub max_batch_size: usize,
59}
60
61impl ResNetConfig {
62 pub fn new(depth: ResNetDepth, output: OutputMode) -> Self {
63 Self { depth, output, max_batch_size: 1 }
64 }
65
66 pub fn with_max_batch_size(mut self, max_batch_size: usize) -> Self {
67 self.max_batch_size = max_batch_size;
68 self
69 }
70}