Skip to main content

steer_core/tools/static_tools/
dispatch_agent.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::agents::{
6    McpAccessPolicy, agent_spec, agent_specs, agent_specs_prompt, default_agent_spec_id,
7};
8use crate::app::domain::event::SessionEvent;
9use crate::app::domain::runtime::RuntimeService;
10use crate::app::domain::types::SessionId;
11use crate::app::validation::ValidatorRegistry;
12use crate::config::model::builtin::claude_sonnet_4_5 as default_model;
13use crate::runners::OneShotRunner;
14use crate::session::state::BackendConfig;
15use crate::tools::capability::Capabilities;
16use crate::tools::services::{SubAgentConfig, SubAgentError, ToolServices};
17use crate::tools::static_tool::{StaticTool, StaticToolContext, StaticToolError};
18use crate::tools::{BackendRegistry, ToolExecutor, ToolRegistry};
19use crate::workspace::{
20    CreateWorkspaceRequest, EnvironmentId, RepoRef, VcsKind, VcsStatus, Workspace,
21    WorkspaceCreateStrategy, WorkspaceRef, create_workspace_from_session_config,
22};
23use steer_tools::ToolSpec;
24use steer_tools::result::{AgentResult, AgentWorkspaceInfo, AgentWorkspaceRevision};
25use steer_tools::tools::dispatch_agent::{
26    DispatchAgentError, DispatchAgentParams, DispatchAgentTarget, DispatchAgentToolSpec,
27    WorkspaceTarget,
28};
29use steer_tools::tools::{GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
30use tracing::warn;
31
32use super::{
33    AstGrepTool, BashTool, EditTool, FetchTool, GlobTool, GrepTool, LsTool, MultiEditTool,
34    ReplaceTool, TodoReadTool, TodoWriteTool, ViewTool, workspace_manager_op_error,
35    workspace_op_error,
36};
37
38fn dispatch_agent_description() -> String {
39    let agent_specs = agent_specs_prompt();
40    let agent_specs_block = if agent_specs.is_empty() {
41        "No agent specs registered.".to_string()
42    } else {
43        agent_specs
44    };
45
46    format!(
47        r#"Launch a new agent to help with a focused task. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you.
48
49    When to use the Agent tool:
50    - If you are searching for a keyword like "config" or "logger", or for questions like "which file does X?", the Agent tool is strongly recommended
51
52    When NOT to use the Agent tool:
53    - If you want to read a specific file path, use the {} or {} tool instead of the Agent tool, to find the match more quickly
54    - If you are searching for a specific class definition like "class Foo", use the {} tool instead, to find the match more quickly
55    - If you are searching for code within a specific file or set of 2-3 files, use the {} tool instead, to find the match more quickly
56
57    Usage notes:
58    1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses
59    2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.
60    3. Each invocation returns a session_id. Pass it back via `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue the conversation with the same agent.
61    4. When `target.session` is `resume`, the session_id must refer to a child of the current session. The `agent` and `workspace` options are ignored and the existing session config is used.
62    5. The agent's outputs should generally be trusted
63    6. IMPORTANT: Only some agent specs include write tools. Use a build agent if the task requires editing files.
64    7. IMPORTANT: New workspaces are preserved (not auto-deleted). Clean them up manually if needed.
65    8. If the agent spec omits a model, the parent session's default model is used.
66    9. IMPORTANT: Do NOT include `Repo: <path>`, `CWD: <path>`, or similar path headers in your prompt. The sub-agent already receives its working directory via system instructions.
67    10. IMPORTANT: If `target.session` is `new` and `workspace.location` is `new`, the sub-agent runs in the newly created workspace path, which may differ from the caller's current directory.
68
69Workspace options:
70- `workspace: {{ "location": "current" }}` to run in the current workspace
71- `workspace: {{ "location": "new", "name": "..." }}` to run in a fresh workspace (jj workspace or git worktree)
72- `location` is a logical workspace selector, not a filesystem path
73
74    Session options:
75    - `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue a prior dispatch_agent session
76
77    New session options:
78    - `target: {{ "session": "new", "workspace": {{ "location": "current" }} }}` to run in the current workspace
79    - `target: {{ "session": "new", "workspace": {{ "location": "new", "name": "..." }} }}` to run in a new workspace
80    - `target: {{ "session": "new", "workspace": {{ "location": "current" }}, "agent": "<id>" }}` selects an agent spec (defaults to "{default_agent}")
81
82    {agent_specs_block}"#,
83        VIEW_TOOL_NAME,
84        LS_TOOL_NAME,
85        GREP_TOOL_NAME,
86        GREP_TOOL_NAME,
87        default_agent = default_agent_spec_id(),
88        agent_specs_block = agent_specs_block
89    )
90}
91
92pub struct DispatchAgentTool;
93
94#[async_trait]
95impl StaticTool for DispatchAgentTool {
96    type Params = DispatchAgentParams;
97    type Output = AgentResult;
98    type Spec = DispatchAgentToolSpec;
99
100    const DESCRIPTION: &'static str = "Launch a sub-agent to search for files or code";
101    const REQUIRES_APPROVAL: bool = false;
102    const REQUIRED_CAPABILITIES: Capabilities = Capabilities::AGENT;
103
104    fn schema() -> steer_tools::ToolSchema {
105        let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
106            s.inline_subschemas = true;
107        });
108        let schema_gen = settings.into_generator();
109        let input_schema = schema_gen.into_root_schema_for::<Self::Params>();
110
111        steer_tools::ToolSchema {
112            name: Self::Spec::NAME.to_string(),
113            display_name: Self::Spec::DISPLAY_NAME.to_string(),
114            description: dispatch_agent_description(),
115            input_schema: input_schema.into(),
116        }
117    }
118
119    async fn execute(
120        &self,
121        params: Self::Params,
122        ctx: &StaticToolContext,
123    ) -> Result<Self::Output, StaticToolError<DispatchAgentError>> {
124        let DispatchAgentParams { prompt, target } = params;
125
126        let (workspace_target, agent) = match target {
127            DispatchAgentTarget::Resume { session_id } => {
128                let session_id = SessionId::parse(&session_id).ok_or_else(|| {
129                    StaticToolError::invalid_params(format!("Invalid session_id '{session_id}'"))
130                })?;
131                return resume_agent_session(session_id, prompt, ctx).await;
132            }
133            DispatchAgentTarget::New { workspace, agent } => (workspace, agent),
134        };
135
136        let spawner = ctx
137            .services
138            .agent_spawner()
139            .ok_or_else(|| StaticToolError::missing_capability("agent_spawner"))?;
140
141        let base_workspace = ctx.services.workspace.clone();
142        let base_path = base_workspace.working_directory().to_path_buf();
143
144        let mut workspace = base_workspace.clone();
145        let mut workspace_ref = None;
146        let mut workspace_id = None;
147        let mut workspace_name = None;
148        let mut repo_id = None;
149        let mut repo_ref = None;
150
151        if let Some(manager) = ctx.services.workspace_manager()
152            && let Ok(info) = manager.resolve_workspace(&base_path).await
153        {
154            workspace_id = Some(info.workspace_id);
155            workspace_name.clone_from(&info.name);
156            repo_id = Some(info.repo_id);
157            workspace_ref = Some(WorkspaceRef {
158                environment_id: info.environment_id,
159                workspace_id: info.workspace_id,
160                repo_id: info.repo_id,
161            });
162        }
163
164        if let Some(manager) = ctx.services.repo_manager() {
165            let repo_env_id = workspace_ref
166                .as_ref()
167                .map_or_else(EnvironmentId::local, |reference| reference.environment_id);
168            if let Ok(info) = manager.resolve_repo(repo_env_id, &base_path).await {
169                if repo_id.is_none() {
170                    repo_id = Some(info.repo_id);
171                }
172                repo_ref = Some(RepoRef {
173                    environment_id: info.environment_id,
174                    repo_id: info.repo_id,
175                    root_path: info.root_path,
176                    vcs_kind: info.vcs_kind,
177                });
178            }
179        }
180
181        let mut new_workspace = false;
182        let mut requested_workspace_name = None;
183
184        match &workspace_target {
185            WorkspaceTarget::Current => {}
186            WorkspaceTarget::New { name } => {
187                new_workspace = true;
188                requested_workspace_name = Some(name.clone());
189            }
190        }
191
192        let mut created_workspace_id = None;
193        let mut status_manager = None;
194
195        if new_workspace {
196            let manager = ctx
197                .services
198                .workspace_manager()
199                .ok_or_else(|| StaticToolError::missing_capability("workspace_manager"))?;
200            status_manager = Some(manager.clone());
201
202            let base_repo_id = repo_id.ok_or_else(|| {
203                StaticToolError::execution(DispatchAgentError::WorkspaceUnavailable {
204                    message:
205                        "Current path is not a supported workspace; cannot create new workspace"
206                            .to_string(),
207                })
208            })?;
209
210            let strategy = match repo_ref
211                .as_ref()
212                .and_then(|reference| reference.vcs_kind.as_ref())
213            {
214                Some(VcsKind::Git) => WorkspaceCreateStrategy::GitWorktree,
215                _ => WorkspaceCreateStrategy::JjWorkspace,
216            };
217
218            let create_request = CreateWorkspaceRequest {
219                repo_id: base_repo_id,
220                name: requested_workspace_name.clone(),
221                parent_workspace_id: workspace_id,
222                strategy,
223            };
224
225            let info = manager
226                .create_workspace(create_request)
227                .await
228                .map_err(|e| {
229                    StaticToolError::execution(DispatchAgentError::Workspace(
230                        workspace_manager_op_error(e),
231                    ))
232                })?;
233
234            workspace = manager
235                .open_workspace(info.workspace_id)
236                .await
237                .map_err(|e| {
238                    StaticToolError::execution(DispatchAgentError::Workspace(
239                        workspace_manager_op_error(e),
240                    ))
241                })?;
242
243            workspace_id = Some(info.workspace_id);
244            created_workspace_id = Some(info.workspace_id);
245            workspace_name.clone_from(&info.name);
246            workspace_ref = Some(WorkspaceRef {
247                environment_id: info.environment_id,
248                workspace_id: info.workspace_id,
249                repo_id: info.repo_id,
250            });
251
252            if let Some(repo_manager) = ctx.services.repo_manager()
253                && let Ok(info) = repo_manager
254                    .resolve_repo(info.environment_id, workspace.working_directory())
255                    .await
256            {
257                repo_ref = Some(RepoRef {
258                    environment_id: info.environment_id,
259                    repo_id: info.repo_id,
260                    root_path: info.root_path,
261                    vcs_kind: info.vcs_kind,
262                });
263            }
264        }
265
266        let env_info = workspace.environment().await.map_err(|e| {
267            StaticToolError::execution(DispatchAgentError::Workspace(workspace_op_error(e)))
268        })?;
269
270        let system_prompt = format!(
271            r#"You are an agent for a CLI-based coding tool. Given the user's prompt, you should use the tools available to you to answer the user's question.
272
273Notes:
2741. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2752. When relevant, share file names and code snippets relevant to the query
2763. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.
277
278{}
279"#,
280            env_info.as_context()
281        );
282
283        let agent_id = agent
284            .as_deref()
285            .filter(|value| !value.trim().is_empty())
286            .map_or_else(|| default_agent_spec_id().to_string(), str::to_string);
287
288        let agent_spec = agent_spec(&agent_id).ok_or_else(|| {
289            let available = agent_specs()
290                .into_iter()
291                .map(|spec| spec.id)
292                .collect::<Vec<_>>()
293                .join(", ");
294            StaticToolError::invalid_params(format!(
295                "Unknown agent spec '{agent_id}'. Available: {available}"
296            ))
297        })?;
298
299        let parent_session_config = match ctx.services.event_store.load_events(ctx.session_id).await
300        {
301            Ok(events) => events.into_iter().find_map(|(_, event)| match event {
302                SessionEvent::SessionCreated { config, .. } => Some(*config),
303                _ => None,
304            }),
305            Err(err) => {
306                warn!(
307                    session_id = %ctx.session_id,
308                    "Failed to load parent session config for MCP servers: {err}"
309                );
310                None
311            }
312        };
313
314        let parent_mcp_backends = parent_session_config
315            .as_ref()
316            .map(|config| config.tool_config.backends.clone())
317            .unwrap_or_default();
318
319        let parent_model = parent_session_config
320            .as_ref()
321            .map_or_else(default_model, |config| config.default_model.clone());
322
323        let allow_mcp_tools = agent_spec.mcp_access.allow_mcp_tools();
324        let mcp_backends = match &agent_spec.mcp_access {
325            McpAccessPolicy::None => Vec::new(),
326            McpAccessPolicy::All => parent_mcp_backends,
327            McpAccessPolicy::Allowlist(servers) => parent_mcp_backends
328                .into_iter()
329                .filter(|backend| match backend {
330                    BackendConfig::Mcp { server_name, .. } => {
331                        servers.iter().any(|allowed| allowed == server_name)
332                    }
333                })
334                .collect(),
335        };
336
337        let config = SubAgentConfig {
338            parent_session_id: ctx.session_id,
339            prompt,
340            allowed_tools: agent_spec.tools.clone(),
341            model: agent_spec.model.clone().unwrap_or(parent_model),
342            system_context: Some(crate::app::SystemContext::new(system_prompt)),
343            workspace: Some(workspace),
344            workspace_ref,
345            workspace_id,
346            repo_ref,
347            workspace_name,
348            mcp_backends,
349            allow_mcp_tools,
350        };
351
352        let spawn_result = spawner.spawn(config, ctx.cancellation_token.clone()).await;
353
354        let mut workspace_info = None;
355
356        if let (Some(manager), Some(workspace_id)) = (status_manager, created_workspace_id) {
357            let revision = match manager.get_workspace_status(workspace_id).await {
358                Ok(status) => match status.vcs {
359                    Some(vcs) => match vcs.status {
360                        VcsStatus::Jj(jj_status) => {
361                            jj_status.working_copy.map(|wc| AgentWorkspaceRevision {
362                                vcs_kind: "jj".to_string(),
363                                revision_id: wc.commit_id,
364                                summary: wc.description,
365                                change_id: Some(wc.change_id),
366                            })
367                        }
368                        VcsStatus::Git(_) => None,
369                    },
370                    None => None,
371                },
372                Err(err) => {
373                    warn!(
374                        workspace_id = %workspace_id.as_uuid(),
375                        "Failed to get workspace status for sub-agent: {err}"
376                    );
377                    None
378                }
379            };
380
381            workspace_info = Some(AgentWorkspaceInfo {
382                workspace_id: Some(workspace_id.as_uuid().to_string()),
383                revision,
384            });
385        }
386
387        let result = spawn_result.map_err(|e| match e {
388            SubAgentError::Cancelled => StaticToolError::Cancelled,
389            other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
390                message: other.to_string(),
391            }),
392        })?;
393
394        Ok(AgentResult {
395            content: result.final_message.extract_text(),
396            session_id: Some(result.session_id.to_string()),
397            workspace: workspace_info,
398        })
399    }
400}
401
402fn build_runtime_tool_executor(
403    workspace: Arc<dyn Workspace>,
404    parent_services: &Arc<ToolServices>,
405) -> Arc<ToolExecutor> {
406    let mut services = ToolServices::new(
407        workspace.clone(),
408        parent_services.event_store.clone(),
409        parent_services.api_client.clone(),
410    );
411
412    if let Some(spawner) = parent_services.agent_spawner() {
413        services = services.with_agent_spawner(spawner.clone());
414    }
415    if let Some(caller) = parent_services.model_caller() {
416        services = services.with_model_caller(caller.clone());
417    }
418    if let Some(manager) = parent_services.workspace_manager() {
419        services = services.with_workspace_manager(manager.clone());
420    }
421    if let Some(manager) = parent_services.repo_manager() {
422        services = services.with_repo_manager(manager.clone());
423    }
424    if parent_services
425        .capabilities()
426        .contains(Capabilities::NETWORK)
427    {
428        services = services.with_network();
429    }
430
431    let mut registry = ToolRegistry::new();
432    registry.register_static(GrepTool);
433    registry.register_static(GlobTool);
434    registry.register_static(LsTool);
435    registry.register_static(ViewTool);
436    registry.register_static(BashTool);
437    registry.register_static(EditTool);
438    registry.register_static(MultiEditTool);
439    registry.register_static(ReplaceTool);
440    registry.register_static(AstGrepTool);
441    registry.register_static(TodoReadTool);
442    registry.register_static(TodoWriteTool);
443    registry.register_static(DispatchAgentTool);
444    registry.register_static(FetchTool);
445
446    Arc::new(
447        ToolExecutor::with_components(
448            Arc::new(BackendRegistry::new()),
449            Arc::new(ValidatorRegistry::new()),
450        )
451        .with_static_tools(Arc::new(registry), Arc::new(services)),
452    )
453}
454
455async fn resume_agent_session(
456    session_id: SessionId,
457    prompt: String,
458    ctx: &StaticToolContext,
459) -> Result<AgentResult, StaticToolError<DispatchAgentError>> {
460    let events = ctx
461        .services
462        .event_store
463        .load_events(session_id)
464        .await
465        .map_err(|e| {
466            StaticToolError::execution(DispatchAgentError::SpawnFailed {
467                message: format!("Failed to load session {session_id}: {e}"),
468            })
469        })?;
470
471    let session_config = events
472        .into_iter()
473        .find_map(|(_, event)| match event {
474            SessionEvent::SessionCreated { config, .. } => Some(*config),
475            _ => None,
476        })
477        .ok_or_else(|| {
478            StaticToolError::execution(DispatchAgentError::SpawnFailed {
479                message: format!("Session {session_id} is missing a SessionCreated event"),
480            })
481        })?;
482
483    if session_config.parent_session_id != Some(ctx.session_id) {
484        return Err(StaticToolError::invalid_params(format!(
485            "Session {session_id} is not a child of current session {}",
486            ctx.session_id
487        )));
488    }
489
490    let workspace = create_workspace_from_session_config(&session_config.workspace)
491        .await
492        .map_err(|e| {
493            StaticToolError::execution(DispatchAgentError::SpawnFailed {
494                message: format!("Failed to open workspace for session {session_id}: {e}"),
495            })
496        })?;
497
498    let tool_executor = build_runtime_tool_executor(workspace, &ctx.services);
499    let runtime = RuntimeService::spawn(
500        ctx.services.event_store.clone(),
501        ctx.services.api_client.clone(),
502        tool_executor,
503    );
504
505    let run_result = OneShotRunner::run_in_session_with_cancel(
506        &runtime.handle,
507        session_id,
508        prompt,
509        session_config.default_model.clone(),
510        ctx.cancellation_token.clone(),
511    )
512    .await;
513
514    runtime.shutdown().await;
515
516    let run_result = run_result.map_err(|e| match e {
517        crate::error::Error::Cancelled => StaticToolError::Cancelled,
518        other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
519            message: other.to_string(),
520        }),
521    })?;
522
523    Ok(AgentResult {
524        content: run_result.final_message.extract_text(),
525        session_id: Some(run_result.session_id.to_string()),
526        workspace: None,
527    })
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use crate::agents::{AgentSpec, AgentSpecError, McpAccessPolicy, register_agent_spec};
534    use crate::api::Client as ApiClient;
535    use crate::api::{ApiError, CompletionResponse, Provider};
536    use crate::app::conversation::{AssistantContent, Message, MessageData};
537    use crate::app::domain::session::EventStore;
538    use crate::app::domain::session::event_store::InMemoryEventStore;
539    use crate::app::domain::types::ToolCallId;
540    use crate::config::model::builtin;
541    use crate::model_registry::ModelRegistry;
542    use crate::session::state::{
543        ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, ToolApprovalPolicyOverrides,
544        ToolFilter, ToolVisibility, UnapprovedBehavior,
545    };
546    use crate::tools::McpTransport;
547    use crate::tools::services::{AgentSpawner, SubAgentError, SubAgentResult, ToolServices};
548    use async_trait::async_trait;
549    use std::collections::{HashMap, HashSet};
550    use std::sync::Mutex as StdMutex;
551    use tokio::time::{Duration, sleep};
552    use tokio_util::sync::CancellationToken;
553    use uuid::Uuid;
554
555    #[derive(Clone)]
556    struct StubProvider {
557        response: String,
558    }
559
560    impl StubProvider {
561        fn new(response: impl Into<String>) -> Self {
562            Self {
563                response: response.into(),
564            }
565        }
566    }
567
568    #[derive(Clone)]
569    struct CancelAwareProvider;
570
571    #[async_trait]
572    impl Provider for CancelAwareProvider {
573        fn name(&self) -> &'static str {
574            "cancel-aware"
575        }
576
577        async fn complete(
578            &self,
579            _model_id: &crate::config::model::ModelId,
580            _messages: Vec<Message>,
581            _system: Option<crate::app::SystemContext>,
582            _tools: Option<Vec<steer_tools::ToolSchema>>,
583            _call_options: Option<crate::config::model::ModelParameters>,
584            token: CancellationToken,
585        ) -> Result<CompletionResponse, ApiError> {
586            token.cancelled().await;
587            Err(ApiError::Cancelled {
588                provider: self.name().to_string(),
589            })
590        }
591    }
592
593    #[async_trait]
594    impl Provider for StubProvider {
595        fn name(&self) -> &'static str {
596            "stub"
597        }
598
599        async fn complete(
600            &self,
601            _model_id: &crate::config::model::ModelId,
602            _messages: Vec<Message>,
603            _system: Option<crate::app::SystemContext>,
604            _tools: Option<Vec<steer_tools::ToolSchema>>,
605            _call_options: Option<crate::config::model::ModelParameters>,
606            _token: CancellationToken,
607        ) -> Result<CompletionResponse, ApiError> {
608            Ok(CompletionResponse {
609                content: vec![AssistantContent::Text {
610                    text: self.response.clone(),
611                }],
612                usage: None,
613            })
614        }
615    }
616
617    #[derive(Clone)]
618    struct StubAgentSpawner {
619        session_id: SessionId,
620        response: String,
621    }
622
623    #[async_trait]
624    impl AgentSpawner for StubAgentSpawner {
625        async fn spawn(
626            &self,
627            _config: crate::tools::services::SubAgentConfig,
628            _cancel_token: CancellationToken,
629        ) -> Result<SubAgentResult, SubAgentError> {
630            let timestamp = Message::current_timestamp();
631            let message = Message {
632                timestamp,
633                id: Message::generate_id("assistant", timestamp),
634                parent_message_id: None,
635                data: MessageData::Assistant {
636                    content: vec![AssistantContent::Text {
637                        text: self.response.clone(),
638                    }],
639                },
640            };
641
642            Ok(SubAgentResult {
643                session_id: self.session_id,
644                final_message: message,
645            })
646        }
647    }
648
649    #[derive(Clone)]
650    struct CapturingAgentSpawner {
651        session_id: SessionId,
652        response: String,
653        captured: Arc<tokio::sync::Mutex<Option<crate::tools::services::SubAgentConfig>>>,
654    }
655
656    #[async_trait]
657    impl AgentSpawner for CapturingAgentSpawner {
658        async fn spawn(
659            &self,
660            config: crate::tools::services::SubAgentConfig,
661            _cancel_token: CancellationToken,
662        ) -> Result<SubAgentResult, SubAgentError> {
663            let mut guard = self.captured.lock().await;
664            *guard = Some(config);
665
666            let timestamp = Message::current_timestamp();
667            let message = Message {
668                timestamp,
669                id: Message::generate_id("assistant", timestamp),
670                parent_message_id: None,
671                data: MessageData::Assistant {
672                    content: vec![AssistantContent::Text {
673                        text: self.response.clone(),
674                    }],
675                },
676            };
677
678            Ok(SubAgentResult {
679                session_id: self.session_id,
680                final_message: message,
681            })
682        }
683    }
684
685    #[derive(Clone)]
686    struct ToolCallThenTextProvider {
687        tool_call: steer_tools::ToolCall,
688        final_text: String,
689        call_count: Arc<StdMutex<usize>>,
690    }
691
692    impl ToolCallThenTextProvider {
693        fn new(tool_call: steer_tools::ToolCall, final_text: impl Into<String>) -> Self {
694            Self {
695                tool_call,
696                final_text: final_text.into(),
697                call_count: Arc::new(StdMutex::new(0)),
698            }
699        }
700    }
701
702    #[async_trait]
703    impl Provider for ToolCallThenTextProvider {
704        fn name(&self) -> &'static str {
705            "stub-tool-call"
706        }
707
708        async fn complete(
709            &self,
710            _model_id: &crate::config::model::ModelId,
711            _messages: Vec<Message>,
712            _system: Option<crate::app::SystemContext>,
713            _tools: Option<Vec<steer_tools::ToolSchema>>,
714            _call_options: Option<crate::config::model::ModelParameters>,
715            _token: CancellationToken,
716        ) -> Result<CompletionResponse, ApiError> {
717            let mut count = self
718                .call_count
719                .lock()
720                .expect("tool call counter lock poisoned");
721            let response = if *count == 0 {
722                CompletionResponse {
723                    content: vec![AssistantContent::ToolCall {
724                        tool_call: self.tool_call.clone(),
725                        thought_signature: None,
726                    }],
727                    usage: None,
728                }
729            } else {
730                CompletionResponse {
731                    content: vec![AssistantContent::Text {
732                        text: self.final_text.clone(),
733                    }],
734                    usage: None,
735                }
736            };
737            *count += 1;
738            Ok(response)
739        }
740    }
741
742    #[tokio::test]
743    async fn resume_session_rejects_non_child() {
744        let event_store = Arc::new(InMemoryEventStore::new());
745        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
746        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
747        let api_client = Arc::new(ApiClient::new_with_deps(
748            crate::test_utils::test_llm_config_provider().unwrap(),
749            provider_registry,
750            model_registry,
751        ));
752        let workspace =
753            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
754                path: std::env::current_dir().unwrap(),
755            })
756            .await
757            .unwrap();
758
759        let parent_session_id = SessionId::new();
760        let session_id = SessionId::new();
761        let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
762        session_config.parent_session_id = Some(parent_session_id);
763
764        event_store.create_session(session_id).await.unwrap();
765        event_store
766            .append(
767                session_id,
768                &SessionEvent::SessionCreated {
769                    config: Box::new(session_config),
770                    metadata: std::collections::HashMap::new(),
771                    parent_session_id: Some(parent_session_id),
772                },
773            )
774            .await
775            .unwrap();
776
777        let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
778
779        let ctx = StaticToolContext {
780            tool_call_id: ToolCallId::new(),
781            session_id: SessionId::new(),
782            cancellation_token: CancellationToken::new(),
783            services,
784        };
785
786        let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
787
788        assert!(matches!(result, Err(StaticToolError::InvalidParams(_))));
789    }
790
791    #[tokio::test]
792    async fn resume_session_accepts_child_and_returns_message() {
793        let event_store = Arc::new(InMemoryEventStore::new());
794        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
795        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
796        let api_client = Arc::new(ApiClient::new_with_deps(
797            crate::test_utils::test_llm_config_provider().unwrap(),
798            provider_registry,
799            model_registry,
800        ));
801        let model = builtin::claude_sonnet_4_5();
802        api_client.insert_test_provider(
803            model.provider.clone(),
804            Arc::new(StubProvider::new("stub-response")),
805        );
806        let workspace =
807            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
808                path: std::env::current_dir().unwrap(),
809            })
810            .await
811            .unwrap();
812
813        let parent_session_id = SessionId::new();
814        let session_id = SessionId::new();
815        let mut session_config = SessionConfig::read_only(model.clone());
816        session_config.parent_session_id = Some(parent_session_id);
817
818        event_store.create_session(session_id).await.unwrap();
819        event_store
820            .append(
821                session_id,
822                &SessionEvent::SessionCreated {
823                    config: Box::new(session_config),
824                    metadata: std::collections::HashMap::new(),
825                    parent_session_id: Some(parent_session_id),
826                },
827            )
828            .await
829            .unwrap();
830
831        let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
832
833        let ctx = StaticToolContext {
834            tool_call_id: ToolCallId::new(),
835            session_id: parent_session_id,
836            cancellation_token: CancellationToken::new(),
837            services,
838        };
839
840        let result = resume_agent_session(session_id, "ping".to_string(), &ctx)
841            .await
842            .unwrap();
843
844        assert!(result.content.contains("stub-response"));
845        assert_eq!(
846            result.session_id.as_deref(),
847            Some(session_id.to_string().as_str())
848        );
849    }
850
851    #[tokio::test]
852    async fn resume_session_honors_cancellation() {
853        let event_store = Arc::new(InMemoryEventStore::new());
854        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
855        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
856        let api_client = Arc::new(ApiClient::new_with_deps(
857            crate::test_utils::test_llm_config_provider().unwrap(),
858            provider_registry,
859            model_registry,
860        ));
861        let model = builtin::claude_sonnet_4_5();
862        api_client.insert_test_provider(model.provider.clone(), Arc::new(CancelAwareProvider));
863        let workspace =
864            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
865                path: std::env::current_dir().unwrap(),
866            })
867            .await
868            .unwrap();
869
870        let parent_session_id = SessionId::new();
871        let session_id = SessionId::new();
872        let mut session_config = SessionConfig::read_only(model);
873        session_config.parent_session_id = Some(parent_session_id);
874
875        event_store.create_session(session_id).await.unwrap();
876        event_store
877            .append(
878                session_id,
879                &SessionEvent::SessionCreated {
880                    config: Box::new(session_config),
881                    metadata: std::collections::HashMap::new(),
882                    parent_session_id: Some(parent_session_id),
883                },
884            )
885            .await
886            .unwrap();
887
888        let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
889
890        let cancel_token = CancellationToken::new();
891        let ctx = StaticToolContext {
892            tool_call_id: ToolCallId::new(),
893            session_id: parent_session_id,
894            cancellation_token: cancel_token.clone(),
895            services,
896        };
897
898        let cancel_task = tokio::spawn(async move {
899            sleep(Duration::from_millis(10)).await;
900            cancel_token.cancel();
901        });
902
903        let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
904        let _ = cancel_task.await;
905
906        assert!(matches!(result, Err(StaticToolError::Cancelled)));
907    }
908
909    #[tokio::test]
910    async fn dispatch_agent_returns_session_id() {
911        let event_store = Arc::new(InMemoryEventStore::new());
912        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
913        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
914        let api_client = Arc::new(ApiClient::new_with_deps(
915            crate::test_utils::test_llm_config_provider().unwrap(),
916            provider_registry,
917            model_registry,
918        ));
919        let workspace =
920            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
921                path: std::env::current_dir().unwrap(),
922            })
923            .await
924            .unwrap();
925
926        let session_id = SessionId::new();
927        let spawner = StubAgentSpawner {
928            session_id,
929            response: "done".to_string(),
930        };
931
932        let services = Arc::new(
933            ToolServices::new(workspace, event_store, api_client)
934                .with_agent_spawner(Arc::new(spawner)),
935        );
936
937        let ctx = StaticToolContext {
938            tool_call_id: ToolCallId::new(),
939            session_id: SessionId::new(),
940            cancellation_token: CancellationToken::new(),
941            services,
942        };
943
944        let params = DispatchAgentParams {
945            prompt: "hello".to_string(),
946            target: DispatchAgentTarget::New {
947                workspace: WorkspaceTarget::Current,
948                agent: None,
949            },
950        };
951
952        let result = DispatchAgentTool.execute(params, &ctx).await.unwrap();
953        assert_eq!(
954            result.session_id.as_deref(),
955            Some(session_id.to_string().as_str())
956        );
957    }
958
959    #[tokio::test]
960    async fn dispatch_agent_filters_mcp_backends_by_allowlist() {
961        let event_store = Arc::new(InMemoryEventStore::new());
962        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
963        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
964        let api_client = Arc::new(ApiClient::new_with_deps(
965            crate::test_utils::test_llm_config_provider().unwrap(),
966            provider_registry,
967            model_registry,
968        ));
969        let workspace =
970            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
971                path: std::env::current_dir().unwrap(),
972            })
973            .await
974            .unwrap();
975
976        let parent_session_id = SessionId::new();
977        let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
978        session_config
979            .tool_config
980            .backends
981            .push(BackendConfig::Mcp {
982                server_name: "allowed-server".to_string(),
983                transport: McpTransport::Tcp {
984                    host: "127.0.0.1".to_string(),
985                    port: 1111,
986                },
987                tool_filter: ToolFilter::All,
988            });
989        session_config
990            .tool_config
991            .backends
992            .push(BackendConfig::Mcp {
993                server_name: "blocked-server".to_string(),
994                transport: McpTransport::Tcp {
995                    host: "127.0.0.1".to_string(),
996                    port: 2222,
997                },
998                tool_filter: ToolFilter::All,
999            });
1000
1001        event_store.create_session(parent_session_id).await.unwrap();
1002        event_store
1003            .append(
1004                parent_session_id,
1005                &SessionEvent::SessionCreated {
1006                    config: Box::new(session_config),
1007                    metadata: HashMap::new(),
1008                    parent_session_id: None,
1009                },
1010            )
1011            .await
1012            .unwrap();
1013
1014        let agent_id = format!("allowlist_{}", Uuid::new_v4());
1015        let spec = AgentSpec {
1016            id: agent_id.clone(),
1017            name: "allowlist test".to_string(),
1018            description: "allowlist test".to_string(),
1019            tools: vec![VIEW_TOOL_NAME.to_string()],
1020            mcp_access: McpAccessPolicy::Allowlist(vec!["allowed-server".to_string()]),
1021            model: None,
1022        };
1023        match register_agent_spec(spec) {
1024            Ok(()) => {}
1025            Err(AgentSpecError::AlreadyRegistered(_)) => {}
1026            Err(AgentSpecError::RegistryPoisoned) => {}
1027        }
1028
1029        let captured = Arc::new(tokio::sync::Mutex::new(None));
1030        let spawner = CapturingAgentSpawner {
1031            session_id: SessionId::new(),
1032            response: "ok".to_string(),
1033            captured: captured.clone(),
1034        };
1035
1036        let services = Arc::new(
1037            ToolServices::new(workspace, event_store, api_client)
1038                .with_agent_spawner(Arc::new(spawner)),
1039        );
1040
1041        let ctx = StaticToolContext {
1042            tool_call_id: ToolCallId::new(),
1043            session_id: parent_session_id,
1044            cancellation_token: CancellationToken::new(),
1045            services,
1046        };
1047
1048        let params = DispatchAgentParams {
1049            prompt: "test".to_string(),
1050            target: DispatchAgentTarget::New {
1051                workspace: WorkspaceTarget::Current,
1052                agent: Some(agent_id),
1053            },
1054        };
1055
1056        let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1057        let captured = captured.lock().await.clone().expect("no config captured");
1058
1059        let backend_names: Vec<String> = captured
1060            .mcp_backends
1061            .iter()
1062            .map(|backend| match backend {
1063                BackendConfig::Mcp { server_name, .. } => server_name.clone(),
1064            })
1065            .collect();
1066
1067        assert_eq!(backend_names, vec!["allowed-server".to_string()]);
1068        assert!(captured.allow_mcp_tools);
1069    }
1070
1071    #[tokio::test]
1072    async fn dispatch_agent_uses_parent_model_when_spec_missing_model() {
1073        let event_store = Arc::new(InMemoryEventStore::new());
1074        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1075        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1076        let api_client = Arc::new(ApiClient::new_with_deps(
1077            crate::test_utils::test_llm_config_provider().unwrap(),
1078            provider_registry,
1079            model_registry,
1080        ));
1081        let workspace =
1082            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1083                path: std::env::current_dir().unwrap(),
1084            })
1085            .await
1086            .unwrap();
1087
1088        let parent_session_id = SessionId::new();
1089        let parent_model = builtin::claude_sonnet_4_5();
1090        let session_config = SessionConfig::read_only(parent_model.clone());
1091
1092        event_store.create_session(parent_session_id).await.unwrap();
1093        event_store
1094            .append(
1095                parent_session_id,
1096                &SessionEvent::SessionCreated {
1097                    config: Box::new(session_config),
1098                    metadata: HashMap::new(),
1099                    parent_session_id: None,
1100                },
1101            )
1102            .await
1103            .unwrap();
1104
1105        let agent_id = format!("inherit_model_{}", Uuid::new_v4());
1106        let spec = AgentSpec {
1107            id: agent_id.clone(),
1108            name: "inherit model test".to_string(),
1109            description: "inherit model test".to_string(),
1110            tools: vec![VIEW_TOOL_NAME.to_string()],
1111            mcp_access: McpAccessPolicy::None,
1112            model: None,
1113        };
1114        match register_agent_spec(spec) {
1115            Ok(()) => {}
1116            Err(AgentSpecError::AlreadyRegistered(_)) => {}
1117            Err(AgentSpecError::RegistryPoisoned) => {}
1118        }
1119
1120        let captured = Arc::new(tokio::sync::Mutex::new(None));
1121        let spawner = CapturingAgentSpawner {
1122            session_id: SessionId::new(),
1123            response: "ok".to_string(),
1124            captured: captured.clone(),
1125        };
1126
1127        let services = Arc::new(
1128            ToolServices::new(workspace, event_store, api_client)
1129                .with_agent_spawner(Arc::new(spawner)),
1130        );
1131
1132        let ctx = StaticToolContext {
1133            tool_call_id: ToolCallId::new(),
1134            session_id: parent_session_id,
1135            cancellation_token: CancellationToken::new(),
1136            services,
1137        };
1138
1139        let params = DispatchAgentParams {
1140            prompt: "test".to_string(),
1141            target: DispatchAgentTarget::New {
1142                workspace: WorkspaceTarget::Current,
1143                agent: Some(agent_id),
1144            },
1145        };
1146
1147        let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1148        let captured = captured.lock().await.clone().expect("no config captured");
1149
1150        assert_eq!(captured.model, parent_model);
1151    }
1152
1153    #[tokio::test]
1154    async fn dispatch_agent_uses_spec_model_when_set() {
1155        let event_store = Arc::new(InMemoryEventStore::new());
1156        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1157        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1158        let api_client = Arc::new(ApiClient::new_with_deps(
1159            crate::test_utils::test_llm_config_provider().unwrap(),
1160            provider_registry,
1161            model_registry,
1162        ));
1163        let workspace =
1164            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1165                path: std::env::current_dir().unwrap(),
1166            })
1167            .await
1168            .unwrap();
1169
1170        let parent_session_id = SessionId::new();
1171        let parent_model = builtin::claude_sonnet_4_5();
1172        let session_config = SessionConfig::read_only(parent_model);
1173
1174        event_store.create_session(parent_session_id).await.unwrap();
1175        event_store
1176            .append(
1177                parent_session_id,
1178                &SessionEvent::SessionCreated {
1179                    config: Box::new(session_config),
1180                    metadata: HashMap::new(),
1181                    parent_session_id: None,
1182                },
1183            )
1184            .await
1185            .unwrap();
1186
1187        let spec_model = builtin::claude_haiku_4_5();
1188        let agent_id = format!("spec_model_{}", Uuid::new_v4());
1189        let spec = AgentSpec {
1190            id: agent_id.clone(),
1191            name: "spec model test".to_string(),
1192            description: "spec model test".to_string(),
1193            tools: vec![VIEW_TOOL_NAME.to_string()],
1194            mcp_access: McpAccessPolicy::None,
1195            model: Some(spec_model.clone()),
1196        };
1197        match register_agent_spec(spec) {
1198            Ok(()) => {}
1199            Err(AgentSpecError::AlreadyRegistered(_)) => {}
1200            Err(AgentSpecError::RegistryPoisoned) => {}
1201        }
1202
1203        let captured = Arc::new(tokio::sync::Mutex::new(None));
1204        let spawner = CapturingAgentSpawner {
1205            session_id: SessionId::new(),
1206            response: "ok".to_string(),
1207            captured: captured.clone(),
1208        };
1209
1210        let services = Arc::new(
1211            ToolServices::new(workspace, event_store, api_client)
1212                .with_agent_spawner(Arc::new(spawner)),
1213        );
1214
1215        let ctx = StaticToolContext {
1216            tool_call_id: ToolCallId::new(),
1217            session_id: parent_session_id,
1218            cancellation_token: CancellationToken::new(),
1219            services,
1220        };
1221
1222        let params = DispatchAgentParams {
1223            prompt: "test".to_string(),
1224            target: DispatchAgentTarget::New {
1225                workspace: WorkspaceTarget::Current,
1226                agent: Some(agent_id),
1227            },
1228        };
1229
1230        let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1231        let captured = captured.lock().await.clone().expect("no config captured");
1232
1233        assert_eq!(captured.model, spec_model);
1234    }
1235
1236    #[tokio::test]
1237    async fn resume_session_denies_disallowed_tools() {
1238        let event_store = Arc::new(InMemoryEventStore::new());
1239        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1240        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1241        let api_client = Arc::new(ApiClient::new_with_deps(
1242            crate::test_utils::test_llm_config_provider().unwrap(),
1243            provider_registry,
1244            model_registry,
1245        ));
1246        let workspace =
1247            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1248                path: std::env::current_dir().unwrap(),
1249            })
1250            .await
1251            .unwrap();
1252
1253        let parent_session_id = SessionId::new();
1254        let session_id = SessionId::new();
1255        let model = builtin::claude_sonnet_4_5();
1256
1257        let tool_call = steer_tools::ToolCall {
1258            name: "bash".to_string(),
1259            parameters: serde_json::json!({ "command": "echo denied" }),
1260            id: "tool_denied".to_string(),
1261        };
1262        api_client.insert_test_provider(
1263            model.provider.clone(),
1264            Arc::new(ToolCallThenTextProvider::new(tool_call, "done")),
1265        );
1266
1267        let mut session_config = SessionConfig::read_only(model);
1268        session_config.parent_session_id = Some(parent_session_id);
1269        session_config.policy_overrides = SessionPolicyOverrides {
1270            default_model: None,
1271            tool_visibility: Some(ToolVisibility::Whitelist(HashSet::from([
1272                VIEW_TOOL_NAME.to_string()
1273            ]))),
1274            approval_policy: ToolApprovalPolicyOverrides {
1275                default_behavior: Some(UnapprovedBehavior::Deny),
1276                preapproved: ApprovalRulesOverrides {
1277                    tools: HashSet::from([VIEW_TOOL_NAME.to_string()]),
1278                    per_tool: HashMap::new(),
1279                },
1280            },
1281        };
1282
1283        event_store.create_session(session_id).await.unwrap();
1284        event_store
1285            .append(
1286                session_id,
1287                &SessionEvent::SessionCreated {
1288                    config: Box::new(session_config),
1289                    metadata: HashMap::new(),
1290                    parent_session_id: Some(parent_session_id),
1291                },
1292            )
1293            .await
1294            .unwrap();
1295
1296        let services = Arc::new(ToolServices::new(
1297            workspace,
1298            event_store.clone(),
1299            api_client,
1300        ));
1301
1302        let ctx = StaticToolContext {
1303            tool_call_id: ToolCallId::new(),
1304            session_id: parent_session_id,
1305            cancellation_token: CancellationToken::new(),
1306            services,
1307        };
1308
1309        let _ = resume_agent_session(session_id, "trigger".to_string(), &ctx)
1310            .await
1311            .unwrap();
1312
1313        let events = event_store.load_events(session_id).await.unwrap();
1314        let denied = events.iter().any(|(_, event)| match event {
1315            SessionEvent::ToolCallFailed { name, error, .. } => {
1316                name == "bash" && error.contains("denied by policy")
1317            }
1318            _ => false,
1319        });
1320
1321        assert!(denied, "expected denied ToolCallFailed event for bash");
1322    }
1323}