1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_core::rand_prelude::IteratorRandom;
8use scirs2_core::random::essentials::Uniform as RandUniform;
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::{thread_rng, SeedableRng};
12use sklears_core::{
13 error::{Result as SklResult, SklearsError},
14 prelude::{Fit, Transform},
15};
16
17#[derive(Debug, Clone)]
22pub struct WassersteinKernelSampler {
24 n_components: usize,
26 ground_metric: GroundMetric,
28 epsilon: f64,
30 max_iter: usize,
32 tolerance: f64,
34 random_state: Option<u64>,
36 projections: Option<Array2<f64>>,
38 transport_method: TransportMethod,
40}
41
42#[derive(Debug, Clone)]
44pub enum GroundMetric {
46 SquaredEuclidean,
48 Euclidean,
50 Manhattan,
52 Minkowski(f64),
54 Custom(fn(&Array1<f64>, &Array1<f64>) -> f64),
56}
57
58#[derive(Debug, Clone)]
60pub enum TransportMethod {
62 Sinkhorn,
64 SlicedWasserstein,
66 TreeWasserstein,
68 ProjectionBased,
70}
71
72impl Default for WassersteinKernelSampler {
73 fn default() -> Self {
74 Self::new(100)
75 }
76}
77
78impl WassersteinKernelSampler {
79 pub fn new(n_components: usize) -> Self {
90 Self {
91 n_components,
92 ground_metric: GroundMetric::SquaredEuclidean,
93 epsilon: 0.1,
94 max_iter: 1000,
95 tolerance: 1e-9,
96 random_state: None,
97 projections: None,
98 transport_method: TransportMethod::SlicedWasserstein,
99 }
100 }
101
102 pub fn ground_metric(mut self, metric: GroundMetric) -> Self {
104 self.ground_metric = metric;
105 self
106 }
107
108 pub fn epsilon(mut self, epsilon: f64) -> Self {
110 self.epsilon = epsilon;
111 self
112 }
113
114 pub fn max_iter(mut self, max_iter: usize) -> Self {
116 self.max_iter = max_iter;
117 self
118 }
119
120 pub fn tolerance(mut self, tolerance: f64) -> Self {
122 self.tolerance = tolerance;
123 self
124 }
125
126 pub fn random_state(mut self, seed: u64) -> Self {
128 self.random_state = Some(seed);
129 self
130 }
131
132 pub fn transport_method(mut self, method: TransportMethod) -> Self {
134 self.transport_method = method;
135 self
136 }
137
138 fn compute_ground_distance(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
140 match &self.ground_metric {
141 GroundMetric::SquaredEuclidean => (x - y).mapv(|v| v * v).sum(),
142 GroundMetric::Euclidean => (x - y).mapv(|v| v * v).sum().sqrt(),
143 GroundMetric::Manhattan => (x - y).mapv(|v| v.abs()).sum(),
144 GroundMetric::Minkowski(p) => (x - y).mapv(|v| v.abs().powf(*p)).sum().powf(1.0 / p),
145 GroundMetric::Custom(func) => func(x, y),
146 }
147 }
148
149 fn sliced_wasserstein_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
151 let projections = self
152 .projections
153 .as_ref()
154 .ok_or_else(|| SklearsError::NotFitted {
155 operation: "transform".to_string(),
156 })?;
157
158 let n_samples = x.nrows();
159 let mut features = Array2::zeros((n_samples, self.n_components));
160
161 let projected = x.dot(projections);
163
164 for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
166 let mut sorted_proj: Vec<f64> = proj_col.to_vec();
167 sorted_proj.sort_by(|a, b| a.partial_cmp(b).unwrap());
168
169 for (i, &val) in proj_col.iter().enumerate() {
171 let quantile = sorted_proj
173 .binary_search_by(|&probe| probe.partial_cmp(&val).unwrap())
174 .unwrap_or_else(|e| e) as f64
175 / sorted_proj.len() as f64;
176
177 features[[i, j]] = quantile;
178 }
179 }
180
181 Ok(features)
182 }
183
184 fn sinkhorn_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
186 let n_samples = x.nrows();
187 let mut features = Array2::zeros((n_samples, self.n_components));
188
189 let mut rng = if let Some(seed) = self.random_state {
191 RealStdRng::seed_from_u64(seed)
192 } else {
193 RealStdRng::from_seed(thread_rng().gen())
194 };
195
196 for j in 0..self.n_components {
197 let subset_size = (n_samples as f64).sqrt() as usize + 1;
199 let indices: Vec<usize> = (0..n_samples)
200 .choose_multiple(&mut rng, subset_size)
201 .into_iter()
202 .collect();
203
204 for (i, x_i) in x.axis_iter(Axis(0)).enumerate() {
205 let mut total_divergence = 0.0;
206
207 for &idx in &indices {
208 if idx != i {
209 let x_j = x.row(idx);
210 let distance =
211 self.compute_ground_distance(&x_i.to_owned(), &x_j.to_owned());
212 total_divergence += (-distance / self.epsilon).exp();
213 }
214 }
215
216 features[[i, j]] = total_divergence / indices.len() as f64;
217 }
218 }
219
220 Ok(features)
221 }
222
223 fn tree_wasserstein_features(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
225 let n_samples = x.nrows();
226 let mut features = Array2::zeros((n_samples, self.n_components));
227
228 let projections = self
230 .projections
231 .as_ref()
232 .ok_or_else(|| SklearsError::NotFitted {
233 operation: "transform".to_string(),
234 })?;
235
236 let projected = x.dot(projections);
237
238 for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
239 let min_val = proj_col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
241 let max_val = proj_col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
242 let bin_width = (max_val - min_val) / 10.0; for (i, &val) in proj_col.iter().enumerate() {
245 let bin = ((val - min_val) / bin_width).floor() as usize;
246 let bin = bin.min(9); features[[i, j]] = bin as f64 / 9.0; }
249 }
250
251 Ok(features)
252 }
253}
254
255impl Fit<Array2<f64>, ()> for WassersteinKernelSampler {
256 type Fitted = FittedWassersteinSampler;
257
258 fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
259 let n_features = x.ncols();
260
261 let mut rng = if let Some(seed) = self.random_state {
263 RealStdRng::seed_from_u64(seed)
264 } else {
265 RealStdRng::from_seed(thread_rng().gen())
266 };
267
268 let uniform = RandUniform::new(-1.0, 1.0).unwrap();
269 let mut projections = Array2::zeros((n_features, self.n_components));
270
271 for mut col in projections.axis_iter_mut(Axis(1)) {
272 for elem in col.iter_mut() {
273 *elem = rng.sample(uniform);
274 }
275 let norm: f64 = col.mapv(|v| v * v).sum();
277 let norm = norm.sqrt();
278 col /= norm;
279 }
280
281 Ok(FittedWassersteinSampler {
282 sampler: WassersteinKernelSampler {
283 projections: Some(projections),
284 ..self.clone()
285 },
286 })
287 }
288}
289
290pub struct FittedWassersteinSampler {
292 sampler: WassersteinKernelSampler,
293}
294
295impl Transform<Array2<f64>, Array2<f64>> for FittedWassersteinSampler {
296 fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
297 match self.sampler.transport_method {
298 TransportMethod::SlicedWasserstein => self.sampler.sliced_wasserstein_features(x),
299 TransportMethod::Sinkhorn => self.sampler.sinkhorn_features(x),
300 TransportMethod::TreeWasserstein => self.sampler.tree_wasserstein_features(x),
301 TransportMethod::ProjectionBased => self.sampler.sliced_wasserstein_features(x),
302 }
303 }
304}
305
306#[derive(Debug, Clone)]
311pub struct EMDKernelSampler {
313 n_components: usize,
315 bandwidth: f64,
317 random_state: Option<u64>,
319 projections: Option<Array2<f64>>,
321 n_bins: usize,
323}
324
325impl Default for EMDKernelSampler {
326 fn default() -> Self {
327 Self::new(100)
328 }
329}
330
331impl EMDKernelSampler {
332 pub fn new(n_components: usize) -> Self {
334 Self {
335 n_components,
336 bandwidth: 1.0,
337 random_state: None,
338 projections: None,
339 n_bins: 20,
340 }
341 }
342
343 pub fn bandwidth(mut self, bandwidth: f64) -> Self {
345 self.bandwidth = bandwidth;
346 self
347 }
348
349 pub fn random_state(mut self, seed: u64) -> Self {
351 self.random_state = Some(seed);
352 self
353 }
354
355 pub fn n_bins(mut self, n_bins: usize) -> Self {
357 self.n_bins = n_bins;
358 self
359 }
360}
361
362impl Fit<Array2<f64>, ()> for EMDKernelSampler {
363 type Fitted = FittedEMDSampler;
364
365 fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
366 let n_features = x.ncols();
367
368 let mut rng = if let Some(seed) = self.random_state {
369 RealStdRng::seed_from_u64(seed)
370 } else {
371 RealStdRng::from_seed(thread_rng().gen())
372 };
373
374 let uniform = RandUniform::new(-1.0, 1.0).unwrap();
375 let mut projections = Array2::zeros((n_features, self.n_components));
376
377 for mut col in projections.axis_iter_mut(Axis(1)) {
378 for elem in col.iter_mut() {
379 *elem = rng.sample(uniform);
380 }
381 let norm: f64 = col.mapv(|v| v * v).sum();
382 let norm = norm.sqrt();
383 col /= norm;
384 }
385
386 Ok(FittedEMDSampler {
387 sampler: EMDKernelSampler {
388 projections: Some(projections),
389 ..self.clone()
390 },
391 })
392 }
393}
394
395pub struct FittedEMDSampler {
397 sampler: EMDKernelSampler,
398}
399
400impl Transform<Array2<f64>, Array2<f64>> for FittedEMDSampler {
401 fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
402 let projections =
403 self.sampler
404 .projections
405 .as_ref()
406 .ok_or_else(|| SklearsError::NotFitted {
407 operation: "transform".to_string(),
408 })?;
409
410 let n_samples = x.nrows();
411 let mut features = Array2::zeros((n_samples, self.sampler.n_components));
412
413 let projected = x.dot(projections);
414
415 for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
416 let min_val = proj_col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
418 let max_val = proj_col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
419 let range = max_val - min_val;
420
421 if range > 1e-10 {
422 for (i, &val) in proj_col.iter().enumerate() {
423 let normalized_val = (val - min_val) / range;
425 let emd_feature = (-normalized_val.abs() / self.sampler.bandwidth).exp();
426 features[[i, j]] = emd_feature;
427 }
428 }
429 }
430
431 Ok(features)
432 }
433}
434
435#[derive(Debug, Clone)]
440pub struct GromovWassersteinSampler {
442 n_components: usize,
444 loss_function: GWLossFunction,
446 max_iter: usize,
448 random_state: Option<u64>,
450 fitted_params: Option<GWFittedParams>,
452}
453
454#[derive(Debug, Clone)]
456pub enum GWLossFunction {
458 Square,
460 KlLoss,
462 Custom(fn(f64, f64, f64, f64) -> f64),
464}
465
466#[derive(Debug, Clone)]
467struct GWFittedParams {
468 reference_distances: Array2<f64>,
469 projections: Array2<f64>,
470}
471
472impl Default for GromovWassersteinSampler {
473 fn default() -> Self {
474 Self::new(50)
475 }
476}
477
478impl GromovWassersteinSampler {
479 pub fn new(n_components: usize) -> Self {
481 Self {
482 n_components,
483 loss_function: GWLossFunction::Square,
484 max_iter: 100,
485 random_state: None,
486 fitted_params: None,
487 }
488 }
489
490 pub fn loss_function(mut self, loss: GWLossFunction) -> Self {
492 self.loss_function = loss;
493 self
494 }
495
496 pub fn random_state(mut self, seed: u64) -> Self {
498 self.random_state = Some(seed);
499 self
500 }
501
502 fn compute_distance_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
504 let n = x.nrows();
505 let mut distances = Array2::zeros((n, n));
506
507 for i in 0..n {
508 for j in i..n {
509 let diff = &x.row(i).to_owned() - &x.row(j);
510 let dist: f64 = diff.mapv(|v| v * v).sum();
511 let dist = dist.sqrt();
512 distances[[i, j]] = dist;
513 distances[[j, i]] = dist;
514 }
515 }
516
517 distances
518 }
519}
520
521impl Fit<Array2<f64>, ()> for GromovWassersteinSampler {
522 type Fitted = FittedGromovWassersteinSampler;
523
524 fn fit(self, x: &Array2<f64>, _y: &()) -> SklResult<Self::Fitted> {
525 let n_features = x.ncols();
526 let _n_samples = x.nrows();
527
528 let distance_matrix = self.compute_distance_matrix(x);
530
531 let mut rng = if let Some(seed) = self.random_state {
533 RealStdRng::seed_from_u64(seed)
534 } else {
535 RealStdRng::from_seed(thread_rng().gen())
536 };
537
538 let uniform = RandUniform::new(-1.0, 1.0).unwrap();
539 let mut projections = Array2::zeros((n_features, self.n_components));
540
541 for mut col in projections.axis_iter_mut(Axis(1)) {
542 for elem in col.iter_mut() {
543 *elem = rng.sample(uniform);
544 }
545 let norm: f64 = col.mapv(|v| v * v).sum();
546 let norm = norm.sqrt();
547 col /= norm;
548 }
549
550 let fitted_params = GWFittedParams {
551 reference_distances: distance_matrix,
552 projections,
553 };
554
555 Ok(FittedGromovWassersteinSampler {
556 sampler: GromovWassersteinSampler {
557 fitted_params: Some(fitted_params),
558 ..self.clone()
559 },
560 })
561 }
562}
563
564pub struct FittedGromovWassersteinSampler {
566 sampler: GromovWassersteinSampler,
567}
568
569impl Transform<Array2<f64>, Array2<f64>> for FittedGromovWassersteinSampler {
570 fn transform(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
571 let params =
572 self.sampler
573 .fitted_params
574 .as_ref()
575 .ok_or_else(|| SklearsError::NotFitted {
576 operation: "transform".to_string(),
577 })?;
578
579 let n_samples = x.nrows();
580 let mut features = Array2::zeros((n_samples, self.sampler.n_components));
581
582 let projected = x.dot(¶ms.projections);
584
585 for (j, proj_col) in projected.axis_iter(Axis(1)).enumerate() {
586 for (i, &val) in proj_col.iter().enumerate() {
587 features[[i, j]] = val.tanh(); }
590 }
591
592 Ok(features)
593 }
594}
595
596#[allow(non_snake_case)]
597#[cfg(test)]
598mod tests {
599 use super::*;
600 use scirs2_core::ndarray::array;
601 use sklears_core::traits::{Fit, Transform};
602
603 #[test]
604 fn test_wasserstein_kernel_sampler() {
605 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
606
607 let sampler = WassersteinKernelSampler::new(10)
608 .random_state(42)
609 .transport_method(TransportMethod::SlicedWasserstein);
610
611 let fitted = sampler.fit(&x, &()).unwrap();
612 let features = fitted.transform(&x).unwrap();
613
614 assert_eq!(features.shape(), &[4, 10]);
615
616 assert!(features.iter().all(|&f| f >= 0.0 && f <= 1.0));
618 }
619
620 #[test]
621 fn test_emd_kernel_sampler() {
622 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
623
624 let sampler = EMDKernelSampler::new(15).bandwidth(0.5).random_state(123);
625
626 let fitted = sampler.fit(&x, &()).unwrap();
627 let features = fitted.transform(&x).unwrap();
628
629 assert_eq!(features.shape(), &[3, 15]);
630
631 assert!(features.iter().all(|&f| f >= 0.0));
633 }
634
635 #[test]
636 fn test_gromov_wasserstein_sampler() {
637 let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0],];
638
639 let sampler = GromovWassersteinSampler::new(8).random_state(456);
640
641 let fitted = sampler.fit(&x, &()).unwrap();
642 let features = fitted.transform(&x).unwrap();
643
644 assert_eq!(features.shape(), &[4, 8]);
645
646 assert!(features.iter().all(|&f| f >= -1.0 && f <= 1.0));
648 }
649
650 #[test]
651 fn test_wasserstein_different_methods() {
652 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
653
654 let methods = vec![
655 TransportMethod::SlicedWasserstein,
656 TransportMethod::Sinkhorn,
657 TransportMethod::TreeWasserstein,
658 ];
659
660 for method in methods {
661 let sampler = WassersteinKernelSampler::new(5)
662 .transport_method(method)
663 .random_state(42);
664
665 let fitted = sampler.fit(&x, &()).unwrap();
666 let features = fitted.transform(&x).unwrap();
667
668 assert_eq!(features.shape(), &[3, 5]);
669 }
670 }
671
672 #[test]
673 fn test_different_ground_metrics() {
674 let x = array![[1.0, 2.0], [2.0, 3.0],];
675
676 let metrics = vec![
677 GroundMetric::SquaredEuclidean,
678 GroundMetric::Euclidean,
679 GroundMetric::Manhattan,
680 GroundMetric::Minkowski(3.0),
681 ];
682
683 for metric in metrics {
684 let sampler = WassersteinKernelSampler::new(3)
685 .ground_metric(metric)
686 .random_state(42);
687
688 let fitted = sampler.fit(&x, &()).unwrap();
689 let features = fitted.transform(&x).unwrap();
690
691 assert_eq!(features.shape(), &[2, 3]);
692 }
693 }
694
695 #[test]
696 fn test_reproducibility() {
697 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
698
699 let sampler1 = WassersteinKernelSampler::new(5).random_state(42);
700 let sampler2 = WassersteinKernelSampler::new(5).random_state(42);
701
702 let fitted1 = sampler1.fit(&x, &()).unwrap();
703 let fitted2 = sampler2.fit(&x, &()).unwrap();
704
705 let features1 = fitted1.transform(&x).unwrap();
706 let features2 = fitted2.transform(&x).unwrap();
707
708 assert!((features1 - features2).mapv(|v| v.abs()).sum() < 1e-10);
710 }
711}