Skip to main content

swarm_engine_eval/
config.rs

1//! Evaluation configuration
2//!
3//! TOML 設定ファイルから評価設定を読み込みます。
4
5use std::path::Path;
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10use crate::error::{EvalError, Result};
11
12/// Evaluation configuration
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct EvalConfig {
15    /// Orchestrator settings
16    #[serde(default)]
17    pub orchestrator: OrchestratorSettings,
18
19    /// Evaluation-specific settings
20    #[serde(default)]
21    pub eval: EvalSettings,
22
23    /// Assertions to verify
24    #[serde(default)]
25    pub assertions: Vec<AssertionConfig>,
26
27    /// Fault injection configurations
28    #[serde(default)]
29    pub faults: Vec<FaultConfig>,
30}
31
32/// Orchestrator settings
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct OrchestratorSettings {
35    /// Tick duration in milliseconds
36    #[serde(default = "default_tick_duration_ms")]
37    pub tick_duration_ms: u64,
38
39    /// Maximum ticks
40    #[serde(default = "default_max_ticks")]
41    pub max_ticks: u64,
42}
43
44fn default_tick_duration_ms() -> u64 {
45    10
46}
47
48fn default_max_ticks() -> u64 {
49    1000
50}
51
52impl Default for OrchestratorSettings {
53    fn default() -> Self {
54        Self {
55            tick_duration_ms: default_tick_duration_ms(),
56            max_ticks: default_max_ticks(),
57        }
58    }
59}
60
61impl EvalConfig {
62    /// Load from TOML file
63    pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
64        let content = std::fs::read_to_string(path)?;
65        Self::from_toml_str(&content)
66    }
67
68    /// Parse from TOML string
69    pub fn from_toml_str(content: &str) -> Result<Self> {
70        let config: EvalConfig = toml::from_str(content)?;
71        config.validate()?;
72        Ok(config)
73    }
74
75    /// Validate configuration
76    fn validate(&self) -> Result<()> {
77        if self.eval.runs == 0 {
78            return Err(EvalError::Config("runs must be > 0".to_string()));
79        }
80        Ok(())
81    }
82}
83
84/// Evaluation-specific settings
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct EvalSettings {
87    /// Number of runs for statistical analysis
88    #[serde(default = "default_runs")]
89    pub runs: usize,
90
91    /// Base seed for reproducibility (None = use current time)
92    pub base_seed: Option<u64>,
93
94    /// Record seeds in report
95    #[serde(default = "default_true")]
96    pub record_seeds: bool,
97
98    /// Parallel execution (number of concurrent runs)
99    #[serde(default = "default_parallel")]
100    pub parallel: usize,
101
102    /// Target tick duration for miss rate calculation
103    #[serde(default)]
104    pub target_tick_duration_ms: Option<u64>,
105}
106
107fn default_runs() -> usize {
108    30
109}
110
111fn default_true() -> bool {
112    true
113}
114
115fn default_parallel() -> usize {
116    1
117}
118
119impl Default for EvalSettings {
120    fn default() -> Self {
121        Self {
122            runs: default_runs(),
123            base_seed: None,
124            record_seeds: true,
125            parallel: default_parallel(),
126            target_tick_duration_ms: None,
127        }
128    }
129}
130
131impl EvalSettings {
132    /// Get target tick duration
133    pub fn target_tick_duration(&self) -> Option<Duration> {
134        self.target_tick_duration_ms.map(Duration::from_millis)
135    }
136}
137
138/// Assertion configuration
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct AssertionConfig {
141    /// Assertion name
142    pub name: String,
143
144    /// Metric to check
145    pub metric: String,
146
147    /// Comparison operator
148    pub op: ComparisonOp,
149
150    /// Expected value
151    pub expected: f64,
152}
153
154/// Comparison operator
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
156#[serde(rename_all = "snake_case")]
157pub enum ComparisonOp {
158    /// Greater than
159    Gt,
160    /// Greater than or equal
161    Gte,
162    /// Less than
163    Lt,
164    /// Less than or equal
165    Lte,
166    /// Equal (within epsilon)
167    Eq,
168}
169
170impl ComparisonOp {
171    /// Check if actual value satisfies the comparison
172    pub fn check(&self, actual: f64, expected: f64) -> bool {
173        const EPSILON: f64 = 1e-9;
174        match self {
175            ComparisonOp::Gt => actual > expected,
176            ComparisonOp::Gte => actual >= expected - EPSILON,
177            ComparisonOp::Lt => actual < expected,
178            ComparisonOp::Lte => actual <= expected + EPSILON,
179            ComparisonOp::Eq => (actual - expected).abs() < EPSILON,
180        }
181    }
182}
183
184/// Fault injection configuration
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct FaultConfig {
187    /// Fault type
188    pub fault_type: FaultType,
189
190    /// Tick range for fault injection (start, end)
191    #[serde(default)]
192    pub tick_range: Option<(u64, u64)>,
193
194    /// Probability of fault occurrence (0.0 - 1.0)
195    #[serde(default = "default_probability")]
196    pub probability: f64,
197
198    /// Duration in ticks (for delay injection)
199    #[serde(default)]
200    pub duration_ticks: Option<u64>,
201
202    /// Target workers (None = all workers)
203    #[serde(default)]
204    pub target_workers: Option<Vec<usize>>,
205}
206
207fn default_probability() -> f64 {
208    1.0
209}
210
211/// Fault type
212#[derive(Debug, Clone, Serialize, Deserialize)]
213#[serde(tag = "type", rename_all = "snake_case")]
214pub enum FaultType {
215    /// Inject delay into tick processing
216    DelayInjection {
217        /// Delay in milliseconds
218        delay_ms: u64,
219    },
220
221    /// Skip worker execution
222    WorkerSkip,
223
224    /// Override worker guidance
225    GuidanceOverride {
226        /// Goal to inject
227        goal: String,
228    },
229
230    /// Cause action to fail
231    ActionFailure,
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_default_config() {
240        let config = EvalConfig::default();
241        assert_eq!(config.eval.runs, 30);
242        assert!(config.eval.record_seeds);
243        assert_eq!(config.eval.parallel, 1);
244    }
245
246    #[test]
247    fn test_parse_minimal_toml() {
248        let toml = r#"
249[eval]
250runs = 10
251"#;
252        let config = EvalConfig::from_toml_str(toml).unwrap();
253        assert_eq!(config.eval.runs, 10);
254    }
255
256    #[test]
257    fn test_parse_with_assertions() {
258        let toml = r#"
259[eval]
260runs = 30
261
262[[assertions]]
263name = "success_rate_threshold"
264metric = "success_rate"
265op = "gte"
266expected = 0.8
267"#;
268        let config = EvalConfig::from_toml_str(toml).unwrap();
269        assert_eq!(config.assertions.len(), 1);
270        assert_eq!(config.assertions[0].name, "success_rate_threshold");
271        assert_eq!(config.assertions[0].op, ComparisonOp::Gte);
272    }
273
274    #[test]
275    fn test_parse_with_faults() {
276        let toml = r#"
277[eval]
278runs = 10
279
280[[faults]]
281fault_type = { type = "delay_injection", delay_ms = 100 }
282tick_range = [10, 50]
283probability = 0.1
284"#;
285        let config = EvalConfig::from_toml_str(toml).unwrap();
286        assert_eq!(config.faults.len(), 1);
287        assert_eq!(config.faults[0].tick_range, Some((10, 50)));
288    }
289
290    #[test]
291    fn test_comparison_op() {
292        assert!(ComparisonOp::Gt.check(0.9, 0.8));
293        assert!(!ComparisonOp::Gt.check(0.8, 0.8));
294
295        assert!(ComparisonOp::Gte.check(0.8, 0.8));
296        assert!(ComparisonOp::Gte.check(0.9, 0.8));
297
298        assert!(ComparisonOp::Lt.check(0.7, 0.8));
299        assert!(!ComparisonOp::Lt.check(0.8, 0.8));
300
301        assert!(ComparisonOp::Lte.check(0.8, 0.8));
302        assert!(ComparisonOp::Lte.check(0.7, 0.8));
303
304        assert!(ComparisonOp::Eq.check(0.8, 0.8));
305        assert!(!ComparisonOp::Eq.check(0.81, 0.8));
306    }
307
308    #[test]
309    fn test_invalid_runs() {
310        let toml = r#"
311[eval]
312runs = 0
313"#;
314        let result = EvalConfig::from_toml_str(toml);
315        assert!(result.is_err());
316    }
317}