1use crate::api::{ApiError, Client as ApiClient, Model};
2use crate::app::conversation::{Message, MessageData};
3use futures::{StreamExt, stream::FuturesUnordered};
4use std::future::Future;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7use steer_tools::{ToolCall, ToolError, ToolSchema, result::ToolResult as SteerToolResult};
8use thiserror::Error;
9use tokio::sync::mpsc;
10use tokio_util::sync::CancellationToken;
11use tracing::{debug, error, info, instrument, warn};
12use uuid::Uuid;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ApprovalDecision {
16 Approved,
17 Denied,
18}
19
20#[derive(Debug)]
21pub enum AgentEvent {
22 MessageFinal(Message),
23 ExecutingTool {
24 tool_call_id: String,
25 name: String,
26 parameters: serde_json::Value,
27 },
28}
29
30#[derive(Error, Debug)]
31pub enum AgentExecutorError {
32 #[error(transparent)]
33 Api(#[from] ApiError),
34 #[error(transparent)]
35 Tool(#[from] ToolError),
36 #[error("Event channel send error: {0}")]
37 SendError(String),
38 #[error("Operation cancelled")]
39 Cancelled,
40 #[error("Internal error: {0}")]
41 Internal(String),
42 #[error("Unexpected API response structure")]
43 UnexpectedResponse,
44}
45
46impl<T> From<mpsc::error::SendError<T>> for AgentExecutorError {
47 fn from(err: mpsc::error::SendError<T>) -> Self {
48 AgentExecutorError::SendError(err.to_string())
49 }
50}
51
52#[derive(Clone)]
53pub struct AgentExecutor {
54 api_client: Arc<ApiClient>,
55}
56
57pub struct AgentExecutorRunRequest<A, E> {
58 pub model: Model,
59 pub initial_messages: Vec<Message>,
60 pub system_prompt: Option<String>,
61 pub available_tools: Vec<ToolSchema>,
62 pub tool_approval_callback: A,
63 pub tool_execution_callback: E,
64}
65
66impl AgentExecutor {
67 pub fn new(api_client: Arc<ApiClient>) -> Self {
68 Self { api_client }
69 }
70
71 #[instrument(skip_all, name = "AgentExecutor::run")]
72 pub async fn run<A, AFut, E, EFut>(
73 &self,
74 request: AgentExecutorRunRequest<A, E>,
75 event_sender: mpsc::Sender<AgentEvent>,
76 token: CancellationToken,
77 ) -> Result<Message, AgentExecutorError>
78 where
79 A: Fn(ToolCall) -> AFut + Send + Sync + 'static,
80 AFut: Future<Output = Result<ApprovalDecision, ToolError>> + Send + 'static,
81 E: Fn(ToolCall, CancellationToken) -> EFut + Send + Sync + 'static,
82 EFut: Future<Output = Result<SteerToolResult, ToolError>> + Send + 'static,
83 {
84 let mut messages = request.initial_messages.clone();
85 let tools = if request.available_tools.is_empty() {
86 None
87 } else {
88 Some(request.available_tools)
89 };
90
91 debug!(target: "AgentExecutor::run", "About to start completion loop with model: {:?}", request.model);
92
93 loop {
94 if token.is_cancelled() {
95 info!("Operation cancelled before API call.");
96 return Err(AgentExecutorError::Cancelled);
97 }
98
99 info!(target: "AgentExecutor::run", model = ?request.model, "Calling LLM API");
100 let completion_response = self
101 .api_client
102 .complete_with_retry(
103 request.model,
104 &messages,
105 &request.system_prompt,
106 &tools,
107 token.clone(),
108 3,
109 )
110 .await?;
111 let tool_calls = completion_response.extract_tool_calls();
112
113 let parent_id = if let Some(last_msg) = messages.last() {
115 last_msg.id().to_string()
116 } else {
117 return Err(AgentExecutorError::Internal(
119 "No messages in conversation when adding assistant message".to_string(),
120 ));
121 };
122
123 let full_assistant_message = Message {
124 data: MessageData::Assistant {
125 content: completion_response.content,
126 },
127 timestamp: SystemTime::now()
128 .duration_since(UNIX_EPOCH)
129 .unwrap()
130 .as_secs(),
131 id: Uuid::new_v4().to_string(),
132 parent_message_id: Some(parent_id),
133 };
134
135 messages.push(full_assistant_message.clone());
136
137 if tool_calls.is_empty() {
138 info!("LLM response received, no tool calls requested.");
139 event_sender
140 .send(AgentEvent::MessageFinal(full_assistant_message.clone()))
141 .await?;
142 debug!(target: "AgentExecutor::run_operation", "Operation finished successfully (no tool calls), returning final message.");
143 return Ok(full_assistant_message);
144 } else {
145 info!(count = tool_calls.len(), "LLM requested tool calls.");
146 event_sender
147 .send(AgentEvent::MessageFinal(full_assistant_message.clone()))
148 .await?;
149
150 let mut pending_tools: FuturesUnordered<_> = tool_calls
152 .into_iter()
153 .map(|call| {
154 let event_sender_clone = event_sender.clone();
155 let token_clone = token.clone();
156 let approval_callback = &request.tool_approval_callback;
157 let execution_callback = &request.tool_execution_callback;
158
159 async move {
160 let message_id = uuid::Uuid::new_v4().to_string();
161 let call_id = call.id.clone();
162
163 let result = Self::handle_single_tool_call(
165 call,
166 approval_callback,
167 execution_callback,
168 &event_sender_clone,
169 token_clone,
170 )
171 .await;
172
173 (call_id, message_id, result)
174 }
175 })
176 .collect();
177
178 while let Some((tool_call_id, message_id, result)) = pending_tools.next().await {
180 if token.is_cancelled() {
181 info!("Operation cancelled during tool handling.");
182 return Err(AgentExecutorError::Cancelled);
183 }
184
185 let parent_id = if let Some(last_msg) = messages.last() {
187 last_msg.id().to_string()
188 } else {
189 return Err(AgentExecutorError::Internal(
190 "No messages in conversation when adding tool results".to_string(),
191 ));
192 };
193
194 let tool_message = Message {
196 data: MessageData::Tool {
197 tool_use_id: tool_call_id,
198 result,
199 },
200 timestamp: SystemTime::now()
201 .duration_since(UNIX_EPOCH)
202 .unwrap()
203 .as_secs(),
204 id: message_id,
205 parent_message_id: Some(parent_id),
206 };
207
208 messages.push(tool_message.clone());
209 event_sender
210 .send(AgentEvent::MessageFinal(tool_message))
211 .await?;
212 }
213
214 debug!("Looping back to LLM with tool results.");
215 }
216 }
217 }
218
219 #[instrument(
220 skip(tool_call, approval_callback, execution_callback, event_sender, token),
221 name = "AgentExecutor::handle_single_tool_call"
222 )]
223 async fn handle_single_tool_call<A, AFut, E, EFut>(
224 tool_call: ToolCall,
225 approval_callback: &A,
226 execution_callback: &E,
227 event_sender: &mpsc::Sender<AgentEvent>,
228 token: CancellationToken,
229 ) -> SteerToolResult
230 where
231 A: Fn(ToolCall) -> AFut + Send + Sync + 'static,
232 AFut: Future<Output = Result<ApprovalDecision, ToolError>> + Send + 'static,
233 E: Fn(ToolCall, CancellationToken) -> EFut + Send + Sync + 'static,
234 EFut: Future<Output = Result<SteerToolResult, ToolError>> + Send + 'static,
235 {
236 let call_id = tool_call.id.clone();
237 let tool_name = tool_call.name.clone();
238
239 let approval_result = tokio::select! {
241 biased;
242 _ = token.cancelled() => {
243 warn!(tool_id=%call_id, tool_name=%tool_name, "Cancellation detected during tool approval");
244 Err(ToolError::Cancelled(tool_name.clone()))
245 }
246 res = approval_callback(tool_call.clone()) => res,
247 };
248
249 match approval_result {
250 Ok(ApprovalDecision::Approved) => {
251 debug!(tool_id=%call_id, tool_name=%tool_name, "Tool approved, executing");
252
253 if let Err(e) = event_sender
255 .send(AgentEvent::ExecutingTool {
256 tool_call_id: call_id.clone(),
257 name: tool_name.clone(),
258 parameters: tool_call.parameters.clone(),
259 })
260 .await
261 {
262 warn!(tool_id=%call_id, tool_name=%tool_name, "Failed to send ExecutingTool event: {}", e);
263 }
264
265 let execution_result = tokio::select! {
267 biased;
268 _ = token.cancelled() => {
269 warn!(tool_id=%call_id, tool_name=%tool_name, "Cancellation detected during tool execution");
270 Err(ToolError::Cancelled(tool_name.clone()))
271 }
272 res = execution_callback(tool_call, token.clone()) => res,
273 };
274
275 match execution_result {
276 Ok(output) => {
277 debug!(tool_id=%call_id, tool_name=%tool_name, "Tool executed successfully");
278 output
279 }
280 Err(e) => {
281 error!(tool_id=%call_id, tool_name=%tool_name, "Tool execution failed: {}", e);
282 SteerToolResult::Error(e)
283 }
284 }
285 }
286 Ok(ApprovalDecision::Denied) => {
287 warn!(tool_id=%call_id, tool_name=%tool_name, "Tool approval denied");
288 SteerToolResult::Error(ToolError::DeniedByUser(tool_name))
289 }
290 Err(e @ ToolError::Cancelled(_)) => {
291 warn!(tool_id=%call_id, tool_name=%tool_name, "Tool approval cancelled: {}", e);
292 SteerToolResult::Error(e)
293 }
294 Err(e) => {
295 error!(tool_id=%call_id, tool_name=%tool_name, "Tool approval failed: {}", e);
296 SteerToolResult::Error(e)
297 }
298 }
299 }
300}