1use crate::{Device, Kind, TchError, Tensor};
3
4#[derive(Debug, Copy, Clone)]
10pub enum FanInOut {
11 FanIn,
12 FanOut,
13}
14
15impl FanInOut {
16 pub fn for_weight_dims(&self, dims: &[i64]) -> i64 {
20 let receptive_field_size: i64 = dims.iter().skip(2).product();
21 match &self {
22 FanInOut::FanIn => {
23 if dims.len() < 2 {
24 1
25 } else {
26 dims[1] * receptive_field_size
27 }
28 }
29 FanInOut::FanOut => {
30 if dims.is_empty() {
31 1
32 } else {
33 dims[0] * receptive_field_size
34 }
35 }
36 }
37 }
38}
39
40#[derive(Debug, Copy, Clone)]
41pub enum NormalOrUniform {
42 Normal,
43 Uniform,
44}
45
46#[derive(Debug, Copy, Clone)]
49pub enum NonLinearity {
50 ReLU,
51 Linear,
52 Sigmoid,
53 Tanh,
54 SELU,
55 ExplicitGain(f64),
56}
57
58impl NonLinearity {
59 pub fn gain(&self) -> f64 {
60 match *self {
61 NonLinearity::ReLU => 2f64.sqrt(),
62 NonLinearity::Tanh => 5. / 3.,
63 NonLinearity::Linear | NonLinearity::Sigmoid => 1.,
64 NonLinearity::SELU => 0.75,
65 NonLinearity::ExplicitGain(g) => g,
66 }
67 }
68}
69
70#[derive(Debug, Copy, Clone)]
72pub enum Init {
73 Const(f64),
75
76 Randn { mean: f64, stdev: f64 },
78
79 Uniform { lo: f64, up: f64 },
81
82 Kaiming { dist: NormalOrUniform, fan: FanInOut, non_linearity: NonLinearity },
86
87 Orthogonal { gain: f64 },
89}
90
91pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
92 dist: NormalOrUniform::Uniform,
93 fan: FanInOut::FanIn,
94 non_linearity: NonLinearity::ReLU,
95};
96
97pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
98 dist: NormalOrUniform::Normal,
99 fan: FanInOut::FanIn,
100 non_linearity: NonLinearity::ReLU,
101};
102
103pub fn f_init(i: Init, dims: &[i64], device: Device, kind: Kind) -> Result<Tensor, TchError> {
105 match i {
106 Init::Const(cst) => {
107 if cst == 0. {
109 Tensor::f_zeros(dims, (kind, device))
110 } else if (cst - 1.).abs() <= f64::EPSILON {
111 Tensor::f_ones(dims, (kind, device))
112 } else {
113 Tensor::f_ones(dims, (kind, device)).map(|t| t * cst)
114 }
115 }
116 Init::Uniform { lo, up } => Tensor::f_zeros(dims, (kind, device))?.f_uniform_(lo, up),
117 Init::Randn { mean, stdev } => {
118 if mean == 0. && (stdev - 1.).abs() <= f64::EPSILON {
119 Tensor::f_randn(dims, (kind, device))
120 } else {
121 Tensor::f_randn(dims, (kind, device)).map(|t| t * stdev + mean)
122 }
123 }
124 Init::Kaiming { dist, fan, non_linearity } => {
125 let fan = fan.for_weight_dims(dims);
126 let gain = non_linearity.gain();
127 let std = gain / (fan as f64).sqrt();
128 match dist {
129 NormalOrUniform::Uniform => {
130 let bound = 3f64.sqrt() * std;
131 Tensor::f_zeros(dims, (kind, device))?.f_uniform_(-bound, bound)
132 }
133 NormalOrUniform::Normal => {
134 let randn = Tensor::f_randn(dims, (kind, device))?;
135 Ok(randn * std)
136 }
137 }
138 }
139 Init::Orthogonal { gain } => {
140 if dims.len() < 2 {
141 return Err(TchError::Shape(
142 "Only tensors with 2 or more dimensions are supported".to_string(),
143 ));
144 }
145 let rows = dims[0];
146 let cols: i64 = dims.iter().skip(1).product();
147
148 let mut flattened =
149 Tensor::f_empty([rows, cols], (kind, device))?.f_normal_(0.0, 1.0)?;
150 let flattened = if rows < cols { flattened.f_t_()? } else { flattened };
151
152 let (mut q, r) = Tensor::f_linalg_qr(&flattened, "reduced")?;
153 let d = r.f_diag(0)?;
154 let ph = d.f_sign()?;
155 q *= ph;
156
157 let mut q = if rows < cols { q.f_t_()? } else { q };
158 crate::no_grad(|| q *= gain);
159
160 q.f_contiguous()
161 }
162 }
163}
164
165pub fn init(i: Init, dims: &[i64], device: Device) -> Tensor {
167 f_init(i, dims, device, Kind::Float).unwrap()
168}
169
170impl Init {
171 pub fn set(self, tensor: &mut Tensor) {
173 match self {
174 Init::Const(cst) => {
175 let _ = tensor.fill_(cst);
176 }
177 Init::Uniform { lo, up } => {
178 let _ = tensor.uniform_(lo, up);
179 }
180 Init::Kaiming { dist, fan, non_linearity } => {
181 let fan = fan.for_weight_dims(&tensor.size());
182 let gain = non_linearity.gain();
183 let std = gain / (fan as f64).sqrt();
184 match dist {
185 NormalOrUniform::Uniform => {
186 let bound = 3f64.sqrt() * std;
187 let _ = tensor.uniform_(-bound, bound);
188 }
189 NormalOrUniform::Normal => {
190 tensor.copy_(&(tensor.randn_like() * std));
191 }
192 }
193 }
194 Init::Randn { mean, stdev } => {
195 tensor.copy_(&(tensor.randn_like() * stdev + mean));
196 }
197 Init::Orthogonal { gain } => {
198 let q =
199 f_init(Init::Orthogonal { gain }, &tensor.size(), tensor.device(), Kind::Float)
200 .unwrap();
201 crate::no_grad(|| tensor.view_as(&q).copy_(&q));
202 }
203 }
204 }
205}
206
207impl Tensor {
208 pub fn init(&mut self, i: Init) {
210 i.set(self)
211 }
212}