Skip to main content

tenflowers_neural/continuous_normalizing_flows/
mlp.rs

1//! Shared MLP and CNF dynamics backbone.
2//!
3//! Contains `CnfMlp`, `CnfDynamics`, and `ContinuousNormalizingFlow`.
4
5use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
6use scirs2_core::RngExt;
7use std::f64::consts::PI;
8
9// ────────────────────────────────────────────────────────────────────────────
10// Shared MLP for dynamics / velocity networks
11// ────────────────────────────────────────────────────────────────────────────
12
13/// Multi-layer perceptron used as backbone for CNF dynamics and flow matching velocity fields.
14///
15/// Weights layout: `weights[layer][out_neuron][in_neuron]`
16/// Activation: tanh (smoother than ReLU — avoids discontinuous Jacobians in ODE integration).
17#[derive(Clone)]
18pub struct CnfMlp {
19    /// Layer weights: `weights[l][j][i]` = weight from neuron `i` (layer `l`) to neuron `j` (layer `l+1`).
20    pub weights: Vec<Vec<Vec<f64>>>,
21    /// Layer biases: `biases[l][j]`.
22    pub biases: Vec<Vec<f64>>,
23}
24
25impl CnfMlp {
26    /// Construct a new MLP with Xavier-initialised weights.
27    ///
28    /// `layer_sizes` specifies the width of each layer including input and output.
29    /// E.g. `[4, 64, 64, 4]` gives a 2-hidden-layer network mapping R^4 → R^4.
30    pub fn new(layer_sizes: &[usize]) -> Self {
31        assert!(layer_sizes.len() >= 2, "need at least input + output layer");
32        let n_layers = layer_sizes.len() - 1;
33        let mut weights = Vec::with_capacity(n_layers);
34        let mut biases = Vec::with_capacity(n_layers);
35        let mut rng = StdRng::seed_from_u64(0xabcdef01_u64);
36
37        for l in 0..n_layers {
38            let fan_in = layer_sizes[l];
39            let fan_out = layer_sizes[l + 1];
40            // Xavier uniform: U[-limit, limit], limit = sqrt(6 / (fan_in + fan_out))
41            let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
42            let layer_w: Vec<Vec<f64>> = (0..fan_out)
43                .map(|_| {
44                    (0..fan_in)
45                        .map(|_| {
46                            let u: f64 = rng.random();
47                            u * 2.0 * limit - limit
48                        })
49                        .collect()
50                })
51                .collect();
52            let layer_b: Vec<f64> = vec![0.0; fan_out];
53            weights.push(layer_w);
54            biases.push(layer_b);
55        }
56
57        CnfMlp { weights, biases }
58    }
59
60    /// Forward pass: applies affine transform + tanh for hidden layers, linear output.
61    pub fn forward(&self, x: &[f64]) -> Vec<f64> {
62        let n_layers = self.weights.len();
63        let mut h: Vec<f64> = x.to_vec();
64
65        for (l, (w, b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
66            let out_dim = w.len();
67            let mut z = vec![0.0_f64; out_dim];
68            for j in 0..out_dim {
69                let mut acc = b[j];
70                for (i, hi) in h.iter().enumerate() {
71                    acc += w[j][i] * hi;
72                }
73                z[j] = acc;
74            }
75            // Apply tanh activation on all layers except the last
76            if l < n_layers - 1 {
77                for v in z.iter_mut() {
78                    *v = v.tanh();
79                }
80            }
81            h = z;
82        }
83        h
84    }
85
86    /// Diagonal of the Jacobian via central finite differences.
87    ///
88    /// Returns `∂f_i/∂x_i` for each dimension `i`.
89    pub fn jacobian_diagonal_approx(&self, x: &[f64]) -> Vec<f64> {
90        let eps = 1e-5;
91        let d = x.len();
92        let mut diag = vec![0.0_f64; d];
93        for i in 0..d {
94            let mut xp = x.to_vec();
95            let mut xm = x.to_vec();
96            xp[i] += eps;
97            xm[i] -= eps;
98            let fp = self.forward(&xp);
99            let fm = self.forward(&xm);
100            if i < fp.len() {
101                diag[i] = (fp[i] - fm[i]) / (2.0 * eps);
102            }
103        }
104        diag
105    }
106
107    /// SGD weight update: `θ ← θ - lr * grad`.
108    pub fn update(&mut self, grad_w: &[Vec<Vec<f64>>], grad_b: &[Vec<f64>], lr: f64) {
109        for (l, (gw, gb)) in grad_w.iter().zip(grad_b.iter()).enumerate() {
110            if l >= self.weights.len() {
111                break;
112            }
113            for j in 0..self.weights[l].len().min(gw.len()) {
114                for i in 0..self.weights[l][j].len().min(gw[j].len()) {
115                    self.weights[l][j][i] -= lr * gw[j][i];
116                }
117            }
118            for j in 0..self.biases[l].len().min(gb.len()) {
119                self.biases[l][j] -= lr * gb[j];
120            }
121        }
122    }
123
124    /// Return the number of layers (including the output layer).
125    pub fn n_layers(&self) -> usize {
126        self.weights.len()
127    }
128
129    /// Output dimension of the MLP.
130    pub fn out_dim(&self) -> usize {
131        self.weights.last().map(|w| w.len()).unwrap_or(0)
132    }
133}
134
135// ────────────────────────────────────────────────────────────────────────────
136// CNF Dynamics — f(z, t)
137// ────────────────────────────────────────────────────────────────────────────
138
139/// ODE dynamics `dz/dt = f(z, t; θ)` backed by a `CnfMlp`.
140///
141/// If `include_time = true`, the input to the MLP is the concatenation `[z; t]`
142/// giving a time-conditioned vector field.
143#[derive(Clone)]
144pub struct CnfDynamics {
145    /// The neural network parameterising the vector field.
146    pub mlp: CnfMlp,
147    /// Dimensionality of the state vector `z`.
148    pub z_dim: usize,
149    /// Whether to include time `t` as an extra input coordinate.
150    pub include_time: bool,
151}
152
153impl CnfDynamics {
154    /// Create CNF dynamics with the given architecture.
155    ///
156    /// Network input: `z_dim` (or `z_dim + 1` if `include_time`), output: `z_dim`.
157    pub fn new(z_dim: usize, hidden_dim: usize, n_layers: usize, include_time: bool) -> Self {
158        let in_dim = if include_time { z_dim + 1 } else { z_dim };
159        let mut sizes = vec![in_dim];
160        for _ in 0..n_layers {
161            sizes.push(hidden_dim);
162        }
163        sizes.push(z_dim);
164        CnfDynamics {
165            mlp: CnfMlp::new(&sizes),
166            z_dim,
167            include_time,
168        }
169    }
170
171    /// Evaluate the vector field `f(z, t)`.
172    pub fn forward(&self, z: &[f64], t: f64) -> Vec<f64> {
173        if self.include_time {
174            let mut inp = z.to_vec();
175            inp.push(t);
176            self.mlp.forward(&inp)
177        } else {
178            self.mlp.forward(z)
179        }
180    }
181
182    /// Hutchinson trace estimator: `E_ε[ε^T (∂f/∂z) ε]` with `ε ~ Rademacher(±1)`.
183    ///
184    /// For each sample, finite-difference JVP: `(∂f/∂z)ε ≈ (f(z+δε) - f(z-δε)) / (2δ)`.
185    /// Then inner-product with `ε` gives an unbiased trace estimate.
186    pub fn trace_jac_approx(&self, z: &[f64], t: f64, n_samples: usize, rng: &mut StdRng) -> f64 {
187        let eps = 1e-4;
188        let mut trace_est = 0.0_f64;
189        let n = n_samples.max(1);
190
191        for _ in 0..n {
192            // Sample Rademacher vector ε ∈ {-1, +1}^d
193            let epsilon: Vec<f64> = (0..self.z_dim)
194                .map(|_| if rng.random::<f64>() < 0.5 { 1.0 } else { -1.0 })
195                .collect();
196
197            // z + eps * ε and z - eps * ε
198            let z_plus: Vec<f64> = z
199                .iter()
200                .zip(epsilon.iter())
201                .map(|(zi, ei)| zi + eps * ei)
202                .collect();
203            let z_minus: Vec<f64> = z
204                .iter()
205                .zip(epsilon.iter())
206                .map(|(zi, ei)| zi - eps * ei)
207                .collect();
208
209            let f_plus = self.forward(&z_plus, t);
210            let f_minus = self.forward(&z_minus, t);
211
212            // JVP: Jε ≈ (f(z + eps*ε) - f(z - eps*ε)) / (2*eps)
213            // Trace estimate: ε^T Jε
214            let sample_est: f64 = epsilon
215                .iter()
216                .zip(f_plus.iter())
217                .zip(f_minus.iter())
218                .map(|((ei, fp_i), fm_i)| ei * (fp_i - fm_i) / (2.0 * eps))
219                .sum();
220            trace_est += sample_est;
221        }
222        trace_est / n as f64
223    }
224}
225
226// ────────────────────────────────────────────────────────────────────────────
227// Continuous Normalizing Flow
228// ────────────────────────────────────────────────────────────────────────────
229
230/// Continuous Normalizing Flow model.
231///
232/// Transforms a base distribution `p_0 = N(base_mean, diag(base_std^2))` into
233/// a target distribution via the ODE `dz/dt = f(z, t; θ)`.
234///
235/// Log-likelihood computation:
236/// ```text
237/// log p(x) = log p_0(z_0) + ∫_0^T tr(∂f/∂z) dt
238/// ```
239pub struct ContinuousNormalizingFlow {
240    /// ODE dynamics network.
241    pub dynamics: CnfDynamics,
242    /// Mean of the base Gaussian distribution.
243    pub base_mean: Vec<f64>,
244    /// Standard deviation of the base Gaussian distribution (diagonal).
245    pub base_std: Vec<f64>,
246}
247
248impl ContinuousNormalizingFlow {
249    /// Create a CNF with standard-normal base distribution.
250    pub fn new(z_dim: usize, hidden_dim: usize, n_layers: usize) -> Self {
251        ContinuousNormalizingFlow {
252            dynamics: CnfDynamics::new(z_dim, hidden_dim, n_layers, true),
253            base_mean: vec![0.0; z_dim],
254            base_std: vec![1.0; z_dim],
255        }
256    }
257
258    /// Euler-integrate the augmented ODE `(dz/dt, d_logdet/dt)` from `t_start` to `t_end`.
259    ///
260    /// Returns `(z_T, log_det_jacobian)` where `log_det_jacobian = ∫ tr(∂f/∂z) dt`.
261    pub fn integrate_forward(
262        &self,
263        z0: &[f64],
264        n_steps: usize,
265        t_start: f64,
266        t_end: f64,
267    ) -> (Vec<f64>, f64) {
268        let n = n_steps.max(1);
269        let dt = (t_end - t_start) / n as f64;
270        let mut z = z0.to_vec();
271        let mut log_det = 0.0_f64;
272        let mut rng = StdRng::seed_from_u64(0xdeadbeef_u64);
273
274        for step in 0..n {
275            let t = t_start + step as f64 * dt;
276            let dz = self.dynamics.forward(&z, t);
277            let tr = self.dynamics.trace_jac_approx(&z, t, 1, &mut rng);
278            // Euler step
279            for (zi, dzi) in z.iter_mut().zip(dz.iter()) {
280                *zi += dt * dzi;
281            }
282            log_det += dt * tr;
283        }
284        (z, log_det)
285    }
286
287    /// Euler-integrate backwards from `T=1` to `T=0` (inverse direction).
288    ///
289    /// Returns `(z_0, log_det)` where `log_det = -∫_T^0 tr(∂f/∂z) dt`.
290    pub fn integrate_backward(&self, x: &[f64], n_steps: usize) -> (Vec<f64>, f64) {
291        let n = n_steps.max(1);
292        let dt = 1.0 / n as f64;
293        let mut z = x.to_vec();
294        let mut log_det = 0.0_f64;
295        let mut rng = StdRng::seed_from_u64(0xcafe1234_u64);
296
297        // Integrate from t=1 down to t=0 in steps of -dt
298        for step in 0..n {
299            let t = 1.0 - step as f64 * dt;
300            let dz = self.dynamics.forward(&z, t);
301            let tr = self.dynamics.trace_jac_approx(&z, t, 1, &mut rng);
302            // Euler step backwards: dz/dt in reverse time is -f
303            for (zi, dzi) in z.iter_mut().zip(dz.iter()) {
304                *zi -= dt * dzi;
305            }
306            log_det += dt * tr;
307        }
308        (z, log_det)
309    }
310
311    /// Compute log p(x) = log p_0(z_0) + log_det via backward integration.
312    pub fn log_prob(&self, x: &[f64], n_steps: usize) -> f64 {
313        let (z0, log_det) = self.integrate_backward(x, n_steps);
314        let log_p0 = self.log_base_prob(&z0);
315        log_p0 + log_det
316    }
317
318    /// Log-probability under the base Gaussian `N(base_mean, diag(base_std^2))`.
319    pub(crate) fn log_base_prob(&self, z: &[f64]) -> f64 {
320        let d = z.len().min(self.base_mean.len()).min(self.base_std.len());
321        let mut lp = 0.0_f64;
322        for i in 0..d {
323            let sigma = self.base_std[i].max(1e-15);
324            let diff = z[i] - self.base_mean[i];
325            lp -= 0.5 * (diff * diff / (sigma * sigma) + (2.0 * PI * sigma * sigma).ln());
326        }
327        lp
328    }
329
330    /// Sample a point from the model by sampling from the base distribution and integrating forward.
331    pub fn sample(&self, n_steps: usize, rng: &mut StdRng) -> Vec<f64> {
332        let d = self.dynamics.z_dim;
333        // Sample z_0 ~ N(base_mean, diag(base_std^2)) via Box-Muller
334        let z0: Vec<f64> = (0..d)
335            .map(|i| {
336                let u1: f64 = rng.random::<f64>().max(1e-15);
337                let u2: f64 = rng.random::<f64>();
338                let g = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
339                self.base_mean[i] + self.base_std[i] * g
340            })
341            .collect();
342        let (x, _log_det) = self.integrate_forward(&z0, n_steps, 0.0, 1.0);
343        x
344    }
345
346    /// One training step: compute mean negative-log-likelihood over `x_batch`, update via FD gradient.
347    ///
348    /// Returns the mean NLL loss.
349    pub fn train_step(&mut self, x_batch: &[Vec<f64>], n_steps: usize, lr: f64) -> f64 {
350        if x_batch.is_empty() {
351            return 0.0;
352        }
353        let batch_size = x_batch.len();
354
355        // Compute baseline loss
356        let base_loss: f64 = x_batch
357            .iter()
358            .map(|x| -self.log_prob(x, n_steps))
359            .sum::<f64>()
360            / batch_size as f64;
361
362        // Finite-difference gradient w.r.t. each MLP parameter and update
363        let fd_eps = 1e-4;
364        let n_layers = self.dynamics.mlp.n_layers();
365
366        let mut rng = StdRng::seed_from_u64(0x98765432_u64);
367        let mut grad_w: Vec<Vec<Vec<f64>>> = self
368            .dynamics
369            .mlp
370            .weights
371            .iter()
372            .map(|lw| lw.iter().map(|row| vec![0.0; row.len()]).collect())
373            .collect();
374        let mut grad_b: Vec<Vec<f64>> = self
375            .dynamics
376            .mlp
377            .biases
378            .iter()
379            .map(|lb| vec![0.0; lb.len()])
380            .collect();
381
382        // Stochastic FD: random subset of parameters
383        for l in 0..n_layers {
384            for j in 0..self.dynamics.mlp.weights[l].len() {
385                for i in 0..self.dynamics.mlp.weights[l][j].len() {
386                    if rng.random::<f64>() < 0.05 {
387                        // Perturb weight
388                        self.dynamics.mlp.weights[l][j][i] += fd_eps;
389                        let perturbed_loss: f64 = x_batch
390                            .iter()
391                            .map(|x| -self.log_prob(x, n_steps))
392                            .sum::<f64>()
393                            / batch_size as f64;
394                        self.dynamics.mlp.weights[l][j][i] -= fd_eps;
395                        grad_w[l][j][i] = (perturbed_loss - base_loss) / fd_eps;
396                    }
397                }
398            }
399            for j in 0..self.dynamics.mlp.biases[l].len() {
400                if rng.random::<f64>() < 0.05 {
401                    self.dynamics.mlp.biases[l][j] += fd_eps;
402                    let perturbed_loss: f64 = x_batch
403                        .iter()
404                        .map(|x| -self.log_prob(x, n_steps))
405                        .sum::<f64>()
406                        / batch_size as f64;
407                    self.dynamics.mlp.biases[l][j] -= fd_eps;
408                    grad_b[l][j] = (perturbed_loss - base_loss) / fd_eps;
409                }
410            }
411        }
412
413        self.dynamics.mlp.update(&grad_w, &grad_b, lr);
414        base_loss
415    }
416}