Skip to main content

steer_core/
primary_agents.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::sync::RwLock;
5
6use crate::agents::DEFAULT_AGENT_SPEC_ID;
7use crate::config::model::ModelId;
8use crate::prompts::system_prompt_for_model;
9use crate::session::state::{
10    ApprovalRules, SessionConfig, SessionPolicyOverrides, ToolApprovalPolicy, ToolRule,
11    ToolVisibility, UnapprovedBehavior,
12};
13use crate::tools::DISPATCH_AGENT_TOOL_NAME;
14use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
15
16pub const NORMAL_PRIMARY_AGENT_ID: &str = "normal";
17pub const PLANNER_PRIMARY_AGENT_ID: &str = "plan";
18pub const YOLO_PRIMARY_AGENT_ID: &str = "yolo";
19pub const DEFAULT_PRIMARY_AGENT_ID: &str = NORMAL_PRIMARY_AGENT_ID;
20
21static PLANNER_SYSTEM_PROMPT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
22    format!(
23        r#"You are in plan mode. Your job is to either ask clarifying questions or provide a concise plan.
24
25Rules:
26- Use read-only tools to gather the context you need before planning.
27- When broader search is needed, use dispatch_agent with the "explore" sub-agent.
28- Do not make changes or write code/patches.
29- First check whether scope, constraints, or success criteria are missing.
30- If missing details would materially change the plan, ask targeted clarifying questions and stop.
31- Do not provide a plan in the same response as clarifying questions.
32- Prefer clarifying questions over speculative assumptions.
33
34When details are missing, respond using:
35Questions:
361. ...
372. ...
38
39When you can proceed, respond using this structure:
40Plan:
411. ...
422. ...
433. ...
44
45Assumptions:
46- ... (only include if required and specific)
47
48Risks:
49- ... (only include concrete, non-generic risks)
50
51Validation:
52- ... (only include specific, testable checks)
53
54Execution note:
55- Plan mode cannot execute changes.
56- If execution is needed, mention that switching to "{NORMAL_PRIMARY_AGENT_ID}" or "{YOLO_PRIMARY_AGENT_ID}" is required."#,
57    )
58});
59
60#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
61pub struct PrimaryAgentSpec {
62    pub id: String,
63    pub name: String,
64    pub description: String,
65    pub model: Option<ModelId>,
66    pub system_prompt: Option<String>,
67    pub tool_visibility: ToolVisibility,
68    pub approval_policy: ToolApprovalPolicy,
69}
70
71static PRIMARY_AGENT_SPECS: std::sync::LazyLock<RwLock<HashMap<String, PrimaryAgentSpec>>> =
72    std::sync::LazyLock::new(|| {
73        let mut specs = HashMap::new();
74        for spec in default_primary_agent_specs() {
75            specs.insert(spec.id.clone(), spec);
76        }
77        RwLock::new(specs)
78    });
79
80pub fn primary_agent_spec(id: &str) -> Option<PrimaryAgentSpec> {
81    let registry = PRIMARY_AGENT_SPECS.read().ok()?;
82    registry.get(id).cloned()
83}
84
85pub fn primary_agent_specs() -> Vec<PrimaryAgentSpec> {
86    let registry = match PRIMARY_AGENT_SPECS.read() {
87        Ok(registry) => registry,
88        Err(_) => return Vec::new(),
89    };
90    let mut specs: Vec<_> = registry.values().cloned().collect();
91    specs.sort_by(|a, b| a.id.cmp(&b.id));
92    specs
93}
94
95pub fn default_primary_agent_id() -> &'static str {
96    DEFAULT_PRIMARY_AGENT_ID
97}
98
99pub fn resolve_effective_config(base_config: &SessionConfig) -> SessionConfig {
100    let mut config = base_config.clone();
101
102    let requested_agent_id = base_config
103        .primary_agent_id
104        .clone()
105        .unwrap_or_else(|| DEFAULT_PRIMARY_AGENT_ID.to_string());
106
107    let (primary_agent_id, spec) = if let Some(spec) = primary_agent_spec(&requested_agent_id) {
108        (requested_agent_id, spec)
109    } else {
110        let fallback_spec =
111            primary_agent_spec(DEFAULT_PRIMARY_AGENT_ID).unwrap_or_else(|| PrimaryAgentSpec {
112                id: DEFAULT_PRIMARY_AGENT_ID.to_string(),
113                name: "Normal".to_string(),
114                description: "Default agent".to_string(),
115                model: None,
116                system_prompt: None,
117                tool_visibility: ToolVisibility::All,
118                approval_policy: ToolApprovalPolicy::default(),
119            });
120        (DEFAULT_PRIMARY_AGENT_ID.to_string(), fallback_spec)
121    };
122
123    let effective_model = resolve_default_model(&config, &spec, &base_config.policy_overrides);
124    let effective_visibility = base_config
125        .policy_overrides
126        .tool_visibility
127        .clone()
128        .unwrap_or_else(|| spec.tool_visibility.clone());
129    let effective_approval_policy = base_config
130        .policy_overrides
131        .approval_policy
132        .apply_to(&spec.approval_policy);
133    let effective_system_prompt = resolve_system_prompt(
134        &config,
135        &spec,
136        &effective_model,
137        &base_config.policy_overrides,
138    );
139
140    config.primary_agent_id = Some(primary_agent_id);
141    config.default_model = effective_model;
142    config.tool_config.visibility = effective_visibility;
143    config.tool_config.approval_policy = effective_approval_policy;
144    config.system_prompt = effective_system_prompt;
145
146    config
147}
148
149fn resolve_default_model(
150    config: &SessionConfig,
151    spec: &PrimaryAgentSpec,
152    overrides: &SessionPolicyOverrides,
153) -> ModelId {
154    let mut model = config.default_model.clone();
155    if let Some(spec_model) = spec.model.as_ref() {
156        model = spec_model.clone();
157    }
158    if let Some(override_model) = overrides.default_model.as_ref() {
159        model = override_model.clone();
160    }
161    model
162}
163
164fn resolve_system_prompt(
165    config: &SessionConfig,
166    spec: &PrimaryAgentSpec,
167    model: &ModelId,
168    _overrides: &SessionPolicyOverrides,
169) -> Option<String> {
170    if let Some(prompt) = spec.system_prompt.as_ref() {
171        return Some(prompt.clone());
172    }
173
174    if let Some(prompt) = config.system_prompt.as_ref()
175        && !prompt.trim().is_empty()
176        && !is_known_primary_agent_prompt(prompt)
177    {
178        return Some(prompt.clone());
179    }
180
181    Some(system_prompt_for_model(model))
182}
183
184fn is_known_primary_agent_prompt(prompt: &str) -> bool {
185    primary_agent_specs()
186        .iter()
187        .any(|spec| spec.system_prompt.as_deref() == Some(prompt))
188}
189
190fn approval_policy_with_dispatch_explore_preapproval() -> ToolApprovalPolicy {
191    let mut policy = ToolApprovalPolicy::default();
192    policy.preapproved.per_tool.insert(
193        DISPATCH_AGENT_TOOL_NAME.to_string(),
194        ToolRule::DispatchAgent {
195            agent_patterns: vec![DEFAULT_AGENT_SPEC_ID.to_string()],
196        },
197    );
198    policy
199}
200
201fn default_primary_agent_specs() -> Vec<PrimaryAgentSpec> {
202    let planner_tool_visibility = ToolVisibility::Whitelist(
203        READ_ONLY_TOOL_NAMES
204            .iter()
205            .map(|name| (*name).to_string())
206            .chain(std::iter::once(DISPATCH_AGENT_TOOL_NAME.to_string()))
207            .collect::<HashSet<_>>(),
208    );
209
210    vec![
211        PrimaryAgentSpec {
212            id: NORMAL_PRIMARY_AGENT_ID.to_string(),
213            name: "Normal".to_string(),
214            description: "Default agent with full tool visibility. Tools which can write require explicit approvals."
215                .to_string(),
216            model: None,
217            system_prompt: None,
218            tool_visibility: ToolVisibility::All,
219            approval_policy: approval_policy_with_dispatch_explore_preapproval(),
220        },
221        PrimaryAgentSpec {
222            id: PLANNER_PRIMARY_AGENT_ID.to_string(),
223            name: "Plan".to_string(),
224            description: "Plan-only agent with read-only tools.".to_string(),
225            model: None,
226            system_prompt: Some(PLANNER_SYSTEM_PROMPT.clone()),
227            tool_visibility: planner_tool_visibility,
228            approval_policy: approval_policy_with_dispatch_explore_preapproval(),
229        },
230        PrimaryAgentSpec {
231            id: YOLO_PRIMARY_AGENT_ID.to_string(),
232            name: "Yolo".to_string(),
233            description: "Full tool visibility with auto-approval for all tools.".to_string(),
234            model: None,
235            system_prompt: None,
236            tool_visibility: ToolVisibility::All,
237            approval_policy: ToolApprovalPolicy {
238                default_behavior: UnapprovedBehavior::Allow,
239                preapproved: ApprovalRules::default(),
240            },
241        },
242    ]
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use crate::config::model::builtin;
249    use crate::session::state::{
250        ApprovalRulesOverrides, SessionPolicyOverrides, SessionToolConfig,
251        ToolApprovalPolicyOverrides, ToolDecision, ToolRule, ToolRuleOverrides, WorkspaceConfig,
252    };
253    use crate::tools::DISPATCH_AGENT_TOOL_NAME;
254    use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
255    use std::collections::HashMap;
256    use std::path::PathBuf;
257    use steer_tools::tools::{TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME};
258
259    fn base_config() -> SessionConfig {
260        SessionConfig {
261            workspace: WorkspaceConfig::Local {
262                path: PathBuf::from("/tmp"),
263            },
264            workspace_ref: None,
265            workspace_id: None,
266            repo_ref: None,
267            parent_session_id: None,
268            workspace_name: None,
269            tool_config: SessionToolConfig::read_only(),
270            system_prompt: Some("base prompt".to_string()),
271            primary_agent_id: None,
272            policy_overrides: SessionPolicyOverrides::empty(),
273            metadata: HashMap::new(),
274            default_model: builtin::claude_sonnet_4_5(),
275            auto_compaction: crate::session::state::AutoCompactionConfig::default(),
276        }
277    }
278
279    #[test]
280    fn default_primary_agent_exists() {
281        let id = default_primary_agent_id();
282        let spec = primary_agent_spec(id);
283        assert!(spec.is_some());
284        assert_eq!(spec.unwrap().id, id);
285    }
286
287    #[test]
288    fn resolve_effective_config_preserves_base_when_unset() {
289        let mut config = base_config();
290        config.primary_agent_id = Some(NORMAL_PRIMARY_AGENT_ID.to_string());
291        let updated = resolve_effective_config(&config);
292
293        assert_eq!(updated.default_model, config.default_model);
294        assert_eq!(updated.system_prompt, config.system_prompt);
295        assert_eq!(updated.tool_config.visibility, ToolVisibility::All);
296    }
297
298    #[test]
299    fn resolve_effective_config_applies_overrides() {
300        let mut config = base_config();
301        config.primary_agent_id = Some(NORMAL_PRIMARY_AGENT_ID.to_string());
302        config.policy_overrides = SessionPolicyOverrides {
303            default_model: Some(builtin::claude_haiku_4_5()),
304            tool_visibility: Some(ToolVisibility::ReadOnly),
305            approval_policy: ToolApprovalPolicyOverrides {
306                preapproved: ApprovalRulesOverrides {
307                    tools: ["custom_tool".to_string()].into_iter().collect(),
308                    per_tool: [(
309                        "bash".to_string(),
310                        ToolRuleOverrides::Bash {
311                            patterns: vec!["git status".to_string()],
312                        },
313                    )]
314                    .into_iter()
315                    .collect(),
316                },
317            },
318        };
319
320        let updated = resolve_effective_config(&config);
321        assert_eq!(updated.default_model, builtin::claude_haiku_4_5());
322        assert_eq!(updated.tool_config.visibility, ToolVisibility::ReadOnly);
323        assert_eq!(
324            updated.tool_config.approval_policy.default_behavior,
325            UnapprovedBehavior::Prompt
326        );
327        assert!(
328            updated
329                .tool_config
330                .approval_policy
331                .preapproved
332                .tools
333                .contains("custom_tool")
334        );
335
336        let rule = updated
337            .tool_config
338            .approval_policy
339            .preapproved
340            .per_tool
341            .get("bash")
342            .expect("bash rule");
343        match rule {
344            ToolRule::Bash { patterns } => {
345                assert!(patterns.contains(&"git status".to_string()));
346            }
347            ToolRule::DispatchAgent { .. } => panic!("Unexpected dispatch agent rule"),
348        }
349    }
350
351    #[test]
352    fn normal_spec_preapproves_dispatch_agent_for_explore() {
353        let spec = primary_agent_spec(NORMAL_PRIMARY_AGENT_ID).expect("normal spec");
354
355        let rule = spec
356            .approval_policy
357            .preapproved
358            .per_tool
359            .get(DISPATCH_AGENT_TOOL_NAME)
360            .expect("dispatch agent rule");
361
362        match rule {
363            ToolRule::DispatchAgent { agent_patterns } => {
364                assert_eq!(agent_patterns.as_slice(), [DEFAULT_AGENT_SPEC_ID]);
365            }
366            ToolRule::Bash { .. } => panic!("Unexpected bash rule"),
367        }
368    }
369
370    #[test]
371    fn plan_spec_limits_tools_and_dispatch_agent() {
372        let spec = primary_agent_spec(PLANNER_PRIMARY_AGENT_ID).expect("plan spec");
373
374        match &spec.tool_visibility {
375            ToolVisibility::Whitelist(allowed) => {
376                assert!(allowed.contains(DISPATCH_AGENT_TOOL_NAME));
377                for name in READ_ONLY_TOOL_NAMES {
378                    assert!(allowed.contains(*name));
379                }
380                assert_eq!(allowed.len(), READ_ONLY_TOOL_NAMES.len() + 1);
381            }
382            other => panic!("Unexpected tool visibility: {other:?}"),
383        }
384
385        let rule = spec
386            .approval_policy
387            .preapproved
388            .per_tool
389            .get(DISPATCH_AGENT_TOOL_NAME)
390            .expect("dispatch agent rule");
391
392        match rule {
393            ToolRule::DispatchAgent { agent_patterns } => {
394                assert_eq!(agent_patterns.as_slice(), [DEFAULT_AGENT_SPEC_ID]);
395            }
396            ToolRule::Bash { .. } => panic!("Unexpected bash rule"),
397        }
398    }
399
400    #[test]
401    fn all_primary_agents_allow_todos_by_default() {
402        for agent_id in [
403            NORMAL_PRIMARY_AGENT_ID,
404            PLANNER_PRIMARY_AGENT_ID,
405            YOLO_PRIMARY_AGENT_ID,
406        ] {
407            let spec = primary_agent_spec(agent_id).expect("primary agent spec");
408            assert_eq!(
409                spec.approval_policy.tool_decision(TODO_READ_TOOL_NAME),
410                ToolDecision::Allow
411            );
412            assert_eq!(
413                spec.approval_policy.tool_decision(TODO_WRITE_TOOL_NAME),
414                ToolDecision::Allow
415            );
416        }
417    }
418}