1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::Duration;
8
9pub use super::actions::ScenarioActions;
11pub use super::conditions::EvalConditions;
12pub use super::dependency::DependencyGraphConfig;
13pub use super::llm::{LlmConfig, LlmConfigOverride};
14pub use super::manager::{
15 BatchProcessorConfig, ManagerActivationConfig, ManagerConfig, ManagerTemplate,
16};
17pub use super::milestone::Milestone;
18
19#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct TaskConfig {
26 pub goal: String,
28
29 #[serde(default)]
31 pub expected: Option<String>,
32
33 #[serde(default)]
35 pub context: TaskContext,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct TaskContext {
41 #[serde(default)]
43 pub target_path: Option<String>,
44
45 #[serde(default)]
47 pub working_dir: Option<String>,
48
49 #[serde(default)]
51 pub max_depth: Option<usize>,
52
53 #[serde(default, flatten)]
55 pub extra: HashMap<String, toml::Value>,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
64pub struct ScenarioId(pub String);
65
66impl ScenarioId {
67 pub fn new(id: impl Into<String>) -> Self {
68 Self(id.into())
69 }
70
71 pub fn as_str(&self) -> &str {
72 &self.0
73 }
74
75 pub fn learning_key(&self) -> String {
82 let parts: Vec<&str> = self.0.split(':').collect();
83 if parts.len() >= 2 {
84 parts[1].to_string()
86 } else {
87 self.0
89 .chars()
90 .map(|c| {
91 if c.is_alphanumeric() || c == '-' || c == '_' {
92 c
93 } else {
94 '_'
95 }
96 })
97 .collect()
98 }
99 }
100}
101
102impl std::fmt::Display for ScenarioId {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 write!(f, "{}", self.0)
105 }
106}
107
108impl From<&str> for ScenarioId {
109 fn from(s: &str) -> Self {
110 Self::new(s)
111 }
112}
113
114impl From<String> for ScenarioId {
115 fn from(s: String) -> Self {
116 Self::new(s)
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ScenarioVariant {
127 pub name: String,
129
130 #[serde(default)]
132 pub description: String,
133
134 #[serde(default)]
136 pub llm: Option<LlmConfigOverride>,
137
138 #[serde(default)]
140 pub environment_params: serde_json::Value,
141
142 #[serde(default)]
144 pub dependency_graph: Option<DependencyGraphConfig>,
145
146 #[serde(default)]
148 pub app_config: Option<AppConfigOverride>,
149
150 #[serde(default)]
152 pub max_ticks: Option<u64>,
153
154 #[serde(default)]
156 pub workers_count: Option<usize>,
157
158 #[serde(default)]
160 pub managers_count: Option<usize>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct EvalScenario {
170 pub meta: ScenarioMeta,
172
173 #[serde(default)]
175 pub task: TaskConfig,
176
177 #[serde(default)]
179 pub llm: LlmConfig,
180
181 #[serde(default)]
183 pub manager: ManagerConfig,
184
185 #[serde(default)]
187 pub batch_processor: BatchProcessorConfig,
188
189 #[serde(default)]
191 pub dependency_graph: Option<DependencyGraphConfig>,
192
193 #[serde(default)]
195 pub actions: ScenarioActions,
196
197 pub app_config: AppConfigTemplate,
199
200 pub environment: EnvironmentConfig,
202
203 pub agents: AgentsConfig,
205
206 pub conditions: EvalConditions,
208
209 #[serde(default)]
211 pub milestones: Vec<Milestone>,
212
213 #[serde(default)]
215 pub variants: Vec<ScenarioVariant>,
216}
217
218impl EvalScenario {
219 pub fn with_variant(&self, variant_name: &str) -> Option<EvalScenario> {
221 let variant = self.variants.iter().find(|v| v.name == variant_name)?;
222
223 let mut scenario = self.clone();
224
225 if let Some(ref llm_override) = variant.llm {
227 llm_override.apply_to(&mut scenario.llm);
228 }
229
230 if !variant.environment_params.is_null() {
232 if let serde_json::Value::Object(override_map) = &variant.environment_params {
233 if let serde_json::Value::Object(ref mut base_map) = scenario.environment.params {
234 for (key, value) in override_map {
235 base_map.insert(key.clone(), value.clone());
236 }
237 }
238 }
239 }
240
241 if variant.dependency_graph.is_some() {
243 scenario.dependency_graph = variant.dependency_graph.clone();
244 }
245
246 if let Some(ref app_override) = variant.app_config {
248 if let Some(ref strategy) = app_override.management_strategy {
249 scenario.app_config.management_strategy = strategy.clone();
250 }
251 if let Some(tick_ms) = app_override.tick_duration_ms {
252 scenario.app_config.tick_duration_ms = tick_ms;
253 }
254 if let Some(enable_exp) = app_override.enable_exploration {
255 scenario.app_config.enable_exploration = enable_exp;
256 }
257 }
258
259 if let Some(max_ticks) = variant.max_ticks {
261 scenario.app_config.max_ticks = max_ticks;
262 }
263
264 if let Some(workers_count) = variant.workers_count {
266 if let Some(first_worker) = scenario.agents.workers.first_mut() {
267 first_worker.count = workers_count;
268 }
269 }
270
271 if let Some(managers_count) = variant.managers_count {
273 if let Some(first_manager) = scenario.agents.managers.first_mut() {
274 first_manager.count = managers_count;
275 if first_manager.id_pattern.is_none() {
277 if let Some(ref id) = first_manager.id {
278 first_manager.id_pattern = Some(format!("{}_{{i}}", id));
279 first_manager.id = None;
280 }
281 }
282 }
283 }
284
285 scenario.meta.name = format!("{} ({})", self.meta.name, variant_name);
287
288 Some(scenario)
289 }
290
291 pub fn variant_names(&self) -> Vec<&str> {
293 self.variants.iter().map(|v| v.name.as_str()).collect()
294 }
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct ScenarioMeta {
304 pub name: String,
306
307 #[serde(default = "default_version")]
309 pub version: String,
310
311 pub id: ScenarioId,
313
314 #[serde(default)]
316 pub description: String,
317
318 #[serde(default)]
320 pub tags: Vec<String>,
321}
322
323fn default_version() -> String {
324 "1.0.0".to_string()
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct AppConfigTemplate {
337 #[serde(default = "default_tick_duration_ms")]
339 pub tick_duration_ms: u64,
340
341 #[serde(default = "default_max_ticks")]
343 pub max_ticks: u64,
344
345 #[serde(default)]
347 pub management_strategy: ManagementStrategyConfig,
348
349 #[serde(default)]
351 pub enable_exploration: bool,
352}
353
354fn default_tick_duration_ms() -> u64 {
355 10
356}
357
358fn default_max_ticks() -> u64 {
359 1000
360}
361
362impl AppConfigTemplate {
363 pub fn tick_duration(&self) -> Duration {
364 Duration::from_millis(self.tick_duration_ms)
365 }
366}
367
368impl Default for AppConfigTemplate {
369 fn default() -> Self {
370 Self {
371 tick_duration_ms: default_tick_duration_ms(),
372 max_ticks: default_max_ticks(),
373 management_strategy: ManagementStrategyConfig::default(),
374 enable_exploration: false,
375 }
376 }
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize, Default)]
383pub struct AppConfigOverride {
384 #[serde(default)]
386 pub management_strategy: Option<ManagementStrategyConfig>,
387
388 #[serde(default)]
390 pub tick_duration_ms: Option<u64>,
391
392 #[serde(default)]
394 pub enable_exploration: Option<bool>,
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
403#[serde(tag = "type", rename_all = "snake_case")]
404pub enum ManagementStrategyConfig {
405 EveryTick {},
409
410 IntervalBased {
412 #[serde(default = "default_max_interval")]
413 max_interval: u64,
414 },
415 EventDriven {
417 #[serde(default)]
418 triggers: Vec<String>,
419 },
420 Hybrid {
422 #[serde(default = "default_max_interval")]
423 max_interval: u64,
424 #[serde(default)]
425 triggers: Vec<String>,
426 },
427 #[serde(alias = "disabled")]
429 Disabled {},
430}
431
432fn default_max_interval() -> u64 {
433 20
434}
435
436impl Default for ManagementStrategyConfig {
437 fn default() -> Self {
438 Self::IntervalBased {
439 max_interval: default_max_interval(),
440 }
441 }
442}
443
444#[derive(Debug, Clone, Serialize, Deserialize)]
450pub struct EnvironmentConfig {
451 pub env_type: String,
453
454 #[serde(default)]
456 pub params: serde_json::Value,
457
458 #[serde(default)]
460 pub initial_state: Option<InitialStateConfig>,
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize)]
465#[serde(tag = "type", rename_all = "snake_case")]
466pub enum InitialStateConfig {
467 #[serde(alias = "seeded_random")]
469 SeededRandom {},
470 Fixed {
472 state: serde_json::Value,
474 },
475 Custom {
477 generator: String,
479 params: serde_json::Value,
481 },
482}
483
484#[derive(Debug, Clone, Default, Serialize, Deserialize)]
490pub struct AgentsConfig {
491 #[serde(default)]
493 pub workers: Vec<WorkerTemplate>,
494
495 #[serde(default)]
497 pub managers: Vec<ManagerTemplate>,
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct WorkerTemplate {
503 pub id_pattern: String,
505
506 #[serde(default = "default_worker_count")]
508 pub count: usize,
509
510 #[serde(default)]
512 pub role: String,
513
514 #[serde(default)]
516 pub config: serde_json::Value,
517}
518
519fn default_worker_count() -> usize {
520 1
521}
522
523impl WorkerTemplate {
524 pub fn generate_ids(&self) -> Vec<String> {
526 (0..self.count)
527 .map(|i| self.id_pattern.replace("{i}", &i.to_string()))
528 .collect()
529 }
530}
531
532#[cfg(test)]
537mod tests {
538 use super::*;
539 use crate::scenario::llm::LlmProvider;
540
541 #[test]
542 fn test_scenario_id() {
543 let id = ScenarioId::new("test:scenario:v1");
544 assert_eq!(id.as_str(), "test:scenario:v1");
545 }
546
547 #[test]
548 fn test_scenario_id_learning_key() {
549 let id = ScenarioId::new("user:troubleshooting:v2");
551 assert_eq!(id.learning_key(), "troubleshooting");
552
553 let id = ScenarioId::new("builtin:resource_gathering:v1");
555 assert_eq!(id.learning_key(), "resource_gathering");
556
557 let id = ScenarioId::new("simple_scenario");
559 assert_eq!(id.learning_key(), "simple_scenario");
560
561 let id = ScenarioId::new("Service Troubleshooting");
563 assert_eq!(id.learning_key(), "Service_Troubleshooting");
564 }
565
566 #[test]
567 fn test_worker_template_generate_ids() {
568 let template = WorkerTemplate {
569 id_pattern: "worker_{i}".to_string(),
570 count: 3,
571 role: "gatherer".to_string(),
572 config: serde_json::Value::Null,
573 };
574
575 let ids = template.generate_ids();
576 assert_eq!(ids, vec!["worker_0", "worker_1", "worker_2"]);
577 }
578
579 #[test]
580 fn test_app_config_template_default() {
581 let config = AppConfigTemplate::default();
582 assert_eq!(config.tick_duration_ms, 10);
583 assert_eq!(config.max_ticks, 1000);
584 }
585
586 #[test]
587 fn test_management_strategy_deserialize() {
588 let json = r#"{"type": "hybrid", "max_interval": 30, "triggers": ["event_a"]}"#;
589 let strategy: ManagementStrategyConfig = serde_json::from_str(json).unwrap();
590
591 match strategy {
592 ManagementStrategyConfig::Hybrid {
593 max_interval,
594 triggers,
595 } => {
596 assert_eq!(max_interval, 30);
597 assert_eq!(triggers, vec!["event_a"]);
598 }
599 _ => panic!("Expected Hybrid variant"),
600 }
601 }
602
603 #[test]
604 fn test_task_config_default() {
605 let task = TaskConfig::default();
606 assert!(task.goal.is_empty());
607 assert!(task.expected.is_none());
608 }
609
610 #[test]
611 fn test_task_config_deserialize_toml() {
612 let toml_str = r#"
613 goal = "Find the function that handles authentication"
614 expected = "src/auth/handler.rs:42"
615 [context]
616 target_path = "/path/to/codebase"
617 working_dir = "/path/to/codebase"
618 max_depth = 5
619 "#;
620
621 let task: TaskConfig = toml::from_str(toml_str).unwrap();
622 assert_eq!(task.goal, "Find the function that handles authentication");
623 assert_eq!(task.expected, Some("src/auth/handler.rs:42".to_string()));
624 assert_eq!(
625 task.context.target_path,
626 Some("/path/to/codebase".to_string())
627 );
628 assert_eq!(task.context.max_depth, Some(5));
629 }
630
631 #[test]
632 fn test_scenario_variant_with_llm_override() {
633 let toml_str = r#"
634 [meta]
635 name = "Test Scenario"
636 id = "test:scenario:v1"
637
638 [task]
639 goal = "Test goal"
640
641 [llm]
642 provider = "ollama"
643 model = "llama3:8b"
644
645 [app_config]
646 max_ticks = 100
647
648 [environment]
649 env_type = "test"
650
651 [agents]
652
653 [conditions]
654 on_timeout = "fail"
655
656 [[variants]]
657 name = "mistral"
658 description = "Use mistral.rs local inference"
659 [variants.llm]
660 provider = "mistral"
661 model = "LiquidAI/LFM2.5-1.2B-Instruct-GGUF"
662 gguf_files = ["LFM2.5-1.2B-Instruct-Q4_K_M.gguf"]
663 "#;
664
665 let scenario: EvalScenario = toml::from_str(toml_str).unwrap();
666 assert_eq!(scenario.llm.provider, LlmProvider::Ollama);
667 assert_eq!(scenario.variants.len(), 1);
668
669 let mistral_scenario = scenario.with_variant("mistral").unwrap();
671 assert_eq!(mistral_scenario.llm.provider, LlmProvider::Mistral);
672 assert_eq!(
673 mistral_scenario.llm.model,
674 "LiquidAI/LFM2.5-1.2B-Instruct-GGUF"
675 );
676 assert!(mistral_scenario.llm.is_gguf());
677 assert_eq!(mistral_scenario.meta.name, "Test Scenario (mistral)");
678 }
679
680 #[test]
681 fn test_scenario_variant_partial_llm_override() {
682 let toml_str = r#"
683 [meta]
684 name = "Test"
685 id = "test:v1"
686
687 [task]
688 goal = "Test"
689
690 [llm]
691 provider = "ollama"
692 model = "llama3:8b"
693 temperature = 0.1
694 num_ctx = 4096
695
696 [app_config]
697 max_ticks = 100
698
699 [environment]
700 env_type = "test"
701
702 [agents]
703
704 [conditions]
705 on_timeout = "fail"
706
707 [[variants]]
708 name = "high_temp"
709 [variants.llm]
710 temperature = 0.9
711 "#;
712
713 let scenario: EvalScenario = toml::from_str(toml_str).unwrap();
714 let variant = scenario.with_variant("high_temp").unwrap();
715
716 assert!((variant.llm.temperature - 0.9).abs() < f32::EPSILON);
718 assert_eq!(variant.llm.provider, LlmProvider::Ollama);
720 assert_eq!(variant.llm.model, "llama3:8b");
721 assert_eq!(variant.llm.num_ctx, Some(4096));
722 }
723}