Skip to main content

swarm_engine_eval/
config.rs

1//! Evaluation configuration
2//!
3//! TOML 設定ファイルから評価設定を読み込みます。
4
5use std::path::Path;
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10use crate::error::{EvalError, Result};
11
12/// Evaluation configuration
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct EvalConfig {
15    /// Orchestrator settings
16    #[serde(default)]
17    pub orchestrator: OrchestratorSettings,
18
19    /// Evaluation-specific settings
20    #[serde(default)]
21    pub eval: EvalSettings,
22
23    /// Assertions to verify
24    #[serde(default)]
25    pub assertions: Vec<AssertionConfig>,
26
27    /// Fault injection configurations
28    #[serde(default)]
29    pub faults: Vec<FaultConfig>,
30}
31
32/// Orchestrator settings
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct OrchestratorSettings {
35    /// Tick duration in milliseconds
36    #[serde(default = "default_tick_duration_ms")]
37    pub tick_duration_ms: u64,
38
39    /// Maximum ticks
40    #[serde(default = "default_max_ticks")]
41    pub max_ticks: u64,
42
43    /// DependencyGraph プロバイダーの種類
44    #[serde(default)]
45    pub dependency_provider: DependencyProviderKind,
46}
47
48/// DependencyGraph プロバイダーの種類
49///
50/// 学習済みアクション順序からグラフを提供する方式を選択する。
51///
52/// Note: `Smart` と `Learned` は統合され、どちらも `LearnedDependencyProvider` を使用。
53/// 後方互換性のため両方の値を受け付けるが、動作は同一。
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
55#[serde(rename_all = "snake_case")]
56pub enum DependencyProviderKind {
57    /// LearnedDependencyProvider
58    ///
59    /// 100% 一致時は学習済みグラフを使用。
60    /// 部分一致時は `select()` で投票戦略を決定。
61    Learned,
62
63    /// LearnedDependencyProvider(Smart は Learned に統合)
64    ///
65    /// 後方互換性のためのエイリアス。動作は `Learned` と同一。
66    #[default]
67    Smart,
68}
69
70fn default_tick_duration_ms() -> u64 {
71    10
72}
73
74fn default_max_ticks() -> u64 {
75    1000
76}
77
78impl Default for OrchestratorSettings {
79    fn default() -> Self {
80        Self {
81            tick_duration_ms: default_tick_duration_ms(),
82            max_ticks: default_max_ticks(),
83            dependency_provider: DependencyProviderKind::default(),
84        }
85    }
86}
87
88impl EvalConfig {
89    /// Load from TOML file
90    pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
91        let content = std::fs::read_to_string(path)?;
92        Self::from_toml_str(&content)
93    }
94
95    /// Parse from TOML string
96    pub fn from_toml_str(content: &str) -> Result<Self> {
97        let config: EvalConfig = toml::from_str(content)?;
98        config.validate()?;
99        Ok(config)
100    }
101
102    /// Validate configuration
103    fn validate(&self) -> Result<()> {
104        if self.eval.runs == 0 {
105            return Err(EvalError::Config("runs must be > 0".to_string()));
106        }
107        Ok(())
108    }
109}
110
111/// Evaluation-specific settings
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct EvalSettings {
114    /// Number of runs for statistical analysis
115    #[serde(default = "default_runs")]
116    pub runs: usize,
117
118    /// Base seed for reproducibility (None = use current time)
119    pub base_seed: Option<u64>,
120
121    /// Record seeds in report
122    #[serde(default = "default_true")]
123    pub record_seeds: bool,
124
125    /// Parallel execution (number of concurrent runs)
126    #[serde(default = "default_parallel")]
127    pub parallel: usize,
128
129    /// Target tick duration for miss rate calculation
130    #[serde(default)]
131    pub target_tick_duration_ms: Option<u64>,
132}
133
134fn default_runs() -> usize {
135    30
136}
137
138fn default_true() -> bool {
139    true
140}
141
142fn default_parallel() -> usize {
143    1
144}
145
146impl Default for EvalSettings {
147    fn default() -> Self {
148        Self {
149            runs: default_runs(),
150            base_seed: None,
151            record_seeds: true,
152            parallel: default_parallel(),
153            target_tick_duration_ms: None,
154        }
155    }
156}
157
158impl EvalSettings {
159    /// Get target tick duration
160    pub fn target_tick_duration(&self) -> Option<Duration> {
161        self.target_tick_duration_ms.map(Duration::from_millis)
162    }
163}
164
165/// Assertion configuration
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct AssertionConfig {
168    /// Assertion name
169    pub name: String,
170
171    /// Metric to check
172    pub metric: String,
173
174    /// Comparison operator
175    pub op: ComparisonOp,
176
177    /// Expected value
178    pub expected: f64,
179}
180
181/// Comparison operator
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
183#[serde(rename_all = "snake_case")]
184pub enum ComparisonOp {
185    /// Greater than
186    Gt,
187    /// Greater than or equal
188    Gte,
189    /// Less than
190    Lt,
191    /// Less than or equal
192    Lte,
193    /// Equal (within epsilon)
194    Eq,
195}
196
197impl ComparisonOp {
198    /// Check if actual value satisfies the comparison
199    pub fn check(&self, actual: f64, expected: f64) -> bool {
200        const EPSILON: f64 = 1e-9;
201        match self {
202            ComparisonOp::Gt => actual > expected,
203            ComparisonOp::Gte => actual >= expected - EPSILON,
204            ComparisonOp::Lt => actual < expected,
205            ComparisonOp::Lte => actual <= expected + EPSILON,
206            ComparisonOp::Eq => (actual - expected).abs() < EPSILON,
207        }
208    }
209}
210
211/// Fault injection configuration
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct FaultConfig {
214    /// Fault type
215    pub fault_type: FaultType,
216
217    /// Tick range for fault injection (start, end)
218    #[serde(default)]
219    pub tick_range: Option<(u64, u64)>,
220
221    /// Probability of fault occurrence (0.0 - 1.0)
222    #[serde(default = "default_probability")]
223    pub probability: f64,
224
225    /// Duration in ticks (for delay injection)
226    #[serde(default)]
227    pub duration_ticks: Option<u64>,
228
229    /// Target workers (None = all workers)
230    #[serde(default)]
231    pub target_workers: Option<Vec<usize>>,
232}
233
234fn default_probability() -> f64 {
235    1.0
236}
237
238/// Fault type
239#[derive(Debug, Clone, Serialize, Deserialize)]
240#[serde(tag = "type", rename_all = "snake_case")]
241pub enum FaultType {
242    /// Inject delay into tick processing
243    DelayInjection {
244        /// Delay in milliseconds
245        delay_ms: u64,
246    },
247
248    /// Skip worker execution
249    WorkerSkip,
250
251    /// Override worker guidance
252    GuidanceOverride {
253        /// Goal to inject
254        goal: String,
255    },
256
257    /// Cause action to fail
258    ActionFailure,
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_default_config() {
267        let config = EvalConfig::default();
268        assert_eq!(config.eval.runs, 30);
269        assert!(config.eval.record_seeds);
270        assert_eq!(config.eval.parallel, 1);
271    }
272
273    #[test]
274    fn test_parse_minimal_toml() {
275        let toml = r#"
276[eval]
277runs = 10
278"#;
279        let config = EvalConfig::from_toml_str(toml).unwrap();
280        assert_eq!(config.eval.runs, 10);
281    }
282
283    #[test]
284    fn test_parse_with_assertions() {
285        let toml = r#"
286[eval]
287runs = 30
288
289[[assertions]]
290name = "success_rate_threshold"
291metric = "success_rate"
292op = "gte"
293expected = 0.8
294"#;
295        let config = EvalConfig::from_toml_str(toml).unwrap();
296        assert_eq!(config.assertions.len(), 1);
297        assert_eq!(config.assertions[0].name, "success_rate_threshold");
298        assert_eq!(config.assertions[0].op, ComparisonOp::Gte);
299    }
300
301    #[test]
302    fn test_parse_with_faults() {
303        let toml = r#"
304[eval]
305runs = 10
306
307[[faults]]
308fault_type = { type = "delay_injection", delay_ms = 100 }
309tick_range = [10, 50]
310probability = 0.1
311"#;
312        let config = EvalConfig::from_toml_str(toml).unwrap();
313        assert_eq!(config.faults.len(), 1);
314        assert_eq!(config.faults[0].tick_range, Some((10, 50)));
315    }
316
317    #[test]
318    fn test_comparison_op() {
319        assert!(ComparisonOp::Gt.check(0.9, 0.8));
320        assert!(!ComparisonOp::Gt.check(0.8, 0.8));
321
322        assert!(ComparisonOp::Gte.check(0.8, 0.8));
323        assert!(ComparisonOp::Gte.check(0.9, 0.8));
324
325        assert!(ComparisonOp::Lt.check(0.7, 0.8));
326        assert!(!ComparisonOp::Lt.check(0.8, 0.8));
327
328        assert!(ComparisonOp::Lte.check(0.8, 0.8));
329        assert!(ComparisonOp::Lte.check(0.7, 0.8));
330
331        assert!(ComparisonOp::Eq.check(0.8, 0.8));
332        assert!(!ComparisonOp::Eq.check(0.81, 0.8));
333    }
334
335    #[test]
336    fn test_invalid_runs() {
337        let toml = r#"
338[eval]
339runs = 0
340"#;
341        let result = EvalConfig::from_toml_str(toml);
342        assert!(result.is_err());
343    }
344}