Skip to main content

sgr_agent/agents/
planning.rs

1//! PlanningAgent — read-only agent variant that produces structured plans.
2//!
3//! Wraps any `Agent` impl and restricts tools to a read-only subset.
4//! The agent explores the codebase, then calls `submit_plan` with a structured plan.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! let inner = SgrAgent::new(client, PLAN_SYSTEM_PROMPT);
10//! let planner = PlanningAgent::new(Box::new(inner));
11//!
12//! // Register read-only tools + PlanTool
13//! let tools = ToolRegistry::new()
14//!     .register(ReadFile)
15//!     .register(ListDir)
16//!     .register(SearchCode)
17//!     .register(PlanTool)
18//!     .register(ClarificationTool);
19//!
20//! run_loop(&planner, &tools, &mut ctx, &mut msgs, &config, |e| { ... }).await?;
21//!
22//! // Extract the plan
23//! let plan: Plan = Plan::from_context(&ctx).unwrap();
24//! ```
25
26use crate::agent::{Agent, AgentError, Decision};
27use crate::context::AgentContext;
28use crate::registry::ToolRegistry;
29use crate::types::Message;
30use serde_json::Value;
31
32/// A structured plan produced by the PlanningAgent.
33#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
34pub struct Plan {
35    pub summary: String,
36    pub steps: Vec<PlanStep>,
37}
38
39/// A single step in the plan.
40#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct PlanStep {
42    pub description: String,
43    #[serde(default)]
44    pub files: Vec<String>,
45    #[serde(default)]
46    pub tool_hints: Vec<String>,
47}
48
49impl Plan {
50    /// Extract plan from AgentContext (set by PlanTool).
51    pub fn from_context(ctx: &AgentContext) -> Option<Self> {
52        ctx.get("plan")
53            .and_then(|v| serde_json::from_value(v.clone()).ok())
54    }
55
56    /// Convert plan to a message for injection into a build agent's context.
57    pub fn to_message(&self) -> Message {
58        let mut text = format!("## Implementation Plan\n\n{}\n\n", self.summary);
59        for (i, step) in self.steps.iter().enumerate() {
60            text.push_str(&format!("{}. {}\n", i + 1, step.description));
61            if !step.files.is_empty() {
62                text.push_str(&format!("   Files: {}\n", step.files.join(", ")));
63            }
64        }
65        Message::system(&text)
66    }
67}
68
69/// Tool names that are safe for read-only plan mode.
70pub const READ_ONLY_TOOLS: &[&str] = &[
71    "read_file",
72    "list_files",
73    "list_dir",
74    "search",
75    "search_code",
76    "grep",
77    "glob",
78    "git_status",
79    "git_diff",
80    "git_log",
81    "get_cwd",
82    "change_dir",
83    // System tools (always allowed)
84    "ask_user",
85    "submit_plan",
86    "finish_task",
87];
88
89/// Wraps any Agent to enforce read-only tool access for planning.
90///
91/// Filters tools via `prepare_tools` to only allow read-only operations.
92/// Sets `plan_mode: true` in context so tools can check and adapt behavior.
93pub struct PlanningAgent {
94    inner: Box<dyn Agent>,
95    allowed_tools: Vec<String>,
96}
97
98impl PlanningAgent {
99    pub fn new(inner: Box<dyn Agent>) -> Self {
100        Self {
101            inner,
102            allowed_tools: READ_ONLY_TOOLS.iter().map(|s| s.to_string()).collect(),
103        }
104    }
105
106    /// Override the set of allowed tools (replaces default READ_ONLY_TOOLS).
107    pub fn with_allowed_tools(mut self, tools: Vec<String>) -> Self {
108        self.allowed_tools = tools;
109        self
110    }
111
112    /// Add extra tools to the allowed set (e.g. custom read-only tools).
113    pub fn allow_tool(mut self, name: impl Into<String>) -> Self {
114        self.allowed_tools.push(name.into());
115        self
116    }
117}
118
119#[async_trait::async_trait]
120impl Agent for PlanningAgent {
121    async fn decide(
122        &self,
123        messages: &[Message],
124        tools: &ToolRegistry,
125    ) -> Result<Decision, AgentError> {
126        self.inner.decide(messages, tools).await
127    }
128
129    fn prepare_tools(&self, _ctx: &AgentContext, tools: &ToolRegistry) -> Vec<String> {
130        tools
131            .list()
132            .iter()
133            .filter(|t| {
134                t.is_system()
135                    || self
136                        .allowed_tools
137                        .iter()
138                        .any(|a| a.eq_ignore_ascii_case(t.name()))
139            })
140            .map(|t| t.name().to_string())
141            .collect()
142    }
143
144    fn prepare_context(&self, ctx: &mut AgentContext, messages: &[Message]) {
145        ctx.set("plan_mode", Value::Bool(true));
146        self.inner.prepare_context(ctx, messages);
147    }
148
149    fn after_action(&self, ctx: &mut AgentContext, tool_name: &str, output: &str) {
150        self.inner.after_action(ctx, tool_name, output);
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::agent_tool::{Tool, ToolError, ToolOutput};
158    use crate::registry::ToolRegistry;
159
160    // Mock agent that returns one tool call then completes
161    struct MockAgent;
162
163    #[async_trait::async_trait]
164    impl Agent for MockAgent {
165        async fn decide(&self, _: &[Message], _: &ToolRegistry) -> Result<Decision, AgentError> {
166            Ok(Decision {
167                situation: "planning".into(),
168                task: vec![],
169                tool_calls: vec![],
170                completed: true,
171            })
172        }
173    }
174
175    struct ReadFileTool;
176    #[async_trait::async_trait]
177    impl Tool for ReadFileTool {
178        fn name(&self) -> &str {
179            "read_file"
180        }
181        fn description(&self) -> &str {
182            "read"
183        }
184        fn parameters_schema(&self) -> Value {
185            serde_json::json!({"type": "object"})
186        }
187        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
188            Ok(ToolOutput::text("content"))
189        }
190    }
191
192    struct WriteFileTool;
193    #[async_trait::async_trait]
194    impl Tool for WriteFileTool {
195        fn name(&self) -> &str {
196            "write_file"
197        }
198        fn description(&self) -> &str {
199            "write"
200        }
201        fn parameters_schema(&self) -> Value {
202            serde_json::json!({"type": "object"})
203        }
204        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
205            Ok(ToolOutput::text("written"))
206        }
207    }
208
209    struct BashTool;
210    #[async_trait::async_trait]
211    impl Tool for BashTool {
212        fn name(&self) -> &str {
213            "bash"
214        }
215        fn description(&self) -> &str {
216            "bash"
217        }
218        fn parameters_schema(&self) -> Value {
219            serde_json::json!({"type": "object"})
220        }
221        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
222            Ok(ToolOutput::text("output"))
223        }
224    }
225
226    #[test]
227    fn planning_filters_write_tools() {
228        let planner = PlanningAgent::new(Box::new(MockAgent));
229        let tools = ToolRegistry::new()
230            .register(ReadFileTool)
231            .register(WriteFileTool)
232            .register(BashTool);
233
234        let ctx = AgentContext::new();
235        let allowed = planner.prepare_tools(&ctx, &tools);
236
237        assert!(allowed.contains(&"read_file".to_string()));
238        assert!(!allowed.contains(&"write_file".to_string()));
239        assert!(!allowed.contains(&"bash".to_string()));
240    }
241
242    #[test]
243    fn planning_sets_plan_mode_in_context() {
244        let planner = PlanningAgent::new(Box::new(MockAgent));
245        let mut ctx = AgentContext::new();
246        let msgs = vec![Message::user("plan this")];
247
248        planner.prepare_context(&mut ctx, &msgs);
249        assert_eq!(ctx.get("plan_mode"), Some(&Value::Bool(true)));
250    }
251
252    #[test]
253    fn plan_from_context() {
254        let mut ctx = AgentContext::new();
255        ctx.set(
256            "plan",
257            serde_json::json!({
258                "summary": "Add auth",
259                "steps": [
260                    {"description": "Create module", "files": ["src/auth.rs"]},
261                    {"description": "Write tests"}
262                ]
263            }),
264        );
265
266        let plan = Plan::from_context(&ctx).unwrap();
267        assert_eq!(plan.summary, "Add auth");
268        assert_eq!(plan.steps.len(), 2);
269        assert_eq!(plan.steps[0].files, vec!["src/auth.rs"]);
270        assert!(plan.steps[1].files.is_empty());
271    }
272
273    #[test]
274    fn plan_to_message() {
275        let plan = Plan {
276            summary: "Refactor auth".into(),
277            steps: vec![
278                PlanStep {
279                    description: "Extract trait".into(),
280                    files: vec!["src/auth.rs".into()],
281                    tool_hints: vec![],
282                },
283                PlanStep {
284                    description: "Add tests".into(),
285                    files: vec![],
286                    tool_hints: vec![],
287                },
288            ],
289        };
290        let msg = plan.to_message();
291        assert!(msg.content.contains("Refactor auth"));
292        assert!(msg.content.contains("1. Extract trait"));
293        assert!(msg.content.contains("src/auth.rs"));
294    }
295
296    #[test]
297    fn allow_extra_tools() {
298        let planner = PlanningAgent::new(Box::new(MockAgent)).allow_tool("custom_search");
299
300        let tools = ToolRegistry::new()
301            .register(ReadFileTool)
302            .register(WriteFileTool);
303
304        let ctx = AgentContext::new();
305        let allowed = planner.prepare_tools(&ctx, &tools);
306        assert!(allowed.contains(&"read_file".to_string()));
307        // custom_search not in registry, so not in result
308        // but if it were, it would be allowed
309    }
310}