1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::Client as ApiClient;
8use crate::app::domain::runtime::RuntimeService;
9use crate::app::domain::session::EventStore;
10use crate::error::Error;
11use crate::model_registry::ModelRegistry;
12use crate::runners::OneShotRunner;
13use crate::session::state::{
14 ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, SessionToolConfig,
15 ToolApprovalPolicy, ToolApprovalPolicyOverrides, ToolVisibility, WorkspaceConfig,
16};
17use crate::tools::{ToolExecutor, ToolSystemBuilder};
18use crate::workspace::{RepoManager, Workspace, WorkspaceManager};
19
20use super::services::{AgentSpawner, SubAgentConfig, SubAgentError, SubAgentResult};
21
22pub struct DefaultAgentSpawner {
23 event_store: Arc<dyn EventStore>,
24 api_client: Arc<ApiClient>,
25 workspace: Arc<dyn Workspace>,
26 model_registry: Arc<ModelRegistry>,
27 workspace_manager: Option<Arc<dyn WorkspaceManager>>,
28 repo_manager: Option<Arc<dyn RepoManager>>,
29}
30
31impl DefaultAgentSpawner {
32 pub fn new(
33 event_store: Arc<dyn EventStore>,
34 api_client: Arc<ApiClient>,
35 workspace: Arc<dyn Workspace>,
36 model_registry: Arc<ModelRegistry>,
37 workspace_manager: Option<Arc<dyn WorkspaceManager>>,
38 repo_manager: Option<Arc<dyn RepoManager>>,
39 ) -> Self {
40 Self {
41 event_store,
42 api_client,
43 workspace,
44 model_registry,
45 workspace_manager,
46 repo_manager,
47 }
48 }
49
50 fn build_tool_executor(&self, workspace: Arc<dyn Workspace>) -> Arc<ToolExecutor> {
51 let mut tool_builder = ToolSystemBuilder::new(
52 workspace,
53 self.event_store.clone(),
54 self.api_client.clone(),
55 self.model_registry.clone(),
56 );
57
58 if let Some(manager) = &self.workspace_manager {
59 tool_builder = tool_builder.with_workspace_manager(manager.clone());
60 }
61 if let Some(manager) = &self.repo_manager {
62 tool_builder = tool_builder.with_repo_manager(manager.clone());
63 }
64
65 tool_builder.build()
66 }
67}
68
69#[async_trait]
70impl AgentSpawner for DefaultAgentSpawner {
71 async fn spawn(
72 &self,
73 config: SubAgentConfig,
74 cancel_token: CancellationToken,
75 ) -> Result<SubAgentResult, SubAgentError> {
76 let workspace = config
77 .workspace
78 .clone()
79 .unwrap_or_else(|| self.workspace.clone());
80 let workspace_path = workspace.working_directory().to_path_buf();
81
82 let visibility_tools: HashSet<String> = config.allowed_tools.iter().cloned().collect();
83 let mcp_backends = if config.allow_mcp_tools {
84 config.mcp_backends.clone()
85 } else {
86 Vec::new()
87 };
88
89 let tool_config = SessionToolConfig {
90 backends: mcp_backends,
91 visibility: ToolVisibility::All,
92 approval_policy: ToolApprovalPolicy::default(),
93 metadata: HashMap::new(),
94 };
95
96 let policy_overrides = SessionPolicyOverrides {
97 default_model: Some(config.model.clone()),
98 tool_visibility: Some(ToolVisibility::Whitelist(visibility_tools.clone())),
99 approval_policy: ToolApprovalPolicyOverrides {
100 preapproved: ApprovalRulesOverrides {
101 tools: visibility_tools,
102 per_tool: HashMap::new(),
103 },
104 },
105 };
106
107 let session_config = SessionConfig {
108 workspace: WorkspaceConfig::Local {
109 path: workspace_path,
110 },
111 workspace_ref: config.workspace_ref.clone(),
112 workspace_id: config.workspace_id,
113 repo_ref: config.repo_ref.clone(),
114 parent_session_id: Some(config.parent_session_id),
115 workspace_name: config.workspace_name.clone(),
116 tool_config,
117 system_prompt: config
118 .system_context
119 .as_ref()
120 .map(|context| context.prompt.clone()),
121 primary_agent_id: None,
122 policy_overrides,
123 title: None,
124 metadata: HashMap::new(),
125 default_model: config.model.clone(),
126 auto_compaction: crate::session::state::AutoCompactionConfig::default(),
127 };
128
129 let tool_executor = self.build_tool_executor(workspace);
130 let runtime = RuntimeService::spawn(
131 self.event_store.clone(),
132 self.api_client.clone(),
133 tool_executor,
134 );
135
136 let run_result = OneShotRunner::run_new_session_with_cancel(
137 &runtime.handle,
138 session_config,
139 config.prompt,
140 config.model.clone(),
141 cancel_token,
142 )
143 .await;
144
145 runtime.shutdown().await;
146
147 let run_result = run_result.map_err(|err| match err {
148 Error::Cancelled => SubAgentError::Cancelled,
149 Error::Api(error) => SubAgentError::Api(error.to_string()),
150 other => SubAgentError::Agent(other.to_string()),
151 })?;
152
153 Ok(SubAgentResult {
154 session_id: run_result.session_id,
155 final_message: run_result.final_message,
156 })
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::DefaultAgentSpawner;
163 use crate::api::Client as ApiClient;
164 use crate::api::{ApiError, CompletionResponse, Provider};
165 use crate::app::conversation::AssistantContent;
166 use crate::app::domain::event::SessionEvent;
167 use crate::app::domain::session::EventStore;
168 use crate::app::domain::session::event_store::InMemoryEventStore;
169 use crate::auth::ProviderRegistry;
170 use crate::config::model::builtin;
171 use crate::model_registry::ModelRegistry;
172 use crate::session::state::ToolVisibility;
173 use crate::test_utils::test_llm_config_provider;
174 use crate::tools::builtin_tools::READ_ONLY_TOOL_NAMES;
175 use crate::tools::services::AgentSpawner;
176 use crate::tools::services::SubAgentConfig;
177 use crate::workspace::WorkspaceConfig;
178 use std::collections::HashSet;
179 use std::sync::Arc;
180 use std::sync::Mutex as StdMutex;
181 use steer_tools::tools::edit::multi_edit::MULTI_EDIT_TOOL_NAME;
182 use steer_tools::tools::replace::REPLACE_TOOL_NAME;
183 use steer_tools::tools::{
184 BASH_TOOL_NAME, EDIT_TOOL_NAME, GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME,
185 VIEW_TOOL_NAME,
186 };
187 use tempfile::TempDir;
188 use tokio_util::sync::CancellationToken;
189
190 #[derive(Clone)]
191 struct RecordingProvider {
192 response: String,
193 last_system: Arc<StdMutex<Option<String>>>,
194 }
195
196 impl RecordingProvider {
197 fn new(response: impl Into<String>, last_system: Arc<StdMutex<Option<String>>>) -> Self {
198 Self {
199 response: response.into(),
200 last_system,
201 }
202 }
203 }
204
205 #[async_trait::async_trait]
206 impl Provider for RecordingProvider {
207 fn name(&self) -> &'static str {
208 "recording"
209 }
210
211 async fn complete(
212 &self,
213 _model_id: &crate::config::model::ModelId,
214 _messages: Vec<crate::app::conversation::Message>,
215 system: Option<crate::app::SystemContext>,
216 _tools: Option<Vec<steer_tools::ToolSchema>>,
217 _call_options: Option<crate::config::model::ModelParameters>,
218 _token: CancellationToken,
219 ) -> Result<CompletionResponse, ApiError> {
220 *self
221 .last_system
222 .lock()
223 .expect("system prompt lock poisoned") =
224 system.and_then(|context| context.render());
225
226 Ok(CompletionResponse {
227 content: vec![AssistantContent::Text {
228 text: self.response.clone(),
229 }],
230 usage: None,
231 })
232 }
233 }
234
235 #[tokio::test]
236 async fn sub_agent_tool_executor_includes_builtin_tools() {
237 let temp_dir = TempDir::new().expect("create temp dir");
238 let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
239 path: temp_dir.path().to_path_buf(),
240 })
241 .await
242 .expect("create workspace");
243 let event_store = Arc::new(InMemoryEventStore::new());
244 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
245 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
246 let api_client = Arc::new(ApiClient::new_with_deps(
247 test_llm_config_provider().unwrap(),
248 provider_registry,
249 model_registry.clone(),
250 ));
251
252 let spawner = DefaultAgentSpawner::new(
253 event_store,
254 api_client,
255 workspace.clone(),
256 model_registry,
257 None,
258 None,
259 );
260
261 let tool_executor = spawner.build_tool_executor(workspace);
262 for tool_name in [
263 GLOB_TOOL_NAME,
264 GREP_TOOL_NAME,
265 LS_TOOL_NAME,
266 VIEW_TOOL_NAME,
267 EDIT_TOOL_NAME,
268 MULTI_EDIT_TOOL_NAME,
269 REPLACE_TOOL_NAME,
270 BASH_TOOL_NAME,
271 ] {
272 assert!(
273 tool_executor.is_builtin_tool(tool_name),
274 "expected sub-agent to have builtin tool: {tool_name}"
275 );
276 }
277 }
278
279 #[tokio::test]
280 async fn sub_agent_persists_events_and_uses_whitelist_visibility() {
281 let temp_dir = TempDir::new().expect("create temp dir");
282 let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
283 path: temp_dir.path().to_path_buf(),
284 })
285 .await
286 .expect("create workspace");
287 let event_store = Arc::new(InMemoryEventStore::new());
288 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
289 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
290 let api_client = Arc::new(ApiClient::new_with_deps(
291 test_llm_config_provider().unwrap(),
292 provider_registry,
293 model_registry.clone(),
294 ));
295
296 let system_capture = Arc::new(StdMutex::new(None));
297 let model = builtin::claude_sonnet_4_5();
298 api_client.insert_test_provider(
299 model.provider.clone(),
300 Arc::new(RecordingProvider::new("ok", system_capture.clone())),
301 );
302
303 let spawner = DefaultAgentSpawner::new(
304 event_store.clone(),
305 api_client,
306 workspace.clone(),
307 model_registry,
308 None,
309 None,
310 );
311
312 let parent_session_id = crate::app::domain::types::SessionId::new();
313 let allowed_tools = vec![
314 VIEW_TOOL_NAME.to_string(),
315 "mcp__alpha__allowed".to_string(),
316 ];
317 let system_prompt = "subagent system".to_string();
318
319 let config = SubAgentConfig {
320 parent_session_id,
321 prompt: "hello".to_string(),
322 allowed_tools: allowed_tools.clone(),
323 model: model.clone(),
324 system_context: Some(crate::app::SystemContext::new(system_prompt.clone())),
325 workspace: Some(workspace),
326 workspace_ref: None,
327 workspace_id: None,
328 repo_ref: None,
329 workspace_name: None,
330 mcp_backends: Vec::new(),
331 allow_mcp_tools: true,
332 };
333
334 let result = spawner
335 .spawn(config, CancellationToken::new())
336 .await
337 .expect("spawn sub-agent");
338
339 let events = event_store
340 .load_events(result.session_id)
341 .await
342 .expect("load events");
343
344 let mut saw_session_created = false;
345 let mut saw_assistant_message = false;
346 let mut seen_visibility = None;
347 let mut seen_preapproved = None;
348
349 for (_, event) in events {
350 match event {
351 SessionEvent::SessionCreated { config, .. } => {
352 saw_session_created = true;
353 assert_eq!(config.parent_session_id, Some(parent_session_id));
354 let configured_system = config
355 .system_prompt
356 .as_deref()
357 .expect("expected system prompt in session config");
358 assert!(
359 configured_system.starts_with(system_prompt.as_str()),
360 "expected system prompt prefix, got: {configured_system:?}"
361 );
362 match &config.tool_config.visibility {
363 ToolVisibility::Whitelist(allowed) => {
364 seen_visibility = Some(allowed.clone());
365 }
366 other => panic!("expected whitelist visibility, got {other:?}"),
367 }
368 seen_preapproved =
369 Some(config.tool_config.approval_policy.preapproved.tools.clone());
370 }
371 SessionEvent::AssistantMessageAdded { .. } => {
372 saw_assistant_message = true;
373 }
374 _ => {}
375 }
376 }
377
378 assert!(saw_session_created, "expected SessionCreated event");
379 assert!(
380 saw_assistant_message,
381 "expected AssistantMessageAdded event"
382 );
383
384 let expected_visibility: HashSet<String> = allowed_tools.into_iter().collect();
385 let expected_preapproved: HashSet<String> = READ_ONLY_TOOL_NAMES
386 .iter()
387 .map(|name| (*name).to_string())
388 .chain(expected_visibility.iter().cloned())
389 .collect();
390 assert_eq!(seen_visibility, Some(expected_visibility));
391 assert_eq!(seen_preapproved, Some(expected_preapproved));
392
393 let captured_system = system_capture
394 .lock()
395 .expect("system capture lock poisoned")
396 .clone();
397 let captured_system = captured_system.expect("expected captured system prompt");
398 assert!(
399 captured_system.starts_with(system_prompt.as_str()),
400 "expected system prompt prefix, got: {captured_system:?}"
401 );
402 }
403}