1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::agents::{
6 McpAccessPolicy, agent_spec, agent_specs, agent_specs_prompt, default_agent_spec_id,
7};
8use crate::app::domain::event::SessionEvent;
9use crate::app::domain::runtime::RuntimeService;
10use crate::app::domain::types::SessionId;
11use crate::app::validation::ValidatorRegistry;
12use crate::config::model::builtin::claude_sonnet_4_5 as default_model;
13use crate::runners::OneShotRunner;
14use crate::session::state::BackendConfig;
15use crate::tools::capability::Capabilities;
16use crate::tools::services::{SubAgentConfig, SubAgentError, ToolServices};
17use crate::tools::static_tool::{StaticTool, StaticToolContext, StaticToolError};
18use crate::tools::{BackendRegistry, ToolExecutor, ToolRegistry};
19use crate::workspace::{
20 CreateWorkspaceRequest, EnvironmentId, RepoRef, VcsKind, VcsStatus, Workspace,
21 WorkspaceCreateStrategy, WorkspaceRef, create_workspace_from_session_config,
22};
23use steer_tools::ToolSpec;
24use steer_tools::result::{AgentResult, AgentWorkspaceInfo, AgentWorkspaceRevision};
25use steer_tools::tools::dispatch_agent::{
26 DispatchAgentError, DispatchAgentParams, DispatchAgentTarget, DispatchAgentToolSpec,
27 WorkspaceTarget,
28};
29use steer_tools::tools::{GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
30use tracing::warn;
31
32use super::{
33 AstGrepTool, BashTool, EditTool, FetchTool, GlobTool, GrepTool, LsTool, MultiEditTool,
34 ReplaceTool, TodoReadTool, TodoWriteTool, ViewTool, workspace_manager_op_error,
35 workspace_op_error,
36};
37
38fn dispatch_agent_description() -> String {
39 let agent_specs = agent_specs_prompt();
40 let agent_specs_block = if agent_specs.is_empty() {
41 "No agent specs registered.".to_string()
42 } else {
43 agent_specs
44 };
45
46 format!(
47 r#"Launch a new agent to help with a focused task. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you.
48
49 When to use the Agent tool:
50 - If you are searching for a keyword like "config" or "logger", or for questions like "which file does X?", the Agent tool is strongly recommended
51
52 When NOT to use the Agent tool:
53 - If you want to read a specific file path, use the {} or {} tool instead of the Agent tool, to find the match more quickly
54 - If you are searching for a specific class definition like "class Foo", use the {} tool instead, to find the match more quickly
55 - If you are searching for code within a specific file or set of 2-3 files, use the {} tool instead, to find the match more quickly
56
57 Usage notes:
58 1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses
59 2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.
60 3. Each invocation returns a session_id. Pass it back via `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue the conversation with the same agent.
61 4. When `target.session` is `resume`, the session_id must refer to a child of the current session. The `agent` and `workspace` options are ignored and the existing session config is used.
62 5. The agent's outputs should generally be trusted
63 6. IMPORTANT: Only some agent specs include write tools. Use a build agent if the task requires editing files.
64 7. IMPORTANT: New workspaces are preserved (not auto-deleted). Clean them up manually if needed.
65 8. If the agent spec omits a model, the parent session's default model is used.
66
67Workspace options:
68- `workspace: {{ "location": "current" }}` to run in the current workspace
69- `workspace: {{ "location": "new", "name": "..." }}` to run in a fresh workspace (jj workspace or git worktree)
70- `location` is a logical workspace selector, not a filesystem path
71
72 Session options:
73 - `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue a prior dispatch_agent session
74
75 New session options:
76 - `target: {{ "session": "new", "workspace": {{ "location": "current" }} }}` to run in the current workspace
77 - `target: {{ "session": "new", "workspace": {{ "location": "new", "name": "..." }} }}` to run in a new workspace
78 - `target: {{ "session": "new", "workspace": {{ "location": "current" }}, "agent": "<id>" }}` selects an agent spec (defaults to "{default_agent}")
79
80 {agent_specs_block}"#,
81 VIEW_TOOL_NAME,
82 LS_TOOL_NAME,
83 GREP_TOOL_NAME,
84 GREP_TOOL_NAME,
85 default_agent = default_agent_spec_id(),
86 agent_specs_block = agent_specs_block
87 )
88}
89
90pub struct DispatchAgentTool;
91
92#[async_trait]
93impl StaticTool for DispatchAgentTool {
94 type Params = DispatchAgentParams;
95 type Output = AgentResult;
96 type Spec = DispatchAgentToolSpec;
97
98 const DESCRIPTION: &'static str = "Launch a sub-agent to search for files or code";
99 const REQUIRES_APPROVAL: bool = false;
100 const REQUIRED_CAPABILITIES: Capabilities = Capabilities::AGENT;
101
102 fn schema() -> steer_tools::ToolSchema {
103 let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
104 s.inline_subschemas = true;
105 });
106 let schema_gen = settings.into_generator();
107 let input_schema = schema_gen.into_root_schema_for::<Self::Params>();
108
109 steer_tools::ToolSchema {
110 name: Self::Spec::NAME.to_string(),
111 display_name: Self::Spec::DISPLAY_NAME.to_string(),
112 description: dispatch_agent_description(),
113 input_schema: input_schema.into(),
114 }
115 }
116
117 async fn execute(
118 &self,
119 params: Self::Params,
120 ctx: &StaticToolContext,
121 ) -> Result<Self::Output, StaticToolError<DispatchAgentError>> {
122 let DispatchAgentParams { prompt, target } = params;
123
124 let (workspace_target, agent) = match target {
125 DispatchAgentTarget::Resume { session_id } => {
126 let session_id = SessionId::parse(&session_id).ok_or_else(|| {
127 StaticToolError::invalid_params(format!("Invalid session_id '{session_id}'"))
128 })?;
129 return resume_agent_session(session_id, prompt, ctx).await;
130 }
131 DispatchAgentTarget::New { workspace, agent } => (workspace, agent),
132 };
133
134 let spawner = ctx
135 .services
136 .agent_spawner()
137 .ok_or_else(|| StaticToolError::missing_capability("agent_spawner"))?;
138
139 let base_workspace = ctx.services.workspace.clone();
140 let base_path = base_workspace.working_directory().to_path_buf();
141
142 let mut workspace = base_workspace.clone();
143 let mut workspace_ref = None;
144 let mut workspace_id = None;
145 let mut workspace_name = None;
146 let mut repo_id = None;
147 let mut repo_ref = None;
148
149 if let Some(manager) = ctx.services.workspace_manager()
150 && let Ok(info) = manager.resolve_workspace(&base_path).await
151 {
152 workspace_id = Some(info.workspace_id);
153 workspace_name.clone_from(&info.name);
154 repo_id = Some(info.repo_id);
155 workspace_ref = Some(WorkspaceRef {
156 environment_id: info.environment_id,
157 workspace_id: info.workspace_id,
158 repo_id: info.repo_id,
159 });
160 }
161
162 if let Some(manager) = ctx.services.repo_manager() {
163 let repo_env_id = workspace_ref
164 .as_ref()
165 .map_or_else(EnvironmentId::local, |reference| reference.environment_id);
166 if let Ok(info) = manager.resolve_repo(repo_env_id, &base_path).await {
167 if repo_id.is_none() {
168 repo_id = Some(info.repo_id);
169 }
170 repo_ref = Some(RepoRef {
171 environment_id: info.environment_id,
172 repo_id: info.repo_id,
173 root_path: info.root_path,
174 vcs_kind: info.vcs_kind,
175 });
176 }
177 }
178
179 let mut new_workspace = false;
180 let mut requested_workspace_name = None;
181
182 match &workspace_target {
183 WorkspaceTarget::Current => {}
184 WorkspaceTarget::New { name } => {
185 new_workspace = true;
186 requested_workspace_name = Some(name.clone());
187 }
188 }
189
190 let mut created_workspace_id = None;
191 let mut status_manager = None;
192
193 if new_workspace {
194 let manager = ctx
195 .services
196 .workspace_manager()
197 .ok_or_else(|| StaticToolError::missing_capability("workspace_manager"))?;
198 status_manager = Some(manager.clone());
199
200 let base_repo_id = repo_id.ok_or_else(|| {
201 StaticToolError::execution(DispatchAgentError::WorkspaceUnavailable {
202 message:
203 "Current path is not a supported workspace; cannot create new workspace"
204 .to_string(),
205 })
206 })?;
207
208 let strategy = match repo_ref
209 .as_ref()
210 .and_then(|reference| reference.vcs_kind.as_ref())
211 {
212 Some(VcsKind::Git) => WorkspaceCreateStrategy::GitWorktree,
213 _ => WorkspaceCreateStrategy::JjWorkspace,
214 };
215
216 let create_request = CreateWorkspaceRequest {
217 repo_id: base_repo_id,
218 name: requested_workspace_name.clone(),
219 parent_workspace_id: workspace_id,
220 strategy,
221 };
222
223 let info = manager
224 .create_workspace(create_request)
225 .await
226 .map_err(|e| {
227 StaticToolError::execution(DispatchAgentError::Workspace(
228 workspace_manager_op_error(e),
229 ))
230 })?;
231
232 workspace = manager
233 .open_workspace(info.workspace_id)
234 .await
235 .map_err(|e| {
236 StaticToolError::execution(DispatchAgentError::Workspace(
237 workspace_manager_op_error(e),
238 ))
239 })?;
240
241 workspace_id = Some(info.workspace_id);
242 created_workspace_id = Some(info.workspace_id);
243 workspace_name.clone_from(&info.name);
244 workspace_ref = Some(WorkspaceRef {
245 environment_id: info.environment_id,
246 workspace_id: info.workspace_id,
247 repo_id: info.repo_id,
248 });
249
250 if let Some(repo_manager) = ctx.services.repo_manager()
251 && let Ok(info) = repo_manager
252 .resolve_repo(info.environment_id, workspace.working_directory())
253 .await
254 {
255 repo_ref = Some(RepoRef {
256 environment_id: info.environment_id,
257 repo_id: info.repo_id,
258 root_path: info.root_path,
259 vcs_kind: info.vcs_kind,
260 });
261 }
262 }
263
264 let env_info = workspace.environment().await.map_err(|e| {
265 StaticToolError::execution(DispatchAgentError::Workspace(workspace_op_error(e)))
266 })?;
267
268 let system_prompt = format!(
269 r#"You are an agent for a CLI-based coding tool. Given the user's prompt, you should use the tools available to you to answer the user's question.
270
271Notes:
2721. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2732. When relevant, share file names and code snippets relevant to the query
2743. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.
275
276{}
277"#,
278 env_info.as_context()
279 );
280
281 let agent_id = agent
282 .as_deref()
283 .filter(|value| !value.trim().is_empty())
284 .map_or_else(|| default_agent_spec_id().to_string(), str::to_string);
285
286 let agent_spec = agent_spec(&agent_id).ok_or_else(|| {
287 let available = agent_specs()
288 .into_iter()
289 .map(|spec| spec.id)
290 .collect::<Vec<_>>()
291 .join(", ");
292 StaticToolError::invalid_params(format!(
293 "Unknown agent spec '{agent_id}'. Available: {available}"
294 ))
295 })?;
296
297 let parent_session_config = match ctx.services.event_store.load_events(ctx.session_id).await
298 {
299 Ok(events) => events.into_iter().find_map(|(_, event)| match event {
300 SessionEvent::SessionCreated { config, .. } => Some(*config),
301 _ => None,
302 }),
303 Err(err) => {
304 warn!(
305 session_id = %ctx.session_id,
306 "Failed to load parent session config for MCP servers: {err}"
307 );
308 None
309 }
310 };
311
312 let parent_mcp_backends = parent_session_config
313 .as_ref()
314 .map(|config| config.tool_config.backends.clone())
315 .unwrap_or_default();
316
317 let parent_model = parent_session_config
318 .as_ref()
319 .map_or_else(default_model, |config| config.default_model.clone());
320
321 let allow_mcp_tools = agent_spec.mcp_access.allow_mcp_tools();
322 let mcp_backends = match &agent_spec.mcp_access {
323 McpAccessPolicy::None => Vec::new(),
324 McpAccessPolicy::All => parent_mcp_backends,
325 McpAccessPolicy::Allowlist(servers) => parent_mcp_backends
326 .into_iter()
327 .filter(|backend| match backend {
328 BackendConfig::Mcp { server_name, .. } => {
329 servers.iter().any(|allowed| allowed == server_name)
330 }
331 })
332 .collect(),
333 };
334
335 let config = SubAgentConfig {
336 parent_session_id: ctx.session_id,
337 prompt,
338 allowed_tools: agent_spec.tools.clone(),
339 model: agent_spec.model.clone().unwrap_or(parent_model),
340 system_context: Some(crate::app::SystemContext::new(system_prompt)),
341 workspace: Some(workspace),
342 workspace_ref,
343 workspace_id,
344 repo_ref,
345 workspace_name,
346 mcp_backends,
347 allow_mcp_tools,
348 };
349
350 let spawn_result = spawner.spawn(config, ctx.cancellation_token.clone()).await;
351
352 let mut workspace_info = None;
353
354 if let (Some(manager), Some(workspace_id)) = (status_manager, created_workspace_id) {
355 let revision = match manager.get_workspace_status(workspace_id).await {
356 Ok(status) => match status.vcs {
357 Some(vcs) => match vcs.status {
358 VcsStatus::Jj(jj_status) => {
359 jj_status.working_copy.map(|wc| AgentWorkspaceRevision {
360 vcs_kind: "jj".to_string(),
361 revision_id: wc.commit_id,
362 summary: wc.description,
363 change_id: Some(wc.change_id),
364 })
365 }
366 VcsStatus::Git(_) => None,
367 },
368 None => None,
369 },
370 Err(err) => {
371 warn!(
372 workspace_id = %workspace_id.as_uuid(),
373 "Failed to get workspace status for sub-agent: {err}"
374 );
375 None
376 }
377 };
378
379 workspace_info = Some(AgentWorkspaceInfo {
380 workspace_id: Some(workspace_id.as_uuid().to_string()),
381 revision,
382 });
383 }
384
385 let result = spawn_result.map_err(|e| match e {
386 SubAgentError::Cancelled => StaticToolError::Cancelled,
387 other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
388 message: other.to_string(),
389 }),
390 })?;
391
392 Ok(AgentResult {
393 content: result.final_message.extract_text(),
394 session_id: Some(result.session_id.to_string()),
395 workspace: workspace_info,
396 })
397 }
398}
399
400fn build_runtime_tool_executor(
401 workspace: Arc<dyn Workspace>,
402 parent_services: &Arc<ToolServices>,
403) -> Arc<ToolExecutor> {
404 let mut services = ToolServices::new(
405 workspace.clone(),
406 parent_services.event_store.clone(),
407 parent_services.api_client.clone(),
408 );
409
410 if let Some(spawner) = parent_services.agent_spawner() {
411 services = services.with_agent_spawner(spawner.clone());
412 }
413 if let Some(caller) = parent_services.model_caller() {
414 services = services.with_model_caller(caller.clone());
415 }
416 if let Some(manager) = parent_services.workspace_manager() {
417 services = services.with_workspace_manager(manager.clone());
418 }
419 if let Some(manager) = parent_services.repo_manager() {
420 services = services.with_repo_manager(manager.clone());
421 }
422 if parent_services
423 .capabilities()
424 .contains(Capabilities::NETWORK)
425 {
426 services = services.with_network();
427 }
428
429 let mut registry = ToolRegistry::new();
430 registry.register_static(GrepTool);
431 registry.register_static(GlobTool);
432 registry.register_static(LsTool);
433 registry.register_static(ViewTool);
434 registry.register_static(BashTool);
435 registry.register_static(EditTool);
436 registry.register_static(MultiEditTool);
437 registry.register_static(ReplaceTool);
438 registry.register_static(AstGrepTool);
439 registry.register_static(TodoReadTool);
440 registry.register_static(TodoWriteTool);
441 registry.register_static(DispatchAgentTool);
442 registry.register_static(FetchTool);
443
444 Arc::new(
445 ToolExecutor::with_components(
446 Arc::new(BackendRegistry::new()),
447 Arc::new(ValidatorRegistry::new()),
448 )
449 .with_static_tools(Arc::new(registry), Arc::new(services)),
450 )
451}
452
453async fn resume_agent_session(
454 session_id: SessionId,
455 prompt: String,
456 ctx: &StaticToolContext,
457) -> Result<AgentResult, StaticToolError<DispatchAgentError>> {
458 let events = ctx
459 .services
460 .event_store
461 .load_events(session_id)
462 .await
463 .map_err(|e| {
464 StaticToolError::execution(DispatchAgentError::SpawnFailed {
465 message: format!("Failed to load session {session_id}: {e}"),
466 })
467 })?;
468
469 let session_config = events
470 .into_iter()
471 .find_map(|(_, event)| match event {
472 SessionEvent::SessionCreated { config, .. } => Some(*config),
473 _ => None,
474 })
475 .ok_or_else(|| {
476 StaticToolError::execution(DispatchAgentError::SpawnFailed {
477 message: format!("Session {session_id} is missing a SessionCreated event"),
478 })
479 })?;
480
481 if session_config.parent_session_id != Some(ctx.session_id) {
482 return Err(StaticToolError::invalid_params(format!(
483 "Session {session_id} is not a child of current session {}",
484 ctx.session_id
485 )));
486 }
487
488 let workspace = create_workspace_from_session_config(&session_config.workspace)
489 .await
490 .map_err(|e| {
491 StaticToolError::execution(DispatchAgentError::SpawnFailed {
492 message: format!("Failed to open workspace for session {session_id}: {e}"),
493 })
494 })?;
495
496 let tool_executor = build_runtime_tool_executor(workspace, &ctx.services);
497 let runtime = RuntimeService::spawn(
498 ctx.services.event_store.clone(),
499 ctx.services.api_client.clone(),
500 tool_executor,
501 );
502
503 let run_result = OneShotRunner::run_in_session_with_cancel(
504 &runtime.handle,
505 session_id,
506 prompt,
507 session_config.default_model.clone(),
508 ctx.cancellation_token.clone(),
509 )
510 .await;
511
512 runtime.shutdown().await;
513
514 let run_result = run_result.map_err(|e| match e {
515 crate::error::Error::Cancelled => StaticToolError::Cancelled,
516 other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
517 message: other.to_string(),
518 }),
519 })?;
520
521 Ok(AgentResult {
522 content: run_result.final_message.extract_text(),
523 session_id: Some(run_result.session_id.to_string()),
524 workspace: None,
525 })
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use crate::agents::{AgentSpec, AgentSpecError, McpAccessPolicy, register_agent_spec};
532 use crate::api::Client as ApiClient;
533 use crate::api::{ApiError, CompletionResponse, Provider};
534 use crate::app::conversation::{AssistantContent, Message, MessageData};
535 use crate::app::domain::session::EventStore;
536 use crate::app::domain::session::event_store::InMemoryEventStore;
537 use crate::app::domain::types::ToolCallId;
538 use crate::config::model::builtin;
539 use crate::model_registry::ModelRegistry;
540 use crate::session::state::{
541 ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, ToolApprovalPolicyOverrides,
542 ToolFilter, ToolVisibility, UnapprovedBehavior,
543 };
544 use crate::tools::McpTransport;
545 use crate::tools::services::{AgentSpawner, SubAgentError, SubAgentResult, ToolServices};
546 use async_trait::async_trait;
547 use std::collections::{HashMap, HashSet};
548 use std::sync::Mutex as StdMutex;
549 use tokio::time::{Duration, sleep};
550 use tokio_util::sync::CancellationToken;
551 use uuid::Uuid;
552
553 #[derive(Clone)]
554 struct StubProvider {
555 response: String,
556 }
557
558 impl StubProvider {
559 fn new(response: impl Into<String>) -> Self {
560 Self {
561 response: response.into(),
562 }
563 }
564 }
565
566 #[derive(Clone)]
567 struct CancelAwareProvider;
568
569 #[async_trait]
570 impl Provider for CancelAwareProvider {
571 fn name(&self) -> &'static str {
572 "cancel-aware"
573 }
574
575 async fn complete(
576 &self,
577 _model_id: &crate::config::model::ModelId,
578 _messages: Vec<Message>,
579 _system: Option<crate::app::SystemContext>,
580 _tools: Option<Vec<steer_tools::ToolSchema>>,
581 _call_options: Option<crate::config::model::ModelParameters>,
582 token: CancellationToken,
583 ) -> Result<CompletionResponse, ApiError> {
584 token.cancelled().await;
585 Err(ApiError::Cancelled {
586 provider: self.name().to_string(),
587 })
588 }
589 }
590
591 #[async_trait]
592 impl Provider for StubProvider {
593 fn name(&self) -> &'static str {
594 "stub"
595 }
596
597 async fn complete(
598 &self,
599 _model_id: &crate::config::model::ModelId,
600 _messages: Vec<Message>,
601 _system: Option<crate::app::SystemContext>,
602 _tools: Option<Vec<steer_tools::ToolSchema>>,
603 _call_options: Option<crate::config::model::ModelParameters>,
604 _token: CancellationToken,
605 ) -> Result<CompletionResponse, ApiError> {
606 Ok(CompletionResponse {
607 content: vec![AssistantContent::Text {
608 text: self.response.clone(),
609 }],
610 })
611 }
612 }
613
614 #[derive(Clone)]
615 struct StubAgentSpawner {
616 session_id: SessionId,
617 response: String,
618 }
619
620 #[async_trait]
621 impl AgentSpawner for StubAgentSpawner {
622 async fn spawn(
623 &self,
624 _config: crate::tools::services::SubAgentConfig,
625 _cancel_token: CancellationToken,
626 ) -> Result<SubAgentResult, SubAgentError> {
627 let timestamp = Message::current_timestamp();
628 let message = Message {
629 timestamp,
630 id: Message::generate_id("assistant", timestamp),
631 parent_message_id: None,
632 data: MessageData::Assistant {
633 content: vec![AssistantContent::Text {
634 text: self.response.clone(),
635 }],
636 },
637 };
638
639 Ok(SubAgentResult {
640 session_id: self.session_id,
641 final_message: message,
642 })
643 }
644 }
645
646 #[derive(Clone)]
647 struct CapturingAgentSpawner {
648 session_id: SessionId,
649 response: String,
650 captured: Arc<tokio::sync::Mutex<Option<crate::tools::services::SubAgentConfig>>>,
651 }
652
653 #[async_trait]
654 impl AgentSpawner for CapturingAgentSpawner {
655 async fn spawn(
656 &self,
657 config: crate::tools::services::SubAgentConfig,
658 _cancel_token: CancellationToken,
659 ) -> Result<SubAgentResult, SubAgentError> {
660 let mut guard = self.captured.lock().await;
661 *guard = Some(config);
662
663 let timestamp = Message::current_timestamp();
664 let message = Message {
665 timestamp,
666 id: Message::generate_id("assistant", timestamp),
667 parent_message_id: None,
668 data: MessageData::Assistant {
669 content: vec![AssistantContent::Text {
670 text: self.response.clone(),
671 }],
672 },
673 };
674
675 Ok(SubAgentResult {
676 session_id: self.session_id,
677 final_message: message,
678 })
679 }
680 }
681
682 #[derive(Clone)]
683 struct ToolCallThenTextProvider {
684 tool_call: steer_tools::ToolCall,
685 final_text: String,
686 call_count: Arc<StdMutex<usize>>,
687 }
688
689 impl ToolCallThenTextProvider {
690 fn new(tool_call: steer_tools::ToolCall, final_text: impl Into<String>) -> Self {
691 Self {
692 tool_call,
693 final_text: final_text.into(),
694 call_count: Arc::new(StdMutex::new(0)),
695 }
696 }
697 }
698
699 #[async_trait]
700 impl Provider for ToolCallThenTextProvider {
701 fn name(&self) -> &'static str {
702 "stub-tool-call"
703 }
704
705 async fn complete(
706 &self,
707 _model_id: &crate::config::model::ModelId,
708 _messages: Vec<Message>,
709 _system: Option<crate::app::SystemContext>,
710 _tools: Option<Vec<steer_tools::ToolSchema>>,
711 _call_options: Option<crate::config::model::ModelParameters>,
712 _token: CancellationToken,
713 ) -> Result<CompletionResponse, ApiError> {
714 let mut count = self
715 .call_count
716 .lock()
717 .expect("tool call counter lock poisoned");
718 let response = if *count == 0 {
719 CompletionResponse {
720 content: vec![AssistantContent::ToolCall {
721 tool_call: self.tool_call.clone(),
722 thought_signature: None,
723 }],
724 }
725 } else {
726 CompletionResponse {
727 content: vec![AssistantContent::Text {
728 text: self.final_text.clone(),
729 }],
730 }
731 };
732 *count += 1;
733 Ok(response)
734 }
735 }
736
737 #[tokio::test]
738 async fn resume_session_rejects_non_child() {
739 let event_store = Arc::new(InMemoryEventStore::new());
740 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
741 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
742 let api_client = Arc::new(ApiClient::new_with_deps(
743 crate::test_utils::test_llm_config_provider().unwrap(),
744 provider_registry,
745 model_registry,
746 ));
747 let workspace =
748 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
749 path: std::env::current_dir().unwrap(),
750 })
751 .await
752 .unwrap();
753
754 let parent_session_id = SessionId::new();
755 let session_id = SessionId::new();
756 let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
757 session_config.parent_session_id = Some(parent_session_id);
758
759 event_store.create_session(session_id).await.unwrap();
760 event_store
761 .append(
762 session_id,
763 &SessionEvent::SessionCreated {
764 config: Box::new(session_config),
765 metadata: std::collections::HashMap::new(),
766 parent_session_id: Some(parent_session_id),
767 },
768 )
769 .await
770 .unwrap();
771
772 let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
773
774 let ctx = StaticToolContext {
775 tool_call_id: ToolCallId::new(),
776 session_id: SessionId::new(),
777 cancellation_token: CancellationToken::new(),
778 services,
779 };
780
781 let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
782
783 assert!(matches!(result, Err(StaticToolError::InvalidParams(_))));
784 }
785
786 #[tokio::test]
787 async fn resume_session_accepts_child_and_returns_message() {
788 let event_store = Arc::new(InMemoryEventStore::new());
789 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
790 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
791 let api_client = Arc::new(ApiClient::new_with_deps(
792 crate::test_utils::test_llm_config_provider().unwrap(),
793 provider_registry,
794 model_registry,
795 ));
796 let model = builtin::claude_sonnet_4_5();
797 api_client.insert_test_provider(
798 model.provider.clone(),
799 Arc::new(StubProvider::new("stub-response")),
800 );
801 let workspace =
802 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
803 path: std::env::current_dir().unwrap(),
804 })
805 .await
806 .unwrap();
807
808 let parent_session_id = SessionId::new();
809 let session_id = SessionId::new();
810 let mut session_config = SessionConfig::read_only(model.clone());
811 session_config.parent_session_id = Some(parent_session_id);
812
813 event_store.create_session(session_id).await.unwrap();
814 event_store
815 .append(
816 session_id,
817 &SessionEvent::SessionCreated {
818 config: Box::new(session_config),
819 metadata: std::collections::HashMap::new(),
820 parent_session_id: Some(parent_session_id),
821 },
822 )
823 .await
824 .unwrap();
825
826 let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
827
828 let ctx = StaticToolContext {
829 tool_call_id: ToolCallId::new(),
830 session_id: parent_session_id,
831 cancellation_token: CancellationToken::new(),
832 services,
833 };
834
835 let result = resume_agent_session(session_id, "ping".to_string(), &ctx)
836 .await
837 .unwrap();
838
839 assert!(result.content.contains("stub-response"));
840 assert_eq!(
841 result.session_id.as_deref(),
842 Some(session_id.to_string().as_str())
843 );
844 }
845
846 #[tokio::test]
847 async fn resume_session_honors_cancellation() {
848 let event_store = Arc::new(InMemoryEventStore::new());
849 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
850 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
851 let api_client = Arc::new(ApiClient::new_with_deps(
852 crate::test_utils::test_llm_config_provider().unwrap(),
853 provider_registry,
854 model_registry,
855 ));
856 let model = builtin::claude_sonnet_4_5();
857 api_client.insert_test_provider(model.provider.clone(), Arc::new(CancelAwareProvider));
858 let workspace =
859 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
860 path: std::env::current_dir().unwrap(),
861 })
862 .await
863 .unwrap();
864
865 let parent_session_id = SessionId::new();
866 let session_id = SessionId::new();
867 let mut session_config = SessionConfig::read_only(model);
868 session_config.parent_session_id = Some(parent_session_id);
869
870 event_store.create_session(session_id).await.unwrap();
871 event_store
872 .append(
873 session_id,
874 &SessionEvent::SessionCreated {
875 config: Box::new(session_config),
876 metadata: std::collections::HashMap::new(),
877 parent_session_id: Some(parent_session_id),
878 },
879 )
880 .await
881 .unwrap();
882
883 let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
884
885 let cancel_token = CancellationToken::new();
886 let ctx = StaticToolContext {
887 tool_call_id: ToolCallId::new(),
888 session_id: parent_session_id,
889 cancellation_token: cancel_token.clone(),
890 services,
891 };
892
893 let cancel_task = tokio::spawn(async move {
894 sleep(Duration::from_millis(10)).await;
895 cancel_token.cancel();
896 });
897
898 let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
899 let _ = cancel_task.await;
900
901 assert!(matches!(result, Err(StaticToolError::Cancelled)));
902 }
903
904 #[tokio::test]
905 async fn dispatch_agent_returns_session_id() {
906 let event_store = Arc::new(InMemoryEventStore::new());
907 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
908 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
909 let api_client = Arc::new(ApiClient::new_with_deps(
910 crate::test_utils::test_llm_config_provider().unwrap(),
911 provider_registry,
912 model_registry,
913 ));
914 let workspace =
915 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
916 path: std::env::current_dir().unwrap(),
917 })
918 .await
919 .unwrap();
920
921 let session_id = SessionId::new();
922 let spawner = StubAgentSpawner {
923 session_id,
924 response: "done".to_string(),
925 };
926
927 let services = Arc::new(
928 ToolServices::new(workspace, event_store, api_client)
929 .with_agent_spawner(Arc::new(spawner)),
930 );
931
932 let ctx = StaticToolContext {
933 tool_call_id: ToolCallId::new(),
934 session_id: SessionId::new(),
935 cancellation_token: CancellationToken::new(),
936 services,
937 };
938
939 let params = DispatchAgentParams {
940 prompt: "hello".to_string(),
941 target: DispatchAgentTarget::New {
942 workspace: WorkspaceTarget::Current,
943 agent: None,
944 },
945 };
946
947 let result = DispatchAgentTool.execute(params, &ctx).await.unwrap();
948 assert_eq!(
949 result.session_id.as_deref(),
950 Some(session_id.to_string().as_str())
951 );
952 }
953
954 #[tokio::test]
955 async fn dispatch_agent_filters_mcp_backends_by_allowlist() {
956 let event_store = Arc::new(InMemoryEventStore::new());
957 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
958 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
959 let api_client = Arc::new(ApiClient::new_with_deps(
960 crate::test_utils::test_llm_config_provider().unwrap(),
961 provider_registry,
962 model_registry,
963 ));
964 let workspace =
965 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
966 path: std::env::current_dir().unwrap(),
967 })
968 .await
969 .unwrap();
970
971 let parent_session_id = SessionId::new();
972 let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
973 session_config
974 .tool_config
975 .backends
976 .push(BackendConfig::Mcp {
977 server_name: "allowed-server".to_string(),
978 transport: McpTransport::Tcp {
979 host: "127.0.0.1".to_string(),
980 port: 1111,
981 },
982 tool_filter: ToolFilter::All,
983 });
984 session_config
985 .tool_config
986 .backends
987 .push(BackendConfig::Mcp {
988 server_name: "blocked-server".to_string(),
989 transport: McpTransport::Tcp {
990 host: "127.0.0.1".to_string(),
991 port: 2222,
992 },
993 tool_filter: ToolFilter::All,
994 });
995
996 event_store.create_session(parent_session_id).await.unwrap();
997 event_store
998 .append(
999 parent_session_id,
1000 &SessionEvent::SessionCreated {
1001 config: Box::new(session_config),
1002 metadata: HashMap::new(),
1003 parent_session_id: None,
1004 },
1005 )
1006 .await
1007 .unwrap();
1008
1009 let agent_id = format!("allowlist_{}", Uuid::new_v4());
1010 let spec = AgentSpec {
1011 id: agent_id.clone(),
1012 name: "allowlist test".to_string(),
1013 description: "allowlist test".to_string(),
1014 tools: vec![VIEW_TOOL_NAME.to_string()],
1015 mcp_access: McpAccessPolicy::Allowlist(vec!["allowed-server".to_string()]),
1016 model: None,
1017 };
1018 match register_agent_spec(spec) {
1019 Ok(()) => {}
1020 Err(AgentSpecError::AlreadyRegistered(_)) => {}
1021 Err(AgentSpecError::RegistryPoisoned) => {}
1022 }
1023
1024 let captured = Arc::new(tokio::sync::Mutex::new(None));
1025 let spawner = CapturingAgentSpawner {
1026 session_id: SessionId::new(),
1027 response: "ok".to_string(),
1028 captured: captured.clone(),
1029 };
1030
1031 let services = Arc::new(
1032 ToolServices::new(workspace, event_store, api_client)
1033 .with_agent_spawner(Arc::new(spawner)),
1034 );
1035
1036 let ctx = StaticToolContext {
1037 tool_call_id: ToolCallId::new(),
1038 session_id: parent_session_id,
1039 cancellation_token: CancellationToken::new(),
1040 services,
1041 };
1042
1043 let params = DispatchAgentParams {
1044 prompt: "test".to_string(),
1045 target: DispatchAgentTarget::New {
1046 workspace: WorkspaceTarget::Current,
1047 agent: Some(agent_id),
1048 },
1049 };
1050
1051 let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1052 let captured = captured.lock().await.clone().expect("no config captured");
1053
1054 let backend_names: Vec<String> = captured
1055 .mcp_backends
1056 .iter()
1057 .map(|backend| match backend {
1058 BackendConfig::Mcp { server_name, .. } => server_name.clone(),
1059 })
1060 .collect();
1061
1062 assert_eq!(backend_names, vec!["allowed-server".to_string()]);
1063 assert!(captured.allow_mcp_tools);
1064 }
1065
1066 #[tokio::test]
1067 async fn dispatch_agent_uses_parent_model_when_spec_missing_model() {
1068 let event_store = Arc::new(InMemoryEventStore::new());
1069 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1070 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1071 let api_client = Arc::new(ApiClient::new_with_deps(
1072 crate::test_utils::test_llm_config_provider().unwrap(),
1073 provider_registry,
1074 model_registry,
1075 ));
1076 let workspace =
1077 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1078 path: std::env::current_dir().unwrap(),
1079 })
1080 .await
1081 .unwrap();
1082
1083 let parent_session_id = SessionId::new();
1084 let parent_model = builtin::claude_sonnet_4_5();
1085 let session_config = SessionConfig::read_only(parent_model.clone());
1086
1087 event_store.create_session(parent_session_id).await.unwrap();
1088 event_store
1089 .append(
1090 parent_session_id,
1091 &SessionEvent::SessionCreated {
1092 config: Box::new(session_config),
1093 metadata: HashMap::new(),
1094 parent_session_id: None,
1095 },
1096 )
1097 .await
1098 .unwrap();
1099
1100 let agent_id = format!("inherit_model_{}", Uuid::new_v4());
1101 let spec = AgentSpec {
1102 id: agent_id.clone(),
1103 name: "inherit model test".to_string(),
1104 description: "inherit model test".to_string(),
1105 tools: vec![VIEW_TOOL_NAME.to_string()],
1106 mcp_access: McpAccessPolicy::None,
1107 model: None,
1108 };
1109 match register_agent_spec(spec) {
1110 Ok(()) => {}
1111 Err(AgentSpecError::AlreadyRegistered(_)) => {}
1112 Err(AgentSpecError::RegistryPoisoned) => {}
1113 }
1114
1115 let captured = Arc::new(tokio::sync::Mutex::new(None));
1116 let spawner = CapturingAgentSpawner {
1117 session_id: SessionId::new(),
1118 response: "ok".to_string(),
1119 captured: captured.clone(),
1120 };
1121
1122 let services = Arc::new(
1123 ToolServices::new(workspace, event_store, api_client)
1124 .with_agent_spawner(Arc::new(spawner)),
1125 );
1126
1127 let ctx = StaticToolContext {
1128 tool_call_id: ToolCallId::new(),
1129 session_id: parent_session_id,
1130 cancellation_token: CancellationToken::new(),
1131 services,
1132 };
1133
1134 let params = DispatchAgentParams {
1135 prompt: "test".to_string(),
1136 target: DispatchAgentTarget::New {
1137 workspace: WorkspaceTarget::Current,
1138 agent: Some(agent_id),
1139 },
1140 };
1141
1142 let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1143 let captured = captured.lock().await.clone().expect("no config captured");
1144
1145 assert_eq!(captured.model, parent_model);
1146 }
1147
1148 #[tokio::test]
1149 async fn dispatch_agent_uses_spec_model_when_set() {
1150 let event_store = Arc::new(InMemoryEventStore::new());
1151 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1152 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1153 let api_client = Arc::new(ApiClient::new_with_deps(
1154 crate::test_utils::test_llm_config_provider().unwrap(),
1155 provider_registry,
1156 model_registry,
1157 ));
1158 let workspace =
1159 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1160 path: std::env::current_dir().unwrap(),
1161 })
1162 .await
1163 .unwrap();
1164
1165 let parent_session_id = SessionId::new();
1166 let parent_model = builtin::claude_sonnet_4_5();
1167 let session_config = SessionConfig::read_only(parent_model);
1168
1169 event_store.create_session(parent_session_id).await.unwrap();
1170 event_store
1171 .append(
1172 parent_session_id,
1173 &SessionEvent::SessionCreated {
1174 config: Box::new(session_config),
1175 metadata: HashMap::new(),
1176 parent_session_id: None,
1177 },
1178 )
1179 .await
1180 .unwrap();
1181
1182 let spec_model = builtin::claude_haiku_4_5();
1183 let agent_id = format!("spec_model_{}", Uuid::new_v4());
1184 let spec = AgentSpec {
1185 id: agent_id.clone(),
1186 name: "spec model test".to_string(),
1187 description: "spec model test".to_string(),
1188 tools: vec![VIEW_TOOL_NAME.to_string()],
1189 mcp_access: McpAccessPolicy::None,
1190 model: Some(spec_model.clone()),
1191 };
1192 match register_agent_spec(spec) {
1193 Ok(()) => {}
1194 Err(AgentSpecError::AlreadyRegistered(_)) => {}
1195 Err(AgentSpecError::RegistryPoisoned) => {}
1196 }
1197
1198 let captured = Arc::new(tokio::sync::Mutex::new(None));
1199 let spawner = CapturingAgentSpawner {
1200 session_id: SessionId::new(),
1201 response: "ok".to_string(),
1202 captured: captured.clone(),
1203 };
1204
1205 let services = Arc::new(
1206 ToolServices::new(workspace, event_store, api_client)
1207 .with_agent_spawner(Arc::new(spawner)),
1208 );
1209
1210 let ctx = StaticToolContext {
1211 tool_call_id: ToolCallId::new(),
1212 session_id: parent_session_id,
1213 cancellation_token: CancellationToken::new(),
1214 services,
1215 };
1216
1217 let params = DispatchAgentParams {
1218 prompt: "test".to_string(),
1219 target: DispatchAgentTarget::New {
1220 workspace: WorkspaceTarget::Current,
1221 agent: Some(agent_id),
1222 },
1223 };
1224
1225 let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1226 let captured = captured.lock().await.clone().expect("no config captured");
1227
1228 assert_eq!(captured.model, spec_model);
1229 }
1230
1231 #[tokio::test]
1232 async fn resume_session_denies_disallowed_tools() {
1233 let event_store = Arc::new(InMemoryEventStore::new());
1234 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1235 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1236 let api_client = Arc::new(ApiClient::new_with_deps(
1237 crate::test_utils::test_llm_config_provider().unwrap(),
1238 provider_registry,
1239 model_registry,
1240 ));
1241 let workspace =
1242 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1243 path: std::env::current_dir().unwrap(),
1244 })
1245 .await
1246 .unwrap();
1247
1248 let parent_session_id = SessionId::new();
1249 let session_id = SessionId::new();
1250 let model = builtin::claude_sonnet_4_5();
1251
1252 let tool_call = steer_tools::ToolCall {
1253 name: "bash".to_string(),
1254 parameters: serde_json::json!({ "command": "echo denied" }),
1255 id: "tool_denied".to_string(),
1256 };
1257 api_client.insert_test_provider(
1258 model.provider.clone(),
1259 Arc::new(ToolCallThenTextProvider::new(tool_call, "done")),
1260 );
1261
1262 let mut session_config = SessionConfig::read_only(model);
1263 session_config.parent_session_id = Some(parent_session_id);
1264 session_config.policy_overrides = SessionPolicyOverrides {
1265 default_model: None,
1266 tool_visibility: Some(ToolVisibility::Whitelist(HashSet::from([
1267 VIEW_TOOL_NAME.to_string()
1268 ]))),
1269 approval_policy: ToolApprovalPolicyOverrides {
1270 default_behavior: Some(UnapprovedBehavior::Deny),
1271 preapproved: ApprovalRulesOverrides {
1272 tools: HashSet::from([VIEW_TOOL_NAME.to_string()]),
1273 per_tool: HashMap::new(),
1274 },
1275 },
1276 };
1277
1278 event_store.create_session(session_id).await.unwrap();
1279 event_store
1280 .append(
1281 session_id,
1282 &SessionEvent::SessionCreated {
1283 config: Box::new(session_config),
1284 metadata: HashMap::new(),
1285 parent_session_id: Some(parent_session_id),
1286 },
1287 )
1288 .await
1289 .unwrap();
1290
1291 let services = Arc::new(ToolServices::new(
1292 workspace,
1293 event_store.clone(),
1294 api_client,
1295 ));
1296
1297 let ctx = StaticToolContext {
1298 tool_call_id: ToolCallId::new(),
1299 session_id: parent_session_id,
1300 cancellation_token: CancellationToken::new(),
1301 services,
1302 };
1303
1304 let _ = resume_agent_session(session_id, "trigger".to_string(), &ctx)
1305 .await
1306 .unwrap();
1307
1308 let events = event_store.load_events(session_id).await.unwrap();
1309 let denied = events.iter().any(|(_, event)| match event {
1310 SessionEvent::ToolCallFailed { name, error, .. } => {
1311 name == "bash" && error.contains("denied by policy")
1312 }
1313 _ => false,
1314 });
1315
1316 assert!(denied, "expected denied ToolCallFailed event for bash");
1317 }
1318}