tch_plus/nn/
init.rs

1//! Variable initialization.
2use crate::{Device, Kind, TchError, Tensor};
3
4/// Number of features as input or output of a layer.
5/// In Kaiming initialization, choosing `FanIn` preserves
6/// the magnitude of the variance of the weights in the
7/// forward pass, choosing `FanOut` preserves this
8/// magnitude in the backward pass.
9#[derive(Debug, Copy, Clone)]
10pub enum FanInOut {
11    FanIn,
12    FanOut,
13}
14
15impl FanInOut {
16    /// Compute the fan-in or fan-out value for a weight tensor of
17    /// the specified dimensions.
18    /// <https://github.com/pytorch/pytorch/blob/dbeacf11820e336e803bb719b7aaaf2125ae4d9c/torch/nn/init.py#L284>
19    pub fn for_weight_dims(&self, dims: &[i64]) -> i64 {
20        let receptive_field_size: i64 = dims.iter().skip(2).product();
21        match &self {
22            FanInOut::FanIn => {
23                if dims.len() < 2 {
24                    1
25                } else {
26                    dims[1] * receptive_field_size
27                }
28            }
29            FanInOut::FanOut => {
30                if dims.is_empty() {
31                    1
32                } else {
33                    dims[0] * receptive_field_size
34                }
35            }
36        }
37    }
38}
39
40#[derive(Debug, Copy, Clone)]
41pub enum NormalOrUniform {
42    Normal,
43    Uniform,
44}
45
46/// The non-linear function that follows this layer. ReLU is the
47/// recommended value.
48#[derive(Debug, Copy, Clone)]
49pub enum NonLinearity {
50    ReLU,
51    Linear,
52    Sigmoid,
53    Tanh,
54    SELU,
55    ExplicitGain(f64),
56}
57
58impl NonLinearity {
59    pub fn gain(&self) -> f64 {
60        match *self {
61            NonLinearity::ReLU => 2f64.sqrt(),
62            NonLinearity::Tanh => 5. / 3.,
63            NonLinearity::Linear | NonLinearity::Sigmoid => 1.,
64            NonLinearity::SELU => 0.75,
65            NonLinearity::ExplicitGain(g) => g,
66        }
67    }
68}
69
70/// Variable initializations.
71#[derive(Debug, Copy, Clone)]
72pub enum Init {
73    /// Constant value.
74    Const(f64),
75
76    /// Random normal with some mean and standard deviation.
77    Randn { mean: f64, stdev: f64 },
78
79    /// Uniform initialization between some lower and upper bounds.
80    Uniform { lo: f64, up: f64 },
81
82    /// Kaiming uniform initialization.
83    /// See "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification"
84    /// He, K. et al. (2015). This uses a uniform distribution.
85    Kaiming { dist: NormalOrUniform, fan: FanInOut, non_linearity: NonLinearity },
86
87    /// Orthogonal initialization
88    Orthogonal { gain: f64 },
89}
90
91pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
92    dist: NormalOrUniform::Uniform,
93    fan: FanInOut::FanIn,
94    non_linearity: NonLinearity::ReLU,
95};
96
97pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
98    dist: NormalOrUniform::Normal,
99    fan: FanInOut::FanIn,
100    non_linearity: NonLinearity::ReLU,
101};
102
103/// Creates a new float tensor with the specified shape, device, and initialization.
104pub fn f_init(i: Init, dims: &[i64], device: Device, kind: Kind) -> Result<Tensor, TchError> {
105    match i {
106        Init::Const(cst) => {
107            // Optimize the case for which a single C++ code can be done.
108            if cst == 0. {
109                Tensor::f_zeros(dims, (kind, device))
110            } else if (cst - 1.).abs() <= f64::EPSILON {
111                Tensor::f_ones(dims, (kind, device))
112            } else {
113                Tensor::f_ones(dims, (kind, device)).map(|t| t * cst)
114            }
115        }
116        Init::Uniform { lo, up } => Tensor::f_zeros(dims, (kind, device))?.f_uniform_(lo, up),
117        Init::Randn { mean, stdev } => {
118            if mean == 0. && (stdev - 1.).abs() <= f64::EPSILON {
119                Tensor::f_randn(dims, (kind, device))
120            } else {
121                Tensor::f_randn(dims, (kind, device)).map(|t| t * stdev + mean)
122            }
123        }
124        Init::Kaiming { dist, fan, non_linearity } => {
125            let fan = fan.for_weight_dims(dims);
126            let gain = non_linearity.gain();
127            let std = gain / (fan as f64).sqrt();
128            match dist {
129                NormalOrUniform::Uniform => {
130                    let bound = 3f64.sqrt() * std;
131                    Tensor::f_zeros(dims, (kind, device))?.f_uniform_(-bound, bound)
132                }
133                NormalOrUniform::Normal => {
134                    let randn = Tensor::f_randn(dims, (kind, device))?;
135                    Ok(randn * std)
136                }
137            }
138        }
139        Init::Orthogonal { gain } => {
140            if dims.len() < 2 {
141                return Err(TchError::Shape(
142                    "Only tensors with 2 or more dimensions are supported".to_string(),
143                ));
144            }
145            let rows = dims[0];
146            let cols: i64 = dims.iter().skip(1).product();
147
148            let mut flattened =
149                Tensor::f_empty([rows, cols], (kind, device))?.f_normal_(0.0, 1.0)?;
150            let flattened = if rows < cols { flattened.f_t_()? } else { flattened };
151
152            let (mut q, r) = Tensor::f_linalg_qr(&flattened, "reduced")?;
153            let d = r.f_diag(0)?;
154            let ph = d.f_sign()?;
155            q *= ph;
156
157            let mut q = if rows < cols { q.f_t_()? } else { q };
158            crate::no_grad(|| q *= gain);
159
160            q.f_contiguous()
161        }
162    }
163}
164
165/// Creates a new float tensor with the specified shape, device, and initialization.
166pub fn init(i: Init, dims: &[i64], device: Device) -> Tensor {
167    f_init(i, dims, device, Kind::Float).unwrap()
168}
169
170impl Init {
171    /// Re-initializes an existing tensor with the specified initialization
172    pub fn set(self, tensor: &mut Tensor) {
173        match self {
174            Init::Const(cst) => {
175                let _ = tensor.fill_(cst);
176            }
177            Init::Uniform { lo, up } => {
178                let _ = tensor.uniform_(lo, up);
179            }
180            Init::Kaiming { dist, fan, non_linearity } => {
181                let fan = fan.for_weight_dims(&tensor.size());
182                let gain = non_linearity.gain();
183                let std = gain / (fan as f64).sqrt();
184                match dist {
185                    NormalOrUniform::Uniform => {
186                        let bound = 3f64.sqrt() * std;
187                        let _ = tensor.uniform_(-bound, bound);
188                    }
189                    NormalOrUniform::Normal => {
190                        tensor.copy_(&(tensor.randn_like() * std));
191                    }
192                }
193            }
194            Init::Randn { mean, stdev } => {
195                tensor.copy_(&(tensor.randn_like() * stdev + mean));
196            }
197            Init::Orthogonal { gain } => {
198                let q =
199                    f_init(Init::Orthogonal { gain }, &tensor.size(), tensor.device(), Kind::Float)
200                        .unwrap();
201                crate::no_grad(|| tensor.view_as(&q).copy_(&q));
202            }
203        }
204    }
205}
206
207impl Tensor {
208    /// Re-initializes the tensor using the specified initialization.
209    pub fn init(&mut self, i: Init) {
210        i.set(self)
211    }
212}