1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::sync::RwLock;
5
6use crate::agents::DEFAULT_AGENT_SPEC_ID;
7use crate::config::model::ModelId;
8use crate::prompts::system_prompt_for_model;
9use crate::session::state::{
10 ApprovalRules, SessionConfig, SessionPolicyOverrides, ToolApprovalPolicy, ToolRule,
11 ToolVisibility, UnapprovedBehavior,
12};
13use crate::tools::DISPATCH_AGENT_TOOL_NAME;
14use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
15
16pub const NORMAL_PRIMARY_AGENT_ID: &str = "normal";
17pub const PLANNER_PRIMARY_AGENT_ID: &str = "plan";
18pub const YOLO_PRIMARY_AGENT_ID: &str = "yolo";
19pub const DEFAULT_PRIMARY_AGENT_ID: &str = NORMAL_PRIMARY_AGENT_ID;
20
21static PLANNER_SYSTEM_PROMPT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
22 format!(
23 r#"You are in plan mode. Produce a concise, step-by-step plan only.
24
25Rules:
26- Use read-only tools to gather the context you need before planning.
27- When broader search is needed, use dispatch_agent with the "explore" sub-agent.
28- Do not make changes or write code/patches.
29- If key details are missing, ask up to three targeted questions and stop.
30
31When you can proceed, respond using this structure (omit empty sections):
32Plan:
331. ...
342. ...
353. ...
36
37Assumptions:
38- ...
39
40Risks:
41- ...
42
43Validation:
44- ...
45
46Finish by asking the user to switch back to "{NORMAL_PRIMARY_AGENT_ID}" or "{YOLO_PRIMARY_AGENT_ID}" to execute."#,
47 )
48});
49
50#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
51pub struct PrimaryAgentSpec {
52 pub id: String,
53 pub name: String,
54 pub description: String,
55 pub model: Option<ModelId>,
56 pub system_prompt: Option<String>,
57 pub tool_visibility: ToolVisibility,
58 pub approval_policy: ToolApprovalPolicy,
59}
60
61static PRIMARY_AGENT_SPECS: std::sync::LazyLock<RwLock<HashMap<String, PrimaryAgentSpec>>> =
62 std::sync::LazyLock::new(|| {
63 let mut specs = HashMap::new();
64 for spec in default_primary_agent_specs() {
65 specs.insert(spec.id.clone(), spec);
66 }
67 RwLock::new(specs)
68 });
69
70pub fn primary_agent_spec(id: &str) -> Option<PrimaryAgentSpec> {
71 let registry = PRIMARY_AGENT_SPECS.read().ok()?;
72 registry.get(id).cloned()
73}
74
75pub fn primary_agent_specs() -> Vec<PrimaryAgentSpec> {
76 let registry = match PRIMARY_AGENT_SPECS.read() {
77 Ok(registry) => registry,
78 Err(_) => return Vec::new(),
79 };
80 let mut specs: Vec<_> = registry.values().cloned().collect();
81 specs.sort_by(|a, b| a.id.cmp(&b.id));
82 specs
83}
84
85pub fn default_primary_agent_id() -> &'static str {
86 DEFAULT_PRIMARY_AGENT_ID
87}
88
89pub fn resolve_effective_config(base_config: &SessionConfig) -> SessionConfig {
90 let mut config = base_config.clone();
91
92 let requested_agent_id = base_config
93 .primary_agent_id
94 .clone()
95 .unwrap_or_else(|| DEFAULT_PRIMARY_AGENT_ID.to_string());
96
97 let (primary_agent_id, spec) = if let Some(spec) = primary_agent_spec(&requested_agent_id) {
98 (requested_agent_id, spec)
99 } else {
100 let fallback_spec =
101 primary_agent_spec(DEFAULT_PRIMARY_AGENT_ID).unwrap_or_else(|| PrimaryAgentSpec {
102 id: DEFAULT_PRIMARY_AGENT_ID.to_string(),
103 name: "Normal".to_string(),
104 description: "Default agent".to_string(),
105 model: None,
106 system_prompt: None,
107 tool_visibility: ToolVisibility::All,
108 approval_policy: ToolApprovalPolicy::default(),
109 });
110 (DEFAULT_PRIMARY_AGENT_ID.to_string(), fallback_spec)
111 };
112
113 let effective_model = resolve_default_model(&config, &spec, &base_config.policy_overrides);
114 let effective_visibility = base_config
115 .policy_overrides
116 .tool_visibility
117 .clone()
118 .unwrap_or_else(|| spec.tool_visibility.clone());
119 let effective_approval_policy = base_config
120 .policy_overrides
121 .approval_policy
122 .apply_to(&spec.approval_policy);
123 let effective_system_prompt = resolve_system_prompt(
124 &config,
125 &spec,
126 &effective_model,
127 &base_config.policy_overrides,
128 );
129
130 config.primary_agent_id = Some(primary_agent_id);
131 config.default_model = effective_model;
132 config.tool_config.visibility = effective_visibility;
133 config.tool_config.approval_policy = effective_approval_policy;
134 config.system_prompt = effective_system_prompt;
135
136 config
137}
138
139fn resolve_default_model(
140 config: &SessionConfig,
141 spec: &PrimaryAgentSpec,
142 overrides: &SessionPolicyOverrides,
143) -> ModelId {
144 let mut model = config.default_model.clone();
145 if let Some(spec_model) = spec.model.as_ref() {
146 model = spec_model.clone();
147 }
148 if let Some(override_model) = overrides.default_model.as_ref() {
149 model = override_model.clone();
150 }
151 model
152}
153
154fn resolve_system_prompt(
155 config: &SessionConfig,
156 spec: &PrimaryAgentSpec,
157 model: &ModelId,
158 _overrides: &SessionPolicyOverrides,
159) -> Option<String> {
160 if let Some(prompt) = spec.system_prompt.as_ref() {
161 return Some(prompt.clone());
162 }
163
164 if let Some(prompt) = config.system_prompt.as_ref()
165 && !prompt.trim().is_empty()
166 && !is_known_primary_agent_prompt(prompt)
167 {
168 return Some(prompt.clone());
169 }
170
171 Some(system_prompt_for_model(model))
172}
173
174fn is_known_primary_agent_prompt(prompt: &str) -> bool {
175 primary_agent_specs()
176 .iter()
177 .any(|spec| spec.system_prompt.as_deref() == Some(prompt))
178}
179
180fn approval_policy_with_dispatch_explore_preapproval() -> ToolApprovalPolicy {
181 let mut policy = ToolApprovalPolicy::default();
182 policy.preapproved.per_tool.insert(
183 DISPATCH_AGENT_TOOL_NAME.to_string(),
184 ToolRule::DispatchAgent {
185 agent_patterns: vec![DEFAULT_AGENT_SPEC_ID.to_string()],
186 },
187 );
188 policy
189}
190
191fn default_primary_agent_specs() -> Vec<PrimaryAgentSpec> {
192 let planner_tool_visibility = ToolVisibility::Whitelist(
193 READ_ONLY_TOOL_NAMES
194 .iter()
195 .map(|name| (*name).to_string())
196 .chain(std::iter::once(DISPATCH_AGENT_TOOL_NAME.to_string()))
197 .collect::<HashSet<_>>(),
198 );
199
200 vec![
201 PrimaryAgentSpec {
202 id: NORMAL_PRIMARY_AGENT_ID.to_string(),
203 name: "Normal".to_string(),
204 description: "Default agent with full tool visibility. Tools which can write require explicit approvals."
205 .to_string(),
206 model: None,
207 system_prompt: None,
208 tool_visibility: ToolVisibility::All,
209 approval_policy: approval_policy_with_dispatch_explore_preapproval(),
210 },
211 PrimaryAgentSpec {
212 id: PLANNER_PRIMARY_AGENT_ID.to_string(),
213 name: "Plan".to_string(),
214 description: "Plan-only agent with read-only tools.".to_string(),
215 model: None,
216 system_prompt: Some(PLANNER_SYSTEM_PROMPT.clone()),
217 tool_visibility: planner_tool_visibility,
218 approval_policy: approval_policy_with_dispatch_explore_preapproval(),
219 },
220 PrimaryAgentSpec {
221 id: YOLO_PRIMARY_AGENT_ID.to_string(),
222 name: "Yolo".to_string(),
223 description: "Full tool visibility with auto-approval for all tools.".to_string(),
224 model: None,
225 system_prompt: None,
226 tool_visibility: ToolVisibility::All,
227 approval_policy: ToolApprovalPolicy {
228 default_behavior: UnapprovedBehavior::Allow,
229 preapproved: ApprovalRules::default(),
230 },
231 },
232 ]
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::config::model::builtin;
239 use crate::session::state::{
240 ApprovalRulesOverrides, SessionPolicyOverrides, SessionToolConfig,
241 ToolApprovalPolicyOverrides, ToolDecision, ToolRule, ToolRuleOverrides, WorkspaceConfig,
242 };
243 use crate::tools::DISPATCH_AGENT_TOOL_NAME;
244 use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
245 use std::collections::HashMap;
246 use std::path::PathBuf;
247 use steer_tools::tools::{TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME};
248
249 fn base_config() -> SessionConfig {
250 SessionConfig {
251 workspace: WorkspaceConfig::Local {
252 path: PathBuf::from("/tmp"),
253 },
254 workspace_ref: None,
255 workspace_id: None,
256 repo_ref: None,
257 parent_session_id: None,
258 workspace_name: None,
259 tool_config: SessionToolConfig::read_only(),
260 system_prompt: Some("base prompt".to_string()),
261 primary_agent_id: None,
262 policy_overrides: SessionPolicyOverrides::empty(),
263 metadata: HashMap::new(),
264 default_model: builtin::claude_sonnet_4_5(),
265 }
266 }
267
268 #[test]
269 fn default_primary_agent_exists() {
270 let id = default_primary_agent_id();
271 let spec = primary_agent_spec(id);
272 assert!(spec.is_some());
273 assert_eq!(spec.unwrap().id, id);
274 }
275
276 #[test]
277 fn resolve_effective_config_preserves_base_when_unset() {
278 let mut config = base_config();
279 config.primary_agent_id = Some(NORMAL_PRIMARY_AGENT_ID.to_string());
280 let updated = resolve_effective_config(&config);
281
282 assert_eq!(updated.default_model, config.default_model);
283 assert_eq!(updated.system_prompt, config.system_prompt);
284 assert_eq!(updated.tool_config.visibility, ToolVisibility::All);
285 }
286
287 #[test]
288 fn resolve_effective_config_applies_overrides() {
289 let mut config = base_config();
290 config.primary_agent_id = Some(NORMAL_PRIMARY_AGENT_ID.to_string());
291 config.policy_overrides = SessionPolicyOverrides {
292 default_model: Some(builtin::claude_haiku_4_5()),
293 tool_visibility: Some(ToolVisibility::ReadOnly),
294 approval_policy: ToolApprovalPolicyOverrides {
295 default_behavior: Some(UnapprovedBehavior::Allow),
296 preapproved: ApprovalRulesOverrides {
297 tools: ["custom_tool".to_string()].into_iter().collect(),
298 per_tool: [(
299 "bash".to_string(),
300 ToolRuleOverrides::Bash {
301 patterns: vec!["git status".to_string()],
302 },
303 )]
304 .into_iter()
305 .collect(),
306 },
307 },
308 };
309
310 let updated = resolve_effective_config(&config);
311 assert_eq!(updated.default_model, builtin::claude_haiku_4_5());
312 assert_eq!(updated.tool_config.visibility, ToolVisibility::ReadOnly);
313 assert_eq!(
314 updated.tool_config.approval_policy.default_behavior,
315 UnapprovedBehavior::Allow
316 );
317 assert!(
318 updated
319 .tool_config
320 .approval_policy
321 .preapproved
322 .tools
323 .contains("custom_tool")
324 );
325
326 let rule = updated
327 .tool_config
328 .approval_policy
329 .preapproved
330 .per_tool
331 .get("bash")
332 .expect("bash rule");
333 match rule {
334 ToolRule::Bash { patterns } => {
335 assert!(patterns.contains(&"git status".to_string()));
336 }
337 ToolRule::DispatchAgent { .. } => panic!("Unexpected dispatch agent rule"),
338 }
339 }
340
341 #[test]
342 fn normal_spec_preapproves_dispatch_agent_for_explore() {
343 let spec = primary_agent_spec(NORMAL_PRIMARY_AGENT_ID).expect("normal spec");
344
345 let rule = spec
346 .approval_policy
347 .preapproved
348 .per_tool
349 .get(DISPATCH_AGENT_TOOL_NAME)
350 .expect("dispatch agent rule");
351
352 match rule {
353 ToolRule::DispatchAgent { agent_patterns } => {
354 assert_eq!(agent_patterns.as_slice(), [DEFAULT_AGENT_SPEC_ID]);
355 }
356 ToolRule::Bash { .. } => panic!("Unexpected bash rule"),
357 }
358 }
359
360 #[test]
361 fn plan_spec_limits_tools_and_dispatch_agent() {
362 let spec = primary_agent_spec(PLANNER_PRIMARY_AGENT_ID).expect("plan spec");
363
364 match &spec.tool_visibility {
365 ToolVisibility::Whitelist(allowed) => {
366 assert!(allowed.contains(DISPATCH_AGENT_TOOL_NAME));
367 for name in READ_ONLY_TOOL_NAMES {
368 assert!(allowed.contains(*name));
369 }
370 assert_eq!(allowed.len(), READ_ONLY_TOOL_NAMES.len() + 1);
371 }
372 other => panic!("Unexpected tool visibility: {other:?}"),
373 }
374
375 let rule = spec
376 .approval_policy
377 .preapproved
378 .per_tool
379 .get(DISPATCH_AGENT_TOOL_NAME)
380 .expect("dispatch agent rule");
381
382 match rule {
383 ToolRule::DispatchAgent { agent_patterns } => {
384 assert_eq!(agent_patterns.as_slice(), [DEFAULT_AGENT_SPEC_ID]);
385 }
386 ToolRule::Bash { .. } => panic!("Unexpected bash rule"),
387 }
388 }
389
390 #[test]
391 fn all_primary_agents_allow_todos_by_default() {
392 for agent_id in [
393 NORMAL_PRIMARY_AGENT_ID,
394 PLANNER_PRIMARY_AGENT_ID,
395 YOLO_PRIMARY_AGENT_ID,
396 ] {
397 let spec = primary_agent_spec(agent_id).expect("primary agent spec");
398 assert_eq!(
399 spec.approval_policy.tool_decision(TODO_READ_TOOL_NAME),
400 ToolDecision::Allow
401 );
402 assert_eq!(
403 spec.approval_policy.tool_decision(TODO_WRITE_TOOL_NAME),
404 ToolDecision::Allow
405 );
406 }
407 }
408}