Skip to main content

steer_core/
primary_agents.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::sync::RwLock;
5
6use crate::agents::DEFAULT_AGENT_SPEC_ID;
7use crate::config::model::ModelId;
8use crate::prompts::system_prompt_for_model;
9use crate::session::state::{
10    ApprovalRules, SessionConfig, SessionPolicyOverrides, ToolApprovalPolicy, ToolRule,
11    ToolVisibility, UnapprovedBehavior,
12};
13use crate::tools::DISPATCH_AGENT_TOOL_NAME;
14use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
15
16pub const NORMAL_PRIMARY_AGENT_ID: &str = "normal";
17pub const PLANNER_PRIMARY_AGENT_ID: &str = "plan";
18pub const YOLO_PRIMARY_AGENT_ID: &str = "yolo";
19pub const DEFAULT_PRIMARY_AGENT_ID: &str = NORMAL_PRIMARY_AGENT_ID;
20
21static PLANNER_SYSTEM_PROMPT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
22    format!(
23        r#"You are in plan mode. Produce a concise, step-by-step plan only.
24
25Rules:
26- Use read-only tools to gather the context you need before planning.
27- When broader search is needed, use dispatch_agent with the "explore" sub-agent.
28- Do not make changes or write code/patches.
29- If key details are missing, ask up to three targeted questions and stop.
30
31When you can proceed, respond using this structure (omit empty sections):
32Plan:
331. ...
342. ...
353. ...
36
37Assumptions:
38- ...
39
40Risks:
41- ...
42
43Validation:
44- ...
45
46Finish by asking the user to switch back to "{NORMAL_PRIMARY_AGENT_ID}" or "{YOLO_PRIMARY_AGENT_ID}" to execute."#,
47    )
48});
49
50#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
51pub struct PrimaryAgentSpec {
52    pub id: String,
53    pub name: String,
54    pub description: String,
55    pub model: Option<ModelId>,
56    pub system_prompt: Option<String>,
57    pub tool_visibility: ToolVisibility,
58    pub approval_policy: ToolApprovalPolicy,
59}
60
61static PRIMARY_AGENT_SPECS: std::sync::LazyLock<RwLock<HashMap<String, PrimaryAgentSpec>>> =
62    std::sync::LazyLock::new(|| {
63        let mut specs = HashMap::new();
64        for spec in default_primary_agent_specs() {
65            specs.insert(spec.id.clone(), spec);
66        }
67        RwLock::new(specs)
68    });
69
70pub fn primary_agent_spec(id: &str) -> Option<PrimaryAgentSpec> {
71    let registry = PRIMARY_AGENT_SPECS.read().ok()?;
72    registry.get(id).cloned()
73}
74
75pub fn primary_agent_specs() -> Vec<PrimaryAgentSpec> {
76    let registry = match PRIMARY_AGENT_SPECS.read() {
77        Ok(registry) => registry,
78        Err(_) => return Vec::new(),
79    };
80    let mut specs: Vec<_> = registry.values().cloned().collect();
81    specs.sort_by(|a, b| a.id.cmp(&b.id));
82    specs
83}
84
85pub fn default_primary_agent_id() -> &'static str {
86    DEFAULT_PRIMARY_AGENT_ID
87}
88
89pub fn resolve_effective_config(base_config: &SessionConfig) -> SessionConfig {
90    let mut config = base_config.clone();
91
92    let requested_agent_id = base_config
93        .primary_agent_id
94        .clone()
95        .unwrap_or_else(|| DEFAULT_PRIMARY_AGENT_ID.to_string());
96
97    let (primary_agent_id, spec) = if let Some(spec) = primary_agent_spec(&requested_agent_id) {
98        (requested_agent_id, spec)
99    } else {
100        let fallback_spec =
101            primary_agent_spec(DEFAULT_PRIMARY_AGENT_ID).unwrap_or_else(|| PrimaryAgentSpec {
102                id: DEFAULT_PRIMARY_AGENT_ID.to_string(),
103                name: "Normal".to_string(),
104                description: "Default agent".to_string(),
105                model: None,
106                system_prompt: None,
107                tool_visibility: ToolVisibility::All,
108                approval_policy: ToolApprovalPolicy::default(),
109            });
110        (DEFAULT_PRIMARY_AGENT_ID.to_string(), fallback_spec)
111    };
112
113    let effective_model = resolve_default_model(&config, &spec, &base_config.policy_overrides);
114    let effective_visibility = base_config
115        .policy_overrides
116        .tool_visibility
117        .clone()
118        .unwrap_or_else(|| spec.tool_visibility.clone());
119    let effective_approval_policy = base_config
120        .policy_overrides
121        .approval_policy
122        .apply_to(&spec.approval_policy);
123    let effective_system_prompt = resolve_system_prompt(
124        &config,
125        &spec,
126        &effective_model,
127        &base_config.policy_overrides,
128    );
129
130    config.primary_agent_id = Some(primary_agent_id);
131    config.default_model = effective_model;
132    config.tool_config.visibility = effective_visibility;
133    config.tool_config.approval_policy = effective_approval_policy;
134    config.system_prompt = effective_system_prompt;
135
136    config
137}
138
139fn resolve_default_model(
140    config: &SessionConfig,
141    spec: &PrimaryAgentSpec,
142    overrides: &SessionPolicyOverrides,
143) -> ModelId {
144    let mut model = config.default_model.clone();
145    if let Some(spec_model) = spec.model.as_ref() {
146        model = spec_model.clone();
147    }
148    if let Some(override_model) = overrides.default_model.as_ref() {
149        model = override_model.clone();
150    }
151    model
152}
153
154fn resolve_system_prompt(
155    config: &SessionConfig,
156    spec: &PrimaryAgentSpec,
157    model: &ModelId,
158    _overrides: &SessionPolicyOverrides,
159) -> Option<String> {
160    if let Some(prompt) = spec.system_prompt.as_ref() {
161        return Some(prompt.clone());
162    }
163
164    if let Some(prompt) = config.system_prompt.as_ref()
165        && !prompt.trim().is_empty()
166        && !is_known_primary_agent_prompt(prompt)
167    {
168        return Some(prompt.clone());
169    }
170
171    Some(system_prompt_for_model(model))
172}
173
174fn is_known_primary_agent_prompt(prompt: &str) -> bool {
175    primary_agent_specs()
176        .iter()
177        .any(|spec| spec.system_prompt.as_deref() == Some(prompt))
178}
179
180fn approval_policy_with_dispatch_explore_preapproval() -> ToolApprovalPolicy {
181    let mut policy = ToolApprovalPolicy::default();
182    policy.preapproved.per_tool.insert(
183        DISPATCH_AGENT_TOOL_NAME.to_string(),
184        ToolRule::DispatchAgent {
185            agent_patterns: vec![DEFAULT_AGENT_SPEC_ID.to_string()],
186        },
187    );
188    policy
189}
190
191fn default_primary_agent_specs() -> Vec<PrimaryAgentSpec> {
192    let planner_tool_visibility = ToolVisibility::Whitelist(
193        READ_ONLY_TOOL_NAMES
194            .iter()
195            .map(|name| (*name).to_string())
196            .chain(std::iter::once(DISPATCH_AGENT_TOOL_NAME.to_string()))
197            .collect::<HashSet<_>>(),
198    );
199
200    vec![
201        PrimaryAgentSpec {
202            id: NORMAL_PRIMARY_AGENT_ID.to_string(),
203            name: "Normal".to_string(),
204            description: "Default agent with full tool visibility. Tools which can write require explicit approvals."
205                .to_string(),
206            model: None,
207            system_prompt: None,
208            tool_visibility: ToolVisibility::All,
209            approval_policy: approval_policy_with_dispatch_explore_preapproval(),
210        },
211        PrimaryAgentSpec {
212            id: PLANNER_PRIMARY_AGENT_ID.to_string(),
213            name: "Plan".to_string(),
214            description: "Plan-only agent with read-only tools.".to_string(),
215            model: None,
216            system_prompt: Some(PLANNER_SYSTEM_PROMPT.clone()),
217            tool_visibility: planner_tool_visibility,
218            approval_policy: approval_policy_with_dispatch_explore_preapproval(),
219        },
220        PrimaryAgentSpec {
221            id: YOLO_PRIMARY_AGENT_ID.to_string(),
222            name: "Yolo".to_string(),
223            description: "Full tool visibility with auto-approval for all tools.".to_string(),
224            model: None,
225            system_prompt: None,
226            tool_visibility: ToolVisibility::All,
227            approval_policy: ToolApprovalPolicy {
228                default_behavior: UnapprovedBehavior::Allow,
229                preapproved: ApprovalRules::default(),
230            },
231        },
232    ]
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::config::model::builtin;
239    use crate::session::state::{
240        ApprovalRulesOverrides, SessionPolicyOverrides, SessionToolConfig,
241        ToolApprovalPolicyOverrides, ToolDecision, ToolRule, ToolRuleOverrides, WorkspaceConfig,
242    };
243    use crate::tools::DISPATCH_AGENT_TOOL_NAME;
244    use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
245    use std::collections::HashMap;
246    use std::path::PathBuf;
247    use steer_tools::tools::{TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME};
248
249    fn base_config() -> SessionConfig {
250        SessionConfig {
251            workspace: WorkspaceConfig::Local {
252                path: PathBuf::from("/tmp"),
253            },
254            workspace_ref: None,
255            workspace_id: None,
256            repo_ref: None,
257            parent_session_id: None,
258            workspace_name: None,
259            tool_config: SessionToolConfig::read_only(),
260            system_prompt: Some("base prompt".to_string()),
261            primary_agent_id: None,
262            policy_overrides: SessionPolicyOverrides::empty(),
263            metadata: HashMap::new(),
264            default_model: builtin::claude_sonnet_4_5(),
265            auto_compaction: crate::session::state::AutoCompactionConfig::default(),
266        }
267    }
268
269    #[test]
270    fn default_primary_agent_exists() {
271        let id = default_primary_agent_id();
272        let spec = primary_agent_spec(id);
273        assert!(spec.is_some());
274        assert_eq!(spec.unwrap().id, id);
275    }
276
277    #[test]
278    fn resolve_effective_config_preserves_base_when_unset() {
279        let mut config = base_config();
280        config.primary_agent_id = Some(NORMAL_PRIMARY_AGENT_ID.to_string());
281        let updated = resolve_effective_config(&config);
282
283        assert_eq!(updated.default_model, config.default_model);
284        assert_eq!(updated.system_prompt, config.system_prompt);
285        assert_eq!(updated.tool_config.visibility, ToolVisibility::All);
286    }
287
288    #[test]
289    fn resolve_effective_config_applies_overrides() {
290        let mut config = base_config();
291        config.primary_agent_id = Some(NORMAL_PRIMARY_AGENT_ID.to_string());
292        config.policy_overrides = SessionPolicyOverrides {
293            default_model: Some(builtin::claude_haiku_4_5()),
294            tool_visibility: Some(ToolVisibility::ReadOnly),
295            approval_policy: ToolApprovalPolicyOverrides {
296                default_behavior: Some(UnapprovedBehavior::Allow),
297                preapproved: ApprovalRulesOverrides {
298                    tools: ["custom_tool".to_string()].into_iter().collect(),
299                    per_tool: [(
300                        "bash".to_string(),
301                        ToolRuleOverrides::Bash {
302                            patterns: vec!["git status".to_string()],
303                        },
304                    )]
305                    .into_iter()
306                    .collect(),
307                },
308            },
309        };
310
311        let updated = resolve_effective_config(&config);
312        assert_eq!(updated.default_model, builtin::claude_haiku_4_5());
313        assert_eq!(updated.tool_config.visibility, ToolVisibility::ReadOnly);
314        assert_eq!(
315            updated.tool_config.approval_policy.default_behavior,
316            UnapprovedBehavior::Allow
317        );
318        assert!(
319            updated
320                .tool_config
321                .approval_policy
322                .preapproved
323                .tools
324                .contains("custom_tool")
325        );
326
327        let rule = updated
328            .tool_config
329            .approval_policy
330            .preapproved
331            .per_tool
332            .get("bash")
333            .expect("bash rule");
334        match rule {
335            ToolRule::Bash { patterns } => {
336                assert!(patterns.contains(&"git status".to_string()));
337            }
338            ToolRule::DispatchAgent { .. } => panic!("Unexpected dispatch agent rule"),
339        }
340    }
341
342    #[test]
343    fn normal_spec_preapproves_dispatch_agent_for_explore() {
344        let spec = primary_agent_spec(NORMAL_PRIMARY_AGENT_ID).expect("normal spec");
345
346        let rule = spec
347            .approval_policy
348            .preapproved
349            .per_tool
350            .get(DISPATCH_AGENT_TOOL_NAME)
351            .expect("dispatch agent rule");
352
353        match rule {
354            ToolRule::DispatchAgent { agent_patterns } => {
355                assert_eq!(agent_patterns.as_slice(), [DEFAULT_AGENT_SPEC_ID]);
356            }
357            ToolRule::Bash { .. } => panic!("Unexpected bash rule"),
358        }
359    }
360
361    #[test]
362    fn plan_spec_limits_tools_and_dispatch_agent() {
363        let spec = primary_agent_spec(PLANNER_PRIMARY_AGENT_ID).expect("plan spec");
364
365        match &spec.tool_visibility {
366            ToolVisibility::Whitelist(allowed) => {
367                assert!(allowed.contains(DISPATCH_AGENT_TOOL_NAME));
368                for name in READ_ONLY_TOOL_NAMES {
369                    assert!(allowed.contains(*name));
370                }
371                assert_eq!(allowed.len(), READ_ONLY_TOOL_NAMES.len() + 1);
372            }
373            other => panic!("Unexpected tool visibility: {other:?}"),
374        }
375
376        let rule = spec
377            .approval_policy
378            .preapproved
379            .per_tool
380            .get(DISPATCH_AGENT_TOOL_NAME)
381            .expect("dispatch agent rule");
382
383        match rule {
384            ToolRule::DispatchAgent { agent_patterns } => {
385                assert_eq!(agent_patterns.as_slice(), [DEFAULT_AGENT_SPEC_ID]);
386            }
387            ToolRule::Bash { .. } => panic!("Unexpected bash rule"),
388        }
389    }
390
391    #[test]
392    fn all_primary_agents_allow_todos_by_default() {
393        for agent_id in [
394            NORMAL_PRIMARY_AGENT_ID,
395            PLANNER_PRIMARY_AGENT_ID,
396            YOLO_PRIMARY_AGENT_ID,
397        ] {
398            let spec = primary_agent_spec(agent_id).expect("primary agent spec");
399            assert_eq!(
400                spec.approval_policy.tool_decision(TODO_READ_TOOL_NAME),
401                ToolDecision::Allow
402            );
403            assert_eq!(
404                spec.approval_policy.tool_decision(TODO_WRITE_TOOL_NAME),
405                ToolDecision::Allow
406            );
407        }
408    }
409}