Skip to main content

steer_core/
agents.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::RwLock;
5
6use thiserror::Error;
7
8use crate::config::model::ModelId;
9use steer_tools::tools::edit::multi_edit::MULTI_EDIT_TOOL_NAME;
10use steer_tools::tools::replace::REPLACE_TOOL_NAME;
11use steer_tools::tools::{
12    BASH_TOOL_NAME, EDIT_TOOL_NAME, GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME,
13    READ_FILE_TOOL_NAME,
14};
15
16pub const DEFAULT_AGENT_SPEC_ID: &str = "explore";
17
18#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
19#[serde(rename_all = "snake_case")]
20pub enum McpAccessPolicy {
21    None,
22    Allowlist(Vec<String>),
23    All,
24}
25
26impl McpAccessPolicy {
27    pub fn allows_server(&self, server_name: &str) -> bool {
28        match self {
29            McpAccessPolicy::None => false,
30            McpAccessPolicy::All => true,
31            McpAccessPolicy::Allowlist(servers) => servers.iter().any(|s| s == server_name),
32        }
33    }
34
35    pub fn allow_mcp_tools(&self) -> bool {
36        !matches!(self, McpAccessPolicy::None)
37    }
38
39    pub fn describe(&self) -> String {
40        match self {
41            McpAccessPolicy::None => "none".to_string(),
42            McpAccessPolicy::All => "all".to_string(),
43            McpAccessPolicy::Allowlist(servers) => {
44                let list = if servers.is_empty() {
45                    "<empty>".to_string()
46                } else {
47                    servers.join(", ")
48                };
49                format!("allowlist({list})")
50            }
51        }
52    }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
56pub struct AgentSpec {
57    pub id: String,
58    pub name: String,
59    pub description: String,
60    pub tools: Vec<String>,
61    pub mcp_access: McpAccessPolicy,
62    #[serde(default)]
63    pub model: Option<ModelId>,
64}
65
66#[derive(Debug, Error)]
67pub enum AgentSpecError {
68    #[error("Agent spec already registered: {0}")]
69    AlreadyRegistered(String),
70    #[error("Agent spec registry lock poisoned")]
71    RegistryPoisoned,
72}
73
74static AGENT_SPECS: std::sync::LazyLock<RwLock<HashMap<String, AgentSpec>>> =
75    std::sync::LazyLock::new(|| {
76        let mut specs = HashMap::new();
77        for spec in default_agent_specs() {
78            specs.insert(spec.id.clone(), spec);
79        }
80        RwLock::new(specs)
81    });
82
83pub fn register_agent_spec(spec: AgentSpec) -> Result<(), AgentSpecError> {
84    let mut registry = AGENT_SPECS
85        .write()
86        .map_err(|_| AgentSpecError::RegistryPoisoned)?;
87    if registry.contains_key(&spec.id) {
88        return Err(AgentSpecError::AlreadyRegistered(spec.id));
89    }
90    registry.insert(spec.id.clone(), spec);
91    Ok(())
92}
93
94pub fn agent_spec(id: &str) -> Option<AgentSpec> {
95    let registry = AGENT_SPECS.read().ok()?;
96    registry.get(id).cloned()
97}
98
99pub fn agent_specs() -> Vec<AgentSpec> {
100    let registry = match AGENT_SPECS.read() {
101        Ok(registry) => registry,
102        Err(_) => return Vec::new(),
103    };
104    let mut specs: Vec<_> = registry.values().cloned().collect();
105    specs.sort_by(|a, b| a.id.cmp(&b.id));
106    specs
107}
108
109pub fn default_agent_spec_id() -> &'static str {
110    DEFAULT_AGENT_SPEC_ID
111}
112
113pub fn agent_specs_prompt() -> String {
114    let specs = agent_specs();
115    if specs.is_empty() {
116        return String::new();
117    }
118
119    let mut lines = Vec::new();
120    lines.push("Available sub-agent specs:".to_string());
121    for spec in specs {
122        let tools = spec.tools.join(", ");
123        let mcp = spec.mcp_access.describe();
124        let model = spec
125            .model
126            .as_ref()
127            .map(|model| format!("{}/{}", model.provider.storage_key(), model.id));
128        let mut details = format!("tools: {tools}; mcp: {mcp}");
129        if let Some(model) = model {
130            details.push_str(&format!("; model: {model}"));
131        }
132        lines.push(format!("- {}: {} ({details})", spec.id, spec.description));
133    }
134    lines.join("\n")
135}
136
137fn default_agent_specs() -> Vec<AgentSpec> {
138    let explore_tools = vec![
139        GLOB_TOOL_NAME,
140        GREP_TOOL_NAME,
141        LS_TOOL_NAME,
142        READ_FILE_TOOL_NAME,
143    ]
144    .into_iter()
145    .map(|tool| tool.to_string())
146    .collect();
147
148    let build_tools = vec![
149        GLOB_TOOL_NAME,
150        GREP_TOOL_NAME,
151        LS_TOOL_NAME,
152        READ_FILE_TOOL_NAME,
153        EDIT_TOOL_NAME,
154        MULTI_EDIT_TOOL_NAME,
155        REPLACE_TOOL_NAME,
156        BASH_TOOL_NAME,
157    ]
158    .into_iter()
159    .map(|tool| tool.to_string())
160    .collect();
161
162    vec![
163        AgentSpec {
164            id: "explore".to_string(),
165            name: "Explore agent".to_string(),
166            description: "Use for code reviews, exploration, and any other read-only task"
167                .to_string(),
168            tools: explore_tools,
169            mcp_access: McpAccessPolicy::None,
170            model: None,
171        },
172        AgentSpec {
173            id: "build".to_string(),
174            name: "Build agent".to_string(),
175            description:
176                "Use only when the sub-agent needs to modify files (includes build commands)"
177                    .to_string(),
178            tools: build_tools,
179            mcp_access: McpAccessPolicy::All,
180            model: None,
181        },
182    ]
183}