Skip to main content

sklears_kernel_approximation/
optimal_transport.rs

1//! Optimal transport kernel approximation methods
2//!
3//! This module implements kernel approximation methods based on optimal transport theory,
4//! including Wasserstein kernels, Sinkhorn divergences, and earth mover's distance approximations.
5
6use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_core::rand_prelude::IteratorRandom;
8use scirs2_core::random::essentials::Uniform as RandUniform;
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::RngExt;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::{
13    error::{Result as SklResult, SklearsError},
14    prelude::{Fit, Transform},
15};
16
17/// Wasserstein kernel approximation using random features
18///
19/// This implements approximations to the Wasserstein kernel which computes
20/// optimal transport distances between empirical distributions.
21#[derive(Debug, Clone)]
22/// WassersteinKernelSampler
23pub struct WassersteinKernelSampler {
24    /// Number of random features to generate
25    n_components: usize,
26    /// Ground metric for optimal transport (default: squared Euclidean)
27    ground_metric: GroundMetric,
28    /// Regularization parameter for Sinkhorn divergence
29    epsilon: f64,
30    /// Maximum number of Sinkhorn iterations
31    max_iter: usize,
32    /// Convergence tolerance for Sinkhorn
33    tolerance: f64,
34    /// Random state for reproducibility
35    random_state: Option<u64>,
36    /// Fitted random projections
37    projections: Option<Array2<f64>>,
38    /// Transport plan approximation method
39    transport_method: TransportMethod,
40}
41
42/// Ground metric options for optimal transport
43#[derive(Debug, Clone)]
44/// GroundMetric
45pub enum GroundMetric {
46    /// Squared Euclidean distance
47    SquaredEuclidean,
48    /// Euclidean distance
49    Euclidean,
50    /// Manhattan (L1) distance
51    Manhattan,
52    /// Minkowski distance with parameter p
53    Minkowski(f64),
54    /// Custom metric function
55    Custom(fn(&Array1<f64>, &Array1<f64>) -> f64),
56}
57
58/// Transport plan approximation methods
59#[derive(Debug, Clone)]
60/// TransportMethod
61pub enum TransportMethod {
62    /// Sinkhorn divergence approximation
63    Sinkhorn,
64    /// Sliced Wasserstein using random projections
65    SlicedWasserstein,
66    /// Tree-Wasserstein using hierarchical decomposition
67    TreeWasserstein,
68    /// Projection-based approximation
69    ProjectionBased,
70}
71
72impl Default for WassersteinKernelSampler {
73    fn default() -> Self {
74        Self::new(100)
75    }
76}
77
78impl WassersteinKernelSampler {
79    /// Create a new Wasserstein kernel approximation sampler
80    ///
81    /// # Arguments
82    /// * `n_components` - Number of random features to generate
83    ///
84    /// # Examples
85    /// ```
86    /// use sklears_kernel_approximation::WassersteinKernelSampler;
87    /// let sampler = WassersteinKernelSampler::new(100);
88    /// ```
89    pub fn new(n_components: usize) -> Self {
90        Self {
91            n_components,
92            ground_metric: GroundMetric::SquaredEuclidean,
93            epsilon: 0.1,
94            max_iter: 1000,
95            tolerance: 1e-9,
96            random_state: None,
97            projections: None,
98            transport_method: TransportMethod::SlicedWasserstein,
99        }
100    }
101
102    /// Set the ground metric for optimal transport
103    pub fn ground_metric(mut self, metric: GroundMetric) -> Self {
104        self.ground_metric = metric;
105        self
106    }
107
108    /// Set the regularization parameter for Sinkhorn divergence
109    pub fn epsilon(mut self, epsilon: f64) -> Self {
110        self.epsilon = epsilon;
111        self
112    }
113
114    /// Set the maximum number of iterations for Sinkhorn
115    pub fn max_iter(mut self, max_iter: usize) -> Self {
116        self.max_iter = max_iter;
117        self
118    }
119
120    /// Set the convergence tolerance
121    pub fn tolerance(mut self, tolerance: f64) -> Self {
122        self.tolerance = tolerance;
123        self
124    }
125
126    /// Set the random state for reproducibility
127    pub fn random_state(mut self, seed: u64) -> Self {
128        self.random_state = Some(seed);
129        self
130    }
131
132    /// Set the transport approximation method
133    pub fn transport_method(mut self, method: TransportMethod) -> Self {
134        self.transport_method = method;
135        self
136    }
137
138    /// Compute distance using the specified ground metric
139    fn compute_ground_distance(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
140        match &self.ground_metric {
141            GroundMetric::SquaredEuclidean => (x - y).mapv(|v| v * v).sum(),
142            GroundMetric::Euclidean => (x - y).mapv(|v| v * v).sum().sqrt(),
143            GroundMetric::Manhattan => (x - y).mapv(|v| v.abs()).sum(),
144            GroundMetric::Minkowski(p) => (x - y).mapv(|v| v.abs().powf(*p)).sum().powf(1.0 / p),
145            GroundMetric::Custom(func) => func(x, y),
146        }
147    }
148
149    /// Compute sliced Wasserstein distance using random projections
150    fn sliced_wasserstein_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
151        let projections = self
152            .projections
153            .as_ref()
154            .ok_or_else(|| SklearsError::NotFitted {
155                operation: "transform".to_string(),
156            })?;
157
158        let n_samples = x.nrows();
159        let mut features = Array2::zeros((n_samples, self.n_components));
160
161        // Project data onto random directions
162        let projected = x.dot(projections);
163
164        // For each projection, compute cumulative distribution features
165        for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
166            let mut sorted_proj: Vec<f64> = proj_col.to_vec();
167            sorted_proj.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
168
169            // Use quantile-based features
170            for (i, &val) in proj_col.iter().enumerate() {
171                // Find quantile of this value in the sorted distribution
172                let quantile = sorted_proj
173                    .binary_search_by(|&probe| {
174                        probe.partial_cmp(&val).expect("operation should succeed")
175                    })
176                    .unwrap_or_else(|e| e) as f64
177                    / sorted_proj.len() as f64;
178
179                features[[i, j]] = quantile;
180            }
181        }
182
183        Ok(features)
184    }
185
186    /// Compute Sinkhorn divergence features
187    fn sinkhorn_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
188        let n_samples = x.nrows();
189        let mut features = Array2::zeros((n_samples, self.n_components));
190
191        // Use random subsampling for Sinkhorn approximation
192        let mut rng = if let Some(seed) = self.random_state {
193            RealStdRng::seed_from_u64(seed)
194        } else {
195            RealStdRng::from_seed(thread_rng().random())
196        };
197
198        for j in 0..self.n_components {
199            // Sample a subset of points for comparison
200            let subset_size = (n_samples as f64).sqrt() as usize + 1;
201            let indices: Vec<usize> = (0..n_samples)
202                .choose_multiple(&mut rng, subset_size)
203                .into_iter()
204                .collect();
205
206            for (i, x_i) in x.axis_iter(Axis(0)).enumerate() {
207                let mut total_divergence = 0.0;
208
209                for &idx in &indices {
210                    if idx != i {
211                        let x_j = x.row(idx);
212                        let distance =
213                            self.compute_ground_distance(&x_i.to_owned(), &x_j.to_owned());
214                        total_divergence += (-distance / self.epsilon).exp();
215                    }
216                }
217
218                features[[i, j]] = total_divergence / indices.len() as f64;
219            }
220        }
221
222        Ok(features)
223    }
224
225    /// Compute tree-Wasserstein features using hierarchical decomposition
226    fn tree_wasserstein_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
227        let n_samples = x.nrows();
228        let mut features = Array2::zeros((n_samples, self.n_components));
229
230        // Use hierarchical clustering approximation
231        let projections = self
232            .projections
233            .as_ref()
234            .ok_or_else(|| SklearsError::NotFitted {
235                operation: "transform".to_string(),
236            })?;
237
238        let projected = x.dot(projections);
239
240        for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
241            // Create hierarchical bins
242            let min_val = proj_col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
243            let max_val = proj_col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
244            let bin_width = (max_val - min_val) / 10.0; // 10 bins
245
246            for (i, &val) in proj_col.iter().enumerate() {
247                let bin = ((val - min_val) / bin_width).floor() as usize;
248                let bin = bin.min(9); // Ensure we don't exceed bounds
249                features[[i, j]] = bin as f64 / 9.0; // Normalize
250            }
251        }
252
253        Ok(features)
254    }
255}
256
257impl Fit<Array2<f64>, ()> for WassersteinKernelSampler {
258    type Fitted = FittedWassersteinSampler;
259
260    fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
261        let n_features = x.ncols();
262
263        // Generate random projections for sliced Wasserstein
264        let mut rng = if let Some(seed) = self.random_state {
265            RealStdRng::seed_from_u64(seed)
266        } else {
267            RealStdRng::from_seed(thread_rng().random())
268        };
269
270        let uniform = RandUniform::new(-1.0, 1.0).expect("operation should succeed");
271        let mut projections = Array2::zeros((n_features, self.n_components));
272
273        for mut col in projections.axis_iter_mut(Axis(1)) {
274            for elem in col.iter_mut() {
275                *elem = rng.sample(uniform);
276            }
277            // Normalize the projection vector
278            let norm: f64 = col.mapv(|v| v * v).sum();
279            let norm = norm.sqrt();
280            col /= norm;
281        }
282
283        Ok(FittedWassersteinSampler {
284            sampler: WassersteinKernelSampler {
285                projections: Some(projections),
286                ..self.clone()
287            },
288        })
289    }
290}
291
292/// Fitted Wasserstein kernel sampler
293pub struct FittedWassersteinSampler {
294    sampler: WassersteinKernelSampler,
295}
296
297impl Transform<Array2<f64>, Array2<f64>> for FittedWassersteinSampler {
298    fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
299        match self.sampler.transport_method {
300            TransportMethod::SlicedWasserstein => self.sampler.sliced_wasserstein_features(x),
301            TransportMethod::Sinkhorn => self.sampler.sinkhorn_features(x),
302            TransportMethod::TreeWasserstein => self.sampler.tree_wasserstein_features(x),
303            TransportMethod::ProjectionBased => self.sampler.sliced_wasserstein_features(x),
304        }
305    }
306}
307
308/// Earth Mover's Distance (EMD) kernel approximation
309///
310/// Implements approximations to kernels based on the earth mover's distance,
311/// also known as the Wasserstein-1 distance.
312#[derive(Debug, Clone)]
313/// EMDKernelSampler
314pub struct EMDKernelSampler {
315    /// Number of random features
316    n_components: usize,
317    /// Bandwidth parameter for the kernel
318    bandwidth: f64,
319    /// Random state for reproducibility
320    random_state: Option<u64>,
321    /// Fitted random projections
322    projections: Option<Array2<f64>>,
323    /// Number of quantile bins for approximation
324    n_bins: usize,
325}
326
327impl Default for EMDKernelSampler {
328    fn default() -> Self {
329        Self::new(100)
330    }
331}
332
333impl EMDKernelSampler {
334    /// Create a new EMD kernel sampler
335    pub fn new(n_components: usize) -> Self {
336        Self {
337            n_components,
338            bandwidth: 1.0,
339            random_state: None,
340            projections: None,
341            n_bins: 20,
342        }
343    }
344
345    /// Set the bandwidth parameter
346    pub fn bandwidth(mut self, bandwidth: f64) -> Self {
347        self.bandwidth = bandwidth;
348        self
349    }
350
351    /// Set the random state
352    pub fn random_state(mut self, seed: u64) -> Self {
353        self.random_state = Some(seed);
354        self
355    }
356
357    /// Set the number of quantile bins
358    pub fn n_bins(mut self, n_bins: usize) -> Self {
359        self.n_bins = n_bins;
360        self
361    }
362}
363
364impl Fit<Array2<f64>, ()> for EMDKernelSampler {
365    type Fitted = FittedEMDSampler;
366
367    fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
368        let n_features = x.ncols();
369
370        let mut rng = if let Some(seed) = self.random_state {
371            RealStdRng::seed_from_u64(seed)
372        } else {
373            RealStdRng::from_seed(thread_rng().random())
374        };
375
376        let uniform = RandUniform::new(-1.0, 1.0).expect("operation should succeed");
377        let mut projections = Array2::zeros((n_features, self.n_components));
378
379        for mut col in projections.axis_iter_mut(Axis(1)) {
380            for elem in col.iter_mut() {
381                *elem = rng.sample(uniform);
382            }
383            let norm: f64 = col.mapv(|v| v * v).sum();
384            let norm = norm.sqrt();
385            col /= norm;
386        }
387
388        Ok(FittedEMDSampler {
389            sampler: EMDKernelSampler {
390                projections: Some(projections),
391                ..self.clone()
392            },
393        })
394    }
395}
396
397/// Fitted EMD kernel sampler
398pub struct FittedEMDSampler {
399    sampler: EMDKernelSampler,
400}
401
402impl Transform<Array2<f64>, Array2<f64>> for FittedEMDSampler {
403    fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
404        let projections =
405            self.sampler
406                .projections
407                .as_ref()
408                .ok_or_else(|| SklearsError::NotFitted {
409                    operation: "transform".to_string(),
410                })?;
411
412        let n_samples = x.nrows();
413        let mut features = Array2::zeros((n_samples, self.sampler.n_components));
414
415        let projected = x.dot(projections);
416
417        for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
418            // Compute quantile-based features for EMD approximation
419            let min_val = proj_col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
420            let max_val = proj_col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
421            let range = max_val - min_val;
422
423            if range > 1e-10 {
424                for (i, &val) in proj_col.iter().enumerate() {
425                    // Compute cumulative distribution function approximation
426                    let normalized_val = (val - min_val) / range;
427                    let emd_feature = (-normalized_val.abs() / self.sampler.bandwidth).exp();
428                    features[[i, j]] = emd_feature;
429                }
430            }
431        }
432
433        Ok(features)
434    }
435}
436
437/// Gromov-Wasserstein kernel approximation
438///
439/// Implements approximations to Gromov-Wasserstein distances which compare
440/// metric measure spaces by their intrinsic geometry.
441#[derive(Debug, Clone)]
442/// GromovWassersteinSampler
443pub struct GromovWassersteinSampler {
444    /// Number of random features
445    n_components: usize,
446    /// Loss function for Gromov-Wasserstein
447    loss_function: GWLossFunction,
448    /// Number of iterations for optimization
449    max_iter: usize,
450    /// Random state
451    random_state: Option<u64>,
452    /// Fitted parameters
453    fitted_params: Option<GWFittedParams>,
454}
455
456/// Loss functions for Gromov-Wasserstein
457#[derive(Debug, Clone)]
458/// GWLossFunction
459pub enum GWLossFunction {
460    /// Squared loss
461    Square,
462    /// KL divergence
463    KlLoss,
464    /// Custom loss function
465    Custom(fn(f64, f64, f64, f64) -> f64),
466}
467
468#[derive(Debug, Clone)]
469struct GWFittedParams {
470    reference_distances: Array2<f64>,
471    projections: Array2<f64>,
472}
473
474impl Default for GromovWassersteinSampler {
475    fn default() -> Self {
476        Self::new(50)
477    }
478}
479
480impl GromovWassersteinSampler {
481    /// Create a new Gromov-Wasserstein sampler
482    pub fn new(n_components: usize) -> Self {
483        Self {
484            n_components,
485            loss_function: GWLossFunction::Square,
486            max_iter: 100,
487            random_state: None,
488            fitted_params: None,
489        }
490    }
491
492    /// Set the loss function
493    pub fn loss_function(mut self, loss: GWLossFunction) -> Self {
494        self.loss_function = loss;
495        self
496    }
497
498    /// Set random state
499    pub fn random_state(mut self, seed: u64) -> Self {
500        self.random_state = Some(seed);
501        self
502    }
503
504    /// Compute pairwise distances for metric space comparison
505    fn compute_distance_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
506        let n = x.nrows();
507        let mut distances = Array2::zeros((n, n));
508
509        for i in 0..n {
510            for j in i..n {
511                let diff = &x.row(i).to_owned() - &x.row(j);
512                let dist: f64 = diff.mapv(|v| v * v).sum();
513                let dist = dist.sqrt();
514                distances[[i, j]] = dist;
515                distances[[j, i]] = dist;
516            }
517        }
518
519        distances
520    }
521}
522
523impl Fit<Array2<f64>, ()> for GromovWassersteinSampler {
524    type Fitted = FittedGromovWassersteinSampler;
525
526    fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
527        let n_features = x.ncols();
528        let _n_samples = x.nrows();
529
530        // Compute reference distance matrix
531        let distance_matrix = self.compute_distance_matrix(x);
532
533        // Generate random projections for dimensionality reduction
534        let mut rng = if let Some(seed) = self.random_state {
535            RealStdRng::seed_from_u64(seed)
536        } else {
537            RealStdRng::from_seed(thread_rng().random())
538        };
539
540        let uniform = RandUniform::new(-1.0, 1.0).expect("operation should succeed");
541        let mut projections = Array2::zeros((n_features, self.n_components));
542
543        for mut col in projections.axis_iter_mut(Axis(1)) {
544            for elem in col.iter_mut() {
545                *elem = rng.sample(uniform);
546            }
547            let norm: f64 = col.mapv(|v| v * v).sum();
548            let norm = norm.sqrt();
549            col /= norm;
550        }
551
552        let fitted_params = GWFittedParams {
553            reference_distances: distance_matrix,
554            projections,
555        };
556
557        Ok(FittedGromovWassersteinSampler {
558            sampler: GromovWassersteinSampler {
559                fitted_params: Some(fitted_params),
560                ..self.clone()
561            },
562        })
563    }
564}
565
566/// Fitted Gromov-Wasserstein sampler
567pub struct FittedGromovWassersteinSampler {
568    sampler: GromovWassersteinSampler,
569}
570
571impl Transform<Array2<f64>, Array2<f64>> for FittedGromovWassersteinSampler {
572    fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
573        let params =
574            self.sampler
575                .fitted_params
576                .as_ref()
577                .ok_or_else(|| SklearsError::NotFitted {
578                    operation: "transform".to_string(),
579                })?;
580
581        let n_samples = x.nrows();
582        let mut features = Array2::zeros((n_samples, self.sampler.n_components));
583
584        // Project data and compute distance-based features
585        let projected = x.dot(&params.projections);
586
587        for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
588            for (i, &val) in proj_col.iter().enumerate() {
589                // Use projected coordinates as distance-based features
590                features[[i, j]] = val.tanh(); // Bounded activation
591            }
592        }
593
594        Ok(features)
595    }
596}
597
598#[allow(non_snake_case)]
599#[cfg(test)]
600mod tests {
601    use super::*;
602    use scirs2_core::ndarray::array;
603    use sklears_core::traits::{Fit, Transform};
604
605    #[test]
606    fn test_wasserstein_kernel_sampler() {
607        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
608
609        let sampler = WassersteinKernelSampler::new(10)
610            .random_state(42)
611            .transport_method(TransportMethod::SlicedWasserstein);
612
613        let fitted = sampler.fit(&x, &()).expect("operation should succeed");
614        let features = fitted.transform(&x).expect("operation should succeed");
615
616        assert_eq!(features.shape(), &[4, 10]);
617
618        // Features should be bounded
619        assert!(features.iter().all(|&f| f >= 0.0 && f <= 1.0));
620    }
621
622    #[test]
623    fn test_emd_kernel_sampler() {
624        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
625
626        let sampler = EMDKernelSampler::new(15).bandwidth(0.5).random_state(123);
627
628        let fitted = sampler.fit(&x, &()).expect("operation should succeed");
629        let features = fitted.transform(&x).expect("operation should succeed");
630
631        assert_eq!(features.shape(), &[3, 15]);
632
633        // Features should be positive (exponential)
634        assert!(features.iter().all(|&f| f >= 0.0));
635    }
636
637    #[test]
638    fn test_gromov_wasserstein_sampler() {
639        let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0],];
640
641        let sampler = GromovWassersteinSampler::new(8).random_state(456);
642
643        let fitted = sampler.fit(&x, &()).expect("operation should succeed");
644        let features = fitted.transform(&x).expect("operation should succeed");
645
646        assert_eq!(features.shape(), &[4, 8]);
647
648        // Features should be bounded by tanh
649        assert!(features.iter().all(|&f| f >= -1.0 && f <= 1.0));
650    }
651
652    #[test]
653    fn test_wasserstein_different_methods() {
654        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
655
656        let methods = vec![
657            TransportMethod::SlicedWasserstein,
658            TransportMethod::Sinkhorn,
659            TransportMethod::TreeWasserstein,
660        ];
661
662        for method in methods {
663            let sampler = WassersteinKernelSampler::new(5)
664                .transport_method(method)
665                .random_state(42);
666
667            let fitted = sampler.fit(&x, &()).expect("operation should succeed");
668            let features = fitted.transform(&x).expect("operation should succeed");
669
670            assert_eq!(features.shape(), &[3, 5]);
671        }
672    }
673
674    #[test]
675    fn test_different_ground_metrics() {
676        let x = array![[1.0, 2.0], [2.0, 3.0],];
677
678        let metrics = vec![
679            GroundMetric::SquaredEuclidean,
680            GroundMetric::Euclidean,
681            GroundMetric::Manhattan,
682            GroundMetric::Minkowski(3.0),
683        ];
684
685        for metric in metrics {
686            let sampler = WassersteinKernelSampler::new(3)
687                .ground_metric(metric)
688                .random_state(42);
689
690            let fitted = sampler.fit(&x, &()).expect("operation should succeed");
691            let features = fitted.transform(&x).expect("operation should succeed");
692
693            assert_eq!(features.shape(), &[2, 3]);
694        }
695    }
696
697    #[test]
698    fn test_reproducibility() {
699        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
700
701        let sampler1 = WassersteinKernelSampler::new(5).random_state(42);
702        let sampler2 = WassersteinKernelSampler::new(5).random_state(42);
703
704        let fitted1 = sampler1.fit(&x, &()).expect("operation should succeed");
705        let fitted2 = sampler2.fit(&x, &()).expect("operation should succeed");
706
707        let features1 = fitted1.transform(&x).expect("operation should succeed");
708        let features2 = fitted2.transform(&x).expect("operation should succeed");
709
710        // Results should be identical with same random state
711        assert!((features1 - features2).mapv(|v| v.abs()).sum() < 1e-10);
712    }
713}