1use std::path::Path;
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10use crate::error::{EvalError, Result};
11
12#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct EvalConfig {
15 #[serde(default)]
17 pub orchestrator: OrchestratorSettings,
18
19 #[serde(default)]
21 pub eval: EvalSettings,
22
23 #[serde(default)]
25 pub assertions: Vec<AssertionConfig>,
26
27 #[serde(default)]
29 pub faults: Vec<FaultConfig>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct OrchestratorSettings {
35 #[serde(default = "default_tick_duration_ms")]
37 pub tick_duration_ms: u64,
38
39 #[serde(default = "default_max_ticks")]
41 pub max_ticks: u64,
42}
43
44fn default_tick_duration_ms() -> u64 {
45 10
46}
47
48fn default_max_ticks() -> u64 {
49 1000
50}
51
52impl Default for OrchestratorSettings {
53 fn default() -> Self {
54 Self {
55 tick_duration_ms: default_tick_duration_ms(),
56 max_ticks: default_max_ticks(),
57 }
58 }
59}
60
61impl EvalConfig {
62 pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
64 let content = std::fs::read_to_string(path)?;
65 Self::from_toml_str(&content)
66 }
67
68 pub fn from_toml_str(content: &str) -> Result<Self> {
70 let config: EvalConfig = toml::from_str(content)?;
71 config.validate()?;
72 Ok(config)
73 }
74
75 fn validate(&self) -> Result<()> {
77 if self.eval.runs == 0 {
78 return Err(EvalError::Config("runs must be > 0".to_string()));
79 }
80 Ok(())
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct EvalSettings {
87 #[serde(default = "default_runs")]
89 pub runs: usize,
90
91 pub base_seed: Option<u64>,
93
94 #[serde(default = "default_true")]
96 pub record_seeds: bool,
97
98 #[serde(default = "default_parallel")]
100 pub parallel: usize,
101
102 #[serde(default)]
104 pub target_tick_duration_ms: Option<u64>,
105}
106
107fn default_runs() -> usize {
108 30
109}
110
111fn default_true() -> bool {
112 true
113}
114
115fn default_parallel() -> usize {
116 1
117}
118
119impl Default for EvalSettings {
120 fn default() -> Self {
121 Self {
122 runs: default_runs(),
123 base_seed: None,
124 record_seeds: true,
125 parallel: default_parallel(),
126 target_tick_duration_ms: None,
127 }
128 }
129}
130
131impl EvalSettings {
132 pub fn target_tick_duration(&self) -> Option<Duration> {
134 self.target_tick_duration_ms.map(Duration::from_millis)
135 }
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct AssertionConfig {
141 pub name: String,
143
144 pub metric: String,
146
147 pub op: ComparisonOp,
149
150 pub expected: f64,
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
156#[serde(rename_all = "snake_case")]
157pub enum ComparisonOp {
158 Gt,
160 Gte,
162 Lt,
164 Lte,
166 Eq,
168}
169
170impl ComparisonOp {
171 pub fn check(&self, actual: f64, expected: f64) -> bool {
173 const EPSILON: f64 = 1e-9;
174 match self {
175 ComparisonOp::Gt => actual > expected,
176 ComparisonOp::Gte => actual >= expected - EPSILON,
177 ComparisonOp::Lt => actual < expected,
178 ComparisonOp::Lte => actual <= expected + EPSILON,
179 ComparisonOp::Eq => (actual - expected).abs() < EPSILON,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct FaultConfig {
187 pub fault_type: FaultType,
189
190 #[serde(default)]
192 pub tick_range: Option<(u64, u64)>,
193
194 #[serde(default = "default_probability")]
196 pub probability: f64,
197
198 #[serde(default)]
200 pub duration_ticks: Option<u64>,
201
202 #[serde(default)]
204 pub target_workers: Option<Vec<usize>>,
205}
206
207fn default_probability() -> f64 {
208 1.0
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213#[serde(tag = "type", rename_all = "snake_case")]
214pub enum FaultType {
215 DelayInjection {
217 delay_ms: u64,
219 },
220
221 WorkerSkip,
223
224 GuidanceOverride {
226 goal: String,
228 },
229
230 ActionFailure,
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_default_config() {
240 let config = EvalConfig::default();
241 assert_eq!(config.eval.runs, 30);
242 assert!(config.eval.record_seeds);
243 assert_eq!(config.eval.parallel, 1);
244 }
245
246 #[test]
247 fn test_parse_minimal_toml() {
248 let toml = r#"
249[eval]
250runs = 10
251"#;
252 let config = EvalConfig::from_toml_str(toml).unwrap();
253 assert_eq!(config.eval.runs, 10);
254 }
255
256 #[test]
257 fn test_parse_with_assertions() {
258 let toml = r#"
259[eval]
260runs = 30
261
262[[assertions]]
263name = "success_rate_threshold"
264metric = "success_rate"
265op = "gte"
266expected = 0.8
267"#;
268 let config = EvalConfig::from_toml_str(toml).unwrap();
269 assert_eq!(config.assertions.len(), 1);
270 assert_eq!(config.assertions[0].name, "success_rate_threshold");
271 assert_eq!(config.assertions[0].op, ComparisonOp::Gte);
272 }
273
274 #[test]
275 fn test_parse_with_faults() {
276 let toml = r#"
277[eval]
278runs = 10
279
280[[faults]]
281fault_type = { type = "delay_injection", delay_ms = 100 }
282tick_range = [10, 50]
283probability = 0.1
284"#;
285 let config = EvalConfig::from_toml_str(toml).unwrap();
286 assert_eq!(config.faults.len(), 1);
287 assert_eq!(config.faults[0].tick_range, Some((10, 50)));
288 }
289
290 #[test]
291 fn test_comparison_op() {
292 assert!(ComparisonOp::Gt.check(0.9, 0.8));
293 assert!(!ComparisonOp::Gt.check(0.8, 0.8));
294
295 assert!(ComparisonOp::Gte.check(0.8, 0.8));
296 assert!(ComparisonOp::Gte.check(0.9, 0.8));
297
298 assert!(ComparisonOp::Lt.check(0.7, 0.8));
299 assert!(!ComparisonOp::Lt.check(0.8, 0.8));
300
301 assert!(ComparisonOp::Lte.check(0.8, 0.8));
302 assert!(ComparisonOp::Lte.check(0.7, 0.8));
303
304 assert!(ComparisonOp::Eq.check(0.8, 0.8));
305 assert!(!ComparisonOp::Eq.check(0.81, 0.8));
306 }
307
308 #[test]
309 fn test_invalid_runs() {
310 let toml = r#"
311[eval]
312runs = 0
313"#;
314 let result = EvalConfig::from_toml_str(toml);
315 assert!(result.is_err());
316 }
317}