ruvector_math/optimal_transport/
config.rs1#[derive(Debug, Clone)]
5pub struct WassersteinConfig {
6 pub num_projections: usize,
8 pub regularization: f64,
10 pub max_iterations: usize,
12 pub threshold: f64,
14 pub p: f64,
16 pub seed: Option<u64>,
18}
19
20impl Default for WassersteinConfig {
21 fn default() -> Self {
22 Self {
23 num_projections: 100,
24 regularization: 0.1,
25 max_iterations: 100,
26 threshold: 1e-6,
27 p: 2.0,
28 seed: None,
29 }
30 }
31}
32
33impl WassersteinConfig {
34 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn with_projections(mut self, n: usize) -> Self {
41 self.num_projections = n;
42 self
43 }
44
45 pub fn with_regularization(mut self, eps: f64) -> Self {
47 self.regularization = eps;
48 self
49 }
50
51 pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
53 self.max_iterations = max_iter;
54 self
55 }
56
57 pub fn with_threshold(mut self, threshold: f64) -> Self {
59 self.threshold = threshold;
60 self
61 }
62
63 pub fn with_power(mut self, p: f64) -> Self {
65 self.p = p;
66 self
67 }
68
69 pub fn with_seed(mut self, seed: u64) -> Self {
71 self.seed = Some(seed);
72 self
73 }
74
75 pub fn validate(&self) -> crate::Result<()> {
77 if self.num_projections == 0 {
78 return Err(crate::MathError::invalid_parameter(
79 "num_projections",
80 "must be > 0",
81 ));
82 }
83 if self.regularization <= 0.0 {
84 return Err(crate::MathError::invalid_parameter(
85 "regularization",
86 "must be > 0",
87 ));
88 }
89 if self.p <= 0.0 {
90 return Err(crate::MathError::invalid_parameter("p", "must be > 0"));
91 }
92 Ok(())
93 }
94}