swiftide_core/
agent_traits.rs

1use std::{
2    path::{Path, PathBuf},
3    sync::{Arc, Mutex},
4};
5
6use crate::{
7    chat_completion::{ChatMessage, ToolCall},
8    indexing::IndexingStream,
9};
10use anyhow::Result;
11use async_trait::async_trait;
12use dyn_clone::DynClone;
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15
16/// A `ToolExecutor` provides an interface for agents to interact with a system
17/// in an isolated context.
18///
19/// When starting up an agent, it's context expects an executor. For example,
20/// you might want your coding agent to work with a fresh, isolated set of files,
21/// separated from the rest of the system.
22///
23/// See `swiftide-docker-executor` for an executor that uses Docker. By default
24/// the executor is a local executor.
25///
26/// Additionally, the executor can be used stream files files for indexing.
27#[async_trait]
28pub trait ToolExecutor: Send + Sync + DynClone {
29    /// Execute a command in the executor
30    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
31
32    /// Stream files from the executor
33    async fn stream_files(
34        &self,
35        path: &Path,
36        extensions: Option<Vec<String>>,
37    ) -> Result<IndexingStream>;
38}
39
40dyn_clone::clone_trait_object!(ToolExecutor);
41
42#[async_trait]
43impl<T: ToolExecutor> ToolExecutor for &T {
44    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
45        (*self).exec_cmd(cmd).await
46    }
47
48    async fn stream_files(
49        &self,
50        path: &Path,
51        extensions: Option<Vec<String>>,
52    ) -> Result<IndexingStream> {
53        (*self).stream_files(path, extensions).await
54    }
55}
56
57#[async_trait]
58impl ToolExecutor for Arc<dyn ToolExecutor> {
59    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
60        self.as_ref().exec_cmd(cmd).await
61    }
62    async fn stream_files(
63        &self,
64        path: &Path,
65        extensions: Option<Vec<String>>,
66    ) -> Result<IndexingStream> {
67        self.as_ref().stream_files(path, extensions).await
68    }
69}
70
71#[async_trait]
72impl ToolExecutor for Box<dyn ToolExecutor> {
73    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
74        self.as_ref().exec_cmd(cmd).await
75    }
76    async fn stream_files(
77        &self,
78        path: &Path,
79        extensions: Option<Vec<String>>,
80    ) -> Result<IndexingStream> {
81        self.as_ref().stream_files(path, extensions).await
82    }
83}
84
85#[async_trait]
86impl ToolExecutor for &dyn ToolExecutor {
87    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
88        (*self).exec_cmd(cmd).await
89    }
90    async fn stream_files(
91        &self,
92        path: &Path,
93        extensions: Option<Vec<String>>,
94    ) -> Result<IndexingStream> {
95        (*self).stream_files(path, extensions).await
96    }
97}
98
99#[derive(Debug, Error)]
100pub enum CommandError {
101    /// The executor itself failed
102    #[error("executor error: {0:#}")]
103    ExecutorError(#[from] anyhow::Error),
104
105    /// The command failed, i.e. failing tests with stderr. This error might be handled
106    #[error("command failed with NonZeroExit: {0}")]
107    NonZeroExit(CommandOutput),
108}
109
110impl From<std::io::Error> for CommandError {
111    fn from(err: std::io::Error) -> Self {
112        CommandError::NonZeroExit(err.to_string().into())
113    }
114}
115
116/// Commands that can be executed by the executor
117/// Conceptually, `Shell` allows any kind of input, and other commands enable more optimized
118/// implementations.
119///
120/// There is an ongoing consideration to make this an associated type on the executor
121///
122/// TODO: Should be able to borrow everything?
123#[non_exhaustive]
124#[derive(Debug, Clone)]
125pub enum Command {
126    Shell(String),
127    ReadFile(PathBuf),
128    WriteFile(PathBuf, String),
129}
130
131impl Command {
132    pub fn shell<S: Into<String>>(cmd: S) -> Self {
133        Command::Shell(cmd.into())
134    }
135
136    pub fn read_file<P: Into<PathBuf>>(path: P) -> Self {
137        Command::ReadFile(path.into())
138    }
139
140    pub fn write_file<P: Into<PathBuf>, S: Into<String>>(path: P, content: S) -> Self {
141        Command::WriteFile(path.into(), content.into())
142    }
143}
144
145/// Output from a `Command`
146#[derive(Debug, Clone)]
147pub struct CommandOutput {
148    pub output: String,
149    // status_code: i32,
150    // success: bool,
151}
152
153impl CommandOutput {
154    pub fn empty() -> Self {
155        CommandOutput {
156            output: String::new(),
157        }
158    }
159
160    pub fn new(output: impl Into<String>) -> Self {
161        CommandOutput {
162            output: output.into(),
163        }
164    }
165    pub fn is_empty(&self) -> bool {
166        self.output.is_empty()
167    }
168}
169
170impl std::fmt::Display for CommandOutput {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        self.output.fmt(f)
173    }
174}
175
176impl<T: Into<String>> From<T> for CommandOutput {
177    fn from(value: T) -> Self {
178        CommandOutput {
179            output: value.into(),
180        }
181    }
182}
183
184impl AsRef<str> for CommandOutput {
185    fn as_ref(&self) -> &str {
186        &self.output
187    }
188}
189
190/// Feedback that can be given on a tool, i.e. with a human in the loop
191#[derive(Debug, Clone, Serialize, Deserialize, strum_macros::EnumIs)]
192#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
193pub enum ToolFeedback {
194    Approved { payload: Option<serde_json::Value> },
195    Refused { payload: Option<serde_json::Value> },
196}
197
198impl ToolFeedback {
199    pub fn approved() -> Self {
200        ToolFeedback::Approved { payload: None }
201    }
202
203    pub fn refused() -> Self {
204        ToolFeedback::Refused { payload: None }
205    }
206
207    pub fn payload(&self) -> Option<&serde_json::Value> {
208        match self {
209            ToolFeedback::Refused { payload } | ToolFeedback::Approved { payload } => {
210                payload.as_ref()
211            }
212        }
213    }
214
215    #[must_use]
216    pub fn with_payload(self, payload: serde_json::Value) -> Self {
217        match self {
218            ToolFeedback::Approved { .. } => ToolFeedback::Approved {
219                payload: Some(payload),
220            },
221            ToolFeedback::Refused { .. } => ToolFeedback::Refused {
222                payload: Some(payload),
223            },
224        }
225    }
226}
227
228/// Acts as the interface to the external world and manages messages for completion
229#[async_trait]
230pub trait AgentContext: Send + Sync {
231    /// List of all messages for this agent
232    ///
233    /// Used as main source for the next completion and expects all
234    /// messages to be returned if new messages are present.
235    ///
236    /// Once this method has been called, there should not be new messages
237    ///
238    /// TODO: Figure out a nice way to return a reference instead while still supporting i.e.
239    /// mutexes
240    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>>;
241
242    /// Lists only the new messages after calling `new_completion`
243    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>>;
244
245    /// Add messages for the next completion
246    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()>;
247
248    /// Add messages for the next completion
249    async fn add_message(&self, item: ChatMessage) -> Result<()>;
250
251    /// Execute a command if the context supports it
252    ///
253    /// Deprecated: use executor instead to access the executor directly
254    #[deprecated(note = "use executor instead")]
255    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError>;
256
257    fn executor(&self) -> &Arc<dyn ToolExecutor>;
258
259    async fn history(&self) -> Result<Vec<ChatMessage>>;
260
261    /// Pops the last messages up until the last completion
262    ///
263    /// LLMs failing completion for various reasons is unfortunately a common occurrence
264    /// This gives a way to redrive the last completion in a generic way
265    async fn redrive(&self) -> Result<()>;
266
267    /// Tools that require feedback or approval (i.e. from a human) can use this to check if the
268    /// feedback is received
269    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback>;
270
271    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()>;
272}
273
274#[async_trait]
275impl AgentContext for Box<dyn AgentContext> {
276    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
277        (**self).next_completion().await
278    }
279
280    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
281        (**self).current_new_messages().await
282    }
283
284    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
285        (**self).add_messages(item).await
286    }
287
288    async fn add_message(&self, item: ChatMessage) -> Result<()> {
289        (**self).add_message(item).await
290    }
291
292    #[allow(deprecated)]
293    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
294        (**self).exec_cmd(cmd).await
295    }
296
297    fn executor(&self) -> &Arc<dyn ToolExecutor> {
298        (**self).executor()
299    }
300
301    async fn history(&self) -> Result<Vec<ChatMessage>> {
302        (**self).history().await
303    }
304
305    async fn redrive(&self) -> Result<()> {
306        (**self).redrive().await
307    }
308
309    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
310        (**self).has_received_feedback(tool_call).await
311    }
312
313    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
314        (**self).feedback_received(tool_call, feedback).await
315    }
316}
317
318#[async_trait]
319impl AgentContext for Arc<dyn AgentContext> {
320    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
321        (**self).next_completion().await
322    }
323
324    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
325        (**self).current_new_messages().await
326    }
327
328    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
329        (**self).add_messages(item).await
330    }
331
332    async fn add_message(&self, item: ChatMessage) -> Result<()> {
333        (**self).add_message(item).await
334    }
335
336    #[allow(deprecated)]
337    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
338        (**self).exec_cmd(cmd).await
339    }
340
341    fn executor(&self) -> &Arc<dyn ToolExecutor> {
342        (**self).executor()
343    }
344
345    async fn history(&self) -> Result<Vec<ChatMessage>> {
346        (**self).history().await
347    }
348
349    async fn redrive(&self) -> Result<()> {
350        (**self).redrive().await
351    }
352
353    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
354        (**self).has_received_feedback(tool_call).await
355    }
356
357    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
358        (**self).feedback_received(tool_call, feedback).await
359    }
360}
361
362#[async_trait]
363impl AgentContext for &dyn AgentContext {
364    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
365        (**self).next_completion().await
366    }
367
368    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
369        (**self).current_new_messages().await
370    }
371
372    async fn add_messages(&self, item: Vec<ChatMessage>) -> Result<()> {
373        (**self).add_messages(item).await
374    }
375
376    async fn add_message(&self, item: ChatMessage) -> Result<()> {
377        (**self).add_message(item).await
378    }
379
380    #[allow(deprecated)]
381    async fn exec_cmd(&self, cmd: &Command) -> Result<CommandOutput, CommandError> {
382        (**self).exec_cmd(cmd).await
383    }
384
385    fn executor(&self) -> &Arc<dyn ToolExecutor> {
386        (**self).executor()
387    }
388
389    async fn history(&self) -> Result<Vec<ChatMessage>> {
390        (**self).history().await
391    }
392
393    async fn redrive(&self) -> Result<()> {
394        (**self).redrive().await
395    }
396
397    async fn has_received_feedback(&self, tool_call: &ToolCall) -> Option<ToolFeedback> {
398        (**self).has_received_feedback(tool_call).await
399    }
400
401    async fn feedback_received(&self, tool_call: &ToolCall, feedback: &ToolFeedback) -> Result<()> {
402        (**self).feedback_received(tool_call, feedback).await
403    }
404}
405
406/// Convenience implementation for empty agent context
407///
408/// Errors if tools attempt to execute commands
409#[async_trait]
410impl AgentContext for () {
411    async fn next_completion(&self) -> Result<Option<Vec<ChatMessage>>> {
412        Ok(None)
413    }
414
415    async fn current_new_messages(&self) -> Result<Vec<ChatMessage>> {
416        Ok(Vec::new())
417    }
418
419    async fn add_messages(&self, _item: Vec<ChatMessage>) -> Result<()> {
420        Ok(())
421    }
422
423    async fn add_message(&self, _item: ChatMessage) -> Result<()> {
424        Ok(())
425    }
426
427    async fn exec_cmd(&self, _cmd: &Command) -> Result<CommandOutput, CommandError> {
428        Err(CommandError::ExecutorError(anyhow::anyhow!(
429            "Empty agent context does not have a tool executor"
430        )))
431    }
432
433    fn executor(&self) -> &Arc<dyn ToolExecutor> {
434        unimplemented!("Empty agent context does not have a tool executor")
435    }
436
437    async fn history(&self) -> Result<Vec<ChatMessage>> {
438        Ok(Vec::new())
439    }
440
441    async fn redrive(&self) -> Result<()> {
442        Ok(())
443    }
444
445    async fn has_received_feedback(&self, _tool_call: &ToolCall) -> Option<ToolFeedback> {
446        Some(ToolFeedback::Approved { payload: None })
447    }
448
449    async fn feedback_received(
450        &self,
451        _tool_call: &ToolCall,
452        _feedback: &ToolFeedback,
453    ) -> Result<()> {
454        Ok(())
455    }
456}
457
458/// A backend for the agent context. A default is provided for Arc<Mutex<Vec<ChatMessage>>>
459///
460/// If you want to use for instance a database, implement this trait and pass it to the agent
461/// context when creating it.
462#[async_trait]
463pub trait MessageHistory: Send + Sync + std::fmt::Debug {
464    /// Returns the history of messages
465    async fn history(&self) -> Result<Vec<ChatMessage>>;
466
467    /// Add a message to the history
468    async fn push_owned(&self, item: ChatMessage) -> Result<()>;
469
470    /// Overwrite the history with the given items
471    async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()>;
472
473    /// Add a message to the history
474    async fn push(&self, item: &ChatMessage) -> Result<()> {
475        self.push_owned(item.clone()).await
476    }
477
478    /// Extend the history with the given items
479    async fn extend(&self, items: &[ChatMessage]) -> Result<()> {
480        self.extend_owned(items.to_vec()).await
481    }
482
483    /// Extend the history with the given items, taking ownership of them
484    async fn extend_owned(&self, items: Vec<ChatMessage>) -> Result<()> {
485        for item in items {
486            self.push_owned(item).await?;
487        }
488
489        Ok(())
490    }
491}
492
493#[async_trait]
494impl MessageHistory for Mutex<Vec<ChatMessage>> {
495    async fn history(&self) -> Result<Vec<ChatMessage>> {
496        Ok(self.lock().unwrap().clone())
497    }
498
499    async fn push_owned(&self, item: ChatMessage) -> Result<()> {
500        self.lock().unwrap().push(item);
501
502        Ok(())
503    }
504
505    async fn overwrite(&self, items: Vec<ChatMessage>) -> Result<()> {
506        let mut lock = self.lock().unwrap();
507        *lock = items;
508
509        Ok(())
510    }
511}