Skip to main content

sh_layer2/
planner.rs

1//! # Agent Planner
2//!
3//! 任务分解和执行计划生成。
4//!
5//! 支持将复杂任务分解为可执行的子任务序列。
6
7use crate::types::{Layer2Result, TaskId};
8use crate::workflow_engine::{Dag, Node};
9use serde::{Deserialize, Serialize};
10
11/// 任务分解策略
12#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
13pub enum DecompositionStrategy {
14    /// 顺序分解:按步骤顺序执行
15    Sequential,
16    /// 并行分解:独立子任务并行执行
17    Parallel,
18    /// 混合分解:根据依赖关系自动选择
19    #[default]
20    Hybrid,
21    /// 层次分解:先粗粒度再细粒度
22    Hierarchical,
23}
24
25/// 子任务定义
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct SubTask {
28    /// 子任务ID
29    pub id: String,
30    /// 子任务名称
31    pub name: String,
32    /// 子任务描述
33    pub description: String,
34    /// 执行优先级 (0最高)
35    pub priority: u32,
36    /// 依赖的任务ID列表
37    pub dependencies: Vec<String>,
38    /// 预估复杂度 (1-10)
39    pub estimated_complexity: u32,
40    /// 执行工具
41    pub tool: Option<String>,
42    /// 工具参数
43    pub tool_args: Option<serde_json::Value>,
44    /// 验证条件
45    pub validation_criteria: Vec<String>,
46    /// 失败时的替代方案
47    pub fallback: Option<Box<SubTask>>,
48}
49
50impl SubTask {
51    /// 创建新的子任务
52    pub fn new(
53        id: impl Into<String>,
54        name: impl Into<String>,
55        description: impl Into<String>,
56    ) -> Self {
57        Self {
58            id: id.into(),
59            name: name.into(),
60            description: description.into(),
61            priority: 0,
62            dependencies: Vec::new(),
63            estimated_complexity: 5,
64            tool: None,
65            tool_args: None,
66            validation_criteria: Vec::new(),
67            fallback: None,
68        }
69    }
70
71    /// 设置优先级
72    pub fn with_priority(mut self, priority: u32) -> Self {
73        self.priority = priority;
74        self
75    }
76
77    /// 添加依赖
78    pub fn with_dependency(mut self, dep_id: impl Into<String>) -> Self {
79        self.dependencies.push(dep_id.into());
80        self
81    }
82
83    /// 设置执行工具
84    pub fn with_tool(mut self, tool: impl Into<String>, args: serde_json::Value) -> Self {
85        self.tool = Some(tool.into());
86        self.tool_args = Some(args);
87        self
88    }
89
90    /// 设置验证条件
91    pub fn with_validation(mut self, criteria: impl Into<String>) -> Self {
92        self.validation_criteria.push(criteria.into());
93        self
94    }
95
96    /// 设置失败替代方案
97    pub fn with_fallback(mut self, fallback: SubTask) -> Self {
98        self.fallback = Some(Box::new(fallback));
99        self
100    }
101}
102
103/// 执行计划
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct ExecutionPlan {
106    /// 计划ID
107    pub id: String,
108    /// 原始任务描述
109    pub original_task: String,
110    /// 分解策略
111    pub strategy: DecompositionStrategy,
112    /// 子任务列表
113    pub subtasks: Vec<SubTask>,
114    /// 执行顺序(拓扑排序后的ID列表)
115    pub execution_order: Vec<String>,
116    /// 估算总步数
117    pub estimated_steps: u32,
118    /// 风险评估
119    pub risk_level: RiskLevel,
120    /// 创建时间
121    pub created_at: chrono::DateTime<chrono::Utc>,
122}
123
124/// 风险等级
125#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
126pub enum RiskLevel {
127    /// 低风险:简单的确定性任务
128    Low,
129    /// 中风险:需要多个工具协作
130    #[default]
131    Medium,
132    /// 高风险:涉及复杂逻辑或外部系统
133    High,
134    /// 极高风险:可能需要用户介入
135    Critical,
136}
137
138impl ExecutionPlan {
139    /// 创建新的执行计划
140    pub fn new(original_task: impl Into<String>) -> Self {
141        Self {
142            id: TaskId::new().to_string(),
143            original_task: original_task.into(),
144            strategy: DecompositionStrategy::default(),
145            subtasks: Vec::new(),
146            execution_order: Vec::new(),
147            estimated_steps: 1,
148            risk_level: RiskLevel::default(),
149            created_at: chrono::Utc::now(),
150        }
151    }
152
153    /// 添加子任务
154    pub fn add_subtask(&mut self, subtask: SubTask) -> &mut Self {
155        self.subtasks.push(subtask);
156        self
157    }
158
159    /// 计算执行顺序(拓扑排序)
160    pub fn compute_execution_order(&mut self) -> Layer2Result<()> {
161        let mut dag = Dag::new();
162
163        // 添加节点
164        for subtask in &self.subtasks {
165            let node = Node::new(&subtask.id, &subtask.name);
166            dag.add_node(node)?;
167        }
168
169        // 添加边(依赖关系)
170        for subtask in &self.subtasks {
171            for dep in &subtask.dependencies {
172                dag.add_edge(dep, &subtask.id)?;
173            }
174        }
175
176        // 检查循环依赖
177        if dag.has_cycle() {
178            return Err(anyhow::anyhow!(
179                "Circular dependency detected in execution plan"
180            ));
181        }
182
183        // 拓扑排序
184        self.execution_order = dag.topological_sort()?;
185
186        // 估算步数
187        self.estimated_steps = self.subtasks.len() as u32;
188
189        Ok(())
190    }
191
192    /// 转换为 DAG 工作流
193    pub fn to_dag(&self) -> Layer2Result<Dag> {
194        let mut dag = Dag::new();
195
196        for subtask in &self.subtasks {
197            let mut node = Node::new(&subtask.id, &subtask.name);
198            // 使用 config 字段存储子任务信息
199            node.config = serde_json::json!({
200                "description": subtask.description,
201                "tool": subtask.tool,
202                "tool_args": subtask.tool_args,
203            });
204            dag.add_node(node)?;
205        }
206
207        for subtask in &self.subtasks {
208            for dep in &subtask.dependencies {
209                dag.add_edge(dep, &subtask.id)?;
210            }
211        }
212
213        Ok(dag)
214    }
215}
216
217/// 任务分解器
218pub struct TaskDecomposer {
219    /// 分解策略
220    strategy: DecompositionStrategy,
221    /// 最大分解深度
222    max_depth: u32,
223    /// 最小子任务粒度
224    #[allow(dead_code)]
225    min_granularity: u32,
226}
227
228impl Default for TaskDecomposer {
229    fn default() -> Self {
230        Self {
231            strategy: DecompositionStrategy::Hybrid,
232            max_depth: 3,
233            min_granularity: 1,
234        }
235    }
236}
237
238impl TaskDecomposer {
239    /// 创建新的任务分解器
240    pub fn new() -> Self {
241        Self::default()
242    }
243
244    /// 设置分解策略
245    pub fn with_strategy(mut self, strategy: DecompositionStrategy) -> Self {
246        self.strategy = strategy;
247        self
248    }
249
250    /// 设置最大分解深度
251    pub fn with_max_depth(mut self, depth: u32) -> Self {
252        self.max_depth = depth;
253        self
254    }
255
256    /// 分解任务为执行计划
257    pub fn decompose(&self, task: &str) -> Layer2Result<ExecutionPlan> {
258        let mut plan = ExecutionPlan::new(task);
259        plan.strategy = self.strategy;
260
261        // 分析任务复杂度
262        let complexity = self.analyze_complexity(task);
263        plan.risk_level = self.estimate_risk(task, complexity);
264
265        // 根据策略分解任务
266        let subtasks = match self.strategy {
267            DecompositionStrategy::Sequential => self.decompose_sequential(task),
268            DecompositionStrategy::Parallel => self.decompose_parallel(task),
269            DecompositionStrategy::Hierarchical => self.decompose_hierarchical(task, 0),
270            DecompositionStrategy::Hybrid => self.decompose_hybrid(task),
271        };
272
273        for subtask in subtasks {
274            plan.add_subtask(subtask);
275        }
276
277        plan.compute_execution_order()?;
278
279        Ok(plan)
280    }
281
282    /// 分析任务复杂度
283    fn analyze_complexity(&self, task: &str) -> u32 {
284        let mut complexity = 1u32;
285
286        // 关键词分析
287        let task_lower = task.to_lowercase();
288
289        // 复杂度增加因素
290        if task_lower.contains("implement") || task_lower.contains("create") {
291            complexity += 2;
292        }
293        if task_lower.contains("refactor") || task_lower.contains("rewrite") {
294            complexity += 2;
295        }
296        if task_lower.contains("integrate") || task_lower.contains("connect") {
297            complexity += 1;
298        }
299        if task_lower.contains("test") || task_lower.contains("verify") {
300            complexity += 1;
301        }
302        if task_lower.contains("and") || task_lower.contains("then") {
303            complexity += 1;
304        }
305        if task_lower.contains("multiple") || task_lower.contains("several") {
306            complexity += 1;
307        }
308
309        // 长度因素
310        let word_count = task.split_whitespace().count();
311        if word_count > 20 {
312            complexity += 1;
313        }
314        if word_count > 50 {
315            complexity += 1;
316        }
317
318        complexity.min(10)
319    }
320
321    /// 估算风险等级
322    fn estimate_risk(&self, task: &str, complexity: u32) -> RiskLevel {
323        let task_lower = task.to_lowercase();
324
325        // 检查高风险关键词
326        if task_lower.contains("delete")
327            || task_lower.contains("remove")
328            || task_lower.contains("drop")
329        {
330            return RiskLevel::Critical;
331        }
332        if task_lower.contains("production")
333            || task_lower.contains("live")
334            || task_lower.contains("deploy")
335        {
336            return RiskLevel::High;
337        }
338        if task_lower.contains("database") || task_lower.contains("migration") {
339            return RiskLevel::High;
340        }
341
342        // 基于复杂度
343        match complexity {
344            1..=3 => RiskLevel::Low,
345            4..=6 => RiskLevel::Medium,
346            7..=8 => RiskLevel::High,
347            _ => RiskLevel::Critical,
348        }
349    }
350
351    /// 顺序分解
352    fn decompose_sequential(&self, task: &str) -> Vec<SubTask> {
353        let steps = self.extract_steps(task);
354        let mut subtasks = Vec::new();
355        let mut prev_id: Option<String> = None;
356
357        for (i, step) in steps.into_iter().enumerate() {
358            let id = format!("step_{}", i + 1);
359            let mut subtask = SubTask::new(&id, format!("Step {}", i + 1), step);
360            subtask.priority = i as u32;
361
362            if let Some(prev) = prev_id {
363                subtask = subtask.with_dependency(prev);
364            }
365
366            prev_id = Some(id);
367            subtasks.push(subtask);
368        }
369
370        if subtasks.is_empty() {
371            subtasks.push(SubTask::new("step_1", "Execute task", task));
372        }
373
374        subtasks
375    }
376
377    /// 并行分解
378    fn decompose_parallel(&self, task: &str) -> Vec<SubTask> {
379        let parts = self.extract_parallel_parts(task);
380        let mut subtasks = Vec::new();
381
382        for (i, part) in parts.into_iter().enumerate() {
383            let id = format!("parallel_{}", i + 1);
384            let subtask = SubTask::new(&id, format!("Task {}", i + 1), part);
385            subtasks.push(subtask);
386        }
387
388        if subtasks.is_empty() {
389            subtasks.push(SubTask::new("parallel_1", "Execute task", task));
390        }
391
392        subtasks
393    }
394
395    /// 层次分解
396    fn decompose_hierarchical(&self, task: &str, depth: u32) -> Vec<SubTask> {
397        if depth >= self.max_depth {
398            return vec![SubTask::new(format!("leaf_{}", depth), "Execute", task)];
399        }
400
401        let main_steps = self.extract_steps(task);
402        let mut subtasks = Vec::new();
403
404        for (i, step) in main_steps.into_iter().enumerate() {
405            let id = format!("h{}_{}", depth, i + 1);
406            let mut subtask = SubTask::new(&id, format!("Phase {}", i + 1), step.clone());
407            subtask.estimated_complexity = self.analyze_complexity(&step);
408
409            // 如果步骤仍然复杂,继续分解
410            if subtask.estimated_complexity > 5 && depth < self.max_depth - 1 {
411                let sub_subtasks = self.decompose_hierarchical(&step, depth + 1);
412                for (j, sub_sub) in sub_subtasks.into_iter().enumerate() {
413                    let mut sub_sub_id = sub_sub;
414                    sub_sub_id.id = format!("{}_{}", id, j + 1);
415                    sub_sub_id.dependencies.push(id.clone());
416                    subtasks.push(sub_sub_id);
417                }
418            }
419
420            subtasks.push(subtask);
421        }
422
423        subtasks
424    }
425
426    /// 混合分解
427    fn decompose_hybrid(&self, task: &str) -> Vec<SubTask> {
428        let complexity = self.analyze_complexity(task);
429
430        if complexity <= 3 {
431            // 简单任务,单步执行
432            vec![SubTask::new("execute", "Execute task", task)]
433        } else if complexity <= 6 {
434            // 中等复杂度,顺序分解
435            self.decompose_sequential(task)
436        } else {
437            // 高复杂度,层次分解
438            self.decompose_hierarchical(task, 0)
439        }
440    }
441
442    /// 提取步骤
443    fn extract_steps(&self, task: &str) -> Vec<String> {
444        let mut steps = Vec::new();
445
446        // 按句号、分号、换行分割
447        let sentences: Vec<&str> = task
448            .split(&['.', ';', '\n'][..])
449            .map(|s| s.trim())
450            .filter(|s| !s.is_empty())
451            .collect();
452
453        if sentences.len() > 1 {
454            steps = sentences.into_iter().map(|s| s.to_string()).collect();
455        } else {
456            // 尝试按 "and then" 或 "then" 分割
457            let mut then_parts: Vec<&str> = task.split("and then").collect();
458            if then_parts.len() == 1 {
459                then_parts = task.split("then").collect();
460            }
461            if then_parts.len() == 1 {
462                then_parts = task.split("after that").collect();
463            }
464
465            if then_parts.len() > 1 {
466                steps = then_parts
467                    .into_iter()
468                    .map(|s| s.trim().to_string())
469                    .collect();
470            } else {
471                // 单一任务
472                steps.push(task.to_string());
473            }
474        }
475
476        steps
477    }
478
479    /// 提取并行部分
480    fn extract_parallel_parts(&self, task: &str) -> Vec<String> {
481        // 按 "and" 或逗号分割
482        let mut parts: Vec<&str> = task.split(", and ").collect();
483        if parts.len() == 1 {
484            parts = task.split(" and ").collect();
485        }
486        if parts.len() == 1 {
487            parts = task.split(", ").collect();
488        }
489
490        let parts: Vec<&str> = parts
491            .into_iter()
492            .map(|s| s.trim())
493            .filter(|s| !s.is_empty() && s.len() > 3)
494            .collect();
495
496        if parts.len() > 1 {
497            parts.into_iter().map(|s| s.to_string()).collect()
498        } else {
499            vec![task.to_string()]
500        }
501    }
502}
503
504/// 规划结果
505#[derive(Debug, Clone)]
506pub struct PlanResult {
507    /// 执行计划
508    pub plan: ExecutionPlan,
509    /// 分解质量评分 (0-100)
510    pub quality_score: u32,
511    /// 分解建议
512    pub suggestions: Vec<String>,
513}
514
515impl PlanResult {
516    /// 创建新的规划结果
517    pub fn new(plan: ExecutionPlan) -> Self {
518        let quality_score = Self::calculate_quality(&plan);
519        let suggestions = Self::generate_suggestions(&plan);
520
521        Self {
522            plan,
523            quality_score,
524            suggestions,
525        }
526    }
527
528    /// 计算分解质量
529    fn calculate_quality(plan: &ExecutionPlan) -> u32 {
530        let mut score = 100u32;
531
532        // 检查子任务数量
533        if plan.subtasks.is_empty() {
534            score = 0;
535        } else if plan.subtasks.len() == 1 {
536            score -= 20; // 未分解
537        }
538
539        // 检查执行顺序是否合理
540        if plan.execution_order.len() != plan.subtasks.len() {
541            score -= 30;
542        }
543
544        // 检查是否有验证条件
545        let has_validation = plan
546            .subtasks
547            .iter()
548            .any(|s| !s.validation_criteria.is_empty());
549        if !has_validation {
550            score -= 10;
551        }
552
553        // 检查是否有失败处理
554        let has_fallback = plan.subtasks.iter().any(|s| s.fallback.is_some());
555        if !has_fallback && plan.risk_level >= RiskLevel::High {
556            score -= 15;
557        }
558
559        score
560    }
561
562    /// 生成建议
563    fn generate_suggestions(plan: &ExecutionPlan) -> Vec<String> {
564        let mut suggestions = Vec::new();
565
566        if plan.subtasks.len() == 1 {
567            suggestions.push("Consider breaking down the task into smaller subtasks".to_string());
568        }
569
570        if plan.risk_level >= RiskLevel::High {
571            suggestions.push("High-risk task: consider adding validation steps".to_string());
572        }
573
574        let has_fallback = plan.subtasks.iter().any(|s| s.fallback.is_some());
575        if !has_fallback && !plan.subtasks.is_empty() {
576            suggestions
577                .push("Consider adding fallback strategies for critical subtasks".to_string());
578        }
579
580        suggestions
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::{
587        DecompositionStrategy, ExecutionPlan, PlanResult, RiskLevel, SubTask, TaskDecomposer,
588    };
589
590    #[test]
591    fn test_subtask_creation() {
592        let subtask = SubTask::new("test_1", "Test", "Test subtask");
593        assert_eq!(subtask.id, "test_1");
594        assert_eq!(subtask.name, "Test");
595    }
596
597    #[test]
598    fn test_subtask_with_dependencies() {
599        let subtask = SubTask::new("test_2", "Test", "Test").with_dependency("test_1");
600        assert_eq!(subtask.dependencies.len(), 1);
601    }
602
603    #[test]
604    fn test_execution_plan_creation() {
605        let plan = ExecutionPlan::new("Test task");
606        assert!(!plan.original_task.is_empty());
607        assert!(plan.subtasks.is_empty());
608    }
609
610    #[test]
611    fn test_task_decomposer() {
612        let decomposer = TaskDecomposer::new();
613        let plan = decomposer
614            .decompose("Create a file and write some content")
615            .unwrap();
616
617        assert!(!plan.subtasks.is_empty());
618        assert!(!plan.execution_order.is_empty());
619    }
620
621    #[test]
622    fn test_complexity_analysis() {
623        let decomposer = TaskDecomposer::new();
624
625        let simple = decomposer.analyze_complexity("Read a file");
626        assert!(simple <= 3);
627
628        let complex = decomposer.analyze_complexity(
629            "Implement a complete authentication system with OAuth2 integration",
630        );
631        assert!(complex > 3);
632    }
633
634    #[test]
635    fn test_risk_estimation() {
636        let decomposer = TaskDecomposer::new();
637
638        let low = decomposer.estimate_risk("Read a file", 2);
639        assert_eq!(low, RiskLevel::Low);
640
641        let critical = decomposer.estimate_risk("Delete the production database", 5);
642        assert_eq!(critical, RiskLevel::Critical);
643    }
644
645    #[test]
646    fn test_sequential_decomposition() {
647        let decomposer = TaskDecomposer::new().with_strategy(DecompositionStrategy::Sequential);
648
649        let plan = decomposer
650            .decompose("First step. Second step. Third step.")
651            .unwrap();
652
653        assert!(plan.subtasks.len() >= 3);
654        // 验证依赖链
655        for i in 1..plan.subtasks.len() {
656            assert!(plan.subtasks[i]
657                .dependencies
658                .contains(&plan.subtasks[i - 1].id));
659        }
660    }
661
662    #[test]
663    fn test_parallel_decomposition() {
664        let decomposer = TaskDecomposer::new().with_strategy(DecompositionStrategy::Parallel);
665
666        let plan = decomposer
667            .decompose("Task A and Task B and Task C")
668            .unwrap();
669
670        assert!(plan.subtasks.len() >= 2);
671        // 并行任务不应该有依赖
672        let has_deps: bool = plan.subtasks.iter().any(|s| !s.dependencies.is_empty());
673        assert!(!has_deps);
674    }
675
676    #[test]
677    fn test_plan_result_quality() {
678        let mut plan = ExecutionPlan::new("Test task");
679        plan.add_subtask(SubTask::new("s1", "Step 1", "First step"));
680        plan.add_subtask(SubTask::new("s2", "Step 2", "Second step").with_dependency("s1"));
681        plan.compute_execution_order().unwrap();
682
683        let result = PlanResult::new(plan);
684        assert!(result.quality_score > 0);
685    }
686
687    #[test]
688    fn test_dag_conversion() {
689        let mut plan = ExecutionPlan::new("Test task");
690        plan.add_subtask(SubTask::new("s1", "Step 1", "First step"));
691        plan.add_subtask(SubTask::new("s2", "Step 2", "Second step").with_dependency("s1"));
692        plan.compute_execution_order().unwrap();
693
694        let dag_result = plan.to_dag();
695        assert!(dag_result.is_ok());
696    }
697}