tch_plus/nn/
group_norm.rs1use crate::Tensor;
4use std::borrow::Borrow;
5
6#[derive(Debug, Clone, Copy)]
8pub struct GroupNormConfig {
9 pub cudnn_enabled: bool,
10 pub eps: f64,
11 pub affine: bool,
12 pub ws_init: super::Init,
13 pub bs_init: super::Init,
14}
15
16impl Default for GroupNormConfig {
17 fn default() -> Self {
18 GroupNormConfig {
19 cudnn_enabled: true,
20 eps: 1e-5,
21 affine: true,
22 ws_init: super::Init::Const(1.),
23 bs_init: super::Init::Const(0.),
24 }
25 }
26}
27
28#[derive(Debug)]
30pub struct GroupNorm {
31 config: GroupNormConfig,
32 pub ws: Option<Tensor>,
33 pub bs: Option<Tensor>,
34 pub num_groups: i64,
35 pub num_channels: i64,
36}
37
38pub fn group_norm<'a, T: Borrow<super::Path<'a>>>(
39 vs: T,
40 num_groups: i64,
41 num_channels: i64,
42 config: GroupNormConfig,
43) -> GroupNorm {
44 let vs = vs.borrow();
45 let (ws, bs) = if config.affine {
46 let ws = vs.var("weight", &[num_channels], config.ws_init);
47 let bs = vs.var("bias", &[num_channels], config.bs_init);
48 (Some(ws), Some(bs))
49 } else {
50 (None, None)
51 };
52 GroupNorm { config, ws, bs, num_groups, num_channels }
53}
54
55impl super::module::Module for GroupNorm {
56 fn forward(&self, xs: &Tensor) -> Tensor {
57 Tensor::group_norm(
58 xs,
59 self.num_groups,
60 self.ws.as_ref(),
61 self.bs.as_ref(),
62 self.config.eps,
63 self.config.cudnn_enabled,
64 )
65 }
66}