1use crate::api::{ApiError, Client as ApiClient};
2use crate::app::conversation::{Message, MessageData};
3use crate::config::model::ModelId;
4use futures::{StreamExt, stream::FuturesUnordered};
5use std::future::Future;
6use std::sync::Arc;
7use std::time::{SystemTime, UNIX_EPOCH};
8use steer_tools::{ToolCall, ToolError, ToolSchema, result::ToolResult as SteerToolResult};
9use thiserror::Error;
10use tokio::sync::mpsc;
11use tokio_util::sync::CancellationToken;
12use tracing::{debug, error, info, instrument, warn};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ApprovalDecision {
17 Approved,
18 Denied,
19}
20
21#[derive(Debug)]
22pub enum AgentEvent {
23 MessageFinal(Message),
24 ExecutingTool {
25 tool_call_id: String,
26 name: String,
27 parameters: serde_json::Value,
28 },
29}
30
31#[derive(Error, Debug)]
32pub enum AgentExecutorError {
33 #[error(transparent)]
34 Api(#[from] ApiError),
35 #[error(transparent)]
36 Tool(#[from] ToolError),
37 #[error("Event channel send error: {0}")]
38 SendError(String),
39 #[error("Operation cancelled")]
40 Cancelled,
41 #[error("Internal error: {0}")]
42 Internal(String),
43 #[error("Unexpected API response structure")]
44 UnexpectedResponse,
45}
46
47impl<T> From<mpsc::error::SendError<T>> for AgentExecutorError {
48 fn from(err: mpsc::error::SendError<T>) -> Self {
49 AgentExecutorError::SendError(err.to_string())
50 }
51}
52
53#[derive(Clone)]
54pub struct AgentExecutor {
55 api_client: Arc<ApiClient>,
56}
57
58pub struct AgentExecutorRunRequest<A, E> {
59 pub model: ModelId,
60 pub initial_messages: Vec<Message>,
61 pub system_prompt: Option<String>,
62 pub available_tools: Vec<ToolSchema>,
63 pub tool_approval_callback: A,
64 pub tool_execution_callback: E,
65}
66
67impl AgentExecutor {
68 pub fn new(api_client: Arc<ApiClient>) -> Self {
69 Self { api_client }
70 }
71
72 #[instrument(skip_all, name = "AgentExecutor::run")]
73 pub async fn run<A, AFut, E, EFut>(
74 &self,
75 request: AgentExecutorRunRequest<A, E>,
76 event_sender: mpsc::Sender<AgentEvent>,
77 token: CancellationToken,
78 ) -> Result<Message, AgentExecutorError>
79 where
80 A: Fn(ToolCall) -> AFut + Send + Sync + 'static,
81 AFut: Future<Output = Result<ApprovalDecision, ToolError>> + Send + 'static,
82 E: Fn(ToolCall, CancellationToken) -> EFut + Send + Sync + 'static,
83 EFut: Future<Output = Result<SteerToolResult, ToolError>> + Send + 'static,
84 {
85 let mut messages = request.initial_messages.clone();
86 let tools = if request.available_tools.is_empty() {
87 None
88 } else {
89 Some(request.available_tools)
90 };
91
92 debug!(target: "AgentExecutor::run", "About to start completion loop with model: {:?}", request.model);
93
94 loop {
95 if token.is_cancelled() {
96 info!("Operation cancelled before API call.");
97 return Err(AgentExecutorError::Cancelled);
98 }
99
100 info!(target: "AgentExecutor::run", model = ?&request.model, "Calling LLM API");
101 let completion_response = self
102 .api_client
103 .complete_with_retry(
104 &request.model,
105 &messages,
106 &request.system_prompt,
107 &tools,
108 token.clone(),
109 3,
110 )
111 .await?;
112 let tool_calls = completion_response.extract_tool_calls();
113
114 let parent_id = if let Some(last_msg) = messages.last() {
116 last_msg.id().to_string()
117 } else {
118 return Err(AgentExecutorError::Internal(
120 "No messages in conversation when adding assistant message".to_string(),
121 ));
122 };
123
124 let full_assistant_message = Message {
125 data: MessageData::Assistant {
126 content: completion_response.content,
127 },
128 timestamp: SystemTime::now()
129 .duration_since(UNIX_EPOCH)
130 .unwrap()
131 .as_secs(),
132 id: Uuid::new_v4().to_string(),
133 parent_message_id: Some(parent_id),
134 };
135
136 messages.push(full_assistant_message.clone());
137
138 if tool_calls.is_empty() {
139 info!("LLM response received, no tool calls requested.");
140 event_sender
141 .send(AgentEvent::MessageFinal(full_assistant_message.clone()))
142 .await?;
143 debug!(target: "AgentExecutor::run_operation", "Operation finished successfully (no tool calls), returning final message.");
144 return Ok(full_assistant_message);
145 } else {
146 info!(count = tool_calls.len(), "LLM requested tool calls.");
147 event_sender
148 .send(AgentEvent::MessageFinal(full_assistant_message.clone()))
149 .await?;
150
151 let mut pending_tools: FuturesUnordered<_> = tool_calls
153 .into_iter()
154 .map(|call| {
155 let event_sender_clone = event_sender.clone();
156 let token_clone = token.clone();
157 let approval_callback = &request.tool_approval_callback;
158 let execution_callback = &request.tool_execution_callback;
159
160 async move {
161 let message_id = uuid::Uuid::new_v4().to_string();
162 let call_id = call.id.clone();
163
164 let result = Self::handle_single_tool_call(
166 call,
167 approval_callback,
168 execution_callback,
169 &event_sender_clone,
170 token_clone,
171 )
172 .await;
173
174 (call_id, message_id, result)
175 }
176 })
177 .collect();
178
179 while let Some((tool_call_id, message_id, result)) = pending_tools.next().await {
181 if token.is_cancelled() {
182 info!("Operation cancelled during tool handling.");
183 return Err(AgentExecutorError::Cancelled);
184 }
185
186 let parent_id = if let Some(last_msg) = messages.last() {
188 last_msg.id().to_string()
189 } else {
190 return Err(AgentExecutorError::Internal(
191 "No messages in conversation when adding tool results".to_string(),
192 ));
193 };
194
195 let tool_message = Message {
197 data: MessageData::Tool {
198 tool_use_id: tool_call_id,
199 result,
200 },
201 timestamp: SystemTime::now()
202 .duration_since(UNIX_EPOCH)
203 .unwrap()
204 .as_secs(),
205 id: message_id,
206 parent_message_id: Some(parent_id),
207 };
208
209 messages.push(tool_message.clone());
210 event_sender
211 .send(AgentEvent::MessageFinal(tool_message))
212 .await?;
213 }
214
215 debug!("Looping back to LLM with tool results.");
216 }
217 }
218 }
219
220 #[instrument(
221 skip(tool_call, approval_callback, execution_callback, event_sender, token),
222 name = "AgentExecutor::handle_single_tool_call"
223 )]
224 async fn handle_single_tool_call<A, AFut, E, EFut>(
225 tool_call: ToolCall,
226 approval_callback: &A,
227 execution_callback: &E,
228 event_sender: &mpsc::Sender<AgentEvent>,
229 token: CancellationToken,
230 ) -> SteerToolResult
231 where
232 A: Fn(ToolCall) -> AFut + Send + Sync + 'static,
233 AFut: Future<Output = Result<ApprovalDecision, ToolError>> + Send + 'static,
234 E: Fn(ToolCall, CancellationToken) -> EFut + Send + Sync + 'static,
235 EFut: Future<Output = Result<SteerToolResult, ToolError>> + Send + 'static,
236 {
237 let call_id = tool_call.id.clone();
238 let tool_name = tool_call.name.clone();
239
240 let approval_result = tokio::select! {
242 biased;
243 _ = token.cancelled() => {
244 warn!(tool_id=%call_id, tool_name=%tool_name, "Cancellation detected during tool approval");
245 Err(ToolError::Cancelled(tool_name.clone()))
246 }
247 res = approval_callback(tool_call.clone()) => res,
248 };
249
250 match approval_result {
251 Ok(ApprovalDecision::Approved) => {
252 debug!(tool_id=%call_id, tool_name=%tool_name, "Tool approved, executing");
253
254 if let Err(e) = event_sender
256 .send(AgentEvent::ExecutingTool {
257 tool_call_id: call_id.clone(),
258 name: tool_name.clone(),
259 parameters: tool_call.parameters.clone(),
260 })
261 .await
262 {
263 warn!(tool_id=%call_id, tool_name=%tool_name, "Failed to send ExecutingTool event: {}", e);
264 }
265
266 let execution_result = tokio::select! {
268 biased;
269 _ = token.cancelled() => {
270 warn!(tool_id=%call_id, tool_name=%tool_name, "Cancellation detected during tool execution");
271 Err(ToolError::Cancelled(tool_name.clone()))
272 }
273 res = execution_callback(tool_call, token.clone()) => res,
274 };
275
276 match execution_result {
277 Ok(output) => {
278 debug!(tool_id=%call_id, tool_name=%tool_name, "Tool executed successfully");
279 output
280 }
281 Err(e) => {
282 error!(tool_id=%call_id, tool_name=%tool_name, "Tool execution failed: {}", e);
283 SteerToolResult::Error(e)
284 }
285 }
286 }
287 Ok(ApprovalDecision::Denied) => {
288 warn!(tool_id=%call_id, tool_name=%tool_name, "Tool approval denied");
289 SteerToolResult::Error(ToolError::DeniedByUser(tool_name))
290 }
291 Err(e @ ToolError::Cancelled(_)) => {
292 warn!(tool_id=%call_id, tool_name=%tool_name, "Tool approval cancelled: {}", e);
293 SteerToolResult::Error(e)
294 }
295 Err(e) => {
296 error!(tool_id=%call_id, tool_name=%tool_name, "Tool approval failed: {}", e);
297 SteerToolResult::Error(e)
298 }
299 }
300 }
301}