Skip to main content

steer_core/tools/static_tools/
dispatch_agent.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::agents::{
6    McpAccessPolicy, agent_spec, agent_specs, agent_specs_prompt, default_agent_spec_id,
7};
8use crate::app::domain::event::SessionEvent;
9use crate::app::domain::runtime::RuntimeService;
10use crate::app::domain::types::SessionId;
11use crate::app::validation::ValidatorRegistry;
12use crate::config::model::builtin::claude_sonnet_4_5 as default_model;
13use crate::runners::OneShotRunner;
14use crate::session::state::BackendConfig;
15use crate::tools::capability::Capabilities;
16use crate::tools::services::{SubAgentConfig, SubAgentError, ToolServices};
17use crate::tools::static_tool::{StaticTool, StaticToolContext, StaticToolError};
18use crate::tools::{BackendRegistry, ToolExecutor, ToolRegistry};
19use crate::workspace::{
20    CreateWorkspaceRequest, EnvironmentId, RepoRef, VcsKind, VcsStatus, Workspace,
21    WorkspaceCreateStrategy, WorkspaceRef, create_workspace_from_session_config,
22};
23use steer_tools::ToolSpec;
24use steer_tools::result::{AgentResult, AgentWorkspaceInfo, AgentWorkspaceRevision};
25use steer_tools::tools::dispatch_agent::{
26    DispatchAgentError, DispatchAgentParams, DispatchAgentTarget, DispatchAgentToolSpec,
27    WorkspaceTarget,
28};
29use steer_tools::tools::{GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
30use tracing::warn;
31
32use super::{
33    AstGrepTool, BashTool, EditTool, FetchTool, GlobTool, GrepTool, LsTool, MultiEditTool,
34    ReplaceTool, TodoReadTool, TodoWriteTool, ViewTool, workspace_manager_op_error,
35    workspace_op_error,
36};
37
38fn dispatch_agent_description() -> String {
39    let agent_specs = agent_specs_prompt();
40    let agent_specs_block = if agent_specs.is_empty() {
41        "No agent specs registered.".to_string()
42    } else {
43        agent_specs
44    };
45
46    format!(
47        r#"Launch a new agent to help with a focused task. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you.
48
49    When to use the Agent tool:
50    - If you are searching for a keyword like "config" or "logger", or for questions like "which file does X?", the Agent tool is strongly recommended
51
52    When NOT to use the Agent tool:
53    - If you want to read a specific file path, use the {} or {} tool instead of the Agent tool, to find the match more quickly
54    - If you are searching for a specific class definition like "class Foo", use the {} tool instead, to find the match more quickly
55    - If you are searching for code within a specific file or set of 2-3 files, use the {} tool instead, to find the match more quickly
56
57    Usage notes:
58    1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses
59    2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.
60    3. Each invocation returns a session_id. Pass it back via `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue the conversation with the same agent.
61    4. When `target.session` is `resume`, the session_id must refer to a child of the current session. The `agent` and `workspace` options are ignored and the existing session config is used.
62    5. The agent's outputs should generally be trusted
63    6. IMPORTANT: Only some agent specs include write tools. Use a build agent if the task requires editing files.
64    7. IMPORTANT: New workspaces are preserved (not auto-deleted). Clean them up manually if needed.
65    8. If the agent spec omits a model, the parent session's default model is used.
66
67Workspace options:
68- `workspace: {{ "location": "current" }}` to run in the current workspace
69- `workspace: {{ "location": "new", "name": "..." }}` to run in a fresh workspace (jj workspace or git worktree)
70- `location` is a logical workspace selector, not a filesystem path
71
72    Session options:
73    - `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue a prior dispatch_agent session
74
75    New session options:
76    - `target: {{ "session": "new", "workspace": {{ "location": "current" }} }}` to run in the current workspace
77    - `target: {{ "session": "new", "workspace": {{ "location": "new", "name": "..." }} }}` to run in a new workspace
78    - `target: {{ "session": "new", "workspace": {{ "location": "current" }}, "agent": "<id>" }}` selects an agent spec (defaults to "{default_agent}")
79
80    {agent_specs_block}"#,
81        VIEW_TOOL_NAME,
82        LS_TOOL_NAME,
83        GREP_TOOL_NAME,
84        GREP_TOOL_NAME,
85        default_agent = default_agent_spec_id(),
86        agent_specs_block = agent_specs_block
87    )
88}
89
90pub struct DispatchAgentTool;
91
92#[async_trait]
93impl StaticTool for DispatchAgentTool {
94    type Params = DispatchAgentParams;
95    type Output = AgentResult;
96    type Spec = DispatchAgentToolSpec;
97
98    const DESCRIPTION: &'static str = "Launch a sub-agent to search for files or code";
99    const REQUIRES_APPROVAL: bool = false;
100    const REQUIRED_CAPABILITIES: Capabilities = Capabilities::AGENT;
101
102    fn schema() -> steer_tools::ToolSchema {
103        let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
104            s.inline_subschemas = true;
105        });
106        let schema_gen = settings.into_generator();
107        let input_schema = schema_gen.into_root_schema_for::<Self::Params>();
108
109        steer_tools::ToolSchema {
110            name: Self::Spec::NAME.to_string(),
111            display_name: Self::Spec::DISPLAY_NAME.to_string(),
112            description: dispatch_agent_description(),
113            input_schema: input_schema.into(),
114        }
115    }
116
117    async fn execute(
118        &self,
119        params: Self::Params,
120        ctx: &StaticToolContext,
121    ) -> Result<Self::Output, StaticToolError<DispatchAgentError>> {
122        let DispatchAgentParams { prompt, target } = params;
123
124        let (workspace_target, agent) = match target {
125            DispatchAgentTarget::Resume { session_id } => {
126                let session_id = SessionId::parse(&session_id).ok_or_else(|| {
127                    StaticToolError::invalid_params(format!("Invalid session_id '{session_id}'"))
128                })?;
129                return resume_agent_session(session_id, prompt, ctx).await;
130            }
131            DispatchAgentTarget::New { workspace, agent } => (workspace, agent),
132        };
133
134        let spawner = ctx
135            .services
136            .agent_spawner()
137            .ok_or_else(|| StaticToolError::missing_capability("agent_spawner"))?;
138
139        let base_workspace = ctx.services.workspace.clone();
140        let base_path = base_workspace.working_directory().to_path_buf();
141
142        let mut workspace = base_workspace.clone();
143        let mut workspace_ref = None;
144        let mut workspace_id = None;
145        let mut workspace_name = None;
146        let mut repo_id = None;
147        let mut repo_ref = None;
148
149        if let Some(manager) = ctx.services.workspace_manager()
150            && let Ok(info) = manager.resolve_workspace(&base_path).await
151        {
152            workspace_id = Some(info.workspace_id);
153            workspace_name.clone_from(&info.name);
154            repo_id = Some(info.repo_id);
155            workspace_ref = Some(WorkspaceRef {
156                environment_id: info.environment_id,
157                workspace_id: info.workspace_id,
158                repo_id: info.repo_id,
159            });
160        }
161
162        if let Some(manager) = ctx.services.repo_manager() {
163            let repo_env_id = workspace_ref
164                .as_ref()
165                .map_or_else(EnvironmentId::local, |reference| reference.environment_id);
166            if let Ok(info) = manager.resolve_repo(repo_env_id, &base_path).await {
167                if repo_id.is_none() {
168                    repo_id = Some(info.repo_id);
169                }
170                repo_ref = Some(RepoRef {
171                    environment_id: info.environment_id,
172                    repo_id: info.repo_id,
173                    root_path: info.root_path,
174                    vcs_kind: info.vcs_kind,
175                });
176            }
177        }
178
179        let mut new_workspace = false;
180        let mut requested_workspace_name = None;
181
182        match &workspace_target {
183            WorkspaceTarget::Current => {}
184            WorkspaceTarget::New { name } => {
185                new_workspace = true;
186                requested_workspace_name = Some(name.clone());
187            }
188        }
189
190        let mut created_workspace_id = None;
191        let mut status_manager = None;
192
193        if new_workspace {
194            let manager = ctx
195                .services
196                .workspace_manager()
197                .ok_or_else(|| StaticToolError::missing_capability("workspace_manager"))?;
198            status_manager = Some(manager.clone());
199
200            let base_repo_id = repo_id.ok_or_else(|| {
201                StaticToolError::execution(DispatchAgentError::WorkspaceUnavailable {
202                    message:
203                        "Current path is not a supported workspace; cannot create new workspace"
204                            .to_string(),
205                })
206            })?;
207
208            let strategy = match repo_ref
209                .as_ref()
210                .and_then(|reference| reference.vcs_kind.as_ref())
211            {
212                Some(VcsKind::Git) => WorkspaceCreateStrategy::GitWorktree,
213                _ => WorkspaceCreateStrategy::JjWorkspace,
214            };
215
216            let create_request = CreateWorkspaceRequest {
217                repo_id: base_repo_id,
218                name: requested_workspace_name.clone(),
219                parent_workspace_id: workspace_id,
220                strategy,
221            };
222
223            let info = manager
224                .create_workspace(create_request)
225                .await
226                .map_err(|e| {
227                    StaticToolError::execution(DispatchAgentError::Workspace(
228                        workspace_manager_op_error(e),
229                    ))
230                })?;
231
232            workspace = manager
233                .open_workspace(info.workspace_id)
234                .await
235                .map_err(|e| {
236                    StaticToolError::execution(DispatchAgentError::Workspace(
237                        workspace_manager_op_error(e),
238                    ))
239                })?;
240
241            workspace_id = Some(info.workspace_id);
242            created_workspace_id = Some(info.workspace_id);
243            workspace_name.clone_from(&info.name);
244            workspace_ref = Some(WorkspaceRef {
245                environment_id: info.environment_id,
246                workspace_id: info.workspace_id,
247                repo_id: info.repo_id,
248            });
249
250            if let Some(repo_manager) = ctx.services.repo_manager()
251                && let Ok(info) = repo_manager
252                    .resolve_repo(info.environment_id, workspace.working_directory())
253                    .await
254            {
255                repo_ref = Some(RepoRef {
256                    environment_id: info.environment_id,
257                    repo_id: info.repo_id,
258                    root_path: info.root_path,
259                    vcs_kind: info.vcs_kind,
260                });
261            }
262        }
263
264        let env_info = workspace.environment().await.map_err(|e| {
265            StaticToolError::execution(DispatchAgentError::Workspace(workspace_op_error(e)))
266        })?;
267
268        let system_prompt = format!(
269            r#"You are an agent for a CLI-based coding tool. Given the user's prompt, you should use the tools available to you to answer the user's question.
270
271Notes:
2721. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2732. When relevant, share file names and code snippets relevant to the query
2743. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.
275
276{}
277"#,
278            env_info.as_context()
279        );
280
281        let agent_id = agent
282            .as_deref()
283            .filter(|value| !value.trim().is_empty())
284            .map_or_else(|| default_agent_spec_id().to_string(), str::to_string);
285
286        let agent_spec = agent_spec(&agent_id).ok_or_else(|| {
287            let available = agent_specs()
288                .into_iter()
289                .map(|spec| spec.id)
290                .collect::<Vec<_>>()
291                .join(", ");
292            StaticToolError::invalid_params(format!(
293                "Unknown agent spec '{agent_id}'. Available: {available}"
294            ))
295        })?;
296
297        let parent_session_config = match ctx.services.event_store.load_events(ctx.session_id).await
298        {
299            Ok(events) => events.into_iter().find_map(|(_, event)| match event {
300                SessionEvent::SessionCreated { config, .. } => Some(*config),
301                _ => None,
302            }),
303            Err(err) => {
304                warn!(
305                    session_id = %ctx.session_id,
306                    "Failed to load parent session config for MCP servers: {err}"
307                );
308                None
309            }
310        };
311
312        let parent_mcp_backends = parent_session_config
313            .as_ref()
314            .map(|config| config.tool_config.backends.clone())
315            .unwrap_or_default();
316
317        let parent_model = parent_session_config
318            .as_ref()
319            .map_or_else(default_model, |config| config.default_model.clone());
320
321        let allow_mcp_tools = agent_spec.mcp_access.allow_mcp_tools();
322        let mcp_backends = match &agent_spec.mcp_access {
323            McpAccessPolicy::None => Vec::new(),
324            McpAccessPolicy::All => parent_mcp_backends,
325            McpAccessPolicy::Allowlist(servers) => parent_mcp_backends
326                .into_iter()
327                .filter(|backend| match backend {
328                    BackendConfig::Mcp { server_name, .. } => {
329                        servers.iter().any(|allowed| allowed == server_name)
330                    }
331                })
332                .collect(),
333        };
334
335        let config = SubAgentConfig {
336            parent_session_id: ctx.session_id,
337            prompt,
338            allowed_tools: agent_spec.tools.clone(),
339            model: agent_spec.model.clone().unwrap_or(parent_model),
340            system_context: Some(crate::app::SystemContext::new(system_prompt)),
341            workspace: Some(workspace),
342            workspace_ref,
343            workspace_id,
344            repo_ref,
345            workspace_name,
346            mcp_backends,
347            allow_mcp_tools,
348        };
349
350        let spawn_result = spawner.spawn(config, ctx.cancellation_token.clone()).await;
351
352        let mut workspace_info = None;
353
354        if let (Some(manager), Some(workspace_id)) = (status_manager, created_workspace_id) {
355            let revision = match manager.get_workspace_status(workspace_id).await {
356                Ok(status) => match status.vcs {
357                    Some(vcs) => match vcs.status {
358                        VcsStatus::Jj(jj_status) => {
359                            jj_status.working_copy.map(|wc| AgentWorkspaceRevision {
360                                vcs_kind: "jj".to_string(),
361                                revision_id: wc.commit_id,
362                                summary: wc.description,
363                                change_id: Some(wc.change_id),
364                            })
365                        }
366                        VcsStatus::Git(_) => None,
367                    },
368                    None => None,
369                },
370                Err(err) => {
371                    warn!(
372                        workspace_id = %workspace_id.as_uuid(),
373                        "Failed to get workspace status for sub-agent: {err}"
374                    );
375                    None
376                }
377            };
378
379            workspace_info = Some(AgentWorkspaceInfo {
380                workspace_id: Some(workspace_id.as_uuid().to_string()),
381                revision,
382            });
383        }
384
385        let result = spawn_result.map_err(|e| match e {
386            SubAgentError::Cancelled => StaticToolError::Cancelled,
387            other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
388                message: other.to_string(),
389            }),
390        })?;
391
392        Ok(AgentResult {
393            content: result.final_message.extract_text(),
394            session_id: Some(result.session_id.to_string()),
395            workspace: workspace_info,
396        })
397    }
398}
399
400fn build_runtime_tool_executor(
401    workspace: Arc<dyn Workspace>,
402    parent_services: &Arc<ToolServices>,
403) -> Arc<ToolExecutor> {
404    let mut services = ToolServices::new(
405        workspace.clone(),
406        parent_services.event_store.clone(),
407        parent_services.api_client.clone(),
408    );
409
410    if let Some(spawner) = parent_services.agent_spawner() {
411        services = services.with_agent_spawner(spawner.clone());
412    }
413    if let Some(caller) = parent_services.model_caller() {
414        services = services.with_model_caller(caller.clone());
415    }
416    if let Some(manager) = parent_services.workspace_manager() {
417        services = services.with_workspace_manager(manager.clone());
418    }
419    if let Some(manager) = parent_services.repo_manager() {
420        services = services.with_repo_manager(manager.clone());
421    }
422    if parent_services
423        .capabilities()
424        .contains(Capabilities::NETWORK)
425    {
426        services = services.with_network();
427    }
428
429    let mut registry = ToolRegistry::new();
430    registry.register_static(GrepTool);
431    registry.register_static(GlobTool);
432    registry.register_static(LsTool);
433    registry.register_static(ViewTool);
434    registry.register_static(BashTool);
435    registry.register_static(EditTool);
436    registry.register_static(MultiEditTool);
437    registry.register_static(ReplaceTool);
438    registry.register_static(AstGrepTool);
439    registry.register_static(TodoReadTool);
440    registry.register_static(TodoWriteTool);
441    registry.register_static(DispatchAgentTool);
442    registry.register_static(FetchTool);
443
444    Arc::new(
445        ToolExecutor::with_components(
446            Arc::new(BackendRegistry::new()),
447            Arc::new(ValidatorRegistry::new()),
448        )
449        .with_static_tools(Arc::new(registry), Arc::new(services)),
450    )
451}
452
453async fn resume_agent_session(
454    session_id: SessionId,
455    prompt: String,
456    ctx: &StaticToolContext,
457) -> Result<AgentResult, StaticToolError<DispatchAgentError>> {
458    let events = ctx
459        .services
460        .event_store
461        .load_events(session_id)
462        .await
463        .map_err(|e| {
464            StaticToolError::execution(DispatchAgentError::SpawnFailed {
465                message: format!("Failed to load session {session_id}: {e}"),
466            })
467        })?;
468
469    let session_config = events
470        .into_iter()
471        .find_map(|(_, event)| match event {
472            SessionEvent::SessionCreated { config, .. } => Some(*config),
473            _ => None,
474        })
475        .ok_or_else(|| {
476            StaticToolError::execution(DispatchAgentError::SpawnFailed {
477                message: format!("Session {session_id} is missing a SessionCreated event"),
478            })
479        })?;
480
481    if session_config.parent_session_id != Some(ctx.session_id) {
482        return Err(StaticToolError::invalid_params(format!(
483            "Session {session_id} is not a child of current session {}",
484            ctx.session_id
485        )));
486    }
487
488    let workspace = create_workspace_from_session_config(&session_config.workspace)
489        .await
490        .map_err(|e| {
491            StaticToolError::execution(DispatchAgentError::SpawnFailed {
492                message: format!("Failed to open workspace for session {session_id}: {e}"),
493            })
494        })?;
495
496    let tool_executor = build_runtime_tool_executor(workspace, &ctx.services);
497    let runtime = RuntimeService::spawn(
498        ctx.services.event_store.clone(),
499        ctx.services.api_client.clone(),
500        tool_executor,
501    );
502
503    let run_result = OneShotRunner::run_in_session_with_cancel(
504        &runtime.handle,
505        session_id,
506        prompt,
507        session_config.default_model.clone(),
508        ctx.cancellation_token.clone(),
509    )
510    .await;
511
512    runtime.shutdown().await;
513
514    let run_result = run_result.map_err(|e| match e {
515        crate::error::Error::Cancelled => StaticToolError::Cancelled,
516        other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
517            message: other.to_string(),
518        }),
519    })?;
520
521    Ok(AgentResult {
522        content: run_result.final_message.extract_text(),
523        session_id: Some(run_result.session_id.to_string()),
524        workspace: None,
525    })
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::agents::{AgentSpec, AgentSpecError, McpAccessPolicy, register_agent_spec};
532    use crate::api::Client as ApiClient;
533    use crate::api::{ApiError, CompletionResponse, Provider};
534    use crate::app::conversation::{AssistantContent, Message, MessageData};
535    use crate::app::domain::session::EventStore;
536    use crate::app::domain::session::event_store::InMemoryEventStore;
537    use crate::app::domain::types::ToolCallId;
538    use crate::config::model::builtin;
539    use crate::model_registry::ModelRegistry;
540    use crate::session::state::{
541        ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, ToolApprovalPolicyOverrides,
542        ToolFilter, ToolVisibility, UnapprovedBehavior,
543    };
544    use crate::tools::McpTransport;
545    use crate::tools::services::{AgentSpawner, SubAgentError, SubAgentResult, ToolServices};
546    use async_trait::async_trait;
547    use std::collections::{HashMap, HashSet};
548    use std::sync::Mutex as StdMutex;
549    use tokio::time::{Duration, sleep};
550    use tokio_util::sync::CancellationToken;
551    use uuid::Uuid;
552
553    #[derive(Clone)]
554    struct StubProvider {
555        response: String,
556    }
557
558    impl StubProvider {
559        fn new(response: impl Into<String>) -> Self {
560            Self {
561                response: response.into(),
562            }
563        }
564    }
565
566    #[derive(Clone)]
567    struct CancelAwareProvider;
568
569    #[async_trait]
570    impl Provider for CancelAwareProvider {
571        fn name(&self) -> &'static str {
572            "cancel-aware"
573        }
574
575        async fn complete(
576            &self,
577            _model_id: &crate::config::model::ModelId,
578            _messages: Vec<Message>,
579            _system: Option<crate::app::SystemContext>,
580            _tools: Option<Vec<steer_tools::ToolSchema>>,
581            _call_options: Option<crate::config::model::ModelParameters>,
582            token: CancellationToken,
583        ) -> Result<CompletionResponse, ApiError> {
584            token.cancelled().await;
585            Err(ApiError::Cancelled {
586                provider: self.name().to_string(),
587            })
588        }
589    }
590
591    #[async_trait]
592    impl Provider for StubProvider {
593        fn name(&self) -> &'static str {
594            "stub"
595        }
596
597        async fn complete(
598            &self,
599            _model_id: &crate::config::model::ModelId,
600            _messages: Vec<Message>,
601            _system: Option<crate::app::SystemContext>,
602            _tools: Option<Vec<steer_tools::ToolSchema>>,
603            _call_options: Option<crate::config::model::ModelParameters>,
604            _token: CancellationToken,
605        ) -> Result<CompletionResponse, ApiError> {
606            Ok(CompletionResponse {
607                content: vec![AssistantContent::Text {
608                    text: self.response.clone(),
609                }],
610            })
611        }
612    }
613
614    #[derive(Clone)]
615    struct StubAgentSpawner {
616        session_id: SessionId,
617        response: String,
618    }
619
620    #[async_trait]
621    impl AgentSpawner for StubAgentSpawner {
622        async fn spawn(
623            &self,
624            _config: crate::tools::services::SubAgentConfig,
625            _cancel_token: CancellationToken,
626        ) -> Result<SubAgentResult, SubAgentError> {
627            let timestamp = Message::current_timestamp();
628            let message = Message {
629                timestamp,
630                id: Message::generate_id("assistant", timestamp),
631                parent_message_id: None,
632                data: MessageData::Assistant {
633                    content: vec![AssistantContent::Text {
634                        text: self.response.clone(),
635                    }],
636                },
637            };
638
639            Ok(SubAgentResult {
640                session_id: self.session_id,
641                final_message: message,
642            })
643        }
644    }
645
646    #[derive(Clone)]
647    struct CapturingAgentSpawner {
648        session_id: SessionId,
649        response: String,
650        captured: Arc<tokio::sync::Mutex<Option<crate::tools::services::SubAgentConfig>>>,
651    }
652
653    #[async_trait]
654    impl AgentSpawner for CapturingAgentSpawner {
655        async fn spawn(
656            &self,
657            config: crate::tools::services::SubAgentConfig,
658            _cancel_token: CancellationToken,
659        ) -> Result<SubAgentResult, SubAgentError> {
660            let mut guard = self.captured.lock().await;
661            *guard = Some(config);
662
663            let timestamp = Message::current_timestamp();
664            let message = Message {
665                timestamp,
666                id: Message::generate_id("assistant", timestamp),
667                parent_message_id: None,
668                data: MessageData::Assistant {
669                    content: vec![AssistantContent::Text {
670                        text: self.response.clone(),
671                    }],
672                },
673            };
674
675            Ok(SubAgentResult {
676                session_id: self.session_id,
677                final_message: message,
678            })
679        }
680    }
681
682    #[derive(Clone)]
683    struct ToolCallThenTextProvider {
684        tool_call: steer_tools::ToolCall,
685        final_text: String,
686        call_count: Arc<StdMutex<usize>>,
687    }
688
689    impl ToolCallThenTextProvider {
690        fn new(tool_call: steer_tools::ToolCall, final_text: impl Into<String>) -> Self {
691            Self {
692                tool_call,
693                final_text: final_text.into(),
694                call_count: Arc::new(StdMutex::new(0)),
695            }
696        }
697    }
698
699    #[async_trait]
700    impl Provider for ToolCallThenTextProvider {
701        fn name(&self) -> &'static str {
702            "stub-tool-call"
703        }
704
705        async fn complete(
706            &self,
707            _model_id: &crate::config::model::ModelId,
708            _messages: Vec<Message>,
709            _system: Option<crate::app::SystemContext>,
710            _tools: Option<Vec<steer_tools::ToolSchema>>,
711            _call_options: Option<crate::config::model::ModelParameters>,
712            _token: CancellationToken,
713        ) -> Result<CompletionResponse, ApiError> {
714            let mut count = self
715                .call_count
716                .lock()
717                .expect("tool call counter lock poisoned");
718            let response = if *count == 0 {
719                CompletionResponse {
720                    content: vec![AssistantContent::ToolCall {
721                        tool_call: self.tool_call.clone(),
722                        thought_signature: None,
723                    }],
724                }
725            } else {
726                CompletionResponse {
727                    content: vec![AssistantContent::Text {
728                        text: self.final_text.clone(),
729                    }],
730                }
731            };
732            *count += 1;
733            Ok(response)
734        }
735    }
736
737    #[tokio::test]
738    async fn resume_session_rejects_non_child() {
739        let event_store = Arc::new(InMemoryEventStore::new());
740        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
741        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
742        let api_client = Arc::new(ApiClient::new_with_deps(
743            crate::test_utils::test_llm_config_provider().unwrap(),
744            provider_registry,
745            model_registry,
746        ));
747        let workspace =
748            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
749                path: std::env::current_dir().unwrap(),
750            })
751            .await
752            .unwrap();
753
754        let parent_session_id = SessionId::new();
755        let session_id = SessionId::new();
756        let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
757        session_config.parent_session_id = Some(parent_session_id);
758
759        event_store.create_session(session_id).await.unwrap();
760        event_store
761            .append(
762                session_id,
763                &SessionEvent::SessionCreated {
764                    config: Box::new(session_config),
765                    metadata: std::collections::HashMap::new(),
766                    parent_session_id: Some(parent_session_id),
767                },
768            )
769            .await
770            .unwrap();
771
772        let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
773
774        let ctx = StaticToolContext {
775            tool_call_id: ToolCallId::new(),
776            session_id: SessionId::new(),
777            cancellation_token: CancellationToken::new(),
778            services,
779        };
780
781        let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
782
783        assert!(matches!(result, Err(StaticToolError::InvalidParams(_))));
784    }
785
786    #[tokio::test]
787    async fn resume_session_accepts_child_and_returns_message() {
788        let event_store = Arc::new(InMemoryEventStore::new());
789        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
790        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
791        let api_client = Arc::new(ApiClient::new_with_deps(
792            crate::test_utils::test_llm_config_provider().unwrap(),
793            provider_registry,
794            model_registry,
795        ));
796        let model = builtin::claude_sonnet_4_5();
797        api_client.insert_test_provider(
798            model.provider.clone(),
799            Arc::new(StubProvider::new("stub-response")),
800        );
801        let workspace =
802            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
803                path: std::env::current_dir().unwrap(),
804            })
805            .await
806            .unwrap();
807
808        let parent_session_id = SessionId::new();
809        let session_id = SessionId::new();
810        let mut session_config = SessionConfig::read_only(model.clone());
811        session_config.parent_session_id = Some(parent_session_id);
812
813        event_store.create_session(session_id).await.unwrap();
814        event_store
815            .append(
816                session_id,
817                &SessionEvent::SessionCreated {
818                    config: Box::new(session_config),
819                    metadata: std::collections::HashMap::new(),
820                    parent_session_id: Some(parent_session_id),
821                },
822            )
823            .await
824            .unwrap();
825
826        let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
827
828        let ctx = StaticToolContext {
829            tool_call_id: ToolCallId::new(),
830            session_id: parent_session_id,
831            cancellation_token: CancellationToken::new(),
832            services,
833        };
834
835        let result = resume_agent_session(session_id, "ping".to_string(), &ctx)
836            .await
837            .unwrap();
838
839        assert!(result.content.contains("stub-response"));
840        assert_eq!(
841            result.session_id.as_deref(),
842            Some(session_id.to_string().as_str())
843        );
844    }
845
846    #[tokio::test]
847    async fn resume_session_honors_cancellation() {
848        let event_store = Arc::new(InMemoryEventStore::new());
849        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
850        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
851        let api_client = Arc::new(ApiClient::new_with_deps(
852            crate::test_utils::test_llm_config_provider().unwrap(),
853            provider_registry,
854            model_registry,
855        ));
856        let model = builtin::claude_sonnet_4_5();
857        api_client.insert_test_provider(model.provider.clone(), Arc::new(CancelAwareProvider));
858        let workspace =
859            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
860                path: std::env::current_dir().unwrap(),
861            })
862            .await
863            .unwrap();
864
865        let parent_session_id = SessionId::new();
866        let session_id = SessionId::new();
867        let mut session_config = SessionConfig::read_only(model);
868        session_config.parent_session_id = Some(parent_session_id);
869
870        event_store.create_session(session_id).await.unwrap();
871        event_store
872            .append(
873                session_id,
874                &SessionEvent::SessionCreated {
875                    config: Box::new(session_config),
876                    metadata: std::collections::HashMap::new(),
877                    parent_session_id: Some(parent_session_id),
878                },
879            )
880            .await
881            .unwrap();
882
883        let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
884
885        let cancel_token = CancellationToken::new();
886        let ctx = StaticToolContext {
887            tool_call_id: ToolCallId::new(),
888            session_id: parent_session_id,
889            cancellation_token: cancel_token.clone(),
890            services,
891        };
892
893        let cancel_task = tokio::spawn(async move {
894            sleep(Duration::from_millis(10)).await;
895            cancel_token.cancel();
896        });
897
898        let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
899        let _ = cancel_task.await;
900
901        assert!(matches!(result, Err(StaticToolError::Cancelled)));
902    }
903
904    #[tokio::test]
905    async fn dispatch_agent_returns_session_id() {
906        let event_store = Arc::new(InMemoryEventStore::new());
907        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
908        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
909        let api_client = Arc::new(ApiClient::new_with_deps(
910            crate::test_utils::test_llm_config_provider().unwrap(),
911            provider_registry,
912            model_registry,
913        ));
914        let workspace =
915            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
916                path: std::env::current_dir().unwrap(),
917            })
918            .await
919            .unwrap();
920
921        let session_id = SessionId::new();
922        let spawner = StubAgentSpawner {
923            session_id,
924            response: "done".to_string(),
925        };
926
927        let services = Arc::new(
928            ToolServices::new(workspace, event_store, api_client)
929                .with_agent_spawner(Arc::new(spawner)),
930        );
931
932        let ctx = StaticToolContext {
933            tool_call_id: ToolCallId::new(),
934            session_id: SessionId::new(),
935            cancellation_token: CancellationToken::new(),
936            services,
937        };
938
939        let params = DispatchAgentParams {
940            prompt: "hello".to_string(),
941            target: DispatchAgentTarget::New {
942                workspace: WorkspaceTarget::Current,
943                agent: None,
944            },
945        };
946
947        let result = DispatchAgentTool.execute(params, &ctx).await.unwrap();
948        assert_eq!(
949            result.session_id.as_deref(),
950            Some(session_id.to_string().as_str())
951        );
952    }
953
954    #[tokio::test]
955    async fn dispatch_agent_filters_mcp_backends_by_allowlist() {
956        let event_store = Arc::new(InMemoryEventStore::new());
957        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
958        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
959        let api_client = Arc::new(ApiClient::new_with_deps(
960            crate::test_utils::test_llm_config_provider().unwrap(),
961            provider_registry,
962            model_registry,
963        ));
964        let workspace =
965            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
966                path: std::env::current_dir().unwrap(),
967            })
968            .await
969            .unwrap();
970
971        let parent_session_id = SessionId::new();
972        let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
973        session_config
974            .tool_config
975            .backends
976            .push(BackendConfig::Mcp {
977                server_name: "allowed-server".to_string(),
978                transport: McpTransport::Tcp {
979                    host: "127.0.0.1".to_string(),
980                    port: 1111,
981                },
982                tool_filter: ToolFilter::All,
983            });
984        session_config
985            .tool_config
986            .backends
987            .push(BackendConfig::Mcp {
988                server_name: "blocked-server".to_string(),
989                transport: McpTransport::Tcp {
990                    host: "127.0.0.1".to_string(),
991                    port: 2222,
992                },
993                tool_filter: ToolFilter::All,
994            });
995
996        event_store.create_session(parent_session_id).await.unwrap();
997        event_store
998            .append(
999                parent_session_id,
1000                &SessionEvent::SessionCreated {
1001                    config: Box::new(session_config),
1002                    metadata: HashMap::new(),
1003                    parent_session_id: None,
1004                },
1005            )
1006            .await
1007            .unwrap();
1008
1009        let agent_id = format!("allowlist_{}", Uuid::new_v4());
1010        let spec = AgentSpec {
1011            id: agent_id.clone(),
1012            name: "allowlist test".to_string(),
1013            description: "allowlist test".to_string(),
1014            tools: vec![VIEW_TOOL_NAME.to_string()],
1015            mcp_access: McpAccessPolicy::Allowlist(vec!["allowed-server".to_string()]),
1016            model: None,
1017        };
1018        match register_agent_spec(spec) {
1019            Ok(()) => {}
1020            Err(AgentSpecError::AlreadyRegistered(_)) => {}
1021            Err(AgentSpecError::RegistryPoisoned) => {}
1022        }
1023
1024        let captured = Arc::new(tokio::sync::Mutex::new(None));
1025        let spawner = CapturingAgentSpawner {
1026            session_id: SessionId::new(),
1027            response: "ok".to_string(),
1028            captured: captured.clone(),
1029        };
1030
1031        let services = Arc::new(
1032            ToolServices::new(workspace, event_store, api_client)
1033                .with_agent_spawner(Arc::new(spawner)),
1034        );
1035
1036        let ctx = StaticToolContext {
1037            tool_call_id: ToolCallId::new(),
1038            session_id: parent_session_id,
1039            cancellation_token: CancellationToken::new(),
1040            services,
1041        };
1042
1043        let params = DispatchAgentParams {
1044            prompt: "test".to_string(),
1045            target: DispatchAgentTarget::New {
1046                workspace: WorkspaceTarget::Current,
1047                agent: Some(agent_id),
1048            },
1049        };
1050
1051        let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1052        let captured = captured.lock().await.clone().expect("no config captured");
1053
1054        let backend_names: Vec<String> = captured
1055            .mcp_backends
1056            .iter()
1057            .map(|backend| match backend {
1058                BackendConfig::Mcp { server_name, .. } => server_name.clone(),
1059            })
1060            .collect();
1061
1062        assert_eq!(backend_names, vec!["allowed-server".to_string()]);
1063        assert!(captured.allow_mcp_tools);
1064    }
1065
1066    #[tokio::test]
1067    async fn dispatch_agent_uses_parent_model_when_spec_missing_model() {
1068        let event_store = Arc::new(InMemoryEventStore::new());
1069        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1070        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1071        let api_client = Arc::new(ApiClient::new_with_deps(
1072            crate::test_utils::test_llm_config_provider().unwrap(),
1073            provider_registry,
1074            model_registry,
1075        ));
1076        let workspace =
1077            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1078                path: std::env::current_dir().unwrap(),
1079            })
1080            .await
1081            .unwrap();
1082
1083        let parent_session_id = SessionId::new();
1084        let parent_model = builtin::claude_sonnet_4_5();
1085        let session_config = SessionConfig::read_only(parent_model.clone());
1086
1087        event_store.create_session(parent_session_id).await.unwrap();
1088        event_store
1089            .append(
1090                parent_session_id,
1091                &SessionEvent::SessionCreated {
1092                    config: Box::new(session_config),
1093                    metadata: HashMap::new(),
1094                    parent_session_id: None,
1095                },
1096            )
1097            .await
1098            .unwrap();
1099
1100        let agent_id = format!("inherit_model_{}", Uuid::new_v4());
1101        let spec = AgentSpec {
1102            id: agent_id.clone(),
1103            name: "inherit model test".to_string(),
1104            description: "inherit model test".to_string(),
1105            tools: vec![VIEW_TOOL_NAME.to_string()],
1106            mcp_access: McpAccessPolicy::None,
1107            model: None,
1108        };
1109        match register_agent_spec(spec) {
1110            Ok(()) => {}
1111            Err(AgentSpecError::AlreadyRegistered(_)) => {}
1112            Err(AgentSpecError::RegistryPoisoned) => {}
1113        }
1114
1115        let captured = Arc::new(tokio::sync::Mutex::new(None));
1116        let spawner = CapturingAgentSpawner {
1117            session_id: SessionId::new(),
1118            response: "ok".to_string(),
1119            captured: captured.clone(),
1120        };
1121
1122        let services = Arc::new(
1123            ToolServices::new(workspace, event_store, api_client)
1124                .with_agent_spawner(Arc::new(spawner)),
1125        );
1126
1127        let ctx = StaticToolContext {
1128            tool_call_id: ToolCallId::new(),
1129            session_id: parent_session_id,
1130            cancellation_token: CancellationToken::new(),
1131            services,
1132        };
1133
1134        let params = DispatchAgentParams {
1135            prompt: "test".to_string(),
1136            target: DispatchAgentTarget::New {
1137                workspace: WorkspaceTarget::Current,
1138                agent: Some(agent_id),
1139            },
1140        };
1141
1142        let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1143        let captured = captured.lock().await.clone().expect("no config captured");
1144
1145        assert_eq!(captured.model, parent_model);
1146    }
1147
1148    #[tokio::test]
1149    async fn dispatch_agent_uses_spec_model_when_set() {
1150        let event_store = Arc::new(InMemoryEventStore::new());
1151        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1152        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1153        let api_client = Arc::new(ApiClient::new_with_deps(
1154            crate::test_utils::test_llm_config_provider().unwrap(),
1155            provider_registry,
1156            model_registry,
1157        ));
1158        let workspace =
1159            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1160                path: std::env::current_dir().unwrap(),
1161            })
1162            .await
1163            .unwrap();
1164
1165        let parent_session_id = SessionId::new();
1166        let parent_model = builtin::claude_sonnet_4_5();
1167        let session_config = SessionConfig::read_only(parent_model);
1168
1169        event_store.create_session(parent_session_id).await.unwrap();
1170        event_store
1171            .append(
1172                parent_session_id,
1173                &SessionEvent::SessionCreated {
1174                    config: Box::new(session_config),
1175                    metadata: HashMap::new(),
1176                    parent_session_id: None,
1177                },
1178            )
1179            .await
1180            .unwrap();
1181
1182        let spec_model = builtin::claude_haiku_4_5();
1183        let agent_id = format!("spec_model_{}", Uuid::new_v4());
1184        let spec = AgentSpec {
1185            id: agent_id.clone(),
1186            name: "spec model test".to_string(),
1187            description: "spec model test".to_string(),
1188            tools: vec![VIEW_TOOL_NAME.to_string()],
1189            mcp_access: McpAccessPolicy::None,
1190            model: Some(spec_model.clone()),
1191        };
1192        match register_agent_spec(spec) {
1193            Ok(()) => {}
1194            Err(AgentSpecError::AlreadyRegistered(_)) => {}
1195            Err(AgentSpecError::RegistryPoisoned) => {}
1196        }
1197
1198        let captured = Arc::new(tokio::sync::Mutex::new(None));
1199        let spawner = CapturingAgentSpawner {
1200            session_id: SessionId::new(),
1201            response: "ok".to_string(),
1202            captured: captured.clone(),
1203        };
1204
1205        let services = Arc::new(
1206            ToolServices::new(workspace, event_store, api_client)
1207                .with_agent_spawner(Arc::new(spawner)),
1208        );
1209
1210        let ctx = StaticToolContext {
1211            tool_call_id: ToolCallId::new(),
1212            session_id: parent_session_id,
1213            cancellation_token: CancellationToken::new(),
1214            services,
1215        };
1216
1217        let params = DispatchAgentParams {
1218            prompt: "test".to_string(),
1219            target: DispatchAgentTarget::New {
1220                workspace: WorkspaceTarget::Current,
1221                agent: Some(agent_id),
1222            },
1223        };
1224
1225        let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1226        let captured = captured.lock().await.clone().expect("no config captured");
1227
1228        assert_eq!(captured.model, spec_model);
1229    }
1230
1231    #[tokio::test]
1232    async fn resume_session_denies_disallowed_tools() {
1233        let event_store = Arc::new(InMemoryEventStore::new());
1234        let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1235        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1236        let api_client = Arc::new(ApiClient::new_with_deps(
1237            crate::test_utils::test_llm_config_provider().unwrap(),
1238            provider_registry,
1239            model_registry,
1240        ));
1241        let workspace =
1242            crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1243                path: std::env::current_dir().unwrap(),
1244            })
1245            .await
1246            .unwrap();
1247
1248        let parent_session_id = SessionId::new();
1249        let session_id = SessionId::new();
1250        let model = builtin::claude_sonnet_4_5();
1251
1252        let tool_call = steer_tools::ToolCall {
1253            name: "bash".to_string(),
1254            parameters: serde_json::json!({ "command": "echo denied" }),
1255            id: "tool_denied".to_string(),
1256        };
1257        api_client.insert_test_provider(
1258            model.provider.clone(),
1259            Arc::new(ToolCallThenTextProvider::new(tool_call, "done")),
1260        );
1261
1262        let mut session_config = SessionConfig::read_only(model);
1263        session_config.parent_session_id = Some(parent_session_id);
1264        session_config.policy_overrides = SessionPolicyOverrides {
1265            default_model: None,
1266            tool_visibility: Some(ToolVisibility::Whitelist(HashSet::from([
1267                VIEW_TOOL_NAME.to_string()
1268            ]))),
1269            approval_policy: ToolApprovalPolicyOverrides {
1270                default_behavior: Some(UnapprovedBehavior::Deny),
1271                preapproved: ApprovalRulesOverrides {
1272                    tools: HashSet::from([VIEW_TOOL_NAME.to_string()]),
1273                    per_tool: HashMap::new(),
1274                },
1275            },
1276        };
1277
1278        event_store.create_session(session_id).await.unwrap();
1279        event_store
1280            .append(
1281                session_id,
1282                &SessionEvent::SessionCreated {
1283                    config: Box::new(session_config),
1284                    metadata: HashMap::new(),
1285                    parent_session_id: Some(parent_session_id),
1286                },
1287            )
1288            .await
1289            .unwrap();
1290
1291        let services = Arc::new(ToolServices::new(
1292            workspace,
1293            event_store.clone(),
1294            api_client,
1295        ));
1296
1297        let ctx = StaticToolContext {
1298            tool_call_id: ToolCallId::new(),
1299            session_id: parent_session_id,
1300            cancellation_token: CancellationToken::new(),
1301            services,
1302        };
1303
1304        let _ = resume_agent_session(session_id, "trigger".to_string(), &ctx)
1305            .await
1306            .unwrap();
1307
1308        let events = event_store.load_events(session_id).await.unwrap();
1309        let denied = events.iter().any(|(_, event)| match event {
1310            SessionEvent::ToolCallFailed { name, error, .. } => {
1311                name == "bash" && error.contains("denied by policy")
1312            }
1313            _ => false,
1314        });
1315
1316        assert!(denied, "expected denied ToolCallFailed event for bash");
1317    }
1318}