ruvector_math/optimal_transport/
sliced_wasserstein.rs

1//! Sliced Wasserstein Distance
2//!
3//! The Sliced Wasserstein distance projects high-dimensional distributions
4//! onto random 1D lines and averages the 1D Wasserstein distances.
5//!
6//! ## Algorithm
7//!
8//! 1. Generate L random unit vectors (directions) in R^d
9//! 2. For each direction θ:
10//!    a. Project all source and target points onto θ
11//!    b. Compute 1D Wasserstein distance (closed-form via sorted quantiles)
12//! 3. Average over all directions
13//!
14//! ## Complexity
15//!
16//! - O(L × n log n) where L = number of projections, n = number of points
17//! - Linear in dimension d (only dot products)
18//!
19//! ## Advantages
20//!
21//! - **Fast**: Near-linear scaling to millions of points
22//! - **SIMD-friendly**: Projections are just dot products
23//! - **Statistically consistent**: Converges to true W2 as L → ∞
24
25use rand::prelude::*;
26use rand_distr::StandardNormal;
27use crate::utils::{argsort, EPS};
28use super::{OptimalTransport, WassersteinConfig};
29
30/// Sliced Wasserstein distance calculator
31#[derive(Debug, Clone)]
32pub struct SlicedWasserstein {
33    /// Number of random projection directions
34    num_projections: usize,
35    /// Power for Wasserstein-p (typically 1 or 2)
36    p: f64,
37    /// Random seed for reproducibility
38    seed: Option<u64>,
39}
40
41impl SlicedWasserstein {
42    /// Create a new Sliced Wasserstein calculator
43    ///
44    /// # Arguments
45    /// * `num_projections` - Number of random 1D projections (100-1000 typical)
46    pub fn new(num_projections: usize) -> Self {
47        Self {
48            num_projections: num_projections.max(1),
49            p: 2.0,
50            seed: None,
51        }
52    }
53
54    /// Create from configuration
55    pub fn from_config(config: &WassersteinConfig) -> Self {
56        Self {
57            num_projections: config.num_projections.max(1),
58            p: config.p,
59            seed: config.seed,
60        }
61    }
62
63    /// Set the Wasserstein power (1 for W1, 2 for W2)
64    pub fn with_power(mut self, p: f64) -> Self {
65        self.p = p.max(1.0);
66        self
67    }
68
69    /// Set random seed for reproducibility
70    pub fn with_seed(mut self, seed: u64) -> Self {
71        self.seed = Some(seed);
72        self
73    }
74
75    /// Generate random unit directions
76    fn generate_directions(&self, dim: usize) -> Vec<Vec<f64>> {
77        let mut rng = match self.seed {
78            Some(s) => StdRng::seed_from_u64(s),
79            None => StdRng::from_entropy(),
80        };
81
82        (0..self.num_projections)
83            .map(|_| {
84                let mut direction: Vec<f64> = (0..dim)
85                    .map(|_| rng.sample(StandardNormal))
86                    .collect();
87
88                // Normalize to unit vector
89                let norm: f64 = direction.iter().map(|&x| x * x).sum::<f64>().sqrt();
90                if norm > EPS {
91                    for x in &mut direction {
92                        *x /= norm;
93                    }
94                }
95                direction
96            })
97            .collect()
98    }
99
100    /// Project points onto a direction (SIMD-friendly dot product)
101    #[inline(always)]
102    fn project(points: &[Vec<f64>], direction: &[f64]) -> Vec<f64> {
103        points
104            .iter()
105            .map(|p| Self::dot_product(p, direction))
106            .collect()
107    }
108
109    /// Project points into pre-allocated buffer (reduces allocations)
110    #[inline(always)]
111    fn project_into(points: &[Vec<f64>], direction: &[f64], out: &mut [f64]) {
112        for (i, p) in points.iter().enumerate() {
113            out[i] = Self::dot_product(p, direction);
114        }
115    }
116
117    /// SIMD-friendly dot product using fold pattern
118    /// Compiler can auto-vectorize this pattern effectively
119    #[inline(always)]
120    fn dot_product(a: &[f64], b: &[f64]) -> f64 {
121        // Use 4-way unrolled accumulator for better SIMD utilization
122        let len = a.len();
123        let chunks = len / 4;
124        let remainder = len % 4;
125
126        let mut sum0 = 0.0f64;
127        let mut sum1 = 0.0f64;
128        let mut sum2 = 0.0f64;
129        let mut sum3 = 0.0f64;
130
131        // Process 4 elements at a time (helps SIMD vectorization)
132        for i in 0..chunks {
133            let base = i * 4;
134            sum0 += a[base] * b[base];
135            sum1 += a[base + 1] * b[base + 1];
136            sum2 += a[base + 2] * b[base + 2];
137            sum3 += a[base + 3] * b[base + 3];
138        }
139
140        // Handle remainder
141        let base = chunks * 4;
142        for i in 0..remainder {
143            sum0 += a[base + i] * b[base + i];
144        }
145
146        sum0 + sum1 + sum2 + sum3
147    }
148
149    /// Compute 1D Wasserstein distance between two sorted distributions
150    ///
151    /// For uniform weights, this is simply the sum of |sorted_a[i] - sorted_b[i]|^p
152    #[inline]
153    fn wasserstein_1d_uniform(&self, mut proj_a: Vec<f64>, mut proj_b: Vec<f64>) -> f64 {
154        let n = proj_a.len();
155        let m = proj_b.len();
156
157        // Sort projections using fast f64 comparison
158        proj_a.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
159        proj_b.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
160
161        if n == m {
162            // Same size: direct comparison with SIMD-friendly accumulator
163            self.wasserstein_1d_equal_size(&proj_a, &proj_b)
164        } else {
165            // Different sizes: interpolate via quantiles
166            self.wasserstein_1d_quantile(&proj_a, &proj_b, n.max(m))
167        }
168    }
169
170    /// Optimized equal-size 1D Wasserstein with SIMD-friendly pattern
171    #[inline(always)]
172    fn wasserstein_1d_equal_size(&self, sorted_a: &[f64], sorted_b: &[f64]) -> f64 {
173        let n = sorted_a.len();
174        if n == 0 {
175            return 0.0;
176        }
177
178        // Use p=2 fast path (most common case)
179        if (self.p - 2.0).abs() < 1e-10 {
180            // L2 Wasserstein: sum of squared differences
181            let mut sum0 = 0.0f64;
182            let mut sum1 = 0.0f64;
183            let mut sum2 = 0.0f64;
184            let mut sum3 = 0.0f64;
185
186            let chunks = n / 4;
187            let remainder = n % 4;
188
189            for i in 0..chunks {
190                let base = i * 4;
191                let d0 = sorted_a[base] - sorted_b[base];
192                let d1 = sorted_a[base + 1] - sorted_b[base + 1];
193                let d2 = sorted_a[base + 2] - sorted_b[base + 2];
194                let d3 = sorted_a[base + 3] - sorted_b[base + 3];
195                sum0 += d0 * d0;
196                sum1 += d1 * d1;
197                sum2 += d2 * d2;
198                sum3 += d3 * d3;
199            }
200
201            let base = chunks * 4;
202            for i in 0..remainder {
203                let d = sorted_a[base + i] - sorted_b[base + i];
204                sum0 += d * d;
205            }
206
207            (sum0 + sum1 + sum2 + sum3) / n as f64
208        } else if (self.p - 1.0).abs() < 1e-10 {
209            // L1 Wasserstein: sum of absolute differences
210            let mut sum = 0.0f64;
211            for i in 0..n {
212                sum += (sorted_a[i] - sorted_b[i]).abs();
213            }
214            sum / n as f64
215        } else {
216            // General case
217            sorted_a
218                .iter()
219                .zip(sorted_b.iter())
220                .map(|(&a, &b)| (a - b).abs().powf(self.p))
221                .sum::<f64>()
222                / n as f64
223        }
224    }
225
226    /// Compute 1D Wasserstein via quantile interpolation
227    fn wasserstein_1d_quantile(&self, sorted_a: &[f64], sorted_b: &[f64], num_samples: usize) -> f64 {
228        let mut total = 0.0;
229
230        for i in 0..num_samples {
231            let q = (i as f64 + 0.5) / num_samples as f64;
232
233            let val_a = quantile_sorted(sorted_a, q);
234            let val_b = quantile_sorted(sorted_b, q);
235
236            total += (val_a - val_b).abs().powf(self.p);
237        }
238
239        total / num_samples as f64
240    }
241
242    /// Compute 1D Wasserstein with weights
243    fn wasserstein_1d_weighted(
244        &self,
245        proj_a: &[f64],
246        weights_a: &[f64],
247        proj_b: &[f64],
248        weights_b: &[f64],
249    ) -> f64 {
250        // Sort by projected values
251        let idx_a = argsort(proj_a);
252        let idx_b = argsort(proj_b);
253
254        let sorted_a: Vec<f64> = idx_a.iter().map(|&i| proj_a[i]).collect();
255        let sorted_w_a: Vec<f64> = idx_a.iter().map(|&i| weights_a[i]).collect();
256        let sorted_b: Vec<f64> = idx_b.iter().map(|&i| proj_b[i]).collect();
257        let sorted_w_b: Vec<f64> = idx_b.iter().map(|&i| weights_b[i]).collect();
258
259        // Compute cumulative weights
260        let cdf_a = compute_cdf(&sorted_w_a);
261        let cdf_b = compute_cdf(&sorted_w_b);
262
263        // Merge and compute
264        self.wasserstein_1d_from_cdfs(&sorted_a, &cdf_a, &sorted_b, &cdf_b)
265    }
266
267    /// Compute 1D Wasserstein from CDFs
268    fn wasserstein_1d_from_cdfs(
269        &self,
270        values_a: &[f64],
271        cdf_a: &[f64],
272        values_b: &[f64],
273        cdf_b: &[f64],
274    ) -> f64 {
275        // Merge all CDF points
276        let mut events: Vec<(f64, f64, f64)> = Vec::new(); // (position, cdf_a, cdf_b)
277
278        let mut ia = 0;
279        let mut ib = 0;
280        let mut current_cdf_a = 0.0;
281        let mut current_cdf_b = 0.0;
282
283        while ia < values_a.len() || ib < values_b.len() {
284            let pos = match (ia < values_a.len(), ib < values_b.len()) {
285                (true, true) => {
286                    if values_a[ia] <= values_b[ib] {
287                        current_cdf_a = cdf_a[ia];
288                        ia += 1;
289                        values_a[ia - 1]
290                    } else {
291                        current_cdf_b = cdf_b[ib];
292                        ib += 1;
293                        values_b[ib - 1]
294                    }
295                }
296                (true, false) => {
297                    current_cdf_a = cdf_a[ia];
298                    ia += 1;
299                    values_a[ia - 1]
300                }
301                (false, true) => {
302                    current_cdf_b = cdf_b[ib];
303                    ib += 1;
304                    values_b[ib - 1]
305                }
306                (false, false) => break,
307            };
308
309            events.push((pos, current_cdf_a, current_cdf_b));
310        }
311
312        // Integrate |F_a - F_b|^p
313        let mut total = 0.0;
314        for i in 1..events.len() {
315            let width = events[i].0 - events[i - 1].0;
316            let height = (events[i - 1].1 - events[i - 1].2).abs();
317            total += width * height.powf(self.p);
318        }
319
320        total
321    }
322}
323
324impl OptimalTransport for SlicedWasserstein {
325    fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> f64 {
326        if source.is_empty() || target.is_empty() {
327            return 0.0;
328        }
329
330        let dim = source[0].len();
331        if dim == 0 {
332            return 0.0;
333        }
334
335        let directions = self.generate_directions(dim);
336        let n_source = source.len();
337        let n_target = target.len();
338
339        // Pre-allocate projection buffers (reduces allocations per direction)
340        let mut proj_source = vec![0.0; n_source];
341        let mut proj_target = vec![0.0; n_target];
342
343        let total: f64 = directions
344            .iter()
345            .map(|dir| {
346                // Project into pre-allocated buffers
347                Self::project_into(source, dir, &mut proj_source);
348                Self::project_into(target, dir, &mut proj_target);
349
350                // Clone for sorting (wasserstein_1d_uniform sorts in place)
351                self.wasserstein_1d_uniform(proj_source.clone(), proj_target.clone())
352            })
353            .sum();
354
355        (total / self.num_projections as f64).powf(1.0 / self.p)
356    }
357
358    fn weighted_distance(
359        &self,
360        source: &[Vec<f64>],
361        source_weights: &[f64],
362        target: &[Vec<f64>],
363        target_weights: &[f64],
364    ) -> f64 {
365        if source.is_empty() || target.is_empty() {
366            return 0.0;
367        }
368
369        let dim = source[0].len();
370        if dim == 0 {
371            return 0.0;
372        }
373
374        // Normalize weights
375        let sum_a: f64 = source_weights.iter().sum();
376        let sum_b: f64 = target_weights.iter().sum();
377        let weights_a: Vec<f64> = source_weights.iter().map(|&w| w / sum_a).collect();
378        let weights_b: Vec<f64> = target_weights.iter().map(|&w| w / sum_b).collect();
379
380        let directions = self.generate_directions(dim);
381
382        let total: f64 = directions
383            .iter()
384            .map(|dir| {
385                let proj_source = Self::project(source, dir);
386                let proj_target = Self::project(target, dir);
387                self.wasserstein_1d_weighted(&proj_source, &weights_a, &proj_target, &weights_b)
388            })
389            .sum();
390
391        (total / self.num_projections as f64).powf(1.0 / self.p)
392    }
393}
394
395/// Quantile of sorted data
396fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
397    if sorted.is_empty() {
398        return 0.0;
399    }
400
401    let q = q.clamp(0.0, 1.0);
402    let n = sorted.len();
403
404    if n == 1 {
405        return sorted[0];
406    }
407
408    let idx_f = q * (n - 1) as f64;
409    let idx_low = idx_f.floor() as usize;
410    let idx_high = (idx_low + 1).min(n - 1);
411    let frac = idx_f - idx_low as f64;
412
413    sorted[idx_low] * (1.0 - frac) + sorted[idx_high] * frac
414}
415
416/// Compute CDF from weights
417fn compute_cdf(weights: &[f64]) -> Vec<f64> {
418    let total: f64 = weights.iter().sum();
419    let mut cdf = Vec::with_capacity(weights.len());
420    let mut cumsum = 0.0;
421
422    for &w in weights {
423        cumsum += w / total;
424        cdf.push(cumsum);
425    }
426
427    cdf
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn test_sliced_wasserstein_identical() {
436        let sw = SlicedWasserstein::new(100).with_seed(42);
437
438        let points = vec![
439            vec![0.0, 0.0],
440            vec![1.0, 0.0],
441            vec![0.0, 1.0],
442            vec![1.0, 1.0],
443        ];
444
445        // Distance to itself should be very small
446        let dist = sw.distance(&points, &points);
447        assert!(dist < 0.01, "Self-distance should be ~0, got {}", dist);
448    }
449
450    #[test]
451    fn test_sliced_wasserstein_translation() {
452        let sw = SlicedWasserstein::new(500).with_seed(42);
453
454        let source = vec![
455            vec![0.0, 0.0],
456            vec![1.0, 0.0],
457            vec![0.0, 1.0],
458            vec![1.0, 1.0],
459        ];
460
461        // Translate by (1, 1)
462        let target: Vec<Vec<f64>> = source.iter().map(|p| vec![p[0] + 1.0, p[1] + 1.0]).collect();
463
464        let dist = sw.distance(&source, &target);
465
466        // For W2 translation by (1, 1), expected distance is sqrt(2) ≈ 1.414
467        // But Sliced Wasserstein is an approximation, so allow wider tolerance
468        assert!(
469            dist > 0.5 && dist < 2.0,
470            "Translation distance should be positive, got {:.3}",
471            dist
472        );
473    }
474
475    #[test]
476    fn test_sliced_wasserstein_scaling() {
477        let sw = SlicedWasserstein::new(500).with_seed(42);
478
479        let source = vec![
480            vec![0.0, 0.0],
481            vec![1.0, 0.0],
482            vec![0.0, 1.0],
483            vec![1.0, 1.0],
484        ];
485
486        // Scale by 2
487        let target: Vec<Vec<f64>> = source.iter().map(|p| vec![p[0] * 2.0, p[1] * 2.0]).collect();
488
489        let dist = sw.distance(&source, &target);
490
491        // Should be positive for scaled distribution
492        assert!(dist > 0.0, "Scaling should produce positive distance");
493    }
494
495    #[test]
496    fn test_weighted_distance() {
497        let sw = SlicedWasserstein::new(100).with_seed(42);
498
499        let source = vec![vec![0.0], vec![1.0]];
500        let target = vec![vec![2.0], vec![3.0]];
501
502        // Uniform weights
503        let weights_s = vec![0.5, 0.5];
504        let weights_t = vec![0.5, 0.5];
505
506        let dist = sw.weighted_distance(&source, &weights_s, &target, &weights_t);
507        assert!(dist > 0.0);
508    }
509
510    #[test]
511    fn test_1d_projections() {
512        let sw = SlicedWasserstein::new(10);
513        let directions = sw.generate_directions(3);
514
515        assert_eq!(directions.len(), 10);
516
517        // Each direction should be unit length
518        for dir in &directions {
519            let norm: f64 = dir.iter().map(|&x| x * x).sum::<f64>().sqrt();
520            assert!((norm - 1.0).abs() < 1e-6, "Direction not unit: {}", norm);
521        }
522    }
523}