1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::RwLock;
5
6use thiserror::Error;
7
8use crate::config::model::ModelId;
9use steer_tools::tools::edit::multi_edit::MULTI_EDIT_TOOL_NAME;
10use steer_tools::tools::replace::REPLACE_TOOL_NAME;
11use steer_tools::tools::{
12 BASH_TOOL_NAME, EDIT_TOOL_NAME, GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME,
13};
14
15pub const DEFAULT_AGENT_SPEC_ID: &str = "explore";
16
17#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
18#[serde(rename_all = "snake_case")]
19pub enum McpAccessPolicy {
20 None,
21 Allowlist(Vec<String>),
22 All,
23}
24
25impl McpAccessPolicy {
26 pub fn allows_server(&self, server_name: &str) -> bool {
27 match self {
28 McpAccessPolicy::None => false,
29 McpAccessPolicy::All => true,
30 McpAccessPolicy::Allowlist(servers) => servers.iter().any(|s| s == server_name),
31 }
32 }
33
34 pub fn allow_mcp_tools(&self) -> bool {
35 !matches!(self, McpAccessPolicy::None)
36 }
37
38 pub fn describe(&self) -> String {
39 match self {
40 McpAccessPolicy::None => "none".to_string(),
41 McpAccessPolicy::All => "all".to_string(),
42 McpAccessPolicy::Allowlist(servers) => {
43 let list = if servers.is_empty() {
44 "<empty>".to_string()
45 } else {
46 servers.join(", ")
47 };
48 format!("allowlist({list})")
49 }
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
55pub struct AgentSpec {
56 pub id: String,
57 pub name: String,
58 pub description: String,
59 pub tools: Vec<String>,
60 pub mcp_access: McpAccessPolicy,
61 #[serde(default)]
62 pub model: Option<ModelId>,
63}
64
65#[derive(Debug, Error)]
66pub enum AgentSpecError {
67 #[error("Agent spec already registered: {0}")]
68 AlreadyRegistered(String),
69 #[error("Agent spec registry lock poisoned")]
70 RegistryPoisoned,
71}
72
73static AGENT_SPECS: std::sync::LazyLock<RwLock<HashMap<String, AgentSpec>>> =
74 std::sync::LazyLock::new(|| {
75 let mut specs = HashMap::new();
76 for spec in default_agent_specs() {
77 specs.insert(spec.id.clone(), spec);
78 }
79 RwLock::new(specs)
80 });
81
82pub fn register_agent_spec(spec: AgentSpec) -> Result<(), AgentSpecError> {
83 let mut registry = AGENT_SPECS
84 .write()
85 .map_err(|_| AgentSpecError::RegistryPoisoned)?;
86 if registry.contains_key(&spec.id) {
87 return Err(AgentSpecError::AlreadyRegistered(spec.id));
88 }
89 registry.insert(spec.id.clone(), spec);
90 Ok(())
91}
92
93pub fn agent_spec(id: &str) -> Option<AgentSpec> {
94 let registry = AGENT_SPECS.read().ok()?;
95 registry.get(id).cloned()
96}
97
98pub fn agent_specs() -> Vec<AgentSpec> {
99 let registry = match AGENT_SPECS.read() {
100 Ok(registry) => registry,
101 Err(_) => return Vec::new(),
102 };
103 let mut specs: Vec<_> = registry.values().cloned().collect();
104 specs.sort_by(|a, b| a.id.cmp(&b.id));
105 specs
106}
107
108pub fn default_agent_spec_id() -> &'static str {
109 DEFAULT_AGENT_SPEC_ID
110}
111
112pub fn agent_specs_prompt() -> String {
113 let specs = agent_specs();
114 if specs.is_empty() {
115 return String::new();
116 }
117
118 let mut lines = Vec::new();
119 lines.push("Available sub-agent specs:".to_string());
120 for spec in specs {
121 let tools = spec.tools.join(", ");
122 let mcp = spec.mcp_access.describe();
123 let model = spec
124 .model
125 .as_ref()
126 .map(|model| format!("{}/{}", model.provider.storage_key(), model.id));
127 let mut details = format!("tools: {tools}; mcp: {mcp}");
128 if let Some(model) = model {
129 details.push_str(&format!("; model: {model}"));
130 }
131 lines.push(format!("- {}: {} ({details})", spec.id, spec.description));
132 }
133 lines.join("\n")
134}
135
136fn default_agent_specs() -> Vec<AgentSpec> {
137 let explore_tools = vec![GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME]
138 .into_iter()
139 .map(|tool| tool.to_string())
140 .collect();
141
142 let build_tools = vec![
143 GLOB_TOOL_NAME,
144 GREP_TOOL_NAME,
145 LS_TOOL_NAME,
146 VIEW_TOOL_NAME,
147 EDIT_TOOL_NAME,
148 MULTI_EDIT_TOOL_NAME,
149 REPLACE_TOOL_NAME,
150 BASH_TOOL_NAME,
151 ]
152 .into_iter()
153 .map(|tool| tool.to_string())
154 .collect();
155
156 vec![
157 AgentSpec {
158 id: "explore".to_string(),
159 name: "Explore agent".to_string(),
160 description: "Read-only search and inspection".to_string(),
161 tools: explore_tools,
162 mcp_access: McpAccessPolicy::None,
163 model: None,
164 },
165 AgentSpec {
166 id: "build".to_string(),
167 name: "Build agent".to_string(),
168 description: "Read/write changes plus build commands".to_string(),
169 tools: build_tools,
170 mcp_access: McpAccessPolicy::All,
171 model: None,
172 },
173 ]
174}