tenflowers_neural/continuous_normalizing_flows/
flow_matching.rs1use super::mlp::CnfMlp;
9use super::utils::sample_standard_normal;
10use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
11use scirs2_core::RngExt;
12
13#[derive(Clone)]
19pub struct FlowMatchingConfig {
20 pub z_dim: usize,
22 pub hidden_dim: usize,
24 pub n_layers: usize,
26 pub sigma_min: f64,
28 pub n_steps: usize,
30 pub lr: f64,
32}
33
34impl Default for FlowMatchingConfig {
35 fn default() -> Self {
36 FlowMatchingConfig {
37 z_dim: 2,
38 hidden_dim: 64,
39 n_layers: 2,
40 sigma_min: 1e-4,
41 n_steps: 100,
42 lr: 1e-3,
43 }
44 }
45}
46
47pub struct FlowMatchingModel {
59 pub velocity_net: CnfMlp,
61 pub config: FlowMatchingConfig,
63}
64
65impl FlowMatchingModel {
66 pub fn new(config: FlowMatchingConfig) -> Self {
68 let in_dim = config.z_dim + 1; let mut sizes = vec![in_dim];
70 for _ in 0..config.n_layers {
71 sizes.push(config.hidden_dim);
72 }
73 sizes.push(config.z_dim);
74 FlowMatchingModel {
75 velocity_net: CnfMlp::new(&sizes),
76 config,
77 }
78 }
79
80 pub fn velocity(&self, z: &[f64], t: f64) -> Vec<f64> {
82 let mut inp = z.to_vec();
83 inp.push(t);
84 self.velocity_net.forward(&inp)
85 }
86
87 pub fn cfm_loss(&self, x0_batch: &[Vec<f64>], x1_batch: &[Vec<f64>], rng: &mut StdRng) -> f64 {
95 let n = x0_batch.len().min(x1_batch.len());
96 if n == 0 {
97 return 0.0;
98 }
99 let sigma_min = self.config.sigma_min;
100 let mut total_loss = 0.0_f64;
101
102 for i in 0..n {
103 let t: f64 = rng.random();
104 let x0 = &x0_batch[i];
105 let x1 = &x1_batch[i];
106 let d = x0.len().min(x1.len()).min(self.config.z_dim);
107
108 let xt: Vec<f64> = (0..d)
110 .map(|j| (1.0 - (1.0 - sigma_min) * t) * x0[j] + t * x1[j])
111 .collect();
112
113 let ut: Vec<f64> = (0..d).map(|j| x1[j] - (1.0 - sigma_min) * x0[j]).collect();
115
116 let vt = self.velocity(&xt, t);
118
119 let loss: f64 = vt
121 .iter()
122 .zip(ut.iter())
123 .map(|(v, u)| (v - u) * (v - u))
124 .sum::<f64>();
125 total_loss += loss / d.max(1) as f64;
126 }
127 total_loss / n as f64
128 }
129
130 pub fn train_step(&mut self, x1_batch: &[Vec<f64>], lr: f64, rng: &mut StdRng) -> f64 {
134 if x1_batch.is_empty() {
135 return 0.0;
136 }
137 let d = self.config.z_dim;
138 let n = x1_batch.len();
139
140 let x0_batch: Vec<Vec<f64>> = (0..n).map(|_| sample_standard_normal(d, rng)).collect();
142
143 let base_loss = self.cfm_loss(&x0_batch, x1_batch, rng);
144
145 let fd_eps = 1e-4;
147 let mut update_rng = StdRng::seed_from_u64(0x246810ac_u64);
148 let n_layers = self.velocity_net.n_layers();
149 let mut grad_w: Vec<Vec<Vec<f64>>> = self
150 .velocity_net
151 .weights
152 .iter()
153 .map(|lw| lw.iter().map(|row| vec![0.0; row.len()]).collect())
154 .collect();
155 let mut grad_b: Vec<Vec<f64>> = self
156 .velocity_net
157 .biases
158 .iter()
159 .map(|lb| vec![0.0; lb.len()])
160 .collect();
161
162 for l in 0..n_layers {
163 for j in 0..self.velocity_net.weights[l].len() {
164 for i in 0..self.velocity_net.weights[l][j].len() {
165 if update_rng.random::<f64>() < 0.05 {
166 self.velocity_net.weights[l][j][i] += fd_eps;
167 let perturbed = self.cfm_loss(&x0_batch, x1_batch, rng);
168 self.velocity_net.weights[l][j][i] -= fd_eps;
169 grad_w[l][j][i] = (perturbed - base_loss) / fd_eps;
170 }
171 }
172 }
173 for j in 0..self.velocity_net.biases[l].len() {
174 if update_rng.random::<f64>() < 0.05 {
175 self.velocity_net.biases[l][j] += fd_eps;
176 let perturbed = self.cfm_loss(&x0_batch, x1_batch, rng);
177 self.velocity_net.biases[l][j] -= fd_eps;
178 grad_b[l][j] = (perturbed - base_loss) / fd_eps;
179 }
180 }
181 }
182 self.velocity_net.update(&grad_w, &grad_b, lr);
183 base_loss
184 }
185
186 pub fn sample(&self, n_steps: usize, rng: &mut StdRng) -> Vec<f64> {
188 let d = self.config.z_dim;
189 let mut x = sample_standard_normal(d, rng);
190 let n = n_steps.max(1);
191 let dt = 1.0 / n as f64;
192
193 for step in 0..n {
194 let t = step as f64 * dt;
195 let v = self.velocity(&x, t);
196 for (xi, vi) in x.iter_mut().zip(v.iter()) {
197 *xi += dt * vi;
198 }
199 }
200 x
201 }
202
203 pub fn sample_batch(
205 &self,
206 n_samples: usize,
207 n_steps: usize,
208 rng: &mut StdRng,
209 ) -> Vec<Vec<f64>> {
210 (0..n_samples).map(|_| self.sample(n_steps, rng)).collect()
211 }
212}
213
214pub struct OtCfmModel {
223 pub velocity_net: CnfMlp,
225 pub config: FlowMatchingConfig,
227}
228
229impl OtCfmModel {
230 pub fn new(config: FlowMatchingConfig) -> Self {
232 let in_dim = config.z_dim + 1;
233 let mut sizes = vec![in_dim];
234 for _ in 0..config.n_layers {
235 sizes.push(config.hidden_dim);
236 }
237 sizes.push(config.z_dim);
238 OtCfmModel {
239 velocity_net: CnfMlp::new(&sizes),
240 config,
241 }
242 }
243
244 pub fn ot_match(x0_batch: &[Vec<f64>], x1_batch: &[Vec<f64>]) -> Vec<usize> {
249 let n0 = x0_batch.len();
250 let n1 = x1_batch.len();
251 let n = n0.min(n1);
252 let mut perm = vec![0usize; n];
253 let mut used = vec![false; n0];
254
255 for j in 0..n {
256 let x1 = &x1_batch[j];
257 let mut best_idx = 0usize;
258 let mut best_dist = f64::INFINITY;
259 for i in 0..n0 {
260 if used[i] {
261 continue;
262 }
263 let dist: f64 = x0_batch[i]
264 .iter()
265 .zip(x1.iter())
266 .map(|(a, b)| (a - b) * (a - b))
267 .sum();
268 if dist < best_dist {
269 best_dist = dist;
270 best_idx = i;
271 }
272 }
273 perm[j] = best_idx;
274 used[best_idx] = true;
275 }
276 perm
277 }
278
279 pub fn velocity(&self, z: &[f64], t: f64) -> Vec<f64> {
281 let mut inp = z.to_vec();
282 inp.push(t);
283 self.velocity_net.forward(&inp)
284 }
285
286 pub fn train_step(&mut self, x1_batch: &[Vec<f64>], lr: f64, rng: &mut StdRng) -> f64 {
290 if x1_batch.is_empty() {
291 return 0.0;
292 }
293 let d = self.config.z_dim;
294 let n = x1_batch.len();
295
296 let x0_raw: Vec<Vec<f64>> = (0..n).map(|_| sample_standard_normal(d, rng)).collect();
298
299 let perm = Self::ot_match(&x0_raw, x1_batch);
301 let x0_matched: Vec<Vec<f64>> = perm.iter().map(|&idx| x0_raw[idx].clone()).collect();
302
303 let sigma_min = self.config.sigma_min;
304
305 let compute_loss = |vel_net: &CnfMlp, rng_inner: &mut StdRng| -> f64 {
307 let mut total = 0.0_f64;
308 for i in 0..n.min(x0_matched.len()) {
309 let t: f64 = rng_inner.random();
310 let x0 = &x0_matched[i];
311 let x1 = &x1_batch[i];
312 let dim = x0.len().min(x1.len()).min(d);
313 let xt: Vec<f64> = (0..dim)
314 .map(|j| (1.0 - (1.0 - sigma_min) * t) * x0[j] + t * x1[j])
315 .collect();
316 let ut: Vec<f64> = (0..dim)
317 .map(|j| x1[j] - (1.0 - sigma_min) * x0[j])
318 .collect();
319 let mut inp = xt.clone();
320 inp.push(t);
321 let vt = vel_net.forward(&inp);
322 let loss: f64 = vt
323 .iter()
324 .zip(ut.iter())
325 .map(|(v, u)| (v - u) * (v - u))
326 .sum::<f64>();
327 total += loss / dim.max(1) as f64;
328 }
329 total / n as f64
330 };
331
332 let mut eval_rng = StdRng::seed_from_u64(0xf0e1d2c3_u64);
333 let base_loss = compute_loss(&self.velocity_net, &mut eval_rng);
334
335 let fd_eps = 1e-4;
337 let mut update_rng = StdRng::seed_from_u64(0xa1b2c3d4_u64);
338 let n_layers = self.velocity_net.n_layers();
339 let mut grad_w: Vec<Vec<Vec<f64>>> = self
340 .velocity_net
341 .weights
342 .iter()
343 .map(|lw| lw.iter().map(|row| vec![0.0; row.len()]).collect())
344 .collect();
345 let mut grad_b: Vec<Vec<f64>> = self
346 .velocity_net
347 .biases
348 .iter()
349 .map(|lb| vec![0.0; lb.len()])
350 .collect();
351
352 for l in 0..n_layers {
353 for j in 0..self.velocity_net.weights[l].len() {
354 for i in 0..self.velocity_net.weights[l][j].len() {
355 if update_rng.random::<f64>() < 0.04 {
356 self.velocity_net.weights[l][j][i] += fd_eps;
357 let mut r = StdRng::seed_from_u64(0xf0e1d2c3_u64);
358 let perturbed = compute_loss(&self.velocity_net, &mut r);
359 self.velocity_net.weights[l][j][i] -= fd_eps;
360 grad_w[l][j][i] = (perturbed - base_loss) / fd_eps;
361 }
362 }
363 }
364 for j in 0..self.velocity_net.biases[l].len() {
365 if update_rng.random::<f64>() < 0.04 {
366 self.velocity_net.biases[l][j] += fd_eps;
367 let mut r = StdRng::seed_from_u64(0xf0e1d2c3_u64);
368 let perturbed = compute_loss(&self.velocity_net, &mut r);
369 self.velocity_net.biases[l][j] -= fd_eps;
370 grad_b[l][j] = (perturbed - base_loss) / fd_eps;
371 }
372 }
373 }
374 self.velocity_net.update(&grad_w, &grad_b, lr);
375 base_loss
376 }
377
378 pub fn sample(&self, n_steps: usize, rng: &mut StdRng) -> Vec<f64> {
380 let d = self.config.z_dim;
381 let mut x = sample_standard_normal(d, rng);
382 let n = n_steps.max(1);
383 let dt = 1.0 / n as f64;
384
385 for step in 0..n {
386 let t = step as f64 * dt;
387 let v = self.velocity(&x, t);
388 for (xi, vi) in x.iter_mut().zip(v.iter()) {
389 *xi += dt * vi;
390 }
391 }
392 x
393 }
394}
395
396#[derive(Clone)]
402pub struct RectifiedFlowConfig {
403 pub z_dim: usize,
405 pub hidden_dim: usize,
407 pub n_layers: usize,
409 pub n_steps: usize,
411 pub lr: f64,
413}
414
415impl Default for RectifiedFlowConfig {
416 fn default() -> Self {
417 RectifiedFlowConfig {
418 z_dim: 2,
419 hidden_dim: 64,
420 n_layers: 2,
421 n_steps: 100,
422 lr: 1e-3,
423 }
424 }
425}
426
427pub struct RectifiedFlow {
432 pub velocity_net: CnfMlp,
434 pub config: RectifiedFlowConfig,
436}
437
438impl RectifiedFlow {
439 pub fn new(config: RectifiedFlowConfig) -> Self {
441 let in_dim = config.z_dim + 1;
442 let mut sizes = vec![in_dim];
443 for _ in 0..config.n_layers {
444 sizes.push(config.hidden_dim);
445 }
446 sizes.push(config.z_dim);
447 RectifiedFlow {
448 velocity_net: CnfMlp::new(&sizes),
449 config,
450 }
451 }
452
453 pub fn reflow_loss(
457 &self,
458 x0_batch: &[Vec<f64>],
459 x1_batch: &[Vec<f64>],
460 rng: &mut StdRng,
461 ) -> f64 {
462 let n = x0_batch.len().min(x1_batch.len());
463 if n == 0 {
464 return 0.0;
465 }
466 let d = self.config.z_dim;
467 let mut total = 0.0_f64;
468
469 for i in 0..n {
470 let t: f64 = rng.random();
471 let x0 = &x0_batch[i];
472 let x1 = &x1_batch[i];
473 let dim = x0.len().min(x1.len()).min(d);
474
475 let xt: Vec<f64> = (0..dim).map(|j| x0[j] + t * (x1[j] - x0[j])).collect();
477
478 let target: Vec<f64> = (0..dim).map(|j| x1[j] - x0[j]).collect();
480
481 let mut inp = xt;
482 inp.push(t);
483 let pred = self.velocity_net.forward(&inp);
484
485 let loss: f64 = pred
486 .iter()
487 .zip(target.iter())
488 .map(|(p, tg)| (p - tg) * (p - tg))
489 .sum::<f64>();
490 total += loss / dim.max(1) as f64;
491 }
492 total / n as f64
493 }
494
495 pub fn train_step(&mut self, x1_batch: &[Vec<f64>], lr: f64, rng: &mut StdRng) -> f64 {
499 if x1_batch.is_empty() {
500 return 0.0;
501 }
502 let d = self.config.z_dim;
503 let n = x1_batch.len();
504
505 let x0_batch: Vec<Vec<f64>> = (0..n).map(|_| sample_standard_normal(d, rng)).collect();
507
508 let mut eval_rng = StdRng::seed_from_u64(0x55aa77bb_u64);
509 let base_loss = self.reflow_loss(&x0_batch, x1_batch, &mut eval_rng);
510
511 let fd_eps = 1e-4;
513 let mut update_rng = StdRng::seed_from_u64(0xcc11ee22_u64);
514 let n_layers = self.velocity_net.n_layers();
515 let mut grad_w: Vec<Vec<Vec<f64>>> = self
516 .velocity_net
517 .weights
518 .iter()
519 .map(|lw| lw.iter().map(|row| vec![0.0; row.len()]).collect())
520 .collect();
521 let mut grad_b: Vec<Vec<f64>> = self
522 .velocity_net
523 .biases
524 .iter()
525 .map(|lb| vec![0.0; lb.len()])
526 .collect();
527
528 for l in 0..n_layers {
529 for j in 0..self.velocity_net.weights[l].len() {
530 for i in 0..self.velocity_net.weights[l][j].len() {
531 if update_rng.random::<f64>() < 0.05 {
532 self.velocity_net.weights[l][j][i] += fd_eps;
533 let mut r = StdRng::seed_from_u64(0x55aa77bb_u64);
534 let perturbed = self.reflow_loss(&x0_batch, x1_batch, &mut r);
535 self.velocity_net.weights[l][j][i] -= fd_eps;
536 grad_w[l][j][i] = (perturbed - base_loss) / fd_eps;
537 }
538 }
539 }
540 for j in 0..self.velocity_net.biases[l].len() {
541 if update_rng.random::<f64>() < 0.05 {
542 self.velocity_net.biases[l][j] += fd_eps;
543 let mut r = StdRng::seed_from_u64(0x55aa77bb_u64);
544 let perturbed = self.reflow_loss(&x0_batch, x1_batch, &mut r);
545 self.velocity_net.biases[l][j] -= fd_eps;
546 grad_b[l][j] = (perturbed - base_loss) / fd_eps;
547 }
548 }
549 }
550 self.velocity_net.update(&grad_w, &grad_b, lr);
551 base_loss
552 }
553
554 pub fn sample(&self, n_steps: usize, rng: &mut StdRng) -> Vec<f64> {
556 let d = self.config.z_dim;
557 let mut x = sample_standard_normal(d, rng);
558 let n = n_steps.max(1);
559 let dt = 1.0 / n as f64;
560
561 for step in 0..n {
562 let t = step as f64 * dt;
563 let mut inp = x.clone();
564 inp.push(t);
565 let v = self.velocity_net.forward(&inp);
566 for (xi, vi) in x.iter_mut().zip(v.iter()) {
567 *xi += dt * vi;
568 }
569 }
570 x
571 }
572}