Skip to main content

steer_core/tools/
agent_spawner_impl.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use tokio_util::sync::CancellationToken;
6
7use crate::api::Client as ApiClient;
8use crate::app::domain::runtime::RuntimeService;
9use crate::app::domain::session::EventStore;
10use crate::error::Error;
11use crate::model_registry::ModelRegistry;
12use crate::runners::OneShotRunner;
13use crate::session::state::{
14    ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, SessionToolConfig,
15    ToolApprovalPolicy, ToolApprovalPolicyOverrides, ToolVisibility, UnapprovedBehavior,
16    WorkspaceConfig,
17};
18use crate::tools::{ToolExecutor, ToolSystemBuilder};
19use crate::workspace::{RepoManager, Workspace, WorkspaceManager};
20
21use super::services::{AgentSpawner, SubAgentConfig, SubAgentError, SubAgentResult};
22
23pub struct DefaultAgentSpawner {
24    event_store: Arc<dyn EventStore>,
25    api_client: Arc<ApiClient>,
26    workspace: Arc<dyn Workspace>,
27    model_registry: Arc<ModelRegistry>,
28    workspace_manager: Option<Arc<dyn WorkspaceManager>>,
29    repo_manager: Option<Arc<dyn RepoManager>>,
30}
31
32impl DefaultAgentSpawner {
33    pub fn new(
34        event_store: Arc<dyn EventStore>,
35        api_client: Arc<ApiClient>,
36        workspace: Arc<dyn Workspace>,
37        model_registry: Arc<ModelRegistry>,
38        workspace_manager: Option<Arc<dyn WorkspaceManager>>,
39        repo_manager: Option<Arc<dyn RepoManager>>,
40    ) -> Self {
41        Self {
42            event_store,
43            api_client,
44            workspace,
45            model_registry,
46            workspace_manager,
47            repo_manager,
48        }
49    }
50
51    fn build_tool_executor(&self, workspace: Arc<dyn Workspace>) -> Arc<ToolExecutor> {
52        let mut tool_builder = ToolSystemBuilder::new(
53            workspace,
54            self.event_store.clone(),
55            self.api_client.clone(),
56            self.model_registry.clone(),
57        );
58
59        if let Some(manager) = &self.workspace_manager {
60            tool_builder = tool_builder.with_workspace_manager(manager.clone());
61        }
62        if let Some(manager) = &self.repo_manager {
63            tool_builder = tool_builder.with_repo_manager(manager.clone());
64        }
65
66        tool_builder.build()
67    }
68}
69
70#[async_trait]
71impl AgentSpawner for DefaultAgentSpawner {
72    async fn spawn(
73        &self,
74        config: SubAgentConfig,
75        cancel_token: CancellationToken,
76    ) -> Result<SubAgentResult, SubAgentError> {
77        let workspace = config
78            .workspace
79            .clone()
80            .unwrap_or_else(|| self.workspace.clone());
81        let workspace_path = workspace.working_directory().to_path_buf();
82
83        let visibility_tools: HashSet<String> = config.allowed_tools.iter().cloned().collect();
84        let mcp_backends = if config.allow_mcp_tools {
85            config.mcp_backends.clone()
86        } else {
87            Vec::new()
88        };
89
90        let tool_config = SessionToolConfig {
91            backends: mcp_backends,
92            visibility: ToolVisibility::All,
93            approval_policy: ToolApprovalPolicy::default(),
94            metadata: HashMap::new(),
95        };
96
97        let policy_overrides = SessionPolicyOverrides {
98            default_model: Some(config.model.clone()),
99            tool_visibility: Some(ToolVisibility::Whitelist(visibility_tools.clone())),
100            approval_policy: ToolApprovalPolicyOverrides {
101                default_behavior: Some(UnapprovedBehavior::Prompt),
102                preapproved: ApprovalRulesOverrides {
103                    tools: visibility_tools,
104                    per_tool: HashMap::new(),
105                },
106            },
107        };
108
109        let session_config = SessionConfig {
110            workspace: WorkspaceConfig::Local {
111                path: workspace_path,
112            },
113            workspace_ref: config.workspace_ref.clone(),
114            workspace_id: config.workspace_id,
115            repo_ref: config.repo_ref.clone(),
116            parent_session_id: Some(config.parent_session_id),
117            workspace_name: config.workspace_name.clone(),
118            tool_config,
119            system_prompt: config
120                .system_context
121                .as_ref()
122                .map(|context| context.prompt.clone()),
123            primary_agent_id: None,
124            policy_overrides,
125            metadata: HashMap::new(),
126            default_model: config.model.clone(),
127        };
128
129        let tool_executor = self.build_tool_executor(workspace);
130        let runtime = RuntimeService::spawn(
131            self.event_store.clone(),
132            self.api_client.clone(),
133            tool_executor,
134        );
135
136        let run_result = OneShotRunner::run_new_session_with_cancel(
137            &runtime.handle,
138            session_config,
139            config.prompt,
140            config.model.clone(),
141            cancel_token,
142        )
143        .await;
144
145        runtime.shutdown().await;
146
147        let run_result = run_result.map_err(|err| match err {
148            Error::Cancelled => SubAgentError::Cancelled,
149            Error::Api(error) => SubAgentError::Api(error.to_string()),
150            other => SubAgentError::Agent(other.to_string()),
151        })?;
152
153        Ok(SubAgentResult {
154            session_id: run_result.session_id,
155            final_message: run_result.final_message,
156        })
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::DefaultAgentSpawner;
163    use crate::api::Client as ApiClient;
164    use crate::api::{ApiError, CompletionResponse, Provider};
165    use crate::app::conversation::AssistantContent;
166    use crate::app::domain::event::SessionEvent;
167    use crate::app::domain::session::EventStore;
168    use crate::app::domain::session::event_store::InMemoryEventStore;
169    use crate::auth::ProviderRegistry;
170    use crate::config::model::builtin;
171    use crate::model_registry::ModelRegistry;
172    use crate::session::state::ToolVisibility;
173    use crate::test_utils::test_llm_config_provider;
174    use crate::tools::services::AgentSpawner;
175    use crate::tools::services::SubAgentConfig;
176    use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
177    use crate::workspace::WorkspaceConfig;
178    use std::collections::HashSet;
179    use std::sync::Arc;
180    use std::sync::Mutex as StdMutex;
181    use steer_tools::tools::edit::multi_edit::MULTI_EDIT_TOOL_NAME;
182    use steer_tools::tools::replace::REPLACE_TOOL_NAME;
183    use steer_tools::tools::{
184        BASH_TOOL_NAME, EDIT_TOOL_NAME, GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME,
185        VIEW_TOOL_NAME,
186    };
187    use tempfile::TempDir;
188    use tokio_util::sync::CancellationToken;
189
190    #[derive(Clone)]
191    struct RecordingProvider {
192        response: String,
193        last_system: Arc<StdMutex<Option<String>>>,
194    }
195
196    impl RecordingProvider {
197        fn new(response: impl Into<String>, last_system: Arc<StdMutex<Option<String>>>) -> Self {
198            Self {
199                response: response.into(),
200                last_system,
201            }
202        }
203    }
204
205    #[async_trait::async_trait]
206    impl Provider for RecordingProvider {
207        fn name(&self) -> &'static str {
208            "recording"
209        }
210
211        async fn complete(
212            &self,
213            _model_id: &crate::config::model::ModelId,
214            _messages: Vec<crate::app::conversation::Message>,
215            system: Option<crate::app::SystemContext>,
216            _tools: Option<Vec<steer_tools::ToolSchema>>,
217            _call_options: Option<crate::config::model::ModelParameters>,
218            _token: CancellationToken,
219        ) -> Result<CompletionResponse, ApiError> {
220            *self
221                .last_system
222                .lock()
223                .expect("system prompt lock poisoned") =
224                system.and_then(|context| context.render());
225
226            Ok(CompletionResponse {
227                content: vec![AssistantContent::Text {
228                    text: self.response.clone(),
229                }],
230            })
231        }
232    }
233
234    #[tokio::test]
235    async fn sub_agent_tool_executor_includes_static_tools() {
236        let temp_dir = TempDir::new().expect("create temp dir");
237        let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
238            path: temp_dir.path().to_path_buf(),
239        })
240        .await
241        .expect("create workspace");
242        let event_store = Arc::new(InMemoryEventStore::new());
243        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
244        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
245        let api_client = Arc::new(ApiClient::new_with_deps(
246            test_llm_config_provider().unwrap(),
247            provider_registry,
248            model_registry.clone(),
249        ));
250
251        let spawner = DefaultAgentSpawner::new(
252            event_store,
253            api_client,
254            workspace.clone(),
255            model_registry,
256            None,
257            None,
258        );
259
260        let tool_executor = spawner.build_tool_executor(workspace);
261        for tool_name in [
262            GLOB_TOOL_NAME,
263            GREP_TOOL_NAME,
264            LS_TOOL_NAME,
265            VIEW_TOOL_NAME,
266            EDIT_TOOL_NAME,
267            MULTI_EDIT_TOOL_NAME,
268            REPLACE_TOOL_NAME,
269            BASH_TOOL_NAME,
270        ] {
271            assert!(
272                tool_executor.is_static_tool(tool_name),
273                "expected sub-agent to have static tool: {tool_name}"
274            );
275        }
276    }
277
278    #[tokio::test]
279    async fn sub_agent_persists_events_and_uses_whitelist_visibility() {
280        let temp_dir = TempDir::new().expect("create temp dir");
281        let workspace = crate::workspace::create_workspace(&WorkspaceConfig::Local {
282            path: temp_dir.path().to_path_buf(),
283        })
284        .await
285        .expect("create workspace");
286        let event_store = Arc::new(InMemoryEventStore::new());
287        let model_registry = Arc::new(ModelRegistry::load(&[]).expect("model registry"));
288        let provider_registry = Arc::new(ProviderRegistry::load(&[]).expect("provider registry"));
289        let api_client = Arc::new(ApiClient::new_with_deps(
290            test_llm_config_provider().unwrap(),
291            provider_registry,
292            model_registry.clone(),
293        ));
294
295        let system_capture = Arc::new(StdMutex::new(None));
296        let model = builtin::claude_sonnet_4_5();
297        api_client.insert_test_provider(
298            model.provider.clone(),
299            Arc::new(RecordingProvider::new("ok", system_capture.clone())),
300        );
301
302        let spawner = DefaultAgentSpawner::new(
303            event_store.clone(),
304            api_client,
305            workspace.clone(),
306            model_registry,
307            None,
308            None,
309        );
310
311        let parent_session_id = crate::app::domain::types::SessionId::new();
312        let allowed_tools = vec![
313            VIEW_TOOL_NAME.to_string(),
314            "mcp__alpha__allowed".to_string(),
315        ];
316        let system_prompt = "subagent system".to_string();
317
318        let config = SubAgentConfig {
319            parent_session_id,
320            prompt: "hello".to_string(),
321            allowed_tools: allowed_tools.clone(),
322            model: model.clone(),
323            system_context: Some(crate::app::SystemContext::new(system_prompt.clone())),
324            workspace: Some(workspace),
325            workspace_ref: None,
326            workspace_id: None,
327            repo_ref: None,
328            workspace_name: None,
329            mcp_backends: Vec::new(),
330            allow_mcp_tools: true,
331        };
332
333        let result = spawner
334            .spawn(config, CancellationToken::new())
335            .await
336            .expect("spawn sub-agent");
337
338        let events = event_store
339            .load_events(result.session_id)
340            .await
341            .expect("load events");
342
343        let mut saw_session_created = false;
344        let mut saw_assistant_message = false;
345        let mut seen_visibility = None;
346        let mut seen_preapproved = None;
347
348        for (_, event) in events {
349            match event {
350                SessionEvent::SessionCreated { config, .. } => {
351                    saw_session_created = true;
352                    assert_eq!(config.parent_session_id, Some(parent_session_id));
353                    let configured_system = config
354                        .system_prompt
355                        .as_deref()
356                        .expect("expected system prompt in session config");
357                    assert!(
358                        configured_system.starts_with(system_prompt.as_str()),
359                        "expected system prompt prefix, got: {configured_system:?}"
360                    );
361                    match &config.tool_config.visibility {
362                        ToolVisibility::Whitelist(allowed) => {
363                            seen_visibility = Some(allowed.clone());
364                        }
365                        other => panic!("expected whitelist visibility, got {other:?}"),
366                    }
367                    seen_preapproved =
368                        Some(config.tool_config.approval_policy.preapproved.tools.clone());
369                }
370                SessionEvent::AssistantMessageAdded { .. } => {
371                    saw_assistant_message = true;
372                }
373                _ => {}
374            }
375        }
376
377        assert!(saw_session_created, "expected SessionCreated event");
378        assert!(
379            saw_assistant_message,
380            "expected AssistantMessageAdded event"
381        );
382
383        let expected_visibility: HashSet<String> = allowed_tools.into_iter().collect();
384        let expected_preapproved: HashSet<String> = READ_ONLY_TOOL_NAMES
385            .iter()
386            .map(|name| (*name).to_string())
387            .chain(expected_visibility.iter().cloned())
388            .collect();
389        assert_eq!(seen_visibility, Some(expected_visibility));
390        assert_eq!(seen_preapproved, Some(expected_preapproved));
391
392        let captured_system = system_capture
393            .lock()
394            .expect("system capture lock poisoned")
395            .clone();
396        let captured_system = captured_system.expect("expected captured system prompt");
397        assert!(
398            captured_system.starts_with(system_prompt.as_str()),
399            "expected system prompt prefix, got: {captured_system:?}"
400        );
401    }
402}