1use std::{
2 path::{Path, PathBuf},
3 sync::{Arc, Mutex},
4};
5
6use crate::{
7 chat_completion::{ChatMessage, ToolCall},
8 indexing::IndexingStream,
9};
10use anyhow::Result;
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15#[async_trait]
27pub trait ToolExecutor: Send + Sync {
28 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
30
31 async fn stream_files(
33 &self,
34 path: &Path,
35 extensions: Option<Vec<String>>,
36 ) -> Result<IndexingStream>;
37}
38
39#[async_trait]
40impl<T: ToolExecutor> ToolExecutor for &T {
41 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
42 (*self).exec_cmd(cmd).await
43 }
44
45 async fn stream_files(
46 &self,
47 path: &Path,
48 extensions: Option<Vec<String>>,
49 ) -> Result<IndexingStream> {
50 (*self).stream_files(path, extensions).await
51 }
52}
53
54#[async_trait]
55impl ToolExecutor for Arc<dyn ToolExecutor> {
56 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
57 (**self).exec_cmd(cmd).await
58 }
59 async fn stream_files(
60 &self,
61 path: &Path,
62 extensions: Option<Vec<String>>,
63 ) -> Result<IndexingStream> {
64 (*self).stream_files(path, extensions).await
65 }
66}
67
68#[async_trait]
69impl ToolExecutor for Box<dyn ToolExecutor> {
70 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
71 (**self).exec_cmd(cmd).await
72 }
73 async fn stream_files(
74 &self,
75 path: &Path,
76 extensions: Option<Vec<String>>,
77 ) -> Result<IndexingStream> {
78 (*self).stream_files(path, extensions).await
79 }
80}
81
82#[async_trait]
83impl ToolExecutor for &dyn ToolExecutor {
84 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
85 (**self).exec_cmd(cmd).await
86 }
87 async fn stream_files(
88 &self,
89 path: &Path,
90 extensions: Option<Vec<String>>,
91 ) -> Result<IndexingStream> {
92 (*self).stream_files(path, extensions).await
93 }
94}
95
96#[derive(Debug, Error)]
97pub enum CommandError {
98 #[error("executor error: {0:#}")]
100 ExecutorError(#[from] anyhow::Error),
101
102 #[error("command failed with NonZeroExit: {0}")]
104 NonZeroExit(CommandOutput),
105}
106
107impl From<std::io::Error> for CommandError {
108 fn from(err: std::io::Error) -> Self {
109 CommandError::NonZeroExit(err.to_string().into())
110 }
111}
112
113#[non_exhaustive]
121#[derive(Debug, Clone)]
122pub enum Command {
123 Shell(String),
124 ReadFile(PathBuf),
125 WriteFile(PathBuf, String),
126}
127
128impl Command {
129 pub fn shell<S: Into<String>>(cmd: S) -> Self {
130 Command::Shell(cmd.into())
131 }
132
133 pub fn read_file<P: Into<PathBuf>>(path: P) -> Self {
134 Command::ReadFile(path.into())
135 }
136
137 pub fn write_file<P: Into<PathBuf>, S: Into<String>>(path: P, content: S) -> Self {
138 Command::WriteFile(path.into(), content.into())
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct CommandOutput {
145 pub output: String,
146 }
149
150impl CommandOutput {
151 pub fn empty() -> Self {
152 CommandOutput {
153 output: String::new(),
154 }
155 }
156
157 pub fn new(output: impl Into<String>) -> Self {
158 CommandOutput {
159 output: output.into(),
160 }
161 }
162 pub fn is_empty(&self) -> bool {
163 self.output.is_empty()
164 }
165}
166
167impl std::fmt::Display for CommandOutput {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 self.output.fmt(f)
170 }
171}
172
173impl<T: Into<String>> From<T> for CommandOutput {
174 fn from(value: T) -> Self {
175 CommandOutput {
176 output: value.into(),
177 }
178 }
179}
180
181impl AsRef<str> for CommandOutput {
182 fn as_ref(&self) -> &str {
183 &self.output
184 }
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub enum ToolFeedback {
190 Approved { payload: Option<serde_json::Value> },
191 Refused { payload: Option<serde_json::Value> },
192}
193
194impl ToolFeedback {
195 pub fn approved() -> Self {
196 ToolFeedback::Approved { payload: None }
197 }
198
199 pub fn refused() -> Self {
200 ToolFeedback::Refused { payload: None }
201 }
202
203 #[must_use]
204 pub fn with_payload(self, payload: serde_json::Value) -> Self {
205 match self {
206 ToolFeedback::Approved { .. } => ToolFeedback::Approved {
207 payload: Some(payload),
208 },
209 ToolFeedback::Refused { .. } => ToolFeedback::Refused {
210 payload: Some(payload),
211 },
212 }
213 }
214}
215
216#[async_trait]
218pub trait AgentContext: Send + Sync {
219 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>>;
229
230 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>>;
232
233 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()>;
235
236 async fn add_message(&self, item: ChatMessage) -> Result<()>;
238
239 #[deprecated(note = "use executor instead")]
243 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
244
245 fn executor(&self) -> &Arc<dyn ToolExecutor>;
246
247 async fn history(&self) -> Result<Vec<ChatMessage>>;
248
249 async fn redrive(&self) -> Result<()>;
254
255 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback>;
258
259 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()>;
260}
261
262#[async_trait]
263impl AgentContext for Box<dyn AgentContext> {
264 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
265 (**self).next_completion().await
266 }
267
268 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
269 (**self).current_new_messages().await
270 }
271
272 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
273 (**self).add_messages(item).await
274 }
275
276 async fn add_message(&self, item: ChatMessage) -> Result<()> {
277 (**self).add_message(item).await
278 }
279
280 #[allow(deprecated)]
281 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
282 (**self).exec_cmd(cmd).await
283 }
284
285 fn executor(&self) -> &Arc<dyn ToolExecutor> {
286 (**self).executor()
287 }
288
289 async fn history(&self) -> Result<Vec<ChatMessage>> {
290 (**self).history().await
291 }
292
293 async fn redrive(&self) -> Result<()> {
294 (**self).redrive().await
295 }
296
297 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
298 (**self).has_received_feedback(tool_call).await
299 }
300
301 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
302 (**self).feedback_received(tool_call, feedback).await
303 }
304}
305
306#[async_trait]
307impl AgentContext for Arc<dyn AgentContext> {
308 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
309 (**self).next_completion().await
310 }
311
312 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
313 (**self).current_new_messages().await
314 }
315
316 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
317 (**self).add_messages(item).await
318 }
319
320 async fn add_message(&self, item: ChatMessage) -> Result<()> {
321 (**self).add_message(item).await
322 }
323
324 #[allow(deprecated)]
325 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
326 (**self).exec_cmd(cmd).await
327 }
328
329 fn executor(&self) -> &Arc<dyn ToolExecutor> {
330 (**self).executor()
331 }
332
333 async fn history(&self) -> Result<Vec<ChatMessage>> {
334 (**self).history().await
335 }
336
337 async fn redrive(&self) -> Result<()> {
338 (**self).redrive().await
339 }
340
341 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
342 (**self).has_received_feedback(tool_call).await
343 }
344
345 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
346 (**self).feedback_received(tool_call, feedback).await
347 }
348}
349
350#[async_trait]
351impl AgentContext for &dyn AgentContext {
352 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
353 (**self).next_completion().await
354 }
355
356 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
357 (**self).current_new_messages().await
358 }
359
360 async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
361 (**self).add_messages(item).await
362 }
363
364 async fn add_message(&self, item: ChatMessage) -> Result<()> {
365 (**self).add_message(item).await
366 }
367
368 #[allow(deprecated)]
369 async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
370 (**self).exec_cmd(cmd).await
371 }
372
373 fn executor(&self) -> &Arc<dyn ToolExecutor> {
374 (**self).executor()
375 }
376
377 async fn history(&self) -> Result<Vec<ChatMessage>> {
378 (**self).history().await
379 }
380
381 async fn redrive(&self) -> Result<()> {
382 (**self).redrive().await
383 }
384
385 async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
386 (**self).has_received_feedback(tool_call).await
387 }
388
389 async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
390 (**self).feedback_received(tool_call, feedback).await
391 }
392}
393
394#[async_trait]
398impl AgentContext for () {
399 async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
400 Ok(None)
401 }
402
403 async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
404 Ok(Vec::new())
405 }
406
407 async fn add_messages(&self, _item: Vec<ChatMessage>) -> Result<()> {
408 Ok(())
409 }
410
411 async fn add_message(&self, _item: ChatMessage) -> Result<()> {
412 Ok(())
413 }
414
415 async fn exec_cmd(&self, _cmd: &Command) -> Result<CommandOutput, CommandError> {
416 Err(CommandError::ExecutorError(anyhow::anyhow!(
417 "Empty agent context does not have a tool executor"
418 )))
419 }
420
421 fn executor(&self) -> &Arc<dyn ToolExecutor> {
422 unimplemented!("Empty agent context does not have a tool executor")
423 }
424
425 async fn history(&self) -> Result<Vec<ChatMessage>> {
426 Ok(Vec::new())
427 }
428
429 async fn redrive(&self) -> Result<()> {
430 Ok(())
431 }
432
433 async fn has_received_feedback(&self, _tool_call: &ToolCall) -> Option<ToolFeedback> {
434 Some(ToolFeedback::Approved { payload: None })
435 }
436
437 async fn feedback_received(
438 &self,
439 _tool_call: &ToolCall,
440 _feedback: &ToolFeedback,
441 ) -> Result<()> {
442 Ok(())
443 }
444}
445
446#[async_trait]
451pub trait MessageHistory: Send + Sync + std::fmt::Debug {
452 async fn history(&self) -> Result<Vec<ChatMessage>>;
454
455 async fn push_owned(&self, item: ChatMessage) -> Result<()>;
457
458 async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()>;
460
461 async fn push(&self, item: &ChatMessage) -> Result<()> {
463 self.push_owned(item.clone()).await
464 }
465
466 async fn extend(&self, items: &[ChatMessage]) -> Result<()> {
468 self.extend_owned(items.to_vec()).await
469 }
470
471 async fn extend_owned(&self, items: Vec<ChatMessage>) -> Result<()> {
473 for item in items {
474 self.push_owned(item).await?;
475 }
476
477 Ok(())
478 }
479}
480
481#[async_trait]
482impl MessageHistory for Mutex<Vec<ChatMessage>> {
483 async fn history(&self) -> Result<Vec<ChatMessage>> {
484 Ok(self.lock().unwrap().clone())
485 }
486
487 async fn push_owned(&self, item: ChatMessage) -> Result<()> {
488 self.lock().unwrap().push(item);
489
490 Ok(())
491 }
492
493 async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()> {
494 let mut lock = self.lock().unwrap();
495 *lock = items;
496
497 Ok(())
498 }
499}