Skip to main content

steer_core/tools/
agent_spawner_impl.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::Client as ApiClient;
8use crate::app::domain::runtime::RuntimeService;
9use crate::app::domain::session::EventStore;
10use crate::error::Error;
11use crate::model_registry::ModelRegistry;
12use crate::runners::OneShotRunner;
13use crate::session::state::{
14    ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, SessionToolConfig,
15    ToolApprovalPolicy, ToolApprovalPolicyOverrides, ToolVisibility, WorkspaceConfig,
16};
17use crate::tools::{ToolExecutor, ToolSystemBuilder};
18use crate::workspace::{RepoManager, Workspace, WorkspaceManager};
19
20use super::services::{AgentSpawner, SubAgentConfig, SubAgentError, SubAgentResult};
21
22pub struct DefaultAgentSpawner {
23    event_store: Arc<dyn EventStore>,
24    api_client: Arc<ApiClient>,
25    workspace: Arc<dyn Workspace>,
26    model_registry: Arc<ModelRegistry>,
27    workspace_manager: Option<Arc<dyn WorkspaceManager>>,
28    repo_manager: Option<Arc<dyn RepoManager>>,
29}
30
31impl DefaultAgentSpawner {
32    pub fn new(
33        event_store: Arc<dyn EventStore>,
34        api_client: Arc<ApiClient>,
35        workspace: Arc<dyn Workspace>,
36        model_registry: Arc<ModelRegistry>,
37        workspace_manager: Option<Arc<dyn WorkspaceManager>>,
38        repo_manager: Option<Arc<dyn RepoManager>>,
39    ) -> Self {
40        Self {
41            event_store,
42            api_client,
43            workspace,
44            model_registry,
45            workspace_manager,
46            repo_manager,
47        }
48    }
49
50    fn build_tool_executor(&self, workspace: Arc<dyn Workspace>) -> Arc<ToolExecutor> {
51        let mut tool_builder = ToolSystemBuilder::new(
52            workspace,
53            self.event_store.clone(),
54            self.api_client.clone(),
55            self.model_registry.clone(),
56        );
57
58        if let Some(manager) = &self.workspace_manager {
59            tool_builder = tool_builder.with_workspace_manager(manager.clone());
60        }
61        if let Some(manager) = &self.repo_manager {
62            tool_builder = tool_builder.with_repo_manager(manager.clone());
63        }
64
65        tool_builder.build()
66    }
67}
68
69#[async_trait]
70impl AgentSpawner for DefaultAgentSpawner {
71    async fn spawn(
72        &self,
73        config: SubAgentConfig,
74        cancel_token: CancellationToken,
75    ) -> Result<SubAgentResult, SubAgentError> {
76        let workspace = config
77            .workspace
78            .clone()
79            .unwrap_or_else(|| self.workspace.clone());
80        let workspace_path = workspace.working_directory().to_path_buf();
81
82        let visibility_tools: HashSet<String> = config.allowed_tools.iter().cloned().collect();
83        let mcp_backends = if config.allow_mcp_tools {
84            config.mcp_backends.clone()
85        } else {
86            Vec::new()
87        };
88
89        let tool_config = SessionToolConfig {
90            backends: mcp_backends,
91            visibility: ToolVisibility::All,
92            approval_policy: ToolApprovalPolicy::default(),
93            metadata: HashMap::new(),
94        };
95
96        let policy_overrides = SessionPolicyOverrides {
97            default_model: Some(config.model.clone()),
98            tool_visibility: Some(ToolVisibility::Whitelist(visibility_tools.clone())),
99            approval_policy: ToolApprovalPolicyOverrides {
100                preapproved: ApprovalRulesOverrides {
101                    tools: visibility_tools,
102                    per_tool: HashMap::new(),
103                },
104            },
105        };
106
107        let session_config = SessionConfig {
108            workspace: WorkspaceConfig::Local {
109                path: workspace_path,
110            },
111            workspace_ref: config.workspace_ref.clone(),
112            workspace_id: config.workspace_id,
113            repo_ref: config.repo_ref.clone(),
114            parent_session_id: Some(config.parent_session_id),
115            workspace_name: config.workspace_name.clone(),
116            tool_config,
117            system_prompt: config
118                .system_context
119                .as_ref()
120                .map(|context| context.prompt.clone()),
121            primary_agent_id: None,
122            policy_overrides,
123            title: None,
124            metadata: HashMap::new(),
125            default_model: config.model.clone(),
126            auto_compaction: crate::session::state::AutoCompactionConfig::default(),
127        };
128
129        let tool_executor = self.build_tool_executor(workspace);
130        let runtime = RuntimeService::spawn(
131            self.event_store.clone(),
132            self.api_client.clone(),
133            tool_executor,
134        );
135
136        let run_result = OneShotRunner::run_new_session_with_cancel(
137            &runtime.handle,
138            session_config,
139            config.prompt,
140            config.model.clone(),
141            cancel_token,
142        )
143        .await;
144
145        runtime.shutdown().await;
146
147        let run_result = run_result.map_err(|err| match err {
148            Error::Cancelled => SubAgentError::Cancelled,
149            Error::Api(error) => SubAgentError::Api(error.to_string()),
150            other => SubAgentError::Agent(other.to_string()),
151        })?;
152
153        Ok(SubAgentResult {
154            session_id: run_result.session_id,
155            final_message: run_result.final_message,
156        })
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::DefaultAgentSpawner;
163    use crate::api::Client as ApiClient;
164    use crate::api::{ApiError, CompletionResponse, Provider};
165    use crate::app::conversation::AssistantContent;
166    use crate::app::domain::event::SessionEvent;
167    use crate::app::domain::session::EventStore;
168    use crate::app::domain::session::event_store::InMemoryEventStore;
169    use crate::auth::ProviderRegistry;
170    use crate::config::model::builtin;
171    use crate::model_registry::ModelRegistry;
172    use crate::session::state::ToolVisibility;
173    use crate::test_utils::test_llm_config_provider;
174    use crate::tools::builtin_tools::READ_ONLY_TOOL_NAMES;
175    use crate::tools::services::AgentSpawner;
176    use crate::tools::services::SubAgentConfig;
177    use crate::workspace::WorkspaceConfig;
178    use std::collections::HashSet;
179    use std::sync::Arc;
180    use std::sync::Mutex as StdMutex;
181    use steer_tools::tools::edit::multi_edit::MULTI_EDIT_TOOL_NAME;
182    use steer_tools::tools::replace::REPLACE_TOOL_NAME;
183    use steer_tools::tools::{
184        BASH_TOOL_NAME, EDIT_TOOL_NAME, GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME,
185        VIEW_TOOL_NAME,
186    };
187    use tempfile::TempDir;
188    use tokio_util::sync::CancellationToken;
189
190    #[derive(Clone)]
191    struct RecordingProvider {
192        response: String,
193        last_system: Arc<StdMutex<Option<String>>>,
194    }
195
196    impl RecordingProvider {
197        fn new(response: impl Into<String>, last_system: Arc<StdMutex<Option<String>>>) -> Self {
198            Self {
199                response: response.into(),
200                last_system,
201            }
202        }
203    }
204
205    #[async_trait::async_trait]
206    impl Provider for RecordingProvider {
207        fn name(&self) -> &'static str {
208            "recording"
209        }
210
211        async fn complete(
212            &self,
213            _model_id: &crate::config::model::ModelId,
214            _messages: Vec<crate::app::conversation::Message>,
215            system: Option<crate::app::SystemContext>,
216            _tools: Option<Vec<steer_tools::ToolSchema>>,
217            _call_options: Option<crate::config::model::ModelParameters>,
218            _token: CancellationToken,
219        ) -> Result<CompletionResponse, ApiError> {
220            *self
221                .last_system
222                .lock()
223                .expect("system prompt lock poisoned") =
224                system.and_then(|context| context.render());
225
226            Ok(CompletionResponse {
227                content: vec![AssistantContent::Text {
228                    text: self.response.clone(),
229                }],
230                usage: None,
231            })
232        }
233    }
234
235    #[tokio::test]
236    async fn sub_agent_tool_executor_includes_builtin_tools() {
237        let temp_dir = TempDir::new().expect("create temp dir");
238        let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
239            path: temp_dir.path().to_path_buf(),
240        })
241        .await
242        .expect("create workspace");
243        let event_store = Arc::new(InMemoryEventStore::new());
244        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
245        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
246        let api_client = Arc::new(ApiClient::new_with_deps(
247            test_llm_config_provider().unwrap(),
248            provider_registry,
249            model_registry.clone(),
250        ));
251
252        let spawner = DefaultAgentSpawner::new(
253            event_store,
254            api_client,
255            workspace.clone(),
256            model_registry,
257            None,
258            None,
259        );
260
261        let tool_executor = spawner.build_tool_executor(workspace);
262        for tool_name in [
263            GLOB_TOOL_NAME,
264            GREP_TOOL_NAME,
265            LS_TOOL_NAME,
266            VIEW_TOOL_NAME,
267            EDIT_TOOL_NAME,
268            MULTI_EDIT_TOOL_NAME,
269            REPLACE_TOOL_NAME,
270            BASH_TOOL_NAME,
271        ] {
272            assert!(
273                tool_executor.is_builtin_tool(tool_name),
274                "expected sub-agent to have builtin tool: {tool_name}"
275            );
276        }
277    }
278
279    #[tokio::test]
280    async fn sub_agent_persists_events_and_uses_whitelist_visibility() {
281        let temp_dir = TempDir::new().expect("create temp dir");
282        let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
283            path: temp_dir.path().to_path_buf(),
284        })
285        .await
286        .expect("create workspace");
287        let event_store = Arc::new(InMemoryEventStore::new());
288        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
289        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
290        let api_client = Arc::new(ApiClient::new_with_deps(
291            test_llm_config_provider().unwrap(),
292            provider_registry,
293            model_registry.clone(),
294        ));
295
296        let system_capture = Arc::new(StdMutex::new(None));
297        let model = builtin::claude_sonnet_4_5();
298        api_client.insert_test_provider(
299            model.provider.clone(),
300            Arc::new(RecordingProvider::new("ok", system_capture.clone())),
301        );
302
303        let spawner = DefaultAgentSpawner::new(
304            event_store.clone(),
305            api_client,
306            workspace.clone(),
307            model_registry,
308            None,
309            None,
310        );
311
312        let parent_session_id = crate::app::domain::types::SessionId::new();
313        let allowed_tools = vec![
314            VIEW_TOOL_NAME.to_string(),
315            "mcp__alpha__allowed".to_string(),
316        ];
317        let system_prompt = "subagent system".to_string();
318
319        let config = SubAgentConfig {
320            parent_session_id,
321            prompt: "hello".to_string(),
322            allowed_tools: allowed_tools.clone(),
323            model: model.clone(),
324            system_context: Some(crate::app::SystemContext::new(system_prompt.clone())),
325            workspace: Some(workspace),
326            workspace_ref: None,
327            workspace_id: None,
328            repo_ref: None,
329            workspace_name: None,
330            mcp_backends: Vec::new(),
331            allow_mcp_tools: true,
332        };
333
334        let result = spawner
335            .spawn(config, CancellationToken::new())
336            .await
337            .expect("spawn sub-agent");
338
339        let events = event_store
340            .load_events(result.session_id)
341            .await
342            .expect("load events");
343
344        let mut saw_session_created = false;
345        let mut saw_assistant_message = false;
346        let mut seen_visibility = None;
347        let mut seen_preapproved = None;
348
349        for (_, event) in events {
350            match event {
351                SessionEvent::SessionCreated { config, .. } => {
352                    saw_session_created = true;
353                    assert_eq!(config.parent_session_id, Some(parent_session_id));
354                    let configured_system = config
355                        .system_prompt
356                        .as_deref()
357                        .expect("expected system prompt in session config");
358                    assert!(
359                        configured_system.starts_with(system_prompt.as_str()),
360                        "expected system prompt prefix, got: {configured_system:?}"
361                    );
362                    match &config.tool_config.visibility {
363                        ToolVisibility::Whitelist(allowed) => {
364                            seen_visibility = Some(allowed.clone());
365                        }
366                        other => panic!("expected whitelist visibility, got {other:?}"),
367                    }
368                    seen_preapproved =
369                        Some(config.tool_config.approval_policy.preapproved.tools.clone());
370                }
371                SessionEvent::AssistantMessageAdded { .. } => {
372                    saw_assistant_message = true;
373                }
374                _ => {}
375            }
376        }
377
378        assert!(saw_session_created, "expected SessionCreated event");
379        assert!(
380            saw_assistant_message,
381            "expected AssistantMessageAdded event"
382        );
383
384        let expected_visibility: HashSet<String> = allowed_tools.into_iter().collect();
385        let expected_preapproved: HashSet<String> = READ_ONLY_TOOL_NAMES
386            .iter()
387            .map(|name| (*name).to_string())
388            .chain(expected_visibility.iter().cloned())
389            .collect();
390        assert_eq!(seen_visibility, Some(expected_visibility));
391        assert_eq!(seen_preapproved, Some(expected_preapproved));
392
393        let captured_system = system_capture
394            .lock()
395            .expect("system capture lock poisoned")
396            .clone();
397        let captured_system = captured_system.expect("expected captured system prompt");
398        assert!(
399            captured_system.starts_with(system_prompt.as_str()),
400            "expected system prompt prefix, got: {captured_system:?}"
401        );
402    }
403}