ruvector_math/optimal_transport/
sliced_wasserstein.rs1use rand::prelude::*;
26use rand_distr::StandardNormal;
27use crate::utils::{argsort, EPS};
28use super::{OptimalTransport, WassersteinConfig};
29
30#[derive(Debug, Clone)]
32pub struct SlicedWasserstein {
33 num_projections: usize,
35 p: f64,
37 seed: Option<u64>,
39}
40
41impl SlicedWasserstein {
42 pub fn new(num_projections: usize) -> Self {
47 Self {
48 num_projections: num_projections.max(1),
49 p: 2.0,
50 seed: None,
51 }
52 }
53
54 pub fn from_config(config: &WassersteinConfig) -> Self {
56 Self {
57 num_projections: config.num_projections.max(1),
58 p: config.p,
59 seed: config.seed,
60 }
61 }
62
63 pub fn with_power(mut self, p: f64) -> Self {
65 self.p = p.max(1.0);
66 self
67 }
68
69 pub fn with_seed(mut self, seed: u64) -> Self {
71 self.seed = Some(seed);
72 self
73 }
74
75 fn generate_directions(&self, dim: usize) -> Vec<Vec<f64>> {
77 let mut rng = match self.seed {
78 Some(s) => StdRng::seed_from_u64(s),
79 None => StdRng::from_entropy(),
80 };
81
82 (0..self.num_projections)
83 .map(|_| {
84 let mut direction: Vec<f64> = (0..dim)
85 .map(|_| rng.sample(StandardNormal))
86 .collect();
87
88 let norm: f64 = direction.iter().map(|&x| x * x).sum::<f64>().sqrt();
90 if norm > EPS {
91 for x in &mut direction {
92 *x /= norm;
93 }
94 }
95 direction
96 })
97 .collect()
98 }
99
100 #[inline(always)]
102 fn project(points: &[Vec<f64>], direction: &[f64]) -> Vec<f64> {
103 points
104 .iter()
105 .map(|p| Self::dot_product(p, direction))
106 .collect()
107 }
108
109 #[inline(always)]
111 fn project_into(points: &[Vec<f64>], direction: &[f64], out: &mut [f64]) {
112 for (i, p) in points.iter().enumerate() {
113 out[i] = Self::dot_product(p, direction);
114 }
115 }
116
117 #[inline(always)]
120 fn dot_product(a: &[f64], b: &[f64]) -> f64 {
121 let len = a.len();
123 let chunks = len / 4;
124 let remainder = len % 4;
125
126 let mut sum0 = 0.0f64;
127 let mut sum1 = 0.0f64;
128 let mut sum2 = 0.0f64;
129 let mut sum3 = 0.0f64;
130
131 for i in 0..chunks {
133 let base = i * 4;
134 sum0 += a[base] * b[base];
135 sum1 += a[base + 1] * b[base + 1];
136 sum2 += a[base + 2] * b[base + 2];
137 sum3 += a[base + 3] * b[base + 3];
138 }
139
140 let base = chunks * 4;
142 for i in 0..remainder {
143 sum0 += a[base + i] * b[base + i];
144 }
145
146 sum0 + sum1 + sum2 + sum3
147 }
148
149 #[inline]
153 fn wasserstein_1d_uniform(&self, mut proj_a: Vec<f64>, mut proj_b: Vec<f64>) -> f64 {
154 let n = proj_a.len();
155 let m = proj_b.len();
156
157 proj_a.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
159 proj_b.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
160
161 if n == m {
162 self.wasserstein_1d_equal_size(&proj_a, &proj_b)
164 } else {
165 self.wasserstein_1d_quantile(&proj_a, &proj_b, n.max(m))
167 }
168 }
169
170 #[inline(always)]
172 fn wasserstein_1d_equal_size(&self, sorted_a: &[f64], sorted_b: &[f64]) -> f64 {
173 let n = sorted_a.len();
174 if n == 0 {
175 return 0.0;
176 }
177
178 if (self.p - 2.0).abs() < 1e-10 {
180 let mut sum0 = 0.0f64;
182 let mut sum1 = 0.0f64;
183 let mut sum2 = 0.0f64;
184 let mut sum3 = 0.0f64;
185
186 let chunks = n / 4;
187 let remainder = n % 4;
188
189 for i in 0..chunks {
190 let base = i * 4;
191 let d0 = sorted_a[base] - sorted_b[base];
192 let d1 = sorted_a[base + 1] - sorted_b[base + 1];
193 let d2 = sorted_a[base + 2] - sorted_b[base + 2];
194 let d3 = sorted_a[base + 3] - sorted_b[base + 3];
195 sum0 += d0 * d0;
196 sum1 += d1 * d1;
197 sum2 += d2 * d2;
198 sum3 += d3 * d3;
199 }
200
201 let base = chunks * 4;
202 for i in 0..remainder {
203 let d = sorted_a[base + i] - sorted_b[base + i];
204 sum0 += d * d;
205 }
206
207 (sum0 + sum1 + sum2 + sum3) / n as f64
208 } else if (self.p - 1.0).abs() < 1e-10 {
209 let mut sum = 0.0f64;
211 for i in 0..n {
212 sum += (sorted_a[i] - sorted_b[i]).abs();
213 }
214 sum / n as f64
215 } else {
216 sorted_a
218 .iter()
219 .zip(sorted_b.iter())
220 .map(|(&a, &b)| (a - b).abs().powf(self.p))
221 .sum::<f64>()
222 / n as f64
223 }
224 }
225
226 fn wasserstein_1d_quantile(&self, sorted_a: &[f64], sorted_b: &[f64], num_samples: usize) -> f64 {
228 let mut total = 0.0;
229
230 for i in 0..num_samples {
231 let q = (i as f64 + 0.5) / num_samples as f64;
232
233 let val_a = quantile_sorted(sorted_a, q);
234 let val_b = quantile_sorted(sorted_b, q);
235
236 total += (val_a - val_b).abs().powf(self.p);
237 }
238
239 total / num_samples as f64
240 }
241
242 fn wasserstein_1d_weighted(
244 &self,
245 proj_a: &[f64],
246 weights_a: &[f64],
247 proj_b: &[f64],
248 weights_b: &[f64],
249 ) -> f64 {
250 let idx_a = argsort(proj_a);
252 let idx_b = argsort(proj_b);
253
254 let sorted_a: Vec<f64> = idx_a.iter().map(|&i| proj_a[i]).collect();
255 let sorted_w_a: Vec<f64> = idx_a.iter().map(|&i| weights_a[i]).collect();
256 let sorted_b: Vec<f64> = idx_b.iter().map(|&i| proj_b[i]).collect();
257 let sorted_w_b: Vec<f64> = idx_b.iter().map(|&i| weights_b[i]).collect();
258
259 let cdf_a = compute_cdf(&sorted_w_a);
261 let cdf_b = compute_cdf(&sorted_w_b);
262
263 self.wasserstein_1d_from_cdfs(&sorted_a, &cdf_a, &sorted_b, &cdf_b)
265 }
266
267 fn wasserstein_1d_from_cdfs(
269 &self,
270 values_a: &[f64],
271 cdf_a: &[f64],
272 values_b: &[f64],
273 cdf_b: &[f64],
274 ) -> f64 {
275 let mut events: Vec<(f64, f64, f64)> = Vec::new(); let mut ia = 0;
279 let mut ib = 0;
280 let mut current_cdf_a = 0.0;
281 let mut current_cdf_b = 0.0;
282
283 while ia < values_a.len() || ib < values_b.len() {
284 let pos = match (ia < values_a.len(), ib < values_b.len()) {
285 (true, true) => {
286 if values_a[ia] <= values_b[ib] {
287 current_cdf_a = cdf_a[ia];
288 ia += 1;
289 values_a[ia - 1]
290 } else {
291 current_cdf_b = cdf_b[ib];
292 ib += 1;
293 values_b[ib - 1]
294 }
295 }
296 (true, false) => {
297 current_cdf_a = cdf_a[ia];
298 ia += 1;
299 values_a[ia - 1]
300 }
301 (false, true) => {
302 current_cdf_b = cdf_b[ib];
303 ib += 1;
304 values_b[ib - 1]
305 }
306 (false, false) => break,
307 };
308
309 events.push((pos, current_cdf_a, current_cdf_b));
310 }
311
312 let mut total = 0.0;
314 for i in 1..events.len() {
315 let width = events[i].0 - events[i - 1].0;
316 let height = (events[i - 1].1 - events[i - 1].2).abs();
317 total += width * height.powf(self.p);
318 }
319
320 total
321 }
322}
323
324impl OptimalTransport for SlicedWasserstein {
325 fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> f64 {
326 if source.is_empty() || target.is_empty() {
327 return 0.0;
328 }
329
330 let dim = source[0].len();
331 if dim == 0 {
332 return 0.0;
333 }
334
335 let directions = self.generate_directions(dim);
336 let n_source = source.len();
337 let n_target = target.len();
338
339 let mut proj_source = vec![0.0; n_source];
341 let mut proj_target = vec![0.0; n_target];
342
343 let total: f64 = directions
344 .iter()
345 .map(|dir| {
346 Self::project_into(source, dir, &mut proj_source);
348 Self::project_into(target, dir, &mut proj_target);
349
350 self.wasserstein_1d_uniform(proj_source.clone(), proj_target.clone())
352 })
353 .sum();
354
355 (total / self.num_projections as f64).powf(1.0 / self.p)
356 }
357
358 fn weighted_distance(
359 &self,
360 source: &[Vec<f64>],
361 source_weights: &[f64],
362 target: &[Vec<f64>],
363 target_weights: &[f64],
364 ) -> f64 {
365 if source.is_empty() || target.is_empty() {
366 return 0.0;
367 }
368
369 let dim = source[0].len();
370 if dim == 0 {
371 return 0.0;
372 }
373
374 let sum_a: f64 = source_weights.iter().sum();
376 let sum_b: f64 = target_weights.iter().sum();
377 let weights_a: Vec<f64> = source_weights.iter().map(|&w| w / sum_a).collect();
378 let weights_b: Vec<f64> = target_weights.iter().map(|&w| w / sum_b).collect();
379
380 let directions = self.generate_directions(dim);
381
382 let total: f64 = directions
383 .iter()
384 .map(|dir| {
385 let proj_source = Self::project(source, dir);
386 let proj_target = Self::project(target, dir);
387 self.wasserstein_1d_weighted(&proj_source, &weights_a, &proj_target, &weights_b)
388 })
389 .sum();
390
391 (total / self.num_projections as f64).powf(1.0 / self.p)
392 }
393}
394
395fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
397 if sorted.is_empty() {
398 return 0.0;
399 }
400
401 let q = q.clamp(0.0, 1.0);
402 let n = sorted.len();
403
404 if n == 1 {
405 return sorted[0];
406 }
407
408 let idx_f = q * (n - 1) as f64;
409 let idx_low = idx_f.floor() as usize;
410 let idx_high = (idx_low + 1).min(n - 1);
411 let frac = idx_f - idx_low as f64;
412
413 sorted[idx_low] * (1.0 - frac) + sorted[idx_high] * frac
414}
415
416fn compute_cdf(weights: &[f64]) -> Vec<f64> {
418 let total: f64 = weights.iter().sum();
419 let mut cdf = Vec::with_capacity(weights.len());
420 let mut cumsum = 0.0;
421
422 for &w in weights {
423 cumsum += w / total;
424 cdf.push(cumsum);
425 }
426
427 cdf
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_sliced_wasserstein_identical() {
436 let sw = SlicedWasserstein::new(100).with_seed(42);
437
438 let points = vec![
439 vec![0.0, 0.0],
440 vec![1.0, 0.0],
441 vec![0.0, 1.0],
442 vec![1.0, 1.0],
443 ];
444
445 let dist = sw.distance(&points, &points);
447 assert!(dist < 0.01, "Self-distance should be ~0, got {}", dist);
448 }
449
450 #[test]
451 fn test_sliced_wasserstein_translation() {
452 let sw = SlicedWasserstein::new(500).with_seed(42);
453
454 let source = vec![
455 vec![0.0, 0.0],
456 vec![1.0, 0.0],
457 vec![0.0, 1.0],
458 vec![1.0, 1.0],
459 ];
460
461 let target: Vec<Vec<f64>> = source.iter().map(|p| vec![p[0] + 1.0, p[1] + 1.0]).collect();
463
464 let dist = sw.distance(&source, &target);
465
466 assert!(
469 dist > 0.5 && dist < 2.0,
470 "Translation distance should be positive, got {:.3}",
471 dist
472 );
473 }
474
475 #[test]
476 fn test_sliced_wasserstein_scaling() {
477 let sw = SlicedWasserstein::new(500).with_seed(42);
478
479 let source = vec![
480 vec![0.0, 0.0],
481 vec![1.0, 0.0],
482 vec![0.0, 1.0],
483 vec![1.0, 1.0],
484 ];
485
486 let target: Vec<Vec<f64>> = source.iter().map(|p| vec![p[0] * 2.0, p[1] * 2.0]).collect();
488
489 let dist = sw.distance(&source, &target);
490
491 assert!(dist > 0.0, "Scaling should produce positive distance");
493 }
494
495 #[test]
496 fn test_weighted_distance() {
497 let sw = SlicedWasserstein::new(100).with_seed(42);
498
499 let source = vec![vec![0.0], vec![1.0]];
500 let target = vec![vec![2.0], vec![3.0]];
501
502 let weights_s = vec![0.5, 0.5];
504 let weights_t = vec![0.5, 0.5];
505
506 let dist = sw.weighted_distance(&source, &weights_s, &target, &weights_t);
507 assert!(dist > 0.0);
508 }
509
510 #[test]
511 fn test_1d_projections() {
512 let sw = SlicedWasserstein::new(10);
513 let directions = sw.generate_directions(3);
514
515 assert_eq!(directions.len(), 10);
516
517 for dir in &directions {
519 let norm: f64 = dir.iter().map(|&x| x * x).sum::<f64>().sqrt();
520 assert!((norm - 1.0).abs() < 1e-6, "Direction not unit: {}", norm);
521 }
522 }
523}