1use serde::{Deserialize, Serialize};
65use std::collections::HashMap;
66use std::time::Duration;
67
68pub use super::actions::ScenarioActions;
70pub use super::conditions::EvalConditions;
71pub use super::dependency::DependencyGraphConfig;
72pub use super::llm::{LlmConfig, LlmConfigOverride};
73pub use super::manager::{
74 BatchProcessorConfig, ManagerActivationConfig, ManagerConfig, ManagerTemplate,
75};
76pub use super::milestone::Milestone;
77
78#[derive(Debug, Clone, Serialize, Deserialize, Default)]
84pub struct TaskConfig {
85 pub goal: String,
87
88 #[serde(default)]
90 pub expected: Option<String>,
91
92 #[serde(default)]
94 pub context: TaskContext,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, Default)]
99pub struct TaskContext {
100 #[serde(default)]
102 pub target_path: Option<String>,
103
104 #[serde(default)]
106 pub working_dir: Option<String>,
107
108 #[serde(default)]
110 pub max_depth: Option<usize>,
111
112 #[serde(default, flatten)]
114 pub extra: HashMap<String, toml::Value>,
115}
116
117#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
123pub struct ScenarioId(pub String);
124
125impl ScenarioId {
126 pub fn new(id: impl Into<String>) -> Self {
127 Self(id.into())
128 }
129
130 pub fn as_str(&self) -> &str {
131 &self.0
132 }
133
134 pub fn learning_key(&self) -> String {
141 let parts: Vec<&str> = self.0.split(':').collect();
142 if parts.len() >= 2 {
143 parts[1].to_string()
145 } else {
146 self.0
148 .chars()
149 .map(|c| {
150 if c.is_alphanumeric() || c == '-' || c == '_' {
151 c
152 } else {
153 '_'
154 }
155 })
156 .collect()
157 }
158 }
159}
160
161impl std::fmt::Display for ScenarioId {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 write!(f, "{}", self.0)
164 }
165}
166
167impl From<&str> for ScenarioId {
168 fn from(s: &str) -> Self {
169 Self::new(s)
170 }
171}
172
173impl From<String> for ScenarioId {
174 fn from(s: String) -> Self {
175 Self::new(s)
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ScenarioVariant {
186 pub name: String,
188
189 #[serde(default)]
191 pub description: String,
192
193 #[serde(default)]
195 pub llm: Option<LlmConfigOverride>,
196
197 #[serde(default)]
199 pub environment_params: serde_json::Value,
200
201 #[serde(default)]
203 pub dependency_graph: Option<DependencyGraphConfig>,
204
205 #[serde(default)]
207 pub app_config: Option<AppConfigOverride>,
208
209 #[serde(default)]
211 pub max_ticks: Option<u64>,
212
213 #[serde(default)]
215 pub workers_count: Option<usize>,
216
217 #[serde(default)]
219 pub managers_count: Option<usize>,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct EvalScenario {
229 pub meta: ScenarioMeta,
231
232 #[serde(default)]
234 pub task: TaskConfig,
235
236 #[serde(default)]
238 pub llm: LlmConfig,
239
240 #[serde(default)]
242 pub manager: ManagerConfig,
243
244 #[serde(default)]
246 pub batch_processor: BatchProcessorConfig,
247
248 #[serde(default)]
250 pub dependency_graph: Option<DependencyGraphConfig>,
251
252 #[serde(default)]
254 pub actions: ScenarioActions,
255
256 pub app_config: AppConfigTemplate,
258
259 pub environment: EnvironmentConfig,
261
262 pub agents: AgentsConfig,
264
265 pub conditions: EvalConditions,
267
268 #[serde(default)]
270 pub milestones: Vec<Milestone>,
271
272 #[serde(default)]
274 pub variants: Vec<ScenarioVariant>,
275}
276
277impl EvalScenario {
278 pub fn with_variant(&self, variant_name: &str) -> Option<EvalScenario> {
280 let variant = self.variants.iter().find(|v| v.name == variant_name)?;
281
282 let mut scenario = self.clone();
283
284 if let Some(ref llm_override) = variant.llm {
286 llm_override.apply_to(&mut scenario.llm);
287 }
288
289 if !variant.environment_params.is_null() {
291 if let serde_json::Value::Object(override_map) = &variant.environment_params {
292 if let serde_json::Value::Object(ref mut base_map) = scenario.environment.params {
293 for (key, value) in override_map {
294 base_map.insert(key.clone(), value.clone());
295 }
296 }
297 }
298 }
299
300 if variant.dependency_graph.is_some() {
302 scenario.dependency_graph = variant.dependency_graph.clone();
303 }
304
305 if let Some(ref app_override) = variant.app_config {
307 if let Some(ref strategy) = app_override.management_strategy {
308 scenario.app_config.management_strategy = strategy.clone();
309 }
310 if let Some(tick_ms) = app_override.tick_duration_ms {
311 scenario.app_config.tick_duration_ms = tick_ms;
312 }
313 if let Some(enable_exp) = app_override.enable_exploration {
314 scenario.app_config.enable_exploration = enable_exp;
315 }
316 }
317
318 if let Some(max_ticks) = variant.max_ticks {
320 scenario.app_config.max_ticks = max_ticks;
321 }
322
323 if let Some(workers_count) = variant.workers_count {
325 if let Some(first_worker) = scenario.agents.workers.first_mut() {
326 first_worker.count = workers_count;
327 }
328 }
329
330 if let Some(managers_count) = variant.managers_count {
332 if let Some(first_manager) = scenario.agents.managers.first_mut() {
333 first_manager.count = managers_count;
334 if first_manager.id_pattern.is_none() {
336 if let Some(ref id) = first_manager.id {
337 first_manager.id_pattern = Some(format!("{}_{{i}}", id));
338 first_manager.id = None;
339 }
340 }
341 }
342 }
343
344 scenario.meta.name = format!("{} ({})", self.meta.name, variant_name);
346
347 Some(scenario)
348 }
349
350 pub fn variant_names(&self) -> Vec<&str> {
352 self.variants.iter().map(|v| v.name.as_str()).collect()
353 }
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct ScenarioMeta {
363 pub name: String,
365
366 #[serde(default = "default_version")]
368 pub version: String,
369
370 pub id: ScenarioId,
372
373 #[serde(default)]
375 pub description: String,
376
377 #[serde(default)]
379 pub tags: Vec<String>,
380}
381
382fn default_version() -> String {
383 "1.0.0".to_string()
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
395pub struct AppConfigTemplate {
396 #[serde(default = "default_tick_duration_ms")]
398 pub tick_duration_ms: u64,
399
400 #[serde(default = "default_max_ticks")]
402 pub max_ticks: u64,
403
404 #[serde(default)]
406 pub management_strategy: ManagementStrategyConfig,
407
408 #[serde(default)]
410 pub enable_exploration: bool,
411}
412
413fn default_tick_duration_ms() -> u64 {
414 10
415}
416
417fn default_max_ticks() -> u64 {
418 1000
419}
420
421impl AppConfigTemplate {
422 pub fn tick_duration(&self) -> Duration {
423 Duration::from_millis(self.tick_duration_ms)
424 }
425}
426
427impl Default for AppConfigTemplate {
428 fn default() -> Self {
429 Self {
430 tick_duration_ms: default_tick_duration_ms(),
431 max_ticks: default_max_ticks(),
432 management_strategy: ManagementStrategyConfig::default(),
433 enable_exploration: false,
434 }
435 }
436}
437
438#[derive(Debug, Clone, Serialize, Deserialize, Default)]
442pub struct AppConfigOverride {
443 #[serde(default)]
445 pub management_strategy: Option<ManagementStrategyConfig>,
446
447 #[serde(default)]
449 pub tick_duration_ms: Option<u64>,
450
451 #[serde(default)]
453 pub enable_exploration: Option<bool>,
454}
455
456#[derive(Debug, Clone, Serialize, Deserialize)]
462#[serde(tag = "type", rename_all = "snake_case")]
463pub enum ManagementStrategyConfig {
464 EveryTick {},
468
469 IntervalBased {
471 #[serde(default = "default_max_interval")]
472 max_interval: u64,
473 },
474 EventDriven {
476 #[serde(default)]
477 triggers: Vec<String>,
478 },
479 Hybrid {
481 #[serde(default = "default_max_interval")]
482 max_interval: u64,
483 #[serde(default)]
484 triggers: Vec<String>,
485 },
486 #[serde(alias = "disabled")]
488 Disabled {},
489}
490
491fn default_max_interval() -> u64 {
492 20
493}
494
495impl Default for ManagementStrategyConfig {
496 fn default() -> Self {
497 Self::IntervalBased {
498 max_interval: default_max_interval(),
499 }
500 }
501}
502
503#[derive(Debug, Clone, Serialize, Deserialize)]
509pub struct EnvironmentConfig {
510 pub env_type: String,
512
513 #[serde(default)]
515 pub params: serde_json::Value,
516
517 #[serde(default)]
519 pub initial_state: Option<InitialStateConfig>,
520}
521
522#[derive(Debug, Clone, Serialize, Deserialize)]
524#[serde(tag = "type", rename_all = "snake_case")]
525pub enum InitialStateConfig {
526 #[serde(alias = "seeded_random")]
528 SeededRandom {},
529 Fixed {
531 state: serde_json::Value,
533 },
534 Custom {
536 generator: String,
538 params: serde_json::Value,
540 },
541}
542
543#[derive(Debug, Clone, Default, Serialize, Deserialize)]
549pub struct AgentsConfig {
550 #[serde(default)]
552 pub workers: Vec<WorkerTemplate>,
553
554 #[serde(default)]
556 pub managers: Vec<ManagerTemplate>,
557}
558
559#[derive(Debug, Clone, Serialize, Deserialize)]
561pub struct WorkerTemplate {
562 pub id_pattern: String,
564
565 #[serde(default = "default_worker_count")]
567 pub count: usize,
568
569 #[serde(default)]
571 pub role: String,
572
573 #[serde(default)]
575 pub config: serde_json::Value,
576}
577
578fn default_worker_count() -> usize {
579 1
580}
581
582impl WorkerTemplate {
583 pub fn generate_ids(&self) -> Vec<String> {
585 (0..self.count)
586 .map(|i| self.id_pattern.replace("{i}", &i.to_string()))
587 .collect()
588 }
589}
590
591#[cfg(test)]
596mod tests {
597 use super::*;
598 use crate::scenario::llm::LlmProvider;
599
600 #[test]
601 fn test_scenario_id() {
602 let id = ScenarioId::new("test:scenario:v1");
603 assert_eq!(id.as_str(), "test:scenario:v1");
604 }
605
606 #[test]
607 fn test_scenario_id_learning_key() {
608 let id = ScenarioId::new("user:troubleshooting:v2");
610 assert_eq!(id.learning_key(), "troubleshooting");
611
612 let id = ScenarioId::new("builtin:resource_gathering:v1");
614 assert_eq!(id.learning_key(), "resource_gathering");
615
616 let id = ScenarioId::new("simple_scenario");
618 assert_eq!(id.learning_key(), "simple_scenario");
619
620 let id = ScenarioId::new("Service Troubleshooting");
622 assert_eq!(id.learning_key(), "Service_Troubleshooting");
623 }
624
625 #[test]
626 fn test_worker_template_generate_ids() {
627 let template = WorkerTemplate {
628 id_pattern: "worker_{i}".to_string(),
629 count: 3,
630 role: "gatherer".to_string(),
631 config: serde_json::Value::Null,
632 };
633
634 let ids = template.generate_ids();
635 assert_eq!(ids, vec!["worker_0", "worker_1", "worker_2"]);
636 }
637
638 #[test]
639 fn test_app_config_template_default() {
640 let config = AppConfigTemplate::default();
641 assert_eq!(config.tick_duration_ms, 10);
642 assert_eq!(config.max_ticks, 1000);
643 }
644
645 #[test]
646 fn test_management_strategy_deserialize() {
647 let json = r#"{"type": "hybrid", "max_interval": 30, "triggers": ["event_a"]}"#;
648 let strategy: ManagementStrategyConfig = serde_json::from_str(json).unwrap();
649
650 match strategy {
651 ManagementStrategyConfig::Hybrid {
652 max_interval,
653 triggers,
654 } => {
655 assert_eq!(max_interval, 30);
656 assert_eq!(triggers, vec!["event_a"]);
657 }
658 _ => panic!("Expected Hybrid variant"),
659 }
660 }
661
662 #[test]
663 fn test_task_config_default() {
664 let task = TaskConfig::default();
665 assert!(task.goal.is_empty());
666 assert!(task.expected.is_none());
667 }
668
669 #[test]
670 fn test_task_config_deserialize_toml() {
671 let toml_str = r#"
672 goal = "Find the function that handles authentication"
673 expected = "src/auth/handler.rs:42"
674 [context]
675 target_path = "/path/to/codebase"
676 working_dir = "/path/to/codebase"
677 max_depth = 5
678 "#;
679
680 let task: TaskConfig = toml::from_str(toml_str).unwrap();
681 assert_eq!(task.goal, "Find the function that handles authentication");
682 assert_eq!(task.expected, Some("src/auth/handler.rs:42".to_string()));
683 assert_eq!(
684 task.context.target_path,
685 Some("/path/to/codebase".to_string())
686 );
687 assert_eq!(task.context.max_depth, Some(5));
688 }
689
690 #[test]
691 fn test_scenario_variant_with_llm_override() {
692 let toml_str = r#"
693 [meta]
694 name = "Test Scenario"
695 id = "test:scenario:v1"
696
697 [task]
698 goal = "Test goal"
699
700 [llm]
701 provider = "ollama"
702 model = "llama3:8b"
703
704 [app_config]
705 max_ticks = 100
706
707 [environment]
708 env_type = "test"
709
710 [agents]
711
712 [conditions]
713 on_timeout = "fail"
714
715 [[variants]]
716 name = "mistral"
717 description = "Use mistral.rs local inference"
718 [variants.llm]
719 provider = "mistral"
720 model = "LiquidAI/LFM2.5-1.2B-Instruct-GGUF"
721 gguf_files = ["LFM2.5-1.2B-Instruct-Q4_K_M.gguf"]
722 "#;
723
724 let scenario: EvalScenario = toml::from_str(toml_str).unwrap();
725 assert_eq!(scenario.llm.provider, LlmProvider::Ollama);
726 assert_eq!(scenario.variants.len(), 1);
727
728 let mistral_scenario = scenario.with_variant("mistral").unwrap();
730 assert_eq!(mistral_scenario.llm.provider, LlmProvider::Mistral);
731 assert_eq!(
732 mistral_scenario.llm.model,
733 "LiquidAI/LFM2.5-1.2B-Instruct-GGUF"
734 );
735 assert!(mistral_scenario.llm.is_gguf());
736 assert_eq!(mistral_scenario.meta.name, "Test Scenario (mistral)");
737 }
738
739 #[test]
740 fn test_scenario_variant_partial_llm_override() {
741 let toml_str = r#"
742 [meta]
743 name = "Test"
744 id = "test:v1"
745
746 [task]
747 goal = "Test"
748
749 [llm]
750 provider = "ollama"
751 model = "llama3:8b"
752 temperature = 0.1
753 num_ctx = 4096
754
755 [app_config]
756 max_ticks = 100
757
758 [environment]
759 env_type = "test"
760
761 [agents]
762
763 [conditions]
764 on_timeout = "fail"
765
766 [[variants]]
767 name = "high_temp"
768 [variants.llm]
769 temperature = 0.9
770 "#;
771
772 let scenario: EvalScenario = toml::from_str(toml_str).unwrap();
773 let variant = scenario.with_variant("high_temp").unwrap();
774
775 assert!((variant.llm.temperature - 0.9).abs() < f32::EPSILON);
777 assert_eq!(variant.llm.provider, LlmProvider::Ollama);
779 assert_eq!(variant.llm.model, "llama3:8b");
780 assert_eq!(variant.llm.num_ctx, Some(4096));
781 }
782}