1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::Client as ApiClient;
8use crate::app::domain::runtime::RuntimeService;
9use crate::app::domain::session::EventStore;
10use crate::error::Error;
11use crate::model_registry::ModelRegistry;
12use crate::runners::OneShotRunner;
13use crate::session::state::{
14 ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, SessionToolConfig,
15 ToolApprovalPolicy, ToolApprovalPolicyOverrides, ToolVisibility, UnapprovedBehavior,
16 WorkspaceConfig,
17};
18use crate::tools::{ToolExecutor, ToolSystemBuilder};
19use crate::workspace::{RepoManager, Workspace, WorkspaceManager};
20
21use super::services::{AgentSpawner, SubAgentConfig, SubAgentError, SubAgentResult};
22
23pub struct DefaultAgentSpawner {
24 event_store: Arc<dyn EventStore>,
25 api_client: Arc<ApiClient>,
26 workspace: Arc<dyn Workspace>,
27 model_registry: Arc<ModelRegistry>,
28 workspace_manager: Option<Arc<dyn WorkspaceManager>>,
29 repo_manager: Option<Arc<dyn RepoManager>>,
30}
31
32impl DefaultAgentSpawner {
33 pub fn new(
34 event_store: Arc<dyn EventStore>,
35 api_client: Arc<ApiClient>,
36 workspace: Arc<dyn Workspace>,
37 model_registry: Arc<ModelRegistry>,
38 workspace_manager: Option<Arc<dyn WorkspaceManager>>,
39 repo_manager: Option<Arc<dyn RepoManager>>,
40 ) -> Self {
41 Self {
42 event_store,
43 api_client,
44 workspace,
45 model_registry,
46 workspace_manager,
47 repo_manager,
48 }
49 }
50
51 fn build_tool_executor(&self, workspace: Arc<dyn Workspace>) -> Arc<ToolExecutor> {
52 let mut tool_builder = ToolSystemBuilder::new(
53 workspace,
54 self.event_store.clone(),
55 self.api_client.clone(),
56 self.model_registry.clone(),
57 );
58
59 if let Some(manager) = &self.workspace_manager {
60 tool_builder = tool_builder.with_workspace_manager(manager.clone());
61 }
62 if let Some(manager) = &self.repo_manager {
63 tool_builder = tool_builder.with_repo_manager(manager.clone());
64 }
65
66 tool_builder.build()
67 }
68}
69
70#[async_trait]
71impl AgentSpawner for DefaultAgentSpawner {
72 async fn spawn(
73 &self,
74 config: SubAgentConfig,
75 cancel_token: CancellationToken,
76 ) -> Result<SubAgentResult, SubAgentError> {
77 let workspace = config
78 .workspace
79 .clone()
80 .unwrap_or_else(|| self.workspace.clone());
81 let workspace_path = workspace.working_directory().to_path_buf();
82
83 let visibility_tools: HashSet<String> = config.allowed_tools.iter().cloned().collect();
84 let mcp_backends = if config.allow_mcp_tools {
85 config.mcp_backends.clone()
86 } else {
87 Vec::new()
88 };
89
90 let tool_config = SessionToolConfig {
91 backends: mcp_backends,
92 visibility: ToolVisibility::All,
93 approval_policy: ToolApprovalPolicy::default(),
94 metadata: HashMap::new(),
95 };
96
97 let policy_overrides = SessionPolicyOverrides {
98 default_model: Some(config.model.clone()),
99 tool_visibility: Some(ToolVisibility::Whitelist(visibility_tools.clone())),
100 approval_policy: ToolApprovalPolicyOverrides {
101 default_behavior: Some(UnapprovedBehavior::Prompt),
102 preapproved: ApprovalRulesOverrides {
103 tools: visibility_tools,
104 per_tool: HashMap::new(),
105 },
106 },
107 };
108
109 let session_config = SessionConfig {
110 workspace: WorkspaceConfig::Local {
111 path: workspace_path,
112 },
113 workspace_ref: config.workspace_ref.clone(),
114 workspace_id: config.workspace_id,
115 repo_ref: config.repo_ref.clone(),
116 parent_session_id: Some(config.parent_session_id),
117 workspace_name: config.workspace_name.clone(),
118 tool_config,
119 system_prompt: config
120 .system_context
121 .as_ref()
122 .map(|context| context.prompt.clone()),
123 primary_agent_id: None,
124 policy_overrides,
125 metadata: HashMap::new(),
126 default_model: config.model.clone(),
127 auto_compaction: crate::session::state::AutoCompactionConfig::default(),
128 };
129
130 let tool_executor = self.build_tool_executor(workspace);
131 let runtime = RuntimeService::spawn(
132 self.event_store.clone(),
133 self.api_client.clone(),
134 tool_executor,
135 );
136
137 let run_result = OneShotRunner::run_new_session_with_cancel(
138 &runtime.handle,
139 session_config,
140 config.prompt,
141 config.model.clone(),
142 cancel_token,
143 )
144 .await;
145
146 runtime.shutdown().await;
147
148 let run_result = run_result.map_err(|err| match err {
149 Error::Cancelled => SubAgentError::Cancelled,
150 Error::Api(error) => SubAgentError::Api(error.to_string()),
151 other => SubAgentError::Agent(other.to_string()),
152 })?;
153
154 Ok(SubAgentResult {
155 session_id: run_result.session_id,
156 final_message: run_result.final_message,
157 })
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::DefaultAgentSpawner;
164 use crate::api::Client as ApiClient;
165 use crate::api::{ApiError, CompletionResponse, Provider};
166 use crate::app::conversation::AssistantContent;
167 use crate::app::domain::event::SessionEvent;
168 use crate::app::domain::session::EventStore;
169 use crate::app::domain::session::event_store::InMemoryEventStore;
170 use crate::auth::ProviderRegistry;
171 use crate::config::model::builtin;
172 use crate::model_registry::ModelRegistry;
173 use crate::session::state::ToolVisibility;
174 use crate::test_utils::test_llm_config_provider;
175 use crate::tools::services::AgentSpawner;
176 use crate::tools::services::SubAgentConfig;
177 use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
178 use crate::workspace::WorkspaceConfig;
179 use std::collections::HashSet;
180 use std::sync::Arc;
181 use std::sync::Mutex as StdMutex;
182 use steer_tools::tools::edit::multi_edit::MULTI_EDIT_TOOL_NAME;
183 use steer_tools::tools::replace::REPLACE_TOOL_NAME;
184 use steer_tools::tools::{
185 BASH_TOOL_NAME, EDIT_TOOL_NAME, GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME,
186 VIEW_TOOL_NAME,
187 };
188 use tempfile::TempDir;
189 use tokio_util::sync::CancellationToken;
190
191 #[derive(Clone)]
192 struct RecordingProvider {
193 response: String,
194 last_system: Arc<StdMutex<Option<String>>>,
195 }
196
197 impl RecordingProvider {
198 fn new(response: impl Into<String>, last_system: Arc<StdMutex<Option<String>>>) -> Self {
199 Self {
200 response: response.into(),
201 last_system,
202 }
203 }
204 }
205
206 #[async_trait::async_trait]
207 impl Provider for RecordingProvider {
208 fn name(&self) -> &'static str {
209 "recording"
210 }
211
212 async fn complete(
213 &self,
214 _model_id: &crate::config::model::ModelId,
215 _messages: Vec<crate::app::conversation::Message>,
216 system: Option<crate::app::SystemContext>,
217 _tools: Option<Vec<steer_tools::ToolSchema>>,
218 _call_options: Option<crate::config::model::ModelParameters>,
219 _token: CancellationToken,
220 ) -> Result<CompletionResponse, ApiError> {
221 *self
222 .last_system
223 .lock()
224 .expect("system prompt lock poisoned") =
225 system.and_then(|context| context.render());
226
227 Ok(CompletionResponse {
228 content: vec![AssistantContent::Text {
229 text: self.response.clone(),
230 }],
231 usage: None,
232 })
233 }
234 }
235
236 #[tokio::test]
237 async fn sub_agent_tool_executor_includes_static_tools() {
238 let temp_dir = TempDir::new().expect("create temp dir");
239 let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
240 path: temp_dir.path().to_path_buf(),
241 })
242 .await
243 .expect("create workspace");
244 let event_store = Arc::new(InMemoryEventStore::new());
245 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
246 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
247 let api_client = Arc::new(ApiClient::new_with_deps(
248 test_llm_config_provider().unwrap(),
249 provider_registry,
250 model_registry.clone(),
251 ));
252
253 let spawner = DefaultAgentSpawner::new(
254 event_store,
255 api_client,
256 workspace.clone(),
257 model_registry,
258 None,
259 None,
260 );
261
262 let tool_executor = spawner.build_tool_executor(workspace);
263 for tool_name in [
264 GLOB_TOOL_NAME,
265 GREP_TOOL_NAME,
266 LS_TOOL_NAME,
267 VIEW_TOOL_NAME,
268 EDIT_TOOL_NAME,
269 MULTI_EDIT_TOOL_NAME,
270 REPLACE_TOOL_NAME,
271 BASH_TOOL_NAME,
272 ] {
273 assert!(
274 tool_executor.is_static_tool(tool_name),
275 "expected sub-agent to have static tool: {tool_name}"
276 );
277 }
278 }
279
280 #[tokio::test]
281 async fn sub_agent_persists_events_and_uses_whitelist_visibility() {
282 let temp_dir = TempDir::new().expect("create temp dir");
283 let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
284 path: temp_dir.path().to_path_buf(),
285 })
286 .await
287 .expect("create workspace");
288 let event_store = Arc::new(InMemoryEventStore::new());
289 let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
290 let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
291 let api_client = Arc::new(ApiClient::new_with_deps(
292 test_llm_config_provider().unwrap(),
293 provider_registry,
294 model_registry.clone(),
295 ));
296
297 let system_capture = Arc::new(StdMutex::new(None));
298 let model = builtin::claude_sonnet_4_5();
299 api_client.insert_test_provider(
300 model.provider.clone(),
301 Arc::new(RecordingProvider::new("ok", system_capture.clone())),
302 );
303
304 let spawner = DefaultAgentSpawner::new(
305 event_store.clone(),
306 api_client,
307 workspace.clone(),
308 model_registry,
309 None,
310 None,
311 );
312
313 let parent_session_id = crate::app::domain::types::SessionId::new();
314 let allowed_tools = vec![
315 VIEW_TOOL_NAME.to_string(),
316 "mcp__alpha__allowed".to_string(),
317 ];
318 let system_prompt = "subagent system".to_string();
319
320 let config = SubAgentConfig {
321 parent_session_id,
322 prompt: "hello".to_string(),
323 allowed_tools: allowed_tools.clone(),
324 model: model.clone(),
325 system_context: Some(crate::app::SystemContext::new(system_prompt.clone())),
326 workspace: Some(workspace),
327 workspace_ref: None,
328 workspace_id: None,
329 repo_ref: None,
330 workspace_name: None,
331 mcp_backends: Vec::new(),
332 allow_mcp_tools: true,
333 };
334
335 let result = spawner
336 .spawn(config, CancellationToken::new())
337 .await
338 .expect("spawn sub-agent");
339
340 let events = event_store
341 .load_events(result.session_id)
342 .await
343 .expect("load events");
344
345 let mut saw_session_created = false;
346 let mut saw_assistant_message = false;
347 let mut seen_visibility = None;
348 let mut seen_preapproved = None;
349
350 for (_, event) in events {
351 match event {
352 SessionEvent::SessionCreated { config, .. } => {
353 saw_session_created = true;
354 assert_eq!(config.parent_session_id, Some(parent_session_id));
355 let configured_system = config
356 .system_prompt
357 .as_deref()
358 .expect("expected system prompt in session config");
359 assert!(
360 configured_system.starts_with(system_prompt.as_str()),
361 "expected system prompt prefix, got: {configured_system:?}"
362 );
363 match &config.tool_config.visibility {
364 ToolVisibility::Whitelist(allowed) => {
365 seen_visibility = Some(allowed.clone());
366 }
367 other => panic!("expected whitelist visibility, got {other:?}"),
368 }
369 seen_preapproved =
370 Some(config.tool_config.approval_policy.preapproved.tools.clone());
371 }
372 SessionEvent::AssistantMessageAdded { .. } => {
373 saw_assistant_message = true;
374 }
375 _ => {}
376 }
377 }
378
379 assert!(saw_session_created, "expected SessionCreated event");
380 assert!(
381 saw_assistant_message,
382 "expected AssistantMessageAdded event"
383 );
384
385 let expected_visibility: HashSet<String> = allowed_tools.into_iter().collect();
386 let expected_preapproved: HashSet<String> = READ_ONLY_TOOL_NAMES
387 .iter()
388 .map(|name| (*name).to_string())
389 .chain(expected_visibility.iter().cloned())
390 .collect();
391 assert_eq!(seen_visibility, Some(expected_visibility));
392 assert_eq!(seen_preapproved, Some(expected_preapproved));
393
394 let captured_system = system_capture
395 .lock()
396 .expect("system capture lock poisoned")
397 .clone();
398 let captured_system = captured_system.expect("expected captured system prompt");
399 assert!(
400 captured_system.starts_with(system_prompt.as_str()),
401 "expected system prompt prefix, got: {captured_system:?}"
402 );
403 }
404}