tch_plus/nn/
group_norm.rs

1//! A group-normalization layer.
2//! Group Normalization <https://arxiv.org/abs/1803.0849>
3use crate::Tensor;
4use std::borrow::Borrow;
5
6/// Group-normalization config.
7#[derive(Debug, Clone, Copy)]
8pub struct GroupNormConfig {
9    pub cudnn_enabled: bool,
10    pub eps: f64,
11    pub affine: bool,
12    pub ws_init: super::Init,
13    pub bs_init: super::Init,
14}
15
16impl Default for GroupNormConfig {
17    fn default() -> Self {
18        GroupNormConfig {
19            cudnn_enabled: true,
20            eps: 1e-5,
21            affine: true,
22            ws_init: super::Init::Const(1.),
23            bs_init: super::Init::Const(0.),
24        }
25    }
26}
27
28/// A group-normalization layer.
29#[derive(Debug)]
30pub struct GroupNorm {
31    config: GroupNormConfig,
32    pub ws: Option<Tensor>,
33    pub bs: Option<Tensor>,
34    pub num_groups: i64,
35    pub num_channels: i64,
36}
37
38pub fn group_norm<'a, T: Borrow<super::Path<'a>>>(
39    vs: T,
40    num_groups: i64,
41    num_channels: i64,
42    config: GroupNormConfig,
43) -> GroupNorm {
44    let vs = vs.borrow();
45    let (ws, bs) = if config.affine {
46        let ws = vs.var("weight", &[num_channels], config.ws_init);
47        let bs = vs.var("bias", &[num_channels], config.bs_init);
48        (Some(ws), Some(bs))
49    } else {
50        (None, None)
51    };
52    GroupNorm { config, ws, bs, num_groups, num_channels }
53}
54
55impl super::module::Module for GroupNorm {
56    fn forward(&self, xs: &Tensor) -> Tensor {
57        Tensor::group_norm(
58            xs,
59            self.num_groups,
60            self.ws.as_ref(),
61            self.bs.as_ref(),
62            self.config.eps,
63            self.config.cudnn_enabled,
64        )
65    }
66}