Skip to main content

scirs2_optimize/dro/
types.rs

1//! Types for Distributionally Robust Optimization (DRO).
2//!
3//! Provides configuration structs, result types, and core abstractions for
4//! Wasserstein-ball and CVaR-based DRO.
5
6use crate::error::{OptimizeError, OptimizeResult};
7
8// ---------------------------------------------------------------------------
9// Configuration
10// ---------------------------------------------------------------------------
11
12/// Configuration for distributionally robust optimization.
13///
14/// Controls the Wasserstein ball radius, sample count, and solver parameters.
15#[derive(Debug, Clone)]
16pub struct DroConfig {
17    /// Wasserstein ball radius ε ≥ 0.  Larger values yield more conservative solutions.
18    pub radius: f64,
19    /// Number of empirical samples drawn from the reference distribution P_N.
20    pub n_samples: usize,
21    /// Maximum number of outer solver iterations.
22    pub max_iter: usize,
23    /// Convergence tolerance on the gradient norm.
24    pub tol: f64,
25    /// Step size for (sub)gradient descent.  When `None` the solver uses
26    /// the adaptive schedule 1/√t.
27    pub step_size: Option<f64>,
28}
29
30impl Default for DroConfig {
31    fn default() -> Self {
32        Self {
33            radius: 0.1,
34            n_samples: 100,
35            max_iter: 500,
36            tol: 1e-6,
37            step_size: None,
38        }
39    }
40}
41
42impl DroConfig {
43    /// Validate configuration parameters.
44    pub fn validate(&self) -> OptimizeResult<()> {
45        if self.radius < 0.0 {
46            return Err(OptimizeError::InvalidParameter(
47                "radius must be non-negative".into(),
48            ));
49        }
50        if self.n_samples == 0 {
51            return Err(OptimizeError::InvalidParameter(
52                "n_samples must be positive".into(),
53            ));
54        }
55        if self.max_iter == 0 {
56            return Err(OptimizeError::InvalidParameter(
57                "max_iter must be positive".into(),
58            ));
59        }
60        if self.tol <= 0.0 {
61            return Err(OptimizeError::InvalidParameter(
62                "tol must be positive".into(),
63            ));
64        }
65        Ok(())
66    }
67}
68
69// ---------------------------------------------------------------------------
70// Result types
71// ---------------------------------------------------------------------------
72
73/// Result of a distributionally robust optimization run.
74#[derive(Debug, Clone)]
75pub struct DroResult {
76    /// Optimal decision variable weights (e.g. portfolio weights).
77    pub optimal_weights: Vec<f64>,
78    /// Worst-case expected loss under the Wasserstein-ball ambiguity set.
79    pub worst_case_loss: f64,
80    /// Primal objective value at the optimal weights.
81    pub primal_obj: f64,
82    /// Number of iterations performed.
83    pub n_iter: usize,
84    /// Whether the solver converged to the requested tolerance.
85    pub converged: bool,
86}
87
88// ---------------------------------------------------------------------------
89// Wasserstein ball description
90// ---------------------------------------------------------------------------
91
92/// Describes a Wasserstein-1 ball around a set of centre samples.
93///
94/// The ball B_ε(P_N) = {Q : W_1(Q, P_N) ≤ ε} contains all probability
95/// measures within Wasserstein distance ε of the empirical distribution P_N.
96#[derive(Debug, Clone)]
97pub struct WassersteinBall {
98    /// Centre samples {x_1, …, x_N} defining the empirical distribution P_N.
99    pub center_samples: Vec<Vec<f64>>,
100    /// Ball radius ε ≥ 0.
101    pub radius: f64,
102}
103
104impl WassersteinBall {
105    /// Create a new Wasserstein ball.
106    ///
107    /// Returns an error when `radius < 0` or `center_samples` is empty.
108    pub fn new(center_samples: Vec<Vec<f64>>, radius: f64) -> OptimizeResult<Self> {
109        if radius < 0.0 {
110            return Err(OptimizeError::InvalidParameter(
111                "Wasserstein ball radius must be non-negative".into(),
112            ));
113        }
114        if center_samples.is_empty() {
115            return Err(OptimizeError::InvalidParameter(
116                "center_samples must be non-empty".into(),
117            ));
118        }
119        Ok(Self {
120            center_samples,
121            radius,
122        })
123    }
124
125    /// Wasserstein-1 distance from the empirical centre to a single point `q`.
126    ///
127    /// For a discrete empirical distribution P_N the W_1 distance to the
128    /// Dirac mass δ_q is min_i ‖x_i − q‖_2 (the nearest-centre distance).
129    pub fn distance_to_point(&self, q: &[f64]) -> f64 {
130        self.center_samples
131            .iter()
132            .map(|c| {
133                c.iter()
134                    .zip(q.iter())
135                    .map(|(a, b)| (a - b).powi(2))
136                    .sum::<f64>()
137                    .sqrt()
138            })
139            .fold(f64::INFINITY, f64::min)
140    }
141
142    /// Check whether `q` is within the Wasserstein ball of the empirical centre.
143    ///
144    /// Returns `true` iff `distance_to_point(q) ≤ self.radius`.
145    pub fn contains_point(&self, q: &[f64]) -> bool {
146        self.distance_to_point(q) <= self.radius + f64::EPSILON
147    }
148}
149
150// ---------------------------------------------------------------------------
151// Robust objective variants
152// ---------------------------------------------------------------------------
153
154/// Selection of distributionally robust objective criterion.
155#[derive(Debug, Clone, Copy, PartialEq)]
156#[non_exhaustive]
157pub enum RobustObjective {
158    /// Mean-variance trade-off: `E[loss] + lambda * Var[loss]`.
159    MeanVariance {
160        /// Trade-off parameter λ ≥ 0.  Larger values penalise variance more.
161        lambda: f64,
162    },
163    /// Conditional Value-at-Risk at level α ∈ (0,1).
164    CVaR {
165        /// Confidence level α.  Must be in (0, 1).
166        alpha: f64,
167    },
168    /// Pure worst-case (minimax) objective: `max_{Q in B_eps} E_Q[loss]`.
169    WorstCase,
170}
171
172impl Default for RobustObjective {
173    fn default() -> Self {
174        Self::CVaR { alpha: 0.95 }
175    }
176}
177
178// ---------------------------------------------------------------------------
179// DroSolver
180// ---------------------------------------------------------------------------
181
182/// High-level handle to a DRO solver.
183///
184/// Stores configuration and exposes a uniform interface for different DRO
185/// objectives.  Actual computation is delegated to the functions in
186/// [`super::wasserstein_dro`] and [`super::cvar_dro`].
187#[derive(Debug, Clone)]
188pub struct DroSolver {
189    /// Solver configuration.
190    pub config: DroConfig,
191    /// Objective criterion.
192    pub objective: RobustObjective,
193}
194
195impl Default for DroSolver {
196    fn default() -> Self {
197        Self {
198            config: DroConfig::default(),
199            objective: RobustObjective::default(),
200        }
201    }
202}
203
204impl DroSolver {
205    /// Create a new DRO solver with the given config and objective.
206    pub fn new(config: DroConfig, objective: RobustObjective) -> OptimizeResult<Self> {
207        config.validate()?;
208        Ok(Self { config, objective })
209    }
210}
211
212// ---------------------------------------------------------------------------
213// Tests
214// ---------------------------------------------------------------------------
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn test_dro_config_default_valid() {
222        let cfg = DroConfig::default();
223        assert!(cfg.validate().is_ok());
224    }
225
226    #[test]
227    fn test_dro_config_negative_radius_error() {
228        let cfg = DroConfig {
229            radius: -0.1,
230            ..Default::default()
231        };
232        assert!(cfg.validate().is_err());
233    }
234
235    #[test]
236    fn test_wasserstein_ball_contains_center() {
237        // Distance from a single-sample centre to itself is 0; should be in ball.
238        let sample = vec![1.0, 2.0];
239        let ball = WassersteinBall::new(vec![sample.clone()], 0.5).expect("valid ball");
240        assert!(ball.contains_point(&sample));
241    }
242
243    #[test]
244    fn test_wasserstein_ball_outside_radius() {
245        let sample = vec![0.0, 0.0];
246        let ball = WassersteinBall::new(vec![sample], 0.5).expect("valid ball");
247        // Point at distance sqrt(2) ≈ 1.41 > 0.5
248        assert!(!ball.contains_point(&[1.0, 1.0]));
249    }
250
251    #[test]
252    fn test_wasserstein_ball_negative_radius_error() {
253        assert!(WassersteinBall::new(vec![vec![0.0]], -0.1).is_err());
254    }
255
256    #[test]
257    fn test_robust_objective_default() {
258        let obj = RobustObjective::default();
259        matches!(obj, RobustObjective::CVaR { .. });
260    }
261
262    #[test]
263    fn test_dro_solver_default() {
264        let solver = DroSolver::default();
265        assert!(solver.config.radius >= 0.0);
266    }
267}