ruvector_math/optimal_transport/
config.rs

1//! Configuration for optimal transport algorithms
2
3/// Configuration for Wasserstein distance computation
4#[derive(Debug, Clone)]
5pub struct WassersteinConfig {
6    /// Number of random projections for Sliced Wasserstein
7    pub num_projections: usize,
8    /// Regularization parameter for Sinkhorn (epsilon)
9    pub regularization: f64,
10    /// Maximum iterations for Sinkhorn
11    pub max_iterations: usize,
12    /// Convergence threshold for Sinkhorn
13    pub threshold: f64,
14    /// Power p for Wasserstein-p distance
15    pub p: f64,
16    /// Random seed for reproducibility
17    pub seed: Option<u64>,
18}
19
20impl Default for WassersteinConfig {
21    fn default() -> Self {
22        Self {
23            num_projections: 100,
24            regularization: 0.1,
25            max_iterations: 100,
26            threshold: 1e-6,
27            p: 2.0,
28            seed: None,
29        }
30    }
31}
32
33impl WassersteinConfig {
34    /// Create a new configuration with default values
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Set the number of random projections
40    pub fn with_projections(mut self, n: usize) -> Self {
41        self.num_projections = n;
42        self
43    }
44
45    /// Set the regularization parameter
46    pub fn with_regularization(mut self, eps: f64) -> Self {
47        self.regularization = eps;
48        self
49    }
50
51    /// Set the maximum iterations
52    pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
53        self.max_iterations = max_iter;
54        self
55    }
56
57    /// Set the convergence threshold
58    pub fn with_threshold(mut self, threshold: f64) -> Self {
59        self.threshold = threshold;
60        self
61    }
62
63    /// Set the Wasserstein power
64    pub fn with_power(mut self, p: f64) -> Self {
65        self.p = p;
66        self
67    }
68
69    /// Set the random seed
70    pub fn with_seed(mut self, seed: u64) -> Self {
71        self.seed = Some(seed);
72        self
73    }
74
75    /// Validate the configuration
76    pub fn validate(&self) -> crate::Result<()> {
77        if self.num_projections == 0 {
78            return Err(crate::MathError::invalid_parameter(
79                "num_projections",
80                "must be > 0",
81            ));
82        }
83        if self.regularization <= 0.0 {
84            return Err(crate::MathError::invalid_parameter(
85                "regularization",
86                "must be > 0",
87            ));
88        }
89        if self.p <= 0.0 {
90            return Err(crate::MathError::invalid_parameter("p", "must be > 0"));
91        }
92        Ok(())
93    }
94}