1use std::{
2 path::PathBuf,
3 sync::{Arc, Mutex},
4};
5
6use crate::chat_completion::{ChatMessage, ToolCall};
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use thiserror::Error;
11
12#[async_trait]
14pub trait ToolExecutor: Send + Sync {
15 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
16}
17
18#[async_trait]
19impl<T: ToolExecutor> ToolExecutor for &T {
20 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
21 (*self).exec_cmd(cmd).await
22 }
23}
24
25#[async_trait]
26impl ToolExecutor for Arc<dyn ToolExecutor> {
27 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
28 (**self).exec_cmd(cmd).await
29 }
30}
31
32#[async_trait]
33impl ToolExecutor for Box<dyn ToolExecutor> {
34 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
35 (**self).exec_cmd(cmd).await
36 }
37}
38
39#[async_trait]
40impl ToolExecutor for &dyn ToolExecutor {
41 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
42 (**self).exec_cmd(cmd).await
43 }
44}
45
46#[derive(Debug, Error)]
47pub enum CommandError {
48 #[error("executor error: {0:#}")]
50 ExecutorError(#[from] anyhow::Error),
51
52 #[error("command failed with NonZeroExit: {0}")]
54 NonZeroExit(CommandOutput),
55}
56
57impl From<std::io::Error> for CommandError {
58 fn from(err: std::io::Error) -> Self {
59 CommandError::NonZeroExit(err.to_string().into())
60 }
61}
62
63#[non_exhaustive]
71#[derive(Debug, Clone)]
72pub enum Command {
73 Shell(String),
74 ReadFile(PathBuf),
75 WriteFile(PathBuf, String),
76}
77
78impl Command {
79 pub fn shell<S: Into<String>>(cmd: S) -> Self {
80 Command::Shell(cmd.into())
81 }
82
83 pub fn read_file<P: Into<PathBuf>>(path: P) -> Self {
84 Command::ReadFile(path.into())
85 }
86
87 pub fn write_file<P: Into<PathBuf>, S: Into<String>>(path: P, content: S) -> Self {
88 Command::WriteFile(path.into(), content.into())
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct CommandOutput {
95 pub output: String,
96 }
99
100impl CommandOutput {
101 pub fn empty() -> Self {
102 CommandOutput {
103 output: String::new(),
104 }
105 }
106
107 pub fn new(output: impl Into<String>) -> Self {
108 CommandOutput {
109 output: output.into(),
110 }
111 }
112 pub fn is_empty(&self) -> bool {
113 self.output.is_empty()
114 }
115}
116
117impl std::fmt::Display for CommandOutput {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 self.output.fmt(f)
120 }
121}
122
123impl<T: Into<String>> From<T> for CommandOutput {
124 fn from(value: T) -> Self {
125 CommandOutput {
126 output: value.into(),
127 }
128 }
129}
130
131impl AsRef<str> for CommandOutput {
132 fn as_ref(&self) -> &str {
133 &self.output
134 }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub enum ToolFeedback {
140 Approved { payload: Option<serde_json::Value> },
141 Refused { payload: Option<serde_json::Value> },
142}
143
144impl ToolFeedback {
145 pub fn approved() -> Self {
146 ToolFeedback::Approved { payload: None }
147 }
148
149 pub fn refused() -> Self {
150 ToolFeedback::Refused { payload: None }
151 }
152
153 #[must_use]
154 pub fn with_payload(self, payload: serde_json::Value) -> Self {
155 match self {
156 ToolFeedback::Approved { .. } => ToolFeedback::Approved {
157 payload: Some(payload),
158 },
159 ToolFeedback::Refused { .. } => ToolFeedback::Refused {
160 payload: Some(payload),
161 },
162 }
163 }
164}
165
166#[async_trait]
168pub trait AgentContext: Send + Sync {
169 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>>;
179
180 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>>;
182
183 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()>;
185
186 async fn add_message(&self, item: ChatMessage) -> Result<()>;
188
189 #[deprecated(note = "use executor instead")]
193 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
194
195 fn executor(&self) -> &Arc<dyn ToolExecutor>;
196
197 async fn history(&self) -> Result<Vec<ChatMessage>>;
198
199 async fn redrive(&self) -> Result<()>;
204
205 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback>;
208
209 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()>;
210}
211
212#[async_trait]
213impl AgentContext for Box<dyn AgentContext> {
214 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
215 (**self).next_completion().await
216 }
217
218 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
219 (**self).current_new_messages().await
220 }
221
222 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
223 (**self).add_messages(item).await
224 }
225
226 async fn add_message(&self, item: ChatMessage) -> Result<()> {
227 (**self).add_message(item).await
228 }
229
230 #[allow(deprecated)]
231 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
232 (**self).exec_cmd(cmd).await
233 }
234
235 fn executor(&self) -> &Arc<dyn ToolExecutor> {
236 (**self).executor()
237 }
238
239 async fn history(&self) -> Result<Vec<ChatMessage>> {
240 (**self).history().await
241 }
242
243 async fn redrive(&self) -> Result<()> {
244 (**self).redrive().await
245 }
246
247 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
248 (**self).has_received_feedback(tool_call).await
249 }
250
251 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
252 (**self).feedback_received(tool_call, feedback).await
253 }
254}
255
256#[async_trait]
257impl AgentContext for Arc<dyn AgentContext> {
258 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
259 (**self).next_completion().await
260 }
261
262 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
263 (**self).current_new_messages().await
264 }
265
266 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
267 (**self).add_messages(item).await
268 }
269
270 async fn add_message(&self, item: ChatMessage) -> Result<()> {
271 (**self).add_message(item).await
272 }
273
274 #[allow(deprecated)]
275 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
276 (**self).exec_cmd(cmd).await
277 }
278
279 fn executor(&self) -> &Arc<dyn ToolExecutor> {
280 (**self).executor()
281 }
282
283 async fn history(&self) -> Result<Vec<ChatMessage>> {
284 (**self).history().await
285 }
286
287 async fn redrive(&self) -> Result<()> {
288 (**self).redrive().await
289 }
290
291 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
292 (**self).has_received_feedback(tool_call).await
293 }
294
295 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
296 (**self).feedback_received(tool_call, feedback).await
297 }
298}
299
300#[async_trait]
301impl AgentContext for &dyn AgentContext {
302 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
303 (**self).next_completion().await
304 }
305
306 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
307 (**self).current_new_messages().await
308 }
309
310 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
311 (**self).add_messages(item).await
312 }
313
314 async fn add_message(&self, item: ChatMessage) -> Result<()> {
315 (**self).add_message(item).await
316 }
317
318 #[allow(deprecated)]
319 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
320 (**self).exec_cmd(cmd).await
321 }
322
323 fn executor(&self) -> &Arc<dyn ToolExecutor> {
324 (**self).executor()
325 }
326
327 async fn history(&self) -> Result<Vec<ChatMessage>> {
328 (**self).history().await
329 }
330
331 async fn redrive(&self) -> Result<()> {
332 (**self).redrive().await
333 }
334
335 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
336 (**self).has_received_feedback(tool_call).await
337 }
338
339 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
340 (**self).feedback_received(tool_call, feedback).await
341 }
342}
343
344#[async_trait]
348impl AgentContext for () {
349 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
350 Ok(None)
351 }
352
353 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
354 Ok(Vec::new())
355 }
356
357 async fn add_messages(&self, _item: Vec<ChatMessage>) -> Result<()> {
358 Ok(())
359 }
360
361 async fn add_message(&self, _item: ChatMessage) -> Result<()> {
362 Ok(())
363 }
364
365 async fn exec_cmd(&self, _cmd: &Command) -> Result<CommandOutput, CommandError> {
366 Err(CommandError::ExecutorError(anyhow::anyhow!(
367 "Empty agent context does not have a tool executor"
368 )))
369 }
370
371 fn executor(&self) -> &Arc<dyn ToolExecutor> {
372 unimplemented!("Empty agent context does not have a tool executor")
373 }
374
375 async fn history(&self) -> Result<Vec<ChatMessage>> {
376 Ok(Vec::new())
377 }
378
379 async fn redrive(&self) -> Result<()> {
380 Ok(())
381 }
382
383 async fn has_received_feedback(&self, _tool_call: &ToolCall) -> Option<ToolFeedback> {
384 Some(ToolFeedback::Approved { payload: None })
385 }
386
387 async fn feedback_received(
388 &self,
389 _tool_call: &ToolCall,
390 _feedback: &ToolFeedback,
391 ) -> Result<()> {
392 Ok(())
393 }
394}
395
396#[async_trait]
401pub trait MessageHistory: Send + Sync + std::fmt::Debug {
402 async fn history(&self) -> Result<Vec<ChatMessage>>;
404
405 async fn push_owned(&self, item: ChatMessage) -> Result<()>;
407
408 async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()>;
410
411 async fn push(&self, item: &ChatMessage) -> Result<()> {
413 self.push_owned(item.clone()).await
414 }
415
416 async fn extend(&self, items: &[ChatMessage]) -> Result<()> {
418 self.extend_owned(items.to_vec()).await
419 }
420
421 async fn extend_owned(&self, items: Vec<ChatMessage>) -> Result<()> {
423 for item in items {
424 self.push_owned(item).await?;
425 }
426
427 Ok(())
428 }
429}
430
431#[async_trait]
432impl MessageHistory for Mutex<Vec<ChatMessage>> {
433 async fn history(&self) -> Result<Vec<ChatMessage>> {
434 Ok(self.lock().unwrap().clone())
435 }
436
437 async fn push_owned(&self, item: ChatMessage) -> Result<()> {
438 self.lock().unwrap().push(item);
439
440 Ok(())
441 }
442
443 async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()> {
444 let mut lock = self.lock().unwrap();
445 *lock = items;
446
447 Ok(())
448 }
449}