1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#[derive(Clone, Copy)]
pub enum InitializationType {
    ///-1 to 1
    Random,
    Xavier,
    Fixed(f32),
}

pub fn calc_initialization(typ: InitializationType, prev_layer_size: usize) -> f32 {
    match typ {
        InitializationType::Random => fastrand::f32() * 2. - 1.,
        InitializationType::Xavier => {
            (fastrand::f32() * 2. - 1.) * (1.0 / prev_layer_size as f32).sqrt()
        }
        InitializationType::Fixed(val) => val,
    }
}