Skip to main content

tenflowers_neural/continuous_normalizing_flows/
flow_matching.rs

1//! Flow Matching, OT-CFM, and Rectified Flow
2//!
3//! Implements:
4//! - **CFM** (Lipman et al. 2022): Conditional Flow Matching
5//! - **OT-CFM** (Tong et al. 2023): Optimal Transport CFM
6//! - **Rectified Flow** (Liu et al. 2022): straight-line interpolation paths
7
8use super::mlp::CnfMlp;
9use super::utils::sample_standard_normal;
10use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
11use scirs2_core::RngExt;
12
13// ────────────────────────────────────────────────────────────────────────────
14// Flow Matching Configuration
15// ────────────────────────────────────────────────────────────────────────────
16
17/// Configuration for Flow Matching / CFM models.
18#[derive(Clone)]
19pub struct FlowMatchingConfig {
20    /// Dimensionality of the data/latent space.
21    pub z_dim: usize,
22    /// Hidden layer width.
23    pub hidden_dim: usize,
24    /// Number of hidden layers in the velocity network.
25    pub n_layers: usize,
26    /// Small constant `σ_min` for numerical stability (default `1e-4`).
27    pub sigma_min: f64,
28    /// Number of Euler steps during inference.
29    pub n_steps: usize,
30    /// Learning rate.
31    pub lr: f64,
32}
33
34impl Default for FlowMatchingConfig {
35    fn default() -> Self {
36        FlowMatchingConfig {
37            z_dim: 2,
38            hidden_dim: 64,
39            n_layers: 2,
40            sigma_min: 1e-4,
41            n_steps: 100,
42            lr: 1e-3,
43        }
44    }
45}
46
47// ────────────────────────────────────────────────────────────────────────────
48// Conditional Flow Matching (Lipman et al. 2022)
49// ────────────────────────────────────────────────────────────────────────────
50
51/// Conditional Flow Matching model (Lipman et al. 2022).
52///
53/// Learns a velocity field `v_θ(x, t)` such that Euler integration from `t=0` to `t=1`
54/// transforms noise `p_0 = N(0,I)` into data `p_1`.
55///
56/// Loss: `E[||v_θ(x_t, t) - u_t(x_t|x_0,x_1)||^2]`
57/// where `x_t = (1-(1-σ)t)·x_0 + t·x_1` and `u_t = x_1 - (1-σ)·x_0`.
58pub struct FlowMatchingModel {
59    /// Velocity network: input `[z; t]` (dim `z_dim + 1`) → output `z_dim`.
60    pub velocity_net: CnfMlp,
61    /// Configuration.
62    pub config: FlowMatchingConfig,
63}
64
65impl FlowMatchingModel {
66    /// Create a new FlowMatchingModel.
67    pub fn new(config: FlowMatchingConfig) -> Self {
68        let in_dim = config.z_dim + 1; // [z; t]
69        let mut sizes = vec![in_dim];
70        for _ in 0..config.n_layers {
71            sizes.push(config.hidden_dim);
72        }
73        sizes.push(config.z_dim);
74        FlowMatchingModel {
75            velocity_net: CnfMlp::new(&sizes),
76            config,
77        }
78    }
79
80    /// Evaluate the velocity field `v_θ(z, t)`.
81    pub fn velocity(&self, z: &[f64], t: f64) -> Vec<f64> {
82        let mut inp = z.to_vec();
83        inp.push(t);
84        self.velocity_net.forward(&inp)
85    }
86
87    /// Conditional Flow Matching loss.
88    ///
89    /// For each pair `(x0, x1)`:
90    /// - Sample `t ~ U[0,1]`
91    /// - Compute `x_t = (1 - (1 - σ_min) * t) * x0 + t * x1`
92    /// - Conditional vector field: `u_t = x1 - (1 - σ_min) * x0`
93    /// - Loss term: `||v_θ(x_t, t) - u_t||^2`
94    pub fn cfm_loss(&self, x0_batch: &[Vec<f64>], x1_batch: &[Vec<f64>], rng: &mut StdRng) -> f64 {
95        let n = x0_batch.len().min(x1_batch.len());
96        if n == 0 {
97            return 0.0;
98        }
99        let sigma_min = self.config.sigma_min;
100        let mut total_loss = 0.0_f64;
101
102        for i in 0..n {
103            let t: f64 = rng.random();
104            let x0 = &x0_batch[i];
105            let x1 = &x1_batch[i];
106            let d = x0.len().min(x1.len()).min(self.config.z_dim);
107
108            // x_t = (1 - (1 - sigma_min) * t) * x0 + t * x1
109            let xt: Vec<f64> = (0..d)
110                .map(|j| (1.0 - (1.0 - sigma_min) * t) * x0[j] + t * x1[j])
111                .collect();
112
113            // Conditional vector field u_t = x1 - (1 - sigma_min) * x0
114            let ut: Vec<f64> = (0..d).map(|j| x1[j] - (1.0 - sigma_min) * x0[j]).collect();
115
116            // Velocity prediction
117            let vt = self.velocity(&xt, t);
118
119            // MSE loss
120            let loss: f64 = vt
121                .iter()
122                .zip(ut.iter())
123                .map(|(v, u)| (v - u) * (v - u))
124                .sum::<f64>();
125            total_loss += loss / d.max(1) as f64;
126        }
127        total_loss / n as f64
128    }
129
130    /// One training step: compute CFM loss and update via finite-differences.
131    ///
132    /// Returns the CFM loss.
133    pub fn train_step(&mut self, x1_batch: &[Vec<f64>], lr: f64, rng: &mut StdRng) -> f64 {
134        if x1_batch.is_empty() {
135            return 0.0;
136        }
137        let d = self.config.z_dim;
138        let n = x1_batch.len();
139
140        // Sample noise batch x0 ~ N(0, I)
141        let x0_batch: Vec<Vec<f64>> = (0..n).map(|_| sample_standard_normal(d, rng)).collect();
142
143        let base_loss = self.cfm_loss(&x0_batch, x1_batch, rng);
144
145        // FD gradient update
146        let fd_eps = 1e-4;
147        let mut update_rng = StdRng::seed_from_u64(0x246810ac_u64);
148        let n_layers = self.velocity_net.n_layers();
149        let mut grad_w: Vec<Vec<Vec<f64>>> = self
150            .velocity_net
151            .weights
152            .iter()
153            .map(|lw| lw.iter().map(|row| vec![0.0; row.len()]).collect())
154            .collect();
155        let mut grad_b: Vec<Vec<f64>> = self
156            .velocity_net
157            .biases
158            .iter()
159            .map(|lb| vec![0.0; lb.len()])
160            .collect();
161
162        for l in 0..n_layers {
163            for j in 0..self.velocity_net.weights[l].len() {
164                for i in 0..self.velocity_net.weights[l][j].len() {
165                    if update_rng.random::<f64>() < 0.05 {
166                        self.velocity_net.weights[l][j][i] += fd_eps;
167                        let perturbed = self.cfm_loss(&x0_batch, x1_batch, rng);
168                        self.velocity_net.weights[l][j][i] -= fd_eps;
169                        grad_w[l][j][i] = (perturbed - base_loss) / fd_eps;
170                    }
171                }
172            }
173            for j in 0..self.velocity_net.biases[l].len() {
174                if update_rng.random::<f64>() < 0.05 {
175                    self.velocity_net.biases[l][j] += fd_eps;
176                    let perturbed = self.cfm_loss(&x0_batch, x1_batch, rng);
177                    self.velocity_net.biases[l][j] -= fd_eps;
178                    grad_b[l][j] = (perturbed - base_loss) / fd_eps;
179                }
180            }
181        }
182        self.velocity_net.update(&grad_w, &grad_b, lr);
183        base_loss
184    }
185
186    /// Sample a single data point by Euler-integrating the velocity field from `t=0` to `t=1`.
187    pub fn sample(&self, n_steps: usize, rng: &mut StdRng) -> Vec<f64> {
188        let d = self.config.z_dim;
189        let mut x = sample_standard_normal(d, rng);
190        let n = n_steps.max(1);
191        let dt = 1.0 / n as f64;
192
193        for step in 0..n {
194            let t = step as f64 * dt;
195            let v = self.velocity(&x, t);
196            for (xi, vi) in x.iter_mut().zip(v.iter()) {
197                *xi += dt * vi;
198            }
199        }
200        x
201    }
202
203    /// Sample a batch of data points.
204    pub fn sample_batch(
205        &self,
206        n_samples: usize,
207        n_steps: usize,
208        rng: &mut StdRng,
209    ) -> Vec<Vec<f64>> {
210        (0..n_samples).map(|_| self.sample(n_steps, rng)).collect()
211    }
212}
213
214// ────────────────────────────────────────────────────────────────────────────
215// OT-CFM — Optimal Transport Conditional Flow Matching (Tong et al. 2023)
216// ────────────────────────────────────────────────────────────────────────────
217
218/// Optimal-transport conditional flow matching model.
219///
220/// Uses mini-batch greedy OT to straighten probability paths, reducing the
221/// curvature of learned trajectories compared to standard CFM.
222pub struct OtCfmModel {
223    /// Velocity network: input `[z; t]` → output `z_dim`.
224    pub velocity_net: CnfMlp,
225    /// Configuration (shared with FlowMatchingModel).
226    pub config: FlowMatchingConfig,
227}
228
229impl OtCfmModel {
230    /// Create a new OT-CFM model.
231    pub fn new(config: FlowMatchingConfig) -> Self {
232        let in_dim = config.z_dim + 1;
233        let mut sizes = vec![in_dim];
234        for _ in 0..config.n_layers {
235            sizes.push(config.hidden_dim);
236        }
237        sizes.push(config.z_dim);
238        OtCfmModel {
239            velocity_net: CnfMlp::new(&sizes),
240            config,
241        }
242    }
243
244    /// Greedy OT matching: for each data point `x1[j]`, find the nearest noise point `x0[i]`.
245    ///
246    /// Returns a permutation `perm` of `x0` indices such that `x0[perm[j]]` is paired
247    /// with `x1[j]`.  Uses a greedy nearest-neighbour algorithm in O(n^2).
248    pub fn ot_match(x0_batch: &[Vec<f64>], x1_batch: &[Vec<f64>]) -> Vec<usize> {
249        let n0 = x0_batch.len();
250        let n1 = x1_batch.len();
251        let n = n0.min(n1);
252        let mut perm = vec![0usize; n];
253        let mut used = vec![false; n0];
254
255        for j in 0..n {
256            let x1 = &x1_batch[j];
257            let mut best_idx = 0usize;
258            let mut best_dist = f64::INFINITY;
259            for i in 0..n0 {
260                if used[i] {
261                    continue;
262                }
263                let dist: f64 = x0_batch[i]
264                    .iter()
265                    .zip(x1.iter())
266                    .map(|(a, b)| (a - b) * (a - b))
267                    .sum();
268                if dist < best_dist {
269                    best_dist = dist;
270                    best_idx = i;
271                }
272            }
273            perm[j] = best_idx;
274            used[best_idx] = true;
275        }
276        perm
277    }
278
279    /// Evaluate the velocity field.
280    pub fn velocity(&self, z: &[f64], t: f64) -> Vec<f64> {
281        let mut inp = z.to_vec();
282        inp.push(t);
283        self.velocity_net.forward(&inp)
284    }
285
286    /// One training step with OT matching.
287    ///
288    /// Returns the CFM loss after matching.
289    pub fn train_step(&mut self, x1_batch: &[Vec<f64>], lr: f64, rng: &mut StdRng) -> f64 {
290        if x1_batch.is_empty() {
291            return 0.0;
292        }
293        let d = self.config.z_dim;
294        let n = x1_batch.len();
295
296        // Sample noise batch x0 ~ N(0, I)
297        let x0_raw: Vec<Vec<f64>> = (0..n).map(|_| sample_standard_normal(d, rng)).collect();
298
299        // OT matching: reorder x0 to be close to x1
300        let perm = Self::ot_match(&x0_raw, x1_batch);
301        let x0_matched: Vec<Vec<f64>> = perm.iter().map(|&idx| x0_raw[idx].clone()).collect();
302
303        let sigma_min = self.config.sigma_min;
304
305        // Compute CFM loss with OT-matched pairs
306        let compute_loss = |vel_net: &CnfMlp, rng_inner: &mut StdRng| -> f64 {
307            let mut total = 0.0_f64;
308            for i in 0..n.min(x0_matched.len()) {
309                let t: f64 = rng_inner.random();
310                let x0 = &x0_matched[i];
311                let x1 = &x1_batch[i];
312                let dim = x0.len().min(x1.len()).min(d);
313                let xt: Vec<f64> = (0..dim)
314                    .map(|j| (1.0 - (1.0 - sigma_min) * t) * x0[j] + t * x1[j])
315                    .collect();
316                let ut: Vec<f64> = (0..dim)
317                    .map(|j| x1[j] - (1.0 - sigma_min) * x0[j])
318                    .collect();
319                let mut inp = xt.clone();
320                inp.push(t);
321                let vt = vel_net.forward(&inp);
322                let loss: f64 = vt
323                    .iter()
324                    .zip(ut.iter())
325                    .map(|(v, u)| (v - u) * (v - u))
326                    .sum::<f64>();
327                total += loss / dim.max(1) as f64;
328            }
329            total / n as f64
330        };
331
332        let mut eval_rng = StdRng::seed_from_u64(0xf0e1d2c3_u64);
333        let base_loss = compute_loss(&self.velocity_net, &mut eval_rng);
334
335        // FD gradient update
336        let fd_eps = 1e-4;
337        let mut update_rng = StdRng::seed_from_u64(0xa1b2c3d4_u64);
338        let n_layers = self.velocity_net.n_layers();
339        let mut grad_w: Vec<Vec<Vec<f64>>> = self
340            .velocity_net
341            .weights
342            .iter()
343            .map(|lw| lw.iter().map(|row| vec![0.0; row.len()]).collect())
344            .collect();
345        let mut grad_b: Vec<Vec<f64>> = self
346            .velocity_net
347            .biases
348            .iter()
349            .map(|lb| vec![0.0; lb.len()])
350            .collect();
351
352        for l in 0..n_layers {
353            for j in 0..self.velocity_net.weights[l].len() {
354                for i in 0..self.velocity_net.weights[l][j].len() {
355                    if update_rng.random::<f64>() < 0.04 {
356                        self.velocity_net.weights[l][j][i] += fd_eps;
357                        let mut r = StdRng::seed_from_u64(0xf0e1d2c3_u64);
358                        let perturbed = compute_loss(&self.velocity_net, &mut r);
359                        self.velocity_net.weights[l][j][i] -= fd_eps;
360                        grad_w[l][j][i] = (perturbed - base_loss) / fd_eps;
361                    }
362                }
363            }
364            for j in 0..self.velocity_net.biases[l].len() {
365                if update_rng.random::<f64>() < 0.04 {
366                    self.velocity_net.biases[l][j] += fd_eps;
367                    let mut r = StdRng::seed_from_u64(0xf0e1d2c3_u64);
368                    let perturbed = compute_loss(&self.velocity_net, &mut r);
369                    self.velocity_net.biases[l][j] -= fd_eps;
370                    grad_b[l][j] = (perturbed - base_loss) / fd_eps;
371                }
372            }
373        }
374        self.velocity_net.update(&grad_w, &grad_b, lr);
375        base_loss
376    }
377
378    /// Sample a single point from the model via Euler integration.
379    pub fn sample(&self, n_steps: usize, rng: &mut StdRng) -> Vec<f64> {
380        let d = self.config.z_dim;
381        let mut x = sample_standard_normal(d, rng);
382        let n = n_steps.max(1);
383        let dt = 1.0 / n as f64;
384
385        for step in 0..n {
386            let t = step as f64 * dt;
387            let v = self.velocity(&x, t);
388            for (xi, vi) in x.iter_mut().zip(v.iter()) {
389                *xi += dt * vi;
390            }
391        }
392        x
393    }
394}
395
396// ────────────────────────────────────────────────────────────────────────────
397// Rectified Flow (Liu et al. 2022)
398// ────────────────────────────────────────────────────────────────────────────
399
400/// Configuration for Rectified Flow.
401#[derive(Clone)]
402pub struct RectifiedFlowConfig {
403    /// Dimensionality of the data/latent space.
404    pub z_dim: usize,
405    /// Hidden layer width.
406    pub hidden_dim: usize,
407    /// Number of hidden layers.
408    pub n_layers: usize,
409    /// Number of Euler integration steps during inference.
410    pub n_steps: usize,
411    /// Learning rate.
412    pub lr: f64,
413}
414
415impl Default for RectifiedFlowConfig {
416    fn default() -> Self {
417        RectifiedFlowConfig {
418            z_dim: 2,
419            hidden_dim: 64,
420            n_layers: 2,
421            n_steps: 100,
422            lr: 1e-3,
423        }
424    }
425}
426
427/// Rectified Flow: learns a vector field `v_θ(x_t, t) ≈ x1 - x0` along straight paths.
428///
429/// At training time: `x_t = x0 + t * (x1 - x0)`, target = `x1 - x0`.
430/// At sampling time: Euler-integrate `dX/dt = v_θ(X, t)` from `t=0` to `t=1`.
431pub struct RectifiedFlow {
432    /// Velocity network: input `[z; t]` → output `z_dim`.
433    pub velocity_net: CnfMlp,
434    /// Configuration.
435    pub config: RectifiedFlowConfig,
436}
437
438impl RectifiedFlow {
439    /// Create a new Rectified Flow model.
440    pub fn new(config: RectifiedFlowConfig) -> Self {
441        let in_dim = config.z_dim + 1;
442        let mut sizes = vec![in_dim];
443        for _ in 0..config.n_layers {
444            sizes.push(config.hidden_dim);
445        }
446        sizes.push(config.z_dim);
447        RectifiedFlow {
448            velocity_net: CnfMlp::new(&sizes),
449            config,
450        }
451    }
452
453    /// Reflow loss: `E[||v_θ(x_t, t) - (x1 - x0)||^2]`.
454    ///
455    /// `x_t = x0 + t * (x1 - x0)`, `t ~ U[0,1]`.
456    pub fn reflow_loss(
457        &self,
458        x0_batch: &[Vec<f64>],
459        x1_batch: &[Vec<f64>],
460        rng: &mut StdRng,
461    ) -> f64 {
462        let n = x0_batch.len().min(x1_batch.len());
463        if n == 0 {
464            return 0.0;
465        }
466        let d = self.config.z_dim;
467        let mut total = 0.0_f64;
468
469        for i in 0..n {
470            let t: f64 = rng.random();
471            let x0 = &x0_batch[i];
472            let x1 = &x1_batch[i];
473            let dim = x0.len().min(x1.len()).min(d);
474
475            // x_t = x0 + t * (x1 - x0)
476            let xt: Vec<f64> = (0..dim).map(|j| x0[j] + t * (x1[j] - x0[j])).collect();
477
478            // Target: straight-line velocity = x1 - x0
479            let target: Vec<f64> = (0..dim).map(|j| x1[j] - x0[j]).collect();
480
481            let mut inp = xt;
482            inp.push(t);
483            let pred = self.velocity_net.forward(&inp);
484
485            let loss: f64 = pred
486                .iter()
487                .zip(target.iter())
488                .map(|(p, tg)| (p - tg) * (p - tg))
489                .sum::<f64>();
490            total += loss / dim.max(1) as f64;
491        }
492        total / n as f64
493    }
494
495    /// One training step: compute reflow loss and update via finite-differences.
496    ///
497    /// Returns the reflow loss.
498    pub fn train_step(&mut self, x1_batch: &[Vec<f64>], lr: f64, rng: &mut StdRng) -> f64 {
499        if x1_batch.is_empty() {
500            return 0.0;
501        }
502        let d = self.config.z_dim;
503        let n = x1_batch.len();
504
505        // Sample noise x0 ~ N(0, I)
506        let x0_batch: Vec<Vec<f64>> = (0..n).map(|_| sample_standard_normal(d, rng)).collect();
507
508        let mut eval_rng = StdRng::seed_from_u64(0x55aa77bb_u64);
509        let base_loss = self.reflow_loss(&x0_batch, x1_batch, &mut eval_rng);
510
511        // FD gradient update
512        let fd_eps = 1e-4;
513        let mut update_rng = StdRng::seed_from_u64(0xcc11ee22_u64);
514        let n_layers = self.velocity_net.n_layers();
515        let mut grad_w: Vec<Vec<Vec<f64>>> = self
516            .velocity_net
517            .weights
518            .iter()
519            .map(|lw| lw.iter().map(|row| vec![0.0; row.len()]).collect())
520            .collect();
521        let mut grad_b: Vec<Vec<f64>> = self
522            .velocity_net
523            .biases
524            .iter()
525            .map(|lb| vec![0.0; lb.len()])
526            .collect();
527
528        for l in 0..n_layers {
529            for j in 0..self.velocity_net.weights[l].len() {
530                for i in 0..self.velocity_net.weights[l][j].len() {
531                    if update_rng.random::<f64>() < 0.05 {
532                        self.velocity_net.weights[l][j][i] += fd_eps;
533                        let mut r = StdRng::seed_from_u64(0x55aa77bb_u64);
534                        let perturbed = self.reflow_loss(&x0_batch, x1_batch, &mut r);
535                        self.velocity_net.weights[l][j][i] -= fd_eps;
536                        grad_w[l][j][i] = (perturbed - base_loss) / fd_eps;
537                    }
538                }
539            }
540            for j in 0..self.velocity_net.biases[l].len() {
541                if update_rng.random::<f64>() < 0.05 {
542                    self.velocity_net.biases[l][j] += fd_eps;
543                    let mut r = StdRng::seed_from_u64(0x55aa77bb_u64);
544                    let perturbed = self.reflow_loss(&x0_batch, x1_batch, &mut r);
545                    self.velocity_net.biases[l][j] -= fd_eps;
546                    grad_b[l][j] = (perturbed - base_loss) / fd_eps;
547                }
548            }
549        }
550        self.velocity_net.update(&grad_w, &grad_b, lr);
551        base_loss
552    }
553
554    /// Sample a data point by Euler-integrating `v_θ` from `t=0` to `t=1`.
555    pub fn sample(&self, n_steps: usize, rng: &mut StdRng) -> Vec<f64> {
556        let d = self.config.z_dim;
557        let mut x = sample_standard_normal(d, rng);
558        let n = n_steps.max(1);
559        let dt = 1.0 / n as f64;
560
561        for step in 0..n {
562            let t = step as f64 * dt;
563            let mut inp = x.clone();
564            inp.push(t);
565            let v = self.velocity_net.forward(&inp);
566            for (xi, vi) in x.iter_mut().zip(v.iter()) {
567                *xi += dt * vi;
568            }
569        }
570        x
571    }
572}