1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::Client as ApiClient;
8use crate::app::domain::runtime::RuntimeService;
9use crate::app::domain::session::EventStore;
10use crate::error::Error;
11use crate::model_registry::ModelRegistry;
12use crate::runners::OneShotRunner;
13use crate::session::state::{
14 ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, SessionToolConfig,
15 ToolApprovalPolicy, ToolApprovalPolicyOverrides, ToolVisibility, UnapprovedBehavior,
16 WorkspaceConfig,
17};
18use crate::tools::{ToolExecutor, ToolSystemBuilder};
19use crate::workspace::{RepoManager, Workspace, WorkspaceManager};
20
21use super::services::{AgentSpawner, SubAgentConfig, SubAgentError, SubAgentResult};
22
23pub struct DefaultAgentSpawner {
24 event_store: Arc<dyn EventStore>,
25 api_client: Arc<ApiClient>,
26 workspace: Arc<dyn Workspace>,
27 model_registry: Arc<ModelRegistry>,
28 workspace_manager: Option<Arc<dyn WorkspaceManager>>,
29 repo_manager: Option<Arc<dyn RepoManager>>,
30}
31
32impl DefaultAgentSpawner {
33 pub fn new(
34 event_store: Arc<dyn EventStore>,
35 api_client: Arc<ApiClient>,
36 workspace: Arc<dyn Workspace>,
37 model_registry: Arc<ModelRegistry>,
38 workspace_manager: Option<Arc<dyn WorkspaceManager>>,
39 repo_manager: Option<Arc<dyn RepoManager>>,
40 ) -> Self {
41 Self {
42 event_store,
43 api_client,
44 workspace,
45 model_registry,
46 workspace_manager,
47 repo_manager,
48 }
49 }
50
51 fn build_tool_executor(&self, workspace: Arc<dyn Workspace>) -> Arc<ToolExecutor> {
52 let mut tool_builder = ToolSystemBuilder::new(
53 workspace,
54 self.event_store.clone(),
55 self.api_client.clone(),
56 self.model_registry.clone(),
57 );
58
59 if let Some(manager) = &self.workspace_manager {
60 tool_builder = tool_builder.with_workspace_manager(manager.clone());
61 }
62 if let Some(manager) = &self.repo_manager {
63 tool_builder = tool_builder.with_repo_manager(manager.clone());
64 }
65
66 tool_builder.build()
67 }
68}
69
70#[async_trait]
71impl AgentSpawner for DefaultAgentSpawner {
72 async fn spawn(
73 &self,
74 config: SubAgentConfig,
75 cancel_token: CancellationToken,
76 ) -> Result<SubAgentResult, SubAgentError> {
77 let workspace = config
78 .workspace
79 .clone()
80 .unwrap_or_else(|| self.workspace.clone());
81 let workspace_path = workspace.working_directory().to_path_buf();
82
83 let visibility_tools: HashSet<String> = config.allowed_tools.iter().cloned().collect();
84 let mcp_backends = if config.allow_mcp_tools {
85 config.mcp_backends.clone()
86 } else {
87 Vec::new()
88 };
89
90 let tool_config = SessionToolConfig {
91 backends: mcp_backends,
92 visibility: ToolVisibility::All,
93 approval_policy: ToolApprovalPolicy::default(),
94 metadata: HashMap::new(),
95 };
96
97 let policy_overrides = SessionPolicyOverrides {
98 default_model: Some(config.model.clone()),
99 tool_visibility: Some(ToolVisibility::Whitelist(visibility_tools.clone())),
100 approval_policy: ToolApprovalPolicyOverrides {
101 default_behavior: Some(UnapprovedBehavior::Prompt),
102 preapproved: ApprovalRulesOverrides {
103 tools: visibility_tools,
104 per_tool: HashMap::new(),
105 },
106 },
107 };
108
109 let session_config = SessionConfig {
110 workspace: WorkspaceConfig::Local {
111 path: workspace_path,
112 },
113 workspace_ref: config.workspace_ref.clone(),
114 workspace_id: config.workspace_id,
115 repo_ref: config.repo_ref.clone(),
116 parent_session_id: Some(config.parent_session_id),
117 workspace_name: config.workspace_name.clone(),
118 tool_config,
119 system_prompt: config
120 .system_context
121 .as_ref()
122 .map(|context| context.prompt.clone()),
123 primary_agent_id: None,
124 policy_overrides,
125 metadata: HashMap::new(),
126 default_model: config.model.clone(),
127 };
128
129 let tool_executor = self.build_tool_executor(workspace);
130 let runtime = RuntimeService::spawn(
131 self.event_store.clone(),
132 self.api_client.clone(),
133 tool_executor,
134 );
135
136 let run_result = OneShotRunner::run_new_session_with_cancel(
137 &runtime.handle,
138 session_config,
139 config.prompt,
140 config.model.clone(),
141 cancel_token,
142 )
143 .await;
144
145 runtime.shutdown().await;
146
147 let run_result = run_result.map_err(|err| match err {
148 Error::Cancelled => SubAgentError::Cancelled,
149 Error::Api(error) => SubAgentError::Api(error.to_string()),
150 other => SubAgentError::Agent(other.to_string()),
151 })?;
152
153 Ok(SubAgentResult {
154 session_id: run_result.session_id,
155 final_message: run_result.final_message,
156 })
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::DefaultAgentSpawner;
163 use crate::api::Client as ApiClient;
164 use crate::api::{ApiError, CompletionResponse, Provider};
165 use crate::app::conversation::AssistantContent;
166 use crate::app::domain::event::SessionEvent;
167 use crate::app::domain::session::EventStore;
168 use crate::app::domain::session::event_store::InMemoryEventStore;
169 use crate::auth::ProviderRegistry;
170 use crate::config::model::builtin;
171 use crate::model_registry::ModelRegistry;
172 use crate::session::state::ToolVisibility;
173 use crate::test_utils::test_llm_config_provider;
174 use crate::tools::services::AgentSpawner;
175 use crate::tools::services::SubAgentConfig;
176 use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
177 use crate::workspace::WorkspaceConfig;
178 use std::collections::HashSet;
179 use std::sync::Arc;
180 use std::sync::Mutex as StdMutex;
181 use steer_tools::tools::edit::multi_edit::MULTI_EDIT_TOOL_NAME;
182 use steer_tools::tools::replace::REPLACE_TOOL_NAME;
183 use steer_tools::tools::{
184 BASH_TOOL_NAME, EDIT_TOOL_NAME, GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME,
185 VIEW_TOOL_NAME,
186 };
187 use tempfile::TempDir;
188 use tokio_util::sync::CancellationToken;
189
190 #[derive(Clone)]
191 struct RecordingProvider {
192 response: String,
193 last_system: Arc<StdMutex<Option<String>>>,
194 }
195
196 impl RecordingProvider {
197 fn new(response: impl Into<String>, last_system: Arc<StdMutex<Option<String>>>) -> Self {
198 Self {
199 response: response.into(),
200 last_system,
201 }
202 }
203 }
204
205 #[async_trait::async_trait]
206 impl Provider for RecordingProvider {
207 fn name(&self) -> &'static str {
208 "recording"
209 }
210
211 async fn complete(
212 &self,
213 _model_id: &crate::config::model::ModelId,
214 _messages: Vec<crate::app::conversation::Message>,
215 system: Option<crate::app::SystemContext>,
216 _tools: Option<Vec<steer_tools::ToolSchema>>,
217 _call_options: Option<crate::config::model::ModelParameters>,
218 _token: CancellationToken,
219 ) -> Result<CompletionResponse, ApiError> {
220 *self
221 .last_system
222 .lock()
223 .expect("system prompt lock poisoned") =
224 system.and_then(|context| context.render());
225
226 Ok(CompletionResponse {
227 content: vec![AssistantContent::Text {
228 text: self.response.clone(),
229 }],
230 })
231 }
232 }
233
234 #[tokio::test]
235 async fn sub_agent_tool_executor_includes_static_tools() {
236 let temp_dir = TempDir::new().expect("create temp dir");
237 let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
238 path: temp_dir.path().to_path_buf(),
239 })
240 .await
241 .expect("create workspace");
242 let event_store = Arc::new(InMemoryEventStore::new());
243 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
244 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
245 let api_client = Arc::new(ApiClient::new_with_deps(
246 test_llm_config_provider().unwrap(),
247 provider_registry,
248 model_registry.clone(),
249 ));
250
251 let spawner = DefaultAgentSpawner::new(
252 event_store,
253 api_client,
254 workspace.clone(),
255 model_registry,
256 None,
257 None,
258 );
259
260 let tool_executor = spawner.build_tool_executor(workspace);
261 for tool_name in [
262 GLOB_TOOL_NAME,
263 GREP_TOOL_NAME,
264 LS_TOOL_NAME,
265 VIEW_TOOL_NAME,
266 EDIT_TOOL_NAME,
267 MULTI_EDIT_TOOL_NAME,
268 REPLACE_TOOL_NAME,
269 BASH_TOOL_NAME,
270 ] {
271 assert!(
272 tool_executor.is_static_tool(tool_name),
273 "expected sub-agent to have static tool: {tool_name}"
274 );
275 }
276 }
277
278 #[tokio::test]
279 async fn sub_agent_persists_events_and_uses_whitelist_visibility() {
280 let temp_dir = TempDir::new().expect("create temp dir");
281 let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
282 path: temp_dir.path().to_path_buf(),
283 })
284 .await
285 .expect("create workspace");
286 let event_store = Arc::new(InMemoryEventStore::new());
287 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
288 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
289 let api_client = Arc::new(ApiClient::new_with_deps(
290 test_llm_config_provider().unwrap(),
291 provider_registry,
292 model_registry.clone(),
293 ));
294
295 let system_capture = Arc::new(StdMutex::new(None));
296 let model = builtin::claude_sonnet_4_5();
297 api_client.insert_test_provider(
298 model.provider.clone(),
299 Arc::new(RecordingProvider::new("ok", system_capture.clone())),
300 );
301
302 let spawner = DefaultAgentSpawner::new(
303 event_store.clone(),
304 api_client,
305 workspace.clone(),
306 model_registry,
307 None,
308 None,
309 );
310
311 let parent_session_id = crate::app::domain::types::SessionId::new();
312 let allowed_tools = vec![
313 VIEW_TOOL_NAME.to_string(),
314 "mcp__alpha__allowed".to_string(),
315 ];
316 let system_prompt = "subagent system".to_string();
317
318 let config = SubAgentConfig {
319 parent_session_id,
320 prompt: "hello".to_string(),
321 allowed_tools: allowed_tools.clone(),
322 model: model.clone(),
323 system_context: Some(crate::app::SystemContext::new(system_prompt.clone())),
324 workspace: Some(workspace),
325 workspace_ref: None,
326 workspace_id: None,
327 repo_ref: None,
328 workspace_name: None,
329 mcp_backends: Vec::new(),
330 allow_mcp_tools: true,
331 };
332
333 let result = spawner
334 .spawn(config, CancellationToken::new())
335 .await
336 .expect("spawn sub-agent");
337
338 let events = event_store
339 .load_events(result.session_id)
340 .await
341 .expect("load events");
342
343 let mut saw_session_created = false;
344 let mut saw_assistant_message = false;
345 let mut seen_visibility = None;
346 let mut seen_preapproved = None;
347
348 for (_, event) in events {
349 match event {
350 SessionEvent::SessionCreated { config, .. } => {
351 saw_session_created = true;
352 assert_eq!(config.parent_session_id, Some(parent_session_id));
353 let configured_system = config
354 .system_prompt
355 .as_deref()
356 .expect("expected system prompt in session config");
357 assert!(
358 configured_system.starts_with(system_prompt.as_str()),
359 "expected system prompt prefix, got: {configured_system:?}"
360 );
361 match &config.tool_config.visibility {
362 ToolVisibility::Whitelist(allowed) => {
363 seen_visibility = Some(allowed.clone());
364 }
365 other => panic!("expected whitelist visibility, got {other:?}"),
366 }
367 seen_preapproved =
368 Some(config.tool_config.approval_policy.preapproved.tools.clone());
369 }
370 SessionEvent::AssistantMessageAdded { .. } => {
371 saw_assistant_message = true;
372 }
373 _ => {}
374 }
375 }
376
377 assert!(saw_session_created, "expected SessionCreated event");
378 assert!(
379 saw_assistant_message,
380 "expected AssistantMessageAdded event"
381 );
382
383 let expected_visibility: HashSet<String> = allowed_tools.into_iter().collect();
384 let expected_preapproved: HashSet<String> = READ_ONLY_TOOL_NAMES
385 .iter()
386 .map(|name| (*name).to_string())
387 .chain(expected_visibility.iter().cloned())
388 .collect();
389 assert_eq!(seen_visibility, Some(expected_visibility));
390 assert_eq!(seen_preapproved, Some(expected_preapproved));
391
392 let captured_system = system_capture
393 .lock()
394 .expect("system capture lock poisoned")
395 .clone();
396 let captured_system = captured_system.expect("expected captured system prompt");
397 assert!(
398 captured_system.starts_with(system_prompt.as_str()),
399 "expected system prompt prefix, got: {captured_system:?}"
400 );
401 }
402}