1use std::{
2 path::{Path, PathBuf},
3 sync::{Arc, Mutex},
4};
5
6use crate::{
7 chat_completion::{ChatMessage, ToolCall},
8 indexing::IndexingStream,
9};
10use anyhow::Result;
11use async_trait::async_trait;
12use dyn_clone::DynClone;
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15
16#[async_trait]
28pub trait ToolExecutor: Send + Sync + DynClone {
29 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
31
32 async fn stream_files(
34 &self,
35 path: &Path,
36 extensions: Option<Vec<String>>,
37 ) -> Result<IndexingStream>;
38}
39
40dyn_clone::clone_trait_object!(ToolExecutor);
41
42#[async_trait]
43impl<T: ToolExecutor> ToolExecutor for &T {
44 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
45 (*self).exec_cmd(cmd).await
46 }
47
48 async fn stream_files(
49 &self,
50 path: &Path,
51 extensions: Option<Vec<String>>,
52 ) -> Result<IndexingStream> {
53 (*self).stream_files(path, extensions).await
54 }
55}
56
57#[async_trait]
58impl ToolExecutor for Arc<dyn ToolExecutor> {
59 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
60 self.as_ref().exec_cmd(cmd).await
61 }
62 async fn stream_files(
63 &self,
64 path: &Path,
65 extensions: Option<Vec<String>>,
66 ) -> Result<IndexingStream> {
67 self.as_ref().stream_files(path, extensions).await
68 }
69}
70
71#[async_trait]
72impl ToolExecutor for Box<dyn ToolExecutor> {
73 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
74 self.as_ref().exec_cmd(cmd).await
75 }
76 async fn stream_files(
77 &self,
78 path: &Path,
79 extensions: Option<Vec<String>>,
80 ) -> Result<IndexingStream> {
81 self.as_ref().stream_files(path, extensions).await
82 }
83}
84
85#[async_trait]
86impl ToolExecutor for &dyn ToolExecutor {
87 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
88 (*self).exec_cmd(cmd).await
89 }
90 async fn stream_files(
91 &self,
92 path: &Path,
93 extensions: Option<Vec<String>>,
94 ) -> Result<IndexingStream> {
95 (*self).stream_files(path, extensions).await
96 }
97}
98
99#[derive(Debug, Error)]
100pub enum CommandError {
101 #[error("executor error: {0:#}")]
103 ExecutorError(#[from] anyhow::Error),
104
105 #[error("command failed with NonZeroExit: {0}")]
107 NonZeroExit(CommandOutput),
108}
109
110impl From<std::io::Error> for CommandError {
111 fn from(err: std::io::Error) -> Self {
112 CommandError::NonZeroExit(err.to_string().into())
113 }
114}
115
116#[non_exhaustive]
124#[derive(Debug, Clone)]
125pub enum Command {
126 Shell(String),
127 ReadFile(PathBuf),
128 WriteFile(PathBuf, String),
129}
130
131impl Command {
132 pub fn shell<S: Into<String>>(cmd: S) -> Self {
133 Command::Shell(cmd.into())
134 }
135
136 pub fn read_file<P: Into<PathBuf>>(path: P) -> Self {
137 Command::ReadFile(path.into())
138 }
139
140 pub fn write_file<P: Into<PathBuf>, S: Into<String>>(path: P, content: S) -> Self {
141 Command::WriteFile(path.into(), content.into())
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct CommandOutput {
148 pub output: String,
149 }
152
153impl CommandOutput {
154 pub fn empty() -> Self {
155 CommandOutput {
156 output: String::new(),
157 }
158 }
159
160 pub fn new(output: impl Into<String>) -> Self {
161 CommandOutput {
162 output: output.into(),
163 }
164 }
165 pub fn is_empty(&self) -> bool {
166 self.output.is_empty()
167 }
168}
169
170impl std::fmt::Display for CommandOutput {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 self.output.fmt(f)
173 }
174}
175
176impl<T: Into<String>> From<T> for CommandOutput {
177 fn from(value: T) -> Self {
178 CommandOutput {
179 output: value.into(),
180 }
181 }
182}
183
184impl AsRef<str> for CommandOutput {
185 fn as_ref(&self) -> &str {
186 &self.output
187 }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize, strum_macros::EnumIs)]
192#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
193pub enum ToolFeedback {
194 Approved { payload: Option<serde_json::Value> },
195 Refused { payload: Option<serde_json::Value> },
196}
197
198impl ToolFeedback {
199 pub fn approved() -> Self {
200 ToolFeedback::Approved { payload: None }
201 }
202
203 pub fn refused() -> Self {
204 ToolFeedback::Refused { payload: None }
205 }
206
207 pub fn payload(&self) -> Option<&serde_json::Value> {
208 match self {
209 ToolFeedback::Refused { payload } | ToolFeedback::Approved { payload } => {
210 payload.as_ref()
211 }
212 }
213 }
214
215 #[must_use]
216 pub fn with_payload(self, payload: serde_json::Value) -> Self {
217 match self {
218 ToolFeedback::Approved { .. } => ToolFeedback::Approved {
219 payload: Some(payload),
220 },
221 ToolFeedback::Refused { .. } => ToolFeedback::Refused {
222 payload: Some(payload),
223 },
224 }
225 }
226}
227
228#[async_trait]
230pub trait AgentContext: Send + Sync {
231 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>>;
241
242 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>>;
244
245 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()>;
247
248 async fn add_message(&self, item: ChatMessage) -> Result<()>;
250
251 #[deprecated(note = "use executor instead")]
255 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
256
257 fn executor(&self) -> &Arc<dyn ToolExecutor>;
258
259 async fn history(&self) -> Result<Vec<ChatMessage>>;
260
261 async fn redrive(&self) -> Result<()>;
266
267 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback>;
270
271 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()>;
272}
273
274#[async_trait]
275impl AgentContext for Box<dyn AgentContext> {
276 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
277 (**self).next_completion().await
278 }
279
280 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
281 (**self).current_new_messages().await
282 }
283
284 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
285 (**self).add_messages(item).await
286 }
287
288 async fn add_message(&self, item: ChatMessage) -> Result<()> {
289 (**self).add_message(item).await
290 }
291
292 #[allow(deprecated)]
293 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
294 (**self).exec_cmd(cmd).await
295 }
296
297 fn executor(&self) -> &Arc<dyn ToolExecutor> {
298 (**self).executor()
299 }
300
301 async fn history(&self) -> Result<Vec<ChatMessage>> {
302 (**self).history().await
303 }
304
305 async fn redrive(&self) -> Result<()> {
306 (**self).redrive().await
307 }
308
309 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
310 (**self).has_received_feedback(tool_call).await
311 }
312
313 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
314 (**self).feedback_received(tool_call, feedback).await
315 }
316}
317
318#[async_trait]
319impl AgentContext for Arc<dyn AgentContext> {
320 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
321 (**self).next_completion().await
322 }
323
324 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
325 (**self).current_new_messages().await
326 }
327
328 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
329 (**self).add_messages(item).await
330 }
331
332 async fn add_message(&self, item: ChatMessage) -> Result<()> {
333 (**self).add_message(item).await
334 }
335
336 #[allow(deprecated)]
337 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
338 (**self).exec_cmd(cmd).await
339 }
340
341 fn executor(&self) -> &Arc<dyn ToolExecutor> {
342 (**self).executor()
343 }
344
345 async fn history(&self) -> Result<Vec<ChatMessage>> {
346 (**self).history().await
347 }
348
349 async fn redrive(&self) -> Result<()> {
350 (**self).redrive().await
351 }
352
353 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
354 (**self).has_received_feedback(tool_call).await
355 }
356
357 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
358 (**self).feedback_received(tool_call, feedback).await
359 }
360}
361
362#[async_trait]
363impl AgentContext for &dyn AgentContext {
364 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
365 (**self).next_completion().await
366 }
367
368 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
369 (**self).current_new_messages().await
370 }
371
372 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
373 (**self).add_messages(item).await
374 }
375
376 async fn add_message(&self, item: ChatMessage) -> Result<()> {
377 (**self).add_message(item).await
378 }
379
380 #[allow(deprecated)]
381 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
382 (**self).exec_cmd(cmd).await
383 }
384
385 fn executor(&self) -> &Arc<dyn ToolExecutor> {
386 (**self).executor()
387 }
388
389 async fn history(&self) -> Result<Vec<ChatMessage>> {
390 (**self).history().await
391 }
392
393 async fn redrive(&self) -> Result<()> {
394 (**self).redrive().await
395 }
396
397 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
398 (**self).has_received_feedback(tool_call).await
399 }
400
401 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
402 (**self).feedback_received(tool_call, feedback).await
403 }
404}
405
406#[async_trait]
410impl AgentContext for () {
411 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
412 Ok(None)
413 }
414
415 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
416 Ok(Vec::new())
417 }
418
419 async fn add_messages(&self, _item: Vec<ChatMessage>) -> Result<()> {
420 Ok(())
421 }
422
423 async fn add_message(&self, _item: ChatMessage) -> Result<()> {
424 Ok(())
425 }
426
427 async fn exec_cmd(&self, _cmd: &Command) -> Result<CommandOutput, CommandError> {
428 Err(CommandError::ExecutorError(anyhow::anyhow!(
429 "Empty agent context does not have a tool executor"
430 )))
431 }
432
433 fn executor(&self) -> &Arc<dyn ToolExecutor> {
434 unimplemented!("Empty agent context does not have a tool executor")
435 }
436
437 async fn history(&self) -> Result<Vec<ChatMessage>> {
438 Ok(Vec::new())
439 }
440
441 async fn redrive(&self) -> Result<()> {
442 Ok(())
443 }
444
445 async fn has_received_feedback(&self, _tool_call: &ToolCall) -> Option<ToolFeedback> {
446 Some(ToolFeedback::Approved { payload: None })
447 }
448
449 async fn feedback_received(
450 &self,
451 _tool_call: &ToolCall,
452 _feedback: &ToolFeedback,
453 ) -> Result<()> {
454 Ok(())
455 }
456}
457
458#[async_trait]
463pub trait MessageHistory: Send + Sync + std::fmt::Debug {
464 async fn history(&self) -> Result<Vec<ChatMessage>>;
466
467 async fn push_owned(&self, item: ChatMessage) -> Result<()>;
469
470 async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()>;
472
473 async fn push(&self, item: &ChatMessage) -> Result<()> {
475 self.push_owned(item.clone()).await
476 }
477
478 async fn extend(&self, items: &[ChatMessage]) -> Result<()> {
480 self.extend_owned(items.to_vec()).await
481 }
482
483 async fn extend_owned(&self, items: Vec<ChatMessage>) -> Result<()> {
485 for item in items {
486 self.push_owned(item).await?;
487 }
488
489 Ok(())
490 }
491}
492
493#[async_trait]
494impl MessageHistory for Mutex<Vec<ChatMessage>> {
495 async fn history(&self) -> Result<Vec<ChatMessage>> {
496 Ok(self.lock().unwrap().clone())
497 }
498
499 async fn push_owned(&self, item: ChatMessage) -> Result<()> {
500 self.lock().unwrap().push(item);
501
502 Ok(())
503 }
504
505 async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()> {
506 let mut lock = self.lock().unwrap();
507 *lock = items;
508
509 Ok(())
510 }
511}