Skip to main content

steer_core/
primary_agents.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::sync::RwLock;
5
6use crate::agents::DEFAULT_AGENT_SPEC_ID;
7use crate::config::model::ModelId;
8use crate::prompts::system_prompt_for_model;
9use crate::session::state::{
10    ApprovalRules, SessionConfig, SessionPolicyOverrides, ToolApprovalPolicy, ToolRule,
11    ToolVisibility, UnapprovedBehavior,
12};
13use crate::tools::DISPATCH_AGENT_TOOL_NAME;
14use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
15
16pub const NORMAL_PRIMARY_AGENT_ID: &str = "normal";
17pub const PLANNER_PRIMARY_AGENT_ID: &str = "plan";
18pub const YOLO_PRIMARY_AGENT_ID: &str = "yolo";
19pub const DEFAULT_PRIMARY_AGENT_ID: &str = NORMAL_PRIMARY_AGENT_ID;
20
21static PLANNER_SYSTEM_PROMPT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
22    format!(
23        r#"You are in plan mode. Produce a concise, step-by-step plan only.
24
25Rules:
26- Use read-only tools to gather the context you need before planning.
27- When broader search is needed, use dispatch_agent with the "explore" sub-agent.
28- Do not make changes or write code/patches.
29- If key details are missing, ask up to three targeted questions and stop.
30
31When you can proceed, respond using this structure (omit empty sections):
32Plan:
331. ...
342. ...
353. ...
36
37Assumptions:
38- ...
39
40Risks:
41- ...
42
43Validation:
44- ...
45
46Finish by asking the user to switch back to "{NORMAL_PRIMARY_AGENT_ID}" or "{YOLO_PRIMARY_AGENT_ID}" to execute."#,
47    )
48});
49
50#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
51pub struct PrimaryAgentSpec {
52    pub id: String,
53    pub name: String,
54    pub description: String,
55    pub model: Option<ModelId>,
56    pub system_prompt: Option<String>,
57    pub tool_visibility: ToolVisibility,
58    pub approval_policy: ToolApprovalPolicy,
59}
60
61static PRIMARY_AGENT_SPECS: std::sync::LazyLock<RwLock<HashMap<String, PrimaryAgentSpec>>> =
62    std::sync::LazyLock::new(|| {
63        let mut specs = HashMap::new();
64        for spec in default_primary_agent_specs() {
65            specs.insert(spec.id.clone(), spec);
66        }
67        RwLock::new(specs)
68    });
69
70pub fn primary_agent_spec(id: &str) -> Option<PrimaryAgentSpec> {
71    let registry = PRIMARY_AGENT_SPECS.read().ok()?;
72    registry.get(id).cloned()
73}
74
75pub fn primary_agent_specs() -> Vec<PrimaryAgentSpec> {
76    let registry = match PRIMARY_AGENT_SPECS.read() {
77        Ok(registry) => registry,
78        Err(_) => return Vec::new(),
79    };
80    let mut specs: Vec<_> = registry.values().cloned().collect();
81    specs.sort_by(|a, b| a.id.cmp(&b.id));
82    specs
83}
84
85pub fn default_primary_agent_id() -> &'static str {
86    DEFAULT_PRIMARY_AGENT_ID
87}
88
89pub fn resolve_effective_config(base_config: &SessionConfig) -> SessionConfig {
90    let mut config = base_config.clone();
91
92    let requested_agent_id = base_config
93        .primary_agent_id
94        .clone()
95        .unwrap_or_else(|| DEFAULT_PRIMARY_AGENT_ID.to_string());
96
97    let (primary_agent_id, spec) = if let Some(spec) = primary_agent_spec(&requested_agent_id) {
98        (requested_agent_id, spec)
99    } else {
100        let fallback_spec =
101            primary_agent_spec(DEFAULT_PRIMARY_AGENT_ID).unwrap_or_else(|| PrimaryAgentSpec {
102                id: DEFAULT_PRIMARY_AGENT_ID.to_string(),
103                name: "Normal".to_string(),
104                description: "Default agent".to_string(),
105                model: None,
106                system_prompt: None,
107                tool_visibility: ToolVisibility::All,
108                approval_policy: ToolApprovalPolicy::default(),
109            });
110        (DEFAULT_PRIMARY_AGENT_ID.to_string(), fallback_spec)
111    };
112
113    let effective_model = resolve_default_model(&config, &spec, &base_config.policy_overrides);
114    let effective_visibility = base_config
115        .policy_overrides
116        .tool_visibility
117        .clone()
118        .unwrap_or_else(|| spec.tool_visibility.clone());
119    let effective_approval_policy = base_config
120        .policy_overrides
121        .approval_policy
122        .apply_to(&spec.approval_policy);
123    let effective_system_prompt = resolve_system_prompt(
124        &config,
125        &spec,
126        &effective_model,
127        &base_config.policy_overrides,
128    );
129
130    config.primary_agent_id = Some(primary_agent_id);
131    config.default_model = effective_model;
132    config.tool_config.visibility = effective_visibility;
133    config.tool_config.approval_policy = effective_approval_policy;
134    config.system_prompt = effective_system_prompt;
135
136    config
137}
138
139fn resolve_default_model(
140    config: &SessionConfig,
141    spec: &PrimaryAgentSpec,
142    overrides: &SessionPolicyOverrides,
143) -> ModelId {
144    let mut model = config.default_model.clone();
145    if let Some(spec_model) = spec.model.as_ref() {
146        model = spec_model.clone();
147    }
148    if let Some(override_model) = overrides.default_model.as_ref() {
149        model = override_model.clone();
150    }
151    model
152}
153
154fn resolve_system_prompt(
155    config: &SessionConfig,
156    spec: &PrimaryAgentSpec,
157    model: &ModelId,
158    _overrides: &SessionPolicyOverrides,
159) -> Option<String> {
160    if let Some(prompt) = spec.system_prompt.as_ref() {
161        return Some(prompt.clone());
162    }
163
164    if let Some(prompt) = config.system_prompt.as_ref()
165        && !prompt.trim().is_empty()
166        && !is_known_primary_agent_prompt(prompt)
167    {
168        return Some(prompt.clone());
169    }
170
171    Some(system_prompt_for_model(model))
172}
173
174fn is_known_primary_agent_prompt(prompt: &str) -> bool {
175    primary_agent_specs()
176        .iter()
177        .any(|spec| spec.system_prompt.as_deref() == Some(prompt))
178}
179
180fn approval_policy_with_dispatch_explore_preapproval() -> ToolApprovalPolicy {
181    let mut policy = ToolApprovalPolicy::default();
182    policy.preapproved.per_tool.insert(
183        DISPATCH_AGENT_TOOL_NAME.to_string(),
184        ToolRule::DispatchAgent {
185            agent_patterns: vec![DEFAULT_AGENT_SPEC_ID.to_string()],
186        },
187    );
188    policy
189}
190
191fn default_primary_agent_specs() -> Vec<PrimaryAgentSpec> {
192    let planner_tool_visibility = ToolVisibility::Whitelist(
193        READ_ONLY_TOOL_NAMES
194            .iter()
195            .map(|name| (*name).to_string())
196            .chain(std::iter::once(DISPATCH_AGENT_TOOL_NAME.to_string()))
197            .collect::<HashSet<_>>(),
198    );
199
200    vec![
201        PrimaryAgentSpec {
202            id: NORMAL_PRIMARY_AGENT_ID.to_string(),
203            name: "Normal".to_string(),
204            description: "Default agent with full tool visibility. Tools which can write require explicit approvals."
205                .to_string(),
206            model: None,
207            system_prompt: None,
208            tool_visibility: ToolVisibility::All,
209            approval_policy: approval_policy_with_dispatch_explore_preapproval(),
210        },
211        PrimaryAgentSpec {
212            id: PLANNER_PRIMARY_AGENT_ID.to_string(),
213            name: "Plan".to_string(),
214            description: "Plan-only agent with read-only tools.".to_string(),
215            model: None,
216            system_prompt: Some(PLANNER_SYSTEM_PROMPT.clone()),
217            tool_visibility: planner_tool_visibility,
218            approval_policy: approval_policy_with_dispatch_explore_preapproval(),
219        },
220        PrimaryAgentSpec {
221            id: YOLO_PRIMARY_AGENT_ID.to_string(),
222            name: "Yolo".to_string(),
223            description: "Full tool visibility with auto-approval for all tools.".to_string(),
224            model: None,
225            system_prompt: None,
226            tool_visibility: ToolVisibility::All,
227            approval_policy: ToolApprovalPolicy {
228                default_behavior: UnapprovedBehavior::Allow,
229                preapproved: ApprovalRules::default(),
230            },
231        },
232    ]
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::config::model::builtin;
239    use crate::session::state::{
240        ApprovalRulesOverrides, SessionPolicyOverrides, SessionToolConfig,
241        ToolApprovalPolicyOverrides, ToolDecision, ToolRule, ToolRuleOverrides, WorkspaceConfig,
242    };
243    use crate::tools::DISPATCH_AGENT_TOOL_NAME;
244    use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
245    use std::collections::HashMap;
246    use std::path::PathBuf;
247    use steer_tools::tools::{TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME};
248
249    fn base_config() -> SessionConfig {
250        SessionConfig {
251            workspace: WorkspaceConfig::Local {
252                path: PathBuf::from("/tmp"),
253            },
254            workspace_ref: None,
255            workspace_id: None,
256            repo_ref: None,
257            parent_session_id: None,
258            workspace_name: None,
259            tool_config: SessionToolConfig::read_only(),
260            system_prompt: Some("base prompt".to_string()),
261            primary_agent_id: None,
262            policy_overrides: SessionPolicyOverrides::empty(),
263            metadata: HashMap::new(),
264            default_model: builtin::claude_sonnet_4_5(),
265        }
266    }
267
268    #[test]
269    fn default_primary_agent_exists() {
270        let id = default_primary_agent_id();
271        let spec = primary_agent_spec(id);
272        assert!(spec.is_some());
273        assert_eq!(spec.unwrap().id, id);
274    }
275
276    #[test]
277    fn resolve_effective_config_preserves_base_when_unset() {
278        let mut config = base_config();
279        config.primary_agent_id = Some(NORMAL_PRIMARY_AGENT_ID.to_string());
280        let updated = resolve_effective_config(&config);
281
282        assert_eq!(updated.default_model, config.default_model);
283        assert_eq!(updated.system_prompt, config.system_prompt);
284        assert_eq!(updated.tool_config.visibility, ToolVisibility::All);
285    }
286
287    #[test]
288    fn resolve_effective_config_applies_overrides() {
289        let mut config = base_config();
290        config.primary_agent_id = Some(NORMAL_PRIMARY_AGENT_ID.to_string());
291        config.policy_overrides = SessionPolicyOverrides {
292            default_model: Some(builtin::claude_haiku_4_5()),
293            tool_visibility: Some(ToolVisibility::ReadOnly),
294            approval_policy: ToolApprovalPolicyOverrides {
295                default_behavior: Some(UnapprovedBehavior::Allow),
296                preapproved: ApprovalRulesOverrides {
297                    tools: ["custom_tool".to_string()].into_iter().collect(),
298                    per_tool: [(
299                        "bash".to_string(),
300                        ToolRuleOverrides::Bash {
301                            patterns: vec!["git status".to_string()],
302                        },
303                    )]
304                    .into_iter()
305                    .collect(),
306                },
307            },
308        };
309
310        let updated = resolve_effective_config(&config);
311        assert_eq!(updated.default_model, builtin::claude_haiku_4_5());
312        assert_eq!(updated.tool_config.visibility, ToolVisibility::ReadOnly);
313        assert_eq!(
314            updated.tool_config.approval_policy.default_behavior,
315            UnapprovedBehavior::Allow
316        );
317        assert!(
318            updated
319                .tool_config
320                .approval_policy
321                .preapproved
322                .tools
323                .contains("custom_tool")
324        );
325
326        let rule = updated
327            .tool_config
328            .approval_policy
329            .preapproved
330            .per_tool
331            .get("bash")
332            .expect("bash rule");
333        match rule {
334            ToolRule::Bash { patterns } => {
335                assert!(patterns.contains(&"git status".to_string()));
336            }
337            ToolRule::DispatchAgent { .. } => panic!("Unexpected dispatch agent rule"),
338        }
339    }
340
341    #[test]
342    fn normal_spec_preapproves_dispatch_agent_for_explore() {
343        let spec = primary_agent_spec(NORMAL_PRIMARY_AGENT_ID).expect("normal spec");
344
345        let rule = spec
346            .approval_policy
347            .preapproved
348            .per_tool
349            .get(DISPATCH_AGENT_TOOL_NAME)
350            .expect("dispatch agent rule");
351
352        match rule {
353            ToolRule::DispatchAgent { agent_patterns } => {
354                assert_eq!(agent_patterns.as_slice(), [DEFAULT_AGENT_SPEC_ID]);
355            }
356            ToolRule::Bash { .. } => panic!("Unexpected bash rule"),
357        }
358    }
359
360    #[test]
361    fn plan_spec_limits_tools_and_dispatch_agent() {
362        let spec = primary_agent_spec(PLANNER_PRIMARY_AGENT_ID).expect("plan spec");
363
364        match &spec.tool_visibility {
365            ToolVisibility::Whitelist(allowed) => {
366                assert!(allowed.contains(DISPATCH_AGENT_TOOL_NAME));
367                for name in READ_ONLY_TOOL_NAMES {
368                    assert!(allowed.contains(*name));
369                }
370                assert_eq!(allowed.len(), READ_ONLY_TOOL_NAMES.len() + 1);
371            }
372            other => panic!("Unexpected tool visibility: {other:?}"),
373        }
374
375        let rule = spec
376            .approval_policy
377            .preapproved
378            .per_tool
379            .get(DISPATCH_AGENT_TOOL_NAME)
380            .expect("dispatch agent rule");
381
382        match rule {
383            ToolRule::DispatchAgent { agent_patterns } => {
384                assert_eq!(agent_patterns.as_slice(), [DEFAULT_AGENT_SPEC_ID]);
385            }
386            ToolRule::Bash { .. } => panic!("Unexpected bash rule"),
387        }
388    }
389
390    #[test]
391    fn all_primary_agents_allow_todos_by_default() {
392        for agent_id in [
393            NORMAL_PRIMARY_AGENT_ID,
394            PLANNER_PRIMARY_AGENT_ID,
395            YOLO_PRIMARY_AGENT_ID,
396        ] {
397            let spec = primary_agent_spec(agent_id).expect("primary agent spec");
398            assert_eq!(
399                spec.approval_policy.tool_decision(TODO_READ_TOOL_NAME),
400                ToolDecision::Allow
401            );
402            assert_eq!(
403                spec.approval_policy.tool_decision(TODO_WRITE_TOOL_NAME),
404                ToolDecision::Allow
405            );
406        }
407    }
408}