ruvector_math/optimal_transport/
sliced_wasserstein.rs1use super::{OptimalTransport, WassersteinConfig};
26use crate::utils::{argsort, EPS};
27use rand::prelude::*;
28use rand_distr::StandardNormal;
29
30#[derive(Debug, Clone)]
32pub struct SlicedWasserstein {
33 num_projections: usize,
35 p: f64,
37 seed: Option<u64>,
39}
40
41impl SlicedWasserstein {
42 pub fn new(num_projections: usize) -> Self {
47 Self {
48 num_projections: num_projections.max(1),
49 p: 2.0,
50 seed: None,
51 }
52 }
53
54 pub fn from_config(config: &WassersteinConfig) -> Self {
56 Self {
57 num_projections: config.num_projections.max(1),
58 p: config.p,
59 seed: config.seed,
60 }
61 }
62
63 pub fn with_power(mut self, p: f64) -> Self {
65 self.p = p.max(1.0);
66 self
67 }
68
69 pub fn with_seed(mut self, seed: u64) -> Self {
71 self.seed = Some(seed);
72 self
73 }
74
75 fn generate_directions(&self, dim: usize) -> Vec<Vec<f64>> {
77 let mut rng = match self.seed {
78 Some(s) => StdRng::seed_from_u64(s),
79 None => StdRng::from_entropy(),
80 };
81
82 (0..self.num_projections)
83 .map(|_| {
84 let mut direction: Vec<f64> =
85 (0..dim).map(|_| rng.sample(StandardNormal)).collect();
86
87 let norm: f64 = direction.iter().map(|&x| x * x).sum::<f64>().sqrt();
89 if norm > EPS {
90 for x in &mut direction {
91 *x /= norm;
92 }
93 }
94 direction
95 })
96 .collect()
97 }
98
99 #[inline(always)]
101 fn project(points: &[Vec<f64>], direction: &[f64]) -> Vec<f64> {
102 points
103 .iter()
104 .map(|p| Self::dot_product(p, direction))
105 .collect()
106 }
107
108 #[inline(always)]
110 fn project_into(points: &[Vec<f64>], direction: &[f64], out: &mut [f64]) {
111 for (i, p) in points.iter().enumerate() {
112 out[i] = Self::dot_product(p, direction);
113 }
114 }
115
116 #[inline(always)]
119 fn dot_product(a: &[f64], b: &[f64]) -> f64 {
120 let len = a.len();
122 let chunks = len / 4;
123 let remainder = len % 4;
124
125 let mut sum0 = 0.0f64;
126 let mut sum1 = 0.0f64;
127 let mut sum2 = 0.0f64;
128 let mut sum3 = 0.0f64;
129
130 for i in 0..chunks {
132 let base = i * 4;
133 sum0 += a[base] * b[base];
134 sum1 += a[base + 1] * b[base + 1];
135 sum2 += a[base + 2] * b[base + 2];
136 sum3 += a[base + 3] * b[base + 3];
137 }
138
139 let base = chunks * 4;
141 for i in 0..remainder {
142 sum0 += a[base + i] * b[base + i];
143 }
144
145 sum0 + sum1 + sum2 + sum3
146 }
147
148 #[inline]
152 fn wasserstein_1d_uniform(&self, mut proj_a: Vec<f64>, mut proj_b: Vec<f64>) -> f64 {
153 let n = proj_a.len();
154 let m = proj_b.len();
155
156 proj_a.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
158 proj_b.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
159
160 if n == m {
161 self.wasserstein_1d_equal_size(&proj_a, &proj_b)
163 } else {
164 self.wasserstein_1d_quantile(&proj_a, &proj_b, n.max(m))
166 }
167 }
168
169 #[inline(always)]
171 fn wasserstein_1d_equal_size(&self, sorted_a: &[f64], sorted_b: &[f64]) -> f64 {
172 let n = sorted_a.len();
173 if n == 0 {
174 return 0.0;
175 }
176
177 if (self.p - 2.0).abs() < 1e-10 {
179 let mut sum0 = 0.0f64;
181 let mut sum1 = 0.0f64;
182 let mut sum2 = 0.0f64;
183 let mut sum3 = 0.0f64;
184
185 let chunks = n / 4;
186 let remainder = n % 4;
187
188 for i in 0..chunks {
189 let base = i * 4;
190 let d0 = sorted_a[base] - sorted_b[base];
191 let d1 = sorted_a[base + 1] - sorted_b[base + 1];
192 let d2 = sorted_a[base + 2] - sorted_b[base + 2];
193 let d3 = sorted_a[base + 3] - sorted_b[base + 3];
194 sum0 += d0 * d0;
195 sum1 += d1 * d1;
196 sum2 += d2 * d2;
197 sum3 += d3 * d3;
198 }
199
200 let base = chunks * 4;
201 for i in 0..remainder {
202 let d = sorted_a[base + i] - sorted_b[base + i];
203 sum0 += d * d;
204 }
205
206 (sum0 + sum1 + sum2 + sum3) / n as f64
207 } else if (self.p - 1.0).abs() < 1e-10 {
208 let mut sum = 0.0f64;
210 for i in 0..n {
211 sum += (sorted_a[i] - sorted_b[i]).abs();
212 }
213 sum / n as f64
214 } else {
215 sorted_a
217 .iter()
218 .zip(sorted_b.iter())
219 .map(|(&a, &b)| (a - b).abs().powf(self.p))
220 .sum::<f64>()
221 / n as f64
222 }
223 }
224
225 fn wasserstein_1d_quantile(
227 &self,
228 sorted_a: &[f64],
229 sorted_b: &[f64],
230 num_samples: usize,
231 ) -> f64 {
232 let mut total = 0.0;
233
234 for i in 0..num_samples {
235 let q = (i as f64 + 0.5) / num_samples as f64;
236
237 let val_a = quantile_sorted(sorted_a, q);
238 let val_b = quantile_sorted(sorted_b, q);
239
240 total += (val_a - val_b).abs().powf(self.p);
241 }
242
243 total / num_samples as f64
244 }
245
246 fn wasserstein_1d_weighted(
248 &self,
249 proj_a: &[f64],
250 weights_a: &[f64],
251 proj_b: &[f64],
252 weights_b: &[f64],
253 ) -> f64 {
254 let idx_a = argsort(proj_a);
256 let idx_b = argsort(proj_b);
257
258 let sorted_a: Vec<f64> = idx_a.iter().map(|&i| proj_a[i]).collect();
259 let sorted_w_a: Vec<f64> = idx_a.iter().map(|&i| weights_a[i]).collect();
260 let sorted_b: Vec<f64> = idx_b.iter().map(|&i| proj_b[i]).collect();
261 let sorted_w_b: Vec<f64> = idx_b.iter().map(|&i| weights_b[i]).collect();
262
263 let cdf_a = compute_cdf(&sorted_w_a);
265 let cdf_b = compute_cdf(&sorted_w_b);
266
267 self.wasserstein_1d_from_cdfs(&sorted_a, &cdf_a, &sorted_b, &cdf_b)
269 }
270
271 fn wasserstein_1d_from_cdfs(
273 &self,
274 values_a: &[f64],
275 cdf_a: &[f64],
276 values_b: &[f64],
277 cdf_b: &[f64],
278 ) -> f64 {
279 let mut events: Vec<(f64, f64, f64)> = Vec::new(); let mut ia = 0;
283 let mut ib = 0;
284 let mut current_cdf_a = 0.0;
285 let mut current_cdf_b = 0.0;
286
287 while ia < values_a.len() || ib < values_b.len() {
288 let pos = match (ia < values_a.len(), ib < values_b.len()) {
289 (true, true) => {
290 if values_a[ia] <= values_b[ib] {
291 current_cdf_a = cdf_a[ia];
292 ia += 1;
293 values_a[ia - 1]
294 } else {
295 current_cdf_b = cdf_b[ib];
296 ib += 1;
297 values_b[ib - 1]
298 }
299 }
300 (true, false) => {
301 current_cdf_a = cdf_a[ia];
302 ia += 1;
303 values_a[ia - 1]
304 }
305 (false, true) => {
306 current_cdf_b = cdf_b[ib];
307 ib += 1;
308 values_b[ib - 1]
309 }
310 (false, false) => break,
311 };
312
313 events.push((pos, current_cdf_a, current_cdf_b));
314 }
315
316 let mut total = 0.0;
318 for i in 1..events.len() {
319 let width = events[i].0 - events[i - 1].0;
320 let height = (events[i - 1].1 - events[i - 1].2).abs();
321 total += width * height.powf(self.p);
322 }
323
324 total
325 }
326}
327
328impl OptimalTransport for SlicedWasserstein {
329 fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> f64 {
330 if source.is_empty() || target.is_empty() {
331 return 0.0;
332 }
333
334 let dim = source[0].len();
335 if dim == 0 {
336 return 0.0;
337 }
338
339 let directions = self.generate_directions(dim);
340 let n_source = source.len();
341 let n_target = target.len();
342
343 let mut proj_source = vec![0.0; n_source];
345 let mut proj_target = vec![0.0; n_target];
346
347 let total: f64 = directions
348 .iter()
349 .map(|dir| {
350 Self::project_into(source, dir, &mut proj_source);
352 Self::project_into(target, dir, &mut proj_target);
353
354 self.wasserstein_1d_uniform(proj_source.clone(), proj_target.clone())
356 })
357 .sum();
358
359 (total / self.num_projections as f64).powf(1.0 / self.p)
360 }
361
362 fn weighted_distance(
363 &self,
364 source: &[Vec<f64>],
365 source_weights: &[f64],
366 target: &[Vec<f64>],
367 target_weights: &[f64],
368 ) -> f64 {
369 if source.is_empty() || target.is_empty() {
370 return 0.0;
371 }
372
373 let dim = source[0].len();
374 if dim == 0 {
375 return 0.0;
376 }
377
378 let sum_a: f64 = source_weights.iter().sum();
380 let sum_b: f64 = target_weights.iter().sum();
381 let weights_a: Vec<f64> = source_weights.iter().map(|&w| w / sum_a).collect();
382 let weights_b: Vec<f64> = target_weights.iter().map(|&w| w / sum_b).collect();
383
384 let directions = self.generate_directions(dim);
385
386 let total: f64 = directions
387 .iter()
388 .map(|dir| {
389 let proj_source = Self::project(source, dir);
390 let proj_target = Self::project(target, dir);
391 self.wasserstein_1d_weighted(&proj_source, &weights_a, &proj_target, &weights_b)
392 })
393 .sum();
394
395 (total / self.num_projections as f64).powf(1.0 / self.p)
396 }
397}
398
399fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
401 if sorted.is_empty() {
402 return 0.0;
403 }
404
405 let q = q.clamp(0.0, 1.0);
406 let n = sorted.len();
407
408 if n == 1 {
409 return sorted[0];
410 }
411
412 let idx_f = q * (n - 1) as f64;
413 let idx_low = idx_f.floor() as usize;
414 let idx_high = (idx_low + 1).min(n - 1);
415 let frac = idx_f - idx_low as f64;
416
417 sorted[idx_low] * (1.0 - frac) + sorted[idx_high] * frac
418}
419
420fn compute_cdf(weights: &[f64]) -> Vec<f64> {
422 let total: f64 = weights.iter().sum();
423 let mut cdf = Vec::with_capacity(weights.len());
424 let mut cumsum = 0.0;
425
426 for &w in weights {
427 cumsum += w / total;
428 cdf.push(cumsum);
429 }
430
431 cdf
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_sliced_wasserstein_identical() {
440 let sw = SlicedWasserstein::new(100).with_seed(42);
441
442 let points = vec![
443 vec![0.0, 0.0],
444 vec![1.0, 0.0],
445 vec![0.0, 1.0],
446 vec![1.0, 1.0],
447 ];
448
449 let dist = sw.distance(&points, &points);
451 assert!(dist < 0.01, "Self-distance should be ~0, got {}", dist);
452 }
453
454 #[test]
455 fn test_sliced_wasserstein_translation() {
456 let sw = SlicedWasserstein::new(500).with_seed(42);
457
458 let source = vec![
459 vec![0.0, 0.0],
460 vec![1.0, 0.0],
461 vec![0.0, 1.0],
462 vec![1.0, 1.0],
463 ];
464
465 let target: Vec<Vec<f64>> = source
467 .iter()
468 .map(|p| vec![p[0] + 1.0, p[1] + 1.0])
469 .collect();
470
471 let dist = sw.distance(&source, &target);
472
473 assert!(
476 dist > 0.5 && dist < 2.0,
477 "Translation distance should be positive, got {:.3}",
478 dist
479 );
480 }
481
482 #[test]
483 fn test_sliced_wasserstein_scaling() {
484 let sw = SlicedWasserstein::new(500).with_seed(42);
485
486 let source = vec![
487 vec![0.0, 0.0],
488 vec![1.0, 0.0],
489 vec![0.0, 1.0],
490 vec![1.0, 1.0],
491 ];
492
493 let target: Vec<Vec<f64>> = source
495 .iter()
496 .map(|p| vec![p[0] * 2.0, p[1] * 2.0])
497 .collect();
498
499 let dist = sw.distance(&source, &target);
500
501 assert!(dist > 0.0, "Scaling should produce positive distance");
503 }
504
505 #[test]
506 fn test_weighted_distance() {
507 let sw = SlicedWasserstein::new(100).with_seed(42);
508
509 let source = vec![vec![0.0], vec![1.0]];
510 let target = vec![vec![2.0], vec![3.0]];
511
512 let weights_s = vec![0.5, 0.5];
514 let weights_t = vec![0.5, 0.5];
515
516 let dist = sw.weighted_distance(&source, &weights_s, &target, &weights_t);
517 assert!(dist > 0.0);
518 }
519
520 #[test]
521 fn test_1d_projections() {
522 let sw = SlicedWasserstein::new(10);
523 let directions = sw.generate_directions(3);
524
525 assert_eq!(directions.len(), 10);
526
527 for dir in &directions {
529 let norm: f64 = dir.iter().map(|&x| x * x).sum::<f64>().sqrt();
530 assert!((norm - 1.0).abs() < 1e-6, "Direction not unit: {}", norm);
531 }
532 }
533}