1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_core::rand_prelude::IteratorRandom;
8use scirs2_core::random::essentials::Uniform as RandUniform;
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::RngExt;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::{
13 error::{Result as SklResult, SklearsError},
14 prelude::{Fit, Transform},
15};
16
17#[derive(Debug, Clone)]
22pub struct WassersteinKernelSampler {
24 n_components: usize,
26 ground_metric: GroundMetric,
28 epsilon: f64,
30 max_iter: usize,
32 tolerance: f64,
34 random_state: Option<u64>,
36 projections: Option<Array2<f64>>,
38 transport_method: TransportMethod,
40}
41
42#[derive(Debug, Clone)]
44pub enum GroundMetric {
46 SquaredEuclidean,
48 Euclidean,
50 Manhattan,
52 Minkowski(f64),
54 Custom(fn(&Array1<f64>, &Array1<f64>) -> f64),
56}
57
58#[derive(Debug, Clone)]
60pub enum TransportMethod {
62 Sinkhorn,
64 SlicedWasserstein,
66 TreeWasserstein,
68 ProjectionBased,
70}
71
72impl Default for WassersteinKernelSampler {
73 fn default() -> Self {
74 Self::new(100)
75 }
76}
77
78impl WassersteinKernelSampler {
79 pub fn new(n_components: usize) -> Self {
90 Self {
91 n_components,
92 ground_metric: GroundMetric::SquaredEuclidean,
93 epsilon: 0.1,
94 max_iter: 1000,
95 tolerance: 1e-9,
96 random_state: None,
97 projections: None,
98 transport_method: TransportMethod::SlicedWasserstein,
99 }
100 }
101
102 pub fn ground_metric(mut self, metric: GroundMetric) -> Self {
104 self.ground_metric = metric;
105 self
106 }
107
108 pub fn epsilon(mut self, epsilon: f64) -> Self {
110 self.epsilon = epsilon;
111 self
112 }
113
114 pub fn max_iter(mut self, max_iter: usize) -> Self {
116 self.max_iter = max_iter;
117 self
118 }
119
120 pub fn tolerance(mut self, tolerance: f64) -> Self {
122 self.tolerance = tolerance;
123 self
124 }
125
126 pub fn random_state(mut self, seed: u64) -> Self {
128 self.random_state = Some(seed);
129 self
130 }
131
132 pub fn transport_method(mut self, method: TransportMethod) -> Self {
134 self.transport_method = method;
135 self
136 }
137
138 fn compute_ground_distance(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
140 match &self.ground_metric {
141 GroundMetric::SquaredEuclidean => (x - y).mapv(|v| v * v).sum(),
142 GroundMetric::Euclidean => (x - y).mapv(|v| v * v).sum().sqrt(),
143 GroundMetric::Manhattan => (x - y).mapv(|v| v.abs()).sum(),
144 GroundMetric::Minkowski(p) => (x - y).mapv(|v| v.abs().powf(*p)).sum().powf(1.0 / p),
145 GroundMetric::Custom(func) => func(x, y),
146 }
147 }
148
149 fn sliced_wasserstein_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
151 let projections = self
152 .projections
153 .as_ref()
154 .ok_or_else(|| SklearsError::NotFitted {
155 operation: "transform".to_string(),
156 })?;
157
158 let n_samples = x.nrows();
159 let mut features = Array2::zeros((n_samples, self.n_components));
160
161 let projected = x.dot(projections);
163
164 for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
166 let mut sorted_proj: Vec<f64> = proj_col.to_vec();
167 sorted_proj.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
168
169 for (i, &val) in proj_col.iter().enumerate() {
171 let quantile = sorted_proj
173 .binary_search_by(|&probe| {
174 probe.partial_cmp(&val).expect("operation should succeed")
175 })
176 .unwrap_or_else(|e| e) as f64
177 / sorted_proj.len() as f64;
178
179 features[[i, j]] = quantile;
180 }
181 }
182
183 Ok(features)
184 }
185
186 fn sinkhorn_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
188 let n_samples = x.nrows();
189 let mut features = Array2::zeros((n_samples, self.n_components));
190
191 let mut rng = if let Some(seed) = self.random_state {
193 RealStdRng::seed_from_u64(seed)
194 } else {
195 RealStdRng::from_seed(thread_rng().random())
196 };
197
198 for j in 0..self.n_components {
199 let subset_size = (n_samples as f64).sqrt() as usize + 1;
201 let indices: Vec<usize> = (0..n_samples)
202 .choose_multiple(&mut rng, subset_size)
203 .into_iter()
204 .collect();
205
206 for (i, x_i) in x.axis_iter(Axis(0)).enumerate() {
207 let mut total_divergence = 0.0;
208
209 for &idx in &indices {
210 if idx != i {
211 let x_j = x.row(idx);
212 let distance =
213 self.compute_ground_distance(&x_i.to_owned(), &x_j.to_owned());
214 total_divergence += (-distance / self.epsilon).exp();
215 }
216 }
217
218 features[[i, j]] = total_divergence / indices.len() as f64;
219 }
220 }
221
222 Ok(features)
223 }
224
225 fn tree_wasserstein_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
227 let n_samples = x.nrows();
228 let mut features = Array2::zeros((n_samples, self.n_components));
229
230 let projections = self
232 .projections
233 .as_ref()
234 .ok_or_else(|| SklearsError::NotFitted {
235 operation: "transform".to_string(),
236 })?;
237
238 let projected = x.dot(projections);
239
240 for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
241 let min_val = proj_col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
243 let max_val = proj_col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
244 let bin_width = (max_val - min_val) / 10.0; for (i, &val) in proj_col.iter().enumerate() {
247 let bin = ((val - min_val) / bin_width).floor() as usize;
248 let bin = bin.min(9); features[[i, j]] = bin as f64 / 9.0; }
251 }
252
253 Ok(features)
254 }
255}
256
257impl Fit<Array2<f64>, ()> for WassersteinKernelSampler {
258 type Fitted = FittedWassersteinSampler;
259
260 fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
261 let n_features = x.ncols();
262
263 let mut rng = if let Some(seed) = self.random_state {
265 RealStdRng::seed_from_u64(seed)
266 } else {
267 RealStdRng::from_seed(thread_rng().random())
268 };
269
270 let uniform = RandUniform::new(-1.0, 1.0).expect("operation should succeed");
271 let mut projections = Array2::zeros((n_features, self.n_components));
272
273 for mut col in projections.axis_iter_mut(Axis(1)) {
274 for elem in col.iter_mut() {
275 *elem = rng.sample(uniform);
276 }
277 let norm: f64 = col.mapv(|v| v * v).sum();
279 let norm = norm.sqrt();
280 col /= norm;
281 }
282
283 Ok(FittedWassersteinSampler {
284 sampler: WassersteinKernelSampler {
285 projections: Some(projections),
286 ..self.clone()
287 },
288 })
289 }
290}
291
292pub struct FittedWassersteinSampler {
294 sampler: WassersteinKernelSampler,
295}
296
297impl Transform<Array2<f64>, Array2<f64>> for FittedWassersteinSampler {
298 fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
299 match self.sampler.transport_method {
300 TransportMethod::SlicedWasserstein => self.sampler.sliced_wasserstein_features(x),
301 TransportMethod::Sinkhorn => self.sampler.sinkhorn_features(x),
302 TransportMethod::TreeWasserstein => self.sampler.tree_wasserstein_features(x),
303 TransportMethod::ProjectionBased => self.sampler.sliced_wasserstein_features(x),
304 }
305 }
306}
307
308#[derive(Debug, Clone)]
313pub struct EMDKernelSampler {
315 n_components: usize,
317 bandwidth: f64,
319 random_state: Option<u64>,
321 projections: Option<Array2<f64>>,
323 n_bins: usize,
325}
326
327impl Default for EMDKernelSampler {
328 fn default() -> Self {
329 Self::new(100)
330 }
331}
332
333impl EMDKernelSampler {
334 pub fn new(n_components: usize) -> Self {
336 Self {
337 n_components,
338 bandwidth: 1.0,
339 random_state: None,
340 projections: None,
341 n_bins: 20,
342 }
343 }
344
345 pub fn bandwidth(mut self, bandwidth: f64) -> Self {
347 self.bandwidth = bandwidth;
348 self
349 }
350
351 pub fn random_state(mut self, seed: u64) -> Self {
353 self.random_state = Some(seed);
354 self
355 }
356
357 pub fn n_bins(mut self, n_bins: usize) -> Self {
359 self.n_bins = n_bins;
360 self
361 }
362}
363
364impl Fit<Array2<f64>, ()> for EMDKernelSampler {
365 type Fitted = FittedEMDSampler;
366
367 fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
368 let n_features = x.ncols();
369
370 let mut rng = if let Some(seed) = self.random_state {
371 RealStdRng::seed_from_u64(seed)
372 } else {
373 RealStdRng::from_seed(thread_rng().random())
374 };
375
376 let uniform = RandUniform::new(-1.0, 1.0).expect("operation should succeed");
377 let mut projections = Array2::zeros((n_features, self.n_components));
378
379 for mut col in projections.axis_iter_mut(Axis(1)) {
380 for elem in col.iter_mut() {
381 *elem = rng.sample(uniform);
382 }
383 let norm: f64 = col.mapv(|v| v * v).sum();
384 let norm = norm.sqrt();
385 col /= norm;
386 }
387
388 Ok(FittedEMDSampler {
389 sampler: EMDKernelSampler {
390 projections: Some(projections),
391 ..self.clone()
392 },
393 })
394 }
395}
396
397pub struct FittedEMDSampler {
399 sampler: EMDKernelSampler,
400}
401
402impl Transform<Array2<f64>, Array2<f64>> for FittedEMDSampler {
403 fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
404 let projections =
405 self.sampler
406 .projections
407 .as_ref()
408 .ok_or_else(|| SklearsError::NotFitted {
409 operation: "transform".to_string(),
410 })?;
411
412 let n_samples = x.nrows();
413 let mut features = Array2::zeros((n_samples, self.sampler.n_components));
414
415 let projected = x.dot(projections);
416
417 for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
418 let min_val = proj_col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
420 let max_val = proj_col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
421 let range = max_val - min_val;
422
423 if range > 1e-10 {
424 for (i, &val) in proj_col.iter().enumerate() {
425 let normalized_val = (val - min_val) / range;
427 let emd_feature = (-normalized_val.abs() / self.sampler.bandwidth).exp();
428 features[[i, j]] = emd_feature;
429 }
430 }
431 }
432
433 Ok(features)
434 }
435}
436
437#[derive(Debug, Clone)]
442pub struct GromovWassersteinSampler {
444 n_components: usize,
446 loss_function: GWLossFunction,
448 max_iter: usize,
450 random_state: Option<u64>,
452 fitted_params: Option<GWFittedParams>,
454}
455
456#[derive(Debug, Clone)]
458pub enum GWLossFunction {
460 Square,
462 KlLoss,
464 Custom(fn(f64, f64, f64, f64) -> f64),
466}
467
468#[derive(Debug, Clone)]
469struct GWFittedParams {
470 reference_distances: Array2<f64>,
471 projections: Array2<f64>,
472}
473
474impl Default for GromovWassersteinSampler {
475 fn default() -> Self {
476 Self::new(50)
477 }
478}
479
480impl GromovWassersteinSampler {
481 pub fn new(n_components: usize) -> Self {
483 Self {
484 n_components,
485 loss_function: GWLossFunction::Square,
486 max_iter: 100,
487 random_state: None,
488 fitted_params: None,
489 }
490 }
491
492 pub fn loss_function(mut self, loss: GWLossFunction) -> Self {
494 self.loss_function = loss;
495 self
496 }
497
498 pub fn random_state(mut self, seed: u64) -> Self {
500 self.random_state = Some(seed);
501 self
502 }
503
504 fn compute_distance_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
506 let n = x.nrows();
507 let mut distances = Array2::zeros((n, n));
508
509 for i in 0..n {
510 for j in i..n {
511 let diff = &x.row(i).to_owned() - &x.row(j);
512 let dist: f64 = diff.mapv(|v| v * v).sum();
513 let dist = dist.sqrt();
514 distances[[i, j]] = dist;
515 distances[[j, i]] = dist;
516 }
517 }
518
519 distances
520 }
521}
522
523impl Fit<Array2<f64>, ()> for GromovWassersteinSampler {
524 type Fitted = FittedGromovWassersteinSampler;
525
526 fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
527 let n_features = x.ncols();
528 let _n_samples = x.nrows();
529
530 let distance_matrix = self.compute_distance_matrix(x);
532
533 let mut rng = if let Some(seed) = self.random_state {
535 RealStdRng::seed_from_u64(seed)
536 } else {
537 RealStdRng::from_seed(thread_rng().random())
538 };
539
540 let uniform = RandUniform::new(-1.0, 1.0).expect("operation should succeed");
541 let mut projections = Array2::zeros((n_features, self.n_components));
542
543 for mut col in projections.axis_iter_mut(Axis(1)) {
544 for elem in col.iter_mut() {
545 *elem = rng.sample(uniform);
546 }
547 let norm: f64 = col.mapv(|v| v * v).sum();
548 let norm = norm.sqrt();
549 col /= norm;
550 }
551
552 let fitted_params = GWFittedParams {
553 reference_distances: distance_matrix,
554 projections,
555 };
556
557 Ok(FittedGromovWassersteinSampler {
558 sampler: GromovWassersteinSampler {
559 fitted_params: Some(fitted_params),
560 ..self.clone()
561 },
562 })
563 }
564}
565
566pub struct FittedGromovWassersteinSampler {
568 sampler: GromovWassersteinSampler,
569}
570
571impl Transform<Array2<f64>, Array2<f64>> for FittedGromovWassersteinSampler {
572 fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
573 let params =
574 self.sampler
575 .fitted_params
576 .as_ref()
577 .ok_or_else(|| SklearsError::NotFitted {
578 operation: "transform".to_string(),
579 })?;
580
581 let n_samples = x.nrows();
582 let mut features = Array2::zeros((n_samples, self.sampler.n_components));
583
584 let projected = x.dot(¶ms.projections);
586
587 for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
588 for (i, &val) in proj_col.iter().enumerate() {
589 features[[i, j]] = val.tanh(); }
592 }
593
594 Ok(features)
595 }
596}
597
598#[allow(non_snake_case)]
599#[cfg(test)]
600mod tests {
601 use super::*;
602 use scirs2_core::ndarray::array;
603 use sklears_core::traits::{Fit, Transform};
604
605 #[test]
606 fn test_wasserstein_kernel_sampler() {
607 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
608
609 let sampler = WassersteinKernelSampler::new(10)
610 .random_state(42)
611 .transport_method(TransportMethod::SlicedWasserstein);
612
613 let fitted = sampler.fit(&x, &()).expect("operation should succeed");
614 let features = fitted.transform(&x).expect("operation should succeed");
615
616 assert_eq!(features.shape(), &[4, 10]);
617
618 assert!(features.iter().all(|&f| f >= 0.0 && f <= 1.0));
620 }
621
622 #[test]
623 fn test_emd_kernel_sampler() {
624 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
625
626 let sampler = EMDKernelSampler::new(15).bandwidth(0.5).random_state(123);
627
628 let fitted = sampler.fit(&x, &()).expect("operation should succeed");
629 let features = fitted.transform(&x).expect("operation should succeed");
630
631 assert_eq!(features.shape(), &[3, 15]);
632
633 assert!(features.iter().all(|&f| f >= 0.0));
635 }
636
637 #[test]
638 fn test_gromov_wasserstein_sampler() {
639 let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0],];
640
641 let sampler = GromovWassersteinSampler::new(8).random_state(456);
642
643 let fitted = sampler.fit(&x, &()).expect("operation should succeed");
644 let features = fitted.transform(&x).expect("operation should succeed");
645
646 assert_eq!(features.shape(), &[4, 8]);
647
648 assert!(features.iter().all(|&f| f >= -1.0 && f <= 1.0));
650 }
651
652 #[test]
653 fn test_wasserstein_different_methods() {
654 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
655
656 let methods = vec![
657 TransportMethod::SlicedWasserstein,
658 TransportMethod::Sinkhorn,
659 TransportMethod::TreeWasserstein,
660 ];
661
662 for method in methods {
663 let sampler = WassersteinKernelSampler::new(5)
664 .transport_method(method)
665 .random_state(42);
666
667 let fitted = sampler.fit(&x, &()).expect("operation should succeed");
668 let features = fitted.transform(&x).expect("operation should succeed");
669
670 assert_eq!(features.shape(), &[3, 5]);
671 }
672 }
673
674 #[test]
675 fn test_different_ground_metrics() {
676 let x = array![[1.0, 2.0], [2.0, 3.0],];
677
678 let metrics = vec![
679 GroundMetric::SquaredEuclidean,
680 GroundMetric::Euclidean,
681 GroundMetric::Manhattan,
682 GroundMetric::Minkowski(3.0),
683 ];
684
685 for metric in metrics {
686 let sampler = WassersteinKernelSampler::new(3)
687 .ground_metric(metric)
688 .random_state(42);
689
690 let fitted = sampler.fit(&x, &()).expect("operation should succeed");
691 let features = fitted.transform(&x).expect("operation should succeed");
692
693 assert_eq!(features.shape(), &[2, 3]);
694 }
695 }
696
697 #[test]
698 fn test_reproducibility() {
699 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
700
701 let sampler1 = WassersteinKernelSampler::new(5).random_state(42);
702 let sampler2 = WassersteinKernelSampler::new(5).random_state(42);
703
704 let fitted1 = sampler1.fit(&x, &()).expect("operation should succeed");
705 let fitted2 = sampler2.fit(&x, &()).expect("operation should succeed");
706
707 let features1 = fitted1.transform(&x).expect("operation should succeed");
708 let features2 = fitted2.transform(&x).expect("operation should succeed");
709
710 assert!((features1 - features2).mapv(|v| v.abs()).sum() < 1e-10);
712 }
713}