Skip to main content

steer_core/tools/static_tools/
dispatch_agent.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::agents::{
6    McpAccessPolicy, agent_spec, agent_specs, agent_specs_prompt, default_agent_spec_id,
7};
8use crate::app::domain::event::SessionEvent;
9use crate::app::domain::runtime::RuntimeService;
10use crate::app::domain::types::SessionId;
11use crate::app::validation::ValidatorRegistry;
12use crate::config::model::builtin::claude_sonnet_4_5 as default_model;
13use crate::runners::OneShotRunner;
14use crate::session::state::BackendConfig;
15use crate::tools::capability::Capabilities;
16use crate::tools::services::{SubAgentConfig, SubAgentError, ToolServices};
17use crate::tools::static_tool::{StaticTool, StaticToolContext, StaticToolError};
18use crate::tools::{BackendRegistry, ToolExecutor, ToolRegistry};
19use crate::workspace::{
20    CreateWorkspaceRequest, EnvironmentId, RepoRef, VcsKind, VcsStatus, Workspace,
21    WorkspaceCreateStrategy, WorkspaceRef, create_workspace_from_session_config,
22};
23use steer_tools::ToolSpec;
24use steer_tools::result::{AgentResult, AgentWorkspaceInfo, AgentWorkspaceRevision};
25use steer_tools::tools::dispatch_agent::{
26    DispatchAgentError, DispatchAgentParams, DispatchAgentTarget, DispatchAgentToolSpec,
27    WorkspaceTarget,
28};
29use steer_tools::tools::{GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
30use tracing::warn;
31
32use super::{
33    AstGrepTool, BashTool, EditTool, FetchTool, GlobTool, GrepTool, LsTool, MultiEditTool,
34    ReplaceTool, TodoReadTool, TodoWriteTool, ViewTool, workspace_manager_op_error,
35    workspace_op_error,
36};
37
38fn dispatch_agent_description() -> String {
39    let agent_specs = agent_specs_prompt();
40    let agent_specs_block = if agent_specs.is_empty() {
41        "No agent specs registered.".to_string()
42    } else {
43        agent_specs
44    };
45
46    format!(
47        r#"Launch a new agent to help with a focused task. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you.
48
49    When to use the Agent tool:
50    - If you are searching for a keyword like "config" or "logger", or for questions like "which file does X?", the Agent tool is strongly recommended
51
52    When NOT to use the Agent tool:
53    - If you want to read a specific file path, use the {} or {} tool instead of the Agent tool, to find the match more quickly
54    - If you are searching for a specific class definition like "class Foo", use the {} tool instead, to find the match more quickly
55    - If you are searching for code within a specific file or set of 2-3 files, use the {} tool instead, to find the match more quickly
56
57    Usage notes:
58    1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses
59    2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.
60    3. Each invocation returns a session_id. Pass it back via `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue the conversation with the same agent.
61    4. When `target.session` is `resume`, the session_id must refer to a child of the current session. The `agent` and `workspace` options are ignored and the existing session config is used.
62    5. The agent's outputs should generally be trusted
63    6. IMPORTANT: Only some agent specs include write tools. Use a build agent if the task requires editing files.
64    7. IMPORTANT: New workspaces are preserved (not auto-deleted). Clean them up manually if needed.
65    8. If the agent spec omits a model, the parent session's default model is used.
66    9. IMPORTANT: Do NOT include `Repo: <path>`, `CWD: <path>`, or similar path headers in your prompt. The sub-agent already receives its working directory via system instructions.
67    10. IMPORTANT: If `target.session` is `new` and `workspace.location` is `new`, the sub-agent runs in the newly created workspace path, which may differ from the caller's current directory.
68
69Workspace options:
70- `workspace: {{ "location": "current" }}` to run in the current workspace
71- `workspace: {{ "location": "new", "name": "..." }}` to run in a fresh workspace (jj workspace or git worktree)
72- `location` is a logical workspace selector, not a filesystem path
73
74    Session options:
75    - `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue a prior dispatch_agent session
76
77    New session options:
78    - `target: {{ "session": "new", "workspace": {{ "location": "current" }} }}` to run in the current workspace
79    - `target: {{ "session": "new", "workspace": {{ "location": "new", "name": "..." }} }}` to run in a new workspace
80    - `target: {{ "session": "new", "workspace": {{ "location": "current" }}, "agent": "<id>" }}` selects an agent spec (defaults to "{default_agent}")
81
82    {agent_specs_block}"#,
83        VIEW_TOOL_NAME,
84        LS_TOOL_NAME,
85        GREP_TOOL_NAME,
86        GREP_TOOL_NAME,
87        default_agent = default_agent_spec_id(),
88        agent_specs_block = agent_specs_block
89    )
90}
91
92pub struct DispatchAgentTool;
93
94#[async_trait]
95impl StaticTool for DispatchAgentTool {
96    type Params = DispatchAgentParams;
97    type Output = AgentResult;
98    type Spec = DispatchAgentToolSpec;
99
100    const DESCRIPTION: &'static str = "Launch a sub-agent to search for files or code";
101    const REQUIRES_APPROVAL: bool = false;
102    const REQUIRED_CAPABILITIES: Capabilities = Capabilities::AGENT;
103
104    fn schema() -> steer_tools::ToolSchema {
105        let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
106            s.inline_subschemas = true;
107        });
108        let schema_gen = settings.into_generator();
109        let input_schema = schema_gen.into_root_schema_for::<Self::Params>();
110
111        steer_tools::ToolSchema {
112            name: Self::Spec::NAME.to_string(),
113            display_name: Self::Spec::DISPLAY_NAME.to_string(),
114            description: dispatch_agent_description(),
115            input_schema: input_schema.into(),
116        }
117    }
118
119    async fn execute(
120        &self,
121        params: Self::Params,
122        ctx: &StaticToolContext,
123    ) -> Result<Self::Output, StaticToolError<DispatchAgentError>> {
124        let DispatchAgentParams { prompt, target } = params;
125
126        let (workspace_target, agent) = match target {
127            DispatchAgentTarget::Resume { session_id } => {
128                let session_id = SessionId::parse(&session_id).ok_or_else(|| {
129                    StaticToolError::invalid_params(format!("Invalid session_id '{session_id}'"))
130                })?;
131                return resume_agent_session(session_id, prompt, ctx).await;
132            }
133            DispatchAgentTarget::New { workspace, agent } => (workspace, agent),
134        };
135
136        let spawner = ctx
137            .services
138            .agent_spawner()
139            .ok_or_else(|| StaticToolError::missing_capability("agent_spawner"))?;
140
141        let base_workspace = ctx.services.workspace.clone();
142        let base_path = base_workspace.working_directory().to_path_buf();
143
144        let mut workspace = base_workspace.clone();
145        let mut workspace_ref = None;
146        let mut workspace_id = None;
147        let mut workspace_name = None;
148        let mut repo_id = None;
149        let mut repo_ref = None;
150
151        if let Some(manager) = ctx.services.workspace_manager()
152            && let Ok(info) = manager.resolve_workspace(&base_path).await
153        {
154            workspace_id = Some(info.workspace_id);
155            workspace_name.clone_from(&info.name);
156            repo_id = Some(info.repo_id);
157            workspace_ref = Some(WorkspaceRef {
158                environment_id: info.environment_id,
159                workspace_id: info.workspace_id,
160                repo_id: info.repo_id,
161            });
162        }
163
164        if let Some(manager) = ctx.services.repo_manager() {
165            let repo_env_id = workspace_ref
166                .as_ref()
167                .map_or_else(EnvironmentId::local, |reference| reference.environment_id);
168            if let Ok(info) = manager.resolve_repo(repo_env_id, &base_path).await {
169                if repo_id.is_none() {
170                    repo_id = Some(info.repo_id);
171                }
172                repo_ref = Some(RepoRef {
173                    environment_id: info.environment_id,
174                    repo_id: info.repo_id,
175                    root_path: info.root_path,
176                    vcs_kind: info.vcs_kind,
177                });
178            }
179        }
180
181        let mut new_workspace = false;
182        let mut requested_workspace_name = None;
183
184        match &workspace_target {
185            WorkspaceTarget::Current => {}
186            WorkspaceTarget::New { name } => {
187                new_workspace = true;
188                requested_workspace_name = Some(name.clone());
189            }
190        }
191
192        let mut created_workspace_id = None;
193        let mut status_manager = None;
194
195        if new_workspace {
196            let manager = ctx
197                .services
198                .workspace_manager()
199                .ok_or_else(|| StaticToolError::missing_capability("workspace_manager"))?;
200            status_manager = Some(manager.clone());
201
202            let base_repo_id = repo_id.ok_or_else(|| {
203                StaticToolError::execution(DispatchAgentError::WorkspaceUnavailable {
204                    message:
205                        "Current path is not a supported workspace; cannot create new workspace"
206                            .to_string(),
207                })
208            })?;
209
210            let strategy = match repo_ref
211                .as_ref()
212                .and_then(|reference| reference.vcs_kind.as_ref())
213            {
214                Some(VcsKind::Git) => WorkspaceCreateStrategy::GitWorktree,
215                _ => WorkspaceCreateStrategy::JjWorkspace,
216            };
217
218            let create_request = CreateWorkspaceRequest {
219                repo_id: base_repo_id,
220                name: requested_workspace_name.clone(),
221                parent_workspace_id: workspace_id,
222                strategy,
223            };
224
225            let info = manager
226                .create_workspace(create_request)
227                .await
228                .map_err(|e| {
229                    StaticToolError::execution(DispatchAgentError::Workspace(
230                        workspace_manager_op_error(e),
231                    ))
232                })?;
233
234            workspace = manager
235                .open_workspace(info.workspace_id)
236                .await
237                .map_err(|e| {
238                    StaticToolError::execution(DispatchAgentError::Workspace(
239                        workspace_manager_op_error(e),
240                    ))
241                })?;
242
243            workspace_id = Some(info.workspace_id);
244            created_workspace_id = Some(info.workspace_id);
245            workspace_name.clone_from(&info.name);
246            workspace_ref = Some(WorkspaceRef {
247                environment_id: info.environment_id,
248                workspace_id: info.workspace_id,
249                repo_id: info.repo_id,
250            });
251
252            if let Some(repo_manager) = ctx.services.repo_manager()
253                && let Ok(info) = repo_manager
254                    .resolve_repo(info.environment_id, workspace.working_directory())
255                    .await
256            {
257                repo_ref = Some(RepoRef {
258                    environment_id: info.environment_id,
259                    repo_id: info.repo_id,
260                    root_path: info.root_path,
261                    vcs_kind: info.vcs_kind,
262                });
263            }
264        }
265
266        let env_info = workspace.environment().await.map_err(|e| {
267            StaticToolError::execution(DispatchAgentError::Workspace(workspace_op_error(e)))
268        })?;
269
270        let system_prompt = format!(
271            r#"You are an agent for a CLI-based coding tool. Given the user's prompt, you should use the tools available to you to answer the user's question.
272
273Notes:
2741. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2752. When relevant, share file names and code snippets relevant to the query
2763. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.
277
278{}
279"#,
280            env_info.as_context()
281        );
282
283        let agent_id = agent
284            .as_deref()
285            .filter(|value| !value.trim().is_empty())
286            .map_or_else(|| default_agent_spec_id().to_string(), str::to_string);
287
288        let agent_spec = agent_spec(&agent_id).ok_or_else(|| {
289            let available = agent_specs()
290                .into_iter()
291                .map(|spec| spec.id)
292                .collect::<Vec<_>>()
293                .join(", ");
294            StaticToolError::invalid_params(format!(
295                "Unknown agent spec '{agent_id}'. Available: {available}"
296            ))
297        })?;
298
299        let parent_session_config = match ctx.services.event_store.load_events(ctx.session_id).await
300        {
301            Ok(events) => events.into_iter().find_map(|(_, event)| match event {
302                SessionEvent::SessionCreated { config, .. } => Some(*config),
303                _ => None,
304            }),
305            Err(err) => {
306                warn!(
307                    session_id = %ctx.session_id,
308                    "Failed to load parent session config for MCP servers: {err}"
309                );
310                None
311            }
312        };
313
314        let parent_mcp_backends = parent_session_config
315            .as_ref()
316            .map(|config| config.tool_config.backends.clone())
317            .unwrap_or_default();
318
319        let parent_model = parent_session_config
320            .as_ref()
321            .map_or_else(default_model, |config| config.default_model.clone());
322
323        let allow_mcp_tools = agent_spec.mcp_access.allow_mcp_tools();
324        let mcp_backends = match &agent_spec.mcp_access {
325            McpAccessPolicy::None => Vec::new(),
326            McpAccessPolicy::All => parent_mcp_backends,
327            McpAccessPolicy::Allowlist(servers) => parent_mcp_backends
328                .into_iter()
329                .filter(|backend| match backend {
330                    BackendConfig::Mcp { server_name, .. } => {
331                        servers.iter().any(|allowed| allowed == server_name)
332                    }
333                })
334                .collect(),
335        };
336
337        let config = SubAgentConfig {
338            parent_session_id: ctx.session_id,
339            prompt,
340            allowed_tools: agent_spec.tools.clone(),
341            model: agent_spec.model.clone().unwrap_or(parent_model),
342            system_context: Some(crate::app::SystemContext::new(system_prompt)),
343            workspace: Some(workspace),
344            workspace_ref,
345            workspace_id,
346            repo_ref,
347            workspace_name,
348            mcp_backends,
349            allow_mcp_tools,
350        };
351
352        let spawn_result = spawner.spawn(config, ctx.cancellation_token.clone()).await;
353
354        let mut workspace_info = None;
355
356        if let (Some(manager), Some(workspace_id)) = (status_manager, created_workspace_id) {
357            let revision = match manager.get_workspace_status(workspace_id).await {
358                Ok(status) => match status.vcs {
359                    Some(vcs) => match vcs.status {
360                        VcsStatus::Jj(jj_status) => {
361                            jj_status.working_copy.map(|wc| AgentWorkspaceRevision {
362                                vcs_kind: "jj".to_string(),
363                                revision_id: wc.commit_id,
364                                summary: wc.description,
365                                change_id: Some(wc.change_id),
366                            })
367                        }
368                        VcsStatus::Git(_) => None,
369                    },
370                    None => None,
371                },
372                Err(err) => {
373                    warn!(
374                        workspace_id = %workspace_id.as_uuid(),
375                        "Failed to get workspace status for sub-agent: {err}"
376                    );
377                    None
378                }
379            };
380
381            workspace_info = Some(AgentWorkspaceInfo {
382                workspace_id: Some(workspace_id.as_uuid().to_string()),
383                revision,
384            });
385        }
386
387        let result = spawn_result.map_err(|e| match e {
388            SubAgentError::Cancelled => StaticToolError::Cancelled,
389            other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
390                message: other.to_string(),
391            }),
392        })?;
393
394        Ok(AgentResult {
395            content: result.final_message.extract_text(),
396            session_id: Some(result.session_id.to_string()),
397            workspace: workspace_info,
398        })
399    }
400}
401
402fn build_runtime_tool_executor(
403    workspace: Arc<dyn Workspace>,
404    parent_services: &Arc<ToolServices>,
405) -> Arc<ToolExecutor> {
406    let mut services = ToolServices::new(
407        workspace.clone(),
408        parent_services.event_store.clone(),
409        parent_services.api_client.clone(),
410    );
411
412    if let Some(spawner) = parent_services.agent_spawner() {
413        services = services.with_agent_spawner(spawner.clone());
414    }
415    if let Some(caller) = parent_services.model_caller() {
416        services = services.with_model_caller(caller.clone());
417    }
418    if let Some(manager) = parent_services.workspace_manager() {
419        services = services.with_workspace_manager(manager.clone());
420    }
421    if let Some(manager) = parent_services.repo_manager() {
422        services = services.with_repo_manager(manager.clone());
423    }
424    if parent_services
425        .capabilities()
426        .contains(Capabilities::NETWORK)
427    {
428        services = services.with_network();
429    }
430
431    let mut registry = ToolRegistry::new();
432    registry.register_static(GrepTool);
433    registry.register_static(GlobTool);
434    registry.register_static(LsTool);
435    registry.register_static(ViewTool);
436    registry.register_static(BashTool);
437    registry.register_static(EditTool);
438    registry.register_static(MultiEditTool);
439    registry.register_static(ReplaceTool);
440    registry.register_static(AstGrepTool);
441    registry.register_static(TodoReadTool);
442    registry.register_static(TodoWriteTool);
443    registry.register_static(DispatchAgentTool);
444    registry.register_static(FetchTool);
445
446    Arc::new(
447        ToolExecutor::with_components(
448            Arc::new(BackendRegistry::new()),
449            Arc::new(ValidatorRegistry::new()),
450        )
451        .with_static_tools(Arc::new(registry), Arc::new(services)),
452    )
453}
454
455async fn resume_agent_session(
456    session_id: SessionId,
457    prompt: String,
458    ctx: &StaticToolContext,
459) -> Result<AgentResult, StaticToolError<DispatchAgentError>> {
460    let events = ctx
461        .services
462        .event_store
463        .load_events(session_id)
464        .await
465        .map_err(|e| {
466            StaticToolError::execution(DispatchAgentError::SpawnFailed {
467                message: format!("Failed to load session {session_id}: {e}"),
468            })
469        })?;
470
471    let session_config = events
472        .into_iter()
473        .find_map(|(_, event)| match event {
474            SessionEvent::SessionCreated { config, .. } => Some(*config),
475            _ => None,
476        })
477        .ok_or_else(|| {
478            StaticToolError::execution(DispatchAgentError::SpawnFailed {
479                message: format!("Session {session_id} is missing a SessionCreated event"),
480            })
481        })?;
482
483    if session_config.parent_session_id != Some(ctx.session_id) {
484        return Err(StaticToolError::invalid_params(format!(
485            "Session {session_id} is not a child of current session {}",
486            ctx.session_id
487        )));
488    }
489
490    let workspace = create_workspace_from_session_config(&session_config.workspace)
491        .await
492        .map_err(|e| {
493            StaticToolError::execution(DispatchAgentError::SpawnFailed {
494                message: format!("Failed to open workspace for session {session_id}: {e}"),
495            })
496        })?;
497
498    let tool_executor = build_runtime_tool_executor(workspace, &ctx.services);
499    let runtime = RuntimeService::spawn(
500        ctx.services.event_store.clone(),
501        ctx.services.api_client.clone(),
502        tool_executor,
503    );
504
505    let run_result = OneShotRunner::run_in_session_with_cancel(
506        &runtime.handle,
507        session_id,
508        prompt,
509        session_config.default_model.clone(),
510        ctx.cancellation_token.clone(),
511    )
512    .await;
513
514    runtime.shutdown().await;
515
516    let run_result = run_result.map_err(|e| match e {
517        crate::error::Error::Cancelled => StaticToolError::Cancelled,
518        other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
519            message: other.to_string(),
520        }),
521    })?;
522
523    Ok(AgentResult {
524        content: run_result.final_message.extract_text(),
525        session_id: Some(run_result.session_id.to_string()),
526        workspace: None,
527    })
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use crate::agents::{AgentSpec, AgentSpecError, McpAccessPolicy, register_agent_spec};
534    use crate::api::Client as ApiClient;
535    use crate::api::{ApiError, CompletionResponse, Provider};
536    use crate::app::conversation::{AssistantContent, Message, MessageData};
537    use crate::app::domain::session::EventStore;
538    use crate::app::domain::session::event_store::InMemoryEventStore;
539    use crate::app::domain::types::ToolCallId;
540    use crate::config::model::builtin;
541    use crate::model_registry::ModelRegistry;
542    use crate::session::state::{
543        ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, ToolApprovalPolicyOverrides,
544        ToolFilter, ToolVisibility, UnapprovedBehavior,
545    };
546    use crate::tools::McpTransport;
547    use crate::tools::services::{AgentSpawner, SubAgentError, SubAgentResult, ToolServices};
548    use async_trait::async_trait;
549    use std::collections::{HashMap, HashSet};
550    use std::sync::Mutex as StdMutex;
551    use tokio::time::{Duration, sleep};
552    use tokio_util::sync::CancellationToken;
553    use uuid::Uuid;
554
555    #[derive(Clone)]
556    struct StubProvider {
557        response: String,
558    }
559
560    impl StubProvider {
561        fn new(response: impl Into<String>) -> Self {
562            Self {
563                response: response.into(),
564            }
565        }
566    }
567
568    #[derive(Clone)]
569    struct CancelAwareProvider;
570
571    #[async_trait]
572    impl Provider for CancelAwareProvider {
573        fn name(&self) -> &'static str {
574            "cancel-aware"
575        }
576
577        async fn complete(
578            &self,
579            _model_id: &crate::config::model::ModelId,
580            _messages: Vec<Message>,
581            _system: Option<crate::app::SystemContext>,
582            _tools: Option<Vec<steer_tools::ToolSchema>>,
583            _call_options: Option<crate::config::model::ModelParameters>,
584            token: CancellationToken,
585        ) -> Result<CompletionResponse, ApiError> {
586            token.cancelled().await;
587            Err(ApiError::Cancelled {
588                provider: self.name().to_string(),
589            })
590        }
591    }
592
593    #[async_trait]
594    impl Provider for StubProvider {
595        fn name(&self) -> &'static str {
596            "stub"
597        }
598
599        async fn complete(
600            &self,
601            _model_id: &crate::config::model::ModelId,
602            _messages: Vec<Message>,
603            _system: Option<crate::app::SystemContext>,
604            _tools: Option<Vec<steer_tools::ToolSchema>>,
605            _call_options: Option<crate::config::model::ModelParameters>,
606            _token: CancellationToken,
607        ) -> Result<CompletionResponse, ApiError> {
608            Ok(CompletionResponse {
609                content: vec![AssistantContent::Text {
610                    text: self.response.clone(),
611                }],
612            })
613        }
614    }
615
616    #[derive(Clone)]
617    struct StubAgentSpawner {
618        session_id: SessionId,
619        response: String,
620    }
621
622    #[async_trait]
623    impl AgentSpawner for StubAgentSpawner {
624        async fn spawn(
625            &self,
626            _config: crate::tools::services::SubAgentConfig,
627            _cancel_token: CancellationToken,
628        ) -> Result<SubAgentResult, SubAgentError> {
629            let timestamp = Message::current_timestamp();
630            let message = Message {
631                timestamp,
632                id: Message::generate_id("assistant", timestamp),
633                parent_message_id: None,
634                data: MessageData::Assistant {
635                    content: vec![AssistantContent::Text {
636                        text: self.response.clone(),
637                    }],
638                },
639            };
640
641            Ok(SubAgentResult {
642                session_id: self.session_id,
643                final_message: message,
644            })
645        }
646    }
647
648    #[derive(Clone)]
649    struct CapturingAgentSpawner {
650        session_id: SessionId,
651        response: String,
652        captured: Arc<tokio::sync::Mutex<Option<crate::tools::services::SubAgentConfig>>>,
653    }
654
655    #[async_trait]
656    impl AgentSpawner for CapturingAgentSpawner {
657        async fn spawn(
658            &self,
659            config: crate::tools::services::SubAgentConfig,
660            _cancel_token: CancellationToken,
661        ) -> Result<SubAgentResult, SubAgentError> {
662            let mut guard = self.captured.lock().await;
663            *guard = Some(config);
664
665            let timestamp = Message::current_timestamp();
666            let message = Message {
667                timestamp,
668                id: Message::generate_id("assistant", timestamp),
669                parent_message_id: None,
670                data: MessageData::Assistant {
671                    content: vec![AssistantContent::Text {
672                        text: self.response.clone(),
673                    }],
674                },
675            };
676
677            Ok(SubAgentResult {
678                session_id: self.session_id,
679                final_message: message,
680            })
681        }
682    }
683
684    #[derive(Clone)]
685    struct ToolCallThenTextProvider {
686        tool_call: steer_tools::ToolCall,
687        final_text: String,
688        call_count: Arc<StdMutex<usize>>,
689    }
690
691    impl ToolCallThenTextProvider {
692        fn new(tool_call: steer_tools::ToolCall, final_text: impl Into<String>) -> Self {
693            Self {
694                tool_call,
695                final_text: final_text.into(),
696                call_count: Arc::new(StdMutex::new(0)),
697            }
698        }
699    }
700
701    #[async_trait]
702    impl Provider for ToolCallThenTextProvider {
703        fn name(&self) -> &'static str {
704            "stub-tool-call"
705        }
706
707        async fn complete(
708            &self,
709            _model_id: &crate::config::model::ModelId,
710            _messages: Vec<Message>,
711            _system: Option<crate::app::SystemContext>,
712            _tools: Option<Vec<steer_tools::ToolSchema>>,
713            _call_options: Option<crate::config::model::ModelParameters>,
714            _token: CancellationToken,
715        ) -> Result<CompletionResponse, ApiError> {
716            let mut count = self
717                .call_count
718                .lock()
719                .expect("tool call counter lock poisoned");
720            let response = if *count == 0 {
721                CompletionResponse {
722                    content: vec![AssistantContent::ToolCall {
723                        tool_call: self.tool_call.clone(),
724                        thought_signature: None,
725                    }],
726                }
727            } else {
728                CompletionResponse {
729                    content: vec![AssistantContent::Text {
730                        text: self.final_text.clone(),
731                    }],
732                }
733            };
734            *count += 1;
735            Ok(response)
736        }
737    }
738
739    #[tokio::test]
740    async fn resume_session_rejects_non_child() {
741        let event_store = Arc::new(InMemoryEventStore::new());
742        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
743        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
744        let api_client = Arc::new(ApiClient::new_with_deps(
745            crate::test_utils::test_llm_config_provider().unwrap(),
746            provider_registry,
747            model_registry,
748        ));
749        let workspace =
750            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
751                path: std::env::current_dir().unwrap(),
752            })
753            .await
754            .unwrap();
755
756        let parent_session_id = SessionId::new();
757        let session_id = SessionId::new();
758        let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
759        session_config.parent_session_id = Some(parent_session_id);
760
761        event_store.create_session(session_id).await.unwrap();
762        event_store
763            .append(
764                session_id,
765                &SessionEvent::SessionCreated {
766                    config: Box::new(session_config),
767                    metadata: std::collections::HashMap::new(),
768                    parent_session_id: Some(parent_session_id),
769                },
770            )
771            .await
772            .unwrap();
773
774        let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
775
776        let ctx = StaticToolContext {
777            tool_call_id: ToolCallId::new(),
778            session_id: SessionId::new(),
779            cancellation_token: CancellationToken::new(),
780            services,
781        };
782
783        let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
784
785        assert!(matches!(result, Err(StaticToolError::InvalidParams(_))));
786    }
787
788    #[tokio::test]
789    async fn resume_session_accepts_child_and_returns_message() {
790        let event_store = Arc::new(InMemoryEventStore::new());
791        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
792        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
793        let api_client = Arc::new(ApiClient::new_with_deps(
794            crate::test_utils::test_llm_config_provider().unwrap(),
795            provider_registry,
796            model_registry,
797        ));
798        let model = builtin::claude_sonnet_4_5();
799        api_client.insert_test_provider(
800            model.provider.clone(),
801            Arc::new(StubProvider::new("stub-response")),
802        );
803        let workspace =
804            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
805                path: std::env::current_dir().unwrap(),
806            })
807            .await
808            .unwrap();
809
810        let parent_session_id = SessionId::new();
811        let session_id = SessionId::new();
812        let mut session_config = SessionConfig::read_only(model.clone());
813        session_config.parent_session_id = Some(parent_session_id);
814
815        event_store.create_session(session_id).await.unwrap();
816        event_store
817            .append(
818                session_id,
819                &SessionEvent::SessionCreated {
820                    config: Box::new(session_config),
821                    metadata: std::collections::HashMap::new(),
822                    parent_session_id: Some(parent_session_id),
823                },
824            )
825            .await
826            .unwrap();
827
828        let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
829
830        let ctx = StaticToolContext {
831            tool_call_id: ToolCallId::new(),
832            session_id: parent_session_id,
833            cancellation_token: CancellationToken::new(),
834            services,
835        };
836
837        let result = resume_agent_session(session_id, "ping".to_string(), &ctx)
838            .await
839            .unwrap();
840
841        assert!(result.content.contains("stub-response"));
842        assert_eq!(
843            result.session_id.as_deref(),
844            Some(session_id.to_string().as_str())
845        );
846    }
847
848    #[tokio::test]
849    async fn resume_session_honors_cancellation() {
850        let event_store = Arc::new(InMemoryEventStore::new());
851        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
852        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
853        let api_client = Arc::new(ApiClient::new_with_deps(
854            crate::test_utils::test_llm_config_provider().unwrap(),
855            provider_registry,
856            model_registry,
857        ));
858        let model = builtin::claude_sonnet_4_5();
859        api_client.insert_test_provider(model.provider.clone(), Arc::new(CancelAwareProvider));
860        let workspace =
861            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
862                path: std::env::current_dir().unwrap(),
863            })
864            .await
865            .unwrap();
866
867        let parent_session_id = SessionId::new();
868        let session_id = SessionId::new();
869        let mut session_config = SessionConfig::read_only(model);
870        session_config.parent_session_id = Some(parent_session_id);
871
872        event_store.create_session(session_id).await.unwrap();
873        event_store
874            .append(
875                session_id,
876                &SessionEvent::SessionCreated {
877                    config: Box::new(session_config),
878                    metadata: std::collections::HashMap::new(),
879                    parent_session_id: Some(parent_session_id),
880                },
881            )
882            .await
883            .unwrap();
884
885        let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
886
887        let cancel_token = CancellationToken::new();
888        let ctx = StaticToolContext {
889            tool_call_id: ToolCallId::new(),
890            session_id: parent_session_id,
891            cancellation_token: cancel_token.clone(),
892            services,
893        };
894
895        let cancel_task = tokio::spawn(async move {
896            sleep(Duration::from_millis(10)).await;
897            cancel_token.cancel();
898        });
899
900        let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
901        let _ = cancel_task.await;
902
903        assert!(matches!(result, Err(StaticToolError::Cancelled)));
904    }
905
906    #[tokio::test]
907    async fn dispatch_agent_returns_session_id() {
908        let event_store = Arc::new(InMemoryEventStore::new());
909        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
910        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
911        let api_client = Arc::new(ApiClient::new_with_deps(
912            crate::test_utils::test_llm_config_provider().unwrap(),
913            provider_registry,
914            model_registry,
915        ));
916        let workspace =
917            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
918                path: std::env::current_dir().unwrap(),
919            })
920            .await
921            .unwrap();
922
923        let session_id = SessionId::new();
924        let spawner = StubAgentSpawner {
925            session_id,
926            response: "done".to_string(),
927        };
928
929        let services = Arc::new(
930            ToolServices::new(workspace, event_store, api_client)
931                .with_agent_spawner(Arc::new(spawner)),
932        );
933
934        let ctx = StaticToolContext {
935            tool_call_id: ToolCallId::new(),
936            session_id: SessionId::new(),
937            cancellation_token: CancellationToken::new(),
938            services,
939        };
940
941        let params = DispatchAgentParams {
942            prompt: "hello".to_string(),
943            target: DispatchAgentTarget::New {
944                workspace: WorkspaceTarget::Current,
945                agent: None,
946            },
947        };
948
949        let result = DispatchAgentTool.execute(params, &ctx).await.unwrap();
950        assert_eq!(
951            result.session_id.as_deref(),
952            Some(session_id.to_string().as_str())
953        );
954    }
955
956    #[tokio::test]
957    async fn dispatch_agent_filters_mcp_backends_by_allowlist() {
958        let event_store = Arc::new(InMemoryEventStore::new());
959        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
960        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
961        let api_client = Arc::new(ApiClient::new_with_deps(
962            crate::test_utils::test_llm_config_provider().unwrap(),
963            provider_registry,
964            model_registry,
965        ));
966        let workspace =
967            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
968                path: std::env::current_dir().unwrap(),
969            })
970            .await
971            .unwrap();
972
973        let parent_session_id = SessionId::new();
974        let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
975        session_config
976            .tool_config
977            .backends
978            .push(BackendConfig::Mcp {
979                server_name: "allowed-server".to_string(),
980                transport: McpTransport::Tcp {
981                    host: "127.0.0.1".to_string(),
982                    port: 1111,
983                },
984                tool_filter: ToolFilter::All,
985            });
986        session_config
987            .tool_config
988            .backends
989            .push(BackendConfig::Mcp {
990                server_name: "blocked-server".to_string(),
991                transport: McpTransport::Tcp {
992                    host: "127.0.0.1".to_string(),
993                    port: 2222,
994                },
995                tool_filter: ToolFilter::All,
996            });
997
998        event_store.create_session(parent_session_id).await.unwrap();
999        event_store
1000            .append(
1001                parent_session_id,
1002                &SessionEvent::SessionCreated {
1003                    config: Box::new(session_config),
1004                    metadata: HashMap::new(),
1005                    parent_session_id: None,
1006                },
1007            )
1008            .await
1009            .unwrap();
1010
1011        let agent_id = format!("allowlist_{}", Uuid::new_v4());
1012        let spec = AgentSpec {
1013            id: agent_id.clone(),
1014            name: "allowlist test".to_string(),
1015            description: "allowlist test".to_string(),
1016            tools: vec![VIEW_TOOL_NAME.to_string()],
1017            mcp_access: McpAccessPolicy::Allowlist(vec!["allowed-server".to_string()]),
1018            model: None,
1019        };
1020        match register_agent_spec(spec) {
1021            Ok(()) => {}
1022            Err(AgentSpecError::AlreadyRegistered(_)) => {}
1023            Err(AgentSpecError::RegistryPoisoned) => {}
1024        }
1025
1026        let captured = Arc::new(tokio::sync::Mutex::new(None));
1027        let spawner = CapturingAgentSpawner {
1028            session_id: SessionId::new(),
1029            response: "ok".to_string(),
1030            captured: captured.clone(),
1031        };
1032
1033        let services = Arc::new(
1034            ToolServices::new(workspace, event_store, api_client)
1035                .with_agent_spawner(Arc::new(spawner)),
1036        );
1037
1038        let ctx = StaticToolContext {
1039            tool_call_id: ToolCallId::new(),
1040            session_id: parent_session_id,
1041            cancellation_token: CancellationToken::new(),
1042            services,
1043        };
1044
1045        let params = DispatchAgentParams {
1046            prompt: "test".to_string(),
1047            target: DispatchAgentTarget::New {
1048                workspace: WorkspaceTarget::Current,
1049                agent: Some(agent_id),
1050            },
1051        };
1052
1053        let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1054        let captured = captured.lock().await.clone().expect("no config captured");
1055
1056        let backend_names: Vec<String> = captured
1057            .mcp_backends
1058            .iter()
1059            .map(|backend| match backend {
1060                BackendConfig::Mcp { server_name, .. } => server_name.clone(),
1061            })
1062            .collect();
1063
1064        assert_eq!(backend_names, vec!["allowed-server".to_string()]);
1065        assert!(captured.allow_mcp_tools);
1066    }
1067
1068    #[tokio::test]
1069    async fn dispatch_agent_uses_parent_model_when_spec_missing_model() {
1070        let event_store = Arc::new(InMemoryEventStore::new());
1071        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1072        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1073        let api_client = Arc::new(ApiClient::new_with_deps(
1074            crate::test_utils::test_llm_config_provider().unwrap(),
1075            provider_registry,
1076            model_registry,
1077        ));
1078        let workspace =
1079            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1080                path: std::env::current_dir().unwrap(),
1081            })
1082            .await
1083            .unwrap();
1084
1085        let parent_session_id = SessionId::new();
1086        let parent_model = builtin::claude_sonnet_4_5();
1087        let session_config = SessionConfig::read_only(parent_model.clone());
1088
1089        event_store.create_session(parent_session_id).await.unwrap();
1090        event_store
1091            .append(
1092                parent_session_id,
1093                &SessionEvent::SessionCreated {
1094                    config: Box::new(session_config),
1095                    metadata: HashMap::new(),
1096                    parent_session_id: None,
1097                },
1098            )
1099            .await
1100            .unwrap();
1101
1102        let agent_id = format!("inherit_model_{}", Uuid::new_v4());
1103        let spec = AgentSpec {
1104            id: agent_id.clone(),
1105            name: "inherit model test".to_string(),
1106            description: "inherit model test".to_string(),
1107            tools: vec![VIEW_TOOL_NAME.to_string()],
1108            mcp_access: McpAccessPolicy::None,
1109            model: None,
1110        };
1111        match register_agent_spec(spec) {
1112            Ok(()) => {}
1113            Err(AgentSpecError::AlreadyRegistered(_)) => {}
1114            Err(AgentSpecError::RegistryPoisoned) => {}
1115        }
1116
1117        let captured = Arc::new(tokio::sync::Mutex::new(None));
1118        let spawner = CapturingAgentSpawner {
1119            session_id: SessionId::new(),
1120            response: "ok".to_string(),
1121            captured: captured.clone(),
1122        };
1123
1124        let services = Arc::new(
1125            ToolServices::new(workspace, event_store, api_client)
1126                .with_agent_spawner(Arc::new(spawner)),
1127        );
1128
1129        let ctx = StaticToolContext {
1130            tool_call_id: ToolCallId::new(),
1131            session_id: parent_session_id,
1132            cancellation_token: CancellationToken::new(),
1133            services,
1134        };
1135
1136        let params = DispatchAgentParams {
1137            prompt: "test".to_string(),
1138            target: DispatchAgentTarget::New {
1139                workspace: WorkspaceTarget::Current,
1140                agent: Some(agent_id),
1141            },
1142        };
1143
1144        let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1145        let captured = captured.lock().await.clone().expect("no config captured");
1146
1147        assert_eq!(captured.model, parent_model);
1148    }
1149
1150    #[tokio::test]
1151    async fn dispatch_agent_uses_spec_model_when_set() {
1152        let event_store = Arc::new(InMemoryEventStore::new());
1153        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1154        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1155        let api_client = Arc::new(ApiClient::new_with_deps(
1156            crate::test_utils::test_llm_config_provider().unwrap(),
1157            provider_registry,
1158            model_registry,
1159        ));
1160        let workspace =
1161            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1162                path: std::env::current_dir().unwrap(),
1163            })
1164            .await
1165            .unwrap();
1166
1167        let parent_session_id = SessionId::new();
1168        let parent_model = builtin::claude_sonnet_4_5();
1169        let session_config = SessionConfig::read_only(parent_model);
1170
1171        event_store.create_session(parent_session_id).await.unwrap();
1172        event_store
1173            .append(
1174                parent_session_id,
1175                &SessionEvent::SessionCreated {
1176                    config: Box::new(session_config),
1177                    metadata: HashMap::new(),
1178                    parent_session_id: None,
1179                },
1180            )
1181            .await
1182            .unwrap();
1183
1184        let spec_model = builtin::claude_haiku_4_5();
1185        let agent_id = format!("spec_model_{}", Uuid::new_v4());
1186        let spec = AgentSpec {
1187            id: agent_id.clone(),
1188            name: "spec model test".to_string(),
1189            description: "spec model test".to_string(),
1190            tools: vec![VIEW_TOOL_NAME.to_string()],
1191            mcp_access: McpAccessPolicy::None,
1192            model: Some(spec_model.clone()),
1193        };
1194        match register_agent_spec(spec) {
1195            Ok(()) => {}
1196            Err(AgentSpecError::AlreadyRegistered(_)) => {}
1197            Err(AgentSpecError::RegistryPoisoned) => {}
1198        }
1199
1200        let captured = Arc::new(tokio::sync::Mutex::new(None));
1201        let spawner = CapturingAgentSpawner {
1202            session_id: SessionId::new(),
1203            response: "ok".to_string(),
1204            captured: captured.clone(),
1205        };
1206
1207        let services = Arc::new(
1208            ToolServices::new(workspace, event_store, api_client)
1209                .with_agent_spawner(Arc::new(spawner)),
1210        );
1211
1212        let ctx = StaticToolContext {
1213            tool_call_id: ToolCallId::new(),
1214            session_id: parent_session_id,
1215            cancellation_token: CancellationToken::new(),
1216            services,
1217        };
1218
1219        let params = DispatchAgentParams {
1220            prompt: "test".to_string(),
1221            target: DispatchAgentTarget::New {
1222                workspace: WorkspaceTarget::Current,
1223                agent: Some(agent_id),
1224            },
1225        };
1226
1227        let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1228        let captured = captured.lock().await.clone().expect("no config captured");
1229
1230        assert_eq!(captured.model, spec_model);
1231    }
1232
1233    #[tokio::test]
1234    async fn resume_session_denies_disallowed_tools() {
1235        let event_store = Arc::new(InMemoryEventStore::new());
1236        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1237        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1238        let api_client = Arc::new(ApiClient::new_with_deps(
1239            crate::test_utils::test_llm_config_provider().unwrap(),
1240            provider_registry,
1241            model_registry,
1242        ));
1243        let workspace =
1244            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1245                path: std::env::current_dir().unwrap(),
1246            })
1247            .await
1248            .unwrap();
1249
1250        let parent_session_id = SessionId::new();
1251        let session_id = SessionId::new();
1252        let model = builtin::claude_sonnet_4_5();
1253
1254        let tool_call = steer_tools::ToolCall {
1255            name: "bash".to_string(),
1256            parameters: serde_json::json!({ "command": "echo denied" }),
1257            id: "tool_denied".to_string(),
1258        };
1259        api_client.insert_test_provider(
1260            model.provider.clone(),
1261            Arc::new(ToolCallThenTextProvider::new(tool_call, "done")),
1262        );
1263
1264        let mut session_config = SessionConfig::read_only(model);
1265        session_config.parent_session_id = Some(parent_session_id);
1266        session_config.policy_overrides = SessionPolicyOverrides {
1267            default_model: None,
1268            tool_visibility: Some(ToolVisibility::Whitelist(HashSet::from([
1269                VIEW_TOOL_NAME.to_string()
1270            ]))),
1271            approval_policy: ToolApprovalPolicyOverrides {
1272                default_behavior: Some(UnapprovedBehavior::Deny),
1273                preapproved: ApprovalRulesOverrides {
1274                    tools: HashSet::from([VIEW_TOOL_NAME.to_string()]),
1275                    per_tool: HashMap::new(),
1276                },
1277            },
1278        };
1279
1280        event_store.create_session(session_id).await.unwrap();
1281        event_store
1282            .append(
1283                session_id,
1284                &SessionEvent::SessionCreated {
1285                    config: Box::new(session_config),
1286                    metadata: HashMap::new(),
1287                    parent_session_id: Some(parent_session_id),
1288                },
1289            )
1290            .await
1291            .unwrap();
1292
1293        let services = Arc::new(ToolServices::new(
1294            workspace,
1295            event_store.clone(),
1296            api_client,
1297        ));
1298
1299        let ctx = StaticToolContext {
1300            tool_call_id: ToolCallId::new(),
1301            session_id: parent_session_id,
1302            cancellation_token: CancellationToken::new(),
1303            services,
1304        };
1305
1306        let _ = resume_agent_session(session_id, "trigger".to_string(), &ctx)
1307            .await
1308            .unwrap();
1309
1310        let events = event_store.load_events(session_id).await.unwrap();
1311        let denied = events.iter().any(|(_, event)| match event {
1312            SessionEvent::ToolCallFailed { name, error, .. } => {
1313                name == "bash" && error.contains("denied by policy")
1314            }
1315            _ => false,
1316        });
1317
1318        assert!(denied, "expected denied ToolCallFailed event for bash");
1319    }
1320}