1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::agents::{
6 McpAccessPolicy, agent_spec, agent_specs, agent_specs_prompt, default_agent_spec_id,
7};
8use crate::app::domain::event::SessionEvent;
9use crate::app::domain::runtime::RuntimeService;
10use crate::app::domain::types::SessionId;
11use crate::app::validation::ValidatorRegistry;
12use crate::config::model::builtin::claude_sonnet_4_5 as default_model;
13use crate::runners::OneShotRunner;
14use crate::session::state::BackendConfig;
15use crate::tools::capability::Capabilities;
16use crate::tools::services::{SubAgentConfig, SubAgentError, ToolServices};
17use crate::tools::static_tool::{StaticTool, StaticToolContext, StaticToolError};
18use crate::tools::{BackendRegistry, ToolExecutor, ToolRegistry};
19use crate::workspace::{
20 CreateWorkspaceRequest, EnvironmentId, RepoRef, VcsKind, VcsStatus, Workspace,
21 WorkspaceCreateStrategy, WorkspaceRef, create_workspace_from_session_config,
22};
23use steer_tools::ToolSpec;
24use steer_tools::result::{AgentResult, AgentWorkspaceInfo, AgentWorkspaceRevision};
25use steer_tools::tools::dispatch_agent::{
26 DispatchAgentError, DispatchAgentParams, DispatchAgentTarget, DispatchAgentToolSpec,
27 WorkspaceTarget,
28};
29use steer_tools::tools::{GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
30use tracing::warn;
31
32use super::{
33 AstGrepTool, BashTool, EditTool, FetchTool, GlobTool, GrepTool, LsTool, MultiEditTool,
34 ReplaceTool, TodoReadTool, TodoWriteTool, ViewTool, workspace_manager_op_error,
35 workspace_op_error,
36};
37
38fn dispatch_agent_description() -> String {
39 let agent_specs = agent_specs_prompt();
40 let agent_specs_block = if agent_specs.is_empty() {
41 "No agent specs registered.".to_string()
42 } else {
43 agent_specs
44 };
45
46 format!(
47 r#"Launch a new agent to help with a focused task. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you.
48
49 When to use the Agent tool:
50 - If you are searching for a keyword like "config" or "logger", or for questions like "which file does X?", the Agent tool is strongly recommended
51
52 When NOT to use the Agent tool:
53 - If you want to read a specific file path, use the {} or {} tool instead of the Agent tool, to find the match more quickly
54 - If you are searching for a specific class definition like "class Foo", use the {} tool instead, to find the match more quickly
55 - If you are searching for code within a specific file or set of 2-3 files, use the {} tool instead, to find the match more quickly
56
57 Usage notes:
58 1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses
59 2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.
60 3. Each invocation returns a session_id. Pass it back via `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue the conversation with the same agent.
61 4. When `target.session` is `resume`, the session_id must refer to a child of the current session. The `agent` and `workspace` options are ignored and the existing session config is used.
62 5. The agent's outputs should generally be trusted
63 6. IMPORTANT: Only some agent specs include write tools. Use a build agent if the task requires editing files.
64 7. IMPORTANT: New workspaces are preserved (not auto-deleted). Clean them up manually if needed.
65 8. If the agent spec omits a model, the parent session's default model is used.
66 9. IMPORTANT: Do NOT include `Repo: <path>`, `CWD: <path>`, or similar path headers in your prompt. The sub-agent already receives its working directory via system instructions.
67 10. IMPORTANT: If `target.session` is `new` and `workspace.location` is `new`, the sub-agent runs in the newly created workspace path, which may differ from the caller's current directory.
68
69Workspace options:
70- `workspace: {{ "location": "current" }}` to run in the current workspace
71- `workspace: {{ "location": "new", "name": "..." }}` to run in a fresh workspace (jj workspace or git worktree)
72- `location` is a logical workspace selector, not a filesystem path
73
74 Session options:
75 - `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue a prior dispatch_agent session
76
77 New session options:
78 - `target: {{ "session": "new", "workspace": {{ "location": "current" }} }}` to run in the current workspace
79 - `target: {{ "session": "new", "workspace": {{ "location": "new", "name": "..." }} }}` to run in a new workspace
80 - `target: {{ "session": "new", "workspace": {{ "location": "current" }}, "agent": "<id>" }}` selects an agent spec (defaults to "{default_agent}")
81
82 {agent_specs_block}"#,
83 VIEW_TOOL_NAME,
84 LS_TOOL_NAME,
85 GREP_TOOL_NAME,
86 GREP_TOOL_NAME,
87 default_agent = default_agent_spec_id(),
88 agent_specs_block = agent_specs_block
89 )
90}
91
92pub struct DispatchAgentTool;
93
94#[async_trait]
95impl StaticTool for DispatchAgentTool {
96 type Params = DispatchAgentParams;
97 type Output = AgentResult;
98 type Spec = DispatchAgentToolSpec;
99
100 const DESCRIPTION: &'static str = "Launch a sub-agent to search for files or code";
101 const REQUIRES_APPROVAL: bool = false;
102 const REQUIRED_CAPABILITIES: Capabilities = Capabilities::AGENT;
103
104 fn schema() -> steer_tools::ToolSchema {
105 let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
106 s.inline_subschemas = true;
107 });
108 let schema_gen = settings.into_generator();
109 let input_schema = schema_gen.into_root_schema_for::<Self::Params>();
110
111 steer_tools::ToolSchema {
112 name: Self::Spec::NAME.to_string(),
113 display_name: Self::Spec::DISPLAY_NAME.to_string(),
114 description: dispatch_agent_description(),
115 input_schema: input_schema.into(),
116 }
117 }
118
119 async fn execute(
120 &self,
121 params: Self::Params,
122 ctx: &StaticToolContext,
123 ) -> Result<Self::Output, StaticToolError<DispatchAgentError>> {
124 let DispatchAgentParams { prompt, target } = params;
125
126 let (workspace_target, agent) = match target {
127 DispatchAgentTarget::Resume { session_id } => {
128 let session_id = SessionId::parse(&session_id).ok_or_else(|| {
129 StaticToolError::invalid_params(format!("Invalid session_id '{session_id}'"))
130 })?;
131 return resume_agent_session(session_id, prompt, ctx).await;
132 }
133 DispatchAgentTarget::New { workspace, agent } => (workspace, agent),
134 };
135
136 let spawner = ctx
137 .services
138 .agent_spawner()
139 .ok_or_else(|| StaticToolError::missing_capability("agent_spawner"))?;
140
141 let base_workspace = ctx.services.workspace.clone();
142 let base_path = base_workspace.working_directory().to_path_buf();
143
144 let mut workspace = base_workspace.clone();
145 let mut workspace_ref = None;
146 let mut workspace_id = None;
147 let mut workspace_name = None;
148 let mut repo_id = None;
149 let mut repo_ref = None;
150
151 if let Some(manager) = ctx.services.workspace_manager()
152 && let Ok(info) = manager.resolve_workspace(&base_path).await
153 {
154 workspace_id = Some(info.workspace_id);
155 workspace_name.clone_from(&info.name);
156 repo_id = Some(info.repo_id);
157 workspace_ref = Some(WorkspaceRef {
158 environment_id: info.environment_id,
159 workspace_id: info.workspace_id,
160 repo_id: info.repo_id,
161 });
162 }
163
164 if let Some(manager) = ctx.services.repo_manager() {
165 let repo_env_id = workspace_ref
166 .as_ref()
167 .map_or_else(EnvironmentId::local, |reference| reference.environment_id);
168 if let Ok(info) = manager.resolve_repo(repo_env_id, &base_path).await {
169 if repo_id.is_none() {
170 repo_id = Some(info.repo_id);
171 }
172 repo_ref = Some(RepoRef {
173 environment_id: info.environment_id,
174 repo_id: info.repo_id,
175 root_path: info.root_path,
176 vcs_kind: info.vcs_kind,
177 });
178 }
179 }
180
181 let mut new_workspace = false;
182 let mut requested_workspace_name = None;
183
184 match &workspace_target {
185 WorkspaceTarget::Current => {}
186 WorkspaceTarget::New { name } => {
187 new_workspace = true;
188 requested_workspace_name = Some(name.clone());
189 }
190 }
191
192 let mut created_workspace_id = None;
193 let mut status_manager = None;
194
195 if new_workspace {
196 let manager = ctx
197 .services
198 .workspace_manager()
199 .ok_or_else(|| StaticToolError::missing_capability("workspace_manager"))?;
200 status_manager = Some(manager.clone());
201
202 let base_repo_id = repo_id.ok_or_else(|| {
203 StaticToolError::execution(DispatchAgentError::WorkspaceUnavailable {
204 message:
205 "Current path is not a supported workspace; cannot create new workspace"
206 .to_string(),
207 })
208 })?;
209
210 let strategy = match repo_ref
211 .as_ref()
212 .and_then(|reference| reference.vcs_kind.as_ref())
213 {
214 Some(VcsKind::Git) => WorkspaceCreateStrategy::GitWorktree,
215 _ => WorkspaceCreateStrategy::JjWorkspace,
216 };
217
218 let create_request = CreateWorkspaceRequest {
219 repo_id: base_repo_id,
220 name: requested_workspace_name.clone(),
221 parent_workspace_id: workspace_id,
222 strategy,
223 };
224
225 let info = manager
226 .create_workspace(create_request)
227 .await
228 .map_err(|e| {
229 StaticToolError::execution(DispatchAgentError::Workspace(
230 workspace_manager_op_error(e),
231 ))
232 })?;
233
234 workspace = manager
235 .open_workspace(info.workspace_id)
236 .await
237 .map_err(|e| {
238 StaticToolError::execution(DispatchAgentError::Workspace(
239 workspace_manager_op_error(e),
240 ))
241 })?;
242
243 workspace_id = Some(info.workspace_id);
244 created_workspace_id = Some(info.workspace_id);
245 workspace_name.clone_from(&info.name);
246 workspace_ref = Some(WorkspaceRef {
247 environment_id: info.environment_id,
248 workspace_id: info.workspace_id,
249 repo_id: info.repo_id,
250 });
251
252 if let Some(repo_manager) = ctx.services.repo_manager()
253 && let Ok(info) = repo_manager
254 .resolve_repo(info.environment_id, workspace.working_directory())
255 .await
256 {
257 repo_ref = Some(RepoRef {
258 environment_id: info.environment_id,
259 repo_id: info.repo_id,
260 root_path: info.root_path,
261 vcs_kind: info.vcs_kind,
262 });
263 }
264 }
265
266 let env_info = workspace.environment().await.map_err(|e| {
267 StaticToolError::execution(DispatchAgentError::Workspace(workspace_op_error(e)))
268 })?;
269
270 let system_prompt = format!(
271 r#"You are an agent for a CLI-based coding tool. Given the user's prompt, you should use the tools available to you to answer the user's question.
272
273Notes:
2741. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2752. When relevant, share file names and code snippets relevant to the query
2763. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.
277
278{}
279"#,
280 env_info.as_context()
281 );
282
283 let agent_id = agent
284 .as_deref()
285 .filter(|value| !value.trim().is_empty())
286 .map_or_else(|| default_agent_spec_id().to_string(), str::to_string);
287
288 let agent_spec = agent_spec(&agent_id).ok_or_else(|| {
289 let available = agent_specs()
290 .into_iter()
291 .map(|spec| spec.id)
292 .collect::<Vec<_>>()
293 .join(", ");
294 StaticToolError::invalid_params(format!(
295 "Unknown agent spec '{agent_id}'. Available: {available}"
296 ))
297 })?;
298
299 let parent_session_config = match ctx.services.event_store.load_events(ctx.session_id).await
300 {
301 Ok(events) => events.into_iter().find_map(|(_, event)| match event {
302 SessionEvent::SessionCreated { config, .. } => Some(*config),
303 _ => None,
304 }),
305 Err(err) => {
306 warn!(
307 session_id = %ctx.session_id,
308 "Failed to load parent session config for MCP servers: {err}"
309 );
310 None
311 }
312 };
313
314 let parent_mcp_backends = parent_session_config
315 .as_ref()
316 .map(|config| config.tool_config.backends.clone())
317 .unwrap_or_default();
318
319 let parent_model = parent_session_config
320 .as_ref()
321 .map_or_else(default_model, |config| config.default_model.clone());
322
323 let allow_mcp_tools = agent_spec.mcp_access.allow_mcp_tools();
324 let mcp_backends = match &agent_spec.mcp_access {
325 McpAccessPolicy::None => Vec::new(),
326 McpAccessPolicy::All => parent_mcp_backends,
327 McpAccessPolicy::Allowlist(servers) => parent_mcp_backends
328 .into_iter()
329 .filter(|backend| match backend {
330 BackendConfig::Mcp { server_name, .. } => {
331 servers.iter().any(|allowed| allowed == server_name)
332 }
333 })
334 .collect(),
335 };
336
337 let config = SubAgentConfig {
338 parent_session_id: ctx.session_id,
339 prompt,
340 allowed_tools: agent_spec.tools.clone(),
341 model: agent_spec.model.clone().unwrap_or(parent_model),
342 system_context: Some(crate::app::SystemContext::new(system_prompt)),
343 workspace: Some(workspace),
344 workspace_ref,
345 workspace_id,
346 repo_ref,
347 workspace_name,
348 mcp_backends,
349 allow_mcp_tools,
350 };
351
352 let spawn_result = spawner.spawn(config, ctx.cancellation_token.clone()).await;
353
354 let mut workspace_info = None;
355
356 if let (Some(manager), Some(workspace_id)) = (status_manager, created_workspace_id) {
357 let revision = match manager.get_workspace_status(workspace_id).await {
358 Ok(status) => match status.vcs {
359 Some(vcs) => match vcs.status {
360 VcsStatus::Jj(jj_status) => {
361 jj_status.working_copy.map(|wc| AgentWorkspaceRevision {
362 vcs_kind: "jj".to_string(),
363 revision_id: wc.commit_id,
364 summary: wc.description,
365 change_id: Some(wc.change_id),
366 })
367 }
368 VcsStatus::Git(_) => None,
369 },
370 None => None,
371 },
372 Err(err) => {
373 warn!(
374 workspace_id = %workspace_id.as_uuid(),
375 "Failed to get workspace status for sub-agent: {err}"
376 );
377 None
378 }
379 };
380
381 workspace_info = Some(AgentWorkspaceInfo {
382 workspace_id: Some(workspace_id.as_uuid().to_string()),
383 revision,
384 });
385 }
386
387 let result = spawn_result.map_err(|e| match e {
388 SubAgentError::Cancelled => StaticToolError::Cancelled,
389 other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
390 message: other.to_string(),
391 }),
392 })?;
393
394 Ok(AgentResult {
395 content: result.final_message.extract_text(),
396 session_id: Some(result.session_id.to_string()),
397 workspace: workspace_info,
398 })
399 }
400}
401
402fn build_runtime_tool_executor(
403 workspace: Arc<dyn Workspace>,
404 parent_services: &Arc<ToolServices>,
405) -> Arc<ToolExecutor> {
406 let mut services = ToolServices::new(
407 workspace.clone(),
408 parent_services.event_store.clone(),
409 parent_services.api_client.clone(),
410 );
411
412 if let Some(spawner) = parent_services.agent_spawner() {
413 services = services.with_agent_spawner(spawner.clone());
414 }
415 if let Some(caller) = parent_services.model_caller() {
416 services = services.with_model_caller(caller.clone());
417 }
418 if let Some(manager) = parent_services.workspace_manager() {
419 services = services.with_workspace_manager(manager.clone());
420 }
421 if let Some(manager) = parent_services.repo_manager() {
422 services = services.with_repo_manager(manager.clone());
423 }
424 if parent_services
425 .capabilities()
426 .contains(Capabilities::NETWORK)
427 {
428 services = services.with_network();
429 }
430
431 let mut registry = ToolRegistry::new();
432 registry.register_static(GrepTool);
433 registry.register_static(GlobTool);
434 registry.register_static(LsTool);
435 registry.register_static(ViewTool);
436 registry.register_static(BashTool);
437 registry.register_static(EditTool);
438 registry.register_static(MultiEditTool);
439 registry.register_static(ReplaceTool);
440 registry.register_static(AstGrepTool);
441 registry.register_static(TodoReadTool);
442 registry.register_static(TodoWriteTool);
443 registry.register_static(DispatchAgentTool);
444 registry.register_static(FetchTool);
445
446 Arc::new(
447 ToolExecutor::with_components(
448 Arc::new(BackendRegistry::new()),
449 Arc::new(ValidatorRegistry::new()),
450 )
451 .with_static_tools(Arc::new(registry), Arc::new(services)),
452 )
453}
454
455async fn resume_agent_session(
456 session_id: SessionId,
457 prompt: String,
458 ctx: &StaticToolContext,
459) -> Result<AgentResult, StaticToolError<DispatchAgentError>> {
460 let events = ctx
461 .services
462 .event_store
463 .load_events(session_id)
464 .await
465 .map_err(|e| {
466 StaticToolError::execution(DispatchAgentError::SpawnFailed {
467 message: format!("Failed to load session {session_id}: {e}"),
468 })
469 })?;
470
471 let session_config = events
472 .into_iter()
473 .find_map(|(_, event)| match event {
474 SessionEvent::SessionCreated { config, .. } => Some(*config),
475 _ => None,
476 })
477 .ok_or_else(|| {
478 StaticToolError::execution(DispatchAgentError::SpawnFailed {
479 message: format!("Session {session_id} is missing a SessionCreated event"),
480 })
481 })?;
482
483 if session_config.parent_session_id != Some(ctx.session_id) {
484 return Err(StaticToolError::invalid_params(format!(
485 "Session {session_id} is not a child of current session {}",
486 ctx.session_id
487 )));
488 }
489
490 let workspace = create_workspace_from_session_config(&session_config.workspace)
491 .await
492 .map_err(|e| {
493 StaticToolError::execution(DispatchAgentError::SpawnFailed {
494 message: format!("Failed to open workspace for session {session_id}: {e}"),
495 })
496 })?;
497
498 let tool_executor = build_runtime_tool_executor(workspace, &ctx.services);
499 let runtime = RuntimeService::spawn(
500 ctx.services.event_store.clone(),
501 ctx.services.api_client.clone(),
502 tool_executor,
503 );
504
505 let run_result = OneShotRunner::run_in_session_with_cancel(
506 &runtime.handle,
507 session_id,
508 prompt,
509 session_config.default_model.clone(),
510 ctx.cancellation_token.clone(),
511 )
512 .await;
513
514 runtime.shutdown().await;
515
516 let run_result = run_result.map_err(|e| match e {
517 crate::error::Error::Cancelled => StaticToolError::Cancelled,
518 other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
519 message: other.to_string(),
520 }),
521 })?;
522
523 Ok(AgentResult {
524 content: run_result.final_message.extract_text(),
525 session_id: Some(run_result.session_id.to_string()),
526 workspace: None,
527 })
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use crate::agents::{AgentSpec, AgentSpecError, McpAccessPolicy, register_agent_spec};
534 use crate::api::Client as ApiClient;
535 use crate::api::{ApiError, CompletionResponse, Provider};
536 use crate::app::conversation::{AssistantContent, Message, MessageData};
537 use crate::app::domain::session::EventStore;
538 use crate::app::domain::session::event_store::InMemoryEventStore;
539 use crate::app::domain::types::ToolCallId;
540 use crate::config::model::builtin;
541 use crate::model_registry::ModelRegistry;
542 use crate::session::state::{
543 ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, ToolApprovalPolicyOverrides,
544 ToolFilter, ToolVisibility, UnapprovedBehavior,
545 };
546 use crate::tools::McpTransport;
547 use crate::tools::services::{AgentSpawner, SubAgentError, SubAgentResult, ToolServices};
548 use async_trait::async_trait;
549 use std::collections::{HashMap, HashSet};
550 use std::sync::Mutex as StdMutex;
551 use tokio::time::{Duration, sleep};
552 use tokio_util::sync::CancellationToken;
553 use uuid::Uuid;
554
555 #[derive(Clone)]
556 struct StubProvider {
557 response: String,
558 }
559
560 impl StubProvider {
561 fn new(response: impl Into<String>) -> Self {
562 Self {
563 response: response.into(),
564 }
565 }
566 }
567
568 #[derive(Clone)]
569 struct CancelAwareProvider;
570
571 #[async_trait]
572 impl Provider for CancelAwareProvider {
573 fn name(&self) -> &'static str {
574 "cancel-aware"
575 }
576
577 async fn complete(
578 &self,
579 _model_id: &crate::config::model::ModelId,
580 _messages: Vec<Message>,
581 _system: Option<crate::app::SystemContext>,
582 _tools: Option<Vec<steer_tools::ToolSchema>>,
583 _call_options: Option<crate::config::model::ModelParameters>,
584 token: CancellationToken,
585 ) -> Result<CompletionResponse, ApiError> {
586 token.cancelled().await;
587 Err(ApiError::Cancelled {
588 provider: self.name().to_string(),
589 })
590 }
591 }
592
593 #[async_trait]
594 impl Provider for StubProvider {
595 fn name(&self) -> &'static str {
596 "stub"
597 }
598
599 async fn complete(
600 &self,
601 _model_id: &crate::config::model::ModelId,
602 _messages: Vec<Message>,
603 _system: Option<crate::app::SystemContext>,
604 _tools: Option<Vec<steer_tools::ToolSchema>>,
605 _call_options: Option<crate::config::model::ModelParameters>,
606 _token: CancellationToken,
607 ) -> Result<CompletionResponse, ApiError> {
608 Ok(CompletionResponse {
609 content: vec![AssistantContent::Text {
610 text: self.response.clone(),
611 }],
612 usage: None,
613 })
614 }
615 }
616
617 #[derive(Clone)]
618 struct StubAgentSpawner {
619 session_id: SessionId,
620 response: String,
621 }
622
623 #[async_trait]
624 impl AgentSpawner for StubAgentSpawner {
625 async fn spawn(
626 &self,
627 _config: crate::tools::services::SubAgentConfig,
628 _cancel_token: CancellationToken,
629 ) -> Result<SubAgentResult, SubAgentError> {
630 let timestamp = Message::current_timestamp();
631 let message = Message {
632 timestamp,
633 id: Message::generate_id("assistant", timestamp),
634 parent_message_id: None,
635 data: MessageData::Assistant {
636 content: vec![AssistantContent::Text {
637 text: self.response.clone(),
638 }],
639 },
640 };
641
642 Ok(SubAgentResult {
643 session_id: self.session_id,
644 final_message: message,
645 })
646 }
647 }
648
649 #[derive(Clone)]
650 struct CapturingAgentSpawner {
651 session_id: SessionId,
652 response: String,
653 captured: Arc<tokio::sync::Mutex<Option<crate::tools::services::SubAgentConfig>>>,
654 }
655
656 #[async_trait]
657 impl AgentSpawner for CapturingAgentSpawner {
658 async fn spawn(
659 &self,
660 config: crate::tools::services::SubAgentConfig,
661 _cancel_token: CancellationToken,
662 ) -> Result<SubAgentResult, SubAgentError> {
663 let mut guard = self.captured.lock().await;
664 *guard = Some(config);
665
666 let timestamp = Message::current_timestamp();
667 let message = Message {
668 timestamp,
669 id: Message::generate_id("assistant", timestamp),
670 parent_message_id: None,
671 data: MessageData::Assistant {
672 content: vec![AssistantContent::Text {
673 text: self.response.clone(),
674 }],
675 },
676 };
677
678 Ok(SubAgentResult {
679 session_id: self.session_id,
680 final_message: message,
681 })
682 }
683 }
684
685 #[derive(Clone)]
686 struct ToolCallThenTextProvider {
687 tool_call: steer_tools::ToolCall,
688 final_text: String,
689 call_count: Arc<StdMutex<usize>>,
690 }
691
692 impl ToolCallThenTextProvider {
693 fn new(tool_call: steer_tools::ToolCall, final_text: impl Into<String>) -> Self {
694 Self {
695 tool_call,
696 final_text: final_text.into(),
697 call_count: Arc::new(StdMutex::new(0)),
698 }
699 }
700 }
701
702 #[async_trait]
703 impl Provider for ToolCallThenTextProvider {
704 fn name(&self) -> &'static str {
705 "stub-tool-call"
706 }
707
708 async fn complete(
709 &self,
710 _model_id: &crate::config::model::ModelId,
711 _messages: Vec<Message>,
712 _system: Option<crate::app::SystemContext>,
713 _tools: Option<Vec<steer_tools::ToolSchema>>,
714 _call_options: Option<crate::config::model::ModelParameters>,
715 _token: CancellationToken,
716 ) -> Result<CompletionResponse, ApiError> {
717 let mut count = self
718 .call_count
719 .lock()
720 .expect("tool call counter lock poisoned");
721 let response = if *count == 0 {
722 CompletionResponse {
723 content: vec![AssistantContent::ToolCall {
724 tool_call: self.tool_call.clone(),
725 thought_signature: None,
726 }],
727 usage: None,
728 }
729 } else {
730 CompletionResponse {
731 content: vec![AssistantContent::Text {
732 text: self.final_text.clone(),
733 }],
734 usage: None,
735 }
736 };
737 *count += 1;
738 Ok(response)
739 }
740 }
741
742 #[tokio::test]
743 async fn resume_session_rejects_non_child() {
744 let event_store = Arc::new(InMemoryEventStore::new());
745 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
746 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
747 let api_client = Arc::new(ApiClient::new_with_deps(
748 crate::test_utils::test_llm_config_provider().unwrap(),
749 provider_registry,
750 model_registry,
751 ));
752 let workspace =
753 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
754 path: std::env::current_dir().unwrap(),
755 })
756 .await
757 .unwrap();
758
759 let parent_session_id = SessionId::new();
760 let session_id = SessionId::new();
761 let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
762 session_config.parent_session_id = Some(parent_session_id);
763
764 event_store.create_session(session_id).await.unwrap();
765 event_store
766 .append(
767 session_id,
768 &SessionEvent::SessionCreated {
769 config: Box::new(session_config),
770 metadata: std::collections::HashMap::new(),
771 parent_session_id: Some(parent_session_id),
772 },
773 )
774 .await
775 .unwrap();
776
777 let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
778
779 let ctx = StaticToolContext {
780 tool_call_id: ToolCallId::new(),
781 session_id: SessionId::new(),
782 cancellation_token: CancellationToken::new(),
783 services,
784 };
785
786 let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
787
788 assert!(matches!(result, Err(StaticToolError::InvalidParams(_))));
789 }
790
791 #[tokio::test]
792 async fn resume_session_accepts_child_and_returns_message() {
793 let event_store = Arc::new(InMemoryEventStore::new());
794 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
795 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
796 let api_client = Arc::new(ApiClient::new_with_deps(
797 crate::test_utils::test_llm_config_provider().unwrap(),
798 provider_registry,
799 model_registry,
800 ));
801 let model = builtin::claude_sonnet_4_5();
802 api_client.insert_test_provider(
803 model.provider.clone(),
804 Arc::new(StubProvider::new("stub-response")),
805 );
806 let workspace =
807 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
808 path: std::env::current_dir().unwrap(),
809 })
810 .await
811 .unwrap();
812
813 let parent_session_id = SessionId::new();
814 let session_id = SessionId::new();
815 let mut session_config = SessionConfig::read_only(model.clone());
816 session_config.parent_session_id = Some(parent_session_id);
817
818 event_store.create_session(session_id).await.unwrap();
819 event_store
820 .append(
821 session_id,
822 &SessionEvent::SessionCreated {
823 config: Box::new(session_config),
824 metadata: std::collections::HashMap::new(),
825 parent_session_id: Some(parent_session_id),
826 },
827 )
828 .await
829 .unwrap();
830
831 let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
832
833 let ctx = StaticToolContext {
834 tool_call_id: ToolCallId::new(),
835 session_id: parent_session_id,
836 cancellation_token: CancellationToken::new(),
837 services,
838 };
839
840 let result = resume_agent_session(session_id, "ping".to_string(), &ctx)
841 .await
842 .unwrap();
843
844 assert!(result.content.contains("stub-response"));
845 assert_eq!(
846 result.session_id.as_deref(),
847 Some(session_id.to_string().as_str())
848 );
849 }
850
851 #[tokio::test]
852 async fn resume_session_honors_cancellation() {
853 let event_store = Arc::new(InMemoryEventStore::new());
854 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
855 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
856 let api_client = Arc::new(ApiClient::new_with_deps(
857 crate::test_utils::test_llm_config_provider().unwrap(),
858 provider_registry,
859 model_registry,
860 ));
861 let model = builtin::claude_sonnet_4_5();
862 api_client.insert_test_provider(model.provider.clone(), Arc::new(CancelAwareProvider));
863 let workspace =
864 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
865 path: std::env::current_dir().unwrap(),
866 })
867 .await
868 .unwrap();
869
870 let parent_session_id = SessionId::new();
871 let session_id = SessionId::new();
872 let mut session_config = SessionConfig::read_only(model);
873 session_config.parent_session_id = Some(parent_session_id);
874
875 event_store.create_session(session_id).await.unwrap();
876 event_store
877 .append(
878 session_id,
879 &SessionEvent::SessionCreated {
880 config: Box::new(session_config),
881 metadata: std::collections::HashMap::new(),
882 parent_session_id: Some(parent_session_id),
883 },
884 )
885 .await
886 .unwrap();
887
888 let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
889
890 let cancel_token = CancellationToken::new();
891 let ctx = StaticToolContext {
892 tool_call_id: ToolCallId::new(),
893 session_id: parent_session_id,
894 cancellation_token: cancel_token.clone(),
895 services,
896 };
897
898 let cancel_task = tokio::spawn(async move {
899 sleep(Duration::from_millis(10)).await;
900 cancel_token.cancel();
901 });
902
903 let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
904 let _ = cancel_task.await;
905
906 assert!(matches!(result, Err(StaticToolError::Cancelled)));
907 }
908
909 #[tokio::test]
910 async fn dispatch_agent_returns_session_id() {
911 let event_store = Arc::new(InMemoryEventStore::new());
912 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
913 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
914 let api_client = Arc::new(ApiClient::new_with_deps(
915 crate::test_utils::test_llm_config_provider().unwrap(),
916 provider_registry,
917 model_registry,
918 ));
919 let workspace =
920 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
921 path: std::env::current_dir().unwrap(),
922 })
923 .await
924 .unwrap();
925
926 let session_id = SessionId::new();
927 let spawner = StubAgentSpawner {
928 session_id,
929 response: "done".to_string(),
930 };
931
932 let services = Arc::new(
933 ToolServices::new(workspace, event_store, api_client)
934 .with_agent_spawner(Arc::new(spawner)),
935 );
936
937 let ctx = StaticToolContext {
938 tool_call_id: ToolCallId::new(),
939 session_id: SessionId::new(),
940 cancellation_token: CancellationToken::new(),
941 services,
942 };
943
944 let params = DispatchAgentParams {
945 prompt: "hello".to_string(),
946 target: DispatchAgentTarget::New {
947 workspace: WorkspaceTarget::Current,
948 agent: None,
949 },
950 };
951
952 let result = DispatchAgentTool.execute(params, &ctx).await.unwrap();
953 assert_eq!(
954 result.session_id.as_deref(),
955 Some(session_id.to_string().as_str())
956 );
957 }
958
959 #[tokio::test]
960 async fn dispatch_agent_filters_mcp_backends_by_allowlist() {
961 let event_store = Arc::new(InMemoryEventStore::new());
962 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
963 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
964 let api_client = Arc::new(ApiClient::new_with_deps(
965 crate::test_utils::test_llm_config_provider().unwrap(),
966 provider_registry,
967 model_registry,
968 ));
969 let workspace =
970 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
971 path: std::env::current_dir().unwrap(),
972 })
973 .await
974 .unwrap();
975
976 let parent_session_id = SessionId::new();
977 let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
978 session_config
979 .tool_config
980 .backends
981 .push(BackendConfig::Mcp {
982 server_name: "allowed-server".to_string(),
983 transport: McpTransport::Tcp {
984 host: "127.0.0.1".to_string(),
985 port: 1111,
986 },
987 tool_filter: ToolFilter::All,
988 });
989 session_config
990 .tool_config
991 .backends
992 .push(BackendConfig::Mcp {
993 server_name: "blocked-server".to_string(),
994 transport: McpTransport::Tcp {
995 host: "127.0.0.1".to_string(),
996 port: 2222,
997 },
998 tool_filter: ToolFilter::All,
999 });
1000
1001 event_store.create_session(parent_session_id).await.unwrap();
1002 event_store
1003 .append(
1004 parent_session_id,
1005 &SessionEvent::SessionCreated {
1006 config: Box::new(session_config),
1007 metadata: HashMap::new(),
1008 parent_session_id: None,
1009 },
1010 )
1011 .await
1012 .unwrap();
1013
1014 let agent_id = format!("allowlist_{}", Uuid::new_v4());
1015 let spec = AgentSpec {
1016 id: agent_id.clone(),
1017 name: "allowlist test".to_string(),
1018 description: "allowlist test".to_string(),
1019 tools: vec![VIEW_TOOL_NAME.to_string()],
1020 mcp_access: McpAccessPolicy::Allowlist(vec!["allowed-server".to_string()]),
1021 model: None,
1022 };
1023 match register_agent_spec(spec) {
1024 Ok(()) => {}
1025 Err(AgentSpecError::AlreadyRegistered(_)) => {}
1026 Err(AgentSpecError::RegistryPoisoned) => {}
1027 }
1028
1029 let captured = Arc::new(tokio::sync::Mutex::new(None));
1030 let spawner = CapturingAgentSpawner {
1031 session_id: SessionId::new(),
1032 response: "ok".to_string(),
1033 captured: captured.clone(),
1034 };
1035
1036 let services = Arc::new(
1037 ToolServices::new(workspace, event_store, api_client)
1038 .with_agent_spawner(Arc::new(spawner)),
1039 );
1040
1041 let ctx = StaticToolContext {
1042 tool_call_id: ToolCallId::new(),
1043 session_id: parent_session_id,
1044 cancellation_token: CancellationToken::new(),
1045 services,
1046 };
1047
1048 let params = DispatchAgentParams {
1049 prompt: "test".to_string(),
1050 target: DispatchAgentTarget::New {
1051 workspace: WorkspaceTarget::Current,
1052 agent: Some(agent_id),
1053 },
1054 };
1055
1056 let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1057 let captured = captured.lock().await.clone().expect("no config captured");
1058
1059 let backend_names: Vec<String> = captured
1060 .mcp_backends
1061 .iter()
1062 .map(|backend| match backend {
1063 BackendConfig::Mcp { server_name, .. } => server_name.clone(),
1064 })
1065 .collect();
1066
1067 assert_eq!(backend_names, vec!["allowed-server".to_string()]);
1068 assert!(captured.allow_mcp_tools);
1069 }
1070
1071 #[tokio::test]
1072 async fn dispatch_agent_uses_parent_model_when_spec_missing_model() {
1073 let event_store = Arc::new(InMemoryEventStore::new());
1074 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1075 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1076 let api_client = Arc::new(ApiClient::new_with_deps(
1077 crate::test_utils::test_llm_config_provider().unwrap(),
1078 provider_registry,
1079 model_registry,
1080 ));
1081 let workspace =
1082 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1083 path: std::env::current_dir().unwrap(),
1084 })
1085 .await
1086 .unwrap();
1087
1088 let parent_session_id = SessionId::new();
1089 let parent_model = builtin::claude_sonnet_4_5();
1090 let session_config = SessionConfig::read_only(parent_model.clone());
1091
1092 event_store.create_session(parent_session_id).await.unwrap();
1093 event_store
1094 .append(
1095 parent_session_id,
1096 &SessionEvent::SessionCreated {
1097 config: Box::new(session_config),
1098 metadata: HashMap::new(),
1099 parent_session_id: None,
1100 },
1101 )
1102 .await
1103 .unwrap();
1104
1105 let agent_id = format!("inherit_model_{}", Uuid::new_v4());
1106 let spec = AgentSpec {
1107 id: agent_id.clone(),
1108 name: "inherit model test".to_string(),
1109 description: "inherit model test".to_string(),
1110 tools: vec![VIEW_TOOL_NAME.to_string()],
1111 mcp_access: McpAccessPolicy::None,
1112 model: None,
1113 };
1114 match register_agent_spec(spec) {
1115 Ok(()) => {}
1116 Err(AgentSpecError::AlreadyRegistered(_)) => {}
1117 Err(AgentSpecError::RegistryPoisoned) => {}
1118 }
1119
1120 let captured = Arc::new(tokio::sync::Mutex::new(None));
1121 let spawner = CapturingAgentSpawner {
1122 session_id: SessionId::new(),
1123 response: "ok".to_string(),
1124 captured: captured.clone(),
1125 };
1126
1127 let services = Arc::new(
1128 ToolServices::new(workspace, event_store, api_client)
1129 .with_agent_spawner(Arc::new(spawner)),
1130 );
1131
1132 let ctx = StaticToolContext {
1133 tool_call_id: ToolCallId::new(),
1134 session_id: parent_session_id,
1135 cancellation_token: CancellationToken::new(),
1136 services,
1137 };
1138
1139 let params = DispatchAgentParams {
1140 prompt: "test".to_string(),
1141 target: DispatchAgentTarget::New {
1142 workspace: WorkspaceTarget::Current,
1143 agent: Some(agent_id),
1144 },
1145 };
1146
1147 let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1148 let captured = captured.lock().await.clone().expect("no config captured");
1149
1150 assert_eq!(captured.model, parent_model);
1151 }
1152
1153 #[tokio::test]
1154 async fn dispatch_agent_uses_spec_model_when_set() {
1155 let event_store = Arc::new(InMemoryEventStore::new());
1156 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1157 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1158 let api_client = Arc::new(ApiClient::new_with_deps(
1159 crate::test_utils::test_llm_config_provider().unwrap(),
1160 provider_registry,
1161 model_registry,
1162 ));
1163 let workspace =
1164 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1165 path: std::env::current_dir().unwrap(),
1166 })
1167 .await
1168 .unwrap();
1169
1170 let parent_session_id = SessionId::new();
1171 let parent_model = builtin::claude_sonnet_4_5();
1172 let session_config = SessionConfig::read_only(parent_model);
1173
1174 event_store.create_session(parent_session_id).await.unwrap();
1175 event_store
1176 .append(
1177 parent_session_id,
1178 &SessionEvent::SessionCreated {
1179 config: Box::new(session_config),
1180 metadata: HashMap::new(),
1181 parent_session_id: None,
1182 },
1183 )
1184 .await
1185 .unwrap();
1186
1187 let spec_model = builtin::claude_haiku_4_5();
1188 let agent_id = format!("spec_model_{}", Uuid::new_v4());
1189 let spec = AgentSpec {
1190 id: agent_id.clone(),
1191 name: "spec model test".to_string(),
1192 description: "spec model test".to_string(),
1193 tools: vec![VIEW_TOOL_NAME.to_string()],
1194 mcp_access: McpAccessPolicy::None,
1195 model: Some(spec_model.clone()),
1196 };
1197 match register_agent_spec(spec) {
1198 Ok(()) => {}
1199 Err(AgentSpecError::AlreadyRegistered(_)) => {}
1200 Err(AgentSpecError::RegistryPoisoned) => {}
1201 }
1202
1203 let captured = Arc::new(tokio::sync::Mutex::new(None));
1204 let spawner = CapturingAgentSpawner {
1205 session_id: SessionId::new(),
1206 response: "ok".to_string(),
1207 captured: captured.clone(),
1208 };
1209
1210 let services = Arc::new(
1211 ToolServices::new(workspace, event_store, api_client)
1212 .with_agent_spawner(Arc::new(spawner)),
1213 );
1214
1215 let ctx = StaticToolContext {
1216 tool_call_id: ToolCallId::new(),
1217 session_id: parent_session_id,
1218 cancellation_token: CancellationToken::new(),
1219 services,
1220 };
1221
1222 let params = DispatchAgentParams {
1223 prompt: "test".to_string(),
1224 target: DispatchAgentTarget::New {
1225 workspace: WorkspaceTarget::Current,
1226 agent: Some(agent_id),
1227 },
1228 };
1229
1230 let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1231 let captured = captured.lock().await.clone().expect("no config captured");
1232
1233 assert_eq!(captured.model, spec_model);
1234 }
1235
1236 #[tokio::test]
1237 async fn resume_session_denies_disallowed_tools() {
1238 let event_store = Arc::new(InMemoryEventStore::new());
1239 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1240 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1241 let api_client = Arc::new(ApiClient::new_with_deps(
1242 crate::test_utils::test_llm_config_provider().unwrap(),
1243 provider_registry,
1244 model_registry,
1245 ));
1246 let workspace =
1247 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1248 path: std::env::current_dir().unwrap(),
1249 })
1250 .await
1251 .unwrap();
1252
1253 let parent_session_id = SessionId::new();
1254 let session_id = SessionId::new();
1255 let model = builtin::claude_sonnet_4_5();
1256
1257 let tool_call = steer_tools::ToolCall {
1258 name: "bash".to_string(),
1259 parameters: serde_json::json!({ "command": "echo denied" }),
1260 id: "tool_denied".to_string(),
1261 };
1262 api_client.insert_test_provider(
1263 model.provider.clone(),
1264 Arc::new(ToolCallThenTextProvider::new(tool_call, "done")),
1265 );
1266
1267 let mut session_config = SessionConfig::read_only(model);
1268 session_config.parent_session_id = Some(parent_session_id);
1269 session_config.policy_overrides = SessionPolicyOverrides {
1270 default_model: None,
1271 tool_visibility: Some(ToolVisibility::Whitelist(HashSet::from([
1272 VIEW_TOOL_NAME.to_string()
1273 ]))),
1274 approval_policy: ToolApprovalPolicyOverrides {
1275 default_behavior: Some(UnapprovedBehavior::Deny),
1276 preapproved: ApprovalRulesOverrides {
1277 tools: HashSet::from([VIEW_TOOL_NAME.to_string()]),
1278 per_tool: HashMap::new(),
1279 },
1280 },
1281 };
1282
1283 event_store.create_session(session_id).await.unwrap();
1284 event_store
1285 .append(
1286 session_id,
1287 &SessionEvent::SessionCreated {
1288 config: Box::new(session_config),
1289 metadata: HashMap::new(),
1290 parent_session_id: Some(parent_session_id),
1291 },
1292 )
1293 .await
1294 .unwrap();
1295
1296 let services = Arc::new(ToolServices::new(
1297 workspace,
1298 event_store.clone(),
1299 api_client,
1300 ));
1301
1302 let ctx = StaticToolContext {
1303 tool_call_id: ToolCallId::new(),
1304 session_id: parent_session_id,
1305 cancellation_token: CancellationToken::new(),
1306 services,
1307 };
1308
1309 let _ = resume_agent_session(session_id, "trigger".to_string(), &ctx)
1310 .await
1311 .unwrap();
1312
1313 let events = event_store.load_events(session_id).await.unwrap();
1314 let denied = events.iter().any(|(_, event)| match event {
1315 SessionEvent::ToolCallFailed { name, error, .. } => {
1316 name == "bash" && error.contains("denied by policy")
1317 }
1318 _ => false,
1319 });
1320
1321 assert!(denied, "expected denied ToolCallFailed event for bash");
1322 }
1323}