Skip to main content

sh_layer2/
planner.rs

1//! # Agent Planner
2//!
3//! 任务分解和执行计划生成。
4//!
5//! 支持将复杂任务分解为可执行的子任务序列。
6
7use crate::types::{Layer2Result, TaskId};
8use crate::workflow_engine::{Dag, Node};
9use serde::{Deserialize, Serialize};
10
11/// 任务分解策略
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum DecompositionStrategy {
14    /// 顺序分解:按步骤顺序执行
15    Sequential,
16    /// 并行分解:独立子任务并行执行
17    Parallel,
18    /// 混合分解:根据依赖关系自动选择
19    Hybrid,
20    /// 层次分解:先粗粒度再细粒度
21    Hierarchical,
22}
23
24impl Default for DecompositionStrategy {
25    fn default() -> Self {
26        Self::Hybrid
27    }
28}
29
30/// 子任务定义
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct SubTask {
33    /// 子任务ID
34    pub id: String,
35    /// 子任务名称
36    pub name: String,
37    /// 子任务描述
38    pub description: String,
39    /// 执行优先级 (0最高)
40    pub priority: u32,
41    /// 依赖的任务ID列表
42    pub dependencies: Vec<String>,
43    /// 预估复杂度 (1-10)
44    pub estimated_complexity: u32,
45    /// 执行工具
46    pub tool: Option<String>,
47    /// 工具参数
48    pub tool_args: Option<serde_json::Value>,
49    /// 验证条件
50    pub validation_criteria: Vec<String>,
51    /// 失败时的替代方案
52    pub fallback: Option<Box<SubTask>>,
53}
54
55impl SubTask {
56    /// 创建新的子任务
57    pub fn new(
58        id: impl Into<String>,
59        name: impl Into<String>,
60        description: impl Into<String>,
61    ) -> Self {
62        Self {
63            id: id.into(),
64            name: name.into(),
65            description: description.into(),
66            priority: 0,
67            dependencies: Vec::new(),
68            estimated_complexity: 5,
69            tool: None,
70            tool_args: None,
71            validation_criteria: Vec::new(),
72            fallback: None,
73        }
74    }
75
76    /// 设置优先级
77    pub fn with_priority(mut self, priority: u32) -> Self {
78        self.priority = priority;
79        self
80    }
81
82    /// 添加依赖
83    pub fn with_dependency(mut self, dep_id: impl Into<String>) -> Self {
84        self.dependencies.push(dep_id.into());
85        self
86    }
87
88    /// 设置执行工具
89    pub fn with_tool(mut self, tool: impl Into<String>, args: serde_json::Value) -> Self {
90        self.tool = Some(tool.into());
91        self.tool_args = Some(args);
92        self
93    }
94
95    /// 设置验证条件
96    pub fn with_validation(mut self, criteria: impl Into<String>) -> Self {
97        self.validation_criteria.push(criteria.into());
98        self
99    }
100
101    /// 设置失败替代方案
102    pub fn with_fallback(mut self, fallback: SubTask) -> Self {
103        self.fallback = Some(Box::new(fallback));
104        self
105    }
106}
107
108/// 执行计划
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ExecutionPlan {
111    /// 计划ID
112    pub id: String,
113    /// 原始任务描述
114    pub original_task: String,
115    /// 分解策略
116    pub strategy: DecompositionStrategy,
117    /// 子任务列表
118    pub subtasks: Vec<SubTask>,
119    /// 执行顺序(拓扑排序后的ID列表)
120    pub execution_order: Vec<String>,
121    /// 估算总步数
122    pub estimated_steps: u32,
123    /// 风险评估
124    pub risk_level: RiskLevel,
125    /// 创建时间
126    pub created_at: chrono::DateTime<chrono::Utc>,
127}
128
129/// 风险等级
130#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
131pub enum RiskLevel {
132    /// 低风险:简单的确定性任务
133    Low,
134    /// 中风险:需要多个工具协作
135    Medium,
136    /// 高风险:涉及复杂逻辑或外部系统
137    High,
138    /// 极高风险:可能需要用户介入
139    Critical,
140}
141
142impl Default for RiskLevel {
143    fn default() -> Self {
144        Self::Medium
145    }
146}
147
148impl ExecutionPlan {
149    /// 创建新的执行计划
150    pub fn new(original_task: impl Into<String>) -> Self {
151        Self {
152            id: TaskId::new().to_string(),
153            original_task: original_task.into(),
154            strategy: DecompositionStrategy::default(),
155            subtasks: Vec::new(),
156            execution_order: Vec::new(),
157            estimated_steps: 1,
158            risk_level: RiskLevel::default(),
159            created_at: chrono::Utc::now(),
160        }
161    }
162
163    /// 添加子任务
164    pub fn add_subtask(&mut self, subtask: SubTask) -> &mut Self {
165        self.subtasks.push(subtask);
166        self
167    }
168
169    /// 计算执行顺序(拓扑排序)
170    pub fn compute_execution_order(&mut self) -> Layer2Result<()> {
171        let mut dag = Dag::new();
172
173        // 添加节点
174        for subtask in &self.subtasks {
175            let node = Node::new(&subtask.id, &subtask.name);
176            dag.add_node(node)?;
177        }
178
179        // 添加边(依赖关系)
180        for subtask in &self.subtasks {
181            for dep in &subtask.dependencies {
182                dag.add_edge(dep, &subtask.id)?;
183            }
184        }
185
186        // 检查循环依赖
187        if dag.has_cycle() {
188            return Err(anyhow::anyhow!(
189                "Circular dependency detected in execution plan"
190            ));
191        }
192
193        // 拓扑排序
194        self.execution_order = dag.topological_sort()?;
195
196        // 估算步数
197        self.estimated_steps = self.subtasks.len() as u32;
198
199        Ok(())
200    }
201
202    /// 转换为 DAG 工作流
203    pub fn to_dag(&self) -> Layer2Result<Dag> {
204        let mut dag = Dag::new();
205
206        for subtask in &self.subtasks {
207            let mut node = Node::new(&subtask.id, &subtask.name);
208            // 使用 config 字段存储子任务信息
209            node.config = serde_json::json!({
210                "description": subtask.description,
211                "tool": subtask.tool,
212                "tool_args": subtask.tool_args,
213            });
214            dag.add_node(node)?;
215        }
216
217        for subtask in &self.subtasks {
218            for dep in &subtask.dependencies {
219                dag.add_edge(dep, &subtask.id)?;
220            }
221        }
222
223        Ok(dag)
224    }
225}
226
227/// 任务分解器
228pub struct TaskDecomposer {
229    /// 分解策略
230    strategy: DecompositionStrategy,
231    /// 最大分解深度
232    max_depth: u32,
233    /// 最小子任务粒度
234    min_granularity: u32,
235}
236
237impl Default for TaskDecomposer {
238    fn default() -> Self {
239        Self {
240            strategy: DecompositionStrategy::Hybrid,
241            max_depth: 3,
242            min_granularity: 1,
243        }
244    }
245}
246
247impl TaskDecomposer {
248    /// 创建新的任务分解器
249    pub fn new() -> Self {
250        Self::default()
251    }
252
253    /// 设置分解策略
254    pub fn with_strategy(mut self, strategy: DecompositionStrategy) -> Self {
255        self.strategy = strategy;
256        self
257    }
258
259    /// 设置最大分解深度
260    pub fn with_max_depth(mut self, depth: u32) -> Self {
261        self.max_depth = depth;
262        self
263    }
264
265    /// 分解任务为执行计划
266    pub fn decompose(&self, task: &str) -> Layer2Result<ExecutionPlan> {
267        let mut plan = ExecutionPlan::new(task);
268        plan.strategy = self.strategy;
269
270        // 分析任务复杂度
271        let complexity = self.analyze_complexity(task);
272        plan.risk_level = self.estimate_risk(task, complexity);
273
274        // 根据策略分解任务
275        let subtasks = match self.strategy {
276            DecompositionStrategy::Sequential => self.decompose_sequential(task),
277            DecompositionStrategy::Parallel => self.decompose_parallel(task),
278            DecompositionStrategy::Hierarchical => self.decompose_hierarchical(task, 0),
279            DecompositionStrategy::Hybrid => self.decompose_hybrid(task),
280        };
281
282        for subtask in subtasks {
283            plan.add_subtask(subtask);
284        }
285
286        plan.compute_execution_order()?;
287
288        Ok(plan)
289    }
290
291    /// 分析任务复杂度
292    fn analyze_complexity(&self, task: &str) -> u32 {
293        let mut complexity = 1u32;
294
295        // 关键词分析
296        let task_lower = task.to_lowercase();
297
298        // 复杂度增加因素
299        if task_lower.contains("implement") || task_lower.contains("create") {
300            complexity += 2;
301        }
302        if task_lower.contains("refactor") || task_lower.contains("rewrite") {
303            complexity += 2;
304        }
305        if task_lower.contains("integrate") || task_lower.contains("connect") {
306            complexity += 1;
307        }
308        if task_lower.contains("test") || task_lower.contains("verify") {
309            complexity += 1;
310        }
311        if task_lower.contains("and") || task_lower.contains("then") {
312            complexity += 1;
313        }
314        if task_lower.contains("multiple") || task_lower.contains("several") {
315            complexity += 1;
316        }
317
318        // 长度因素
319        let word_count = task.split_whitespace().count();
320        if word_count > 20 {
321            complexity += 1;
322        }
323        if word_count > 50 {
324            complexity += 1;
325        }
326
327        complexity.min(10)
328    }
329
330    /// 估算风险等级
331    fn estimate_risk(&self, task: &str, complexity: u32) -> RiskLevel {
332        let task_lower = task.to_lowercase();
333
334        // 检查高风险关键词
335        if task_lower.contains("delete")
336            || task_lower.contains("remove")
337            || task_lower.contains("drop")
338        {
339            return RiskLevel::Critical;
340        }
341        if task_lower.contains("production")
342            || task_lower.contains("live")
343            || task_lower.contains("deploy")
344        {
345            return RiskLevel::High;
346        }
347        if task_lower.contains("database") || task_lower.contains("migration") {
348            return RiskLevel::High;
349        }
350
351        // 基于复杂度
352        match complexity {
353            1..=3 => RiskLevel::Low,
354            4..=6 => RiskLevel::Medium,
355            7..=8 => RiskLevel::High,
356            _ => RiskLevel::Critical,
357        }
358    }
359
360    /// 顺序分解
361    fn decompose_sequential(&self, task: &str) -> Vec<SubTask> {
362        let steps = self.extract_steps(task);
363        let mut subtasks = Vec::new();
364        let mut prev_id: Option<String> = None;
365
366        for (i, step) in steps.into_iter().enumerate() {
367            let id = format!("step_{}", i + 1);
368            let mut subtask = SubTask::new(&id, format!("Step {}", i + 1), step);
369            subtask.priority = i as u32;
370
371            if let Some(prev) = prev_id {
372                subtask = subtask.with_dependency(prev);
373            }
374
375            prev_id = Some(id);
376            subtasks.push(subtask);
377        }
378
379        if subtasks.is_empty() {
380            subtasks.push(SubTask::new("step_1", "Execute task", task));
381        }
382
383        subtasks
384    }
385
386    /// 并行分解
387    fn decompose_parallel(&self, task: &str) -> Vec<SubTask> {
388        let parts = self.extract_parallel_parts(task);
389        let mut subtasks = Vec::new();
390
391        for (i, part) in parts.into_iter().enumerate() {
392            let id = format!("parallel_{}", i + 1);
393            let subtask = SubTask::new(&id, format!("Task {}", i + 1), part);
394            subtasks.push(subtask);
395        }
396
397        if subtasks.is_empty() {
398            subtasks.push(SubTask::new("parallel_1", "Execute task", task));
399        }
400
401        subtasks
402    }
403
404    /// 层次分解
405    fn decompose_hierarchical(&self, task: &str, depth: u32) -> Vec<SubTask> {
406        if depth >= self.max_depth {
407            return vec![SubTask::new(&format!("leaf_{}", depth), "Execute", task)];
408        }
409
410        let main_steps = self.extract_steps(task);
411        let mut subtasks = Vec::new();
412
413        for (i, step) in main_steps.into_iter().enumerate() {
414            let id = format!("h{}_{}", depth, i + 1);
415            let mut subtask = SubTask::new(&id, format!("Phase {}", i + 1), step.clone());
416            subtask.estimated_complexity = self.analyze_complexity(&step);
417
418            // 如果步骤仍然复杂,继续分解
419            if subtask.estimated_complexity > 5 && depth < self.max_depth - 1 {
420                let sub_subtasks = self.decompose_hierarchical(&step, depth + 1);
421                for (j, sub_sub) in sub_subtasks.into_iter().enumerate() {
422                    let mut sub_sub_id = sub_sub;
423                    sub_sub_id.id = format!("{}_{}", id, j + 1);
424                    sub_sub_id.dependencies.push(id.clone());
425                    subtasks.push(sub_sub_id);
426                }
427            }
428
429            subtasks.push(subtask);
430        }
431
432        subtasks
433    }
434
435    /// 混合分解
436    fn decompose_hybrid(&self, task: &str) -> Vec<SubTask> {
437        let complexity = self.analyze_complexity(task);
438
439        if complexity <= 3 {
440            // 简单任务,单步执行
441            vec![SubTask::new("execute", "Execute task", task)]
442        } else if complexity <= 6 {
443            // 中等复杂度,顺序分解
444            self.decompose_sequential(task)
445        } else {
446            // 高复杂度,层次分解
447            self.decompose_hierarchical(task, 0)
448        }
449    }
450
451    /// 提取步骤
452    fn extract_steps(&self, task: &str) -> Vec<String> {
453        let mut steps = Vec::new();
454
455        // 按句号、分号、换行分割
456        let sentences: Vec<&str> = task
457            .split(&['.', ';', '\n'][..])
458            .map(|s| s.trim())
459            .filter(|s| !s.is_empty())
460            .collect();
461
462        if sentences.len() > 1 {
463            steps = sentences.into_iter().map(|s| s.to_string()).collect();
464        } else {
465            // 尝试按 "and then" 或 "then" 分割
466            let mut then_parts: Vec<&str> = task.split("and then").collect();
467            if then_parts.len() == 1 {
468                then_parts = task.split("then").collect();
469            }
470            if then_parts.len() == 1 {
471                then_parts = task.split("after that").collect();
472            }
473
474            if then_parts.len() > 1 {
475                steps = then_parts
476                    .into_iter()
477                    .map(|s| s.trim().to_string())
478                    .collect();
479            } else {
480                // 单一任务
481                steps.push(task.to_string());
482            }
483        }
484
485        steps
486    }
487
488    /// 提取并行部分
489    fn extract_parallel_parts(&self, task: &str) -> Vec<String> {
490        // 按 "and" 或逗号分割
491        let mut parts: Vec<&str> = task.split(", and ").collect();
492        if parts.len() == 1 {
493            parts = task.split(" and ").collect();
494        }
495        if parts.len() == 1 {
496            parts = task.split(", ").collect();
497        }
498
499        let parts: Vec<&str> = parts
500            .into_iter()
501            .map(|s| s.trim())
502            .filter(|s| !s.is_empty() && s.len() > 3)
503            .collect();
504
505        if parts.len() > 1 {
506            parts.into_iter().map(|s| s.to_string()).collect()
507        } else {
508            vec![task.to_string()]
509        }
510    }
511}
512
513/// 规划结果
514#[derive(Debug, Clone)]
515pub struct PlanResult {
516    /// 执行计划
517    pub plan: ExecutionPlan,
518    /// 分解质量评分 (0-100)
519    pub quality_score: u32,
520    /// 分解建议
521    pub suggestions: Vec<String>,
522}
523
524impl PlanResult {
525    /// 创建新的规划结果
526    pub fn new(plan: ExecutionPlan) -> Self {
527        let quality_score = Self::calculate_quality(&plan);
528        let suggestions = Self::generate_suggestions(&plan);
529
530        Self {
531            plan,
532            quality_score,
533            suggestions,
534        }
535    }
536
537    /// 计算分解质量
538    fn calculate_quality(plan: &ExecutionPlan) -> u32 {
539        let mut score = 100u32;
540
541        // 检查子任务数量
542        if plan.subtasks.is_empty() {
543            score = 0;
544        } else if plan.subtasks.len() == 1 {
545            score -= 20; // 未分解
546        }
547
548        // 检查执行顺序是否合理
549        if plan.execution_order.len() != plan.subtasks.len() {
550            score -= 30;
551        }
552
553        // 检查是否有验证条件
554        let has_validation = plan
555            .subtasks
556            .iter()
557            .any(|s| !s.validation_criteria.is_empty());
558        if !has_validation {
559            score -= 10;
560        }
561
562        // 检查是否有失败处理
563        let has_fallback = plan.subtasks.iter().any(|s| s.fallback.is_some());
564        if !has_fallback && plan.risk_level >= RiskLevel::High {
565            score -= 15;
566        }
567
568        score
569    }
570
571    /// 生成建议
572    fn generate_suggestions(plan: &ExecutionPlan) -> Vec<String> {
573        let mut suggestions = Vec::new();
574
575        if plan.subtasks.len() == 1 {
576            suggestions.push("Consider breaking down the task into smaller subtasks".to_string());
577        }
578
579        if plan.risk_level >= RiskLevel::High {
580            suggestions.push("High-risk task: consider adding validation steps".to_string());
581        }
582
583        let has_fallback = plan.subtasks.iter().any(|s| s.fallback.is_some());
584        if !has_fallback && !plan.subtasks.is_empty() {
585            suggestions
586                .push("Consider adding fallback strategies for critical subtasks".to_string());
587        }
588
589        suggestions
590    }
591}
592
593#[cfg(test)]
594mod tests {
595    use super::{
596        DecompositionStrategy, ExecutionPlan, PlanResult, RiskLevel, SubTask, TaskDecomposer,
597    };
598
599    #[test]
600    fn test_subtask_creation() {
601        let subtask = SubTask::new("test_1", "Test", "Test subtask");
602        assert_eq!(subtask.id, "test_1");
603        assert_eq!(subtask.name, "Test");
604    }
605
606    #[test]
607    fn test_subtask_with_dependencies() {
608        let subtask = SubTask::new("test_2", "Test", "Test").with_dependency("test_1");
609        assert_eq!(subtask.dependencies.len(), 1);
610    }
611
612    #[test]
613    fn test_execution_plan_creation() {
614        let plan = ExecutionPlan::new("Test task");
615        assert!(!plan.original_task.is_empty());
616        assert!(plan.subtasks.is_empty());
617    }
618
619    #[test]
620    fn test_task_decomposer() {
621        let decomposer = TaskDecomposer::new();
622        let plan = decomposer
623            .decompose("Create a file and write some content")
624            .unwrap();
625
626        assert!(!plan.subtasks.is_empty());
627        assert!(!plan.execution_order.is_empty());
628    }
629
630    #[test]
631    fn test_complexity_analysis() {
632        let decomposer = TaskDecomposer::new();
633
634        let simple = decomposer.analyze_complexity("Read a file");
635        assert!(simple <= 3);
636
637        let complex = decomposer.analyze_complexity(
638            "Implement a complete authentication system with OAuth2 integration",
639        );
640        assert!(complex > 3);
641    }
642
643    #[test]
644    fn test_risk_estimation() {
645        let decomposer = TaskDecomposer::new();
646
647        let low = decomposer.estimate_risk("Read a file", 2);
648        assert_eq!(low, RiskLevel::Low);
649
650        let critical = decomposer.estimate_risk("Delete the production database", 5);
651        assert_eq!(critical, RiskLevel::Critical);
652    }
653
654    #[test]
655    fn test_sequential_decomposition() {
656        let decomposer = TaskDecomposer::new().with_strategy(DecompositionStrategy::Sequential);
657
658        let plan = decomposer
659            .decompose("First step. Second step. Third step.")
660            .unwrap();
661
662        assert!(plan.subtasks.len() >= 3);
663        // 验证依赖链
664        for i in 1..plan.subtasks.len() {
665            assert!(plan.subtasks[i]
666                .dependencies
667                .contains(&plan.subtasks[i - 1].id));
668        }
669    }
670
671    #[test]
672    fn test_parallel_decomposition() {
673        let decomposer = TaskDecomposer::new().with_strategy(DecompositionStrategy::Parallel);
674
675        let plan = decomposer
676            .decompose("Task A and Task B and Task C")
677            .unwrap();
678
679        assert!(plan.subtasks.len() >= 2);
680        // 并行任务不应该有依赖
681        let has_deps: bool = plan.subtasks.iter().any(|s| !s.dependencies.is_empty());
682        assert!(!has_deps);
683    }
684
685    #[test]
686    fn test_plan_result_quality() {
687        let mut plan = ExecutionPlan::new("Test task");
688        plan.add_subtask(SubTask::new("s1", "Step 1", "First step"));
689        plan.add_subtask(SubTask::new("s2", "Step 2", "Second step").with_dependency("s1"));
690        plan.compute_execution_order().unwrap();
691
692        let result = PlanResult::new(plan);
693        assert!(result.quality_score > 0);
694    }
695
696    #[test]
697    fn test_dag_conversion() {
698        let mut plan = ExecutionPlan::new("Test task");
699        plan.add_subtask(SubTask::new("s1", "Step 1", "First step"));
700        plan.add_subtask(SubTask::new("s2", "Step 2", "Second step").with_dependency("s1"));
701        plan.compute_execution_order().unwrap();
702
703        let dag_result = plan.to_dag();
704        assert!(dag_result.is_ok());
705    }
706}