tenflowers_neural/continuous_normalizing_flows/
mlp.rs1use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
6use scirs2_core::RngExt;
7use std::f64::consts::PI;
8
9#[derive(Clone)]
18pub struct CnfMlp {
19 pub weights: Vec<Vec<Vec<f64>>>,
21 pub biases: Vec<Vec<f64>>,
23}
24
25impl CnfMlp {
26 pub fn new(layer_sizes: &[usize]) -> Self {
31 assert!(layer_sizes.len() >= 2, "need at least input + output layer");
32 let n_layers = layer_sizes.len() - 1;
33 let mut weights = Vec::with_capacity(n_layers);
34 let mut biases = Vec::with_capacity(n_layers);
35 let mut rng = StdRng::seed_from_u64(0xabcdef01_u64);
36
37 for l in 0..n_layers {
38 let fan_in = layer_sizes[l];
39 let fan_out = layer_sizes[l + 1];
40 let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
42 let layer_w: Vec<Vec<f64>> = (0..fan_out)
43 .map(|_| {
44 (0..fan_in)
45 .map(|_| {
46 let u: f64 = rng.random();
47 u * 2.0 * limit - limit
48 })
49 .collect()
50 })
51 .collect();
52 let layer_b: Vec<f64> = vec![0.0; fan_out];
53 weights.push(layer_w);
54 biases.push(layer_b);
55 }
56
57 CnfMlp { weights, biases }
58 }
59
60 pub fn forward(&self, x: &[f64]) -> Vec<f64> {
62 let n_layers = self.weights.len();
63 let mut h: Vec<f64> = x.to_vec();
64
65 for (l, (w, b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
66 let out_dim = w.len();
67 let mut z = vec![0.0_f64; out_dim];
68 for j in 0..out_dim {
69 let mut acc = b[j];
70 for (i, hi) in h.iter().enumerate() {
71 acc += w[j][i] * hi;
72 }
73 z[j] = acc;
74 }
75 if l < n_layers - 1 {
77 for v in z.iter_mut() {
78 *v = v.tanh();
79 }
80 }
81 h = z;
82 }
83 h
84 }
85
86 pub fn jacobian_diagonal_approx(&self, x: &[f64]) -> Vec<f64> {
90 let eps = 1e-5;
91 let d = x.len();
92 let mut diag = vec![0.0_f64; d];
93 for i in 0..d {
94 let mut xp = x.to_vec();
95 let mut xm = x.to_vec();
96 xp[i] += eps;
97 xm[i] -= eps;
98 let fp = self.forward(&xp);
99 let fm = self.forward(&xm);
100 if i < fp.len() {
101 diag[i] = (fp[i] - fm[i]) / (2.0 * eps);
102 }
103 }
104 diag
105 }
106
107 pub fn update(&mut self, grad_w: &[Vec<Vec<f64>>], grad_b: &[Vec<f64>], lr: f64) {
109 for (l, (gw, gb)) in grad_w.iter().zip(grad_b.iter()).enumerate() {
110 if l >= self.weights.len() {
111 break;
112 }
113 for j in 0..self.weights[l].len().min(gw.len()) {
114 for i in 0..self.weights[l][j].len().min(gw[j].len()) {
115 self.weights[l][j][i] -= lr * gw[j][i];
116 }
117 }
118 for j in 0..self.biases[l].len().min(gb.len()) {
119 self.biases[l][j] -= lr * gb[j];
120 }
121 }
122 }
123
124 pub fn n_layers(&self) -> usize {
126 self.weights.len()
127 }
128
129 pub fn out_dim(&self) -> usize {
131 self.weights.last().map(|w| w.len()).unwrap_or(0)
132 }
133}
134
135#[derive(Clone)]
144pub struct CnfDynamics {
145 pub mlp: CnfMlp,
147 pub z_dim: usize,
149 pub include_time: bool,
151}
152
153impl CnfDynamics {
154 pub fn new(z_dim: usize, hidden_dim: usize, n_layers: usize, include_time: bool) -> Self {
158 let in_dim = if include_time { z_dim + 1 } else { z_dim };
159 let mut sizes = vec![in_dim];
160 for _ in 0..n_layers {
161 sizes.push(hidden_dim);
162 }
163 sizes.push(z_dim);
164 CnfDynamics {
165 mlp: CnfMlp::new(&sizes),
166 z_dim,
167 include_time,
168 }
169 }
170
171 pub fn forward(&self, z: &[f64], t: f64) -> Vec<f64> {
173 if self.include_time {
174 let mut inp = z.to_vec();
175 inp.push(t);
176 self.mlp.forward(&inp)
177 } else {
178 self.mlp.forward(z)
179 }
180 }
181
182 pub fn trace_jac_approx(&self, z: &[f64], t: f64, n_samples: usize, rng: &mut StdRng) -> f64 {
187 let eps = 1e-4;
188 let mut trace_est = 0.0_f64;
189 let n = n_samples.max(1);
190
191 for _ in 0..n {
192 let epsilon: Vec<f64> = (0..self.z_dim)
194 .map(|_| if rng.random::<f64>() < 0.5 { 1.0 } else { -1.0 })
195 .collect();
196
197 let z_plus: Vec<f64> = z
199 .iter()
200 .zip(epsilon.iter())
201 .map(|(zi, ei)| zi + eps * ei)
202 .collect();
203 let z_minus: Vec<f64> = z
204 .iter()
205 .zip(epsilon.iter())
206 .map(|(zi, ei)| zi - eps * ei)
207 .collect();
208
209 let f_plus = self.forward(&z_plus, t);
210 let f_minus = self.forward(&z_minus, t);
211
212 let sample_est: f64 = epsilon
215 .iter()
216 .zip(f_plus.iter())
217 .zip(f_minus.iter())
218 .map(|((ei, fp_i), fm_i)| ei * (fp_i - fm_i) / (2.0 * eps))
219 .sum();
220 trace_est += sample_est;
221 }
222 trace_est / n as f64
223 }
224}
225
226pub struct ContinuousNormalizingFlow {
240 pub dynamics: CnfDynamics,
242 pub base_mean: Vec<f64>,
244 pub base_std: Vec<f64>,
246}
247
248impl ContinuousNormalizingFlow {
249 pub fn new(z_dim: usize, hidden_dim: usize, n_layers: usize) -> Self {
251 ContinuousNormalizingFlow {
252 dynamics: CnfDynamics::new(z_dim, hidden_dim, n_layers, true),
253 base_mean: vec![0.0; z_dim],
254 base_std: vec![1.0; z_dim],
255 }
256 }
257
258 pub fn integrate_forward(
262 &self,
263 z0: &[f64],
264 n_steps: usize,
265 t_start: f64,
266 t_end: f64,
267 ) -> (Vec<f64>, f64) {
268 let n = n_steps.max(1);
269 let dt = (t_end - t_start) / n as f64;
270 let mut z = z0.to_vec();
271 let mut log_det = 0.0_f64;
272 let mut rng = StdRng::seed_from_u64(0xdeadbeef_u64);
273
274 for step in 0..n {
275 let t = t_start + step as f64 * dt;
276 let dz = self.dynamics.forward(&z, t);
277 let tr = self.dynamics.trace_jac_approx(&z, t, 1, &mut rng);
278 for (zi, dzi) in z.iter_mut().zip(dz.iter()) {
280 *zi += dt * dzi;
281 }
282 log_det += dt * tr;
283 }
284 (z, log_det)
285 }
286
287 pub fn integrate_backward(&self, x: &[f64], n_steps: usize) -> (Vec<f64>, f64) {
291 let n = n_steps.max(1);
292 let dt = 1.0 / n as f64;
293 let mut z = x.to_vec();
294 let mut log_det = 0.0_f64;
295 let mut rng = StdRng::seed_from_u64(0xcafe1234_u64);
296
297 for step in 0..n {
299 let t = 1.0 - step as f64 * dt;
300 let dz = self.dynamics.forward(&z, t);
301 let tr = self.dynamics.trace_jac_approx(&z, t, 1, &mut rng);
302 for (zi, dzi) in z.iter_mut().zip(dz.iter()) {
304 *zi -= dt * dzi;
305 }
306 log_det += dt * tr;
307 }
308 (z, log_det)
309 }
310
311 pub fn log_prob(&self, x: &[f64], n_steps: usize) -> f64 {
313 let (z0, log_det) = self.integrate_backward(x, n_steps);
314 let log_p0 = self.log_base_prob(&z0);
315 log_p0 + log_det
316 }
317
318 pub(crate) fn log_base_prob(&self, z: &[f64]) -> f64 {
320 let d = z.len().min(self.base_mean.len()).min(self.base_std.len());
321 let mut lp = 0.0_f64;
322 for i in 0..d {
323 let sigma = self.base_std[i].max(1e-15);
324 let diff = z[i] - self.base_mean[i];
325 lp -= 0.5 * (diff * diff / (sigma * sigma) + (2.0 * PI * sigma * sigma).ln());
326 }
327 lp
328 }
329
330 pub fn sample(&self, n_steps: usize, rng: &mut StdRng) -> Vec<f64> {
332 let d = self.dynamics.z_dim;
333 let z0: Vec<f64> = (0..d)
335 .map(|i| {
336 let u1: f64 = rng.random::<f64>().max(1e-15);
337 let u2: f64 = rng.random::<f64>();
338 let g = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
339 self.base_mean[i] + self.base_std[i] * g
340 })
341 .collect();
342 let (x, _log_det) = self.integrate_forward(&z0, n_steps, 0.0, 1.0);
343 x
344 }
345
346 pub fn train_step(&mut self, x_batch: &[Vec<f64>], n_steps: usize, lr: f64) -> f64 {
350 if x_batch.is_empty() {
351 return 0.0;
352 }
353 let batch_size = x_batch.len();
354
355 let base_loss: f64 = x_batch
357 .iter()
358 .map(|x| -self.log_prob(x, n_steps))
359 .sum::<f64>()
360 / batch_size as f64;
361
362 let fd_eps = 1e-4;
364 let n_layers = self.dynamics.mlp.n_layers();
365
366 let mut rng = StdRng::seed_from_u64(0x98765432_u64);
367 let mut grad_w: Vec<Vec<Vec<f64>>> = self
368 .dynamics
369 .mlp
370 .weights
371 .iter()
372 .map(|lw| lw.iter().map(|row| vec![0.0; row.len()]).collect())
373 .collect();
374 let mut grad_b: Vec<Vec<f64>> = self
375 .dynamics
376 .mlp
377 .biases
378 .iter()
379 .map(|lb| vec![0.0; lb.len()])
380 .collect();
381
382 for l in 0..n_layers {
384 for j in 0..self.dynamics.mlp.weights[l].len() {
385 for i in 0..self.dynamics.mlp.weights[l][j].len() {
386 if rng.random::<f64>() < 0.05 {
387 self.dynamics.mlp.weights[l][j][i] += fd_eps;
389 let perturbed_loss: f64 = x_batch
390 .iter()
391 .map(|x| -self.log_prob(x, n_steps))
392 .sum::<f64>()
393 / batch_size as f64;
394 self.dynamics.mlp.weights[l][j][i] -= fd_eps;
395 grad_w[l][j][i] = (perturbed_loss - base_loss) / fd_eps;
396 }
397 }
398 }
399 for j in 0..self.dynamics.mlp.biases[l].len() {
400 if rng.random::<f64>() < 0.05 {
401 self.dynamics.mlp.biases[l][j] += fd_eps;
402 let perturbed_loss: f64 = x_batch
403 .iter()
404 .map(|x| -self.log_prob(x, n_steps))
405 .sum::<f64>()
406 / batch_size as f64;
407 self.dynamics.mlp.biases[l][j] -= fd_eps;
408 grad_b[l][j] = (perturbed_loss - base_loss) / fd_eps;
409 }
410 }
411 }
412
413 self.dynamics.mlp.update(&grad_w, &grad_b, lr);
414 base_loss
415 }
416}