scirs2_optimize/dro/types.rs
1//! Types for Distributionally Robust Optimization (DRO).
2//!
3//! Provides configuration structs, result types, and core abstractions for
4//! Wasserstein-ball and CVaR-based DRO.
5
6use crate::error::{OptimizeError, OptimizeResult};
7
8// ---------------------------------------------------------------------------
9// Configuration
10// ---------------------------------------------------------------------------
11
12/// Configuration for distributionally robust optimization.
13///
14/// Controls the Wasserstein ball radius, sample count, and solver parameters.
15#[derive(Debug, Clone)]
16pub struct DroConfig {
17 /// Wasserstein ball radius ε ≥ 0. Larger values yield more conservative solutions.
18 pub radius: f64,
19 /// Number of empirical samples drawn from the reference distribution P_N.
20 pub n_samples: usize,
21 /// Maximum number of outer solver iterations.
22 pub max_iter: usize,
23 /// Convergence tolerance on the gradient norm.
24 pub tol: f64,
25 /// Step size for (sub)gradient descent. When `None` the solver uses
26 /// the adaptive schedule 1/√t.
27 pub step_size: Option<f64>,
28}
29
30impl Default for DroConfig {
31 fn default() -> Self {
32 Self {
33 radius: 0.1,
34 n_samples: 100,
35 max_iter: 500,
36 tol: 1e-6,
37 step_size: None,
38 }
39 }
40}
41
42impl DroConfig {
43 /// Validate configuration parameters.
44 pub fn validate(&self) -> OptimizeResult<()> {
45 if self.radius < 0.0 {
46 return Err(OptimizeError::InvalidParameter(
47 "radius must be non-negative".into(),
48 ));
49 }
50 if self.n_samples == 0 {
51 return Err(OptimizeError::InvalidParameter(
52 "n_samples must be positive".into(),
53 ));
54 }
55 if self.max_iter == 0 {
56 return Err(OptimizeError::InvalidParameter(
57 "max_iter must be positive".into(),
58 ));
59 }
60 if self.tol <= 0.0 {
61 return Err(OptimizeError::InvalidParameter(
62 "tol must be positive".into(),
63 ));
64 }
65 Ok(())
66 }
67}
68
69// ---------------------------------------------------------------------------
70// Result types
71// ---------------------------------------------------------------------------
72
73/// Result of a distributionally robust optimization run.
74#[derive(Debug, Clone)]
75pub struct DroResult {
76 /// Optimal decision variable weights (e.g. portfolio weights).
77 pub optimal_weights: Vec<f64>,
78 /// Worst-case expected loss under the Wasserstein-ball ambiguity set.
79 pub worst_case_loss: f64,
80 /// Primal objective value at the optimal weights.
81 pub primal_obj: f64,
82 /// Number of iterations performed.
83 pub n_iter: usize,
84 /// Whether the solver converged to the requested tolerance.
85 pub converged: bool,
86}
87
88// ---------------------------------------------------------------------------
89// Wasserstein ball description
90// ---------------------------------------------------------------------------
91
92/// Describes a Wasserstein-1 ball around a set of centre samples.
93///
94/// The ball B_ε(P_N) = {Q : W_1(Q, P_N) ≤ ε} contains all probability
95/// measures within Wasserstein distance ε of the empirical distribution P_N.
96#[derive(Debug, Clone)]
97pub struct WassersteinBall {
98 /// Centre samples {x_1, …, x_N} defining the empirical distribution P_N.
99 pub center_samples: Vec<Vec<f64>>,
100 /// Ball radius ε ≥ 0.
101 pub radius: f64,
102}
103
104impl WassersteinBall {
105 /// Create a new Wasserstein ball.
106 ///
107 /// Returns an error when `radius < 0` or `center_samples` is empty.
108 pub fn new(center_samples: Vec<Vec<f64>>, radius: f64) -> OptimizeResult<Self> {
109 if radius < 0.0 {
110 return Err(OptimizeError::InvalidParameter(
111 "Wasserstein ball radius must be non-negative".into(),
112 ));
113 }
114 if center_samples.is_empty() {
115 return Err(OptimizeError::InvalidParameter(
116 "center_samples must be non-empty".into(),
117 ));
118 }
119 Ok(Self {
120 center_samples,
121 radius,
122 })
123 }
124
125 /// Wasserstein-1 distance from the empirical centre to a single point `q`.
126 ///
127 /// For a discrete empirical distribution P_N the W_1 distance to the
128 /// Dirac mass δ_q is min_i ‖x_i − q‖_2 (the nearest-centre distance).
129 pub fn distance_to_point(&self, q: &[f64]) -> f64 {
130 self.center_samples
131 .iter()
132 .map(|c| {
133 c.iter()
134 .zip(q.iter())
135 .map(|(a, b)| (a - b).powi(2))
136 .sum::<f64>()
137 .sqrt()
138 })
139 .fold(f64::INFINITY, f64::min)
140 }
141
142 /// Check whether `q` is within the Wasserstein ball of the empirical centre.
143 ///
144 /// Returns `true` iff `distance_to_point(q) ≤ self.radius`.
145 pub fn contains_point(&self, q: &[f64]) -> bool {
146 self.distance_to_point(q) <= self.radius + f64::EPSILON
147 }
148}
149
150// ---------------------------------------------------------------------------
151// Robust objective variants
152// ---------------------------------------------------------------------------
153
154/// Selection of distributionally robust objective criterion.
155#[derive(Debug, Clone, Copy, PartialEq)]
156#[non_exhaustive]
157pub enum RobustObjective {
158 /// Mean-variance trade-off: `E[loss] + lambda * Var[loss]`.
159 MeanVariance {
160 /// Trade-off parameter λ ≥ 0. Larger values penalise variance more.
161 lambda: f64,
162 },
163 /// Conditional Value-at-Risk at level α ∈ (0,1).
164 CVaR {
165 /// Confidence level α. Must be in (0, 1).
166 alpha: f64,
167 },
168 /// Pure worst-case (minimax) objective: `max_{Q in B_eps} E_Q[loss]`.
169 WorstCase,
170}
171
172impl Default for RobustObjective {
173 fn default() -> Self {
174 Self::CVaR { alpha: 0.95 }
175 }
176}
177
178// ---------------------------------------------------------------------------
179// DroSolver
180// ---------------------------------------------------------------------------
181
182/// High-level handle to a DRO solver.
183///
184/// Stores configuration and exposes a uniform interface for different DRO
185/// objectives. Actual computation is delegated to the functions in
186/// [`super::wasserstein_dro`] and [`super::cvar_dro`].
187#[derive(Debug, Clone)]
188pub struct DroSolver {
189 /// Solver configuration.
190 pub config: DroConfig,
191 /// Objective criterion.
192 pub objective: RobustObjective,
193}
194
195impl Default for DroSolver {
196 fn default() -> Self {
197 Self {
198 config: DroConfig::default(),
199 objective: RobustObjective::default(),
200 }
201 }
202}
203
204impl DroSolver {
205 /// Create a new DRO solver with the given config and objective.
206 pub fn new(config: DroConfig, objective: RobustObjective) -> OptimizeResult<Self> {
207 config.validate()?;
208 Ok(Self { config, objective })
209 }
210}
211
212// ---------------------------------------------------------------------------
213// Tests
214// ---------------------------------------------------------------------------
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn test_dro_config_default_valid() {
222 let cfg = DroConfig::default();
223 assert!(cfg.validate().is_ok());
224 }
225
226 #[test]
227 fn test_dro_config_negative_radius_error() {
228 let cfg = DroConfig {
229 radius: -0.1,
230 ..Default::default()
231 };
232 assert!(cfg.validate().is_err());
233 }
234
235 #[test]
236 fn test_wasserstein_ball_contains_center() {
237 // Distance from a single-sample centre to itself is 0; should be in ball.
238 let sample = vec![1.0, 2.0];
239 let ball = WassersteinBall::new(vec![sample.clone()], 0.5).expect("valid ball");
240 assert!(ball.contains_point(&sample));
241 }
242
243 #[test]
244 fn test_wasserstein_ball_outside_radius() {
245 let sample = vec![0.0, 0.0];
246 let ball = WassersteinBall::new(vec![sample], 0.5).expect("valid ball");
247 // Point at distance sqrt(2) ≈ 1.41 > 0.5
248 assert!(!ball.contains_point(&[1.0, 1.0]));
249 }
250
251 #[test]
252 fn test_wasserstein_ball_negative_radius_error() {
253 assert!(WassersteinBall::new(vec![vec![0.0]], -0.1).is_err());
254 }
255
256 #[test]
257 fn test_robust_objective_default() {
258 let obj = RobustObjective::default();
259 matches!(obj, RobustObjective::CVaR { .. });
260 }
261
262 #[test]
263 fn test_dro_solver_default() {
264 let solver = DroSolver::default();
265 assert!(solver.config.radius >= 0.0);
266 }
267}