steer_core/tools/
dispatch_agent.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4
5use crate::{
6    api::Model,
7    app::{
8        ApprovalDecision,
9        conversation::{Message, MessageData, UserContent},
10    },
11    config::LlmConfigProvider,
12    tools::ToolExecutor,
13};
14
15use crate::app::{AgentEvent, AgentExecutor, AgentExecutorRunRequest};
16use steer_macros::tool_external as tool;
17use steer_tools::tools::{GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
18use steer_tools::{ToolCall, ToolError, ToolSchema};
19use tokio_util::sync::CancellationToken;
20
21#[derive(Deserialize, Debug, Serialize, JsonSchema)]
22pub struct DispatchAgentParams {
23    /// The task for the agent to perform
24    pub prompt: String,
25}
26
27const DISPATCH_AGENT_TOOLS: [&str; 4] =
28    [GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME];
29
30fn format_dispatch_agent_tools() -> String {
31    DISPATCH_AGENT_TOOLS
32        .iter()
33        .map(|tool| tool.to_string())
34        .collect::<Vec<String>>()
35        .join(", ")
36}
37
38tool! {
39    pub struct DispatchAgentTool {
40        pub llm_config_provider: Arc<LlmConfigProvider>,
41        pub workspace: Arc<dyn crate::workspace::Workspace>,
42    } {
43        params: DispatchAgentParams,
44        output: steer_tools::result::AgentResult,
45        variant: Agent,
46        description: &format!(r#"Launch a new agent that has access to the following tools: {}. When you are searching for a keyword or file and are not confident that you will find the right match on the first try, use the Agent tool to perform the search for you.
47
48When to use the Agent tool:
49- If you are searching for a keyword like "config" or "logger", or for questions like "which file does X?", the Agent tool is strongly recommended
50
51When NOT to use the Agent tool:
52- If you want to read a specific file path, use the {VIEW_TOOL_NAME} or {GLOB_TOOL_NAME} tool instead of the Agent tool, to find the match more quickly
53- If you are searching for a specific class definition like "class Foo", use the {GREP_TOOL_NAME} tool instead, to find the match more quickly
54- If you are searching for code within a specific file or set of 2-3 files, use the {GREP_TOOL_NAME} tool instead, to find the match more quickly
55
56Usage notes:
571. Launch multiple agents concurrently whenever possible, to maximize performance; to do that, use a single message with multiple tool uses
582. When the agent is done, it will return a single message back to you. The result returned by the agent is not visible to the user. To show the user the result, you should send a text message back to the user with a concise summary of the result.
593. Each agent invocation is stateless. You will not be able to send additional messages to the agent, nor will the agent be able to communicate with you outside of its final report. Therefore, your prompt should contain a highly detailed task description for the agent to perform autonomously and you should specify exactly what information the agent should return back to you in its final and only message to you.
604. The agent's outputs should generally be trusted
615. IMPORTANT: The agent can not modify files. If you want to modify files, do it directly instead of going through the agent."#, format_dispatch_agent_tools()),
62        name: "dispatch_agent",
63        require_approval: false
64    }
65
66    async fn run(
67        tool: &DispatchAgentTool,
68        params: DispatchAgentParams,
69        context: &steer_tools::ExecutionContext,
70    ) -> std::result::Result<steer_tools::result::AgentResult, ToolError> {
71        let token = context.cancellation_token.clone();
72
73        let api_client = Arc::new(crate::api::Client::new_with_provider((*tool.llm_config_provider).clone())); // Create ApiClient and wrap in Arc
74        let agent_executor = AgentExecutor::new(api_client);
75
76        let tool_executor = Arc::new(ToolExecutor::with_workspace(tool.workspace.clone()));
77
78        let available_tools: Vec<ToolSchema> = tool_executor.get_tool_schemas().await;
79        let tool_approval_callback = move |_tool_call: ToolCall| {
80            async move { Ok(ApprovalDecision::Approved) }
81        };
82
83        let tool_execution_callback =
84            move |tool_call: ToolCall, callback_token: CancellationToken| {
85                let executor = tool_executor.clone();
86                async move {
87                    executor
88                        .execute_tool_with_cancellation(&tool_call, callback_token)
89                        .await
90                }
91            };
92
93        // --- Prepare for AgentExecutor ---
94        let initial_messages = vec![Message {
95            data: MessageData::User {
96                content: vec![UserContent::Text { text: params.prompt }],
97            },
98            timestamp: Message::current_timestamp(),
99            id: Message::generate_id("user", Message::current_timestamp()),
100            parent_message_id: None,
101        }];
102
103        let system_prompt = create_dispatch_agent_system_prompt(&tool.workspace)
104            .await
105            .map_err(|e| ToolError::execution(DISPATCH_AGENT_TOOL_NAME, format!("Failed to create system prompt: {e}")))?;
106
107        // Use a channel to receive events, though we might just aggregate the final result here.
108        let (event_tx, mut event_rx) = tokio::sync::mpsc::channel(100);
109
110        // --- Run AgentExecutor ---
111        let operation_result = agent_executor
112            .run(
113                AgentExecutorRunRequest
114                 {
115                    model: Model::Claude3_7Sonnet20250219, // Or make configurable?
116                    initial_messages,
117                    system_prompt: Some(system_prompt),
118                    available_tools,
119                    tool_approval_callback,
120                    tool_execution_callback,
121                },
122                event_tx,
123                token,
124            )
125            .await;
126
127        // --- Process Result ---
128        // We need the final text response from the agent.
129        // Collect text from events or the final message.
130        let mut final_text = String::new();
131        // let mut final_message_content: Option<ApiMessage> = None;
132
133        // Drain remaining events
134        while let Ok(event) = event_rx.try_recv() {
135            if let AgentEvent::MessageFinal(msg) = event {
136                if final_text.is_empty() {
137                    final_text = msg.extract_text();
138                }
139            }
140        }
141
142
143        match operation_result {
144            Ok(message) => {
145                 // If we still don't have text, extract from final message object
146                 if final_text.is_empty() {
147                     final_text = message.extract_text();
148                 }
149                 Ok(steer_tools::result::AgentResult {
150                     content: final_text,
151                 })
152            }
153            Err(e) => {
154                 Err(ToolError::execution(DISPATCH_AGENT_TOOL_NAME, e.to_string()))
155            }
156        }
157    }
158}
159
160pub async fn create_dispatch_agent_system_prompt(
161    workspace: &Arc<dyn crate::workspace::Workspace>,
162) -> crate::error::Result<String> {
163    // Get full environment context
164    let env_info = workspace.environment().await?;
165    let env_context = env_info.as_context();
166
167    let dispatch_prompt = format!(
168        r#"You are an agent for a CLI-based coding tool. Given the user's prompt, you should use the tools available to you to answer the user's question.
169
170Notes:
1711. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
1722. When relevant, share file names and code snippets relevant to the query
1733. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.
174
175{env_context}
176"#
177    );
178
179    Ok(dispatch_prompt)
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use dotenvy::dotenv;
186    use steer_workspace::local::LocalWorkspace;
187
188    #[tokio::test]
189    #[ignore] // Requires API key and network call
190    async fn test_dispatch_agent() {
191        // Load environment variables from .env file
192        dotenv().ok();
193
194        // Ensure API key is available for the test
195        let _api_key =
196            std::env::var("CLAUDE_API_KEY").expect("CLAUDE_API_KEY must be set for this test");
197
198        // Setup necessary context for the tool run method
199        let temp_dir = tempfile::tempdir().unwrap(); // Create a temp directory for the environment
200        std::fs::write(
201            temp_dir.path().join("search_code.rs"),
202            "fn find_stuff() {}
203fn search_database() {}
204",
205        )
206        .unwrap();
207
208        let auth_storage = Arc::new(crate::test_utils::InMemoryAuthStorage::new());
209        let llm_config_provider = Arc::new(LlmConfigProvider::new(auth_storage));
210
211        // Create execution context
212        let context = steer_tools::ExecutionContext::new("test_tool_call".to_string())
213            .with_working_directory(temp_dir.path().to_path_buf())
214            .with_cancellation_token(tokio_util::sync::CancellationToken::new());
215
216        // Test prompt that should search for specific code
217        let prompt = "Find all files that contain definitions of functions or methods related to search or find operations. Return only the absolute file path.";
218
219        let params = DispatchAgentParams {
220            prompt: prompt.to_string(),
221        };
222
223        // Instantiate the tool struct (assuming default if no specific state needed)
224        let tool_instance = DispatchAgentTool {
225            llm_config_provider,
226            workspace: Arc::new(
227                LocalWorkspace::with_path(temp_dir.path().to_path_buf())
228                    .await
229                    .unwrap(),
230            ),
231        };
232
233        // Execute the agent using the run method
234        let result = run(&tool_instance, params, &context).await;
235
236        // Check if we got a valid response
237        assert!(result.is_ok(), "Agent execution failed: {:?}", result.err());
238        let response = result.unwrap();
239        assert!(!response.content.is_empty(), "Response should not be empty");
240        assert!(
241            response.content.contains("search_code.rs"),
242            "Response should contain the file path"
243        ); // Check for expected content
244
245        println!("Dispatch agent response: {}", response.content);
246        println!("Dispatch agent test passed successfully!");
247    }
248}