Skip to main content

ruvector_math/optimal_transport/
sliced_wasserstein.rs

1//! Sliced Wasserstein Distance
2//!
3//! The Sliced Wasserstein distance projects high-dimensional distributions
4//! onto random 1D lines and averages the 1D Wasserstein distances.
5//!
6//! ## Algorithm
7//!
8//! 1. Generate L random unit vectors (directions) in R^d
9//! 2. For each direction θ:
10//!    a. Project all source and target points onto θ
11//!    b. Compute 1D Wasserstein distance (closed-form via sorted quantiles)
12//! 3. Average over all directions
13//!
14//! ## Complexity
15//!
16//! - O(L × n log n) where L = number of projections, n = number of points
17//! - Linear in dimension d (only dot products)
18//!
19//! ## Advantages
20//!
21//! - **Fast**: Near-linear scaling to millions of points
22//! - **SIMD-friendly**: Projections are just dot products
23//! - **Statistically consistent**: Converges to true W2 as L → ∞
24
25use super::{OptimalTransport, WassersteinConfig};
26use crate::utils::{argsort, EPS};
27use rand::prelude::*;
28use rand_distr::StandardNormal;
29
30/// Sliced Wasserstein distance calculator
31#[derive(Debug, Clone)]
32pub struct SlicedWasserstein {
33    /// Number of random projection directions
34    num_projections: usize,
35    /// Power for Wasserstein-p (typically 1 or 2)
36    p: f64,
37    /// Random seed for reproducibility
38    seed: Option<u64>,
39}
40
41impl SlicedWasserstein {
42    /// Create a new Sliced Wasserstein calculator
43    ///
44    /// # Arguments
45    /// * `num_projections` - Number of random 1D projections (100-1000 typical)
46    pub fn new(num_projections: usize) -> Self {
47        Self {
48            num_projections: num_projections.max(1),
49            p: 2.0,
50            seed: None,
51        }
52    }
53
54    /// Create from configuration
55    pub fn from_config(config: &WassersteinConfig) -> Self {
56        Self {
57            num_projections: config.num_projections.max(1),
58            p: config.p,
59            seed: config.seed,
60        }
61    }
62
63    /// Set the Wasserstein power (1 for W1, 2 for W2)
64    pub fn with_power(mut self, p: f64) -> Self {
65        self.p = p.max(1.0);
66        self
67    }
68
69    /// Set random seed for reproducibility
70    pub fn with_seed(mut self, seed: u64) -> Self {
71        self.seed = Some(seed);
72        self
73    }
74
75    /// Generate random unit directions
76    fn generate_directions(&self, dim: usize) -> Vec<Vec<f64>> {
77        let mut rng = match self.seed {
78            Some(s) => StdRng::seed_from_u64(s),
79            None => StdRng::from_entropy(),
80        };
81
82        (0..self.num_projections)
83            .map(|_| {
84                let mut direction: Vec<f64> =
85                    (0..dim).map(|_| rng.sample(StandardNormal)).collect();
86
87                // Normalize to unit vector
88                let norm: f64 = direction.iter().map(|&x| x * x).sum::<f64>().sqrt();
89                if norm > EPS {
90                    for x in &mut direction {
91                        *x /= norm;
92                    }
93                }
94                direction
95            })
96            .collect()
97    }
98
99    /// Project points onto a direction (SIMD-friendly dot product)
100    #[inline(always)]
101    fn project(points: &[Vec<f64>], direction: &[f64]) -> Vec<f64> {
102        points
103            .iter()
104            .map(|p| Self::dot_product(p, direction))
105            .collect()
106    }
107
108    /// Project points into pre-allocated buffer (reduces allocations)
109    #[inline(always)]
110    fn project_into(points: &[Vec<f64>], direction: &[f64], out: &mut [f64]) {
111        for (i, p) in points.iter().enumerate() {
112            out[i] = Self::dot_product(p, direction);
113        }
114    }
115
116    /// SIMD-friendly dot product using fold pattern
117    /// Compiler can auto-vectorize this pattern effectively
118    #[inline(always)]
119    fn dot_product(a: &[f64], b: &[f64]) -> f64 {
120        // Use 4-way unrolled accumulator for better SIMD utilization
121        let len = a.len();
122        let chunks = len / 4;
123        let remainder = len % 4;
124
125        let mut sum0 = 0.0f64;
126        let mut sum1 = 0.0f64;
127        let mut sum2 = 0.0f64;
128        let mut sum3 = 0.0f64;
129
130        // Process 4 elements at a time (helps SIMD vectorization)
131        for i in 0..chunks {
132            let base = i * 4;
133            sum0 += a[base] * b[base];
134            sum1 += a[base + 1] * b[base + 1];
135            sum2 += a[base + 2] * b[base + 2];
136            sum3 += a[base + 3] * b[base + 3];
137        }
138
139        // Handle remainder
140        let base = chunks * 4;
141        for i in 0..remainder {
142            sum0 += a[base + i] * b[base + i];
143        }
144
145        sum0 + sum1 + sum2 + sum3
146    }
147
148    /// Compute 1D Wasserstein distance between two sorted distributions
149    ///
150    /// For uniform weights, this is simply the sum of |sorted_a[i] - sorted_b[i]|^p
151    #[inline]
152    fn wasserstein_1d_uniform(&self, mut proj_a: Vec<f64>, mut proj_b: Vec<f64>) -> f64 {
153        let n = proj_a.len();
154        let m = proj_b.len();
155
156        // Sort projections using fast f64 comparison
157        proj_a.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
158        proj_b.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
159
160        if n == m {
161            // Same size: direct comparison with SIMD-friendly accumulator
162            self.wasserstein_1d_equal_size(&proj_a, &proj_b)
163        } else {
164            // Different sizes: interpolate via quantiles
165            self.wasserstein_1d_quantile(&proj_a, &proj_b, n.max(m))
166        }
167    }
168
169    /// Optimized equal-size 1D Wasserstein with SIMD-friendly pattern
170    #[inline(always)]
171    fn wasserstein_1d_equal_size(&self, sorted_a: &[f64], sorted_b: &[f64]) -> f64 {
172        let n = sorted_a.len();
173        if n == 0 {
174            return 0.0;
175        }
176
177        // Use p=2 fast path (most common case)
178        if (self.p - 2.0).abs() < 1e-10 {
179            // L2 Wasserstein: sum of squared differences
180            let mut sum0 = 0.0f64;
181            let mut sum1 = 0.0f64;
182            let mut sum2 = 0.0f64;
183            let mut sum3 = 0.0f64;
184
185            let chunks = n / 4;
186            let remainder = n % 4;
187
188            for i in 0..chunks {
189                let base = i * 4;
190                let d0 = sorted_a[base] - sorted_b[base];
191                let d1 = sorted_a[base + 1] - sorted_b[base + 1];
192                let d2 = sorted_a[base + 2] - sorted_b[base + 2];
193                let d3 = sorted_a[base + 3] - sorted_b[base + 3];
194                sum0 += d0 * d0;
195                sum1 += d1 * d1;
196                sum2 += d2 * d2;
197                sum3 += d3 * d3;
198            }
199
200            let base = chunks * 4;
201            for i in 0..remainder {
202                let d = sorted_a[base + i] - sorted_b[base + i];
203                sum0 += d * d;
204            }
205
206            (sum0 + sum1 + sum2 + sum3) / n as f64
207        } else if (self.p - 1.0).abs() < 1e-10 {
208            // L1 Wasserstein: sum of absolute differences
209            let mut sum = 0.0f64;
210            for i in 0..n {
211                sum += (sorted_a[i] - sorted_b[i]).abs();
212            }
213            sum / n as f64
214        } else {
215            // General case
216            sorted_a
217                .iter()
218                .zip(sorted_b.iter())
219                .map(|(&a, &b)| (a - b).abs().powf(self.p))
220                .sum::<f64>()
221                / n as f64
222        }
223    }
224
225    /// Compute 1D Wasserstein via quantile interpolation
226    fn wasserstein_1d_quantile(
227        &self,
228        sorted_a: &[f64],
229        sorted_b: &[f64],
230        num_samples: usize,
231    ) -> f64 {
232        let mut total = 0.0;
233
234        for i in 0..num_samples {
235            let q = (i as f64 + 0.5) / num_samples as f64;
236
237            let val_a = quantile_sorted(sorted_a, q);
238            let val_b = quantile_sorted(sorted_b, q);
239
240            total += (val_a - val_b).abs().powf(self.p);
241        }
242
243        total / num_samples as f64
244    }
245
246    /// Compute 1D Wasserstein with weights
247    fn wasserstein_1d_weighted(
248        &self,
249        proj_a: &[f64],
250        weights_a: &[f64],
251        proj_b: &[f64],
252        weights_b: &[f64],
253    ) -> f64 {
254        // Sort by projected values
255        let idx_a = argsort(proj_a);
256        let idx_b = argsort(proj_b);
257
258        let sorted_a: Vec<f64> = idx_a.iter().map(|&i| proj_a[i]).collect();
259        let sorted_w_a: Vec<f64> = idx_a.iter().map(|&i| weights_a[i]).collect();
260        let sorted_b: Vec<f64> = idx_b.iter().map(|&i| proj_b[i]).collect();
261        let sorted_w_b: Vec<f64> = idx_b.iter().map(|&i| weights_b[i]).collect();
262
263        // Compute cumulative weights
264        let cdf_a = compute_cdf(&sorted_w_a);
265        let cdf_b = compute_cdf(&sorted_w_b);
266
267        // Merge and compute
268        self.wasserstein_1d_from_cdfs(&sorted_a, &cdf_a, &sorted_b, &cdf_b)
269    }
270
271    /// Compute 1D Wasserstein from CDFs
272    fn wasserstein_1d_from_cdfs(
273        &self,
274        values_a: &[f64],
275        cdf_a: &[f64],
276        values_b: &[f64],
277        cdf_b: &[f64],
278    ) -> f64 {
279        // Merge all CDF points
280        let mut events: Vec<(f64, f64, f64)> = Vec::new(); // (position, cdf_a, cdf_b)
281
282        let mut ia = 0;
283        let mut ib = 0;
284        let mut current_cdf_a = 0.0;
285        let mut current_cdf_b = 0.0;
286
287        while ia < values_a.len() || ib < values_b.len() {
288            let pos = match (ia < values_a.len(), ib < values_b.len()) {
289                (true, true) => {
290                    if values_a[ia] <= values_b[ib] {
291                        current_cdf_a = cdf_a[ia];
292                        ia += 1;
293                        values_a[ia - 1]
294                    } else {
295                        current_cdf_b = cdf_b[ib];
296                        ib += 1;
297                        values_b[ib - 1]
298                    }
299                }
300                (true, false) => {
301                    current_cdf_a = cdf_a[ia];
302                    ia += 1;
303                    values_a[ia - 1]
304                }
305                (false, true) => {
306                    current_cdf_b = cdf_b[ib];
307                    ib += 1;
308                    values_b[ib - 1]
309                }
310                (false, false) => break,
311            };
312
313            events.push((pos, current_cdf_a, current_cdf_b));
314        }
315
316        // Integrate |F_a - F_b|^p
317        let mut total = 0.0;
318        for i in 1..events.len() {
319            let width = events[i].0 - events[i - 1].0;
320            let height = (events[i - 1].1 - events[i - 1].2).abs();
321            total += width * height.powf(self.p);
322        }
323
324        total
325    }
326}
327
328impl OptimalTransport for SlicedWasserstein {
329    fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> f64 {
330        if source.is_empty() || target.is_empty() {
331            return 0.0;
332        }
333
334        let dim = source[0].len();
335        if dim == 0 {
336            return 0.0;
337        }
338
339        let directions = self.generate_directions(dim);
340        let n_source = source.len();
341        let n_target = target.len();
342
343        // Pre-allocate projection buffers (reduces allocations per direction)
344        let mut proj_source = vec![0.0; n_source];
345        let mut proj_target = vec![0.0; n_target];
346
347        let total: f64 = directions
348            .iter()
349            .map(|dir| {
350                // Project into pre-allocated buffers
351                Self::project_into(source, dir, &mut proj_source);
352                Self::project_into(target, dir, &mut proj_target);
353
354                // Clone for sorting (wasserstein_1d_uniform sorts in place)
355                self.wasserstein_1d_uniform(proj_source.clone(), proj_target.clone())
356            })
357            .sum();
358
359        (total / self.num_projections as f64).powf(1.0 / self.p)
360    }
361
362    fn weighted_distance(
363        &self,
364        source: &[Vec<f64>],
365        source_weights: &[f64],
366        target: &[Vec<f64>],
367        target_weights: &[f64],
368    ) -> f64 {
369        if source.is_empty() || target.is_empty() {
370            return 0.0;
371        }
372
373        let dim = source[0].len();
374        if dim == 0 {
375            return 0.0;
376        }
377
378        // Normalize weights
379        let sum_a: f64 = source_weights.iter().sum();
380        let sum_b: f64 = target_weights.iter().sum();
381        let weights_a: Vec<f64> = source_weights.iter().map(|&w| w / sum_a).collect();
382        let weights_b: Vec<f64> = target_weights.iter().map(|&w| w / sum_b).collect();
383
384        let directions = self.generate_directions(dim);
385
386        let total: f64 = directions
387            .iter()
388            .map(|dir| {
389                let proj_source = Self::project(source, dir);
390                let proj_target = Self::project(target, dir);
391                self.wasserstein_1d_weighted(&proj_source, &weights_a, &proj_target, &weights_b)
392            })
393            .sum();
394
395        (total / self.num_projections as f64).powf(1.0 / self.p)
396    }
397}
398
399/// Quantile of sorted data
400fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
401    if sorted.is_empty() {
402        return 0.0;
403    }
404
405    let q = q.clamp(0.0, 1.0);
406    let n = sorted.len();
407
408    if n == 1 {
409        return sorted[0];
410    }
411
412    let idx_f = q * (n - 1) as f64;
413    let idx_low = idx_f.floor() as usize;
414    let idx_high = (idx_low + 1).min(n - 1);
415    let frac = idx_f - idx_low as f64;
416
417    sorted[idx_low] * (1.0 - frac) + sorted[idx_high] * frac
418}
419
420/// Compute CDF from weights
421fn compute_cdf(weights: &[f64]) -> Vec<f64> {
422    let total: f64 = weights.iter().sum();
423    let mut cdf = Vec::with_capacity(weights.len());
424    let mut cumsum = 0.0;
425
426    for &w in weights {
427        cumsum += w / total;
428        cdf.push(cumsum);
429    }
430
431    cdf
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_sliced_wasserstein_identical() {
440        let sw = SlicedWasserstein::new(100).with_seed(42);
441
442        let points = vec![
443            vec![0.0, 0.0],
444            vec![1.0, 0.0],
445            vec![0.0, 1.0],
446            vec![1.0, 1.0],
447        ];
448
449        // Distance to itself should be very small
450        let dist = sw.distance(&points, &points);
451        assert!(dist < 0.01, "Self-distance should be ~0, got {}", dist);
452    }
453
454    #[test]
455    fn test_sliced_wasserstein_translation() {
456        let sw = SlicedWasserstein::new(500).with_seed(42);
457
458        let source = vec![
459            vec![0.0, 0.0],
460            vec![1.0, 0.0],
461            vec![0.0, 1.0],
462            vec![1.0, 1.0],
463        ];
464
465        // Translate by (1, 1)
466        let target: Vec<Vec<f64>> = source
467            .iter()
468            .map(|p| vec![p[0] + 1.0, p[1] + 1.0])
469            .collect();
470
471        let dist = sw.distance(&source, &target);
472
473        // For W2 translation by (1, 1), expected distance is sqrt(2) ≈ 1.414
474        // But Sliced Wasserstein is an approximation, so allow wider tolerance
475        assert!(
476            dist > 0.5 && dist < 2.0,
477            "Translation distance should be positive, got {:.3}",
478            dist
479        );
480    }
481
482    #[test]
483    fn test_sliced_wasserstein_scaling() {
484        let sw = SlicedWasserstein::new(500).with_seed(42);
485
486        let source = vec![
487            vec![0.0, 0.0],
488            vec![1.0, 0.0],
489            vec![0.0, 1.0],
490            vec![1.0, 1.0],
491        ];
492
493        // Scale by 2
494        let target: Vec<Vec<f64>> = source
495            .iter()
496            .map(|p| vec![p[0] * 2.0, p[1] * 2.0])
497            .collect();
498
499        let dist = sw.distance(&source, &target);
500
501        // Should be positive for scaled distribution
502        assert!(dist > 0.0, "Scaling should produce positive distance");
503    }
504
505    #[test]
506    fn test_weighted_distance() {
507        let sw = SlicedWasserstein::new(100).with_seed(42);
508
509        let source = vec![vec![0.0], vec![1.0]];
510        let target = vec![vec![2.0], vec![3.0]];
511
512        // Uniform weights
513        let weights_s = vec![0.5, 0.5];
514        let weights_t = vec![0.5, 0.5];
515
516        let dist = sw.weighted_distance(&source, &weights_s, &target, &weights_t);
517        assert!(dist > 0.0);
518    }
519
520    #[test]
521    fn test_1d_projections() {
522        let sw = SlicedWasserstein::new(10);
523        let directions = sw.generate_directions(3);
524
525        assert_eq!(directions.len(), 10);
526
527        // Each direction should be unit length
528        for dir in &directions {
529            let norm: f64 = dir.iter().map(|&x| x * x).sum::<f64>().sqrt();
530            assert!((norm - 1.0).abs() < 1e-6, "Direction not unit: {}", norm);
531        }
532    }
533}