1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
//! Variable initialization.
use crate::{Device, Kind, Tensor};

#[derive(Debug, Copy, Clone)]
pub enum Init {
    Const(f64),
    Randn { mean: f64, stdev: f64 },
    Uniform { lo: f64, up: f64 },
    KaimingUniform,
}

pub fn init(i: Init, dims: &[i64], device: Device) -> Tensor {
    match i {
        Init::Const(cst) => {
            // Optimize the case for which a single C++ code can be done.
            if cst == 0. {
                Tensor::zeros(dims, (Kind::Float, device))
            } else if (cst - 1.) <= std::f64::EPSILON {
                Tensor::ones(dims, (Kind::Float, device))
            } else {
                Tensor::ones(dims, (Kind::Float, device)) * cst
            }
        }
        Init::Uniform { lo, up } => Tensor::zeros(dims, (Kind::Float, device)).uniform_(lo, up),
        Init::Randn { mean, stdev } => {
            if mean == 0. && (stdev - 1.) <= std::f64::EPSILON {
                Tensor::randn(dims, (Kind::Float, device))
            } else {
                Tensor::randn(dims, (Kind::Float, device)) * stdev + mean
            }
        }
        Init::KaimingUniform => {
            let fan_in: i64 = dims.iter().skip(1).product();
            let bound = (1.0 / fan_in as f64).sqrt();
            Tensor::zeros(dims, (Kind::Float, device)).uniform_(-bound, bound)
        }
    }
}

impl Init {
    pub fn set(self, tensor: &mut Tensor) {
        match self {
            Init::Const(cst) => {
                let _ = tensor.fill_(cst);
            }
            Init::Uniform { lo, up } => {
                let _ = tensor.uniform_(lo, up);
            }
            Init::KaimingUniform => {
                let fan_in: i64 = tensor.size().iter().skip(1).product();
                let bound = (1.0 / fan_in as f64).sqrt();
                let _ = tensor.uniform_(-bound, bound);
            }
            Init::Randn { mean, stdev } => {
                tensor.copy_(&(tensor.randn_like() * stdev + mean));
            }
        }
    }
}

impl Tensor {
    pub fn init(&mut self, i: Init) {
        i.set(self)
    }
}