steer_core/app/
agent_executor.rs

1use crate::api::{ApiError, Client as ApiClient, Model};
2use crate::app::conversation::{Message, MessageData};
3use futures::{StreamExt, stream::FuturesUnordered};
4use std::future::Future;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7use steer_tools::{ToolCall, ToolError, ToolSchema, result::ToolResult as SteerToolResult};
8use thiserror::Error;
9use tokio::sync::mpsc;
10use tokio_util::sync::CancellationToken;
11use tracing::{debug, error, info, instrument, warn};
12use uuid::Uuid;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ApprovalDecision {
16    Approved,
17    Denied,
18}
19
20#[derive(Debug)]
21pub enum AgentEvent {
22    MessageFinal(Message),
23    ExecutingTool {
24        tool_call_id: String,
25        name: String,
26        parameters: serde_json::Value,
27    },
28}
29
30#[derive(Error, Debug)]
31pub enum AgentExecutorError {
32    #[error(transparent)]
33    Api(#[from] ApiError),
34    #[error(transparent)]
35    Tool(#[from] ToolError),
36    #[error("Event channel send error: {0}")]
37    SendError(String),
38    #[error("Operation cancelled")]
39    Cancelled,
40    #[error("Internal error: {0}")]
41    Internal(String),
42    #[error("Unexpected API response structure")]
43    UnexpectedResponse,
44}
45
46impl<T> From<mpsc::error::SendError<T>> for AgentExecutorError {
47    fn from(err: mpsc::error::SendError<T>) -> Self {
48        AgentExecutorError::SendError(err.to_string())
49    }
50}
51
52#[derive(Clone)]
53pub struct AgentExecutor {
54    api_client: Arc<ApiClient>,
55}
56
57pub struct AgentExecutorRunRequest<A, E> {
58    pub model: Model,
59    pub initial_messages: Vec<Message>,
60    pub system_prompt: Option<String>,
61    pub available_tools: Vec<ToolSchema>,
62    pub tool_approval_callback: A,
63    pub tool_execution_callback: E,
64}
65
66impl AgentExecutor {
67    pub fn new(api_client: Arc<ApiClient>) -> Self {
68        Self { api_client }
69    }
70
71    #[instrument(skip_all, name = "AgentExecutor::run")]
72    pub async fn run<A, AFut, E, EFut>(
73        &self,
74        request: AgentExecutorRunRequest<A, E>,
75        event_sender: mpsc::Sender<AgentEvent>,
76        token: CancellationToken,
77    ) -> Result<Message, AgentExecutorError>
78    where
79        A: Fn(ToolCall) -> AFut + Send + Sync + 'static,
80        AFut: Future<Output = Result<ApprovalDecision, ToolError>> + Send + 'static,
81        E: Fn(ToolCall, CancellationToken) -> EFut + Send + Sync + 'static,
82        EFut: Future<Output = Result<SteerToolResult, ToolError>> + Send + 'static,
83    {
84        let mut messages = request.initial_messages.clone();
85        let tools = if request.available_tools.is_empty() {
86            None
87        } else {
88            Some(request.available_tools)
89        };
90
91        debug!(target: "AgentExecutor::run", "About to start completion loop with model: {:?}", request.model);
92
93        loop {
94            if token.is_cancelled() {
95                info!("Operation cancelled before API call.");
96                return Err(AgentExecutorError::Cancelled);
97            }
98
99            info!(target: "AgentExecutor::run", model = ?request.model, "Calling LLM API");
100            let completion_response = self
101                .api_client
102                .complete_with_retry(
103                    request.model,
104                    &messages,
105                    &request.system_prompt,
106                    &tools,
107                    token.clone(),
108                    3,
109                )
110                .await?;
111            let tool_calls = completion_response.extract_tool_calls();
112
113            // Get parent info from the last message
114            let parent_id = if let Some(last_msg) = messages.last() {
115                last_msg.id().to_string()
116            } else {
117                // This shouldn't happen
118                return Err(AgentExecutorError::Internal(
119                    "No messages in conversation when adding assistant message".to_string(),
120                ));
121            };
122
123            let full_assistant_message = Message {
124                data: MessageData::Assistant {
125                    content: completion_response.content,
126                },
127                timestamp: SystemTime::now()
128                    .duration_since(UNIX_EPOCH)
129                    .unwrap()
130                    .as_secs(),
131                id: Uuid::new_v4().to_string(),
132                parent_message_id: Some(parent_id),
133            };
134
135            messages.push(full_assistant_message.clone());
136
137            if tool_calls.is_empty() {
138                info!("LLM response received, no tool calls requested.");
139                event_sender
140                    .send(AgentEvent::MessageFinal(full_assistant_message.clone()))
141                    .await?;
142                debug!(target: "AgentExecutor::run_operation", "Operation finished successfully (no tool calls), returning final message.");
143                return Ok(full_assistant_message);
144            } else {
145                info!(count = tool_calls.len(), "LLM requested tool calls.");
146                event_sender
147                    .send(AgentEvent::MessageFinal(full_assistant_message.clone()))
148                    .await?;
149
150                // Create concurrent futures for every tool call
151                let mut pending_tools: FuturesUnordered<_> = tool_calls
152                    .into_iter()
153                    .map(|call| {
154                        let event_sender_clone = event_sender.clone();
155                        let token_clone = token.clone();
156                        let approval_callback = &request.tool_approval_callback;
157                        let execution_callback = &request.tool_execution_callback;
158
159                        async move {
160                            let message_id = uuid::Uuid::new_v4().to_string();
161                            let call_id = call.id.clone();
162
163                            // Handle single tool call
164                            let result = Self::handle_single_tool_call(
165                                call,
166                                approval_callback,
167                                execution_callback,
168                                &event_sender_clone,
169                                token_clone,
170                            )
171                            .await;
172
173                            (call_id, message_id, result)
174                        }
175                    })
176                    .collect();
177
178                // Pull results as they finish and emit events
179                while let Some((tool_call_id, message_id, result)) = pending_tools.next().await {
180                    if token.is_cancelled() {
181                        info!("Operation cancelled during tool handling.");
182                        return Err(AgentExecutorError::Cancelled);
183                    }
184
185                    // Get parent info from the last message
186                    let parent_id = if let Some(last_msg) = messages.last() {
187                        last_msg.id().to_string()
188                    } else {
189                        return Err(AgentExecutorError::Internal(
190                            "No messages in conversation when adding tool results".to_string(),
191                        ));
192                    };
193
194                    // Add tool result message
195                    let tool_message = Message {
196                        data: MessageData::Tool {
197                            tool_use_id: tool_call_id,
198                            result,
199                        },
200                        timestamp: SystemTime::now()
201                            .duration_since(UNIX_EPOCH)
202                            .unwrap()
203                            .as_secs(),
204                        id: message_id,
205                        parent_message_id: Some(parent_id),
206                    };
207
208                    messages.push(tool_message.clone());
209                    event_sender
210                        .send(AgentEvent::MessageFinal(tool_message))
211                        .await?;
212                }
213
214                debug!("Looping back to LLM with tool results.");
215            }
216        }
217    }
218
219    #[instrument(
220        skip(tool_call, approval_callback, execution_callback, event_sender, token),
221        name = "AgentExecutor::handle_single_tool_call"
222    )]
223    async fn handle_single_tool_call<A, AFut, E, EFut>(
224        tool_call: ToolCall,
225        approval_callback: &A,
226        execution_callback: &E,
227        event_sender: &mpsc::Sender<AgentEvent>,
228        token: CancellationToken,
229    ) -> SteerToolResult
230    where
231        A: Fn(ToolCall) -> AFut + Send + Sync + 'static,
232        AFut: Future<Output = Result<ApprovalDecision, ToolError>> + Send + 'static,
233        E: Fn(ToolCall, CancellationToken) -> EFut + Send + Sync + 'static,
234        EFut: Future<Output = Result<SteerToolResult, ToolError>> + Send + 'static,
235    {
236        let call_id = tool_call.id.clone();
237        let tool_name = tool_call.name.clone();
238
239        // First, check approval
240        let approval_result = tokio::select! {
241            biased;
242            _ = token.cancelled() => {
243                warn!(tool_id=%call_id, tool_name=%tool_name, "Cancellation detected during tool approval");
244                Err(ToolError::Cancelled(tool_name.clone()))
245            }
246            res = approval_callback(tool_call.clone()) => res,
247        };
248
249        match approval_result {
250            Ok(ApprovalDecision::Approved) => {
251                debug!(tool_id=%call_id, tool_name=%tool_name, "Tool approved, executing");
252
253                // Send ExecutingTool event for approved execution
254                if let Err(e) = event_sender
255                    .send(AgentEvent::ExecutingTool {
256                        tool_call_id: call_id.clone(),
257                        name: tool_name.clone(),
258                        parameters: tool_call.parameters.clone(),
259                    })
260                    .await
261                {
262                    warn!(tool_id=%call_id, tool_name=%tool_name, "Failed to send ExecutingTool event: {}", e);
263                }
264
265                // Execute the tool
266                let execution_result = tokio::select! {
267                    biased;
268                    _ = token.cancelled() => {
269                        warn!(tool_id=%call_id, tool_name=%tool_name, "Cancellation detected during tool execution");
270                        Err(ToolError::Cancelled(tool_name.clone()))
271                    }
272                    res = execution_callback(tool_call, token.clone()) => res,
273                };
274
275                match execution_result {
276                    Ok(output) => {
277                        debug!(tool_id=%call_id, tool_name=%tool_name, "Tool executed successfully");
278                        output
279                    }
280                    Err(e) => {
281                        error!(tool_id=%call_id, tool_name=%tool_name, "Tool execution failed: {}", e);
282                        SteerToolResult::Error(e)
283                    }
284                }
285            }
286            Ok(ApprovalDecision::Denied) => {
287                warn!(tool_id=%call_id, tool_name=%tool_name, "Tool approval denied");
288                SteerToolResult::Error(ToolError::DeniedByUser(tool_name))
289            }
290            Err(e @ ToolError::Cancelled(_)) => {
291                warn!(tool_id=%call_id, tool_name=%tool_name, "Tool approval cancelled: {}", e);
292                SteerToolResult::Error(e)
293            }
294            Err(e) => {
295                error!(tool_id=%call_id, tool_name=%tool_name, "Tool approval failed: {}", e);
296                SteerToolResult::Error(e)
297            }
298        }
299    }
300}