sklears_kernel_approximation/
optimal_transport.rs

1//! Optimal transport kernel approximation methods
2//!
3//! This module implements kernel approximation methods based on optimal transport theory,
4//! including Wasserstein kernels, Sinkhorn divergences, and earth mover's distance approximations.
5
6use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_core::rand_prelude::IteratorRandom;
8use scirs2_core::random::essentials::Uniform as RandUniform;
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::{
13    error::{Result as SklResult, SklearsError},
14    prelude::{Fit, Transform},
15};
16
17/// Wasserstein kernel approximation using random features
18///
19/// This implements approximations to the Wasserstein kernel which computes
20/// optimal transport distances between empirical distributions.
21#[derive(Debug, Clone)]
22/// WassersteinKernelSampler
23pub struct WassersteinKernelSampler {
24    /// Number of random features to generate
25    n_components: usize,
26    /// Ground metric for optimal transport (default: squared Euclidean)
27    ground_metric: GroundMetric,
28    /// Regularization parameter for Sinkhorn divergence
29    epsilon: f64,
30    /// Maximum number of Sinkhorn iterations
31    max_iter: usize,
32    /// Convergence tolerance for Sinkhorn
33    tolerance: f64,
34    /// Random state for reproducibility
35    random_state: Option<u64>,
36    /// Fitted random projections
37    projections: Option<Array2<f64>>,
38    /// Transport plan approximation method
39    transport_method: TransportMethod,
40}
41
42/// Ground metric options for optimal transport
43#[derive(Debug, Clone)]
44/// GroundMetric
45pub enum GroundMetric {
46    /// Squared Euclidean distance
47    SquaredEuclidean,
48    /// Euclidean distance
49    Euclidean,
50    /// Manhattan (L1) distance
51    Manhattan,
52    /// Minkowski distance with parameter p
53    Minkowski(f64),
54    /// Custom metric function
55    Custom(fn(&Array1<f64>, &Array1<f64>) -> f64),
56}
57
58/// Transport plan approximation methods
59#[derive(Debug, Clone)]
60/// TransportMethod
61pub enum TransportMethod {
62    /// Sinkhorn divergence approximation
63    Sinkhorn,
64    /// Sliced Wasserstein using random projections
65    SlicedWasserstein,
66    /// Tree-Wasserstein using hierarchical decomposition
67    TreeWasserstein,
68    /// Projection-based approximation
69    ProjectionBased,
70}
71
72impl Default for WassersteinKernelSampler {
73    fn default() -> Self {
74        Self::new(100)
75    }
76}
77
78impl WassersteinKernelSampler {
79    /// Create a new Wasserstein kernel approximation sampler
80    ///
81    /// # Arguments
82    /// * `n_components` - Number of random features to generate
83    ///
84    /// # Examples
85    /// ```
86    /// use sklears_kernel_approximation::WassersteinKernelSampler;
87    /// let sampler = WassersteinKernelSampler::new(100);
88    /// ```
89    pub fn new(n_components: usize) -> Self {
90        Self {
91            n_components,
92            ground_metric: GroundMetric::SquaredEuclidean,
93            epsilon: 0.1,
94            max_iter: 1000,
95            tolerance: 1e-9,
96            random_state: None,
97            projections: None,
98            transport_method: TransportMethod::SlicedWasserstein,
99        }
100    }
101
102    /// Set the ground metric for optimal transport
103    pub fn ground_metric(mut self, metric: GroundMetric) -> Self {
104        self.ground_metric = metric;
105        self
106    }
107
108    /// Set the regularization parameter for Sinkhorn divergence
109    pub fn epsilon(mut self, epsilon: f64) -> Self {
110        self.epsilon = epsilon;
111        self
112    }
113
114    /// Set the maximum number of iterations for Sinkhorn
115    pub fn max_iter(mut self, max_iter: usize) -> Self {
116        self.max_iter = max_iter;
117        self
118    }
119
120    /// Set the convergence tolerance
121    pub fn tolerance(mut self, tolerance: f64) -> Self {
122        self.tolerance = tolerance;
123        self
124    }
125
126    /// Set the random state for reproducibility
127    pub fn random_state(mut self, seed: u64) -> Self {
128        self.random_state = Some(seed);
129        self
130    }
131
132    /// Set the transport approximation method
133    pub fn transport_method(mut self, method: TransportMethod) -> Self {
134        self.transport_method = method;
135        self
136    }
137
138    /// Compute distance using the specified ground metric
139    fn compute_ground_distance(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
140        match &self.ground_metric {
141            GroundMetric::SquaredEuclidean => (x - y).mapv(|v| v * v).sum(),
142            GroundMetric::Euclidean => (x - y).mapv(|v| v * v).sum().sqrt(),
143            GroundMetric::Manhattan => (x - y).mapv(|v| v.abs()).sum(),
144            GroundMetric::Minkowski(p) => (x - y).mapv(|v| v.abs().powf(*p)).sum().powf(1.0 / p),
145            GroundMetric::Custom(func) => func(x, y),
146        }
147    }
148
149    /// Compute sliced Wasserstein distance using random projections
150    fn sliced_wasserstein_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
151        let projections = self
152            .projections
153            .as_ref()
154            .ok_or_else(|| SklearsError::NotFitted {
155                operation: "transform".to_string(),
156            })?;
157
158        let n_samples = x.nrows();
159        let mut features = Array2::zeros((n_samples, self.n_components));
160
161        // Project data onto random directions
162        let projected = x.dot(projections);
163
164        // For each projection, compute cumulative distribution features
165        for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
166            let mut sorted_proj: Vec<f64> = proj_col.to_vec();
167            sorted_proj.sort_by(|a, b| a.partial_cmp(b).unwrap());
168
169            // Use quantile-based features
170            for (i, &val) in proj_col.iter().enumerate() {
171                // Find quantile of this value in the sorted distribution
172                let quantile = sorted_proj
173                    .binary_search_by(|&probe| probe.partial_cmp(&val).unwrap())
174                    .unwrap_or_else(|e| e) as f64
175                    / sorted_proj.len() as f64;
176
177                features[[i, j]] = quantile;
178            }
179        }
180
181        Ok(features)
182    }
183
184    /// Compute Sinkhorn divergence features
185    fn sinkhorn_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
186        let n_samples = x.nrows();
187        let mut features = Array2::zeros((n_samples, self.n_components));
188
189        // Use random subsampling for Sinkhorn approximation
190        let mut rng = if let Some(seed) = self.random_state {
191            RealStdRng::seed_from_u64(seed)
192        } else {
193            RealStdRng::from_seed(thread_rng().gen())
194        };
195
196        for j in 0..self.n_components {
197            // Sample a subset of points for comparison
198            let subset_size = (n_samples as f64).sqrt() as usize + 1;
199            let indices: Vec<usize> = (0..n_samples)
200                .choose_multiple(&mut rng, subset_size)
201                .into_iter()
202                .collect();
203
204            for (i, x_i) in x.axis_iter(Axis(0)).enumerate() {
205                let mut total_divergence = 0.0;
206
207                for &idx in &indices {
208                    if idx != i {
209                        let x_j = x.row(idx);
210                        let distance =
211                            self.compute_ground_distance(&x_i.to_owned(), &x_j.to_owned());
212                        total_divergence += (-distance / self.epsilon).exp();
213                    }
214                }
215
216                features[[i, j]] = total_divergence / indices.len() as f64;
217            }
218        }
219
220        Ok(features)
221    }
222
223    /// Compute tree-Wasserstein features using hierarchical decomposition
224    fn tree_wasserstein_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
225        let n_samples = x.nrows();
226        let mut features = Array2::zeros((n_samples, self.n_components));
227
228        // Use hierarchical clustering approximation
229        let projections = self
230            .projections
231            .as_ref()
232            .ok_or_else(|| SklearsError::NotFitted {
233                operation: "transform".to_string(),
234            })?;
235
236        let projected = x.dot(projections);
237
238        for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
239            // Create hierarchical bins
240            let min_val = proj_col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
241            let max_val = proj_col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
242            let bin_width = (max_val - min_val) / 10.0; // 10 bins
243
244            for (i, &val) in proj_col.iter().enumerate() {
245                let bin = ((val - min_val) / bin_width).floor() as usize;
246                let bin = bin.min(9); // Ensure we don't exceed bounds
247                features[[i, j]] = bin as f64 / 9.0; // Normalize
248            }
249        }
250
251        Ok(features)
252    }
253}
254
255impl Fit<Array2<f64>, ()> for WassersteinKernelSampler {
256    type Fitted = FittedWassersteinSampler;
257
258    fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
259        let n_features = x.ncols();
260
261        // Generate random projections for sliced Wasserstein
262        let mut rng = if let Some(seed) = self.random_state {
263            RealStdRng::seed_from_u64(seed)
264        } else {
265            RealStdRng::from_seed(thread_rng().gen())
266        };
267
268        let uniform = RandUniform::new(-1.0, 1.0).unwrap();
269        let mut projections = Array2::zeros((n_features, self.n_components));
270
271        for mut col in projections.axis_iter_mut(Axis(1)) {
272            for elem in col.iter_mut() {
273                *elem = rng.sample(uniform);
274            }
275            // Normalize the projection vector
276            let norm: f64 = col.mapv(|v| v * v).sum();
277            let norm = norm.sqrt();
278            col /= norm;
279        }
280
281        Ok(FittedWassersteinSampler {
282            sampler: WassersteinKernelSampler {
283                projections: Some(projections),
284                ..self.clone()
285            },
286        })
287    }
288}
289
290/// Fitted Wasserstein kernel sampler
291pub struct FittedWassersteinSampler {
292    sampler: WassersteinKernelSampler,
293}
294
295impl Transform<Array2<f64>, Array2<f64>> for FittedWassersteinSampler {
296    fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
297        match self.sampler.transport_method {
298            TransportMethod::SlicedWasserstein => self.sampler.sliced_wasserstein_features(x),
299            TransportMethod::Sinkhorn => self.sampler.sinkhorn_features(x),
300            TransportMethod::TreeWasserstein => self.sampler.tree_wasserstein_features(x),
301            TransportMethod::ProjectionBased => self.sampler.sliced_wasserstein_features(x),
302        }
303    }
304}
305
306/// Earth Mover's Distance (EMD) kernel approximation
307///
308/// Implements approximations to kernels based on the earth mover's distance,
309/// also known as the Wasserstein-1 distance.
310#[derive(Debug, Clone)]
311/// EMDKernelSampler
312pub struct EMDKernelSampler {
313    /// Number of random features
314    n_components: usize,
315    /// Bandwidth parameter for the kernel
316    bandwidth: f64,
317    /// Random state for reproducibility
318    random_state: Option<u64>,
319    /// Fitted random projections
320    projections: Option<Array2<f64>>,
321    /// Number of quantile bins for approximation
322    n_bins: usize,
323}
324
325impl Default for EMDKernelSampler {
326    fn default() -> Self {
327        Self::new(100)
328    }
329}
330
331impl EMDKernelSampler {
332    /// Create a new EMD kernel sampler
333    pub fn new(n_components: usize) -> Self {
334        Self {
335            n_components,
336            bandwidth: 1.0,
337            random_state: None,
338            projections: None,
339            n_bins: 20,
340        }
341    }
342
343    /// Set the bandwidth parameter
344    pub fn bandwidth(mut self, bandwidth: f64) -> Self {
345        self.bandwidth = bandwidth;
346        self
347    }
348
349    /// Set the random state
350    pub fn random_state(mut self, seed: u64) -> Self {
351        self.random_state = Some(seed);
352        self
353    }
354
355    /// Set the number of quantile bins
356    pub fn n_bins(mut self, n_bins: usize) -> Self {
357        self.n_bins = n_bins;
358        self
359    }
360}
361
362impl Fit<Array2<f64>, ()> for EMDKernelSampler {
363    type Fitted = FittedEMDSampler;
364
365    fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
366        let n_features = x.ncols();
367
368        let mut rng = if let Some(seed) = self.random_state {
369            RealStdRng::seed_from_u64(seed)
370        } else {
371            RealStdRng::from_seed(thread_rng().gen())
372        };
373
374        let uniform = RandUniform::new(-1.0, 1.0).unwrap();
375        let mut projections = Array2::zeros((n_features, self.n_components));
376
377        for mut col in projections.axis_iter_mut(Axis(1)) {
378            for elem in col.iter_mut() {
379                *elem = rng.sample(uniform);
380            }
381            let norm: f64 = col.mapv(|v| v * v).sum();
382            let norm = norm.sqrt();
383            col /= norm;
384        }
385
386        Ok(FittedEMDSampler {
387            sampler: EMDKernelSampler {
388                projections: Some(projections),
389                ..self.clone()
390            },
391        })
392    }
393}
394
395/// Fitted EMD kernel sampler
396pub struct FittedEMDSampler {
397    sampler: EMDKernelSampler,
398}
399
400impl Transform<Array2<f64>, Array2<f64>> for FittedEMDSampler {
401    fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
402        let projections =
403            self.sampler
404                .projections
405                .as_ref()
406                .ok_or_else(|| SklearsError::NotFitted {
407                    operation: "transform".to_string(),
408                })?;
409
410        let n_samples = x.nrows();
411        let mut features = Array2::zeros((n_samples, self.sampler.n_components));
412
413        let projected = x.dot(projections);
414
415        for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
416            // Compute quantile-based features for EMD approximation
417            let min_val = proj_col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
418            let max_val = proj_col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
419            let range = max_val - min_val;
420
421            if range > 1e-10 {
422                for (i, &val) in proj_col.iter().enumerate() {
423                    // Compute cumulative distribution function approximation
424                    let normalized_val = (val - min_val) / range;
425                    let emd_feature = (-normalized_val.abs() / self.sampler.bandwidth).exp();
426                    features[[i, j]] = emd_feature;
427                }
428            }
429        }
430
431        Ok(features)
432    }
433}
434
435/// Gromov-Wasserstein kernel approximation
436///
437/// Implements approximations to Gromov-Wasserstein distances which compare
438/// metric measure spaces by their intrinsic geometry.
439#[derive(Debug, Clone)]
440/// GromovWassersteinSampler
441pub struct GromovWassersteinSampler {
442    /// Number of random features
443    n_components: usize,
444    /// Loss function for Gromov-Wasserstein
445    loss_function: GWLossFunction,
446    /// Number of iterations for optimization
447    max_iter: usize,
448    /// Random state
449    random_state: Option<u64>,
450    /// Fitted parameters
451    fitted_params: Option<GWFittedParams>,
452}
453
454/// Loss functions for Gromov-Wasserstein
455#[derive(Debug, Clone)]
456/// GWLossFunction
457pub enum GWLossFunction {
458    /// Squared loss
459    Square,
460    /// KL divergence
461    KlLoss,
462    /// Custom loss function
463    Custom(fn(f64, f64, f64, f64) -> f64),
464}
465
466#[derive(Debug, Clone)]
467struct GWFittedParams {
468    reference_distances: Array2<f64>,
469    projections: Array2<f64>,
470}
471
472impl Default for GromovWassersteinSampler {
473    fn default() -> Self {
474        Self::new(50)
475    }
476}
477
478impl GromovWassersteinSampler {
479    /// Create a new Gromov-Wasserstein sampler
480    pub fn new(n_components: usize) -> Self {
481        Self {
482            n_components,
483            loss_function: GWLossFunction::Square,
484            max_iter: 100,
485            random_state: None,
486            fitted_params: None,
487        }
488    }
489
490    /// Set the loss function
491    pub fn loss_function(mut self, loss: GWLossFunction) -> Self {
492        self.loss_function = loss;
493        self
494    }
495
496    /// Set random state
497    pub fn random_state(mut self, seed: u64) -> Self {
498        self.random_state = Some(seed);
499        self
500    }
501
502    /// Compute pairwise distances for metric space comparison
503    fn compute_distance_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
504        let n = x.nrows();
505        let mut distances = Array2::zeros((n, n));
506
507        for i in 0..n {
508            for j in i..n {
509                let diff = &x.row(i).to_owned() - &x.row(j);
510                let dist: f64 = diff.mapv(|v| v * v).sum();
511                let dist = dist.sqrt();
512                distances[[i, j]] = dist;
513                distances[[j, i]] = dist;
514            }
515        }
516
517        distances
518    }
519}
520
521impl Fit<Array2<f64>, ()> for GromovWassersteinSampler {
522    type Fitted = FittedGromovWassersteinSampler;
523
524    fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
525        let n_features = x.ncols();
526        let _n_samples = x.nrows();
527
528        // Compute reference distance matrix
529        let distance_matrix = self.compute_distance_matrix(x);
530
531        // Generate random projections for dimensionality reduction
532        let mut rng = if let Some(seed) = self.random_state {
533            RealStdRng::seed_from_u64(seed)
534        } else {
535            RealStdRng::from_seed(thread_rng().gen())
536        };
537
538        let uniform = RandUniform::new(-1.0, 1.0).unwrap();
539        let mut projections = Array2::zeros((n_features, self.n_components));
540
541        for mut col in projections.axis_iter_mut(Axis(1)) {
542            for elem in col.iter_mut() {
543                *elem = rng.sample(uniform);
544            }
545            let norm: f64 = col.mapv(|v| v * v).sum();
546            let norm = norm.sqrt();
547            col /= norm;
548        }
549
550        let fitted_params = GWFittedParams {
551            reference_distances: distance_matrix,
552            projections,
553        };
554
555        Ok(FittedGromovWassersteinSampler {
556            sampler: GromovWassersteinSampler {
557                fitted_params: Some(fitted_params),
558                ..self.clone()
559            },
560        })
561    }
562}
563
564/// Fitted Gromov-Wasserstein sampler
565pub struct FittedGromovWassersteinSampler {
566    sampler: GromovWassersteinSampler,
567}
568
569impl Transform<Array2<f64>, Array2<f64>> for FittedGromovWassersteinSampler {
570    fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
571        let params =
572            self.sampler
573                .fitted_params
574                .as_ref()
575                .ok_or_else(|| SklearsError::NotFitted {
576                    operation: "transform".to_string(),
577                })?;
578
579        let n_samples = x.nrows();
580        let mut features = Array2::zeros((n_samples, self.sampler.n_components));
581
582        // Project data and compute distance-based features
583        let projected = x.dot(&params.projections);
584
585        for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
586            for (i, &val) in proj_col.iter().enumerate() {
587                // Use projected coordinates as distance-based features
588                features[[i, j]] = val.tanh(); // Bounded activation
589            }
590        }
591
592        Ok(features)
593    }
594}
595
596#[allow(non_snake_case)]
597#[cfg(test)]
598mod tests {
599    use super::*;
600    use scirs2_core::ndarray::array;
601    use sklears_core::traits::{Fit, Transform};
602
603    #[test]
604    fn test_wasserstein_kernel_sampler() {
605        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
606
607        let sampler = WassersteinKernelSampler::new(10)
608            .random_state(42)
609            .transport_method(TransportMethod::SlicedWasserstein);
610
611        let fitted = sampler.fit(&x, &()).unwrap();
612        let features = fitted.transform(&x).unwrap();
613
614        assert_eq!(features.shape(), &[4, 10]);
615
616        // Features should be bounded
617        assert!(features.iter().all(|&f| f >= 0.0 && f <= 1.0));
618    }
619
620    #[test]
621    fn test_emd_kernel_sampler() {
622        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
623
624        let sampler = EMDKernelSampler::new(15).bandwidth(0.5).random_state(123);
625
626        let fitted = sampler.fit(&x, &()).unwrap();
627        let features = fitted.transform(&x).unwrap();
628
629        assert_eq!(features.shape(), &[3, 15]);
630
631        // Features should be positive (exponential)
632        assert!(features.iter().all(|&f| f >= 0.0));
633    }
634
635    #[test]
636    fn test_gromov_wasserstein_sampler() {
637        let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0],];
638
639        let sampler = GromovWassersteinSampler::new(8).random_state(456);
640
641        let fitted = sampler.fit(&x, &()).unwrap();
642        let features = fitted.transform(&x).unwrap();
643
644        assert_eq!(features.shape(), &[4, 8]);
645
646        // Features should be bounded by tanh
647        assert!(features.iter().all(|&f| f >= -1.0 && f <= 1.0));
648    }
649
650    #[test]
651    fn test_wasserstein_different_methods() {
652        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
653
654        let methods = vec![
655            TransportMethod::SlicedWasserstein,
656            TransportMethod::Sinkhorn,
657            TransportMethod::TreeWasserstein,
658        ];
659
660        for method in methods {
661            let sampler = WassersteinKernelSampler::new(5)
662                .transport_method(method)
663                .random_state(42);
664
665            let fitted = sampler.fit(&x, &()).unwrap();
666            let features = fitted.transform(&x).unwrap();
667
668            assert_eq!(features.shape(), &[3, 5]);
669        }
670    }
671
672    #[test]
673    fn test_different_ground_metrics() {
674        let x = array![[1.0, 2.0], [2.0, 3.0],];
675
676        let metrics = vec![
677            GroundMetric::SquaredEuclidean,
678            GroundMetric::Euclidean,
679            GroundMetric::Manhattan,
680            GroundMetric::Minkowski(3.0),
681        ];
682
683        for metric in metrics {
684            let sampler = WassersteinKernelSampler::new(3)
685                .ground_metric(metric)
686                .random_state(42);
687
688            let fitted = sampler.fit(&x, &()).unwrap();
689            let features = fitted.transform(&x).unwrap();
690
691            assert_eq!(features.shape(), &[2, 3]);
692        }
693    }
694
695    #[test]
696    fn test_reproducibility() {
697        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
698
699        let sampler1 = WassersteinKernelSampler::new(5).random_state(42);
700        let sampler2 = WassersteinKernelSampler::new(5).random_state(42);
701
702        let fitted1 = sampler1.fit(&x, &()).unwrap();
703        let fitted2 = sampler2.fit(&x, &()).unwrap();
704
705        let features1 = fitted1.transform(&x).unwrap();
706        let features2 = fitted2.transform(&x).unwrap();
707
708        // Results should be identical with same random state
709        assert!((features1 - features2).mapv(|v| v.abs()).sum() < 1e-10);
710    }
711}