Skip to main content

steer_core/
agents.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::RwLock;
5
6use thiserror::Error;
7
8use crate::config::model::ModelId;
9use steer_tools::tools::edit::multi_edit::MULTI_EDIT_TOOL_NAME;
10use steer_tools::tools::replace::REPLACE_TOOL_NAME;
11use steer_tools::tools::{
12    BASH_TOOL_NAME, EDIT_TOOL_NAME, GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME,
13};
14
15pub const DEFAULT_AGENT_SPEC_ID: &str = "explore";
16
17#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
18#[serde(rename_all = "snake_case")]
19pub enum McpAccessPolicy {
20    None,
21    Allowlist(Vec<String>),
22    All,
23}
24
25impl McpAccessPolicy {
26    pub fn allows_server(&self, server_name: &str) -> bool {
27        match self {
28            McpAccessPolicy::None => false,
29            McpAccessPolicy::All => true,
30            McpAccessPolicy::Allowlist(servers) => servers.iter().any(|s| s == server_name),
31        }
32    }
33
34    pub fn allow_mcp_tools(&self) -> bool {
35        !matches!(self, McpAccessPolicy::None)
36    }
37
38    pub fn describe(&self) -> String {
39        match self {
40            McpAccessPolicy::None => "none".to_string(),
41            McpAccessPolicy::All => "all".to_string(),
42            McpAccessPolicy::Allowlist(servers) => {
43                let list = if servers.is_empty() {
44                    "<empty>".to_string()
45                } else {
46                    servers.join(", ")
47                };
48                format!("allowlist({list})")
49            }
50        }
51    }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
55pub struct AgentSpec {
56    pub id: String,
57    pub name: String,
58    pub description: String,
59    pub tools: Vec<String>,
60    pub mcp_access: McpAccessPolicy,
61    #[serde(default)]
62    pub model: Option<ModelId>,
63}
64
65#[derive(Debug, Error)]
66pub enum AgentSpecError {
67    #[error("Agent spec already registered: {0}")]
68    AlreadyRegistered(String),
69    #[error("Agent spec registry lock poisoned")]
70    RegistryPoisoned,
71}
72
73static AGENT_SPECS: std::sync::LazyLock<RwLock<HashMap<String, AgentSpec>>> =
74    std::sync::LazyLock::new(|| {
75        let mut specs = HashMap::new();
76        for spec in default_agent_specs() {
77            specs.insert(spec.id.clone(), spec);
78        }
79        RwLock::new(specs)
80    });
81
82pub fn register_agent_spec(spec: AgentSpec) -> Result<(), AgentSpecError> {
83    let mut registry = AGENT_SPECS
84        .write()
85        .map_err(|_| AgentSpecError::RegistryPoisoned)?;
86    if registry.contains_key(&spec.id) {
87        return Err(AgentSpecError::AlreadyRegistered(spec.id));
88    }
89    registry.insert(spec.id.clone(), spec);
90    Ok(())
91}
92
93pub fn agent_spec(id: &str) -> Option<AgentSpec> {
94    let registry = AGENT_SPECS.read().ok()?;
95    registry.get(id).cloned()
96}
97
98pub fn agent_specs() -> Vec<AgentSpec> {
99    let registry = match AGENT_SPECS.read() {
100        Ok(registry) => registry,
101        Err(_) => return Vec::new(),
102    };
103    let mut specs: Vec<_> = registry.values().cloned().collect();
104    specs.sort_by(|a, b| a.id.cmp(&b.id));
105    specs
106}
107
108pub fn default_agent_spec_id() -> &'static str {
109    DEFAULT_AGENT_SPEC_ID
110}
111
112pub fn agent_specs_prompt() -> String {
113    let specs = agent_specs();
114    if specs.is_empty() {
115        return String::new();
116    }
117
118    let mut lines = Vec::new();
119    lines.push("Available sub-agent specs:".to_string());
120    for spec in specs {
121        let tools = spec.tools.join(", ");
122        let mcp = spec.mcp_access.describe();
123        let model = spec
124            .model
125            .as_ref()
126            .map(|model| format!("{}/{}", model.provider.storage_key(), model.id));
127        let mut details = format!("tools: {tools}; mcp: {mcp}");
128        if let Some(model) = model {
129            details.push_str(&format!("; model: {model}"));
130        }
131        lines.push(format!("- {}: {} ({details})", spec.id, spec.description));
132    }
133    lines.join("\n")
134}
135
136fn default_agent_specs() -> Vec<AgentSpec> {
137    let explore_tools = vec![GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME]
138        .into_iter()
139        .map(|tool| tool.to_string())
140        .collect();
141
142    let build_tools = vec![
143        GLOB_TOOL_NAME,
144        GREP_TOOL_NAME,
145        LS_TOOL_NAME,
146        VIEW_TOOL_NAME,
147        EDIT_TOOL_NAME,
148        MULTI_EDIT_TOOL_NAME,
149        REPLACE_TOOL_NAME,
150        BASH_TOOL_NAME,
151    ]
152    .into_iter()
153    .map(|tool| tool.to_string())
154    .collect();
155
156    vec![
157        AgentSpec {
158            id: "explore".to_string(),
159            name: "Explore agent".to_string(),
160            description: "Read-only search and inspection".to_string(),
161            tools: explore_tools,
162            mcp_access: McpAccessPolicy::None,
163            model: None,
164        },
165        AgentSpec {
166            id: "build".to_string(),
167            name: "Build agent".to_string(),
168            description: "Read/write changes plus build commands".to_string(),
169            tools: build_tools,
170            mcp_access: McpAccessPolicy::All,
171            model: None,
172        },
173    ]
174}