1use crate::agent::{Agent, AgentError, Decision};
27use crate::context::AgentContext;
28use crate::registry::ToolRegistry;
29use crate::types::Message;
30use serde_json::Value;
31
32#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
34pub struct Plan {
35 pub summary: String,
36 pub steps: Vec<PlanStep>,
37}
38
39#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct PlanStep {
42 pub description: String,
43 #[serde(default)]
44 pub files: Vec<String>,
45 #[serde(default)]
46 pub tool_hints: Vec<String>,
47}
48
49impl Plan {
50 pub fn from_context(ctx: &AgentContext) -> Option<Self> {
52 ctx.get("plan")
53 .and_then(|v| serde_json::from_value(v.clone()).ok())
54 }
55
56 pub fn to_message(&self) -> Message {
58 let mut text = format!("## Implementation Plan\n\n{}\n\n", self.summary);
59 for (i, step) in self.steps.iter().enumerate() {
60 text.push_str(&format!("{}. {}\n", i + 1, step.description));
61 if !step.files.is_empty() {
62 text.push_str(&format!(" Files: {}\n", step.files.join(", ")));
63 }
64 }
65 Message::system(&text)
66 }
67}
68
69pub const READ_ONLY_TOOLS: &[&str] = &[
71 "read_file",
72 "list_files",
73 "list_dir",
74 "search",
75 "search_code",
76 "grep",
77 "glob",
78 "git_status",
79 "git_diff",
80 "git_log",
81 "get_cwd",
82 "change_dir",
83 "ask_user",
85 "submit_plan",
86 "finish_task",
87];
88
89pub struct PlanningAgent {
94 inner: Box<dyn Agent>,
95 allowed_tools: Vec<String>,
96}
97
98impl PlanningAgent {
99 pub fn new(inner: Box<dyn Agent>) -> Self {
100 Self {
101 inner,
102 allowed_tools: READ_ONLY_TOOLS.iter().map(|s| s.to_string()).collect(),
103 }
104 }
105
106 pub fn with_allowed_tools(mut self, tools: Vec<String>) -> Self {
108 self.allowed_tools = tools;
109 self
110 }
111
112 pub fn allow_tool(mut self, name: impl Into<String>) -> Self {
114 self.allowed_tools.push(name.into());
115 self
116 }
117}
118
119#[async_trait::async_trait]
120impl Agent for PlanningAgent {
121 async fn decide(
122 &self,
123 messages: &[Message],
124 tools: &ToolRegistry,
125 ) -> Result<Decision, AgentError> {
126 self.inner.decide(messages, tools).await
127 }
128
129 fn prepare_tools(&self, _ctx: &AgentContext, tools: &ToolRegistry) -> Vec<String> {
130 tools
131 .list()
132 .iter()
133 .filter(|t| {
134 t.is_system()
135 || self
136 .allowed_tools
137 .iter()
138 .any(|a| a.eq_ignore_ascii_case(t.name()))
139 })
140 .map(|t| t.name().to_string())
141 .collect()
142 }
143
144 fn prepare_context(&self, ctx: &mut AgentContext, messages: &[Message]) {
145 ctx.set("plan_mode", Value::Bool(true));
146 self.inner.prepare_context(ctx, messages);
147 }
148
149 fn after_action(&self, ctx: &mut AgentContext, tool_name: &str, output: &str) {
150 self.inner.after_action(ctx, tool_name, output);
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::agent_tool::{Tool, ToolError, ToolOutput};
158 use crate::registry::ToolRegistry;
159
160 struct MockAgent;
162
163 #[async_trait::async_trait]
164 impl Agent for MockAgent {
165 async fn decide(&self, _: &[Message], _: &ToolRegistry) -> Result<Decision, AgentError> {
166 Ok(Decision {
167 situation: "planning".into(),
168 task: vec![],
169 tool_calls: vec![],
170 completed: true,
171 })
172 }
173 }
174
175 struct ReadFileTool;
176 #[async_trait::async_trait]
177 impl Tool for ReadFileTool {
178 fn name(&self) -> &str {
179 "read_file"
180 }
181 fn description(&self) -> &str {
182 "read"
183 }
184 fn parameters_schema(&self) -> Value {
185 serde_json::json!({"type": "object"})
186 }
187 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
188 Ok(ToolOutput::text("content"))
189 }
190 }
191
192 struct WriteFileTool;
193 #[async_trait::async_trait]
194 impl Tool for WriteFileTool {
195 fn name(&self) -> &str {
196 "write_file"
197 }
198 fn description(&self) -> &str {
199 "write"
200 }
201 fn parameters_schema(&self) -> Value {
202 serde_json::json!({"type": "object"})
203 }
204 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
205 Ok(ToolOutput::text("written"))
206 }
207 }
208
209 struct BashTool;
210 #[async_trait::async_trait]
211 impl Tool for BashTool {
212 fn name(&self) -> &str {
213 "bash"
214 }
215 fn description(&self) -> &str {
216 "bash"
217 }
218 fn parameters_schema(&self) -> Value {
219 serde_json::json!({"type": "object"})
220 }
221 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
222 Ok(ToolOutput::text("output"))
223 }
224 }
225
226 #[test]
227 fn planning_filters_write_tools() {
228 let planner = PlanningAgent::new(Box::new(MockAgent));
229 let tools = ToolRegistry::new()
230 .register(ReadFileTool)
231 .register(WriteFileTool)
232 .register(BashTool);
233
234 let ctx = AgentContext::new();
235 let allowed = planner.prepare_tools(&ctx, &tools);
236
237 assert!(allowed.contains(&"read_file".to_string()));
238 assert!(!allowed.contains(&"write_file".to_string()));
239 assert!(!allowed.contains(&"bash".to_string()));
240 }
241
242 #[test]
243 fn planning_sets_plan_mode_in_context() {
244 let planner = PlanningAgent::new(Box::new(MockAgent));
245 let mut ctx = AgentContext::new();
246 let msgs = vec![Message::user("plan this")];
247
248 planner.prepare_context(&mut ctx, &msgs);
249 assert_eq!(ctx.get("plan_mode"), Some(&Value::Bool(true)));
250 }
251
252 #[test]
253 fn plan_from_context() {
254 let mut ctx = AgentContext::new();
255 ctx.set(
256 "plan",
257 serde_json::json!({
258 "summary": "Add auth",
259 "steps": [
260 {"description": "Create module", "files": ["src/auth.rs"]},
261 {"description": "Write tests"}
262 ]
263 }),
264 );
265
266 let plan = Plan::from_context(&ctx).unwrap();
267 assert_eq!(plan.summary, "Add auth");
268 assert_eq!(plan.steps.len(), 2);
269 assert_eq!(plan.steps[0].files, vec!["src/auth.rs"]);
270 assert!(plan.steps[1].files.is_empty());
271 }
272
273 #[test]
274 fn plan_to_message() {
275 let plan = Plan {
276 summary: "Refactor auth".into(),
277 steps: vec![
278 PlanStep {
279 description: "Extract trait".into(),
280 files: vec!["src/auth.rs".into()],
281 tool_hints: vec![],
282 },
283 PlanStep {
284 description: "Add tests".into(),
285 files: vec![],
286 tool_hints: vec![],
287 },
288 ],
289 };
290 let msg = plan.to_message();
291 assert!(msg.content.contains("Refactor auth"));
292 assert!(msg.content.contains("1. Extract trait"));
293 assert!(msg.content.contains("src/auth.rs"));
294 }
295
296 #[test]
297 fn allow_extra_tools() {
298 let planner = PlanningAgent::new(Box::new(MockAgent)).allow_tool("custom_search");
299
300 let tools = ToolRegistry::new()
301 .register(ReadFileTool)
302 .register(WriteFileTool);
303
304 let ctx = AgentContext::new();
305 let allowed = planner.prepare_tools(&ctx, &tools);
306 assert!(allowed.contains(&"read_file".to_string()));
307 }
310}