Skip to main content

swarm_engine_eval/scenario/
types.rs

1//! Eval シナリオの型定義
2//!
3//! 評価専用のシナリオ定義。SwarmApp との直接連携を前提に設計。
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::Duration;
8
9// Re-exports for backward compatibility
10pub use super::actions::ScenarioActions;
11pub use super::conditions::EvalConditions;
12pub use super::dependency::DependencyGraphConfig;
13pub use super::llm::{LlmConfig, LlmConfigOverride};
14pub use super::manager::{
15    BatchProcessorConfig, ManagerActivationConfig, ManagerConfig, ManagerTemplate,
16};
17pub use super::milestone::Milestone;
18
19// ============================================================================
20// Task Configuration
21// ============================================================================
22
23/// Task definition for swarm evaluation
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct TaskConfig {
26    /// The goal/objective for the swarm to achieve
27    pub goal: String,
28
29    /// Expected result for evaluation (e.g., "src/auth/handler.rs:42")
30    #[serde(default)]
31    pub expected: Option<String>,
32
33    /// Additional context for the task
34    #[serde(default)]
35    pub context: TaskContext,
36}
37
38/// Additional context passed to the task
39#[derive(Debug, Clone, Serialize, Deserialize, Default)]
40pub struct TaskContext {
41    /// Target path for exploration (e.g., codebase root)
42    #[serde(default)]
43    pub target_path: Option<String>,
44
45    /// Working directory for workers
46    #[serde(default)]
47    pub working_dir: Option<String>,
48
49    /// Maximum exploration depth
50    #[serde(default)]
51    pub max_depth: Option<usize>,
52
53    /// Additional key-value context
54    #[serde(default, flatten)]
55    pub extra: HashMap<String, toml::Value>,
56}
57
58// ============================================================================
59// Scenario Identification
60// ============================================================================
61
62/// シナリオ識別子
63#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
64pub struct ScenarioId(pub String);
65
66impl ScenarioId {
67    pub fn new(id: impl Into<String>) -> Self {
68        Self(id.into())
69    }
70
71    pub fn as_str(&self) -> &str {
72        &self.0
73    }
74
75    /// IDから学習用キーを抽出
76    ///
77    /// ID形式: `namespace:name:version` (例: `user:troubleshooting:v2`)
78    /// → 中央の `name` 部分を返す
79    ///
80    /// フォールバック: コロンがない場合はIDをそのままファイルシステム安全な形式に変換
81    pub fn learning_key(&self) -> String {
82        let parts: Vec<&str> = self.0.split(':').collect();
83        if parts.len() >= 2 {
84            // namespace:name:version → name を返す
85            parts[1].to_string()
86        } else {
87            // フォールバック: スペース・特殊文字をアンダースコアに
88            self.0
89                .chars()
90                .map(|c| {
91                    if c.is_alphanumeric() || c == '-' || c == '_' {
92                        c
93                    } else {
94                        '_'
95                    }
96                })
97                .collect()
98        }
99    }
100}
101
102impl std::fmt::Display for ScenarioId {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        write!(f, "{}", self.0)
105    }
106}
107
108impl From<&str> for ScenarioId {
109    fn from(s: &str) -> Self {
110        Self::new(s)
111    }
112}
113
114impl From<String> for ScenarioId {
115    fn from(s: String) -> Self {
116        Self::new(s)
117    }
118}
119
120// ============================================================================
121// Scenario Variant
122// ============================================================================
123
124/// シナリオバリアント(パラメータのオーバーライド)
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ScenarioVariant {
127    /// バリアント名(CLI で指定する識別子)
128    pub name: String,
129
130    /// 説明
131    #[serde(default)]
132    pub description: String,
133
134    /// LLM 設定のオーバーライド(provider 切り替え等)
135    #[serde(default)]
136    pub llm: Option<LlmConfigOverride>,
137
138    /// 環境パラメータのオーバーライド
139    #[serde(default)]
140    pub environment_params: serde_json::Value,
141
142    /// 依存グラフ設定のオーバーライド(オプション)
143    #[serde(default)]
144    pub dependency_graph: Option<DependencyGraphConfig>,
145
146    /// AppConfig のオーバーライド(オプション)
147    #[serde(default)]
148    pub app_config: Option<AppConfigOverride>,
149
150    /// max_ticks のオーバーライド(オプション)
151    #[serde(default)]
152    pub max_ticks: Option<u64>,
153
154    /// Worker 数のオーバーライド(最初の worker template の count を上書き)
155    #[serde(default)]
156    pub workers_count: Option<usize>,
157
158    /// Manager 数のオーバーライド(最初の manager template の count を上書き)
159    #[serde(default)]
160    pub managers_count: Option<usize>,
161}
162
163// ============================================================================
164// Eval Scenario
165// ============================================================================
166
167/// 評価シナリオの完全な定義
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct EvalScenario {
170    /// メタ情報
171    pub meta: ScenarioMeta,
172
173    /// タスク定義 (goal, expected, context)
174    #[serde(default)]
175    pub task: TaskConfig,
176
177    /// LLM 設定 (provider, model, endpoint, etc.)
178    #[serde(default)]
179    pub llm: LlmConfig,
180
181    /// Manager 動作設定 (interval, confidence_threshold, etc.)
182    #[serde(default)]
183    pub manager: ManagerConfig,
184
185    /// BatchProcessor 設定 (parallel, max_concurrency)
186    #[serde(default)]
187    pub batch_processor: BatchProcessorConfig,
188
189    /// 依存グラフ設定 (アクション間の依存関係)
190    #[serde(default)]
191    pub dependency_graph: Option<DependencyGraphConfig>,
192
193    /// 利用可能なアクション設定
194    #[serde(default)]
195    pub actions: ScenarioActions,
196
197    /// SwarmApp 構築設定
198    pub app_config: AppConfigTemplate,
199
200    /// 環境設定 (タスク定義)
201    pub environment: EnvironmentConfig,
202
203    /// エージェント設定
204    pub agents: AgentsConfig,
205
206    /// 成功/失敗条件
207    pub conditions: EvalConditions,
208
209    /// マイルストーン (kpi_score 計算用)
210    #[serde(default)]
211    pub milestones: Vec<Milestone>,
212
213    /// バリアント定義(パラメータの組み合わせ)
214    #[serde(default)]
215    pub variants: Vec<ScenarioVariant>,
216}
217
218impl EvalScenario {
219    /// バリアントを適用した新しいシナリオを返す
220    pub fn with_variant(&self, variant_name: &str) -> Option<EvalScenario> {
221        let variant = self.variants.iter().find(|v| v.name == variant_name)?;
222
223        let mut scenario = self.clone();
224
225        // LLM 設定をマージ
226        if let Some(ref llm_override) = variant.llm {
227            llm_override.apply_to(&mut scenario.llm);
228        }
229
230        // 環境パラメータをマージ
231        if !variant.environment_params.is_null() {
232            if let serde_json::Value::Object(override_map) = &variant.environment_params {
233                if let serde_json::Value::Object(ref mut base_map) = scenario.environment.params {
234                    for (key, value) in override_map {
235                        base_map.insert(key.clone(), value.clone());
236                    }
237                }
238            }
239        }
240
241        // 依存グラフをオーバーライド
242        if variant.dependency_graph.is_some() {
243            scenario.dependency_graph = variant.dependency_graph.clone();
244        }
245
246        // app_config をオーバーライド
247        if let Some(ref app_override) = variant.app_config {
248            if let Some(ref strategy) = app_override.management_strategy {
249                scenario.app_config.management_strategy = strategy.clone();
250            }
251            if let Some(tick_ms) = app_override.tick_duration_ms {
252                scenario.app_config.tick_duration_ms = tick_ms;
253            }
254            if let Some(enable_exp) = app_override.enable_exploration {
255                scenario.app_config.enable_exploration = enable_exp;
256            }
257        }
258
259        // max_ticks をオーバーライド
260        if let Some(max_ticks) = variant.max_ticks {
261            scenario.app_config.max_ticks = max_ticks;
262        }
263
264        // workers_count をオーバーライド(最初の worker template)
265        if let Some(workers_count) = variant.workers_count {
266            if let Some(first_worker) = scenario.agents.workers.first_mut() {
267                first_worker.count = workers_count;
268            }
269        }
270
271        // managers_count をオーバーライド(最初の manager template)
272        if let Some(managers_count) = variant.managers_count {
273            if let Some(first_manager) = scenario.agents.managers.first_mut() {
274                first_manager.count = managers_count;
275                // id_pattern が未設定の場合、id から自動生成
276                if first_manager.id_pattern.is_none() {
277                    if let Some(ref id) = first_manager.id {
278                        first_manager.id_pattern = Some(format!("{}_{{i}}", id));
279                        first_manager.id = None;
280                    }
281                }
282            }
283        }
284
285        // メタ情報を更新(バリアント名を追加)
286        scenario.meta.name = format!("{} ({})", self.meta.name, variant_name);
287
288        Some(scenario)
289    }
290
291    /// 利用可能なバリアント名のリストを返す
292    pub fn variant_names(&self) -> Vec<&str> {
293        self.variants.iter().map(|v| v.name.as_str()).collect()
294    }
295}
296
297// ============================================================================
298// Scenario Meta
299// ============================================================================
300
301/// シナリオのメタ情報
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct ScenarioMeta {
304    /// シナリオ名
305    pub name: String,
306
307    /// バージョン (semver形式推奨)
308    #[serde(default = "default_version")]
309    pub version: String,
310
311    /// 一意識別子
312    pub id: ScenarioId,
313
314    /// 説明文
315    #[serde(default)]
316    pub description: String,
317
318    /// タグ (検索・フィルタ用)
319    #[serde(default)]
320    pub tags: Vec<String>,
321}
322
323fn default_version() -> String {
324    "1.0.0".to_string()
325}
326
327// ============================================================================
328// App Config Template
329// ============================================================================
330
331/// SwarmApp 構築用のテンプレート
332///
333/// 実際の AppConfig を生成するためのテンプレート。
334/// seed や LLM provider は評価時に注入。
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct AppConfigTemplate {
337    /// Tick 間隔 (ミリ秒)
338    #[serde(default = "default_tick_duration_ms")]
339    pub tick_duration_ms: u64,
340
341    /// 最大 Tick 数
342    #[serde(default = "default_max_ticks")]
343    pub max_ticks: u64,
344
345    /// Management Strategy 設定
346    #[serde(default)]
347    pub management_strategy: ManagementStrategyConfig,
348
349    /// ExplorationSpace を有効化
350    #[serde(default)]
351    pub enable_exploration: bool,
352}
353
354fn default_tick_duration_ms() -> u64 {
355    10
356}
357
358fn default_max_ticks() -> u64 {
359    1000
360}
361
362impl AppConfigTemplate {
363    pub fn tick_duration(&self) -> Duration {
364        Duration::from_millis(self.tick_duration_ms)
365    }
366}
367
368impl Default for AppConfigTemplate {
369    fn default() -> Self {
370        Self {
371            tick_duration_ms: default_tick_duration_ms(),
372            max_ticks: default_max_ticks(),
373            management_strategy: ManagementStrategyConfig::default(),
374            enable_exploration: false,
375        }
376    }
377}
378
379/// AppConfig のオーバーライド用構造体
380///
381/// variant で指定されたフィールドだけが base にマージされる。
382#[derive(Debug, Clone, Serialize, Deserialize, Default)]
383pub struct AppConfigOverride {
384    /// Management Strategy のオーバーライド
385    #[serde(default)]
386    pub management_strategy: Option<ManagementStrategyConfig>,
387
388    /// tick_duration_ms のオーバーライド
389    #[serde(default)]
390    pub tick_duration_ms: Option<u64>,
391
392    /// enable_exploration のオーバーライド
393    #[serde(default)]
394    pub enable_exploration: Option<bool>,
395}
396
397// ============================================================================
398// Management Strategy
399// ============================================================================
400
401/// Management Strategy 設定
402#[derive(Debug, Clone, Serialize, Deserialize)]
403#[serde(tag = "type", rename_all = "snake_case")]
404pub enum ManagementStrategyConfig {
405    /// 毎 Tick 起動(LLM 不要なフロー向け)
406    ///
407    /// V2 ExplorationSpace など、LLM 呼び出しなしで Guidance を生成できる場合に使用。
408    EveryTick {},
409
410    /// インターバルベース
411    IntervalBased {
412        #[serde(default = "default_max_interval")]
413        max_interval: u64,
414    },
415    /// イベントドリブン
416    EventDriven {
417        #[serde(default)]
418        triggers: Vec<String>,
419    },
420    /// ハイブリッド
421    Hybrid {
422        #[serde(default = "default_max_interval")]
423        max_interval: u64,
424        #[serde(default)]
425        triggers: Vec<String>,
426    },
427    /// 無効化
428    #[serde(alias = "disabled")]
429    Disabled {},
430}
431
432fn default_max_interval() -> u64 {
433    20
434}
435
436impl Default for ManagementStrategyConfig {
437    fn default() -> Self {
438        Self::IntervalBased {
439            max_interval: default_max_interval(),
440        }
441    }
442}
443
444// ============================================================================
445// Environment Configuration
446// ============================================================================
447
448/// 環境設定
449#[derive(Debug, Clone, Serialize, Deserialize)]
450pub struct EnvironmentConfig {
451    /// 環境タイプ識別子 (e.g., "grid_world", "task_queue")
452    pub env_type: String,
453
454    /// 環境固有パラメータ
455    #[serde(default)]
456    pub params: serde_json::Value,
457
458    /// 初期状態設定
459    #[serde(default)]
460    pub initial_state: Option<InitialStateConfig>,
461}
462
463/// 初期状態設定
464#[derive(Debug, Clone, Serialize, Deserialize)]
465#[serde(tag = "type", rename_all = "snake_case")]
466pub enum InitialStateConfig {
467    /// Seed から決定論的に生成
468    #[serde(alias = "seeded_random")]
469    SeededRandom {},
470    /// 固定状態
471    Fixed {
472        /// 固定状態の定義
473        state: serde_json::Value,
474    },
475    /// カスタム生成器
476    Custom {
477        /// 生成器識別子
478        generator: String,
479        /// 生成器パラメータ
480        params: serde_json::Value,
481    },
482}
483
484// ============================================================================
485// Agents Configuration
486// ============================================================================
487
488/// エージェント設定
489#[derive(Debug, Clone, Default, Serialize, Deserialize)]
490pub struct AgentsConfig {
491    /// Worker テンプレート
492    #[serde(default)]
493    pub workers: Vec<WorkerTemplate>,
494
495    /// Manager テンプレート
496    #[serde(default)]
497    pub managers: Vec<ManagerTemplate>,
498}
499
500/// Worker テンプレート
501#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct WorkerTemplate {
503    /// ID生成パターン (e.g., "worker_{i}")
504    pub id_pattern: String,
505
506    /// 生成数
507    #[serde(default = "default_worker_count")]
508    pub count: usize,
509
510    /// 役割
511    #[serde(default)]
512    pub role: String,
513
514    /// Worker 固有設定
515    #[serde(default)]
516    pub config: serde_json::Value,
517}
518
519fn default_worker_count() -> usize {
520    1
521}
522
523impl WorkerTemplate {
524    /// Worker IDを生成
525    pub fn generate_ids(&self) -> Vec<String> {
526        (0..self.count)
527            .map(|i| self.id_pattern.replace("{i}", &i.to_string()))
528            .collect()
529    }
530}
531
532// ============================================================================
533// Tests
534// ============================================================================
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539    use crate::scenario::llm::LlmProvider;
540
541    #[test]
542    fn test_scenario_id() {
543        let id = ScenarioId::new("test:scenario:v1");
544        assert_eq!(id.as_str(), "test:scenario:v1");
545    }
546
547    #[test]
548    fn test_scenario_id_learning_key() {
549        // Standard format: namespace:name:version
550        let id = ScenarioId::new("user:troubleshooting:v2");
551        assert_eq!(id.learning_key(), "troubleshooting");
552
553        // Builtin format
554        let id = ScenarioId::new("builtin:resource_gathering:v1");
555        assert_eq!(id.learning_key(), "resource_gathering");
556
557        // Simple format (no colons) - fallback
558        let id = ScenarioId::new("simple_scenario");
559        assert_eq!(id.learning_key(), "simple_scenario");
560
561        // With spaces (sanitized)
562        let id = ScenarioId::new("Service Troubleshooting");
563        assert_eq!(id.learning_key(), "Service_Troubleshooting");
564    }
565
566    #[test]
567    fn test_worker_template_generate_ids() {
568        let template = WorkerTemplate {
569            id_pattern: "worker_{i}".to_string(),
570            count: 3,
571            role: "gatherer".to_string(),
572            config: serde_json::Value::Null,
573        };
574
575        let ids = template.generate_ids();
576        assert_eq!(ids, vec!["worker_0", "worker_1", "worker_2"]);
577    }
578
579    #[test]
580    fn test_app_config_template_default() {
581        let config = AppConfigTemplate::default();
582        assert_eq!(config.tick_duration_ms, 10);
583        assert_eq!(config.max_ticks, 1000);
584    }
585
586    #[test]
587    fn test_management_strategy_deserialize() {
588        let json = r#"{"type": "hybrid", "max_interval": 30, "triggers": ["event_a"]}"#;
589        let strategy: ManagementStrategyConfig = serde_json::from_str(json).unwrap();
590
591        match strategy {
592            ManagementStrategyConfig::Hybrid {
593                max_interval,
594                triggers,
595            } => {
596                assert_eq!(max_interval, 30);
597                assert_eq!(triggers, vec!["event_a"]);
598            }
599            _ => panic!("Expected Hybrid variant"),
600        }
601    }
602
603    #[test]
604    fn test_task_config_default() {
605        let task = TaskConfig::default();
606        assert!(task.goal.is_empty());
607        assert!(task.expected.is_none());
608    }
609
610    #[test]
611    fn test_task_config_deserialize_toml() {
612        let toml_str = r#"
613            goal = "Find the function that handles authentication"
614            expected = "src/auth/handler.rs:42"
615            [context]
616            target_path = "/path/to/codebase"
617            working_dir = "/path/to/codebase"
618            max_depth = 5
619        "#;
620
621        let task: TaskConfig = toml::from_str(toml_str).unwrap();
622        assert_eq!(task.goal, "Find the function that handles authentication");
623        assert_eq!(task.expected, Some("src/auth/handler.rs:42".to_string()));
624        assert_eq!(
625            task.context.target_path,
626            Some("/path/to/codebase".to_string())
627        );
628        assert_eq!(task.context.max_depth, Some(5));
629    }
630
631    #[test]
632    fn test_scenario_variant_with_llm_override() {
633        let toml_str = r#"
634            [meta]
635            name = "Test Scenario"
636            id = "test:scenario:v1"
637
638            [task]
639            goal = "Test goal"
640
641            [llm]
642            provider = "ollama"
643            model = "llama3:8b"
644
645            [app_config]
646            max_ticks = 100
647
648            [environment]
649            env_type = "test"
650
651            [agents]
652
653            [conditions]
654            on_timeout = "fail"
655
656            [[variants]]
657            name = "mistral"
658            description = "Use mistral.rs local inference"
659            [variants.llm]
660            provider = "mistral"
661            model = "LiquidAI/LFM2.5-1.2B-Instruct-GGUF"
662            gguf_files = ["LFM2.5-1.2B-Instruct-Q4_K_M.gguf"]
663        "#;
664
665        let scenario: EvalScenario = toml::from_str(toml_str).unwrap();
666        assert_eq!(scenario.llm.provider, LlmProvider::Ollama);
667        assert_eq!(scenario.variants.len(), 1);
668
669        // Apply variant
670        let mistral_scenario = scenario.with_variant("mistral").unwrap();
671        assert_eq!(mistral_scenario.llm.provider, LlmProvider::Mistral);
672        assert_eq!(
673            mistral_scenario.llm.model,
674            "LiquidAI/LFM2.5-1.2B-Instruct-GGUF"
675        );
676        assert!(mistral_scenario.llm.is_gguf());
677        assert_eq!(mistral_scenario.meta.name, "Test Scenario (mistral)");
678    }
679
680    #[test]
681    fn test_scenario_variant_partial_llm_override() {
682        let toml_str = r#"
683            [meta]
684            name = "Test"
685            id = "test:v1"
686
687            [task]
688            goal = "Test"
689
690            [llm]
691            provider = "ollama"
692            model = "llama3:8b"
693            temperature = 0.1
694            num_ctx = 4096
695
696            [app_config]
697            max_ticks = 100
698
699            [environment]
700            env_type = "test"
701
702            [agents]
703
704            [conditions]
705            on_timeout = "fail"
706
707            [[variants]]
708            name = "high_temp"
709            [variants.llm]
710            temperature = 0.9
711        "#;
712
713        let scenario: EvalScenario = toml::from_str(toml_str).unwrap();
714        let variant = scenario.with_variant("high_temp").unwrap();
715
716        // temperature should be overridden
717        assert!((variant.llm.temperature - 0.9).abs() < f32::EPSILON);
718        // other fields should remain unchanged
719        assert_eq!(variant.llm.provider, LlmProvider::Ollama);
720        assert_eq!(variant.llm.model, "llama3:8b");
721        assert_eq!(variant.llm.num_ctx, Some(4096));
722    }
723}