steer_core/tools/
dispatch_agent.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4
5use crate::{
6    app::{
7        ApprovalDecision,
8        conversation::{Message, MessageData, UserContent},
9    },
10    config::LlmConfigProvider,
11    tools::ToolExecutor,
12};
13
14use crate::app::{AgentEvent, AgentExecutor, AgentExecutorRunRequest};
15use steer_macros::tool_external as tool;
16use steer_tools::tools::{GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
17use steer_tools::{ToolCall, ToolError, ToolSchema};
18use tokio_util::sync::CancellationToken;
19
20#[derive(Deserialize, Debug, Serialize, JsonSchema)]
21pub struct DispatchAgentParams {
22    /// The task for the agent to perform
23    pub prompt: String,
24}
25
26const DISPATCH_AGENT_TOOLS: [&str; 4] =
27    [GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME];
28
29fn format_dispatch_agent_tools() -> String {
30    DISPATCH_AGENT_TOOLS
31        .iter()
32        .map(|tool| tool.to_string())
33        .collect::<Vec<String>>()
34        .join(", ")
35}
36
37tool! {
38    pub struct DispatchAgentTool {
39        pub llm_config_provider: Arc<LlmConfigProvider>,
40        pub workspace: Arc<dyn crate::workspace::Workspace>,
41    } {
42        params: DispatchAgentParams,
43        output: steer_tools::result::AgentResult,
44        variant: Agent,
45        description: &format!(r#"Launch a new agent that has access to the following tools: {}. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you.
46
47When to use the Agent tool:
48- If you are searching for a keyword like "config" or "logger", or for questions like "which file does X?", the Agent tool is strongly recommended
49
50When NOT to use the Agent tool:
51- If you want to read a specific file path, use the {VIEW_TOOL_NAME} or {GLOB_TOOL_NAME} tool instead of the Agent tool, to find the match more quickly
52- If you are searching for a specific class definition like "class Foo", use the {GREP_TOOL_NAME} tool instead, to find the match more quickly
53- If you are searching for code within a specific file or set of 2-3 files, use the {GREP_TOOL_NAME} tool instead, to find the match more quickly
54
55Usage notes:
561. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses
572. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.
583. Each agent invocation is stateless. You will not be able to send additional messages to the agent, nor will the agent be able to communicate with you outside of its final report. Therefore, your prompt should contain a highly detailed task description for the agent to perform autonomously and you should specify exactly what information the agent should return back to you in its final and only message to you.
594. The agent's outputs should generally be trusted
605. IMPORTANT: The agent can not modify files. If you want to modify files, do it directly instead of going through the agent."#, format_dispatch_agent_tools()),
61        name: "dispatch_agent",
62        require_approval: false
63    }
64
65    async fn run(
66        tool: &DispatchAgentTool,
67        params: DispatchAgentParams,
68        context: &steer_tools::ExecutionContext,
69    ) -> std::result::Result<steer_tools::result::AgentResult, ToolError> {
70        let token = context.cancellation_token.clone();
71
72        // Load registries for API client - these are lightweight to load
73        let model_registry = Arc::new(crate::model_registry::ModelRegistry::load(&[])
74            .map_err(|e| ToolError::execution("dispatch_agent", format!("Failed to load model registry: {e}")))?);
75        let provider_registry = Arc::new(crate::auth::ProviderRegistry::load(&[])
76            .map_err(|e| ToolError::execution("dispatch_agent", format!("Failed to load provider registry: {e}")))?);
77
78        let api_client = Arc::new(crate::api::Client::new_with_deps(
79            (*tool.llm_config_provider).clone(),
80            provider_registry,
81            model_registry,
82        )); // Create ApiClient and wrap in Arc
83        let agent_executor = AgentExecutor::new(api_client);
84
85        let tool_executor = Arc::new(ToolExecutor::with_workspace(tool.workspace.clone()));
86
87        let available_tools: Vec<ToolSchema> = tool_executor.get_tool_schemas().await;
88        let tool_approval_callback = move |_tool_call: ToolCall| {
89            async move { Ok(ApprovalDecision::Approved) }
90        };
91
92        let tool_execution_callback =
93            move |tool_call: ToolCall, callback_token: CancellationToken| {
94                let executor = tool_executor.clone();
95                async move {
96                    executor
97                        .execute_tool_with_cancellation(&tool_call, callback_token)
98                        .await
99                }
100            };
101
102        // --- Prepare for AgentExecutor ---
103        let initial_messages = vec![Message {
104            data: MessageData::User {
105                content: vec![UserContent::Text { text: params.prompt }],
106            },
107            timestamp: Message::current_timestamp(),
108            id: Message::generate_id("user", Message::current_timestamp()),
109            parent_message_id: None,
110        }];
111
112        let system_prompt = create_dispatch_agent_system_prompt(&tool.workspace)
113            .await
114            .map_err(|e| ToolError::execution(DISPATCH_AGENT_TOOL_NAME, format!("Failed to create system prompt: {e}")))?;
115
116        // Use a channel to receive events, though we might just aggregate the final result here.
117        let (event_tx, mut event_rx) = tokio::sync::mpsc::channel(100);
118
119        // --- Run AgentExecutor ---
120        let operation_result = agent_executor
121            .run(
122                AgentExecutorRunRequest
123                 {
124                    model: crate::config::model::builtin::claude_3_7_sonnet_20250219(), // Or make configurable?
125                    initial_messages,
126                    system_prompt: Some(system_prompt),
127                    available_tools,
128                    tool_approval_callback,
129                    tool_execution_callback,
130                },
131                event_tx,
132                token,
133            )
134            .await;
135
136        // --- Process Result ---
137        // We need the final text response from the agent.
138        // Collect text from events or the final message.
139        let mut final_text = String::new();
140        // let mut final_message_content: Option<ApiMessage> = None;
141
142        // Drain remaining events
143        while let Ok(event) = event_rx.try_recv() {
144            if let AgentEvent::MessageFinal(msg) = event {
145                if final_text.is_empty() {
146                    final_text = msg.extract_text();
147                }
148            }
149        }
150
151
152        match operation_result {
153            Ok(message) => {
154                 // If we still don't have text, extract from final message object
155                 if final_text.is_empty() {
156                     final_text = message.extract_text();
157                 }
158                 Ok(steer_tools::result::AgentResult {
159                     content: final_text,
160                 })
161            }
162            Err(e) => {
163                 Err(ToolError::execution(DISPATCH_AGENT_TOOL_NAME, e.to_string()))
164            }
165        }
166    }
167}
168
169pub async fn create_dispatch_agent_system_prompt(
170    workspace: &Arc<dyn crate::workspace::Workspace>,
171) -> crate::error::Result<String> {
172    // Get full environment context
173    let env_info = workspace.environment().await?;
174    let env_context = env_info.as_context();
175
176    let dispatch_prompt = format!(
177        r#"You are an agent for a CLI-based coding tool. Given the user's prompt, you should use the tools available to you to answer the user's question.
178
179Notes:
1801. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
1812. When relevant, share file names and code snippets relevant to the query
1823. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.
183
184{env_context}
185"#
186    );
187
188    Ok(dispatch_prompt)
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use dotenvy::dotenv;
195    use steer_workspace::local::LocalWorkspace;
196
197    #[tokio::test]
198    #[ignore] // Requires API key and network call
199    async fn test_dispatch_agent() {
200        // Load environment variables from .env file
201        dotenv().ok();
202
203        // Ensure API key is available for the test
204        let _api_key =
205            std::env::var("CLAUDE_API_KEY").expect("CLAUDE_API_KEY must be set for this test");
206
207        // Setup necessary context for the tool run method
208        let temp_dir = tempfile::tempdir().unwrap(); // Create a temp directory for the environment
209        std::fs::write(
210            temp_dir.path().join("search_code.rs"),
211            "fn find_stuff() {}
212fn search_database() {}
213",
214        )
215        .unwrap();
216
217        let auth_storage = Arc::new(crate::test_utils::InMemoryAuthStorage::new());
218        let llm_config_provider = Arc::new(LlmConfigProvider::new(auth_storage));
219
220        // Create execution context
221        let context = steer_tools::ExecutionContext::new("test_tool_call".to_string())
222            .with_working_directory(temp_dir.path().to_path_buf())
223            .with_cancellation_token(tokio_util::sync::CancellationToken::new());
224
225        // Test prompt that should search for specific code
226        let prompt = "Find all files that contain definitions of functions or methods related to search or find operations. Return only the absolute file path.";
227
228        let params = DispatchAgentParams {
229            prompt: prompt.to_string(),
230        };
231
232        // Instantiate the tool struct (assuming default if no specific state needed)
233        let tool_instance = DispatchAgentTool {
234            llm_config_provider,
235            workspace: Arc::new(
236                LocalWorkspace::with_path(temp_dir.path().to_path_buf())
237                    .await
238                    .unwrap(),
239            ),
240        };
241
242        // Execute the agent using the run method
243        let result = run(&tool_instance, params, &context).await;
244
245        // Check if we got a valid response
246        assert!(result.is_ok(), "Agent execution failed: {:?}", result.err());
247        let response = result.unwrap();
248        assert!(!response.content.is_empty(), "Response should not be empty");
249        assert!(
250            response.content.contains("search_code.rs"),
251            "Response should contain the file path"
252        ); // Check for expected content
253
254        println!("Dispatch agent response: {}", response.content);
255        println!("Dispatch agent test passed successfully!");
256    }
257}