steer_core/app/
agent_executor.rs

1use crate::api::{ApiError, Client as ApiClient};
2use crate::app::conversation::{Message, MessageData};
3use crate::config::model::ModelId;
4use futures::{StreamExt, stream::FuturesUnordered};
5use std::future::Future;
6use std::sync::Arc;
7use std::time::{SystemTime, UNIX_EPOCH};
8use steer_tools::{ToolCall, ToolError, ToolSchema, result::ToolResult as SteerToolResult};
9use thiserror::Error;
10use tokio::sync::mpsc;
11use tokio_util::sync::CancellationToken;
12use tracing::{debug, error, info, instrument, warn};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ApprovalDecision {
17    Approved,
18    Denied,
19}
20
21#[derive(Debug)]
22pub enum AgentEvent {
23    MessageFinal(Message),
24    ExecutingTool {
25        tool_call_id: String,
26        name: String,
27        parameters: serde_json::Value,
28    },
29}
30
31#[derive(Error, Debug)]
32pub enum AgentExecutorError {
33    #[error(transparent)]
34    Api(#[from] ApiError),
35    #[error(transparent)]
36    Tool(#[from] ToolError),
37    #[error("Event channel send error: {0}")]
38    SendError(String),
39    #[error("Operation cancelled")]
40    Cancelled,
41    #[error("Internal error: {0}")]
42    Internal(String),
43    #[error("Unexpected API response structure")]
44    UnexpectedResponse,
45}
46
47impl<T> From<mpsc::error::SendError<T>> for AgentExecutorError {
48    fn from(err: mpsc::error::SendError<T>) -> Self {
49        AgentExecutorError::SendError(err.to_string())
50    }
51}
52
53#[derive(Clone)]
54pub struct AgentExecutor {
55    api_client: Arc<ApiClient>,
56}
57
58pub struct AgentExecutorRunRequest<A, E> {
59    pub model: ModelId,
60    pub initial_messages: Vec<Message>,
61    pub system_prompt: Option<String>,
62    pub available_tools: Vec<ToolSchema>,
63    pub tool_approval_callback: A,
64    pub tool_execution_callback: E,
65}
66
67impl AgentExecutor {
68    pub fn new(api_client: Arc<ApiClient>) -> Self {
69        Self { api_client }
70    }
71
72    #[instrument(skip_all, name = "AgentExecutor::run")]
73    pub async fn run<A, AFut, E, EFut>(
74        &self,
75        request: AgentExecutorRunRequest<A, E>,
76        event_sender: mpsc::Sender<AgentEvent>,
77        token: CancellationToken,
78    ) -> Result<Message, AgentExecutorError>
79    where
80        A: Fn(ToolCall) -> AFut + Send + Sync + 'static,
81        AFut: Future<Output = Result<ApprovalDecision, ToolError>> + Send + 'static,
82        E: Fn(ToolCall, CancellationToken) -> EFut + Send + Sync + 'static,
83        EFut: Future<Output = Result<SteerToolResult, ToolError>> + Send + 'static,
84    {
85        let mut messages = request.initial_messages.clone();
86        let tools = if request.available_tools.is_empty() {
87            None
88        } else {
89            Some(request.available_tools)
90        };
91
92        debug!(target: "AgentExecutor::run", "About to start completion loop with model: {:?}", request.model);
93
94        loop {
95            if token.is_cancelled() {
96                info!("Operation cancelled before API call.");
97                return Err(AgentExecutorError::Cancelled);
98            }
99
100            info!(target: "AgentExecutor::run", model = ?&request.model, "Calling LLM API");
101            let completion_response = self
102                .api_client
103                .complete_with_retry(
104                    &request.model,
105                    &messages,
106                    &request.system_prompt,
107                    &tools,
108                    token.clone(),
109                    3,
110                )
111                .await?;
112            let tool_calls = completion_response.extract_tool_calls();
113
114            // Get parent info from the last message
115            let parent_id = if let Some(last_msg) = messages.last() {
116                last_msg.id().to_string()
117            } else {
118                // This shouldn't happen
119                return Err(AgentExecutorError::Internal(
120                    "No messages in conversation when adding assistant message".to_string(),
121                ));
122            };
123
124            let full_assistant_message = Message {
125                data: MessageData::Assistant {
126                    content: completion_response.content,
127                },
128                timestamp: SystemTime::now()
129                    .duration_since(UNIX_EPOCH)
130                    .unwrap()
131                    .as_secs(),
132                id: Uuid::new_v4().to_string(),
133                parent_message_id: Some(parent_id),
134            };
135
136            messages.push(full_assistant_message.clone());
137
138            if tool_calls.is_empty() {
139                info!("LLM response received, no tool calls requested.");
140                event_sender
141                    .send(AgentEvent::MessageFinal(full_assistant_message.clone()))
142                    .await?;
143                debug!(target: "AgentExecutor::run_operation", "Operation finished successfully (no tool calls), returning final message.");
144                return Ok(full_assistant_message);
145            } else {
146                info!(count = tool_calls.len(), "LLM requested tool calls.");
147                event_sender
148                    .send(AgentEvent::MessageFinal(full_assistant_message.clone()))
149                    .await?;
150
151                // Create concurrent futures for every tool call
152                let mut pending_tools: FuturesUnordered<_> = tool_calls
153                    .into_iter()
154                    .map(|call| {
155                        let event_sender_clone = event_sender.clone();
156                        let token_clone = token.clone();
157                        let approval_callback = &request.tool_approval_callback;
158                        let execution_callback = &request.tool_execution_callback;
159
160                        async move {
161                            let message_id = uuid::Uuid::new_v4().to_string();
162                            let call_id = call.id.clone();
163
164                            // Handle single tool call
165                            let result = Self::handle_single_tool_call(
166                                call,
167                                approval_callback,
168                                execution_callback,
169                                &event_sender_clone,
170                                token_clone,
171                            )
172                            .await;
173
174                            (call_id, message_id, result)
175                        }
176                    })
177                    .collect();
178
179                // Pull results as they finish and emit events
180                while let Some((tool_call_id, message_id, result)) = pending_tools.next().await {
181                    if token.is_cancelled() {
182                        info!("Operation cancelled during tool handling.");
183                        return Err(AgentExecutorError::Cancelled);
184                    }
185
186                    // Get parent info from the last message
187                    let parent_id = if let Some(last_msg) = messages.last() {
188                        last_msg.id().to_string()
189                    } else {
190                        return Err(AgentExecutorError::Internal(
191                            "No messages in conversation when adding tool results".to_string(),
192                        ));
193                    };
194
195                    // Add tool result message
196                    let tool_message = Message {
197                        data: MessageData::Tool {
198                            tool_use_id: tool_call_id,
199                            result,
200                        },
201                        timestamp: SystemTime::now()
202                            .duration_since(UNIX_EPOCH)
203                            .unwrap()
204                            .as_secs(),
205                        id: message_id,
206                        parent_message_id: Some(parent_id),
207                    };
208
209                    messages.push(tool_message.clone());
210                    event_sender
211                        .send(AgentEvent::MessageFinal(tool_message))
212                        .await?;
213                }
214
215                debug!("Looping back to LLM with tool results.");
216            }
217        }
218    }
219
220    #[instrument(
221        skip(tool_call, approval_callback, execution_callback, event_sender, token),
222        name = "AgentExecutor::handle_single_tool_call"
223    )]
224    async fn handle_single_tool_call<A, AFut, E, EFut>(
225        tool_call: ToolCall,
226        approval_callback: &A,
227        execution_callback: &E,
228        event_sender: &mpsc::Sender<AgentEvent>,
229        token: CancellationToken,
230    ) -> SteerToolResult
231    where
232        A: Fn(ToolCall) -> AFut + Send + Sync + 'static,
233        AFut: Future<Output = Result<ApprovalDecision, ToolError>> + Send + 'static,
234        E: Fn(ToolCall, CancellationToken) -> EFut + Send + Sync + 'static,
235        EFut: Future<Output = Result<SteerToolResult, ToolError>> + Send + 'static,
236    {
237        let call_id = tool_call.id.clone();
238        let tool_name = tool_call.name.clone();
239
240        // First, check approval
241        let approval_result = tokio::select! {
242            biased;
243            _ = token.cancelled() => {
244                warn!(tool_id=%call_id, tool_name=%tool_name, "Cancellation detected during tool approval");
245                Err(ToolError::Cancelled(tool_name.clone()))
246            }
247            res = approval_callback(tool_call.clone()) => res,
248        };
249
250        match approval_result {
251            Ok(ApprovalDecision::Approved) => {
252                debug!(tool_id=%call_id, tool_name=%tool_name, "Tool approved, executing");
253
254                // Send ExecutingTool event for approved execution
255                if let Err(e) = event_sender
256                    .send(AgentEvent::ExecutingTool {
257                        tool_call_id: call_id.clone(),
258                        name: tool_name.clone(),
259                        parameters: tool_call.parameters.clone(),
260                    })
261                    .await
262                {
263                    warn!(tool_id=%call_id, tool_name=%tool_name, "Failed to send ExecutingTool event: {}", e);
264                }
265
266                // Execute the tool
267                let execution_result = tokio::select! {
268                    biased;
269                    _ = token.cancelled() => {
270                        warn!(tool_id=%call_id, tool_name=%tool_name, "Cancellation detected during tool execution");
271                        Err(ToolError::Cancelled(tool_name.clone()))
272                    }
273                    res = execution_callback(tool_call, token.clone()) => res,
274                };
275
276                match execution_result {
277                    Ok(output) => {
278                        debug!(tool_id=%call_id, tool_name=%tool_name, "Tool executed successfully");
279                        output
280                    }
281                    Err(e) => {
282                        error!(tool_id=%call_id, tool_name=%tool_name, "Tool execution failed: {}", e);
283                        SteerToolResult::Error(e)
284                    }
285                }
286            }
287            Ok(ApprovalDecision::Denied) => {
288                warn!(tool_id=%call_id, tool_name=%tool_name, "Tool approval denied");
289                SteerToolResult::Error(ToolError::DeniedByUser(tool_name))
290            }
291            Err(e @ ToolError::Cancelled(_)) => {
292                warn!(tool_id=%call_id, tool_name=%tool_name, "Tool approval cancelled: {}", e);
293                SteerToolResult::Error(e)
294            }
295            Err(e) => {
296                error!(tool_id=%call_id, tool_name=%tool_name, "Tool approval failed: {}", e);
297                SteerToolResult::Error(e)
298            }
299        }
300    }
301}