1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::agents::{
6 McpAccessPolicy, agent_spec, agent_specs, agent_specs_prompt, default_agent_spec_id,
7};
8use crate::app::domain::event::SessionEvent;
9use crate::app::domain::runtime::RuntimeService;
10use crate::app::domain::types::SessionId;
11use crate::app::validation::ValidatorRegistry;
12use crate::config::model::builtin::claude_sonnet_4_5 as default_model;
13use crate::runners::OneShotRunner;
14use crate::session::state::BackendConfig;
15use crate::tools::capability::Capabilities;
16use crate::tools::services::{SubAgentConfig, SubAgentError, ToolServices};
17use crate::tools::static_tool::{StaticTool, StaticToolContext, StaticToolError};
18use crate::tools::{BackendRegistry, ToolExecutor, ToolRegistry};
19use crate::workspace::{
20 CreateWorkspaceRequest, EnvironmentId, RepoRef, VcsKind, VcsStatus, Workspace,
21 WorkspaceCreateStrategy, WorkspaceRef, create_workspace_from_session_config,
22};
23use steer_tools::ToolSpec;
24use steer_tools::result::{AgentResult, AgentWorkspaceInfo, AgentWorkspaceRevision};
25use steer_tools::tools::dispatch_agent::{
26 DispatchAgentError, DispatchAgentParams, DispatchAgentTarget, DispatchAgentToolSpec,
27 WorkspaceTarget,
28};
29use steer_tools::tools::{GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
30use tracing::warn;
31
32use super::{
33 AstGrepTool, BashTool, EditTool, FetchTool, GlobTool, GrepTool, LsTool, MultiEditTool,
34 ReplaceTool, TodoReadTool, TodoWriteTool, ViewTool, workspace_manager_op_error,
35 workspace_op_error,
36};
37
38fn dispatch_agent_description() -> String {
39 let agent_specs = agent_specs_prompt();
40 let agent_specs_block = if agent_specs.is_empty() {
41 "No agent specs registered.".to_string()
42 } else {
43 agent_specs
44 };
45
46 format!(
47 r#"Launch a new agent to help with a focused task. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you.
48
49 When to use the Agent tool:
50 - If you are searching for a keyword like "config" or "logger", or for questions like "which file does X?", the Agent tool is strongly recommended
51
52 When NOT to use the Agent tool:
53 - If you want to read a specific file path, use the {} or {} tool instead of the Agent tool, to find the match more quickly
54 - If you are searching for a specific class definition like "class Foo", use the {} tool instead, to find the match more quickly
55 - If you are searching for code within a specific file or set of 2-3 files, use the {} tool instead, to find the match more quickly
56
57 Usage notes:
58 1. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses
59 2. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.
60 3. Each invocation returns a session_id. Pass it back via `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue the conversation with the same agent.
61 4. When `target.session` is `resume`, the session_id must refer to a child of the current session. The `agent` and `workspace` options are ignored and the existing session config is used.
62 5. The agent's outputs should generally be trusted
63 6. IMPORTANT: Only some agent specs include write tools. Use a build agent if the task requires editing files.
64 7. IMPORTANT: New workspaces are preserved (not auto-deleted). Clean them up manually if needed.
65 8. If the agent spec omits a model, the parent session's default model is used.
66 9. IMPORTANT: Do NOT include `Repo: <path>`, `CWD: <path>`, or similar path headers in your prompt. The sub-agent already receives its working directory via system instructions.
67 10. IMPORTANT: If `target.session` is `new` and `workspace.location` is `new`, the sub-agent runs in the newly created workspace path, which may differ from the caller's current directory.
68
69Workspace options:
70- `workspace: {{ "location": "current" }}` to run in the current workspace
71- `workspace: {{ "location": "new", "name": "..." }}` to run in a fresh workspace (jj workspace or git worktree)
72- `location` is a logical workspace selector, not a filesystem path
73
74 Session options:
75 - `target: {{ "session": "resume", "session_id": "<uuid>" }}` to continue a prior dispatch_agent session
76
77 New session options:
78 - `target: {{ "session": "new", "workspace": {{ "location": "current" }} }}` to run in the current workspace
79 - `target: {{ "session": "new", "workspace": {{ "location": "new", "name": "..." }} }}` to run in a new workspace
80 - `target: {{ "session": "new", "workspace": {{ "location": "current" }}, "agent": "<id>" }}` selects an agent spec (defaults to "{default_agent}")
81
82 {agent_specs_block}"#,
83 VIEW_TOOL_NAME,
84 LS_TOOL_NAME,
85 GREP_TOOL_NAME,
86 GREP_TOOL_NAME,
87 default_agent = default_agent_spec_id(),
88 agent_specs_block = agent_specs_block
89 )
90}
91
92pub struct DispatchAgentTool;
93
94#[async_trait]
95impl StaticTool for DispatchAgentTool {
96 type Params = DispatchAgentParams;
97 type Output = AgentResult;
98 type Spec = DispatchAgentToolSpec;
99
100 const DESCRIPTION: &'static str = "Launch a sub-agent to search for files or code";
101 const REQUIRES_APPROVAL: bool = false;
102 const REQUIRED_CAPABILITIES: Capabilities = Capabilities::AGENT;
103
104 fn schema() -> steer_tools::ToolSchema {
105 let settings = schemars::generate::SchemaSettings::draft07().with(|s| {
106 s.inline_subschemas = true;
107 });
108 let schema_gen = settings.into_generator();
109 let input_schema = schema_gen.into_root_schema_for::<Self::Params>();
110
111 steer_tools::ToolSchema {
112 name: Self::Spec::NAME.to_string(),
113 display_name: Self::Spec::DISPLAY_NAME.to_string(),
114 description: dispatch_agent_description(),
115 input_schema: input_schema.into(),
116 }
117 }
118
119 async fn execute(
120 &self,
121 params: Self::Params,
122 ctx: &StaticToolContext,
123 ) -> Result<Self::Output, StaticToolError<DispatchAgentError>> {
124 let DispatchAgentParams { prompt, target } = params;
125
126 let (workspace_target, agent) = match target {
127 DispatchAgentTarget::Resume { session_id } => {
128 let session_id = SessionId::parse(&session_id).ok_or_else(|| {
129 StaticToolError::invalid_params(format!("Invalid session_id '{session_id}'"))
130 })?;
131 return resume_agent_session(session_id, prompt, ctx).await;
132 }
133 DispatchAgentTarget::New { workspace, agent } => (workspace, agent),
134 };
135
136 let spawner = ctx
137 .services
138 .agent_spawner()
139 .ok_or_else(|| StaticToolError::missing_capability("agent_spawner"))?;
140
141 let base_workspace = ctx.services.workspace.clone();
142 let base_path = base_workspace.working_directory().to_path_buf();
143
144 let mut workspace = base_workspace.clone();
145 let mut workspace_ref = None;
146 let mut workspace_id = None;
147 let mut workspace_name = None;
148 let mut repo_id = None;
149 let mut repo_ref = None;
150
151 if let Some(manager) = ctx.services.workspace_manager()
152 && let Ok(info) = manager.resolve_workspace(&base_path).await
153 {
154 workspace_id = Some(info.workspace_id);
155 workspace_name.clone_from(&info.name);
156 repo_id = Some(info.repo_id);
157 workspace_ref = Some(WorkspaceRef {
158 environment_id: info.environment_id,
159 workspace_id: info.workspace_id,
160 repo_id: info.repo_id,
161 });
162 }
163
164 if let Some(manager) = ctx.services.repo_manager() {
165 let repo_env_id = workspace_ref
166 .as_ref()
167 .map_or_else(EnvironmentId::local, |reference| reference.environment_id);
168 if let Ok(info) = manager.resolve_repo(repo_env_id, &base_path).await {
169 if repo_id.is_none() {
170 repo_id = Some(info.repo_id);
171 }
172 repo_ref = Some(RepoRef {
173 environment_id: info.environment_id,
174 repo_id: info.repo_id,
175 root_path: info.root_path,
176 vcs_kind: info.vcs_kind,
177 });
178 }
179 }
180
181 let mut new_workspace = false;
182 let mut requested_workspace_name = None;
183
184 match &workspace_target {
185 WorkspaceTarget::Current => {}
186 WorkspaceTarget::New { name } => {
187 new_workspace = true;
188 requested_workspace_name = Some(name.clone());
189 }
190 }
191
192 let mut created_workspace_id = None;
193 let mut status_manager = None;
194
195 if new_workspace {
196 let manager = ctx
197 .services
198 .workspace_manager()
199 .ok_or_else(|| StaticToolError::missing_capability("workspace_manager"))?;
200 status_manager = Some(manager.clone());
201
202 let base_repo_id = repo_id.ok_or_else(|| {
203 StaticToolError::execution(DispatchAgentError::WorkspaceUnavailable {
204 message:
205 "Current path is not a supported workspace; cannot create new workspace"
206 .to_string(),
207 })
208 })?;
209
210 let strategy = match repo_ref
211 .as_ref()
212 .and_then(|reference| reference.vcs_kind.as_ref())
213 {
214 Some(VcsKind::Git) => WorkspaceCreateStrategy::GitWorktree,
215 _ => WorkspaceCreateStrategy::JjWorkspace,
216 };
217
218 let create_request = CreateWorkspaceRequest {
219 repo_id: base_repo_id,
220 name: requested_workspace_name.clone(),
221 parent_workspace_id: workspace_id,
222 strategy,
223 };
224
225 let info = manager
226 .create_workspace(create_request)
227 .await
228 .map_err(|e| {
229 StaticToolError::execution(DispatchAgentError::Workspace(
230 workspace_manager_op_error(e),
231 ))
232 })?;
233
234 workspace = manager
235 .open_workspace(info.workspace_id)
236 .await
237 .map_err(|e| {
238 StaticToolError::execution(DispatchAgentError::Workspace(
239 workspace_manager_op_error(e),
240 ))
241 })?;
242
243 workspace_id = Some(info.workspace_id);
244 created_workspace_id = Some(info.workspace_id);
245 workspace_name.clone_from(&info.name);
246 workspace_ref = Some(WorkspaceRef {
247 environment_id: info.environment_id,
248 workspace_id: info.workspace_id,
249 repo_id: info.repo_id,
250 });
251
252 if let Some(repo_manager) = ctx.services.repo_manager()
253 && let Ok(info) = repo_manager
254 .resolve_repo(info.environment_id, workspace.working_directory())
255 .await
256 {
257 repo_ref = Some(RepoRef {
258 environment_id: info.environment_id,
259 repo_id: info.repo_id,
260 root_path: info.root_path,
261 vcs_kind: info.vcs_kind,
262 });
263 }
264 }
265
266 let env_info = workspace.environment().await.map_err(|e| {
267 StaticToolError::execution(DispatchAgentError::Workspace(workspace_op_error(e)))
268 })?;
269
270 let system_prompt = format!(
271 r#"You are an agent for a CLI-based coding tool. Given the user's prompt, you should use the tools available to you to answer the user's question.
272
273Notes:
2741. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2752. When relevant, share file names and code snippets relevant to the query
2763. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.
277
278{}
279"#,
280 env_info.as_context()
281 );
282
283 let agent_id = agent
284 .as_deref()
285 .filter(|value| !value.trim().is_empty())
286 .map_or_else(|| default_agent_spec_id().to_string(), str::to_string);
287
288 let agent_spec = agent_spec(&agent_id).ok_or_else(|| {
289 let available = agent_specs()
290 .into_iter()
291 .map(|spec| spec.id)
292 .collect::<Vec<_>>()
293 .join(", ");
294 StaticToolError::invalid_params(format!(
295 "Unknown agent spec '{agent_id}'. Available: {available}"
296 ))
297 })?;
298
299 let parent_session_config = match ctx.services.event_store.load_events(ctx.session_id).await
300 {
301 Ok(events) => events.into_iter().find_map(|(_, event)| match event {
302 SessionEvent::SessionCreated { config, .. } => Some(*config),
303 _ => None,
304 }),
305 Err(err) => {
306 warn!(
307 session_id = %ctx.session_id,
308 "Failed to load parent session config for MCP servers: {err}"
309 );
310 None
311 }
312 };
313
314 let parent_mcp_backends = parent_session_config
315 .as_ref()
316 .map(|config| config.tool_config.backends.clone())
317 .unwrap_or_default();
318
319 let parent_model = parent_session_config
320 .as_ref()
321 .map_or_else(default_model, |config| config.default_model.clone());
322
323 let allow_mcp_tools = agent_spec.mcp_access.allow_mcp_tools();
324 let mcp_backends = match &agent_spec.mcp_access {
325 McpAccessPolicy::None => Vec::new(),
326 McpAccessPolicy::All => parent_mcp_backends,
327 McpAccessPolicy::Allowlist(servers) => parent_mcp_backends
328 .into_iter()
329 .filter(|backend| match backend {
330 BackendConfig::Mcp { server_name, .. } => {
331 servers.iter().any(|allowed| allowed == server_name)
332 }
333 })
334 .collect(),
335 };
336
337 let config = SubAgentConfig {
338 parent_session_id: ctx.session_id,
339 prompt,
340 allowed_tools: agent_spec.tools.clone(),
341 model: agent_spec.model.clone().unwrap_or(parent_model),
342 system_context: Some(crate::app::SystemContext::new(system_prompt)),
343 workspace: Some(workspace),
344 workspace_ref,
345 workspace_id,
346 repo_ref,
347 workspace_name,
348 mcp_backends,
349 allow_mcp_tools,
350 };
351
352 let spawn_result = spawner.spawn(config, ctx.cancellation_token.clone()).await;
353
354 let mut workspace_info = None;
355
356 if let (Some(manager), Some(workspace_id)) = (status_manager, created_workspace_id) {
357 let revision = match manager.get_workspace_status(workspace_id).await {
358 Ok(status) => match status.vcs {
359 Some(vcs) => match vcs.status {
360 VcsStatus::Jj(jj_status) => {
361 jj_status.working_copy.map(|wc| AgentWorkspaceRevision {
362 vcs_kind: "jj".to_string(),
363 revision_id: wc.commit_id,
364 summary: wc.description,
365 change_id: Some(wc.change_id),
366 })
367 }
368 VcsStatus::Git(_) => None,
369 },
370 None => None,
371 },
372 Err(err) => {
373 warn!(
374 workspace_id = %workspace_id.as_uuid(),
375 "Failed to get workspace status for sub-agent: {err}"
376 );
377 None
378 }
379 };
380
381 workspace_info = Some(AgentWorkspaceInfo {
382 workspace_id: Some(workspace_id.as_uuid().to_string()),
383 revision,
384 });
385 }
386
387 let result = spawn_result.map_err(|e| match e {
388 SubAgentError::Cancelled => StaticToolError::Cancelled,
389 other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
390 message: other.to_string(),
391 }),
392 })?;
393
394 Ok(AgentResult {
395 content: result.final_message.extract_text(),
396 session_id: Some(result.session_id.to_string()),
397 workspace: workspace_info,
398 })
399 }
400}
401
402fn build_runtime_tool_executor(
403 workspace: Arc<dyn Workspace>,
404 parent_services: &Arc<ToolServices>,
405) -> Arc<ToolExecutor> {
406 let mut services = ToolServices::new(
407 workspace.clone(),
408 parent_services.event_store.clone(),
409 parent_services.api_client.clone(),
410 );
411
412 if let Some(spawner) = parent_services.agent_spawner() {
413 services = services.with_agent_spawner(spawner.clone());
414 }
415 if let Some(caller) = parent_services.model_caller() {
416 services = services.with_model_caller(caller.clone());
417 }
418 if let Some(manager) = parent_services.workspace_manager() {
419 services = services.with_workspace_manager(manager.clone());
420 }
421 if let Some(manager) = parent_services.repo_manager() {
422 services = services.with_repo_manager(manager.clone());
423 }
424 if parent_services
425 .capabilities()
426 .contains(Capabilities::NETWORK)
427 {
428 services = services.with_network();
429 }
430
431 let mut registry = ToolRegistry::new();
432 registry.register_static(GrepTool);
433 registry.register_static(GlobTool);
434 registry.register_static(LsTool);
435 registry.register_static(ViewTool);
436 registry.register_static(BashTool);
437 registry.register_static(EditTool);
438 registry.register_static(MultiEditTool);
439 registry.register_static(ReplaceTool);
440 registry.register_static(AstGrepTool);
441 registry.register_static(TodoReadTool);
442 registry.register_static(TodoWriteTool);
443 registry.register_static(DispatchAgentTool);
444 registry.register_static(FetchTool);
445
446 Arc::new(
447 ToolExecutor::with_components(
448 Arc::new(BackendRegistry::new()),
449 Arc::new(ValidatorRegistry::new()),
450 )
451 .with_static_tools(Arc::new(registry), Arc::new(services)),
452 )
453}
454
455async fn resume_agent_session(
456 session_id: SessionId,
457 prompt: String,
458 ctx: &StaticToolContext,
459) -> Result<AgentResult, StaticToolError<DispatchAgentError>> {
460 let events = ctx
461 .services
462 .event_store
463 .load_events(session_id)
464 .await
465 .map_err(|e| {
466 StaticToolError::execution(DispatchAgentError::SpawnFailed {
467 message: format!("Failed to load session {session_id}: {e}"),
468 })
469 })?;
470
471 let session_config = events
472 .into_iter()
473 .find_map(|(_, event)| match event {
474 SessionEvent::SessionCreated { config, .. } => Some(*config),
475 _ => None,
476 })
477 .ok_or_else(|| {
478 StaticToolError::execution(DispatchAgentError::SpawnFailed {
479 message: format!("Session {session_id} is missing a SessionCreated event"),
480 })
481 })?;
482
483 if session_config.parent_session_id != Some(ctx.session_id) {
484 return Err(StaticToolError::invalid_params(format!(
485 "Session {session_id} is not a child of current session {}",
486 ctx.session_id
487 )));
488 }
489
490 let workspace = create_workspace_from_session_config(&session_config.workspace)
491 .await
492 .map_err(|e| {
493 StaticToolError::execution(DispatchAgentError::SpawnFailed {
494 message: format!("Failed to open workspace for session {session_id}: {e}"),
495 })
496 })?;
497
498 let tool_executor = build_runtime_tool_executor(workspace, &ctx.services);
499 let runtime = RuntimeService::spawn(
500 ctx.services.event_store.clone(),
501 ctx.services.api_client.clone(),
502 tool_executor,
503 );
504
505 let run_result = OneShotRunner::run_in_session_with_cancel(
506 &runtime.handle,
507 session_id,
508 prompt,
509 session_config.default_model.clone(),
510 ctx.cancellation_token.clone(),
511 )
512 .await;
513
514 runtime.shutdown().await;
515
516 let run_result = run_result.map_err(|e| match e {
517 crate::error::Error::Cancelled => StaticToolError::Cancelled,
518 other => StaticToolError::execution(DispatchAgentError::SpawnFailed {
519 message: other.to_string(),
520 }),
521 })?;
522
523 Ok(AgentResult {
524 content: run_result.final_message.extract_text(),
525 session_id: Some(run_result.session_id.to_string()),
526 workspace: None,
527 })
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use crate::agents::{AgentSpec, AgentSpecError, McpAccessPolicy, register_agent_spec};
534 use crate::api::Client as ApiClient;
535 use crate::api::{ApiError, CompletionResponse, Provider};
536 use crate::app::conversation::{AssistantContent, Message, MessageData};
537 use crate::app::domain::session::EventStore;
538 use crate::app::domain::session::event_store::InMemoryEventStore;
539 use crate::app::domain::types::ToolCallId;
540 use crate::config::model::builtin;
541 use crate::model_registry::ModelRegistry;
542 use crate::session::state::{
543 ApprovalRulesOverrides, SessionConfig, SessionPolicyOverrides, ToolApprovalPolicyOverrides,
544 ToolFilter, ToolVisibility, UnapprovedBehavior,
545 };
546 use crate::tools::McpTransport;
547 use crate::tools::services::{AgentSpawner, SubAgentError, SubAgentResult, ToolServices};
548 use async_trait::async_trait;
549 use std::collections::{HashMap, HashSet};
550 use std::sync::Mutex as StdMutex;
551 use tokio::time::{Duration, sleep};
552 use tokio_util::sync::CancellationToken;
553 use uuid::Uuid;
554
555 #[derive(Clone)]
556 struct StubProvider {
557 response: String,
558 }
559
560 impl StubProvider {
561 fn new(response: impl Into<String>) -> Self {
562 Self {
563 response: response.into(),
564 }
565 }
566 }
567
568 #[derive(Clone)]
569 struct CancelAwareProvider;
570
571 #[async_trait]
572 impl Provider for CancelAwareProvider {
573 fn name(&self) -> &'static str {
574 "cancel-aware"
575 }
576
577 async fn complete(
578 &self,
579 _model_id: &crate::config::model::ModelId,
580 _messages: Vec<Message>,
581 _system: Option<crate::app::SystemContext>,
582 _tools: Option<Vec<steer_tools::ToolSchema>>,
583 _call_options: Option<crate::config::model::ModelParameters>,
584 token: CancellationToken,
585 ) -> Result<CompletionResponse, ApiError> {
586 token.cancelled().await;
587 Err(ApiError::Cancelled {
588 provider: self.name().to_string(),
589 })
590 }
591 }
592
593 #[async_trait]
594 impl Provider for StubProvider {
595 fn name(&self) -> &'static str {
596 "stub"
597 }
598
599 async fn complete(
600 &self,
601 _model_id: &crate::config::model::ModelId,
602 _messages: Vec<Message>,
603 _system: Option<crate::app::SystemContext>,
604 _tools: Option<Vec<steer_tools::ToolSchema>>,
605 _call_options: Option<crate::config::model::ModelParameters>,
606 _token: CancellationToken,
607 ) -> Result<CompletionResponse, ApiError> {
608 Ok(CompletionResponse {
609 content: vec![AssistantContent::Text {
610 text: self.response.clone(),
611 }],
612 })
613 }
614 }
615
616 #[derive(Clone)]
617 struct StubAgentSpawner {
618 session_id: SessionId,
619 response: String,
620 }
621
622 #[async_trait]
623 impl AgentSpawner for StubAgentSpawner {
624 async fn spawn(
625 &self,
626 _config: crate::tools::services::SubAgentConfig,
627 _cancel_token: CancellationToken,
628 ) -> Result<SubAgentResult, SubAgentError> {
629 let timestamp = Message::current_timestamp();
630 let message = Message {
631 timestamp,
632 id: Message::generate_id("assistant", timestamp),
633 parent_message_id: None,
634 data: MessageData::Assistant {
635 content: vec![AssistantContent::Text {
636 text: self.response.clone(),
637 }],
638 },
639 };
640
641 Ok(SubAgentResult {
642 session_id: self.session_id,
643 final_message: message,
644 })
645 }
646 }
647
648 #[derive(Clone)]
649 struct CapturingAgentSpawner {
650 session_id: SessionId,
651 response: String,
652 captured: Arc<tokio::sync::Mutex<Option<crate::tools::services::SubAgentConfig>>>,
653 }
654
655 #[async_trait]
656 impl AgentSpawner for CapturingAgentSpawner {
657 async fn spawn(
658 &self,
659 config: crate::tools::services::SubAgentConfig,
660 _cancel_token: CancellationToken,
661 ) -> Result<SubAgentResult, SubAgentError> {
662 let mut guard = self.captured.lock().await;
663 *guard = Some(config);
664
665 let timestamp = Message::current_timestamp();
666 let message = Message {
667 timestamp,
668 id: Message::generate_id("assistant", timestamp),
669 parent_message_id: None,
670 data: MessageData::Assistant {
671 content: vec![AssistantContent::Text {
672 text: self.response.clone(),
673 }],
674 },
675 };
676
677 Ok(SubAgentResult {
678 session_id: self.session_id,
679 final_message: message,
680 })
681 }
682 }
683
684 #[derive(Clone)]
685 struct ToolCallThenTextProvider {
686 tool_call: steer_tools::ToolCall,
687 final_text: String,
688 call_count: Arc<StdMutex<usize>>,
689 }
690
691 impl ToolCallThenTextProvider {
692 fn new(tool_call: steer_tools::ToolCall, final_text: impl Into<String>) -> Self {
693 Self {
694 tool_call,
695 final_text: final_text.into(),
696 call_count: Arc::new(StdMutex::new(0)),
697 }
698 }
699 }
700
701 #[async_trait]
702 impl Provider for ToolCallThenTextProvider {
703 fn name(&self) -> &'static str {
704 "stub-tool-call"
705 }
706
707 async fn complete(
708 &self,
709 _model_id: &crate::config::model::ModelId,
710 _messages: Vec<Message>,
711 _system: Option<crate::app::SystemContext>,
712 _tools: Option<Vec<steer_tools::ToolSchema>>,
713 _call_options: Option<crate::config::model::ModelParameters>,
714 _token: CancellationToken,
715 ) -> Result<CompletionResponse, ApiError> {
716 let mut count = self
717 .call_count
718 .lock()
719 .expect("tool call counter lock poisoned");
720 let response = if *count == 0 {
721 CompletionResponse {
722 content: vec![AssistantContent::ToolCall {
723 tool_call: self.tool_call.clone(),
724 thought_signature: None,
725 }],
726 }
727 } else {
728 CompletionResponse {
729 content: vec![AssistantContent::Text {
730 text: self.final_text.clone(),
731 }],
732 }
733 };
734 *count += 1;
735 Ok(response)
736 }
737 }
738
739 #[tokio::test]
740 async fn resume_session_rejects_non_child() {
741 let event_store = Arc::new(InMemoryEventStore::new());
742 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
743 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
744 let api_client = Arc::new(ApiClient::new_with_deps(
745 crate::test_utils::test_llm_config_provider().unwrap(),
746 provider_registry,
747 model_registry,
748 ));
749 let workspace =
750 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
751 path: std::env::current_dir().unwrap(),
752 })
753 .await
754 .unwrap();
755
756 let parent_session_id = SessionId::new();
757 let session_id = SessionId::new();
758 let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
759 session_config.parent_session_id = Some(parent_session_id);
760
761 event_store.create_session(session_id).await.unwrap();
762 event_store
763 .append(
764 session_id,
765 &SessionEvent::SessionCreated {
766 config: Box::new(session_config),
767 metadata: std::collections::HashMap::new(),
768 parent_session_id: Some(parent_session_id),
769 },
770 )
771 .await
772 .unwrap();
773
774 let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
775
776 let ctx = StaticToolContext {
777 tool_call_id: ToolCallId::new(),
778 session_id: SessionId::new(),
779 cancellation_token: CancellationToken::new(),
780 services,
781 };
782
783 let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
784
785 assert!(matches!(result, Err(StaticToolError::InvalidParams(_))));
786 }
787
788 #[tokio::test]
789 async fn resume_session_accepts_child_and_returns_message() {
790 let event_store = Arc::new(InMemoryEventStore::new());
791 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
792 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
793 let api_client = Arc::new(ApiClient::new_with_deps(
794 crate::test_utils::test_llm_config_provider().unwrap(),
795 provider_registry,
796 model_registry,
797 ));
798 let model = builtin::claude_sonnet_4_5();
799 api_client.insert_test_provider(
800 model.provider.clone(),
801 Arc::new(StubProvider::new("stub-response")),
802 );
803 let workspace =
804 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
805 path: std::env::current_dir().unwrap(),
806 })
807 .await
808 .unwrap();
809
810 let parent_session_id = SessionId::new();
811 let session_id = SessionId::new();
812 let mut session_config = SessionConfig::read_only(model.clone());
813 session_config.parent_session_id = Some(parent_session_id);
814
815 event_store.create_session(session_id).await.unwrap();
816 event_store
817 .append(
818 session_id,
819 &SessionEvent::SessionCreated {
820 config: Box::new(session_config),
821 metadata: std::collections::HashMap::new(),
822 parent_session_id: Some(parent_session_id),
823 },
824 )
825 .await
826 .unwrap();
827
828 let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
829
830 let ctx = StaticToolContext {
831 tool_call_id: ToolCallId::new(),
832 session_id: parent_session_id,
833 cancellation_token: CancellationToken::new(),
834 services,
835 };
836
837 let result = resume_agent_session(session_id, "ping".to_string(), &ctx)
838 .await
839 .unwrap();
840
841 assert!(result.content.contains("stub-response"));
842 assert_eq!(
843 result.session_id.as_deref(),
844 Some(session_id.to_string().as_str())
845 );
846 }
847
848 #[tokio::test]
849 async fn resume_session_honors_cancellation() {
850 let event_store = Arc::new(InMemoryEventStore::new());
851 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
852 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
853 let api_client = Arc::new(ApiClient::new_with_deps(
854 crate::test_utils::test_llm_config_provider().unwrap(),
855 provider_registry,
856 model_registry,
857 ));
858 let model = builtin::claude_sonnet_4_5();
859 api_client.insert_test_provider(model.provider.clone(), Arc::new(CancelAwareProvider));
860 let workspace =
861 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
862 path: std::env::current_dir().unwrap(),
863 })
864 .await
865 .unwrap();
866
867 let parent_session_id = SessionId::new();
868 let session_id = SessionId::new();
869 let mut session_config = SessionConfig::read_only(model);
870 session_config.parent_session_id = Some(parent_session_id);
871
872 event_store.create_session(session_id).await.unwrap();
873 event_store
874 .append(
875 session_id,
876 &SessionEvent::SessionCreated {
877 config: Box::new(session_config),
878 metadata: std::collections::HashMap::new(),
879 parent_session_id: Some(parent_session_id),
880 },
881 )
882 .await
883 .unwrap();
884
885 let services = Arc::new(ToolServices::new(workspace, event_store, api_client));
886
887 let cancel_token = CancellationToken::new();
888 let ctx = StaticToolContext {
889 tool_call_id: ToolCallId::new(),
890 session_id: parent_session_id,
891 cancellation_token: cancel_token.clone(),
892 services,
893 };
894
895 let cancel_task = tokio::spawn(async move {
896 sleep(Duration::from_millis(10)).await;
897 cancel_token.cancel();
898 });
899
900 let result = resume_agent_session(session_id, "ping".to_string(), &ctx).await;
901 let _ = cancel_task.await;
902
903 assert!(matches!(result, Err(StaticToolError::Cancelled)));
904 }
905
906 #[tokio::test]
907 async fn dispatch_agent_returns_session_id() {
908 let event_store = Arc::new(InMemoryEventStore::new());
909 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
910 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
911 let api_client = Arc::new(ApiClient::new_with_deps(
912 crate::test_utils::test_llm_config_provider().unwrap(),
913 provider_registry,
914 model_registry,
915 ));
916 let workspace =
917 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
918 path: std::env::current_dir().unwrap(),
919 })
920 .await
921 .unwrap();
922
923 let session_id = SessionId::new();
924 let spawner = StubAgentSpawner {
925 session_id,
926 response: "done".to_string(),
927 };
928
929 let services = Arc::new(
930 ToolServices::new(workspace, event_store, api_client)
931 .with_agent_spawner(Arc::new(spawner)),
932 );
933
934 let ctx = StaticToolContext {
935 tool_call_id: ToolCallId::new(),
936 session_id: SessionId::new(),
937 cancellation_token: CancellationToken::new(),
938 services,
939 };
940
941 let params = DispatchAgentParams {
942 prompt: "hello".to_string(),
943 target: DispatchAgentTarget::New {
944 workspace: WorkspaceTarget::Current,
945 agent: None,
946 },
947 };
948
949 let result = DispatchAgentTool.execute(params, &ctx).await.unwrap();
950 assert_eq!(
951 result.session_id.as_deref(),
952 Some(session_id.to_string().as_str())
953 );
954 }
955
956 #[tokio::test]
957 async fn dispatch_agent_filters_mcp_backends_by_allowlist() {
958 let event_store = Arc::new(InMemoryEventStore::new());
959 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
960 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
961 let api_client = Arc::new(ApiClient::new_with_deps(
962 crate::test_utils::test_llm_config_provider().unwrap(),
963 provider_registry,
964 model_registry,
965 ));
966 let workspace =
967 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
968 path: std::env::current_dir().unwrap(),
969 })
970 .await
971 .unwrap();
972
973 let parent_session_id = SessionId::new();
974 let mut session_config = SessionConfig::read_only(builtin::claude_sonnet_4_5());
975 session_config
976 .tool_config
977 .backends
978 .push(BackendConfig::Mcp {
979 server_name: "allowed-server".to_string(),
980 transport: McpTransport::Tcp {
981 host: "127.0.0.1".to_string(),
982 port: 1111,
983 },
984 tool_filter: ToolFilter::All,
985 });
986 session_config
987 .tool_config
988 .backends
989 .push(BackendConfig::Mcp {
990 server_name: "blocked-server".to_string(),
991 transport: McpTransport::Tcp {
992 host: "127.0.0.1".to_string(),
993 port: 2222,
994 },
995 tool_filter: ToolFilter::All,
996 });
997
998 event_store.create_session(parent_session_id).await.unwrap();
999 event_store
1000 .append(
1001 parent_session_id,
1002 &SessionEvent::SessionCreated {
1003 config: Box::new(session_config),
1004 metadata: HashMap::new(),
1005 parent_session_id: None,
1006 },
1007 )
1008 .await
1009 .unwrap();
1010
1011 let agent_id = format!("allowlist_{}", Uuid::new_v4());
1012 let spec = AgentSpec {
1013 id: agent_id.clone(),
1014 name: "allowlist test".to_string(),
1015 description: "allowlist test".to_string(),
1016 tools: vec![VIEW_TOOL_NAME.to_string()],
1017 mcp_access: McpAccessPolicy::Allowlist(vec!["allowed-server".to_string()]),
1018 model: None,
1019 };
1020 match register_agent_spec(spec) {
1021 Ok(()) => {}
1022 Err(AgentSpecError::AlreadyRegistered(_)) => {}
1023 Err(AgentSpecError::RegistryPoisoned) => {}
1024 }
1025
1026 let captured = Arc::new(tokio::sync::Mutex::new(None));
1027 let spawner = CapturingAgentSpawner {
1028 session_id: SessionId::new(),
1029 response: "ok".to_string(),
1030 captured: captured.clone(),
1031 };
1032
1033 let services = Arc::new(
1034 ToolServices::new(workspace, event_store, api_client)
1035 .with_agent_spawner(Arc::new(spawner)),
1036 );
1037
1038 let ctx = StaticToolContext {
1039 tool_call_id: ToolCallId::new(),
1040 session_id: parent_session_id,
1041 cancellation_token: CancellationToken::new(),
1042 services,
1043 };
1044
1045 let params = DispatchAgentParams {
1046 prompt: "test".to_string(),
1047 target: DispatchAgentTarget::New {
1048 workspace: WorkspaceTarget::Current,
1049 agent: Some(agent_id),
1050 },
1051 };
1052
1053 let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1054 let captured = captured.lock().await.clone().expect("no config captured");
1055
1056 let backend_names: Vec<String> = captured
1057 .mcp_backends
1058 .iter()
1059 .map(|backend| match backend {
1060 BackendConfig::Mcp { server_name, .. } => server_name.clone(),
1061 })
1062 .collect();
1063
1064 assert_eq!(backend_names, vec!["allowed-server".to_string()]);
1065 assert!(captured.allow_mcp_tools);
1066 }
1067
1068 #[tokio::test]
1069 async fn dispatch_agent_uses_parent_model_when_spec_missing_model() {
1070 let event_store = Arc::new(InMemoryEventStore::new());
1071 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1072 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1073 let api_client = Arc::new(ApiClient::new_with_deps(
1074 crate::test_utils::test_llm_config_provider().unwrap(),
1075 provider_registry,
1076 model_registry,
1077 ));
1078 let workspace =
1079 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1080 path: std::env::current_dir().unwrap(),
1081 })
1082 .await
1083 .unwrap();
1084
1085 let parent_session_id = SessionId::new();
1086 let parent_model = builtin::claude_sonnet_4_5();
1087 let session_config = SessionConfig::read_only(parent_model.clone());
1088
1089 event_store.create_session(parent_session_id).await.unwrap();
1090 event_store
1091 .append(
1092 parent_session_id,
1093 &SessionEvent::SessionCreated {
1094 config: Box::new(session_config),
1095 metadata: HashMap::new(),
1096 parent_session_id: None,
1097 },
1098 )
1099 .await
1100 .unwrap();
1101
1102 let agent_id = format!("inherit_model_{}", Uuid::new_v4());
1103 let spec = AgentSpec {
1104 id: agent_id.clone(),
1105 name: "inherit model test".to_string(),
1106 description: "inherit model test".to_string(),
1107 tools: vec![VIEW_TOOL_NAME.to_string()],
1108 mcp_access: McpAccessPolicy::None,
1109 model: None,
1110 };
1111 match register_agent_spec(spec) {
1112 Ok(()) => {}
1113 Err(AgentSpecError::AlreadyRegistered(_)) => {}
1114 Err(AgentSpecError::RegistryPoisoned) => {}
1115 }
1116
1117 let captured = Arc::new(tokio::sync::Mutex::new(None));
1118 let spawner = CapturingAgentSpawner {
1119 session_id: SessionId::new(),
1120 response: "ok".to_string(),
1121 captured: captured.clone(),
1122 };
1123
1124 let services = Arc::new(
1125 ToolServices::new(workspace, event_store, api_client)
1126 .with_agent_spawner(Arc::new(spawner)),
1127 );
1128
1129 let ctx = StaticToolContext {
1130 tool_call_id: ToolCallId::new(),
1131 session_id: parent_session_id,
1132 cancellation_token: CancellationToken::new(),
1133 services,
1134 };
1135
1136 let params = DispatchAgentParams {
1137 prompt: "test".to_string(),
1138 target: DispatchAgentTarget::New {
1139 workspace: WorkspaceTarget::Current,
1140 agent: Some(agent_id),
1141 },
1142 };
1143
1144 let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1145 let captured = captured.lock().await.clone().expect("no config captured");
1146
1147 assert_eq!(captured.model, parent_model);
1148 }
1149
1150 #[tokio::test]
1151 async fn dispatch_agent_uses_spec_model_when_set() {
1152 let event_store = Arc::new(InMemoryEventStore::new());
1153 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1154 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1155 let api_client = Arc::new(ApiClient::new_with_deps(
1156 crate::test_utils::test_llm_config_provider().unwrap(),
1157 provider_registry,
1158 model_registry,
1159 ));
1160 let workspace =
1161 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1162 path: std::env::current_dir().unwrap(),
1163 })
1164 .await
1165 .unwrap();
1166
1167 let parent_session_id = SessionId::new();
1168 let parent_model = builtin::claude_sonnet_4_5();
1169 let session_config = SessionConfig::read_only(parent_model);
1170
1171 event_store.create_session(parent_session_id).await.unwrap();
1172 event_store
1173 .append(
1174 parent_session_id,
1175 &SessionEvent::SessionCreated {
1176 config: Box::new(session_config),
1177 metadata: HashMap::new(),
1178 parent_session_id: None,
1179 },
1180 )
1181 .await
1182 .unwrap();
1183
1184 let spec_model = builtin::claude_haiku_4_5();
1185 let agent_id = format!("spec_model_{}", Uuid::new_v4());
1186 let spec = AgentSpec {
1187 id: agent_id.clone(),
1188 name: "spec model test".to_string(),
1189 description: "spec model test".to_string(),
1190 tools: vec![VIEW_TOOL_NAME.to_string()],
1191 mcp_access: McpAccessPolicy::None,
1192 model: Some(spec_model.clone()),
1193 };
1194 match register_agent_spec(spec) {
1195 Ok(()) => {}
1196 Err(AgentSpecError::AlreadyRegistered(_)) => {}
1197 Err(AgentSpecError::RegistryPoisoned) => {}
1198 }
1199
1200 let captured = Arc::new(tokio::sync::Mutex::new(None));
1201 let spawner = CapturingAgentSpawner {
1202 session_id: SessionId::new(),
1203 response: "ok".to_string(),
1204 captured: captured.clone(),
1205 };
1206
1207 let services = Arc::new(
1208 ToolServices::new(workspace, event_store, api_client)
1209 .with_agent_spawner(Arc::new(spawner)),
1210 );
1211
1212 let ctx = StaticToolContext {
1213 tool_call_id: ToolCallId::new(),
1214 session_id: parent_session_id,
1215 cancellation_token: CancellationToken::new(),
1216 services,
1217 };
1218
1219 let params = DispatchAgentParams {
1220 prompt: "test".to_string(),
1221 target: DispatchAgentTarget::New {
1222 workspace: WorkspaceTarget::Current,
1223 agent: Some(agent_id),
1224 },
1225 };
1226
1227 let _ = DispatchAgentTool.execute(params, &ctx).await.unwrap();
1228 let captured = captured.lock().await.clone().expect("no config captured");
1229
1230 assert_eq!(captured.model, spec_model);
1231 }
1232
1233 #[tokio::test]
1234 async fn resume_session_denies_disallowed_tools() {
1235 let event_store = Arc::new(InMemoryEventStore::new());
1236 let model_registry = Arc::new(ModelRegistry::load(&[]).unwrap());
1237 let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[]).unwrap());
1238 let api_client = Arc::new(ApiClient::new_with_deps(
1239 crate::test_utils::test_llm_config_provider().unwrap(),
1240 provider_registry,
1241 model_registry,
1242 ));
1243 let workspace =
1244 crate::workspace::create_workspace(&steer_workspace::WorkspaceConfig::Local {
1245 path: std::env::current_dir().unwrap(),
1246 })
1247 .await
1248 .unwrap();
1249
1250 let parent_session_id = SessionId::new();
1251 let session_id = SessionId::new();
1252 let model = builtin::claude_sonnet_4_5();
1253
1254 let tool_call = steer_tools::ToolCall {
1255 name: "bash".to_string(),
1256 parameters: serde_json::json!({ "command": "echo denied" }),
1257 id: "tool_denied".to_string(),
1258 };
1259 api_client.insert_test_provider(
1260 model.provider.clone(),
1261 Arc::new(ToolCallThenTextProvider::new(tool_call, "done")),
1262 );
1263
1264 let mut session_config = SessionConfig::read_only(model);
1265 session_config.parent_session_id = Some(parent_session_id);
1266 session_config.policy_overrides = SessionPolicyOverrides {
1267 default_model: None,
1268 tool_visibility: Some(ToolVisibility::Whitelist(HashSet::from([
1269 VIEW_TOOL_NAME.to_string()
1270 ]))),
1271 approval_policy: ToolApprovalPolicyOverrides {
1272 default_behavior: Some(UnapprovedBehavior::Deny),
1273 preapproved: ApprovalRulesOverrides {
1274 tools: HashSet::from([VIEW_TOOL_NAME.to_string()]),
1275 per_tool: HashMap::new(),
1276 },
1277 },
1278 };
1279
1280 event_store.create_session(session_id).await.unwrap();
1281 event_store
1282 .append(
1283 session_id,
1284 &SessionEvent::SessionCreated {
1285 config: Box::new(session_config),
1286 metadata: HashMap::new(),
1287 parent_session_id: Some(parent_session_id),
1288 },
1289 )
1290 .await
1291 .unwrap();
1292
1293 let services = Arc::new(ToolServices::new(
1294 workspace,
1295 event_store.clone(),
1296 api_client,
1297 ));
1298
1299 let ctx = StaticToolContext {
1300 tool_call_id: ToolCallId::new(),
1301 session_id: parent_session_id,
1302 cancellation_token: CancellationToken::new(),
1303 services,
1304 };
1305
1306 let _ = resume_agent_session(session_id, "trigger".to_string(), &ctx)
1307 .await
1308 .unwrap();
1309
1310 let events = event_store.load_events(session_id).await.unwrap();
1311 let denied = events.iter().any(|(_, event)| match event {
1312 SessionEvent::ToolCallFailed { name, error, .. } => {
1313 name == "bash" && error.contains("denied by policy")
1314 }
1315 _ => false,
1316 });
1317
1318 assert!(denied, "expected denied ToolCallFailed event for bash");
1319 }
1320}