Skip to main content

steer_core/tools/
agent_spawner_impl.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::Client as ApiClient;
8use crate::app::domain::runtime::RuntimeService;
9use crate::app::domain::session::EventStore;
10use crate::error::Error;
11use crate::model_registry::ModelRegistry;
12use crate::runners::OneShotRunner;
13use crate::session::state::{
14    ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, SessionToolConfig,
15    ToolApprovalPolicy, ToolApprovalPolicyOverrides, ToolVisibility, UnapprovedBehavior,
16    WorkspaceConfig,
17};
18use crate::tools::{ToolExecutor, ToolSystemBuilder};
19use crate::workspace::{RepoManager, Workspace, WorkspaceManager};
20
21use super::services::{AgentSpawner, SubAgentConfig, SubAgentError, SubAgentResult};
22
23pub struct DefaultAgentSpawner {
24    event_store: Arc<dyn EventStore>,
25    api_client: Arc<ApiClient>,
26    workspace: Arc<dyn Workspace>,
27    model_registry: Arc<ModelRegistry>,
28    workspace_manager: Option<Arc<dyn WorkspaceManager>>,
29    repo_manager: Option<Arc<dyn RepoManager>>,
30}
31
32impl DefaultAgentSpawner {
33    pub fn new(
34        event_store: Arc<dyn EventStore>,
35        api_client: Arc<ApiClient>,
36        workspace: Arc<dyn Workspace>,
37        model_registry: Arc<ModelRegistry>,
38        workspace_manager: Option<Arc<dyn WorkspaceManager>>,
39        repo_manager: Option<Arc<dyn RepoManager>>,
40    ) -> Self {
41        Self {
42            event_store,
43            api_client,
44            workspace,
45            model_registry,
46            workspace_manager,
47            repo_manager,
48        }
49    }
50
51    fn build_tool_executor(&self, workspace: Arc<dyn Workspace>) -> Arc<ToolExecutor> {
52        let mut tool_builder = ToolSystemBuilder::new(
53            workspace,
54            self.event_store.clone(),
55            self.api_client.clone(),
56            self.model_registry.clone(),
57        );
58
59        if let Some(manager) = &self.workspace_manager {
60            tool_builder = tool_builder.with_workspace_manager(manager.clone());
61        }
62        if let Some(manager) = &self.repo_manager {
63            tool_builder = tool_builder.with_repo_manager(manager.clone());
64        }
65
66        tool_builder.build()
67    }
68}
69
70#[async_trait]
71impl AgentSpawner for DefaultAgentSpawner {
72    async fn spawn(
73        &self,
74        config: SubAgentConfig,
75        cancel_token: CancellationToken,
76    ) -> Result<SubAgentResult, SubAgentError> {
77        let workspace = config
78            .workspace
79            .clone()
80            .unwrap_or_else(|| self.workspace.clone());
81        let workspace_path = workspace.working_directory().to_path_buf();
82
83        let visibility_tools: HashSet<String> = config.allowed_tools.iter().cloned().collect();
84        let mcp_backends = if config.allow_mcp_tools {
85            config.mcp_backends.clone()
86        } else {
87            Vec::new()
88        };
89
90        let tool_config = SessionToolConfig {
91            backends: mcp_backends,
92            visibility: ToolVisibility::All,
93            approval_policy: ToolApprovalPolicy::default(),
94            metadata: HashMap::new(),
95        };
96
97        let policy_overrides = SessionPolicyOverrides {
98            default_model: Some(config.model.clone()),
99            tool_visibility: Some(ToolVisibility::Whitelist(visibility_tools.clone())),
100            approval_policy: ToolApprovalPolicyOverrides {
101                default_behavior: Some(UnapprovedBehavior::Prompt),
102                preapproved: ApprovalRulesOverrides {
103                    tools: visibility_tools,
104                    per_tool: HashMap::new(),
105                },
106            },
107        };
108
109        let session_config = SessionConfig {
110            workspace: WorkspaceConfig::Local {
111                path: workspace_path,
112            },
113            workspace_ref: config.workspace_ref.clone(),
114            workspace_id: config.workspace_id,
115            repo_ref: config.repo_ref.clone(),
116            parent_session_id: Some(config.parent_session_id),
117            workspace_name: config.workspace_name.clone(),
118            tool_config,
119            system_prompt: config
120                .system_context
121                .as_ref()
122                .map(|context| context.prompt.clone()),
123            primary_agent_id: None,
124            policy_overrides,
125            metadata: HashMap::new(),
126            default_model: config.model.clone(),
127            auto_compaction: crate::session::state::AutoCompactionConfig::default(),
128        };
129
130        let tool_executor = self.build_tool_executor(workspace);
131        let runtime = RuntimeService::spawn(
132            self.event_store.clone(),
133            self.api_client.clone(),
134            tool_executor,
135        );
136
137        let run_result = OneShotRunner::run_new_session_with_cancel(
138            &runtime.handle,
139            session_config,
140            config.prompt,
141            config.model.clone(),
142            cancel_token,
143        )
144        .await;
145
146        runtime.shutdown().await;
147
148        let run_result = run_result.map_err(|err| match err {
149            Error::Cancelled => SubAgentError::Cancelled,
150            Error::Api(error) => SubAgentError::Api(error.to_string()),
151            other => SubAgentError::Agent(other.to_string()),
152        })?;
153
154        Ok(SubAgentResult {
155            session_id: run_result.session_id,
156            final_message: run_result.final_message,
157        })
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::DefaultAgentSpawner;
164    use crate::api::Client as ApiClient;
165    use crate::api::{ApiError, CompletionResponse, Provider};
166    use crate::app::conversation::AssistantContent;
167    use crate::app::domain::event::SessionEvent;
168    use crate::app::domain::session::EventStore;
169    use crate::app::domain::session::event_store::InMemoryEventStore;
170    use crate::auth::ProviderRegistry;
171    use crate::config::model::builtin;
172    use crate::model_registry::ModelRegistry;
173    use crate::session::state::ToolVisibility;
174    use crate::test_utils::test_llm_config_provider;
175    use crate::tools::services::AgentSpawner;
176    use crate::tools::services::SubAgentConfig;
177    use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
178    use crate::workspace::WorkspaceConfig;
179    use std::collections::HashSet;
180    use std::sync::Arc;
181    use std::sync::Mutex as StdMutex;
182    use steer_tools::tools::edit::multi_edit::MULTI_EDIT_TOOL_NAME;
183    use steer_tools::tools::replace::REPLACE_TOOL_NAME;
184    use steer_tools::tools::{
185        BASH_TOOL_NAME, EDIT_TOOL_NAME, GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME,
186        VIEW_TOOL_NAME,
187    };
188    use tempfile::TempDir;
189    use tokio_util::sync::CancellationToken;
190
191    #[derive(Clone)]
192    struct RecordingProvider {
193        response: String,
194        last_system: Arc<StdMutex<Option<String>>>,
195    }
196
197    impl RecordingProvider {
198        fn new(response: impl Into<String>, last_system: Arc<StdMutex<Option<String>>>) -> Self {
199            Self {
200                response: response.into(),
201                last_system,
202            }
203        }
204    }
205
206    #[async_trait::async_trait]
207    impl Provider for RecordingProvider {
208        fn name(&self) -> &'static str {
209            "recording"
210        }
211
212        async fn complete(
213            &self,
214            _model_id: &crate::config::model::ModelId,
215            _messages: Vec<crate::app::conversation::Message>,
216            system: Option<crate::app::SystemContext>,
217            _tools: Option<Vec<steer_tools::ToolSchema>>,
218            _call_options: Option<crate::config::model::ModelParameters>,
219            _token: CancellationToken,
220        ) -> Result<CompletionResponse, ApiError> {
221            *self
222                .last_system
223                .lock()
224                .expect("system prompt lock poisoned") =
225                system.and_then(|context| context.render());
226
227            Ok(CompletionResponse {
228                content: vec![AssistantContent::Text {
229                    text: self.response.clone(),
230                }],
231                usage: None,
232            })
233        }
234    }
235
236    #[tokio::test]
237    async fn sub_agent_tool_executor_includes_static_tools() {
238        let temp_dir = TempDir::new().expect("create temp dir");
239        let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
240            path: temp_dir.path().to_path_buf(),
241        })
242        .await
243        .expect("create workspace");
244        let event_store = Arc::new(InMemoryEventStore::new());
245        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
246        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
247        let api_client = Arc::new(ApiClient::new_with_deps(
248            test_llm_config_provider().unwrap(),
249            provider_registry,
250            model_registry.clone(),
251        ));
252
253        let spawner = DefaultAgentSpawner::new(
254            event_store,
255            api_client,
256            workspace.clone(),
257            model_registry,
258            None,
259            None,
260        );
261
262        let tool_executor = spawner.build_tool_executor(workspace);
263        for tool_name in [
264            GLOB_TOOL_NAME,
265            GREP_TOOL_NAME,
266            LS_TOOL_NAME,
267            VIEW_TOOL_NAME,
268            EDIT_TOOL_NAME,
269            MULTI_EDIT_TOOL_NAME,
270            REPLACE_TOOL_NAME,
271            BASH_TOOL_NAME,
272        ] {
273            assert!(
274                tool_executor.is_static_tool(tool_name),
275                "expected sub-agent to have static tool: {tool_name}"
276            );
277        }
278    }
279
280    #[tokio::test]
281    async fn sub_agent_persists_events_and_uses_whitelist_visibility() {
282        let temp_dir = TempDir::new().expect("create temp dir");
283        let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
284            path: temp_dir.path().to_path_buf(),
285        })
286        .await
287        .expect("create workspace");
288        let event_store = Arc::new(InMemoryEventStore::new());
289        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
290        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
291        let api_client = Arc::new(ApiClient::new_with_deps(
292            test_llm_config_provider().unwrap(),
293            provider_registry,
294            model_registry.clone(),
295        ));
296
297        let system_capture = Arc::new(StdMutex::new(None));
298        let model = builtin::claude_sonnet_4_5();
299        api_client.insert_test_provider(
300            model.provider.clone(),
301            Arc::new(RecordingProvider::new("ok", system_capture.clone())),
302        );
303
304        let spawner = DefaultAgentSpawner::new(
305            event_store.clone(),
306            api_client,
307            workspace.clone(),
308            model_registry,
309            None,
310            None,
311        );
312
313        let parent_session_id = crate::app::domain::types::SessionId::new();
314        let allowed_tools = vec![
315            VIEW_TOOL_NAME.to_string(),
316            "mcp__alpha__allowed".to_string(),
317        ];
318        let system_prompt = "subagent system".to_string();
319
320        let config = SubAgentConfig {
321            parent_session_id,
322            prompt: "hello".to_string(),
323            allowed_tools: allowed_tools.clone(),
324            model: model.clone(),
325            system_context: Some(crate::app::SystemContext::new(system_prompt.clone())),
326            workspace: Some(workspace),
327            workspace_ref: None,
328            workspace_id: None,
329            repo_ref: None,
330            workspace_name: None,
331            mcp_backends: Vec::new(),
332            allow_mcp_tools: true,
333        };
334
335        let result = spawner
336            .spawn(config, CancellationToken::new())
337            .await
338            .expect("spawn sub-agent");
339
340        let events = event_store
341            .load_events(result.session_id)
342            .await
343            .expect("load events");
344
345        let mut saw_session_created = false;
346        let mut saw_assistant_message = false;
347        let mut seen_visibility = None;
348        let mut seen_preapproved = None;
349
350        for (_, event) in events {
351            match event {
352                SessionEvent::SessionCreated { config, .. } => {
353                    saw_session_created = true;
354                    assert_eq!(config.parent_session_id, Some(parent_session_id));
355                    let configured_system = config
356                        .system_prompt
357                        .as_deref()
358                        .expect("expected system prompt in session config");
359                    assert!(
360                        configured_system.starts_with(system_prompt.as_str()),
361                        "expected system prompt prefix, got: {configured_system:?}"
362                    );
363                    match &config.tool_config.visibility {
364                        ToolVisibility::Whitelist(allowed) => {
365                            seen_visibility = Some(allowed.clone());
366                        }
367                        other => panic!("expected whitelist visibility, got {other:?}"),
368                    }
369                    seen_preapproved =
370                        Some(config.tool_config.approval_policy.preapproved.tools.clone());
371                }
372                SessionEvent::AssistantMessageAdded { .. } => {
373                    saw_assistant_message = true;
374                }
375                _ => {}
376            }
377        }
378
379        assert!(saw_session_created, "expected SessionCreated event");
380        assert!(
381            saw_assistant_message,
382            "expected AssistantMessageAdded event"
383        );
384
385        let expected_visibility: HashSet<String> = allowed_tools.into_iter().collect();
386        let expected_preapproved: HashSet<String> = READ_ONLY_TOOL_NAMES
387            .iter()
388            .map(|name| (*name).to_string())
389            .chain(expected_visibility.iter().cloned())
390            .collect();
391        assert_eq!(seen_visibility, Some(expected_visibility));
392        assert_eq!(seen_preapproved, Some(expected_preapproved));
393
394        let captured_system = system_capture
395            .lock()
396            .expect("system capture lock poisoned")
397            .clone();
398        let captured_system = captured_system.expect("expected captured system prompt");
399        assert!(
400            captured_system.starts_with(system_prompt.as_str()),
401            "expected system prompt prefix, got: {captured_system:?}"
402        );
403    }
404}